2020年6月15日 星期一

LeetCode 4. Median of Two Sorted Arrays [Hard] [C++] 解題筆記

這題給定兩個排序好的 array nums1 和 nums2,其長度分別為 m, n,要求我們在 O(log(m+n)) 的時間複雜度限制內找出這兩個 array 合併後的中位數。

EX:
        nums1 = [1, 3]
        nums2 = [2]

    The median is 2.0


想法:
    這題如果沒有時間複雜度限制的話就非常簡單,最簡單的方式就是將 nums1, nums2 都放入一個新的 array nums3, 
然後對 nums3 進行排序,之後再找 median 即可,這樣的時間複雜度是 O((m+n)log(m+n))。另一種也很直覺的方式是
直接對 nums1 和 nums2 做 merge,因為 nums1 和 nums2 都是排序好的,所以只需要逐個比大小像 merge sort 在
merge 時的概念一樣,就可以得到排序好的新 array nums3,時間複雜度是 O(m+n)。上述兩種方法皆可通過,但顯然不符
合題目要求的複雜度,注意到題目限制是 O(log(m+n)) 而 nums1 和 nums2 皆是排序好的,這時有沒有想到甚摸? 沒錯!
sorted array + 對數級時間複雜度 的搜索方式就是 binary search 拉,不過這邊實際操作與一般的常規 binary search 
有些差異,可以算是運用 binary search 的概念而已。首先這題可以想像成我們目標是找出 nums1 , nums2 合併後 array nums3 的第 k 項(median),
那摸我們可以先分別在 nums1 和 nums2 中找出第 k / 2 項,假設 nums1 中的第 k / 2 項為 k1,nums2 中的第 k / 2 項為 k2,
則若 k1 < k2 表示 nums1 的前 k / 2 項(1~k/2)都比 nums2的第 k / 2 項還要小,這就表示我們的目標 nums3 中的第 k 項不會出現
在 nums1 中的前 k / 2 項,因為 k1 < k2,所以就算最極端情況 nums2 中的前 1~ k / 2 - 1 都小於 k1,則合併後的 nums3 中,1~k1
最多也只有 k / 2 + k / 2 - 1 = k -1 個數字不會包含 nums3 中的第 k 項,因此可以將 nums1 中的 1 ~ k / 2 項剔除,接著繼續在 
nums1剩下的元素與 nums2 中分別找第 k - k / 2 個數(因為已經確定前面有 k / 2 個數了),依此類推直到 nums1 或 nums2 為空,
反之若 k2 < k1 亦然。
另外這裡可以特別思考一下若是其中一個 array 不存在第 k / 2 項怎摸辦呢?
有沒有可能兩個 array 都不存在第 k / 2 項? 第一個問題假設 nums2 不存在第 k / 2 項,那我們就必須剔除 nums1 的 1 ~ k / 2 項,
因為 nums2 不足 k / 2 個元素,因此就算所有的元素都小於 nums1 中的第 k / 2 項合併後 k / 2 項之前的元素也不到 k 個,
因此合併後第 k 項不會存在 nums1 的 1~ k / 2 項之中。第二個問題不可能會發生,因為這裡的 k 初始值是 nums1 + nums2 的中位數,
因此至少其中一個 array 會存在第 k / 2 個元素,否則就表示兩個 array 長度都不到 k / 2 ,與假設矛盾。

Complexity: O(m+n) / O(log(m+n))

完整程式碼:
解法一(merge two arrays and find median):
class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        int p1 = 0;
        int p2 = 0;
        double mid = 0;
        vector<int> nums1andnums2;
        if (nums1.empty() && nums2.empty()) { return 0.0; }
        else if (nums1.empty()) { nums1andnums2 = nums2; }
        else if (nums2.empty()) { nums1andnums2 = nums1; }
        else {
            while (p1 < nums1.size() || p2 < nums2.size()) {
                if (p1 >= nums1.size()) { 
                    nums1andnums2.emplace_back(nums2.at(p2));
                    p2++;
                }
                else if (p2 >= nums2.size()) {
                    nums1andnums2.emplace_back(nums1.at(p1));
                    p1++;
                }
                else if (nums1.at(p1) < nums2.at(p2)) {
                    nums1andnums2.emplace_back(nums1.at(p1));
                    p1++;
                }
                else if (nums1.at(p1) > nums2.at(p2)) {
                    nums1andnums2.emplace_back(nums2.at(p2));
                    p2++;
                }
                else {
                    nums1andnums2.emplace_back(nums1.at(p1));
                    nums1andnums2.emplace_back(nums2.at(p2));
                    p1++;
                    p2++;
                }
            }
        }
        
        auto size = nums1andnums2.size();
        // even length
        if (size % 2 == 0) {
            if (size == 1) { mid = nums1andnums2.at(0); }
            else {
                mid = nums1andnums2.at(size/2)+nums1andnums2.at(size/2-1);
                mid = mid / 2;
            }
        }
        // odd length
        else {
            mid = nums1andnums2.at(size/2);
        }
        return mid;
    }
};

解法二(binary search / divide and conquer):
class Solution {
public:
    double findKthNumber(const vector<int>& nums1, int start1, const vector<int>& nums2, int start2, int k) {
            if (start1 >= nums1.size()) {
                return nums2[start2 + k - 1];
            }
            if (start2 >= nums2.size()) {
                return nums1[start1 + k - 1];
            }
            if (k == 1) {
                return min(nums1[start1], nums2[start2]);
            }
            int idx1 = start1 + k / 2 - 1;
            int idx2 = start2 + k / 2 - 1;
            int mid1 = (idx1 >= nums1.size())? INT_MAX : nums1[idx1];
            int mid2 = (idx2 >= nums2.size())? INT_MAX : nums2[idx2];
            if (mid1 < mid2) {
                return findKthNumber(nums1, start1 + k / 2, nums2, start2, k - k / 2);
            }
            else {
                return findKthNumber(nums1, start1, nums2, start2 + k / 2, k - k / 2);
            }
        }
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        int len = nums1.size() + nums2.size();
        // odd length
        if (len % 2) {
            return findKthNumber(nums1, 0, nums2, 0, len / 2 + 1);
        }
        // even length
        else {
            return (findKthNumber(nums1, 0, nums2, 0, len / 2) + findKthNumber(nums1, 0, nums2, 0, len / 2 + 1)) / 2.0;
        }
        
        return 0.0;
    }
};

沒有留言:

張貼留言