1577. 数的平方等于两数乘积的方法数

December 26, 2024 · View on GitHub

English Version

题目描述

给你两个整数数组 nums1nums2 ,请你返回根据以下规则形成的三元组的数目(类型 1 和类型 2 ):

  • 类型 1:三元组 (i, j, k) ,如果 nums1[i]2 == nums2[j] * nums2[k] 其中 0 <= i < nums1.length0 <= j < k < nums2.length
  • 类型 2:三元组 (i, j, k) ,如果 nums2[i]2 == nums1[j] * nums1[k] 其中 0 <= i < nums2.length0 <= j < k < nums1.length

 

示例 1:

输入:nums1 = [7,4], nums2 = [5,2,8,9]
输出:1
解释:类型 1:(1,1,2), nums1[1]^2 = nums2[1] * nums2[2] ($4^{2}$ = 2 * 8)

示例 2:

输入:nums1 = [1,1], nums2 = [1,1,1]
输出:9
解释:所有三元组都符合题目要求,因为 $1^{2}$ = 1 * 1
类型 1:(0,0,1), (0,0,2), (0,1,2), (1,0,1), (1,0,2), (1,1,2), nums1[i]^2 = nums2[j] * nums2[k]
类型 2:(0,0,1), (1,0,1), (2,0,1), nums2[i]^2 = nums1[j] * nums1[k]

示例 3:

输入:nums1 = [7,7,8,3], nums2 = [1,2,9,7]
输出:2
解释:有两个符合题目要求的三元组
类型 1:(3,0,2), nums1[3]^2 = nums2[0] * nums2[2]
类型 2:(3,0,1), nums2[3]^2 = nums1[0] * nums1[1]

示例 4:

输入:nums1 = [4,7,9,11,23], nums2 = [3,5,1024,12,18]
输出:0
解释:不存在符合题目要求的三元组

 

提示:

  • 1 <= nums1.length, nums2.length <= 1000
  • 1 <= nums1[i], nums2[i] <= $10^{5}$

解法

方法一:哈希表 + 枚举

我们用哈希表 cnt1\textit{cnt1} 统计 nums1\textit{nums1} 中每个数对 (nums[j],nums[k])(\textit{nums}[j], \textit{nums}[k]) 出现的次数,其中 $0 \leq j \lt k < m,其中,其中 m为数组为数组\textit{nums1}的长度。用哈希表的长度。用哈希表\textit{cnt2}统计统计\textit{nums2}中每个数对中每个数对(\textit{nums}[j], \textit{nums}[k]) 出现的次数,其中 \0 \leq j \lt k < n,其中,其中 n为数组为数组\textit{nums2}$ 的长度。

接下来,我们枚举数组 nums1\textit{nums1} 中的每个数 xx,计算 cnt2[x2]\textit{cnt2}[x^2] 的值,即 nums2\textit{nums2} 中有多少对数 (nums[j],nums[k])(\textit{nums}[j], \textit{nums}[k]) 满足 nums[j]×nums[k]=x2\textit{nums}[j] \times \textit{nums}[k] = x^2。同理,我们枚举数组 nums2\textit{nums2} 中的每个数 xx,计算 cnt1[x2]\textit{cnt1}[x^2] 的值,即 nums1\textit{nums1} 中有多少对数 (nums[j],nums[k])(\textit{nums}[j], \textit{nums}[k]) 满足 nums[j]×nums[k]=x2\textit{nums}[j] \times \textit{nums}[k] = x^2,最后将两者相加返回即可。

时间复杂度 O(m2+n2+m+n)O(m^2 + n^2 + m + n),空间复杂度 O(m2+n2)O(m^2 + n^2)。其中 mmnn 分别为数组 nums1\textit{nums1}nums2\textit{nums2} 的长度。

Python3

