再探有序二维矩阵搜索问题之分治解法

今天重新做了下搜索有序二维矩阵的问题:

在一个大小为 m x n 的二维矩阵中查找目标值 target 是否存在, 已知此矩阵的每一行、每一列的数字都是有序的。

流行的做法是,从矩阵右上角出发进行搜索, 其时间复杂度是 O(m+n)。这个解法简洁而优美, 可以见我之前写过的这篇文章 - 有序二维矩阵搜索问题

搜索有序二维矩阵的常规做法

但是,在矩阵规模增大的时候,它并非最快的解法。

本文将分享一个更优的分治解法。

思路分析

首先,矩阵的左上顶点一定是最小值,右下顶点一定是最大值。

在这两个顶点的对角线上,取中点 m ,把矩阵分为四部分, 如下图所示:

我们知道,四个部分之中,左上的 I 一定不大于右下的 IV

更确切的说,四个部分的大小关系是:

  • I <= II <= IV
  • I <= III <= IV

但是,IIIII 中元素的大小关系并无法确定

那总的思路就是:

先排除 IIIII ,再根据 mtarget 的大小关系,检查 IIV 之一即可

左上角的位置 a 和 右下角的位置 z 可以表示一个子矩阵, 同时对划分的交点进行标号 b, d, f, h ,那么这个思路的伪代码可以写为:

bool find(a, z, target) {
    // 检查右上角 II
    if find(b, f, target) return true
    // 检查左下角 III
    if find(d, h, target) return true
    // 根据 m 和 target 的大小关系,检查 I 或者 IV
    if target < m
        return find(a, m, target)
    else
        return find(m, z, target)
}

可以看到,每次分四份,最差需要检查三份, 面积减少的速度是 3/4 。递归下去,就是 log(m*n) 级别的时间复杂度。

复杂度分析

假设矩阵的面积是 $S = m\times n$ 。

最差的情况下是,每次减少到原来的 $3/4$ ,假设需要 $k$ 次才可以减到 1,那么:

\[S\left( \frac{3}{4} \right)^{k} =1 \\\Longrightarrow \ \ S = \left( \frac{4}{3} \right)^{k} \\\Longrightarrow \ \ k = \log_{4/3} S\]

可以进一步通过换底公式来换到 $2$ 的底, 大约是:

\[k = \frac {\log_{2} S} {\log_{2} {4/3}} \\\approx 2.40942\ log_{2} S\]

总的说,其时间杂度是 $\log\left(m\times n\right)$ ,即 $\log {m} +\log {n}$。

特殊地,当 $m = n$ 时,大约是 $4.8\ \log_{2} n $ 。

代码实现

在每次搜索子矩阵的时候,检查下两个顶点 a, z 和中点 m 的值, 即可判断 target 的存在性。 最终的实现见下:

搜索有序二维矩阵的 C++ 代码实现
class Solution {
   public:
    bool find(vector<vector<int>>& x, int ai, int aj, int zi, int zj,
              int target) {
        if (ai > zi || aj > zj) return false;

        auto a = x[ai][aj];  // 当前矩阵最小值
        auto z = x[zi][zj];  // 当前矩阵最大值

        // 左上角和右下角相等的时候
        if (a == z) {
            if (target == a) return true;
            return false;
        }

        // 恰好在顶点
        if (target == a) return true;
        if (target == z) return true;

        // target 不在范围内的时候
        if (target < a || target > z) return false;

        // 检查顶点 c 和 g
        if (target == x[ai][zj]) return true;
        if (target == x[zi][aj]) return true;

        // 如果矩阵大小就是 4,那么直接 false
        // 这里也排除了下面的二分的死递归可能
        if (zi <= ai + 1 && zj <= aj + 1) return false;

        // 取对角线中点
        auto mi = (ai + zi) / 2;
        auto mj = (aj + zj) / 2;

        auto m = x[mi][mj];

        if (m == target) return true;

        // 检查各个子模块,每次最差排除 1/4 部分
        // 也就是说,以 3/4 的速度减少面积
        // 需要先排除 II 和 III

        // 检查 II:  b => (ai, mj), f => (mi, zj)
        // 不包含 m 所在行和列
        if (find(x, ai, mj + 1, mi - 1, zj, target)) return true;

        // 检查 III: d => (mi, aj), h => (zi, mj)
        // 不包含 m 所在行和列
        if (find(x, mi + 1, aj, zi, mj - 1, target)) return true;

        // 检查 I (包括 m 所在行和列)
        if (target < m) return find(x, ai, aj, mi, mj, target);

        // 检查 IV  (包括 m 所在行和列)
        return find(x, mi, mj, zi, zj, target);
    }

