一篇讓你學會組合問題!
組合
力扣題目鏈接:https://leetcode-cn.com/problems/combinations/
給定兩個整數 n 和 k,返回 1 ... n 中所有可能的 k 個數的組合。
示例: 輸入: n = 4, k = 2 輸出: [ [2,4], [3,4], [2,3], [1,2], [1,3], [1,4], ]
本題這是回溯法的經典題目。
直接的解法當然是使用for循環,例如示例中k為2,很容易想到 用兩個for循環,這樣就可以輸出 和示例中一樣的結果。
代碼如下:
- int n = 4;
- for (int i = 1; i <= n; i++) {
- for (int j = i + 1; j <= n; j++) {
- cout << i << " " << j << endl;
- }
- }
輸入:n = 100, k = 3 那么就三層for循環,代碼如下:
- int n = 100;
- for (int i = 1; i <= n; i++) {
- for (int j = i + 1; j <= n; j++) {
- for (int u = j + 1; u <= n; n++) {
- cout << i << " " << j << " " << u << endl;
- }
- }
- }
如果n為100,k為50呢,那就50層for循環,是不是開始窒息。
此時就會發現雖然想暴力搜索,但是用for循環嵌套連暴力都寫不出來!
咋整?
回溯搜索法來了,雖然回溯法也是暴力,但至少能寫出來,不像for循環嵌套k層讓人絕望。
那么回溯法怎么暴力搜呢?
上面我們說了要解決 n為100,k為50的情況,暴力寫法需要嵌套50層for循環,那么回溯法就用遞歸來解決嵌套層數的問題。
遞歸來做層疊嵌套(可以理解是開k層for循環),每一次的遞歸中嵌套一個for循環,那么遞歸就可以用于解決多層嵌套循環的問題了。
此時遞歸的層數大家應該知道了,例如:n為100,k為50的情況下,就是遞歸50層。
一些同學本來對遞歸就懵,回溯法中遞歸還要嵌套for循環,可能就直接暈倒了!
如果腦洞模擬回溯搜索的過程,絕對可以讓人窒息,所以需要抽象圖形結構來進一步理解。
我們在關于回溯算法,你該了解這些!中說道回溯法解決的問題都可以抽象為樹形結構(N叉樹),用樹形結構來理解回溯就容易多了。
那么我把組合問題抽象為如下樹形結構:
77.組合
可以看出這個棵樹,一開始集合是 1,2,3,4, 從左向右取數,取過的數,不在重復取。
第一次取1,集合變為2,3,4 ,因為k為2,我們只需要再取一個數就可以了,分別取2,3,4,得到集合[1,2] [1,3] [1,4],以此類推。
每次從集合中選取元素,可選擇的范圍隨著選擇的進行而收縮,調整可選擇的范圍。
圖中可以發現n相當于樹的寬度,k相當于樹的深度。
那么如何在這個樹上遍歷,然后收集到我們要的結果集呢?
圖中每次搜索到了葉子節點,我們就找到了一個結果。
相當于只需要把達到葉子節點的結果收集起來,就可以求得 n個數中k個數的組合集合。
在關于回溯算法,你該了解這些!中我們提到了回溯法三部曲,那么我們按照回溯法三部曲開始正式講解代碼了。
回溯法三部曲
遞歸函數的返回值以及參數
在這里要定義兩個全局變量,一個用來存放符合條件單一結果,一個用來存放符合條件結果的集合。
代碼如下:
- vector<vector<int>> result; // 存放符合條件結果的集合
- vector<int> path; // 用來存放符合條件結果
其實不定義這兩個全局遍歷也是可以的,把這兩個變量放進遞歸函數的參數里,但函數里參數太多影響可讀性,所以我定義全局變量了。
函數里一定有兩個參數,既然是集合n里面取k的數,那么n和k是兩個int型的參數。
然后還需要一個參數,為int型變量startIndex,這個參數用來記錄本層遞歸的中,集合從哪里開始遍歷(集合就是[1,...,n] )。
為什么要有這個startIndex呢?
每次從集合中選取元素,可選擇的范圍隨著選擇的進行而收縮,調整可選擇的范圍,就是要靠startIndex。
從下圖中紅線部分可以看出,在集合[1,2,3,4]取1之后,下一層遞歸,就要在[2,3,4]中取數了,那么下一層遞歸如何知道從[2,3,4]中取數呢,靠的就是startIndex。
組合2
所以需要startIndex來記錄下一層遞歸,搜索的起始位置。
那么整體代碼如下:
- vector<vector<int>> result; // 存放符合條件結果的集合
- vector<int> path; // 用來存放符合條件單一結果
- void backtracking(int n, int k, int startIndex)
回溯函數終止條件
什么時候到達所謂的葉子節點了呢?
path這個數組的大小如果達到k,說明我們找到了一個子集大小為k的組合了,在圖中path存的就是根節點到葉子節點的路徑。
如圖紅色部分:
組合3
此時用result二維數組,把path保存起來,并終止本層遞歸。
所以終止條件代碼如下:
- if (path.size() == k) {
- result.push_back(path);
- return;
- }
- 單層搜索的過程
回溯法的搜索過程就是一個樹型結構的遍歷過程,在如下圖中,可以看出for循環用來橫向遍歷,遞歸的過程是縱向遍歷。
組合1
如此我們才遍歷完圖中的這棵樹。
for循環每次從startIndex開始遍歷,然后用path保存取到的節點i。
代碼如下:
- for (int i = startIndex; i <= n; i++) { // 控制樹的橫向遍歷
- path.push_back(i); // 處理節點
- backtracking(n, k, i + 1); // 遞歸:控制樹的縱向遍歷,注意下一層搜索要從i+1開始
- path.pop_back(); // 回溯,撤銷處理的節點
- }
可以看出backtracking(遞歸函數)通過不斷調用自己一直往深處遍歷,總會遇到葉子節點,遇到了葉子節點就要返回。
backtracking的下面部分就是回溯的操作了,撤銷本次處理的結果。
關鍵地方都講完了,組合問題C++完整代碼如下:
- class Solution {
- private:
- vector<vector<int>> result; // 存放符合條件結果的集合
- vector<int> path; // 用來存放符合條件結果
- void backtracking(int n, int k, int startIndex) {
- if (path.size() == k) {
- result.push_back(path);
- return;
- }
- for (int i = startIndex; i <= n; i++) {
- path.push_back(i); // 處理節點
- backtracking(n, k, i + 1); // 遞歸
- path.pop_back(); // 回溯,撤銷處理的節點
- }
- }
- public:
- vector<vector<int>> combine(int n, int k) {
- result.clear(); // 可以不寫
- path.clear(); // 可以不寫
- backtracking(n, k, 1);
- return result;
- }
- };
還記得我們在關于回溯算法,你該了解這些!中給出的回溯法模板么?
如下:
- void backtracking(參數) {
- if (終止條件) {
- 存放結果;
- return;
- }
- for (選擇:本層集合中元素(樹中節點孩子的數量就是集合的大?。? {
- 處理節點;
- backtracking(路徑,選擇列表); // 遞歸
- 回溯,撤銷處理結果
- }
- }
對比一下本題的代碼,是不是發現有點像! 所以有了這個模板,就有解題的大體方向,不至于毫無頭緒。
總結
組合問題是回溯法解決的經典問題,我們開始的時候給大家列舉一個很形象的例子,就是n為100,k為50的話,直接想法就需要50層for循環。
從而引出了回溯法就是解決這種k層for循環嵌套的問題。
然后進一步把回溯法的搜索過程抽象為樹形結構,可以直觀的看出搜索的過程。
接著用回溯法三部曲,逐步分析了函數參數、終止條件和單層搜索的過程。
剪枝優化
我們說過,回溯法雖然是暴力搜索,但也有時候可以有點剪枝優化一下的。
在遍歷的過程中有如下代碼:
- for (int i = startIndex; i <= n; i++) {
- path.push_back(i);
- backtracking(n, k, i + 1);
- path.pop_back();
- }
這個遍歷的范圍是可以剪枝優化的,怎么優化呢?
來舉一個例子,n = 4,k = 4的話,那么第一層for循環的時候,從元素2開始的遍歷都沒有意義了。在第二層for循環,從元素3開始的遍歷都沒有意義了。
這么說有點抽象,如圖所示:
組合4
圖中每一個節點(圖中為矩形),就代表本層的一個for循環,那么每一層的for循環從第二個數開始遍歷的話,都沒有意義,都是無效遍歷。
所以,可以剪枝的地方就在遞歸中每一層的for循環所選擇的起始位置。
如果for循環選擇的起始位置之后的元素個數 已經不足 我們需要的元素個數了,那么就沒有必要搜索了。
注意代碼中i,就是for循環里選擇的起始位置。
- for (int i = startIndex; i <= n; i++) {
接下來看一下優化過程如下:
已經選擇的元素個數:path.size();
還需要的元素個數為: k - path.size();
在集合n中至多要從該起始位置 : n - (k - path.size()) + 1,開始遍歷
為什么有個+1呢,因為包括起始位置,我們要是一個左閉的集合。
舉個例子,n = 4,k = 3, 目前已經選取的元素為0(path.size為0),n - (k - 0) + 1 即 4 - ( 3 - 0) + 1 = 2。
從2開始搜索都是合理的,可以是組合[2, 3, 4]。
這里大家想不懂的話,建議也舉一個例子,就知道是不是要+1了。
所以優化之后的for循環是:
- for (int i = startIndex; i <= n - (k - path.size()) + 1; i++) // i為本次搜索的起始位置
優化后整體代碼如下:
- class Solution {
- private:
- vector<vector<int>> result;
- vector<int> path;
- void backtracking(int n, int k, int startIndex) {
- if (path.size() == k) {
- result.push_back(path);
- return;
- }
- for (int i = startIndex; i <= n - (k - path.size()) + 1; i++) { // 優化的地方
- path.push_back(i); // 處理節點
- backtracking(n, k, i + 1);
- path.pop_back(); // 回溯,撤銷處理的節點
- }
- }
- public:
- vector<vector<int>> combine(int n, int k) {
- backtracking(n, k, 1);
- return result;
- }
- };
剪枝總結
本篇我們準對求組合問題的回溯法代碼做了剪枝優化,這個優化如果不畫圖的話,其實不好理解,也不好講清楚。
所以我依然是把整個回溯過程抽象為一顆樹形結構,然后可以直觀的看出,剪枝究竟是剪的哪里。
其他語言版本
Java
- class Solution {
- List<List<Integer>> result = new ArrayList<>();
- LinkedList<Integer> path = new LinkedList<>();
- public List<List<Integer>> combine(int n, int k) {
- combineHelper(n, k, 1);
- return result;
- }
- /**
- * 每次從集合中選取元素,可選擇的范圍隨著選擇的進行而收縮,調整可選擇的范圍,就是要靠startIndex
- * @param startIndex 用來記錄本層遞歸的中,集合從哪里開始遍歷(集合就是[1,...,n] )。
- */
- private void combineHelper(int n, int k, int startIndex){
- //終止條件
- if (path.size() == k){
- result.add(new ArrayList<>(path));
- return;
- }
- for (int i = startIndex; i <= n - (k - path.size()) + 1; i++){
- path.add(i);
- combineHelper(n, k, i + 1);
- path.removeLast();
- }
- }
- }
Python
- class Solution:
- def combine(self, n: int, k: int) -> List[List[int]]:
- res=[] #存放符合條件結果的集合
- path=[] #用來存放符合條件結果
- def backtrack(n,k,startIndex):
- if len(path) == k:
- res.append(path[:])
- return
- for i in range(startIndex,n+1):
- path.append(i) #處理節點
- backtrack(n,k,i+1) #遞歸
- path.pop() #回溯,撤銷處理的節點
- backtrack(n,k,1)
- return res
Go
- var res [][]int
- func combine(n int, k int) [][]int {
- res=[][]int{}
- if n <= 0 || k <= 0 || k > n {
- return res
- }
- backtrack(n, k, 1, []int{})
- return res
- }
- func backtrack(n,k,start int,track []int){
- if len(track)==k{
- temp:=make([]int,k)
- copy(temp,track)
- res=append(res,temp)
- }
- if len(track)+n-start+1 < k {
- return
- }
- for i:=start;i<=n;i++{
- track=append(track,i)
- backtrack(n,k,i+1,track)
- track=track[:len(track)-1]
- }
- }




































