Skip to content

18. 4Sum

Given an array nums of n integers, return an array of all the unique quadruplets [nums[a], nums[b], nums[c], nums[d]] such that:

  • 0 <= a, b, c, d < n
  • a, b, c, and d are distinct.
  • nums[a] + nums[b] + nums[c] + nums[d] == target

You may return the answer in any order.

Example 1:

Input: nums = [1,0,-1,0,-2,2], target = 0
Output: [[-2,-1,1,2],[-2,0,0,2],[-1,0,0,1]]

Example 2:

Input: nums = [2,2,2,2,2], target = 8
Output: [[2,2,2,2]]

Solution:

class Solution {
    public List<List<Integer>> fourSum(int[] nums, int target) {
        List<List<Integer>> result = new ArrayList<>();
        Arrays.sort(nums);


        for (int i = 0; i < nums.length - 3; i++){
            int curI = nums[i];
            if (i > 0 && curI == nums[i - 1]){
                continue;
            }

            if ((long) curI + nums[nums.length - 1] + nums[nums.length - 2] + nums[nums.length - 3] < target){
                continue;
            }

            if ((long) curI + nums[i + 1] + nums[i + 2] + nums[i + 3] > target){
                break;
            }
            for (int j = i + 1; j < nums.length - 2; j++){
                int curJ = nums[j];

                int k = j + 1;
                int l = nums.length - 1;

                if (j > i + 1 && curJ == nums[j - 1]){
                    continue;
                }

                if ((long) curI  + curJ + nums[l - 1] + nums[l] < target){
                    continue;
                }

                if ((long) curI + curJ + nums[k] + nums[k + 1] > target){
                    break;
                }


                while(k < l){
                    int curK = nums[k];
                    int curL = nums[l];

                    int curSum = curI + curJ + curK + curL;
                    if (curSum == target){
                        result.add(Arrays.asList(curI, curJ , curK, curL));
                        k++;
                        while(k < l && nums[k] == nums[k - 1]){
                            k++;
                        }
                    }else if (curSum < target){
                        k++;
                    }else{
                        l--;
                    }
                }

            }

        }

        return result;

    }
}

// TC: O(n^3)
// SC: O(1)
class Solution {
  public List<List<Integer>> fourSum(int[] nums, int target) {
        Arrays.sort(nums);
        return kSum(nums, target, 0, 4);
    }

    public List<List<Integer>> kSum(int[] nums, long target, int start, int k){
        List<List<Integer>> result = new ArrayList<List<Integer>>();

        // if we have run out of numbers to add, return result.
        if (start == nums.length){
            return result;
        }


        // There are k remaining values to add to the sum. 
        // The average of these values is at least target/k.
        long average_value = target / k;
        // 这个平均值用于判断是否可能在当前剩余数组中找到符合条件的组合

        // We cannot obtain a sum of target if the smallest value 
        // in nums is greater than target / k or if the largest
        // value in nums is smaller than target / k.
        if (nums[start] > average_value || average_value > nums[nums.length - 1]){
            return result;
        }
        //如果 nums[start](当前起始元素)大于 average_value,意味着即使选择了数组中所有最小的元素,
        // 它们的和也会大于 target,因此无法找到有效的组合。
        // 如果 average_value 大于 nums[nums.length - 1](数组中的最大值),
        // 即使选择了所有最大的元素,它们的和也无法达到 target。

        if (k == 2){
            return twoSum(nums, target, start);
        }
        // 当 k == 2 时,问题简化为两数之和。这时,调用 twoSum 方法,
        // 这是一个使用双指针技术的更高效方法来找到所有不同的两个数,它们的和等于 target

        // 对于 k > 2 的情况,需要递归地求解
        for (int i = start; i < nums.length; i++){
            if (i > start && nums[i] == nums[i-1]){
                continue;
            }
            for (List<Integer> subset : kSum(nums, target - nums[i], i + 1, k - 1)){
                result.add(new ArrayList<Integer>(Arrays.asList(nums[i])));
                result.get(result.size() - 1).addAll(subset);
            }

            //在递归调用中,目标值变为 target - nums[i](减去当前选择的元素),
            //start 变为 i + 1(从下一个元素开始),k 减少 1(因为已经选择了一个元素
        }

        return result;
    }

    public List<List<Integer>> twoSum(int[] nums, long target, int start){
        List<List<Integer>> result = new ArrayList<>();
        int left = start;
        int right = nums.length - 1;

        while(left < right){
            int curSum = nums[left] + nums[right];
            if ((left > start) && nums[left] == nums[left -1]){
                left++;
            }else if (right < nums.length -1 && nums[right] == nums[right+1]){
                right--;
            }else if (curSum < target){
                left++;
            }else if (curSum > target){
                right--;
            }else{
                result.add(Arrays.asList(nums[left], nums[right]));
                left++;
                right--;
            }
        }

        return result;
    }
}

// TC: O(n^3)
// SC: O(n)
class Solution {
    public List<List<Integer>> fourSum(int[] nums, int target) {
        Arrays.sort(nums);
        int start = 0;
        int k = 4;
        return kSum(nums, target, start, k);
    }

    public List<List<Integer>> kSum(int[] nums, long target, int start, int k){
        List<List<Integer>> result = new ArrayList<List<Integer>>();

        if (start == nums.length){
            return result;
        }

        long average = target / k;

        if (nums[start] > average || average > nums[nums.length - 1]){
            return result;
        }


        if (k == 2){
            return twoSum(nums, target, start);
        }

        for (int i = start; i < nums.length - k + 1; i++){
            if (i > start && nums[i] == nums[i-1]){
                continue;
            }

            for (List<Integer> subset : kSum(nums, target - nums[i], i + 1, k-1)){
                result.add(new ArrayList<Integer>(Arrays.asList(nums[i]))); 
                /// List<List<Integer>>   [[]]    [[]   [nums[i]] ] 
                result.get(result.size() - 1).addAll(subset); // [nums[i]]
            }
        }
        return result;
    }


    public List<List<Integer>> twoSum(int[] nums, long target, int start){
        List<List<Integer>> result = new ArrayList<List<Integer>>();

        int left = start;
        int right = nums.length - 1;
        while(left < right){
            if (left > start && nums[left] == nums[left - 1]){
                left++;
            }else if (nums[left] + nums[right] == target){
                result.add(Arrays.asList(nums[left], nums[right]));
                left++;
                right--;
            }else if (nums[left] + nums[right] < target){
                left++;
            }else {
                right--;
            }
        }

        return result;

    }
}