HiveBrain v1.2.0
Get Started
← Back to all entries
patterncppMinor

Finding the k-th element in a BST

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
thebstelementfinding

Problem

Given a Binary Search Tree, determine its k-th element in inorder traversal.

Here's my node structure:

struct node 
{
    int elem;
    node *left, *right;

    static node* create(int elem)
    {
        node *newnode = new node;
        newnode->elem = elem;
        newnode->left = newnode->right = NULL;
        return newnode;
    }   

        // Forget freeing up memory

};


Here's my BST:

```
class tree
{
private:
node *root;

public:
tree(int elem) { root = node::create(elem); }

bool insert(int elem)
{
node *newnode = node::create(elem);
node *travnode = root;

while(1)
{
if(elem < travnode->elem)
{
if(travnode-> left == NULL)
{
travnode->left = node::create(elem);
return true;
}
else travnode = travnode->left;
} // elem elem
else
if(elem > travnode->elem)
{

if(travnode->right == NULL)
{
travnode->right = node::create(elem);
return true;
}
else
travnode = travnode->right;

}
else
return false;
}

/* findKthInorder
@param mynode [in] -- the root of tree whose kth largest is to be found
@param k [in] -- value k
@param count [in] -- a counter to keep track of which node we're in
@param result [in,out] -- returns mynode->elem once we're in kth node
*/
void findKthInorder(const node *mynode, const int k, int &count, int &result) const
{
if(mynode != NULL)
{
findKthInorder(mynode->left,k,count,result);
if(!--count)
{
result = mynode->elem;
return;
} // if (!--count)

findKthInorder(mynode->right,k,count,result);
} // if (mynode != NULL)
} // findKthInorder

/* findKthInorder
abstracts away previous function and is exposed to outside world
*/
int findKthInorder(const int k) const
{
int count = k,result = 0;
findKthInorder(root,k,count,result);
return result;

}

Solution

If you add a total count field to each node, you can find the k-th element efficiently (in logarithmic time) by writing a method like this (untested):

node *kth(int k)
{
    assert(k >= 0 && k total)
            return left->kth(k);
        k -= left->total;
    }

    if (k == 0)
        return this;

    assert(right != NULL);
    return right->kth(k - 1);
}


Otherwise, the recursive algorithm you used for findKthInorder is the most elegant way I can think of to do this. I would clean it up a bit, though:

static const Node *kth_(const Node *node, int &k)
{
    if (node == NULL)
        return NULL;

    const Node *tmp = kth_(node->left, k);
    if (tmp != NULL)
        return tmp;

    if (k-- == 0)
        return node;

    return kth_(node->right, k);
}

int kth(int k) const
{
    assert(k >= 0);

    const Node *node = kth_(this, k);
    if (node == NULL) {
        std::cerr elem;
}


Returning a node pointer instead of an element from the helper function has two advantages:

  • We can use NULL to indicate failure.



  • We get to drop an argument from the helper function.



  • In the future, it will be easier to write a function to update the kth element.



In your findKthInOrder helper function, the k argument is never actually used, and can be dropped as well.

A couple cleanups on the side:

  • I renamed the class node to Node to avoid having to say mynode all over the place. I suppose this is just a matter of taste, seeing how the STL uses lowercase type names.



  • I switched to zero-based indexing. Again, this is a matter of taste, but zero-based indexing is far more common in C++, and is easier to work with in many cases.

Code Snippets

node *kth(int k)
{
    assert(k >= 0 && k < total);

    if (left != NULL) {
        if (k < left->total)
            return left->kth(k);
        k -= left->total;
    }

    if (k == 0)
        return this;

    assert(right != NULL);
    return right->kth(k - 1);
}
static const Node *kth_(const Node *node, int &k)
{
    if (node == NULL)
        return NULL;

    const Node *tmp = kth_(node->left, k);
    if (tmp != NULL)
        return tmp;

    if (k-- == 0)
        return node;

    return kth_(node->right, k);
}

int kth(int k) const
{
    assert(k >= 0);

    const Node *node = kth_(this, k);
    if (node == NULL) {
        std::cerr << "kth: k is too large\n";
        exit(1);
    }

    return node->elem;
}

Context

StackExchange Code Review Q#1750, answer score: 5

Revisions (0)

No revisions yet.