HiveBrain v1.2.0
Get Started
← Back to all entries
patterncppMinor

C++ class for disjoint-set/union-find on integers

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
unionfordisjointfindintegersclassset

Problem

I have implemented the disjoint-set data structure. The purpose is to group integers together. For example, if I want to find out the groups of integers where each adjacent neighbors are same:

vector> test = {
{1,0,0,1},
{0,0,0,1},
{0,3,3,1}};


There are 4 groups of integers, namely:

  • Group 1: 1 at (0,0)



  • Group 2: 0's at (0,1), (0,2), (1,0), (1,1), (1,2), (2,0)



  • Group 3: 1's at (0,3), (1,3), (2,3)



  • Group 4: 3's at (2,1), (2,3)



Here is the code in C++:

//This union-find class implements two optimization ideas:
//1)path compression
//2)union by rank

class DisjointSet {
    public:
    DisjointSet(int num_nodes);
    int find(int element);
    void do_union(int elment_a, int element_b);
    void dup(vector &v);
    private:
    vector parent_, rank_;
};

DisjointSet::DisjointSet(int num_nodes):
    parent_(num_nodes),rank_(num_nodes, 0)
{
    for(int i=0; i<num_nodes; ++i) parent_[i] = i;
}

int DisjointSet::find(int element)
{
    if(parent_[element] == element) return element;
    //1)path compression
    return parent_[element] = find(parent_[element]);
}

void DisjointSet::do_union(int a, int b)
{
    if(parent_[a] == parent_[b]) return;
    int fa = find(a), fb = find(b);
    //2)union by rank
    if(rank_[fa] < rank_[fb]) {
    parent_[fa] = fb;
    } else {
    parent_[fb] = fa;
    if(rank_[fa] == rank_[fb]) rank_[fa]++;
    }
}

Solution

Seems pretty straightforward. By the way, "disjoint-set" is a less common name for this data structure than "union-find", at least in my experience.

You have several typos in your code; you should always fix those. Sure, the computer doesn't care, but we don't write C++ code for the computer; we write it for the human reader. elment_a should be element_a; you should either implement dup or remove its declaration; and you should use proper indentation.

class DisjointSet {
public:
    DisjointSet(int num_nodes);
    int find(int element);
    void do_union(int element_a, int element_b);
private:
    std::vector parent_;
    std::vector rank_;
};


Notice also that I added the missing std::s to your vectors (and by the way you should #include ), and split each member variable declaration onto its own line so that the struct layout is easier to see at a glance.

Who says using rank_ is an optimization? It feels to me like a pessimization, because you're trading O(N) extra space (and more convoluted code) for something like O(lg N) faster lookups. I do see that e.g. Boost.DisjointSet implements "union by rank" as well, but I'd be interested in seeing some benchmark numbers on your particular use-case.

return parent_[element] = find(parent_[element]);


Modern C++ (like modern any-other-language) strongly prefers not to mix side effects into the middle of expressions, except in idiomatic cases such as *p++ = .... Prefer to put the assignment in its own statement.

parent_[element] = find(parent_[element]);
return parent_[element];


int fa = find(a), fb = find(b);


Same deal as above: prefer to put each variable declaration on its own line for readability.

int fa = find(a);
int fb = find(b);


And then use proper indentation in what follows:

if (rank_[fa] < rank_[fb]) {
    parent_[fa] = fb;
} else {
    parent_[fb] = fa;
    if (rank_[fa] == rank_[fb]) rank_[fa]++;
}


Basically, though, looks good. Looks like you picked the right level of abstraction for this particular problem: not too scattered (keeping related things in a nice neat class), but not too libraryish (you hard-code int, for example, instead of going all in with templates).

Code Snippets

class DisjointSet {
public:
    DisjointSet(int num_nodes);
    int find(int element);
    void do_union(int element_a, int element_b);
private:
    std::vector<int> parent_;
    std::vector<int> rank_;
};
return parent_[element] = find(parent_[element]);
parent_[element] = find(parent_[element]);
return parent_[element];
int fa = find(a), fb = find(b);
int fa = find(a);
int fb = find(b);

Context

StackExchange Code Review Q#120824, answer score: 5

Revisions (0)

No revisions yet.