HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

3-level deep if-else tree for constructing a neural network

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
elselevelneuraldeepconstructingfornetworktree

Problem

The following code fragment constructs different types of neural networks outputs based on the options supplied. Currently, my code just has a huge note that describes what all the options are supposed to do. I wanted to learn how I can make this code less complex or perhaps break its complexity into simpler pieces. Any suggestions?

```
#------------------------------------------------------------#
# NOTE: Meaning of all the options. #
# stagger_schedule=extended: We copy input vec to output. #
# stagger_schedule=external: We dont copy input to output. #
# -----------------------------------------------------------#
# do_backward_pass: We use the output of the backward LSTM #
# Default:True. #
# -----------------------------------------------------------#
# chop_bilstm: Should we chop the first and last vectors from#
# the sequence. Default:False #
#------------------------------------------------------------#
# extended_multiplicative: Multiply the forward and back LSTM#
# and concatenate the input embedding. #
# external_multiplicative: Multiply the forward and back LSTM#
# and but dont concatenate the input embedding. #
#------------------------------------------------------------#
if (self.prm('stagger_schedule') == 'extended'):
if self.prm('chop_bilstm'):
if self.prm('do_backward_pass'):
self.output_tv = T.concatenate(
[forward, backward, input_tv], axis=1)[1:-1]
pass
else:
self.output_tv = T.concatenate(
[forward, input_tv], axis=1)[1:-1]
pass
pass
else:
if self.prm('do_backward_pass'):
self.output_tv = T.concatenate(
[forward, backward, input_tv], axis=1)
pass
else:
self.output_tv = T.concatenate(
[forward, input_tv], axis=1)

Solution

Here are some ways you can tidy up this code:

-
Add some comments. It’s quite hard for me to tell whether this code is correct, or whether I‘ve introduced a bug by refactoring, because I don’t know what this code is supposed to do.

-
Get rid of the unnecessary pass statements. The pass statement literally does nothing except provide a placeholder for unwritten code. If you delete them all, you’ll save a lot of lines and be able to fit more code on screen.

-
Under the extended branch, the code is almost the same except we remove the first and last character of self.output_tv. If we defer that until the end, we can have one set of branches as follows:

if (self.prm('stagger_schedule') == 'extended'):
    if self.prm('do_backward_pass'):
        self.output_tv = T.concatenate([forward, backward, input_tv], axis=1)
    else:
        self.output_tv = T.concatenate([forward, input_tv], axis=1)

    if self.prm('chop_bilstm'):
        self.output_tv = self.output_tv[1:-1]


22 lines cut down to 7, and this is only 2 levels deep.

-
The code in the external branch can be similarly consolidated: you do the same effect in both subbranches, but with slightly truncated forward and backward variables. Here’s an alternative version:

elif self.prm('stagger_schedule') == 'external':
    if self.prm('chop_bilstm'):
        forward = forward[1:-1]
        backward = backward[2:]

    if self.prm('do_backward_pass'):
        self.output_tv = T.concatenate([forward, backward, axis=1)
    else:
        self.output_tv = forward


20 lines cut down to 8, and this is only 2 levels deep.

-
You could consider wrapping each branch into its own method, and calling into that. For example, something like:

if (self.prm('stagger_schedule') == 'extended'):
    self._set_output_tv_extended(forward, backward, input_tv)
elif self.prm('stagger_schedule') == 'external':
    self._set_output_tv_external(forward, backward, input_tv)
elif self.prm('stagger_schedule') == 'extended_multiplicative':
    self._set_output_tv_extended_multiplicative(forward, backward, input_tv)
elif self.prm('stagger_schedule') == 'external_multiplicative':
    self._set_output_tv_external_multiplicative(forward, backward, input_tv)
else:
    raise NotImplementedError()


That saves you an immediate level of nesting and pushes the specific logic of each branch out of this method.

Code Snippets

if (self.prm('stagger_schedule') == 'extended'):
    if self.prm('do_backward_pass'):
        self.output_tv = T.concatenate([forward, backward, input_tv], axis=1)
    else:
        self.output_tv = T.concatenate([forward, input_tv], axis=1)

    if self.prm('chop_bilstm'):
        self.output_tv = self.output_tv[1:-1]
elif self.prm('stagger_schedule') == 'external':
    if self.prm('chop_bilstm'):
        forward = forward[1:-1]
        backward = backward[2:]

    if self.prm('do_backward_pass'):
        self.output_tv = T.concatenate([forward, backward, axis=1)
    else:
        self.output_tv = forward
if (self.prm('stagger_schedule') == 'extended'):
    self._set_output_tv_extended(forward, backward, input_tv)
elif self.prm('stagger_schedule') == 'external':
    self._set_output_tv_external(forward, backward, input_tv)
elif self.prm('stagger_schedule') == 'extended_multiplicative':
    self._set_output_tv_extended_multiplicative(forward, backward, input_tv)
elif self.prm('stagger_schedule') == 'external_multiplicative':
    self._set_output_tv_external_multiplicative(forward, backward, input_tv)
else:
    raise NotImplementedError()

Context

StackExchange Code Review Q#117028, answer score: 3

Revisions (0)

No revisions yet.