    bool searchMatrix(vector<vector<int>>& matrix, int target) {
        int m = matrix.size();
        int n = matrix[0].size();
        return find(matrix, 0, 0, m - 1, n - 1, target);
    }
};

关于性能

奇怪的的是,在 leetcode 上,这个算法的速度比不上传统的 $O(m+n)$ 的算法, 因为 leetcode 的数据集保证了 1 <= n, m <= 300,就是说,矩阵规模比较小的时候,传统线性方法更快。

因此,我写了一个 benchmark 测试,实际对比了两种方案,在矩阵规模比较大的时候,分治方法的速度优势就很明显了

对比 分治算法 和 传统线性方法的速度的压测结果 - github

相似问题

leetcode 上有一个相似的问题 1351. 统计有序矩阵中的负数

给你一个 m * n 的矩阵 grid,矩阵中的元素无论是按行还是按列,都以非递增顺序排列。 请你统计并返回 grid 中 负数 的数目。

也可以用这个思路来做,奇怪的是,这个题的提交来看,这个方法的速度上去了。

统计有序矩阵中的负数 C++ 代码实现
class Solution {
   public:
    // 统计当前矩阵中的负数, a 和 z 是当前矩阵的左上和右下的顶点
    int count(vector<vector<int>>& x, int ai, int aj, int zi, int zj) {
        // 递归终止: 越界
        if (ai > zi || aj > zj) return 0;

        auto a = x[ai][aj];  // 左上角
        auto z = x[zi][zj];  // 右下角

        // 顶点特判:

        // 递归终止: 只有一个元素的时候
        if (ai == zi && aj == zj) return a < 0 ? 1 : 0;

        // 最大的就是负数,则整个矩阵就是负数
        if (a < 0) return (zi - ai + 1) * (zj - aj + 1);
        // 最小的是非负数,则整个矩阵非负, 计数 0
        if (z >= 0) return 0;

        // 特判矩阵对角线 <= 2 个元素的情况,以防止取中点死递归
        // 现在已知 a >= 0 且 z < 0, 所以初始计数 1
        if (zi == ai + 1 && zj == aj) return 1;  // 两行一列, z 是负数
        if (zi == ai && zj == aj + 1) return 1;  // 一行两列, z 是负数
        if (zi == ai + 1 && zj == aj + 1) {
            // 两行两列, 分别检查 c 和 g 点是否是负数
            return 1 + (x[ai][zj] < 0 ? 1 : 0) + (x[zi][aj] < 0 ? 1 : 0);
        }

        // 取对角线中点
        // m 会最终算入 I 或者 IV 模块
        // 但是无论如何,m 这个点只会被考察一次
        auto mi = (ai + zi) / 2;
        auto mj = (aj + zj) / 2;
        auto m = x[mi][mj];

        int ans = 0;

        // 检查各个子模块,每次最差排除 1/4 部分
        // 也就是说,以 3/4 的速度减少面积
        // 必须检查 II (不含中点所在行和列)
        ans += count(x, ai, mj + 1, mi - 1, zj);

        // 必须检查 III (不含中点所在行和列)
        ans += count(x, mi + 1, aj, zi, mj - 1);

        if (m >= 0) {
            // 此时不必检查 I, 因为 I 都不小于 m, 肯定都非负
            // 只需要检查 IV (包括 m 所在的行和列)
            ans += count(x, mi, mj, zi, zj);
        } else {  // m < 0
            // 此时 IV 全部负数, 全部计入 (包括 m 所在的行和列)
            // 但是为了防止 m 被多算一次,所以要 -1
            ans += (zi - mi + 1) * (zj - mj + 1) - 1;
            // 然后只需要检查 I (包括 m 所在的行和列)
            ans += count(x, ai, aj, mi, mj);
        }
        return ans;
    }

    int countNegatives(vector<vector<int>>& grid) {
        int m = grid.size();
        int n = grid[0].size();
        return count(grid, 0, 0, m - 1, n - 1);
    }
};

(完)

本文原始链接地址: https://writings.sh/post/search-sorted-2d-matrix-revisited

王超 ·
微信扫码赞赏
评论 首页 | 归档 | 算法 | 订阅