今天重新做了下搜索有序二维矩阵的问题:
在一个大小为 m x n 的二维矩阵中查找目标值 target 是否存在, 已知此矩阵的每一行、每一列的数字都是有序的。
流行的做法是,从矩阵右上角出发进行搜索, 其时间复杂度是 O(m+n)
。这个解法简洁而优美, 可以见我之前写过的这篇文章 - 有序二维矩阵搜索问题 。
但是,在矩阵规模增大的时候,它并非最快的解法。
本文将分享一个更优的分治解法。
思路分析 ¶
首先,矩阵的左上顶点一定是最小值,右下顶点一定是最大值。
在这两个顶点的对角线上,取中点 m
,把矩阵分为四部分, 如下图所示:
我们知道,四个部分之中,左上的 I
一定不大于右下的 IV
。
更确切的说,四个部分的大小关系是:
I <= II <= IV
I <= III <= IV
但是,II
和 III
中元素的大小关系并无法确定。
那总的思路就是:
先排除 II
和 III
,再根据 m
和 target
的大小关系,检查 I
和 IV
之一即可。
左上角的位置 a
和 右下角的位置 z
可以表示一个子矩阵, 同时对划分的交点进行标号 b, d, f, h
,那么这个思路的伪代码可以写为:
bool find(a, z, target) {
// 检查右上角 II
if find(b, f, target) return true
// 检查左下角 III
if find(d, h, target) return true
// 根据 m 和 target 的大小关系,检查 I 或者 IV
if target < m
return find(a, m, target)
else
return find(m, z, target)
}
可以看到,每次分四份,最差需要检查三份, 面积减少的速度是 3/4
。递归下去,就是 log(m*n)
级别的时间复杂度。
复杂度分析 ¶
假设矩阵的面积是 $S = m\times n$ 。
最差的情况下是,每次减少到原来的 $3/4$ ,假设需要 $k$ 次才可以减到 1
,那么:
可以进一步通过换底公式来换到 $2$ 的底, 大约是:
\[k = \frac {\log_{2} S} {\log_{2} {4/3}} \\\approx 2.40942\ log_{2} S\]总的说,其时间杂度是 $\log\left(m\times n\right)$ ,即 $\log {m} +\log {n}$。
特殊地,当 $m = n$ 时,大约是 $4.8\ \log_{2} n $ 。
代码实现 ¶
在每次搜索子矩阵的时候,检查下两个顶点 a
, z
和中点 m
的值, 即可判断 target
的存在性。 最终的实现见下:
搜索有序二维矩阵的 C++ 代码实现
class Solution {
public:
bool find(vector<vector<int>>& x, int ai, int aj, int zi, int zj,
int target) {
if (ai > zi || aj > zj) return false;
auto a = x[ai][aj]; // 当前矩阵最小值
auto z = x[zi][zj]; // 当前矩阵最大值
// 左上角和右下角相等的时候
if (a == z) {
if (target == a) return true;
return false;
}
// 恰好在顶点
if (target == a) return true;
if (target == z) return true;
// target 不在范围内的时候
if (target < a || target > z) return false;
// 检查顶点 c 和 g
if (target == x[ai][zj]) return true;
if (target == x[zi][aj]) return true;
// 如果矩阵大小就是 4,那么直接 false
// 这里也排除了下面的二分的死递归可能
if (zi <= ai + 1 && zj <= aj + 1) return false;
// 取对角线中点
auto mi = (ai + zi) / 2;
auto mj = (aj + zj) / 2;
auto m = x[mi][mj];
if (m == target) return true;
// 检查各个子模块,每次最差排除 1/4 部分
// 也就是说,以 3/4 的速度减少面积
// 需要先排除 II 和 III
// 检查 II: b => (ai, mj), f => (mi, zj)
// 不包含 m 所在行和列
if (find(x, ai, mj + 1, mi - 1, zj, target)) return true;
// 检查 III: d => (mi, aj), h => (zi, mj)
// 不包含 m 所在行和列
if (find(x, mi + 1, aj, zi, mj - 1, target)) return true;
// 检查 I (包括 m 所在行和列)
if (target < m) return find(x, ai, aj, mi, mj, target);
// 检查 IV (包括 m 所在行和列)
return find(x, mi, mj, zi, zj, target);
}
bool searchMatrix(vector<vector<int>>& matrix, int target) {
int m = matrix.size();
int n = matrix[0].size();
return find(matrix, 0, 0, m - 1, n - 1, target);
}
};
关于性能 ¶
奇怪的的是,在 leetcode 上,这个算法的速度比不上传统的 $O(m+n)$ 的算法, 因为 leetcode 的数据集保证了 1 <= n, m <= 300
,就是说,矩阵规模比较小的时候,传统线性方法更快。
因此,我写了一个 benchmark 测试,实际对比了两种方案,在矩阵规模比较大的时候,分治方法的速度优势就很明显了。
对比 分治算法 和 传统线性方法的速度的压测结果 - github 。
相似问题 ¶
leetcode 上有一个相似的问题 1351. 统计有序矩阵中的负数:
给你一个 m * n 的矩阵 grid,矩阵中的元素无论是按行还是按列,都以非递增顺序排列。 请你统计并返回 grid 中 负数 的数目。
也可以用这个思路来做,奇怪的是,这个题的提交来看,这个方法的速度上去了。
统计有序矩阵中的负数 C++ 代码实现
class Solution {
public:
// 统计当前矩阵中的负数, a 和 z 是当前矩阵的左上和右下的顶点
int count(vector<vector<int>>& x, int ai, int aj, int zi, int zj) {
// 递归终止: 越界
if (ai > zi || aj > zj) return 0;
auto a = x[ai][aj]; // 左上角
auto z = x[zi][zj]; // 右下角
// 顶点特判:
// 递归终止: 只有一个元素的时候
if (ai == zi && aj == zj) return a < 0 ? 1 : 0;
// 最大的就是负数,则整个矩阵就是负数
if (a < 0) return (zi - ai + 1) * (zj - aj + 1);
// 最小的是非负数,则整个矩阵非负, 计数 0
if (z >= 0) return 0;
// 特判矩阵对角线 <= 2 个元素的情况,以防止取中点死递归
// 现在已知 a >= 0 且 z < 0, 所以初始计数 1
if (zi == ai + 1 && zj == aj) return 1; // 两行一列, z 是负数
if (zi == ai && zj == aj + 1) return 1; // 一行两列, z 是负数
if (zi == ai + 1 && zj == aj + 1) {
// 两行两列, 分别检查 c 和 g 点是否是负数
return 1 + (x[ai][zj] < 0 ? 1 : 0) + (x[zi][aj] < 0 ? 1 : 0);
}
// 取对角线中点
// m 会最终算入 I 或者 IV 模块
// 但是无论如何,m 这个点只会被考察一次
auto mi = (ai + zi) / 2;
auto mj = (aj + zj) / 2;
auto m = x[mi][mj];
int ans = 0;
// 检查各个子模块,每次最差排除 1/4 部分
// 也就是说,以 3/4 的速度减少面积
// 必须检查 II (不含中点所在行和列)
ans += count(x, ai, mj + 1, mi - 1, zj);
// 必须检查 III (不含中点所在行和列)
ans += count(x, mi + 1, aj, zi, mj - 1);
if (m >= 0) {
// 此时不必检查 I, 因为 I 都不小于 m, 肯定都非负
// 只需要检查 IV (包括 m 所在的行和列)
ans += count(x, mi, mj, zi, zj);
} else { // m < 0
// 此时 IV 全部负数, 全部计入 (包括 m 所在的行和列)
// 但是为了防止 m 被多算一次,所以要 -1
ans += (zi - mi + 1) * (zj - mj + 1) - 1;
// 然后只需要检查 I (包括 m 所在的行和列)
ans += count(x, ai, aj, mi, mj);
}
return ans;
}
int countNegatives(vector<vector<int>>& grid) {
int m = grid.size();
int n = grid[0].size();
return count(grid, 0, 0, m - 1, n - 1);
}
};
(完)
本文原始链接地址: https://writings.sh/post/search-sorted-2d-matrix-revisited