snippetpythonMinor
Generate dictionary of points on n-sphere
Viewed 0 times
dictionarygeneratespherepoints
Problem
This is a long-shot, but my question is to simply optimize this particular function in some code I have written:
Note that
At this point, I need to run this function around a billion times or so to get the results I want, which will take hours. I've done as much optimization as possible here, and my micro-optimization seems to indicate that all three major parts of the function (the random number, the normalization, and the dictionary association) all take approximately a third of the run-time, and I'm not sure how to reduce it further. I've squeezed a bit out of Cythoning it, but I'm not sure how much of an improvement that will give me as compared to actually writing it in C and then importing the C function. Unfortunately, I don't know C (or C++) and haven't written Java in years, so I'm pretty stuck using Python (or extensions thereof).
Is it possible to optimize this further by any significant factor? Would changing languages help dramatically? What improvements could be made in Python?
import numpy as np
from numpy.core.umath_tests import inner1d
def getSpherePoint(dim,pList,psetLen):
npgauss = np.random.standard_normal
pt = abs(np.random.standard_normal(psetLen))
pt = pt/np.sqrt(inner1d(pt,pt))
d = dict(zip(pList,pt))
return dNote that
dim is some integer \$n\$, pList is simply a list of \$2^n-1\$ strings, and psetLen is \$2^n-1\$. This function is designed to return a dictionary keyed by elements of pList with values determined by a point on the unit \$n\$-hypersphere.At this point, I need to run this function around a billion times or so to get the results I want, which will take hours. I've done as much optimization as possible here, and my micro-optimization seems to indicate that all three major parts of the function (the random number, the normalization, and the dictionary association) all take approximately a third of the run-time, and I'm not sure how to reduce it further. I've squeezed a bit out of Cythoning it, but I'm not sure how much of an improvement that will give me as compared to actually writing it in C and then importing the C function. Unfortunately, I don't know C (or C++) and haven't written Java in years, so I'm pretty stuck using Python (or extensions thereof).
Is it possible to optimize this further by any significant factor? Would changing languages help dramatically? What improvements could be made in Python?
Solution
Style
Following PEP8, the official Python style guide, you should use snake_case for your variables and functions names. You should also put a space after comas.
I would also rename the function
You also happen to not use the
However, as pointed out by @200_success in a comment, these answers suggest using
Cache function access
In Python, local symbols are faster to resolve. Which means
But caching a complex lookup within the same call you use it does not add much, since the complex lookup still need to be performed.
One way to efficiently do it is to use default values for arguments as the lookup will be performed only once and the speedup will occur at each call.
Proposed improvements
Following PEP8, the official Python style guide, you should use snake_case for your variables and functions names. You should also put a space after comas.
I would also rename the function
sphere_points as the "get" is implied by the fact that you call a function (to get a result).You also happen to not use the
dim parameter and, looking at your description, the last one is just len(pList). Unless this length is common to a lot of calls and there is a real benefit to cache it, you can drop 2 out of 3 parameters.numpynumpy has vectorized functions for your main operations: numpy.fabs (or numpy.absolute if you're dealing with complex numbers) and numpy.linalg.norm. Using these might speed things up. Unfortunately, going back into pure Python realm (dict + zip) can't be speed up that way.However, as pointed out by @200_success in a comment, these answers suggest using
np.sqrt(pt.dot(pt)) or np.sqrt(np.einsum('i,i', pt, pt)) for faster results than np.linalg.norm. As always, profile and choose what's best for your use case.Cache function access
In Python, local symbols are faster to resolve. Which means
np.linalg.norm is kinda slow, but norm is faster. And granted you performed norm = np.linalg.norm somehow before, it will produce the same result.But caching a complex lookup within the same call you use it does not add much, since the complex lookup still need to be performed.
One way to efficiently do it is to use default values for arguments as the lookup will be performed only once and the speedup will occur at each call.
Proposed improvements
import numpy as np
def sphere_points(keys, random_gen=np.random.standard_normal, absolute=np.fabs, norm=np.linalg.norm):
pt = absolute(random_gen(len(keys)))
pt = pt/norm(pt)
return dict(zip(keys, pt))Code Snippets
import numpy as np
def sphere_points(keys, random_gen=np.random.standard_normal, absolute=np.fabs, norm=np.linalg.norm):
pt = absolute(random_gen(len(keys)))
pt = pt/norm(pt)
return dict(zip(keys, pt))Context
StackExchange Code Review Q#136009, answer score: 4
Revisions (0)
No revisions yet.