HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

Fast Python spring network solver

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
fastspringpythonsolvernetwork

Problem

I wanted a very simple spring system written in Python. The system would be defined as a simple network of knots, linked by links using the following rules:

-
A knot is a massless connection between links. Each knot is only affected by the push/pull forces it receives from the links it is connected to (no gravity, viscosity etc.)... Its only attribute is to be anchored or not, where an anchored knot affects the system by its movement. An unanchored knot can also affect the system if being moved, but it will be pulled back by the resulting push/pull forces.

-
A link is a connection between 2 knots. It has no mass, and it applies force on the knots connected to each end derived from the difference between its current length and its initial length.

-
The system takes for input the initial position of each knot, each knot's anchored state, a list of links (presented as arrays of knot indices), and each link's initial lengths. The system then begins iterating over the network by adding up all the forces affecting each knot, adjusting the knots to a their new position (dampened for stability), and keep iterating until an iteration count limit is reached, or the highest force applied at any given iteration is below a given threshold. I don't care to solve over time, I don't need velocity, all I want is the final "relaxed" position of each knot.

Leveraging NumPy's vectorization, I came up with this code:

```
import numpy as np
from numpy.core.umath_tests import inner1d

def solver(kPos, kAnchor, link0, link1, w0, cycles=1000, precision=0.001, dampening=0.1, debug=False):
"""
kPos : vector array - knot position
kAnchor : float array - knot's anchor state, 0 = moves freely, 1 = anchored (not moving)
link0 : int array - array of links connecting each knot. each index corresponds to a knot
link1 : int array - array of links connecting each knot. each index corresponds to a knot
w0

Solution

You should profile your code, to figure out what exactly is it that is slowing your code down. It's hard to tell without some actual measurements, but my bet is on your calls to np.add.at and np.minus.at, as the .at() method is notoriously (very) slow. For some operations there is really no alternative, but for addition/subtraction you can use np.bincount. The transformed code would look something like:

knots, ndim = kPos.shape
...
for dim in range(ndim):
    # Apply force vectors on each dimension of each knot
    F[:, dim] = (np.bincount(link0, weights=f[:, dim], minlength=knots) -
                 np.bincount(link1, weights=f[:, dim], minlength=knots))


As I said, you'll need to test it, but I wouldn't be surprised if this little change makes your code 5-10x faster.

Code Snippets

knots, ndim = kPos.shape
...
for dim in range(ndim):
    # Apply force vectors on each dimension of each knot
    F[:, dim] = (np.bincount(link0, weights=f[:, dim], minlength=knots) -
                 np.bincount(link1, weights=f[:, dim], minlength=knots))

Context

StackExchange Code Review Q#121138, answer score: 6

Revisions (0)

No revisions yet.