patternpythonMinor
3-level deep if-else tree for constructing a neural network
Viewed 0 times
elselevelneuraldeepconstructingfornetworktree
Problem
The following code fragment constructs different types of neural networks outputs based on the options supplied. Currently, my code just has a huge note that describes what all the options are supposed to do. I wanted to learn how I can make this code less complex or perhaps break its complexity into simpler pieces. Any suggestions?
```
#------------------------------------------------------------#
# NOTE: Meaning of all the options. #
# stagger_schedule=extended: We copy input vec to output. #
# stagger_schedule=external: We dont copy input to output. #
# -----------------------------------------------------------#
# do_backward_pass: We use the output of the backward LSTM #
# Default:True. #
# -----------------------------------------------------------#
# chop_bilstm: Should we chop the first and last vectors from#
# the sequence. Default:False #
#------------------------------------------------------------#
# extended_multiplicative: Multiply the forward and back LSTM#
# and concatenate the input embedding. #
# external_multiplicative: Multiply the forward and back LSTM#
# and but dont concatenate the input embedding. #
#------------------------------------------------------------#
if (self.prm('stagger_schedule') == 'extended'):
if self.prm('chop_bilstm'):
if self.prm('do_backward_pass'):
self.output_tv = T.concatenate(
[forward, backward, input_tv], axis=1)[1:-1]
pass
else:
self.output_tv = T.concatenate(
[forward, input_tv], axis=1)[1:-1]
pass
pass
else:
if self.prm('do_backward_pass'):
self.output_tv = T.concatenate(
[forward, backward, input_tv], axis=1)
pass
else:
self.output_tv = T.concatenate(
[forward, input_tv], axis=1)
```
#------------------------------------------------------------#
# NOTE: Meaning of all the options. #
# stagger_schedule=extended: We copy input vec to output. #
# stagger_schedule=external: We dont copy input to output. #
# -----------------------------------------------------------#
# do_backward_pass: We use the output of the backward LSTM #
# Default:True. #
# -----------------------------------------------------------#
# chop_bilstm: Should we chop the first and last vectors from#
# the sequence. Default:False #
#------------------------------------------------------------#
# extended_multiplicative: Multiply the forward and back LSTM#
# and concatenate the input embedding. #
# external_multiplicative: Multiply the forward and back LSTM#
# and but dont concatenate the input embedding. #
#------------------------------------------------------------#
if (self.prm('stagger_schedule') == 'extended'):
if self.prm('chop_bilstm'):
if self.prm('do_backward_pass'):
self.output_tv = T.concatenate(
[forward, backward, input_tv], axis=1)[1:-1]
pass
else:
self.output_tv = T.concatenate(
[forward, input_tv], axis=1)[1:-1]
pass
pass
else:
if self.prm('do_backward_pass'):
self.output_tv = T.concatenate(
[forward, backward, input_tv], axis=1)
pass
else:
self.output_tv = T.concatenate(
[forward, input_tv], axis=1)
Solution
Here are some ways you can tidy up this code:
-
Add some comments. It’s quite hard for me to tell whether this code is correct, or whether I‘ve introduced a bug by refactoring, because I don’t know what this code is supposed to do.
-
Get rid of the unnecessary pass statements. The
-
Under the
22 lines cut down to 7, and this is only 2 levels deep.
-
The code in the
20 lines cut down to 8, and this is only 2 levels deep.
-
You could consider wrapping each branch into its own method, and calling into that. For example, something like:
That saves you an immediate level of nesting and pushes the specific logic of each branch out of this method.
-
Add some comments. It’s quite hard for me to tell whether this code is correct, or whether I‘ve introduced a bug by refactoring, because I don’t know what this code is supposed to do.
-
Get rid of the unnecessary pass statements. The
pass statement literally does nothing except provide a placeholder for unwritten code. If you delete them all, you’ll save a lot of lines and be able to fit more code on screen.-
Under the
extended branch, the code is almost the same except we remove the first and last character of self.output_tv. If we defer that until the end, we can have one set of branches as follows:if (self.prm('stagger_schedule') == 'extended'):
if self.prm('do_backward_pass'):
self.output_tv = T.concatenate([forward, backward, input_tv], axis=1)
else:
self.output_tv = T.concatenate([forward, input_tv], axis=1)
if self.prm('chop_bilstm'):
self.output_tv = self.output_tv[1:-1]22 lines cut down to 7, and this is only 2 levels deep.
-
The code in the
external branch can be similarly consolidated: you do the same effect in both subbranches, but with slightly truncated forward and backward variables. Here’s an alternative version:elif self.prm('stagger_schedule') == 'external':
if self.prm('chop_bilstm'):
forward = forward[1:-1]
backward = backward[2:]
if self.prm('do_backward_pass'):
self.output_tv = T.concatenate([forward, backward, axis=1)
else:
self.output_tv = forward20 lines cut down to 8, and this is only 2 levels deep.
-
You could consider wrapping each branch into its own method, and calling into that. For example, something like:
if (self.prm('stagger_schedule') == 'extended'):
self._set_output_tv_extended(forward, backward, input_tv)
elif self.prm('stagger_schedule') == 'external':
self._set_output_tv_external(forward, backward, input_tv)
elif self.prm('stagger_schedule') == 'extended_multiplicative':
self._set_output_tv_extended_multiplicative(forward, backward, input_tv)
elif self.prm('stagger_schedule') == 'external_multiplicative':
self._set_output_tv_external_multiplicative(forward, backward, input_tv)
else:
raise NotImplementedError()That saves you an immediate level of nesting and pushes the specific logic of each branch out of this method.
Code Snippets
if (self.prm('stagger_schedule') == 'extended'):
if self.prm('do_backward_pass'):
self.output_tv = T.concatenate([forward, backward, input_tv], axis=1)
else:
self.output_tv = T.concatenate([forward, input_tv], axis=1)
if self.prm('chop_bilstm'):
self.output_tv = self.output_tv[1:-1]elif self.prm('stagger_schedule') == 'external':
if self.prm('chop_bilstm'):
forward = forward[1:-1]
backward = backward[2:]
if self.prm('do_backward_pass'):
self.output_tv = T.concatenate([forward, backward, axis=1)
else:
self.output_tv = forwardif (self.prm('stagger_schedule') == 'extended'):
self._set_output_tv_extended(forward, backward, input_tv)
elif self.prm('stagger_schedule') == 'external':
self._set_output_tv_external(forward, backward, input_tv)
elif self.prm('stagger_schedule') == 'extended_multiplicative':
self._set_output_tv_extended_multiplicative(forward, backward, input_tv)
elif self.prm('stagger_schedule') == 'external_multiplicative':
self._set_output_tv_external_multiplicative(forward, backward, input_tv)
else:
raise NotImplementedError()Context
StackExchange Code Review Q#117028, answer score: 3
Revisions (0)
No revisions yet.