HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

Flatten double-linked list

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
listdoublelinkedflatten

Problem

Consider this interview question:


Given a Node with a reference to a child, its next, its previous, and a variable for its value (\$\frac{1}{2}\$ tree, \$\frac{1}{2}\$ doubly linked list structure), you have to find a way to flatten the structure.

1 = 2 = 3 = 5
    |       |
    6 = 7   8
        |
        9




(reference down is child and reference across is next)


this above diagram becomes: 1 - 2 - 6 - 7- 9 - 3 - 5 - 8 but:



  • 2 still has child reference to 6



  • 7 still has child reference to 9



  • 5 still has child reference to 8




Here is my code:

class Node(object):
    def __init__(self, value, next_node=None, prev=None, child=None):
        self.value = value
        self.next = next_node
        self.prev = prev
        self.child = child

    def __repr__(self):
        return str(self.value)

one = Node(1)
two = Node(2)
three = Node(3)
five = Node(5)
six = Node(6)
seven = Node(7)
eight = Node(8)
nine = Node(9)
one.next = two
two.prev = one
two.next = three
two.child = six
three.prev = two
three.next = five
five.prev = three
five.child = eight
six.next = seven
seven.prev = six
seven.child = nine

def flatten(head):
    if not head: return head

    c = flatten(head.child) 
    n = flatten(head.next)
    if c:
        head.next = head.child
        head.child.prev = head
    if c and n:
        tail = c
        while tail.next:
            tail = tail.next
        tail.next = n
        n.prev = tail
    return head

n = flatten(one)
while n:
    print "%d -" % (n.value,),
    n = n.next


I don't like the while loop. I could store tail and pass it around. Any better solution?

Solution

Your current implementation is quite nice.

The only way you are going to avoid using the while loop to find the right-most node is if you have each individual branch maintain its branch's tail Node. This is a [sort of] good approach to the problem. However, once you reach this point, a Node really starts becoming more than just a Node: it now has information about the entire structure instead of what is directly around it.

So instead of modifying the Node class, you could create a Tree class.

This change makes sense:

  • Right now you manually create your 'tree'. This approach is not very portable. A Tree class can handle its own creation.



  • You need to store information about the entire branch structure somewhere. Inside a Node doesn't feel right.



  • You don't really flatten a Node, you flatten a tree of nodes. Thus, it again makes sense to move the flatten function inside a Tree class (or at least take a Tree parameter).



Here is a skeleton version of my Tree class. I have stubbed out all my methods except flatten as that is the function in question:

class Tree(object):
    def __init__(self, root=None):
        self.root = root
        self.tail = root        
        self.length = 1 if root else 0

    def __iter__(self):
        ''' Allows iteration through the tree using `for node in tree`. '''
        pass

    def __getitem__(self, index):
        ''' Fetches the `Node` at the passed index. '''
        pass

    def grow(self, nodes):
        ''' Appends node/nodes/Tree onto itself. Be careful when adding `Node`
            we need to append a *copy* of a `Node`, not the reference. '''
        pass

    def graft(self, tree, node_pos=0):
        ''' Grafts in another tree to this one. By default the new tree
            is grafted to the root of the tree. '''
        pass

    def flatten(self):
        ''' Returns the flattened version of itself. '''
        flat = Tree()
        for next in self:
            flat.grow(next)
            if next.child:
                flat.grow(next.child.flatten())
        return flat


In this version of flatten, we keep the recursive feel as we call a method with the same method declaration. However, we are not doing true recursion because the methods are bound to different objects, and thus are different methods.

In any case, the flatten function is really quite simple:

  • Create the new Tree we are going to modify and return.



  • Loop through the Nodes of the current tree.



-
Append the current Node.

3.1. If that Node has a child Tree, append the flattened version of that Tree.

  • Return the final flattened Tree.



Here is my Tree class plus your Node class in action:

one, two, three, four, five, six, seven, eight, nine = [Node(num) for num in xrange(1,10)]

tree = Tree(one)
tree.grow([two, three, five])

# Graft the new `Tree` to the 3rd `Node`.
tree.graft(Tree(eight), 3)

branch = Tree(six)
branch.grow(seven)

# Graft the new `Tree` to the 1st `Node`.
branch.graft(Tree(nine), 1)

tree.graft(branch, 1)

def flatten(tree):
    ''' A non-class version of the `flatten` function. '''
    flat = Tree()
    for next in tree:
        flat.grow(next)
        if next.child:
            flat.grow(flatten(next.child))
    return flat

print ' - '.join([str(node.value) for node in flatten(tree)])
print ' - '.join([str(node.value) for node in tree.flatten()])

Code Snippets

class Tree(object):
    def __init__(self, root=None):
        self.root = root
        self.tail = root        
        self.length = 1 if root else 0

    def __iter__(self):
        ''' Allows iteration through the tree using `for node in tree`. '''
        pass

    def __getitem__(self, index):
        ''' Fetches the `Node` at the passed index. '''
        pass

    def grow(self, nodes):
        ''' Appends node/nodes/Tree onto itself. Be careful when adding `Node`
            we need to append a *copy* of a `Node`, not the reference. '''
        pass

    def graft(self, tree, node_pos=0):
        ''' Grafts in another tree to this one. By default the new tree
            is grafted to the root of the tree. '''
        pass

    def flatten(self):
        ''' Returns the flattened version of itself. '''
        flat = Tree()
        for next in self:
            flat.grow(next)
            if next.child:
                flat.grow(next.child.flatten())
        return flat
one, two, three, four, five, six, seven, eight, nine = [Node(num) for num in xrange(1,10)]

tree = Tree(one)
tree.grow([two, three, five])

# Graft the new `Tree` to the 3rd `Node`.
tree.graft(Tree(eight), 3)

branch = Tree(six)
branch.grow(seven)

# Graft the new `Tree` to the 1st `Node`.
branch.graft(Tree(nine), 1)

tree.graft(branch, 1)

def flatten(tree):
    ''' A non-class version of the `flatten` function. '''
    flat = Tree()
    for next in tree:
        flat.grow(next)
        if next.child:
            flat.grow(flatten(next.child))
    return flat


print ' - '.join([str(node.value) for node in flatten(tree)])
print ' - '.join([str(node.value) for node in tree.flatten()])

Context

StackExchange Code Review Q#54796, answer score: 2

Revisions (0)

No revisions yet.