patternpythonMinor
Find number of neighbour pairs in the input tree
Viewed 0 times
findnumbertheinputneighbourpairstree
Problem
My code works with small binary tree but with big ones (around 10000 nodes) it takes 15 sec to process.
For small binary tree I'm always getting correct answer and it is fast.
How can I make my program faster for big binary tree?
(x, y) is a neighbour pair when the depth of x and y is the same, the
node color of x and y is the same, the node key of x and y are equal
and there is no other node of the same node color between x and y on the horizontal line connecting x and y in standard drawing..
Input
The first line of input contains number of nodes and root node. Next,
each line describes each node.
The first value is the node label, the second value is the node key,
the third and the fourth values represent the labels of the left and
right child respectively, and the fifth value represents the node
color, white is 0, black is 1. If any of the children does not exist
there is value 0 instead of the child label at the corresponding place
Output
number of neighbour pairs in the input tree
My code
```
from collections import defaultdict
def read_input(inputstring):
inputs = inputstring.split(" ")
nodes_number = int(inputs[0])
root_node = int(inputs[1])
input_tree = [list(map(int, input().split())) for _ in range(nodes_number)]
tree = {}
result = {}
def depth(node, count):
for i in input_tree:
node_label = i[0]
if node_label == node:
tree.setdefault(count, []).append(i)
l_node, r_node = i[2], i[3]
if l_node != 0 and r_node != 0:
depth(l_node, count+1)
depth(r_node, count+1)
elif l_node != 0:
depth(l_node, count+1)
elif r_node != 0:
depth(r_node, count+1)
depth(root_node, 0)
def neighbour_pair(input_tree):
for value in input_tree.values():
for i i
For small binary tree I'm always getting correct answer and it is fast.
How can I make my program faster for big binary tree?
(x, y) is a neighbour pair when the depth of x and y is the same, the
node color of x and y is the same, the node key of x and y are equal
and there is no other node of the same node color between x and y on the horizontal line connecting x and y in standard drawing..
Input
The first line of input contains number of nodes and root node. Next,
each line describes each node.
The first value is the node label, the second value is the node key,
the third and the fourth values represent the labels of the left and
right child respectively, and the fifth value represents the node
color, white is 0, black is 1. If any of the children does not exist
there is value 0 instead of the child label at the corresponding place
Output
number of neighbour pairs in the input tree
My code
```
from collections import defaultdict
def read_input(inputstring):
inputs = inputstring.split(" ")
nodes_number = int(inputs[0])
root_node = int(inputs[1])
input_tree = [list(map(int, input().split())) for _ in range(nodes_number)]
tree = {}
result = {}
def depth(node, count):
for i in input_tree:
node_label = i[0]
if node_label == node:
tree.setdefault(count, []).append(i)
l_node, r_node = i[2], i[3]
if l_node != 0 and r_node != 0:
depth(l_node, count+1)
depth(r_node, count+1)
elif l_node != 0:
depth(l_node, count+1)
elif r_node != 0:
depth(r_node, count+1)
depth(root_node, 0)
def neighbour_pair(input_tree):
for value in input_tree.values():
for i i
Solution
Readability counts
Overall, conventions are respected: case, spaces… But there are still some things missing to make your code feel readable:
-
and so on.
You also import
Building the depth of nodes
For each node (technically each node id), you iterate over the whole
This is a huge waste of time. Instead of iterating over a list to find a node by its id, you should use a datastructure that let you access a node by its id directly. A simple dictionary where keys are ids and values are node should suffice.
Computing neighbour pairs
First of, you could start your last loop at
Proposed improvements
You can see that I’m not storing the pairs of neighbours, it will help speed things up. But I left comments if you still want to be able to retrieve them to test for correctness.
Overall, conventions are respected: case, spaces… But there are still some things missing to make your code feel readable:
- vertical spacing: you should add blank lines before functions definition and important logical sections of the code so we can grasp the structure at a glance.
- naming: one letter variable are not great, especially when conventions dictates an other use of that letter than the one you make (I’m looking at you
for i in input_tree:).
-
namedtuples can help give meaning to your lines of input:from collections import namedtuple
Node = namedtuple('Node', 'id key left_child right_child color')
def read_input(inputstring):
…
input_tree = [Node(*map(int, input().split())) for _ in range(nodes_number)]
…
for node in input_tree:
node_label = node.id
…
l_node, r_node = node.left_child, node.right_childand so on.
- You don't need to define nested functions, you can pass them parameters. It will be easier to test (for correctness or speed). In fact, you should separate further the computation bits from the parsing of the input.
You also import
defaultdict but never use it. Instead, your code contains a few setdefault(…, []) on regular dictionaries. You should declare tree = defaultdict(list) and result = defaultdict(list) instead.Building the depth of nodes
For each node (technically each node id), you iterate over the whole
input_tree and then recurse, iterating over the whole input_tree for the two children.This is a huge waste of time. Instead of iterating over a list to find a node by its id, you should use a datastructure that let you access a node by its id directly. A simple dictionary where keys are ids and values are node should suffice.
Computing neighbour pairs
First of, you could start your last loop at
i+1 so you don't need the if j > i: test. But still, you don't even need that second loop at all: by keeping track of the last white node and the last black node you could only iterate once for each depth.Proposed improvements
from collections import defaultdict, namedtuple
Node = namedtuple('Node', 'id key left_child right_child color')
def read_nodes(nodes_count):
for _ in range(nodes_count):
yield Node(*map(int, input().split()))
def parse_input():
nodes_count, root_node = map(int, input().split())
tree = {node.id: node for node in read_nodes(nodes_count)}
return root_node, tree
def build_depth(node_id, tree, storage, current_depth=0):
node = tree[node_id]
storage[current_depth].append(node)
for child in (node.left_child, node.right_child):
if child:
build_depth(child, tree, storage, current_depth + 1)
def neighbour_pairs(layered_tree):
neighbours_count = 0 # neighbours = defaultdict(list)
for nodes in layered_tree.values():
last_black = Node(0, None, 0, 0, 1)
last_white = Node(0, None, 0, 0, 0)
for node in nodes:
if node.color == 1:
last_node, last_black = last_black, node
else:
last_node, last_white = last_white, node
if last_node.key == node.key:
neighbours_count += 1 # neighbours[last_node.id] = node.id
return neighbours_count # return neighbours
if __name__ == '__main__':
root, tree = parse_input()
nodes_by_depth = defaultdict(list)
build_depth(root, tree, nodes_by_depth)
neighbours = neighbour_pair(nodes_by_depth)
print(neighbours) # print(len(neighbours))You can see that I’m not storing the pairs of neighbours, it will help speed things up. But I left comments if you still want to be able to retrieve them to test for correctness.
Code Snippets
from collections import namedtuple
Node = namedtuple('Node', 'id key left_child right_child color')
def read_input(inputstring):
…
input_tree = [Node(*map(int, input().split())) for _ in range(nodes_number)]
…
for node in input_tree:
node_label = node.id
…
l_node, r_node = node.left_child, node.right_childfrom collections import defaultdict, namedtuple
Node = namedtuple('Node', 'id key left_child right_child color')
def read_nodes(nodes_count):
for _ in range(nodes_count):
yield Node(*map(int, input().split()))
def parse_input():
nodes_count, root_node = map(int, input().split())
tree = {node.id: node for node in read_nodes(nodes_count)}
return root_node, tree
def build_depth(node_id, tree, storage, current_depth=0):
node = tree[node_id]
storage[current_depth].append(node)
for child in (node.left_child, node.right_child):
if child:
build_depth(child, tree, storage, current_depth + 1)
def neighbour_pairs(layered_tree):
neighbours_count = 0 # neighbours = defaultdict(list)
for nodes in layered_tree.values():
last_black = Node(0, None, 0, 0, 1)
last_white = Node(0, None, 0, 0, 0)
for node in nodes:
if node.color == 1:
last_node, last_black = last_black, node
else:
last_node, last_white = last_white, node
if last_node.key == node.key:
neighbours_count += 1 # neighbours[last_node.id] = node.id
return neighbours_count # return neighbours
if __name__ == '__main__':
root, tree = parse_input()
nodes_by_depth = defaultdict(list)
build_depth(root, tree, nodes_by_depth)
neighbours = neighbour_pair(nodes_by_depth)
print(neighbours) # print(len(neighbours))Context
StackExchange Code Review Q#126687, answer score: 2
Revisions (0)
No revisions yet.