HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

Breadth-first tree cutting

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
breadthfirsttreecutting

Problem

I have a working solution for 'Cut the tree' problem but it is not fast enough.

Problem Statement:


Atul is into graph theory, and he is learning about trees nowadays. He observed that the removal of an edge from a given tree T will result in the formation of two separate trees, T1 and T2.


Each vertex of the tree T is assigned a positive integer. Your task is to remove an edge, such that the Tree_diff of the resultant trees is minimized. Tree_diff is defined as the following:

F(T) = Sum of numbers written on each vertex of a tree T
 Tree_diff(T) = abs(F(T1) - F(T2))



Input Format:


The first line will contain an integer \$N\$, i.e. the number of vertices in the tree.


The next line will contain \$N\$ integers separated by a single space, i.e. the values assigned to each of the vertices.


The next \$N−1\$ lines contain a pair of integers each, separated by a single space, that denote the edges of the tree.


In the above input, the vertices are numbered from \$1\$ to \$N\$.

Output Format:


A single line containing the minimum value of Tree_diff.

Constraints:


\$3≤N≤105\$
\$1≤\$ number written on each vertex \$≤1001\$

Sample Input:

6  
100 200 100 500 100 600  
1 2  
2 3  
2 5  
4 5  
5 6



Sample Output:

400



Explanation:


Originally, we can represent tree as

1(100)
          \
           2(200)
          / \
    (100)5   3(100)
        / \
  (500)4   6(600)

Cutting the edge at 1 2 would result in Tree_diff = 1500-100 = 1400 
Cutting the edge at 2 3 would result in Tree_diff = 1500-100 = 1400 
Cutting the edge at 2 5 would result in Tree_diff = 1200-400 = 800 
Cutting the edge at 4 5 would result in Tree_diff = 1100-500 = 600 
Cutting the edge at 5 6 would result in Tree_diff = 1000-600 = 400




Hence, the answer is 400.

The solution is in Python, with iterative BFS. I could only pass the first five tests and timed out on the rest. A

Solution

For every edge in your graph, you are doing a BFS of the tree on one side of the edge to sum all of its vertex weights. It is not hard to see that you are doing a lot of work over an over again. Imagine that every vertex in your tree was connected to only 2 other vertices, so that it basically looked like a doubly linked list. If we store the list in an array, your code would be doing something akin to:

def min_diff_cut(list_):
    total = sum(list_)
    best_diff = total
    for idx in range(1, len(list_)):
        best_diff = min(best_diff, total - sum(list[:idx]))
    return best_diff


This of course has \$O(n^2)\$ complexity, which is rarely a good thing. A much more efficient implementation running in \$O(n)\$ time would do something like:

def min_diff_cut(list_):
    total = sum(list_)
    best_diff = total
    subtotal = 0
    for value in list_:
        subtotal += value
        best_diff = min(best_diff, total - subtotal)
    return best_diff


How can you expand this idea to work on your possibly branched tree? Well, if instead of breadth first, you go depth first, you can recursively find out what the best diff and sum of weights down an edge is, and come up with a solution from there. This would look something like:

def min_diff_cut(graph, weights):
    total = sum(weights)
    visited = set()

    def subtree_total_weight(vertex):
        visited.add(vertex)
        subdiff = total
        subtotal = weights[vertex]
        for next_vtx in graph[vertex]:
            if next_vtx in visited:
                continue
            subt, subd = subtree_total_weight(next_vtx)
            subtotal += subt
            subdiff = min(subdiff, subd)
        return subtotal, min(subdiff, abs(2*subtotal - total))

    # We are assuming vertices are 0-based indices
    _, best_diff = subtree_total_weight(0)

    return best_diff

if __name__ == '__main__':
    print min_diff_cut({0: [1], 1: [0, 2, 4], 2: [1], 3: [4],
                        4: [1, 3, 5], 5: [4]},
                       [100, 200, 100, 500, 100, 600])

Code Snippets

def min_diff_cut(list_):
    total = sum(list_)
    best_diff = total
    for idx in range(1, len(list_)):
        best_diff = min(best_diff, total - sum(list[:idx]))
    return best_diff
def min_diff_cut(list_):
    total = sum(list_)
    best_diff = total
    subtotal = 0
    for value in list_:
        subtotal += value
        best_diff = min(best_diff, total - subtotal)
    return best_diff
def min_diff_cut(graph, weights):
    total = sum(weights)
    visited = set()

    def subtree_total_weight(vertex):
        visited.add(vertex)
        subdiff = total
        subtotal = weights[vertex]
        for next_vtx in graph[vertex]:
            if next_vtx in visited:
                continue
            subt, subd = subtree_total_weight(next_vtx)
            subtotal += subt
            subdiff = min(subdiff, subd)
        return subtotal, min(subdiff, abs(2*subtotal - total))

    # We are assuming vertices are 0-based indices
    _, best_diff = subtree_total_weight(0)

    return best_diff

if __name__ == '__main__':
    print min_diff_cut({0: [1], 1: [0, 2, 4], 2: [1], 3: [4],
                        4: [1, 3, 5], 5: [4]},
                       [100, 200, 100, 500, 100, 600])

Context

StackExchange Code Review Q#98735, answer score: 4

Revisions (0)

No revisions yet.