HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonModerate

Game of Life with NumPy

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
lifegamewithnumpy

Problem

I started this exercise with NumPy with a goal to find neighbors and return the new matrix. I want to get your feedback. Here's an example from this website. It looks like it's \$O(N^2)\$, and I'm adding a internal loop to look around neighbors.

import numpy as np
import pprint

world = np.array([[0, 0, 0, 0, 0],
                  [0, 0, 1, 0, 0],
                  [0, 0, 1, 0, 0],
                  [0, 0, 1, 0, 0],
                  [0, 0, 0, 0, 0]])

pprint.pprint(world)
size = world.shape[0]

def next_state(world):
    """

    :param world:
    :return:
    """
    size = world.shape[0]
    neighbors = np.zeros(shape=(size, size), dtype=int)
    new_world = np.zeros(shape=(size, size), dtype=int)
    neighbor_count = 0
    # Ignore edges: start xrange: in 1
    for rows in xrange(1, size - 1):
        for cols in xrange(1, size - 1):
            # Check neighbors
            for i in [-1, 0, 1]:
                for j in [-1, 0, 1]:
                    # Condition to not count existing cell.
                    if rows + i != rows or cols + j != cols:
                        neighbor_count += world[rows + i][cols + j]
                        neighbors[rows][cols] = neighbor_count

            if neighbors[rows][cols] == 3 or (world[rows][cols] == 1 and neighbors[rows][cols] == 2):
                new_world[rows][cols] = 1
            else:
                new_world[rows][cols] = 0
            neighbor_count = 0

    pprint.pprint(neighbors)
    return new_world

print next_state(world)

Solution

-
That next_state function creates two brand new numpy array. Creating numpy array is slow. Should just update an existing numpy array.

-
Can divide the code into two classes. One for world, the other for the engine. World can have the world array and visualization. Engine can have the neighbor array.

-
Actually the neighbor array can be much smaller than the world if we update the world from left to right.

-
Python loop over each element (the row and col loops) is much slower than numpy's method. Can vectorize counting of neighbor by shifting the world and add to neighbor:

.

neighbor = np.zeros(world.shape, dtype=int)
neighbor[1:] += world[:-1]  # North
neighbor[:-1] += world[1:]  # South
neighbor[:,1:] += world[:,:-1]  # West
neighbor[:,:-1] += world[:,1:]  # East

neighbor[1:,1:] += world[:-1,:-1]  # NW
neighbor[1:,1:] += world[:-1,:-1]  # NE


Draw animation of world with matplotlib:

import numpy as np
import matplotlib.pyplot as plt

class World(object):
    def __init__(self, shape, random=True, dtype=np.int8):
        if random:
            self.data = np.random.randint(0, 2, size=shape, dtype=dtype)
        else:
            self.data = np.zeros(shape, dtype=dtype)
        self.shape = self.data.shape
        self.dtype = dtype
        self._engine = Engine(self)

        self.step = 0

    def animate(self):
        return Animate(self).animate()

    def __str__(self):
        # probably can make a nicer text output here.
        return self.data.__str__()

class Animate(object):
    def __init__(self, world):
        self.world = world
        self.im = None

    def animate(self):
        while (True):
            if self.world.step == 0:
                plt.ion()
                self.im = plt.imshow(self.world.data,vmin=0,vmax=2,
                                     cmap=plt.cm.gray)
            else:
                self.im.set_data(self.world.data)

            self.world.step += 1
            self.world._engine.next_state()
            plt.pause(0.01)
            yield self.world

