2040. 两个有序数组的第 K 小乘积

December 15, 2025 · View on GitHub

English Version

题目描述

给你两个 从小到大排好序 且下标从 0 开始的整数数组 nums1 和 nums2 以及一个整数 k ,请你返回第 k (从 1 开始编号)小的 nums1[i] * nums2[j] 的乘积,其中 0 <= i < nums1.length 0 <= j < nums2.length 。

 

示例 1:

输入:nums1 = [2,5], nums2 = [3,4], k = 2
输出:8
解释:第 2 小的乘积计算如下:
- nums1[0] * nums2[0] = 2 * 3 = 6
- nums1[0] * nums2[1] = 2 * 4 = 8
第 2 小的乘积为 8 。

示例 2:

输入:nums1 = [-4,-2,0,3], nums2 = [2,4], k = 6
输出:0
解释:第 6 小的乘积计算如下:
- nums1[0] * nums2[1] = (-4) * 4 = -16
- nums1[0] * nums2[0] = (-4) * 2 = -8
- nums1[1] * nums2[1] = (-2) * 4 = -8
- nums1[1] * nums2[0] = (-2) * 2 = -4
- nums1[2] * nums2[0] = 0 * 2 = 0
- nums1[2] * nums2[1] = 0 * 4 = 0
第 6 小的乘积为 0 。

示例 3:

输入:nums1 = [-2,-1,0,1,2], nums2 = [-3,-1,2,4,5], k = 3
输出:-6
解释:第 3 小的乘积计算如下:
- nums1[0] * nums2[4] = (-2) * 5 = -10
- nums1[0] * nums2[3] = (-2) * 4 = -8
- nums1[4] * nums2[0] = 2 * (-3) = -6
第 3 小的乘积为 -6 。

 

提示:

  • 1 <= nums1.length, nums2.length <= 5 * 104
  • -105 <= nums1[i], nums2[j] <= 105
  • 1 <= k <= nums1.length * nums2.length
  • nums1 和 nums2 都是从小到大排好序的。

解法

方法一:二分查找

我们可以二分枚举乘积的值 pp,定义二分的区间为 [l,r][l, r],其中 l=max(nums1[0],nums1[n1])×max(nums2[0],nums2[n1])l = -\textit{max}(|\textit{nums1}[0]|, |\textit{nums1}[n - 1]|) \times \textit{max}(|\textit{nums2}[0]|, |\textit{nums2}[n - 1]|), r=lr = -l

对于每个 pp,我们计算出乘积小于等于 pp 的乘积的个数,如果这个个数大于等于 kk,那么说明第 kk 小的乘积一定小于等于 pp,我们就可以将区间右端点缩小到 pp,否则我们将区间左端点增大到 p+1p + 1

那么问题的关键就是如何计算乘积小于等于 pp 的乘积的个数。我们可以枚举 nums1\textit{nums1} 中的每个数 xx,分类讨论:

  • 如果 x>0x > 0,那么 x×nums2[i]x \times \textit{nums2}[i] 随着 ii 的增大是单调递增的,我们可以使用二分查找找到最小的 ii,使得 x×nums2[i]>px \times \textit{nums2}[i] > p,那么 ii 就是小于等于 pp 的乘积的个数,累加到个数 cnt\textit{cnt} 中;
  • 如果 x<0x < 0,那么 x×nums2[i]x \times \textit{nums2}[i] 随着 ii 的增大是单调递减的,我们可以使用二分查找找到最小的 ii,使得 x×nums2[i]px \times \textit{nums2}[i] \leq p,那么 nin - i 就是小于等于 pp 的乘积的个数,累加到个数 cnt\textit{cnt} 中;
  • 如果 x=0x = 0,那么 x×nums2[i]=0x \times \textit{nums2}[i] = 0,如果 p0p \geq 0,那么 nn 就是小于等于 pp 的乘积的个数,累加到个数 cnt\textit{cnt} 中。

这样我们就可以通过二分查找找到第 kk 小的乘积。

