patternpythonMinor
Closest distance between points in a list
Viewed 0 times
pointsdistancebetweenclosestlist
Problem
In a course I'm doing, I was given the task of finding the closest pair of points among the given points. The program that passed all test cases was the following:
This code was based on the algorithm found at https://en.wikipedia.org/wiki/Closest_pair_of_points_problem#Planar_case . What do you think I can do better? Is there a more efficient way to implement this algorithm? (In this one, I just tried to delay the
import math
if __name__ == '__main__':
n = int(input())
P = []
for i in range(n):
line = input().split()
x = int(line[0])
y = int(line[1])
p = (x, y)
P.append(p)
print(closest(P, n))
def square(x):
return x * x
def square_distance(p0, p1):
return square(p0[0] - p1[0]) + square(p0[1] - p1[1])
def closest(P, n):
P.sort() # sort by x coordinates
return math.sqrt(_closest_square_distance(P, n))
def _closest_square_distance(P, n):
if n == 2:
return square_distance(P[0], P[1])
if n == 3:
return min(square_distance(P[0], P[1]), square_distance(P[0], P[2]), square_distance(P[1], P[2]))
mid = n // 2
dl = _closest_square_distance(P[:mid], mid)
dr = _closest_square_distance(P[mid:], n - mid)
closest_square_distance = min(dl, dr)
closest_distance_so_far = math.sqrt(closest_square_distance)
mid_x = P[mid][0]
strip = []
strip_length = 0
for i in range(n):
p = P[i]
if abs(p[0] - mid_x) < closest_distance_so_far:
strip.append(p)
strip_length += 1
strip.sort(key=lambda x: x[1]) # sort strip by y coordinates
for i in range(strip_length):
js = min(strip_length, i + 7) # sufficient to compute next 6 neighbors
for j in range(i + 1, js):
ds = square_distance(strip[i], strip[j])
if ds < closest_square_distance:
closest_square_distance = ds
return closest_square_distanceThis code was based on the algorithm found at https://en.wikipedia.org/wiki/Closest_pair_of_points_problem#Planar_case . What do you think I can do better? Is there a more efficient way to implement this algorithm? (In this one, I just tried to delay the
sqrt computation as long as possible, and avoided Solution
- Review
-
There are no docstrings. What do these functions do? What arguments do they take and what do they return?
-
You write:
I just tried to delay the
sqrt computation as long as possibleThis suggests that your cost model for Python is not quite right. In CPython, it's more expensive to call
square than it to call sqrt:>>> from timeit import timeit
>>> timeit('square(999)', 'from __main__ import square')
0.11685514450073242
>>> timeit('sqrt(999)', 'from math import sqrt')
0.057652950286865234That's because most of the time executing a program in CPython is spent in the execution of bytecode. The actual cost of computing the square root is small by comparison.
Because bytecode is expensive, it sometimes pays (in CPython) to do more computation, if you can arrange for it to be done in the fast C implementation rather than in slow Python code. In this case I would consider using
math.hypot, like this:from math import hypot
def distance(p, q):
"Return the Euclidean distance between points p and q."
return hypot(p[0] - q[0], p[1] - q[1])This is faster than calling
square_distance:>>> timeit('square_distance((1, 2), (3, 4))', 'from __main__ import square_distance')
0.4481959342956543
>>> timeit('distance((1, 2), (3, 4))', 'from __main__ import distance')
0.32285499572753906-
In
closest, the argument n must be the length of the list of points P. So it would be simpler to assign it yourself and avoid the requirement for the caller to pass it in (which is risky, because the caller might get it wrong).-
The function
closest has a side-effect: as well as computing the shortest distance, it sorts the list P. Unexpected side effects should be avoided — it is easier to understand code if it only has "local" effects on data structures. In this case, there are two alternatives: you could require the caller to sort the points before passing them in, or you could take a copy and sort the copy, like this:def closest(P):
"Return the closest Euclidean distance between two points in the list P."
return _closest_distance(sorted(P), len(P))-
When recursing:
dl = _closest_square_distance(P[:mid], mid)
dr = _closest_square_distance(P[mid:], n - mid)the list slice
P[:mid] has to copy out the contents of the slice. It would be cheaper to remember the endpoints of the slice you are working on and avoid the copying:def _closest(P, start, stop):
# closest Euclidean distance between two points in the slice P[start:stop]
# handle base cases here
mid = (start + stop) // 2
dl = _closest(P, start, mid)
dr = _closest(P, mid, stop)-
In the base cases, you could save some duplication by using
itertools.combinations and writing:if n <= 3:
return min(distance(p, q) for p, q in combinations(P, 2))-
The base case logic could be extended to larger values of
n. I find that for large random point sets, n
-
The function:
lambda x: x[1]
is the same every time, so consider storing it in a global variable (or using operator.itemgetter(1)).
-
The code finds the points in the strip by iterating over all the points in P. But since you have the points sorted by their \$x\$ coordinate, you can find the endpoints of the strip using the bisect module.
-
The code does not take advantage of the fact that the two points being compared can only be closer than closest_distance if they were in different halves (if they were in the same half, then they have already been compared by the recursive calls and closest_distance` already takes them into account).- Revised code
from bisect import bisect_left, bisect_right
from itertools import combinations, islice
from math import hypot
from operator import itemgetter
def distance(p, q):
"Return the Euclidean distance between points p and q."
return hypot(p[0] - q[0], p[1] - q[1])
_Y_COORD = itemgetter(1)
def closest_distance(P):
"Return the closest Euclidean distance between two points in the list P."
P = sorted(P)
PX = [x for x, _ in P]
def _closest(start, stop):
if stop - start <= 8:
return min(distance(p, q)
for p, q in combinations(P[start:stop], 2))
mid = (start + stop) // 2
dist = min(_closest(start, mid), _closest(mid, stop))
mid_x = PX[mid]
left = bisect_left(PX, mid_x - dist, lo=start, hi=mid)
right = bisect_right(PX, mid_x + dist, lo=mid, hi=stop)
strip = sorted(P[mid:right], key=_Y_COORD)
strip_y = list(map(_Y_COORD, strip))
for p in P[left:mid]:
y = p[1]
i = bisect_left(strip_y, y - dist)
j = bisect_right(strip_y, y + dist)
assert j - i <= 6
for q in strip[i:j]:
dist = min(dist, distance(p, q))
return dist
return _closest(0, len(P))Code Snippets
>>> from timeit import timeit
>>> timeit('square(999)', 'from __main__ import square')
0.11685514450073242
>>> timeit('sqrt(999)', 'from math import sqrt')
0.057652950286865234from math import hypot
def distance(p, q):
"Return the Euclidean distance between points p and q."
return hypot(p[0] - q[0], p[1] - q[1])>>> timeit('square_distance((1, 2), (3, 4))', 'from __main__ import square_distance')
0.4481959342956543
>>> timeit('distance((1, 2), (3, 4))', 'from __main__ import distance')
0.32285499572753906def closest(P):
"Return the closest Euclidean distance between two points in the list P."
return _closest_distance(sorted(P), len(P))dl = _closest_square_distance(P[:mid], mid)
dr = _closest_square_distance(P[mid:], n - mid)Context
StackExchange Code Review Q#159020, answer score: 5
Revisions (0)
No revisions yet.