Binary Search

最常见的算法之一。因为每次写出来的都不一样 XD ,所以在这里留一个笔记。

算法输出与循环不变量

这里介绍的二分搜索算法可以寻找目标元素的最小索引(假设数组已经从小到大排序),当无法找到时,则返回目标元素可以插入的位置(也就是比多少个元素大)。

已知已排序的数组 elem ,定义两个变量 left & right ,满足

每次循环都要缩小 left - right 的范围来找到目标元素。因为需要对数组进行排序和比较大小,所以数组元素 / 数组元素对应的 key 必须是全序的。

缩小范围的步骤为

缩小范围的过程一直保持后两个不变量,所以跳出循环的条件应该是第一个不变量不成立。此时满足 target > elem[..left] and target <= elem[right..] (注意这里 ..left 不包括 left ),由于不满足第一个不变量,此时 left >= right 。由于缩小范围的过程中,每次缩小时都满足 right' = mid or left' = mid + 1 ,而 left <= mid < right ,所以最后一次缩小后,left == right 。考虑到后两个不变量成立,此时 left 就是目标元素的最小索引。

Code

fn binary_search<T>(elem: &[T], target: T) -> Result<usize, usize>
where
    T: Ord
{
    use std::cmp::Ordering;

    let len = elem.len();
    let mut left = 0;
    let mut right = len;

    while left < right {
        let mid = left + ((right - left) >> 1);
        match target.cmp(&elem[mid]) {
            Ordering::Less | Ordering::Equal => {
                right = mid;
            }
            Ordering::Greater => {
                left = mid + 1;
            }
        }
    }

    if elem[left] == mid { Ok(left) } else { Err(left) }
}

例题:寻找两个正序数组的中位数

题目描述

本题的关键是确定如何进行二分搜索。我们假设最终的中位数大于等于 nums1 , nums2 的前 id1 , id2 个元素。

pub fn find_median_sorted_arrays(mut nums1: Vec<i32>, mut nums2: Vec<i32>) -> f64 {
    fn binary_search_by_key(elems: &[usize], target: bool, key_fn: impl Fn(usize) -> bool) -> Result<usize, usize> {
        use std::cmp::Ordering;

        let len = elems.len();
        let (mut left, mut right) = (0, len);

        while left < right {
            println!("{left} - {right}");
            let mid = left + ((right - left) >> 1);
            let mid_key = key_fn(elems[mid]);

            match target.cmp(&mid_key) {
                Ordering::Less | Ordering::Equal => {
                    right = mid;
                }
                Ordering::Greater => {
                    left = mid + 1;
                }
            }
        }

        let left_key = key_fn(elems[left]);
        if left_key == target { Ok(left) } else { Err(left) }
    }

    let (mut len1, mut len2) = (nums1.len(), nums2.len());
    let len = len1 + len2;
    let mid = len >> 1;

    if len1 > len2 {
        std::mem::swap(&mut len1, &mut len2);
        std::mem::swap(&mut nums1, &mut nums2);
    }

    let elems: Vec<usize> = (0..=len1).collect();
    let search_res = binary_search_by_key(&elems[..], true, |id| {
        let id1 = id;
        let id2 = mid - id1;

        if id1 == len1 {
            true
        } else {
            nums1[id1] >= nums2[id2 - 1]
        }
    }).unwrap();
    let id1 = search_res;
    let id2 = mid - id1;

    if len & 1 == 1 {
        *(nums1.get(id1).unwrap_or(&i32::MAX).min(nums2.get(id2).unwrap_or(&i32::MAX))) as f64
    } else {
        let left = *(nums1.get(id1 - 1).unwrap_or(&i32::MIN).max(nums2.get(id2 - 1).unwrap_or(&i32::MIN)));
        let right = *(nums1.get(id1).unwrap_or(&i32::MAX).min(nums2.get(id2).unwrap_or(&i32::MAX)));
        (left as f64 + right as f64) / 2.
    }
}