HiveBrain v1.2.0
Get Started
← Back to all entries
patterncppMinor

3Sum implementation

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
implementation3sumstackoverflow

Problem

I have recently started to work through the problems on Leetcode, for the sake of bettering my own skills as well as preparing for interviews. I was faced with the 3Sum problem which is:


Given an array S of n integers, are there elements a, b, c in S such that a + b + c = 0? Find all unique triplets in the array which gives the sum of zero.

Now I thought about this: the brute force solution is to iterate through the array twice for every number to find the other two numbers that make the sum of the three equal 0. This would have a time complexity of \$O(n^3)\$ which is less than ideal. So I thought maybe it would be instead good to store all the pairs in the given array.

Then iterate through the array and see if the sum with any pair is equal to 0, and push the pair along with that number into a vector for the output. It is slightly working backwards, and I thought that this would give me an efficiency of \$O(n^2)\$.

vector> threeSum(vector& nums) {
    vector> pairs;
    vector> output;

    if(nums.size() curr;
                    curr.push_back(x.first);
                    curr.push_back(x.second);
                    curr.push_back(nums.at(i));
                    sort(curr.begin(),curr.end());
                    output.push_back(curr);

                }
            }

        }
    }
    int count = 0;
    for(int i:nums){
        if(i==0){
            count++;
            if(count==3){
                vector zeros;
                zeros.push_back(0);
                zeros.push_back(0);
                zeros.push_back(0);
                output.push_back(zeros);
                break;
            }
        }
    }

    sort(output.begin(),output.end());
    output.erase(unique(output.begin(),output.end()),output.end());
    return output;

}


At the end in the lines:

```
int count = 0;
for(int i:nums){
if(i==0){
count++;
if(count==3){
vector zeros;
zeros.push_back(0);

Solution

Avoid unnecessary memory allocations.

Memory allocations take time, significant time if you do them in your inner most loop.

For starters, change:

vector> output;


to:

vector> output;


This converts your putput vector from "A vector of pointers to fixed size dynamic elements" (note the oxymoron, fixed size dynamic elements, there's your problem) to "A vector of fixed size elements". Not only will you be avoiding lots and lots of memory allocations but you will also improve your cache performance by miles because your entire output vector is now contiguous and cache friendly.

Then change:

for(int i=0; i curr;
                curr.push_back(x.first);
                curr.push_back(x.second);
                curr.push_back(nums.at(i));
                sort(curr.begin(),curr.end());
                output.push_back(curr);
            }
        }
    }
}


to:

for(int i=0; i b)
                    swap(a,b);
                if(a > c)
                    swap(a,c);
                if(b > c)
                    swap(b,c);
                output.emplace_back(a,b,c);
            }
        }
    }
}


Return value

You are returning a vector by value, a naive compiler may cause the entire vector contents to be copied and if you have a vector of vectors, this is really bad. However most modern compilers perform Return Value Optimization(RVO). However if you don't want to be at the mercy of a possibly dodgy compiler, just pass a reference to a vector where you want to store the results as an argument to the function.

Time complexity

So you do this:

for(int i=0; i<nums.size();i++){
    for(int j=i+1; j<nums.size();j++){
            pairs.push_back(make_pair(nums.at(i),nums.at(j)));


which means that pairs.size() == n^2. Then proceed with sorting:

sort(pairs.begin(),pairs.end());


Because you are sorting m = n^2 entries, the time complexity becomes O(mlog(m)) = O(n^2log(n)).

But then you do this...

for(int i=0; i<nums.size();i++){
    for(auto x:pairs){
        if(x.first!=nums.at(i) && x.second!=nums.at(i)){


Okay so this becomes O(n*n^2). Note that O(n^2) is easily achievable by other algorithms.

This is why you get TLE.

Better algorithm

We can easily reach O(n^2) time complexity.

We need to find all a, b, c such that a+b+c=0. Note that this is equivalent to c = -(a+b). Hence if we can check if c exists in the input in O(1) time, then we just need to try each pair of a and b and see if a matching c exists. Since there are O(n^2) pairs we have O(n^2*1) time.

A hash set provides the necessary O(1) check if c is present in the input.

(I'm not going to handle the three zeros, you can figure that out).

Pseudocode:

unordered_map hashset;
for(auto& x : input){ hashset.put(x); }

for(int i = 0; i < input.size(); ++i){ 
    for(int j = i+1; j < input.size(); j++){
        auto c = -(input[i] + input[j]);
        if(hashset.contains(c)){
            output.addTuple(input[i], input[j], - c);
        }
    }
}

Code Snippets

vector<vector<int>> output;
vector<std::tuple<int, int, int>> output;
for(int i=0; i<nums.size();i++){
    for(auto x:pairs){
        if(x.first!=nums.at(i) && x.second!=nums.at(i)){
            if(x.first+x.second+nums.at(i)==0){
                vector<int> curr;
                curr.push_back(x.first);
                curr.push_back(x.second);
                curr.push_back(nums.at(i));
                sort(curr.begin(),curr.end());
                output.push_back(curr);
            }
        }
    }
}
for(int i=0; i<nums.size();i++){
    for(auto& x:pairs){
        if(x.first!=nums.at(i) && x.second!=nums.at(i)){
            if(x.first+x.second+nums.at(i)==0){
                int a = x.first;
                int b = x.second;
                int c = nums.at(i);
                // You could write out the if-else tree to 
                // possibly optimise this if your compiler doesn't 
                // already do that.
                if(a > b)
                    swap(a,b);
                if(a > c)
                    swap(a,c);
                if(b > c)
                    swap(b,c);
                output.emplace_back(a,b,c);
            }
        }
    }
}
for(int i=0; i<nums.size();i++){
    for(int j=i+1; j<nums.size();j++){
            pairs.push_back(make_pair(nums.at(i),nums.at(j)));

Context

StackExchange Code Review Q#135918, answer score: 5

Revisions (0)

No revisions yet.