HiveBrain v1.2.0
Get Started
← Back to all entries
snippetpythonMinor

How do I optimize the solving of ODEs?

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
solvingtheodesoptimizehow

Problem

I have been trying to figure out a way to optimize the solving of ODEs in Python but haven't been able to achieve this goal. I tried getting help via a bounty on SO using Cython but nothing came of that.

The code below isn't too slow but it is an example of code I tend to run a lot. Some of the code I use can take forever since I am simulating space flight trajectories. If I can get some help with this code, I can work on the rest myself with ideas and knowledge I gain.

The code isn't short but not too long. The crux of the matter is the performance of the ODE solving. The rest is solving constants that go into the ODE or IC.

```
import numpy as np
from scipy.integrate import ode
import pylab
from mpl_toolkits.mplot3d import Axes3D
from scipy.optimize import brentq
from scipy.optimize import fsolve

me = 5.974 * 10 ** 24 # mass of the earth
mm = 7.348 * 10 ** 22 # mass of the moon
G = 6.67259 * 10 ** -20 # gravitational parameter
re = 6378.0 # radius of the earth in km
rm = 1737.0 # radius of the moon in km
r12 = 384400.0 # distance between the CoM of the earth and moon
rs = 66100.0 # distance to the moon SOI
Lambda = np.pi / 6 # angle at arrival to SOI
M = me + mm
d = 300 # distance the spacecraft is above the Earth
pi1 = me / M
pi2 = mm / M
mue = 398600.0 # gravitational parameter of earth km^3/sec^2
mum = G * mm # grav param of the moon
mu = mue + mum
omega = np.sqrt(mu / r12 ** 3)
# distance from the earth to Lambda on the SOI
r1 = np.sqrt(r12 2 + rs 2 - 2 r12 rs * np.cos(Lambda))
vbo = 10.8

Solution

I would manually "eliminate common sub-expressions" in the deriv:

def deriv(u, dt):
    norm1 = np.sqrt(((u[0] + pi2 * r12) ** 2 + u[1] ** 2) ** 3)
    norm2 = np.sqrt(((u[0] - pi1 * r12) ** 2 + u[1] ** 2) ** 3)
    return [u[3],  #  dotu[0] = u[3]                                                
        u[4],  #  dotu[1] = u[4]                                                
        u[5],  #  dotu[2] = u[5]                                                
        (2 * omega * u[4] + omega ** 2 * u[0] - mue * (u[0] + pi2 * r12) /
         norm1 - mum *
         (u[0] - pi1 * r12) /
         norm2),
        #  dotu[3] = that                                                       
        (-2 * omega * u[3] + omega ** 2 * u[1] - mue * u[1] /
         norm1 - mum * u[1] /
         norm2),
        #  dotu[4] = that                                                       
        0]  #  dotu[5] = 0


And BTW is the tan(0) in f(E0) actually tan(nu0 / 2)?

I am not a Python programmer.

Code Snippets

def deriv(u, dt):
    norm1 = np.sqrt(((u[0] + pi2 * r12) ** 2 + u[1] ** 2) ** 3)
    norm2 = np.sqrt(((u[0] - pi1 * r12) ** 2 + u[1] ** 2) ** 3)
    return [u[3],  #  dotu[0] = u[3]                                                
        u[4],  #  dotu[1] = u[4]                                                
        u[5],  #  dotu[2] = u[5]                                                
        (2 * omega * u[4] + omega ** 2 * u[0] - mue * (u[0] + pi2 * r12) /
         norm1 - mum *
         (u[0] - pi1 * r12) /
         norm2),
        #  dotu[3] = that                                                       
        (-2 * omega * u[3] + omega ** 2 * u[1] - mue * u[1] /
         norm1 - mum * u[1] /
         norm2),
        #  dotu[4] = that                                                       
        0]  #  dotu[5] = 0

Context

StackExchange Code Review Q#26314, answer score: 3

Revisions (0)

No revisions yet.