LeetCode Problem: Unique Binary Search Trees 詳解

Description

給定一個數字 n,回傳 1, 2, ...n 這些數字可以組成多少不同的 binary search tree。

Example 1:

Input: n = 3
Output: 5

Example 2:

Input: n = 1
Output: 1

Solution

準備工作

我們將 $f(i, j)$ 定義為數字 $i$ 到數字 $j$ 可以組合出 binary search tree 的數量。舉例來說 $f(1, 3) = 5$,因為這三個數字能組出的 binary search tree 有五個:

再舉一個例子,$f(2, 4) = 5$:

從我們舉的兩個例子,我們可以很容易的注意到一個規律:$f$ 函數只會受到 $i$、$j$ 的差值影響

$$
\begin{aligned}
&f(1, 1) = f(2, 2) = \cdots = g(0) \\
&f(1, 2) = f(2, 3) = \cdots = g(1) \\
&f(1, 3) = f(2, 4) = \cdots = g(2) \\
\end{aligned}
$$

有了以上的特性之後,我們就可以用函數 $g(i)$ 代表 $f(x, x + i)$。

通用公式

接下來我們要利用先前定義的函數推導出通用的公式。假設我們考慮 $f(1, 3)$ 的狀況,那數字 $1, \cdots, 3$ 的組合有三種,分別以 $3$、$2$、$1$ 為根節點:

可以注意到的是,若我們以 $3$ 作為根節點,那數字 $1$、$2$ 必定會放在左邊子樹,才能符合 binary search tree 的要求。另外,我們還可以輕易的推出以 $3$ 作為根節點的數量,其實剛好就等於 $f(1, 2)$。

若我們讓 $2$ 作為根節點,那剩下的數字 $1$ 必定需要放在左子樹,數字 $3$ 放在右子樹,數量分別為 $f(1, 1)$ 和 $f(3, 3)$。所以以 $2$ 作為根節點可以組成的 binary search tree 數量是:$f(1, 1) \times f(3, 3)$。

接下來我們可以推展到 $f(1, n)$,整體的概念跟 $f(1, 3)$ 如出一轍:

因此我們可以推導出公式:

$$
\begin{aligned}
f(1, n) &= f(1, n – 1) + \sum_{i=2}^{n – 1} f(1, n – i) \times f(n – (i – 2), n) + f(2, n) \\
\Rightarrow g(n – 1) &= g(n – 2) + \sum_{i=2}^{n – 1} g(n – i – 1) \times g(i – 2) + g(n – 2)
\end{aligned}
$$

從上述公式可以發現,我們如果要計算目標 $g(n – 1)$,可以從起始值開始慢慢往下計算,時間複雜度為 $\mathcal{O} (n)$:

$$
\begin{aligned}
g(0) &= 1 \\
g(1) &= 2 \\
g(2) &= 2 \cdot g(1) + g(0) \cdot g(0) = 5 \\
g(3) &= 2 \cdot g(2) + g(1) \cdot g(0) + g(0) \cdot g(1) = 14 \\
g(4) &= 2 \cdot g(3) + g(2) \cdot g(0) + g(1) \cdot g(1) + g(0) \cdot g(2) = 42 \\
\dots
\end{aligned}
$$

完整演算法如下:


class Solution {
public:
    int numTrees(int n) {
        // Border condition
        if (1 == n) return 1;
        else if (2 == n) return 2;
        
        // Start
        int *record = new int[n]{ 0 };
        record[0] = 1; // n = 1;
        record[1] = 2; // n = 2;
        
        for (int differ = 2; differ < n; ++differ) {
            record[differ] = 2*record[differ - 1]; // n = differ + 1
            
            for (int idx = 2; idx < differ + 1; ++idx) {
                record[differ] += (record[differ - idx]*record[idx - 2]);
            }
        }
        
        return record[n - 1];
    }
};