Avin's Blog

Four Number Sum [Python]

December 27, 2019
Tags: leetcode, algorithmic question, arrays, python,

We are given an array, for example [5, 8, -2, 3, 6, 7] and a target sum 14, we have to find all quadruplets that sum up to the target sum 14. In this example the answer should be [[5, 8, -2, 3], [-2, 3, 6, 7]].

We can always start with brute force solution, use 4 loops to get all the possible combinations of quadruplets and check if they sum up to the target sum. This would result in a running time of O(N^4)

Another possibility would be to sort the array and use 3 loops, two of them to get all combinations of two numbers and the last loop to get the remaining two numbers using two pointers since the array is sorted.

Its possible to solve this in O(n^2). The idea is to treat this like a two number sum, P + Q = targetSum where P = sum of two numbers previously seen and Q = sum of two numbers we are currently on.

Take a quick look at the solution first to better understand what is going on here.

def fourNumberSum(array, targetSum):
    seen = {}
    result = [] # quadruplets

    # get Q(current pair of numbers) and check if P we have seen P
    for i in range(1, len(array) - 1):
        for j in range(i+1, len(array)):
            currentSum = array[i] + array[j]
            difference = targetSum - currentSum

            if difference in seen:
                for pair in seen[difference]:
                    result.append(pair + [array[i] , array[j]])
        
        # Add P(pair of numbers previously seen)
        for k in range(i):
            currentSum = array[k] + array[i]
            if currentSum not in seen:
                seen[currentSum] = []
            seen[currentSum].append([array[k], array[i]])
        
    return result

One of the ways to implement this is to simply use two loops and loops over all the pairs and add all the pairs that we come across to our seen dictionary, but this may cause us to generate duplicate quadruplets.

To not generate duplicate quadruplets, we tweak our strategy a little instead of adding all pairs as we see them to our dictionary we add all of the previous pairs AFTER we have seen all the current pairs.

For example in the array [5, 8, -2, 3, 6, 7] we are at 8, we see all the combinations [8, -2], [8, 3]...[8, 7] and check if we have the difference of (sum of the pair) and (the target sum) in our dictionary seen. If we do have the difference in our dictionary then we join all the pairs that sum up to the difference with our current pair and append then to the result array.

After we are done looking at all the current pairs we update our dictionary with pairs of all the numbers before 8, in the previous example we would add [8, 5] to the dictionary, if we were on -2 the pairs would be [5, -2], [8, -2]. So our dictionary would look something like this:

seen = {
    13: [[8, 5 ]]
    3: [[5, -2]]
    6: [[8, -2]]
}

By doing this we make sure that the pairs in the dictionary came by before the current index/pair, which in turn allows us to not repeat a quadruplet and use an array to store them instead of a set. (This assumes that the input array does not contain duplicate elements)

The problem is fairly simple to implement in code once we grasp the idea of using two pairs and treating them like two independent numbers, basically turning the problem to two numbers. The only reason this is not O(n) is because to generate the pairs we need another inner loops making this O(n^2).

The space complexity would be O(n^2). If we look at the code we quickly realize that the dictionary seen takes most the space since we are doing operations on it inside the inner loop which means technically we could add every single time to the dictionary. Hence O(n^2) space complexity.

Here is a similar question on LeetCode, the only difference between the question on LeetCode and the one we solved above is that the array may contain duplicate elements. The same solution works with a few tweaks.

class Solution:
    def fourSum(self, array: List[int], targetSum: int) -> List[List[int]]:
        seen = {}
        result = set() # quadruplets
        array.sort() # Added this so that we have consistent order in our result tuples

        # get Q(current pair of numbers) and check if P we have seen P
        for i in range(1, len(array) - 1):
            for j in range(i+1, len(array)):
                currentSum = array[i] + array[j]
                difference = targetSum - currentSum

                if difference in seen:
                    for pair in seen[difference]:
                        result.add((pair[0], pair[1], array[i] , array[j]))

            # Add P(pair of numbers previously seen)
            for k in range(i):
                currentSum = array[k] + array[i]
                if currentSum not in seen:
                    seen[currentSum] = []
                seen[currentSum].append((array[k], array[i]))

        return [list(quad) for quad in result]