class Solution:
    def numTriplets(self, nums1: List[int], nums2: List[int]) -> int:
        def count(nums: List[int]) -> Counter:
            cnt = Counter()
            for j in range(len(nums)):
                for k in range(j + 1, len(nums)):
                    cnt[nums[j] * nums[k]] += 1
            return cnt

        def cal(nums: List[int], cnt: Counter) -> int:
            return sum(cnt[x * x] for x in nums)

        cnt1 = count(nums1)
        cnt2 = count(nums2)
        return cal(nums1, cnt2) + cal(nums2, cnt1)

Java

class Solution {
    public int numTriplets(int[] nums1, int[] nums2) {
        var cnt1 = count(nums1);
        var cnt2 = count(nums2);
        return cal(cnt1, nums2) + cal(cnt2, nums1);
    }

    private Map<Long, Integer> count(int[] nums) {
        Map<Long, Integer> cnt = new HashMap<>();
        int n = nums.length;
        for (int j = 0; j < n; ++j) {
            for (int k = j + 1; k < n; ++k) {
                long x = (long) nums[j] * nums[k];
                cnt.merge(x, 1, Integer::sum);
            }
        }
        return cnt;
    }

    private int cal(Map<Long, Integer> cnt, int[] nums) {
        int ans = 0;
        for (int x : nums) {
            long y = (long) x * x;
            ans += cnt.getOrDefault(y, 0);
        }
        return ans;
    }
}

C++

class Solution {
public:
    int numTriplets(vector<int>& nums1, vector<int>& nums2) {
        auto cnt1 = count(nums1);
        auto cnt2 = count(nums2);
        return cal(cnt1, nums2) + cal(cnt2, nums1);
    }

    unordered_map<long long, int> count(vector<int>& nums) {
        unordered_map<long long, int> cnt;
        for (int i = 0; i < nums.size(); i++) {
            for (int j = i + 1; j < nums.size(); j++) {
                cnt[(long long) nums[i] * nums[j]]++;
            }
        }
        return cnt;
    }

    int cal(unordered_map<long long, int>& cnt, vector<int>& nums) {
        int ans = 0;
        for (int x : nums) {
            ans += cnt[(long long) x * x];
        }
        return ans;
    }
};

Go

func numTriplets(nums1 []int, nums2 []int) int {
	cnt1 := count(nums1)
	cnt2 := count(nums2)
	return cal(cnt1, nums2) + cal(cnt2, nums1)
}

func count(nums []int) map[int]int {
	cnt := map[int]int{}
	for j, x := range nums {
		for _, y := range nums[j+1:] {
			cnt[x*y]++
		}
	}
	return cnt
}

func cal(cnt map[int]int, nums []int) (ans int) {
	for _, x := range nums {
		ans += cnt[x*x]
	}
	return
}

TypeScript

function numTriplets(nums1: number[], nums2: number[]): number {
    const cnt1 = count(nums1);
    const cnt2 = count(nums2);
    return cal(cnt1, nums2) + cal(cnt2, nums1);
}

function count(nums: number[]): Map<number, number> {
    const cnt: Map<number, number> = new Map();
    for (let j = 0; j < nums.length; ++j) {
        for (let k = j + 1; k < nums.length; ++k) {
            const x = nums[j] * nums[k];
            cnt.set(x, (cnt.get(x) || 0) + 1);
        }
    }
    return cnt;
}

function cal(cnt: Map<number, number>, nums: number[]): number {
    return nums.reduce((acc, x) => acc + (cnt.get(x * x) || 0), 0);
}

方法二:哈希表 + 枚举优化

我们用哈希表 cnt1\textit{cnt1} 统计 nums1\textit{nums1} 中每个数出现的次数,用哈希表 cnt2\textit{cnt2} 统计 nums2\textit{nums2} 中每个数出现的次数。

