HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

Finding the maximum distance from the starting node

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
startingthemaximumnodedistancefindingfrom

Problem

I've tried everything to make my program faster but with 250 nodes my code takes around 9 seconds to print the result, and with 5000 nodes it took around 260 seconds.

Is there a way to make my program faster?

I took the BFS function from here.


The site(connections) should consist of three towns each of which is connected to
directly the other one by a road. The distance between two towns A and
B is the minimum number of pairs of directly connected towns on the
way from A to B. Two towns are connected directly by a road R if there
is no other town on R between A and B. The distance of any possible
site from the capital city is equal to the sum of the distances from
each of the three towns representing the site from the capital city.

Output: the maximum possible distance of a site from the capital city, the number of all sites which are located at the maximum distance from the capital city.


there are two most distant sites from the capital in node 9, the sites
are {0, 1, 3} and {2, 3, 5} and their distance from the capital is 10

Input first line (nodes_count, pairs_number, capital)
Next all the pairs

Input

10 17 9
0 1
0 3
1 3
1 4
2 3
3 4
2 5
3 5
3 6
4 6
4 7
5 6
6 7
6 8
7 8
7 9
8 9


Output

10 2


My code

```
from collections import defaultdict
from queue import Queue

def read_nodes(pairs_number):
for _ in range(pairs_number):
yield map(int, input().split())

def parse_input(tree):
nodes_count, pairs_number, capital = map(int, input().split())
for node1, node2 in read_nodes(pairs_number):
tree[node1].append(node2)
tree[node2].append(node1)
return tree, capital, nodes_count

def traverse_path(fromNode, toNode, nodes):
def getNeighbours(current, nodes):
return nodes[current] if current in nodes else []

def make_path(toNode, graph):
result = []
while 'Root' != toNode:
result.append(toNode)
toNode = graph[toNode]
resu

Solution

-
The code does not work because of a typo — there is a call to distance_sites but there is no such function.

-
There are no docstrings. What do these functions do? How do I call them?

-
In order to improve the performance, we need to measure it, and in order to do that, it's helpful to be able to make test cases of arbitrary sizes. So let's write a test case generator:

from itertools import product

def test_case(n):
    """Construct a graph with n**2 nodes and O(n**2) triangles, and return
    a tuple (graph, capital, number of nodes).

    """
    graph = defaultdict(list)
    for i, j in product(range(n), repeat=2):
        k = i * n + j
        if i < n - 1:
            graph[k].append(k + n)
            graph[k + n].append(k)
        if j < n - 1:
            graph[k].append(k + 1)
            graph[k + 1].append(k)
        if i < n - 1 and j < n - 1:
            graph[k].append(k + n + 1)
            graph[k + n + 1].append(k)
    return graph, 0, n * n


Then we can easily measure the performance of the code using timeit.timeit:

>>> from timeit import timeit
>>> timeit(lambda:distant_sites(*test_case(50)), number=1)
51.7171316598542


with \$n=50\$ the graph has 2,500 nodes and 4,802 triangles.

-
queue.Queue is a thread-safe data structure intended for use by multi-threaded programs. It has to take and release a lock for every operation, so it is overkill to use this in a single-threaded program. It is more than ten times faster to use collections.deque instead:

>>> timeit(lambda:distant_sites(*test_case(50)), number=1)
4.451224883086979


-
The code computes the distance from the capital to each town by running a separate breadth-first search for each town. But this repeats a lot of work: in the course of finding the distance to town A, the breadth-first search will have to visit towns B, C, D, and so on. It would make sense to remember the distance to each town as we visit it, and so compute the distances from the capital to all the towns in one go:

from collections import deque

def distances(graph, origin):
    """Return a dictionary mapping each node in graph to its distance from
    the origin.

    """
    result = {origin: 0}
    visited = set([origin])
    queue = deque([origin])
    while queue:
        node = queue.popleft()
        distance = result[node] + 1
        for neighbour in graph[node]:
            if neighbour not in visited:
                result[neighbour] = distance
                visited.add(neighbour)
                queue.append(neighbour)
    return result


This gives a couple of orders of magnitude speedup on the test case:

>>> timeit(lambda:distant_sites(*test_case(50)), number=1)
0.030185375828295946


-
The code for finding the most distant sites looks at all triangles. But this likely involves a lot of wasted effort. For example, suppose we find a triangle whose nodes have distance 48, 49 and 50 from the origin (with sum 147). Now there is no need to look at any triangle unless it contains a node with distance \${147\over 3} = 49\$ or more. So if we sort the nodes in reverse order by their distance from the origin and just remember the best-scoring site, then we may not have to consider very many sites before we know that we have found the most distant one.

def distant_sites(graph, origin):
    """Return the pair (max_site_dist, site_count), where max_site_dist is
    the maximum distance of any site in the graph from the origin, and
    site_count is the number of sites at that distance. A "site" is a
    triangle of nodes, and its distance from the origin is the sum of
    the distances of the three nodes.

    """
    distance = distances(graph, origin)
    nodes = sorted(((d, n) for n, d in distance.items()), reverse=True)
    max_site_dist = 0
    site_count = 0
    for dist1, node1 in nodes:
        if dist1 * 3 = (dist1, node1):
                continue
            for node3 in graph[node2]:
                if node3 not in neighbours:
                    continue
                dist3 = distance[node3]
                if (dist3, node3) >= (dist2, node2):
                    continue
                site_dist = dist1 + dist2 + dist3
                if site_dist > max_site_dist:
                    max_site_dist = site_dist
                    site_count = 1
                elif site_dist == max_site_dist:
                    site_count += 1
    return max_site_dist, site_count


The speedup you get from this optimization depends on the kinds of graph you feed it (if there are few sites then it won't make much difference). For my test case we get about 40% speedup:

>>> timeit(lambda:distant_sites(*test_case(50)), number=1)
0.017997111193835735

Code Snippets

from itertools import product

def test_case(n):
    """Construct a graph with n**2 nodes and O(n**2) triangles, and return
    a tuple (graph, capital, number of nodes).

    """
    graph = defaultdict(list)
    for i, j in product(range(n), repeat=2):
        k = i * n + j
        if i < n - 1:
            graph[k].append(k + n)
            graph[k + n].append(k)
        if j < n - 1:
            graph[k].append(k + 1)
            graph[k + 1].append(k)
        if i < n - 1 and j < n - 1:
            graph[k].append(k + n + 1)
            graph[k + n + 1].append(k)
    return graph, 0, n * n
>>> from timeit import timeit
>>> timeit(lambda:distant_sites(*test_case(50)), number=1)
51.7171316598542
>>> timeit(lambda:distant_sites(*test_case(50)), number=1)
4.451224883086979
from collections import deque

def distances(graph, origin):
    """Return a dictionary mapping each node in graph to its distance from
    the origin.

    """
    result = {origin: 0}
    visited = set([origin])
    queue = deque([origin])
    while queue:
        node = queue.popleft()
        distance = result[node] + 1
        for neighbour in graph[node]:
            if neighbour not in visited:
                result[neighbour] = distance
                visited.add(neighbour)
                queue.append(neighbour)
    return result
>>> timeit(lambda:distant_sites(*test_case(50)), number=1)
0.030185375828295946

Context

StackExchange Code Review Q#129055, answer score: 6

Revisions (0)

No revisions yet.