class Engine(object):
    def __init__(self, world, dtype=np.int8):
        self._world = world
        self.shape = world.shape
        self.neighbor = np.zeros(world.shape, dtype=dtype)
        self._neighbor_id = self._make_neighbor_indices()

    def _make_neighbor_indices(self):
        # create a list of 2D indices that represents the neighbors of each
        # cell such that list[i] and list[7-i] represents the neighbor at
        # opposite directions. The neighbors are at North, NE, E, SE, S, SW,
        # W, NE directions.
        d = [slice(None), slice(1, None), slice(0, -1)]
        d2 = [
            (0, 1), (1, 1), (1, 0), (1, -1)
        ]
        out = [None for i in range(8)]
        for i, idx in enumerate(d2):
            x, y = idx
            out[i] = [d[x], d[y]]
            out[7 - i] = [d[-x], d[-y]]
        return out

    def _count_neighbors(self):
        self.neighbor[:, :] = 0  # reset neighbors
        # count #neighbors of each cell.
        w = self._world.data
        n_id = self._neighbor_id
        n = self.neighbor
        for i in range(8):
            n[n_id[i]] += w[n_id[7 - i]]

    def _update_world(self):
        w = self._world.data
        n = self.neighbor

        # The rules:
        #    cell        neighbor    cell's next state
        #    ---------   --------    -----------------
        # 1. live         3         dead
        # 4. dead        3           live

        # Simplified rules:
        #    cell        neighbor    cell's next state
        #    ---------   --------    -----------------
        # 1. live        2           live
        # 2. live/dead   3           live
        # 3. Otherwise, dead.

        w &= (n == 2)  # alive if it was alive and has 2 neighbors
        w |= (n == 3)  # alive if it has 3 neighbors

    def next_state(self):
        self._count_neighbors()
        self._update_world()

def main():
    world = World((1000, 1000))

    for w in world.animate():
        pass

if __name__ == '__main__':
    main()

Code Snippets

neighbor = np.zeros(world.shape, dtype=int)
neighbor[1:] += world[:-1]  # North
neighbor[:-1] += world[1:]  # South
neighbor[:,1:] += world[:,:-1]  # West
neighbor[:,:-1] += world[:,1:]  # East

neighbor[1:,1:] += world[:-1,:-1]  # NW
neighbor[1:,1:] += world[:-1,:-1]  # NE
import numpy as np
import matplotlib.pyplot as plt


class World(object):
    def __init__(self, shape, random=True, dtype=np.int8):
        if random:
            self.data = np.random.randint(0, 2, size=shape, dtype=dtype)
        else:
            self.data = np.zeros(shape, dtype=dtype)
        self.shape = self.data.shape
        self.dtype = dtype
        self._engine = Engine(self)

        self.step = 0

    def animate(self):
        return Animate(self).animate()

    def __str__(self):
        # probably can make a nicer text output here.
        return self.data.__str__()


class Animate(object):
    def __init__(self, world):
        self.world = world
        self.im = None

    def animate(self):
        while (True):
            if self.world.step == 0:
                plt.ion()
                self.im = plt.imshow(self.world.data,vmin=0,vmax=2,
                                     cmap=plt.cm.gray)
            else:
                self.im.set_data(self.world.data)

            self.world.step += 1
            self.world._engine.next_state()
            plt.pause(0.01)
            yield self.world


class Engine(object):
    def __init__(self, world, dtype=np.int8):
        self._world = world
        self.shape = world.shape
        self.neighbor = np.zeros(world.shape, dtype=dtype)
        self._neighbor_id = self._make_neighbor_indices()

    def _make_neighbor_indices(self):
        # create a list of 2D indices that represents the neighbors of each
        # cell such that list[i] and list[7-i] represents the neighbor at
        # opposite directions. The neighbors are at North, NE, E, SE, S, SW,
        # W, NE directions.
        d = [slice(None), slice(1, None), slice(0, -1)]
        d2 = [
            (0, 1), (1, 1), (1, 0), (1, -1)
        ]
        out = [None for i in range(8)]
        for i, idx in enumerate(d2):
            x, y = idx
            out[i] = [d[x], d[y]]
            out[7 - i] = [d[-x], d[-y]]
        return out

    def _count_neighbors(self):
        self.neighbor[:, :] = 0  # reset neighbors
        # count #neighbors of each cell.
        w = self._world.data
        n_id = self._neighbor_id
        n = self.neighbor
        for i in range(8):
            n[n_id[i]] += w[n_id[7 - i]]

    def _update_world(self):
        w = self._world.data
        n = self.neighbor

        # The rules:
        #    cell        neighbor    cell's next state
        #    ---------   --------    -----------------
        # 1. live        < 2         dead
        # 2. live        2 or 3      live
        # 3. live        > 3         dead
        # 4. dead        3           live

        # Simplified rules:
        #    cell        neighbor    cell's next state
        #    ---------   --------    -----------------
        # 1. live        2           live
        # 2. live/dead   3           live
        # 3. Otherwise, dead.

        w &= (n == 2)  # alive if it was alive and has 2 neighbors
        w |= (n == 3) 

Context

StackExchange Code Review Q#160802, answer score: 10

Revisions (0)

No revisions yet.