接下来,我们枚举数组 nums1\textit{nums1} 中的每个数 xx,然后枚举 cnt2\textit{cnt2} 中的每个数对 (y,v1)(y, v1),其中 yycnt2\textit{cnt2} 的键,v1v1cnt2\textit{cnt2} 的值。我们计算 z=x2/yz = x^2 / y,如果 y×z=x2y \times z = x^2,此时如果 y=zy = z,说明 yyzz 是同一个数,那么 v1=v2v1 = v2,从 v1v1 个数中任选两个数的方案数为 v1×(v11)=v1×(v21)v1 \times (v1 - 1) = v1 \times (v2 - 1);如果 yzy \neq z,那么 v1v1 个数中任选两个数的方案数为 v1×v2v1 \times v2。最后将所有方案数相加并除以 $2 即可。这里除以 \2是因为我们统计的是对数对是因为我们统计的是对数对(j, k)的方案数,而实际上的方案数,而实际上(j, k)(k, j)$ 是同一种方案。

时间复杂度 O(m×n)O(m \times n),空间复杂度 O(m+n)O(m + n)。其中 mmnn 分别为数组 nums1\textit{nums1}nums2\textit{nums2} 的长度。

Python3

class Solution:
    def numTriplets(self, nums1: List[int], nums2: List[int]) -> int:
        def cal(nums: List[int], cnt: Counter) -> int:
            ans = 0
            for x in nums:
                for y, v1 in cnt.items():
                    z = x * x // y
                    if y * z == x * x:
                        v2 = cnt[z]
                        ans += v1 * (v2 - int(y == z))
            return ans // 2

        cnt1 = Counter(nums1)
        cnt2 = Counter(nums2)
        return cal(nums1, cnt2) + cal(nums2, cnt1)

Java

class Solution {
    public int numTriplets(int[] nums1, int[] nums2) {
        var cnt1 = count(nums1);
        var cnt2 = count(nums2);
        return cal(cnt1, nums2) + cal(cnt2, nums1);
    }

    private Map<Integer, Integer> count(int[] nums) {
        Map<Integer, Integer> cnt = new HashMap<>();
        for (int x : nums) {
            cnt.merge(x, 1, Integer::sum);
        }
        return cnt;
    }

    private int cal(Map<Integer, Integer> cnt, int[] nums) {
        long ans = 0;
        for (int x : nums) {
            for (var e : cnt.entrySet()) {
                int y = e.getKey(), v1 = e.getValue();
                int z = (int) (1L * x * x / y);
                if (y * z == x * x) {
                    int v2 = cnt.getOrDefault(z, 0);
                    ans += v1 * (y == z ? v2 - 1 : v2);
                }
            }
        }
        return (int) (ans / 2);
    }
}

Go

func numTriplets(nums1 []int, nums2 []int) int {
	cnt1 := count(nums1)
	cnt2 := count(nums2)
	return cal(cnt1, nums2) + cal(cnt2, nums1)
}

func count(nums []int) map[int]int {
	cnt := map[int]int{}
	for _, x := range nums {
		cnt[x]++
	}
	return cnt
}

func cal(cnt map[int]int, nums []int) (ans int) {
	for _, x := range nums {
		for y, v1 := range cnt {
			z := x * x / y
			if y*z == x*x {
				if v2, ok := cnt[z]; ok {
					if y == z {
						v2--
					}
					ans += v1 * v2
				}
			}
		}
	}
	ans /= 2
	return
}

TypeScript

function numTriplets(nums1: number[], nums2: number[]): number {
    const cnt1 = count(nums1);
    const cnt2 = count(nums2);
    return cal(cnt1, nums2) + cal(cnt2, nums1);
}

function count(nums: number[]): Map<number, number> {
    const cnt: Map<number, number> = new Map();
    for (const x of nums) {
        cnt.set(x, (cnt.get(x) || 0) + 1);
    }
    return cnt;
}

function cal(cnt: Map<number, number>, nums: number[]): number {
    let ans: number = 0;
    for (const x of nums) {
        for (const [y, v1] of cnt) {
            const z = Math.floor((x * x) / y);
            if (y * z == x * x) {
                const v2 = cnt.get(z) || 0;
                ans += v1 * (y === z ? v2 - 1 : v2);
            }
        }
    }
    return ans / 2;
}