HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

Python factory method with easy registry

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
methodfactorywithregistryeasypython

Problem

My aim is to define a set of classes, each providing methods for comparing a particular type of file. My idea is to use some kind of factory method to instantiate the class based upon a string, which could allow new classes to be added easily. Then it would be simple to loop over a dictionary like:

files = {
    'csv': ('file1.csv', 'file2.csv'),
    'bin': ('file3.bin', 'file4.bin')
}


Here is what I have so far:

```
# results/__init__.py
class ResultDiffException(Exception):
pass

class ResultDiff(object):
"""Base class that enables comparison of result files."""
def __init__(self, path_test, path_ref):
self.path_t = path_test
self.path_r = path_ref

def max(self):
raise NotImplementedError('abstract method')

def min(self):
raise NotImplementedError('abstract method')

def mean(self):
raise NotImplementedError('abstract method')

# results/numeric.py
import numpy as np
from results import ResultDiff, ResultDiffException

class NumericArrayDiff(ResultDiff):

def __init__(self, *args, **kwargs):
super(NumericArrayDiff, self).__init__(*args, **kwargs)

self.data_t = self._read_file(self.path_t)
self.data_r = self._read_file(self.path_r)

if self.data_t.shape != self.data_r.shape:
raise ResultDiffException('Inconsistent array shape')

np.seterr(divide='ignore', invalid='ignore')
self.diff = (self.data_t - self.data_r) / self.data_r
both_zero_ind = np.nonzero((self.data_t == 0) & (self.data_r == 0))
self.diff[both_zero_ind] = 0

def _read_file(self, path):
return np.loadtxt(path, ndmin=1)

def max(self):
return np.amax(self.diff)

def min(self):
return np.amin(self.diff)

def mean(self):
return np.mean(self.diff)

class CsvDiff(NumericArrayDiff):

def __init__(self, *args, **kwargs):
super(CsvDiff, self).__init__(*args, **kwargs)

def _read_file(self, pa

Solution


  1. Stop writing classes



The title for this section comes from Jack Diederich's PyCon 2012 talk.

A class represents a group of objects with similar behaviour, and an object represents some kind of persistent thing. So when deciding what classes a program is going to need, the first question to ask is, "what kind of persistent things does this program need to represent?"

In this case the program:

  • knows how to load NumPy arrays from different kinds of file format (CSV and plain text); and



  • knows how to compute the relative difference between two NumPy arrays (so long as they come from files with the same format).



The only persistent things here are files (represented by Python file objects) and NumPy arrays (represented by numpy.ndarray objects). So there's no need for any more classes.

  1. Other review points



-
The code calls numpy.seterr to suppress the warning:

RuntimeWarning: invalid value encountered in true_divide


but it fails to restore the original error state, whatever it was. This might be an unpleasant surprise for the caller. It would be better to use the numpy.errstate context manager to ensure that the original error state is restored.

-
When dispatching to NumPy functions, it usually unncessary to check shapes for compatibility and raise your own error. Instead, just pass the arrays to NumPy. If they can't be combined, then NumPy will raise:

ValueError: operands could not be broadcast together ...


  1. Revised code



Instead of classes, write a function!

import numpy as np

def relative_difference(t, r):
    """Return the relative difference between arrays t and r, that is:

    0              where t == 0 and r == 0
    (t - r) / r    otherwise

    """
    t, r = np.asarray(t), np.asarray(r)
    with np.errstate(divide='ignore', invalid='ignore'):
        return np.where((t == 0) & (r == 0), 0, (t - r) / r)


Note the following advantages over the original code:

-
It's much shorter, and so there's much less code to maintain.

-
It can find the difference between arrays that come from files with different formats:

relative_difference(np.loadtxt(path1), np.fromfile(path2))


-
It can find the difference between arrays that don't come from files at all:

relative_difference(np.random.randint(0, 10, (10,)), np.arange(1, 11))

Code Snippets

RuntimeWarning: invalid value encountered in true_divide
ValueError: operands could not be broadcast together ...
import numpy as np

def relative_difference(t, r):
    """Return the relative difference between arrays t and r, that is:

    0              where t == 0 and r == 0
    (t - r) / r    otherwise

    """
    t, r = np.asarray(t), np.asarray(r)
    with np.errstate(divide='ignore', invalid='ignore'):
        return np.where((t == 0) & (r == 0), 0, (t - r) / r)
relative_difference(np.loadtxt(path1), np.fromfile(path2))
relative_difference(np.random.randint(0, 10, (10,)), np.arange(1, 11))

Context

StackExchange Code Review Q#91228, answer score: 3

Revisions (0)

No revisions yet.