patternpythonModerate
Game of Life with NumPy
Viewed 0 times
lifegamewithnumpy
Problem
I started this exercise with NumPy with a goal to find neighbors and return the new matrix. I want to get your feedback. Here's an example from this website. It looks like it's \$O(N^2)\$, and I'm adding a internal loop to look around neighbors.
import numpy as np
import pprint
world = np.array([[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0]])
pprint.pprint(world)
size = world.shape[0]
def next_state(world):
"""
:param world:
:return:
"""
size = world.shape[0]
neighbors = np.zeros(shape=(size, size), dtype=int)
new_world = np.zeros(shape=(size, size), dtype=int)
neighbor_count = 0
# Ignore edges: start xrange: in 1
for rows in xrange(1, size - 1):
for cols in xrange(1, size - 1):
# Check neighbors
for i in [-1, 0, 1]:
for j in [-1, 0, 1]:
# Condition to not count existing cell.
if rows + i != rows or cols + j != cols:
neighbor_count += world[rows + i][cols + j]
neighbors[rows][cols] = neighbor_count
if neighbors[rows][cols] == 3 or (world[rows][cols] == 1 and neighbors[rows][cols] == 2):
new_world[rows][cols] = 1
else:
new_world[rows][cols] = 0
neighbor_count = 0
pprint.pprint(neighbors)
return new_world
print next_state(world)Solution
-
That
-
Can divide the code into two classes. One for world, the other for the engine. World can have the world array and visualization. Engine can have the neighbor array.
-
Actually the neighbor array can be much smaller than the world if we update the world from left to right.
-
Python loop over each element (the row and col loops) is much slower than numpy's method. Can vectorize counting of neighbor by shifting the world and add to neighbor:
.
Draw animation of world with matplotlib:
That
next_state function creates two brand new numpy array. Creating numpy array is slow. Should just update an existing numpy array.-
Can divide the code into two classes. One for world, the other for the engine. World can have the world array and visualization. Engine can have the neighbor array.
-
Actually the neighbor array can be much smaller than the world if we update the world from left to right.
-
Python loop over each element (the row and col loops) is much slower than numpy's method. Can vectorize counting of neighbor by shifting the world and add to neighbor:
.
neighbor = np.zeros(world.shape, dtype=int)
neighbor[1:] += world[:-1] # North
neighbor[:-1] += world[1:] # South
neighbor[:,1:] += world[:,:-1] # West
neighbor[:,:-1] += world[:,1:] # East
neighbor[1:,1:] += world[:-1,:-1] # NW
neighbor[1:,1:] += world[:-1,:-1] # NEDraw animation of world with matplotlib:
import numpy as np
import matplotlib.pyplot as plt
class World(object):
def __init__(self, shape, random=True, dtype=np.int8):
if random:
self.data = np.random.randint(0, 2, size=shape, dtype=dtype)
else:
self.data = np.zeros(shape, dtype=dtype)
self.shape = self.data.shape
self.dtype = dtype
self._engine = Engine(self)
self.step = 0
def animate(self):
return Animate(self).animate()
def __str__(self):
# probably can make a nicer text output here.
return self.data.__str__()
class Animate(object):
def __init__(self, world):
self.world = world
self.im = None
def animate(self):
while (True):
if self.world.step == 0:
plt.ion()
self.im = plt.imshow(self.world.data,vmin=0,vmax=2,
cmap=plt.cm.gray)
else:
self.im.set_data(self.world.data)
self.world.step += 1
self.world._engine.next_state()
plt.pause(0.01)
yield self.world
class Engine(object):
def __init__(self, world, dtype=np.int8):
self._world = world
self.shape = world.shape
self.neighbor = np.zeros(world.shape, dtype=dtype)
self._neighbor_id = self._make_neighbor_indices()
def _make_neighbor_indices(self):
# create a list of 2D indices that represents the neighbors of each
# cell such that list[i] and list[7-i] represents the neighbor at
# opposite directions. The neighbors are at North, NE, E, SE, S, SW,
# W, NE directions.
d = [slice(None), slice(1, None), slice(0, -1)]
d2 = [
(0, 1), (1, 1), (1, 0), (1, -1)
]
out = [None for i in range(8)]
for i, idx in enumerate(d2):
x, y = idx
out[i] = [d[x], d[y]]
out[7 - i] = [d[-x], d[-y]]
return out
def _count_neighbors(self):
self.neighbor[:, :] = 0 # reset neighbors
# count #neighbors of each cell.
w = self._world.data
n_id = self._neighbor_id
n = self.neighbor
for i in range(8):
n[n_id[i]] += w[n_id[7 - i]]
def _update_world(self):
w = self._world.data
n = self.neighbor
# The rules:
# cell neighbor cell's next state
# --------- -------- -----------------
# 1. live 3 dead
# 4. dead 3 live
# Simplified rules:
# cell neighbor cell's next state
# --------- -------- -----------------
# 1. live 2 live
# 2. live/dead 3 live
# 3. Otherwise, dead.
w &= (n == 2) # alive if it was alive and has 2 neighbors
w |= (n == 3) # alive if it has 3 neighbors
def next_state(self):
self._count_neighbors()
self._update_world()
def main():
world = World((1000, 1000))
for w in world.animate():
pass
if __name__ == '__main__':
main()Code Snippets
neighbor = np.zeros(world.shape, dtype=int)
neighbor[1:] += world[:-1] # North
neighbor[:-1] += world[1:] # South
neighbor[:,1:] += world[:,:-1] # West
neighbor[:,:-1] += world[:,1:] # East
neighbor[1:,1:] += world[:-1,:-1] # NW
neighbor[1:,1:] += world[:-1,:-1] # NEimport numpy as np
import matplotlib.pyplot as plt
class World(object):
def __init__(self, shape, random=True, dtype=np.int8):
if random:
self.data = np.random.randint(0, 2, size=shape, dtype=dtype)
else:
self.data = np.zeros(shape, dtype=dtype)
self.shape = self.data.shape
self.dtype = dtype
self._engine = Engine(self)
self.step = 0
def animate(self):
return Animate(self).animate()
def __str__(self):
# probably can make a nicer text output here.
return self.data.__str__()
class Animate(object):
def __init__(self, world):
self.world = world
self.im = None
def animate(self):
while (True):
if self.world.step == 0:
plt.ion()
self.im = plt.imshow(self.world.data,vmin=0,vmax=2,
cmap=plt.cm.gray)
else:
self.im.set_data(self.world.data)
self.world.step += 1
self.world._engine.next_state()
plt.pause(0.01)
yield self.world
class Engine(object):
def __init__(self, world, dtype=np.int8):
self._world = world
self.shape = world.shape
self.neighbor = np.zeros(world.shape, dtype=dtype)
self._neighbor_id = self._make_neighbor_indices()
def _make_neighbor_indices(self):
# create a list of 2D indices that represents the neighbors of each
# cell such that list[i] and list[7-i] represents the neighbor at
# opposite directions. The neighbors are at North, NE, E, SE, S, SW,
# W, NE directions.
d = [slice(None), slice(1, None), slice(0, -1)]
d2 = [
(0, 1), (1, 1), (1, 0), (1, -1)
]
out = [None for i in range(8)]
for i, idx in enumerate(d2):
x, y = idx
out[i] = [d[x], d[y]]
out[7 - i] = [d[-x], d[-y]]
return out
def _count_neighbors(self):
self.neighbor[:, :] = 0 # reset neighbors
# count #neighbors of each cell.
w = self._world.data
n_id = self._neighbor_id
n = self.neighbor
for i in range(8):
n[n_id[i]] += w[n_id[7 - i]]
def _update_world(self):
w = self._world.data
n = self.neighbor
# The rules:
# cell neighbor cell's next state
# --------- -------- -----------------
# 1. live < 2 dead
# 2. live 2 or 3 live
# 3. live > 3 dead
# 4. dead 3 live
# Simplified rules:
# cell neighbor cell's next state
# --------- -------- -----------------
# 1. live 2 live
# 2. live/dead 3 live
# 3. Otherwise, dead.
w &= (n == 2) # alive if it was alive and has 2 neighbors
w |= (n == 3) Context
StackExchange Code Review Q#160802, answer score: 10
Revisions (0)
No revisions yet.