时间复杂度 O(m×logn×logM)O(m \times \log n \times \log M),其中 mmnn 分别为 nums1\textit{nums1}nums2\textit{nums2} 的长度,而 MMnums1\textit{nums1}nums2\textit{nums2} 中的最大值的绝对值。

Python3

class Solution:
    def kthSmallestProduct(self, nums1: List[int], nums2: List[int], k: int) -> int:
        def count(p: int) -> int:
            cnt = 0
            n = len(nums2)
            for x in nums1:
                if x > 0:
                    cnt += bisect_right(nums2, p / x)
                elif x < 0:
                    cnt += n - bisect_left(nums2, p / x)
                else:
                    cnt += n * int(p >= 0)
            return cnt

        mx = max(abs(nums1[0]), abs(nums1[-1])) * max(abs(nums2[0]), abs(nums2[-1]))
        return bisect_left(range(-mx, mx + 1), k, key=count) - mx

Java

class Solution {
    private int[] nums1;
    private int[] nums2;

    public long kthSmallestProduct(int[] nums1, int[] nums2, long k) {
        this.nums1 = nums1;
        this.nums2 = nums2;
        int m = nums1.length;
        int n = nums2.length;
        int a = Math.max(Math.abs(nums1[0]), Math.abs(nums1[m - 1]));
        int b = Math.max(Math.abs(nums2[0]), Math.abs(nums2[n - 1]));
        long r = (long) a * b;
        long l = (long) -a * b;
        while (l < r) {
            long mid = (l + r) >> 1;
            if (count(mid) >= k) {
                r = mid;
            } else {
                l = mid + 1;
            }
        }
        return l;
    }

    private long count(long p) {
        long cnt = 0;
        int n = nums2.length;
        for (int x : nums1) {
            if (x > 0) {
                int l = 0, r = n;
                while (l < r) {
                    int mid = (l + r) >> 1;
                    if ((long) x * nums2[mid] > p) {
                        r = mid;
                    } else {
                        l = mid + 1;
                    }
                }
                cnt += l;
            } else if (x < 0) {
                int l = 0, r = n;
                while (l < r) {
                    int mid = (l + r) >> 1;
                    if ((long) x * nums2[mid] <= p) {
                        r = mid;
                    } else {
                        l = mid + 1;
                    }
                }
                cnt += n - l;
            } else if (p >= 0) {
                cnt += n;
            }
        }
        return cnt;
    }
}

C++

class Solution {
public:
    long long kthSmallestProduct(vector<int>& nums1, vector<int>& nums2, long long k) {
        int m = nums1.size(), n = nums2.size();
        int a = max(abs(nums1[0]), abs(nums1[m - 1]));
        int b = max(abs(nums2[0]), abs(nums2[n - 1]));
        long long r = 1LL * a * b;
        long long l = -r;
        auto count = [&](long long p) {
            long long cnt = 0;
            for (int x : nums1) {
                if (x > 0) {
                    int l = 0, r = n;
                    while (l < r) {
                        int mid = (l + r) >> 1;
                        if (1LL * x * nums2[mid] > p) {
                            r = mid;
                        } else {
                            l = mid + 1;
                        }
                    }
                    cnt += l;
                } else if (x < 0) {
                    int l = 0, r = n;
                    while (l < r) {
                        int mid = (l + r) >> 1;
                        if (1LL * x * nums2[mid] <= p) {
                            r = mid;
                        } else {
                            l = mid + 1;
                        }
                    }
                    cnt += n - l;
                } else if (p >= 0) {
                    cnt += n;
                }
            }
            return cnt;
        };
        while (l < r) {
            long long mid = (l + r) >> 1;
            if (count(mid) >= k) {
                r = mid;
            } else {
                l = mid + 1;
            }
        }
        return l;
    }
};

Go

