patternpythonMinor
Fast Python spring network solver
Viewed 0 times
fastspringpythonsolvernetwork
Problem
I wanted a very simple spring system written in Python. The system would be defined as a simple network of
-
A
-
A
-
The system takes for input the
Leveraging NumPy's vectorization, I came up with this code:
```
import numpy as np
from numpy.core.umath_tests import inner1d
def solver(kPos, kAnchor, link0, link1, w0, cycles=1000, precision=0.001, dampening=0.1, debug=False):
"""
kPos : vector array - knot position
kAnchor : float array - knot's anchor state, 0 = moves freely, 1 = anchored (not moving)
link0 : int array - array of links connecting each knot. each index corresponds to a knot
link1 : int array - array of links connecting each knot. each index corresponds to a knot
w0
knots, linked by links using the following rules:-
A
knot is a massless connection between links. Each knot is only affected by the push/pull forces it receives from the links it is connected to (no gravity, viscosity etc.)... Its only attribute is to be anchored or not, where an anchored knot affects the system by its movement. An unanchored knot can also affect the system if being moved, but it will be pulled back by the resulting push/pull forces.-
A
link is a connection between 2 knots. It has no mass, and it applies force on the knots connected to each end derived from the difference between its current length and its initial length.-
The system takes for input the
initial position of each knot, each knot's anchored state, a list of links (presented as arrays of knot indices), and each link's initial lengths. The system then begins iterating over the network by adding up all the forces affecting each knot, adjusting the knots to a their new position (dampened for stability), and keep iterating until an iteration count limit is reached, or the highest force applied at any given iteration is below a given threshold. I don't care to solve over time, I don't need velocity, all I want is the final "relaxed" position of each knot.Leveraging NumPy's vectorization, I came up with this code:
```
import numpy as np
from numpy.core.umath_tests import inner1d
def solver(kPos, kAnchor, link0, link1, w0, cycles=1000, precision=0.001, dampening=0.1, debug=False):
"""
kPos : vector array - knot position
kAnchor : float array - knot's anchor state, 0 = moves freely, 1 = anchored (not moving)
link0 : int array - array of links connecting each knot. each index corresponds to a knot
link1 : int array - array of links connecting each knot. each index corresponds to a knot
w0
Solution
You should profile your code, to figure out what exactly is it that is slowing your code down. It's hard to tell without some actual measurements, but my bet is on your calls to
As I said, you'll need to test it, but I wouldn't be surprised if this little change makes your code 5-10x faster.
np.add.at and np.minus.at, as the .at() method is notoriously (very) slow. For some operations there is really no alternative, but for addition/subtraction you can use np.bincount. The transformed code would look something like:knots, ndim = kPos.shape
...
for dim in range(ndim):
# Apply force vectors on each dimension of each knot
F[:, dim] = (np.bincount(link0, weights=f[:, dim], minlength=knots) -
np.bincount(link1, weights=f[:, dim], minlength=knots))As I said, you'll need to test it, but I wouldn't be surprised if this little change makes your code 5-10x faster.
Code Snippets
knots, ndim = kPos.shape
...
for dim in range(ndim):
# Apply force vectors on each dimension of each knot
F[:, dim] = (np.bincount(link0, weights=f[:, dim], minlength=knots) -
np.bincount(link1, weights=f[:, dim], minlength=knots))Context
StackExchange Code Review Q#121138, answer score: 6
Revisions (0)
No revisions yet.