HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

Speed up large calculation of intersections

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
calculationintersectionslargespeed

Problem

I have the following data:

-
trains is a dictionary with 1700 elements. The keys are the IDs of trains and the value for each train is an array with every station ID where that train stops.

-
departures is a dictionary with the same keys as trains, so also 1700 elements. Each value is the departure time of the train.

Now, I would like to compute intersections between trains. When train A and train B have overlapping stops, I look at the departure time of both trains. When train A departs before train B, then (A, B) is put in the resulting set, otherwise (B, A).

trains = {90: [240, 76, 18, ...], 91: [2, 17, 98, 76, ...], ...}
departures = {90: 1418732160, 91: 1418711580, ...}
intersections = []

for i in trains:
    trA = trains[i]
    for j in trains:
        if i != j:
            trB = trains[j]
            intersect = [val for val in trA if val in trB]
            if intersect:
                if departures[i] < departures[j]:
                    if (i, j) not in intersections:
                        intersections.append((i, j))
                else:
                    if (j, i) not in intersections:
                        intersections.append((j, i))


When finished, the intersections list contains 500.000 elements.

This however takes very long to compute! I'm guessing it is because of the (i, j) not in intersections and (j, i) not in intersections statements.

Is there any way I could alter my code to speed up this calculation?

Solution

I would use a somewhat different algorithm to accomplish this task, by instead making a dictionary of all trains that go to each stop, sorting those trains by arrival, and then taking all inorder pairs of trains at each stop.

The basic algorithm is like this:

from collections import defaultdict
from itertools import combinations

trains = {...}
departures = {...}

intersections = set()

stations = defaultdict(list)
for t, train in trains.items():
    for s in train:
        stations[s].append(t)

for station in stations.values():
    intersections.update(combinations(sorted(station, key=lambda t: departures[t]), 2))


(This version of the code vastly improved by @Veedrac)

Code Snippets

from collections import defaultdict
from itertools import combinations

trains = {...}
departures = {...}

intersections = set()

stations = defaultdict(list)
for t, train in trains.items():
    for s in train:
        stations[s].append(t)

for station in stations.values():
    intersections.update(combinations(sorted(station, key=lambda t: departures[t]), 2))

Context

StackExchange Code Review Q#90309, answer score: 5

Revisions (0)

No revisions yet.