patternpythonMinor
Time-dependent state machine
Viewed 0 times
machinedependentstatetime
Problem
I keep having to write state machines that depend on time for various experiments I run and I'd like to know how to write them better. This state machine is for training a neural network by feeding in keys and expected values.
The results should look something like:
How can I write this better or more efficiently? Is there a state machine library in Python that would stop
import numpy as np
dt = 0.001
period = 0.1
class SimpleEnv(object):
def __init__(self, keys, values, env_period=0.1):
self.keys = keys
self.values = values
self.env_idx = np.arange(len(keys))
self.idx = 0
self.shuffled = False
self.i_every = int(round(env_period/dt))
if self.i_every != env_period/dt:
raise ValueError("dt (%s) does not divide period (%s)" % (dt, period))
def get_key(self):
return self.keys[self.idx]
def get_val(self):
return self.values[self.idx]
def step(self, t):
i = int(round((t - dt)/dt)) # t starts at dt
ix = (i/self.i_every) % len(self.keys)
if ix == 0 and not self.shuffled:
print("shuffling")
np.random.shuffle(self.env_idx)
self.shuffled = True
elif ix == 1:
self.shuffled = False
self.idx = self.env_idx[ix]
return ix
# note the toy keys and values for testing purposes
s_env = SimpleEnv(np.arange(4), np.arange(1, 5), env_period=period)
key = -1
val = -1
ix = -1
# iterate through keys and values twice
run_time = 4 * period * 2
# the event loop
# starts at dt because of reasons
for t in np.arange(dt, run_time, dt):
last_ix = ix
ix = s_env.step(t)
key = s_env.get_key()
val = s_env.get_val()
assert key + 1 == val
if last_ix != ix:
print("Key: %s, Value: %s" %(key, val))The results should look something like:
shuffling
Key: 2, Value: 3
Key: 0, Value: 1
Key: 3, Value: 4
Key: 1, Value: 2
shuffling
Key: 2, Value: 3
Key: 1, Value: 2
Key: 3, Value: 4
Key: 0, Value: 1How can I write this better or more efficiently? Is there a state machine library in Python that would stop
Solution
- (Possible) bugs
There are several problematic aspects of the code in the post. However, it is impossible for me to tell whether these really indicate bugs, because the code has no documentation, and so it is impossible for me to know what it is supposed to do.
-
The code is not portable to Python 3, because of the use of the true division operator here:
ix = (i/self.i_every) % len(self.keys)This causes
ix to be a float, which results in a warning when evaluating self.env_idx[ix]:VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future
The code needs to use the floor division operator,
i // self.i_every, to ensure that ix is an integer.-
The shuffling logic is fragile: it relies on
ix successively taking the value 0 (so that the array is shuffled) and then 1 (so that the shuffled flag is cleared). But this depends on exactly the right sequence of values of t being passed to step. If other values get passed, then the array may never be shuffled:>>> e = SimpleEnv([1, 2, 3], [10, 20, 30])
>>> for t in np.arange(dt, 100, period * 3):
... i = e.step(t)
>>>The output "shuffling" does not appear, indicating that the array was never shuffled.
-
There is nothing about the interface that stops
t from going backwards in time. But if you try to go backwards past a shuffle, then you get different results for the same values of t:>>> e = SimpleEnv([1, 2, 3], [10, 20, 30])
>>> e.step(period * 1); e.get_key()
(0, 1)
>>> e.step(period * 4), e.get_key()
shuffling
(0, 2)
>>> e.step(period * 1), e.get_key()
(0, 2)It would be better if the interface were designed so that unsupported actions (like going backwards in time) cannot be attempted.
- Separation of concerns
An important software design principle is separation of concerns. The idea is that you separate your code into pieces, each of which has a single concern. This makes the pieces easier to understand and easier to test, and makes it more likely that you can reuse some of them.
In this code there are three concerns:
-
Maintaining corresponding arrays of keys and values.
-
Generating a sequence of shuffled indexes into the arrays of keys and values, where the indexes are re-shuffled after they have been used.
-
Converting timestamps into entries in the sequence of shuffled indexes.
Let's see how these might be implemented:
-
This concern seems very simple. We can just have two variables:
keys = np.arange(4)
values = np.arange(1, 5)-
In some circumstances it would make sense to have a state object which varies over time. But we're using NumPy and a rule of thumb when it comes to using NumPy is that we want to do as much as possible with uniform arrays (even if these are big) rather than using iterators that have to evaluate code at each step. So let's write a function that constructs the sequence we need:
def reshuffled_indexes(n, m, state=np.random):
"""Return an array of shape (n,) consisting of shuffled indexes in the
range [0, m). The indexes are re-shuffled each time they have all
been used. Optional argument state may be a np.random.RandomState
instance for random reproducibility.
"""
out = np.empty((n,), dtype=int)
r = n % m
for i in range(0, n - r, m):
out[i: i + m] = state.permutation(m)
if r:
out[-r:] = state.permutation(m)[:r]
return outFor example:
>>> reshuffled_indexes(17, 4)
array([1, 0, 2, 3, 2, 1, 0, 3, 0, 3, 1, 2, 0, 1, 3, 2, 2])Note the optional
state argument. This is a useful argument to have on any function that produces a random result, because in order to test a function with random results, you need to be able to ensure reproducibility of the random number sequence.-
The time series
t is converted to an array of indexes by subtracting dt, dividing by the period, and rounding to the nearest integer:t = np.arange(dt, run_time, dt)
index_t = ((t - dt) / period).round().astype(int)Now that you have the array of indexes, you know how many reshuffled indexes you are going to need:
r = reshuffled_indexes(index_t[-1] + 1, len(keys))and you can use this to index the arrays of keys and values to get time series for each:
r_t = r[index_t]
keys_t = keys[r_t]
values_t = values[r_t]Code Snippets
ix = (i/self.i_every) % len(self.keys)>>> e = SimpleEnv([1, 2, 3], [10, 20, 30])
>>> for t in np.arange(dt, 100, period * 3):
... i = e.step(t)
>>>>>> e = SimpleEnv([1, 2, 3], [10, 20, 30])
>>> e.step(period * 1); e.get_key()
(0, 1)
>>> e.step(period * 4), e.get_key()
shuffling
(0, 2)
>>> e.step(period * 1), e.get_key()
(0, 2)keys = np.arange(4)
values = np.arange(1, 5)def reshuffled_indexes(n, m, state=np.random):
"""Return an array of shape (n,) consisting of shuffled indexes in the
range [0, m). The indexes are re-shuffled each time they have all
been used. Optional argument state may be a np.random.RandomState
instance for random reproducibility.
"""
out = np.empty((n,), dtype=int)
r = n % m
for i in range(0, n - r, m):
out[i: i + m] = state.permutation(m)
if r:
out[-r:] = state.permutation(m)[:r]
return outContext
StackExchange Code Review Q#145471, answer score: 4
Revisions (0)
No revisions yet.