patternpythonMinor
Fastest computation of N likelihoods on normal distributions
Viewed 0 times
computationnormaldistributionslikelihoodsfastest
Problem
In the context of a Gibbs sampler, I profiled my code and my major bottleneck is the following:
I need to compute the likelihood of N points assuming they have been drawn from N normal distributions (with different means but same variance).
Here are two ways to compute it:
-
The first: use the recently implemented
1000 loops, best of 3: 1.33 ms per loop
-
The second: use the traditional
10000 loops, best of 3: 130 µs per loop
Since this is part of a Gibbs sampler, I need to repeat this computation around 10.000 times, and therefore I need it to be as fast as possible.
How can I improve it?
(either from python or calling Cython, R or whatever)
I need to compute the likelihood of N points assuming they have been drawn from N normal distributions (with different means but same variance).
Here are two ways to compute it:
import numpy as np
from scipy.stats import multivariate_normal
from scipy.stats import norm
# Toy data
y = np.random.uniform(low=-1, high=1, size=100) # data points
loc = np.zeros(len(y)) # means
# Two alternatives
%timeit multivariate_normal.logpdf(y, mean=loc, cov=1)
%timeit sum(norm.logpdf(y, loc=loc, scale=1))-
The first: use the recently implemented
multivariate_normal of scipy. Build the equivalent N-dimensional gaussian and compute the (log)probability of a N-dimensional y. 1000 loops, best of 3: 1.33 ms per loop
-
The second: use the traditional
norm function of scipy. Compute the individual (log)probability of every point y and then sum the results.10000 loops, best of 3: 130 µs per loop
Since this is part of a Gibbs sampler, I need to repeat this computation around 10.000 times, and therefore I need it to be as fast as possible.
How can I improve it?
(either from python or calling Cython, R or whatever)
Solution
You should use a line profiler tool to examine what the slowest parts of the code are. It sounds like you did that for your own code, but you could keep going and profile the source code that NumPy and SciPy use when calculating your quantity of interest. The
It looks like a not-insignificant amount of time is being spent checking and removing invalid arguments from the function input. If you can be sure you will never need to use that feature, just write your own function to calculate the
Plus, if you are going to be multiplying probabilities (i.e. adding log probabilities), you could use algebra to simplify and factor out common terms from the summand for the normal distribution's pdf. That will lower the number of function calls to `np.l
Line profiler module is my favorite. import numpy as np
from scipy.stats import multivariate_normal
from scipy.stats import norm
%lprun -f norm.logpdf norm.logpdf(x=np.random.random(1000000), \
loc=np.random.random(1000000), \
scale = np.random.random())Timer unit: 1e-06 s
Total time: 0.14831 s
File: /opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/scipy/stats/_distn_infrastructure.py
Function: logpdf at line 1578
Line # Hits Time Per Hit % Time Line Contents
==============================================================
1578 def logpdf(self, x, *args, **kwds):
1579 """
1580 Log of the probability density function at x of the given RV.
1581
1582 This uses a more numerically accurate calculation if available.
1583
1584 Parameters
1585 ----------
1586 x : array_like
1587 quantiles
1588 arg1, arg2, arg3,... : array_like
1589 The shape parameter(s) for the distribution (see docstring of the
1590 instance object for more information)
1591 loc : array_like, optional
1592 location parameter (default=0)
1593 scale : array_like, optional
1594 scale parameter (default=1)
1595
1596 Returns
1597 -------
1598 logpdf : array_like
1599 Log of the probability density function evaluated at x
1600
1601 """
1602 1 14 14.0 0.0 args, loc, scale = self._parse_args(*args, **kwds)
1603 1 23 23.0 0.0 x, loc, scale = map(asarray, (x, loc, scale))
1604 1 2 2.0 0.0 args = tuple(map(asarray, args))
1605 1 13706 13706.0 9.2 x = asarray((x-loc)*1.0/scale)
1606 1 33 33.0 0.0 cond0 = self._argcheck(*args) & (scale > 0)
1607 1 5331 5331.0 3.6 cond1 = (scale > 0) & (x >= self.a) & (x <= self.b)
1608 1 5625 5625.0 3.8 cond = cond0 & cond1
1609 1 84 84.0 0.1 output = empty(shape(cond), 'd')
1610 1 6029 6029.0 4.1 output.fill(NINF)
1611 1 11459 11459.0 7.7 putmask(output, (1-cond0)+np.isnan(x), self.badvalue)
1612 1 1093 1093.0 0.7 if any(cond):
1613 1 58499 58499.0 39.4 goodargs = argsreduce(cond, *((x,)+args+(scale,)))
1614 1 6 6.0 0.0 scale, goodargs = goodargs[-1], goodargs[:-1]
1615 1 46401 46401.0 31.3 place(output, cond, self._logpdf(*goodargs) - log(scale))
1616 1 4 4.0 0.0 if output.ndim == 0:
1617 return output[()]
1618 1 1 1.0 0.0 return outputIt looks like a not-insignificant amount of time is being spent checking and removing invalid arguments from the function input. If you can be sure you will never need to use that feature, just write your own function to calculate the
logpdf.Plus, if you are going to be multiplying probabilities (i.e. adding log probabilities), you could use algebra to simplify and factor out common terms from the summand for the normal distribution's pdf. That will lower the number of function calls to `np.l
Code Snippets
import numpy as np
from scipy.stats import multivariate_normal
from scipy.stats import norm
%lprun -f norm.logpdf norm.logpdf(x=np.random.random(1000000), \
loc=np.random.random(1000000), \
scale = np.random.random())Timer unit: 1e-06 s
Total time: 0.14831 s
File: /opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/scipy/stats/_distn_infrastructure.py
Function: logpdf at line 1578
Line # Hits Time Per Hit % Time Line Contents
==============================================================
1578 def logpdf(self, x, *args, **kwds):
1579 """
1580 Log of the probability density function at x of the given RV.
1581
1582 This uses a more numerically accurate calculation if available.
1583
1584 Parameters
1585 ----------
1586 x : array_like
1587 quantiles
1588 arg1, arg2, arg3,... : array_like
1589 The shape parameter(s) for the distribution (see docstring of the
1590 instance object for more information)
1591 loc : array_like, optional
1592 location parameter (default=0)
1593 scale : array_like, optional
1594 scale parameter (default=1)
1595
1596 Returns
1597 -------
1598 logpdf : array_like
1599 Log of the probability density function evaluated at x
1600
1601 """
1602 1 14 14.0 0.0 args, loc, scale = self._parse_args(*args, **kwds)
1603 1 23 23.0 0.0 x, loc, scale = map(asarray, (x, loc, scale))
1604 1 2 2.0 0.0 args = tuple(map(asarray, args))
1605 1 13706 13706.0 9.2 x = asarray((x-loc)*1.0/scale)
1606 1 33 33.0 0.0 cond0 = self._argcheck(*args) & (scale > 0)
1607 1 5331 5331.0 3.6 cond1 = (scale > 0) & (x >= self.a) & (x <= self.b)
1608 1 5625 5625.0 3.8 cond = cond0 & cond1
1609 1 84 84.0 0.1 output = emptydef my_logpdf_sum(x, loc, scale):
root2 = np.sqrt(2)
root2pi = np.sqrt(2*np.pi)
prefactor = - x.size * np.log(scale * root2pi)
summand = -np.square((x - loc)/(root2 * scale))
return prefactor + summand.sum()
# toy data
y = np.random.uniform(low=-1, high=1, size=1000) # data points
loc = np.zeros(y.shape)
# timing
%timeit multivariate_normal.logpdf(y, mean=loc, cov=1)
%timeit np.sum(norm.logpdf(y, loc=loc, scale=1))
%timeit my_logpdf_sum(y, loc, 1)
1 loops, best of 3: 156 ms per loop
10000 loops, best of 3: 125 µs per loop
The slowest run took 4.55 times longer than the fastest. This could mean that an intermediate result is being cached
100000 loops, best of 3: 16.3 µs per loopContext
StackExchange Code Review Q#69718, answer score: 4
Revisions (0)
No revisions yet.