func kthSmallestProduct(nums1 []int, nums2 []int, k int64) int64 {
	m := len(nums1)
	n := len(nums2)
	a := max(abs(nums1[0]), abs(nums1[m-1]))
	b := max(abs(nums2[0]), abs(nums2[n-1]))
	r := int64(a) * int64(b)
	l := -r

	count := func(p int64) int64 {
		var cnt int64
		for _, x := range nums1 {
			if x > 0 {
				l, r := 0, n
				for l < r {
					mid := (l + r) >> 1
					if int64(x)*int64(nums2[mid]) > p {
						r = mid
					} else {
						l = mid + 1
					}
				}
				cnt += int64(l)
			} else if x < 0 {
				l, r := 0, n
				for l < r {
					mid := (l + r) >> 1
					if int64(x)*int64(nums2[mid]) <= p {
						r = mid
					} else {
						l = mid + 1
					}
				}
				cnt += int64(n - l)
			} else if p >= 0 {
				cnt += int64(n)
			}
		}
		return cnt
	}

	for l < r {
		mid := (l + r) >> 1
		if count(mid) >= k {
			r = mid
		} else {
			l = mid + 1
		}
	}
	return l
}

func abs(x int) int {
	if x < 0 {
		return -x
	}
	return x
}

TypeScript

function kthSmallestProduct(nums1: number[], nums2: number[], k: number): number {
    const m = nums1.length;
    const n = nums2.length;

    const a = BigInt(Math.max(Math.abs(nums1[0]), Math.abs(nums1[m - 1])));
    const b = BigInt(Math.max(Math.abs(nums2[0]), Math.abs(nums2[n - 1])));

    let l = -a * b;
    let r = a * b;

    const count = (p: bigint): bigint => {
        let cnt = 0n;
        for (const x of nums1) {
            const bx = BigInt(x);
            if (bx > 0n) {
                let l = 0,
                    r = n;
                while (l < r) {
                    const mid = (l + r) >> 1;
                    const prod = bx * BigInt(nums2[mid]);
                    if (prod > p) {
                        r = mid;
                    } else {
                        l = mid + 1;
                    }
                }
                cnt += BigInt(l);
            } else if (bx < 0n) {
                let l = 0,
                    r = n;
                while (l < r) {
                    const mid = (l + r) >> 1;
                    const prod = bx * BigInt(nums2[mid]);
                    if (prod <= p) {
                        r = mid;
                    } else {
                        l = mid + 1;
                    }
                }
                cnt += BigInt(n - l);
            } else if (p >= 0n) {
                cnt += BigInt(n);
            }
        }
        return cnt;
    };

    while (l < r) {
        const mid = (l + r) >> 1n;
        if (count(mid) >= BigInt(k)) {
            r = mid;
        } else {
            l = mid + 1n;
        }
    }

    return Number(l);
}

Rust

impl Solution {
    pub fn kth_smallest_product(nums1: Vec<i32>, nums2: Vec<i32>, k: i64) -> i64 {
        let m = nums1.len();
        let n = nums2.len();
        let a = nums1[0].abs().max(nums1[m - 1].abs()) as i64;
        let b = nums2[0].abs().max(nums2[n - 1].abs()) as i64;
        let mut l = -a * b;
        let mut r = a * b;

        let count = |p: i64| -> i64 {
            let mut cnt = 0i64;
            for &x in &nums1 {
                if x > 0 {
                    let mut left = 0;
                    let mut right = n;
                    while left < right {
                        let mid = (left + right) / 2;
                        if (x as i64) * (nums2[mid] as i64) > p {
                            right = mid;
                        } else {
                            left = mid + 1;
                        }
                    }
                    cnt += left as i64;
                } else if x < 0 {
                    let mut left = 0;
                    let mut right = n;
                    while left < right {
                        let mid = (left + right) / 2;
                        if (x as i64) * (nums2[mid] as i64) <= p {
                            right = mid;
                        } else {
                            left = mid + 1;
                        }
                    }
                    cnt += (n - left) as i64;
                } else if p >= 0 {
                    cnt += n as i64;
                }
            }
            cnt
        };

        while l < r {
            let mid = l + (r - l) / 2;
            if count(mid) >= k {
                r = mid;
            } else {
                l = mid + 1;
            }
        }
        l
    }
}