HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

Attempting to run multiple simulations of the Gillespie algorithm for a set of stochastic chemical reactions in less than 10 minutes

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
theminutesstochasticsimulationsgillespiethanalgorithmattemptingformultiple

Problem

I have written Python code that generates a plot. In the code below, when I set maxtime = 0.1, the program takes ~50s and when I set maxtime = 1, it takes ~420s. I need to run it for maxtime = 1000.

I am familiar with Python syntax and writing "Matlabic" code, but am lost in writing natively "Pythonic" code. As a result, I need help in optimizing this code for runtime, specifically in the two outer for loops and inner while loop.

  • How can I make the code suitable for use with Numba or Cython?



  • If that's not possible, do I need to use functions, or map, or lambda statements?



Unfortunately, my Spyder IDE for Python is freezing up every time I try to profile the code. I would include those details if I could!

``
import numpy as np
import matplotlib.pyplot as plt
import pylab as pl
import math

maxtime = 1
Qpvec = np.logspace(-2,1,101,10)
lenQp = len(Qpvec)
delta = 1e-3
P0vec = Qpvec/delta
SimNum = 100
PTotStoch = np.zeros((SimNum,lenQp))
k_minus = 1e-3
k_cat = 1-1e-3
k_plus = 1e-3
zeta = 1e-4
D0 = 10000
kappa_M = (k_cat+k_minus)/zeta
QpDeg = np.true_divide(1000D0Qpveck_cat,1000Qpvec + kappa_M)

for lenQpInd in range(0,lenQp):
for SimNumInd in range(0,SimNum):
Qp = Qpvec[lenQpInd]
P0 = P0vec[lenQpInd]
DP0 = 0
P = math.floor(P0)
DP = DP0
D = D0
time = 0

while time maxtime:
PTotStoch[SimNumInd,lenQpInd] = PTot
break
elif u_event*kT < tot[0]:
P += 1
elif u_event*kT < tot[1]:
P -= 1
DP += 1
D -= 1
elif u_event*kT < tot[2]:
P += 1
DP -= 1
D += 1
elif u_event*kT < tot[3]:
DP -= 1
D += 1

PMean = PTotStoch.mean(axis=0)
PStd = PTotStoch.std(axis=0)

plt.figure(0)
plt.plot(Qpvec,PMean,marker=".")
plt.errorbar(Qpvec,PMean,yerr = 2*PStd, xerr = None)
plt.show()

Solution

Pythonic code follows PEP8, mostly you need to stick to one type of variable format.
In Python they are:

  • snake_case variables. (functions are variables).



  • UPPER_SNAKE_CASE constants.



  • CammelCase classes.



There are a few white space conventions that you didn't follow, such as spaces around operators, and one space after commas.
But all in all you have some nice style.

First things first, don't use pylab. It's not supported,
in-fact I could only find one archive page about it that said to use matplotlib instead.

I assume that pylab.rand is the same as numpy.random.

Due to the above style concerns I found your code hard to read.
But there are also a lot of variables that are written and read once.
This makes the code harder to read, as then you have to remember senseless names such as u_event or kT.

I found the inner part of the while loop hard to read, instead split, the time changes and changes in p, d and dp, into two sections.
This made it easier for me to read and understand.

After doing the above I ran a profiler over the code, and using numpys array and cumsum were dramatic performance sinks.
To resolve this I implemented cumsum in Python, and used a Python array.

This lead to a 260% increase in speed on max_time = 0.1.
Note that I don't test the generation of P0_VEC as I made it a global variable in my code, but didn't in yours.
(So mine has an unfair advantage).

Also as @oliverpool says in the comments, you can get a speed up from using functions. (I put both my code and your code in a function, so the timings wern't extreemly unfair.)

I doubt that you can make the code much faster, as it's seems to be \$O(n)\$ already.
And with my change you are doing 309057 function calls, or 214921 a second.

import numpy as np
import matplotlib.pyplot as plt
random = np.random

# testing functions remove them from this script.
PROFILE = True
TIME = True

# Global constants
DELTA = 1e-3
SIM_NUM = 100
K_MINUS = 1e-3
K_CAT = 1 - 1e-3
ZETA = 1e-4
D = 10000
QP_VEC = np.logspace(-2, 1, 101, 10)
KAPPA_M = (K_CAT + K_MINUS) / ZETA
P0_VEC = QP_VEC / DELTA

def gillespie(max_time=0.01):
    p_tot_stoch = np.zeros((SIM_NUM, len(QP_VEC)))
    for len_qp_ind in range(len(QP_VEC)):
        qp = QP_VEC[len_qp_ind]
        p0 = int(P0_VEC[len_qp_ind])
        for sim_num_ind in range(SIM_NUM):
            p = p0
            d = D
            dp = time = 0

            while True:
                tot = [qp, ZETA * p * d, K_MINUS * dp, K_CAT * dp]
                for i in range(3):
                    tot[i + 1] += tot[i]
                kt = tot[-1]
                time += -np.log(1 - random.random()) / kt
                if time > max_time:
                    p_tot_stoch[sim_num_ind, len_qp_ind] = p + dp
                    break

                event_kt = random.random() * kt
                if event_kt < tot[0]:
                    p += 1
                elif event_kt < tot[1]:
                    p -= 1
                    dp += 1
                    d -= 1
                elif event_kt < tot[2]:
                    p += 1
                    dp -= 1
                    d += 1
                elif event_kt < tot[3]:
                    dp -= 1
                    d += 1

    return p_tot_stoch

if __name__ == '__main__' and PROFILE:
    import cProfile
    cProfile.run('gillespie()')

if __name__ == '__main__' and TIME:
    import timeit
    def gillespie_dot_one():
        return gillespie(0.1)
    def gillespie_one():
        return gillespie(1)
    print('gillespie(0.01)*100 = ', timeit.timeit(gillespie, number=100))
    print('gillespie(0.1)*10 = ', timeit.timeit(gillespie_dot_one, number=10))
    print('gillespie(0.1) = ', timeit.timeit(gillespie_one, number=1))

if __name__ == '__main__':
    p_tot_stoch = gillespie()
    p_mean = p_tot_stoch.mean(axis=0)
    p_std = p_tot_stoch.std(axis=0)

    plt.figure(0)
    plt.plot(QP_VEC, p_mean, marker=".")
    plt.errorbar(QP_VEC, p_mean, yerr=2 * p_std, xerr=None)
    plt.show()

Code Snippets

import numpy as np
import matplotlib.pyplot as plt
random = np.random

# testing functions remove them from this script.
PROFILE = True
TIME = True

# Global constants
DELTA = 1e-3
SIM_NUM = 100
K_MINUS = 1e-3
K_CAT = 1 - 1e-3
ZETA = 1e-4
D = 10000
QP_VEC = np.logspace(-2, 1, 101, 10)
KAPPA_M = (K_CAT + K_MINUS) / ZETA
P0_VEC = QP_VEC / DELTA


def gillespie(max_time=0.01):
    p_tot_stoch = np.zeros((SIM_NUM, len(QP_VEC)))
    for len_qp_ind in range(len(QP_VEC)):
        qp = QP_VEC[len_qp_ind]
        p0 = int(P0_VEC[len_qp_ind])
        for sim_num_ind in range(SIM_NUM):
            p = p0
            d = D
            dp = time = 0

            while True:
                tot = [qp, ZETA * p * d, K_MINUS * dp, K_CAT * dp]
                for i in range(3):
                    tot[i + 1] += tot[i]
                kt = tot[-1]
                time += -np.log(1 - random.random()) / kt
                if time > max_time:
                    p_tot_stoch[sim_num_ind, len_qp_ind] = p + dp
                    break

                event_kt = random.random() * kt
                if event_kt < tot[0]:
                    p += 1
                elif event_kt < tot[1]:
                    p -= 1
                    dp += 1
                    d -= 1
                elif event_kt < tot[2]:
                    p += 1
                    dp -= 1
                    d += 1
                elif event_kt < tot[3]:
                    dp -= 1
                    d += 1

    return p_tot_stoch


if __name__ == '__main__' and PROFILE:
    import cProfile
    cProfile.run('gillespie()')

if __name__ == '__main__' and TIME:
    import timeit
    def gillespie_dot_one():
        return gillespie(0.1)
    def gillespie_one():
        return gillespie(1)
    print('gillespie(0.01)*100 = ', timeit.timeit(gillespie, number=100))
    print('gillespie(0.1)*10 = ', timeit.timeit(gillespie_dot_one, number=10))
    print('gillespie(0.1) = ', timeit.timeit(gillespie_one, number=1))

if __name__ == '__main__':
    p_tot_stoch = gillespie()
    p_mean = p_tot_stoch.mean(axis=0)
    p_std = p_tot_stoch.std(axis=0)

    plt.figure(0)
    plt.plot(QP_VEC, p_mean, marker=".")
    plt.errorbar(QP_VEC, p_mean, yerr=2 * p_std, xerr=None)
    plt.show()

Context

StackExchange Code Review Q#117903, answer score: 4

Revisions (0)

No revisions yet.