From 7b86449f39cb9a8c13c72a914cb24e213acd93bd Mon Sep 17 00:00:00 2001 From: riccardo ughi Date: Sun, 24 Oct 2021 13:55:39 +0200 Subject: [PATCH] Update rnn.py Add capability for LSTN layer: 1) getting a complete state RNN (c, h) from the stack as input instead to use a zero initialization . Use initial_state=True parameter for that. Default False. 2) putting a complete state RNN (c, h) in the stack as output instead to discard it. Use return_state=True parameter for that . Default False. --- trax/layers/rnn.py | 44 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/trax/layers/rnn.py b/trax/layers/rnn.py index 51969b417..f7b05480f 100644 --- a/trax/layers/rnn.py +++ b/trax/layers/rnn.py @@ -90,17 +90,41 @@ def f(x): # pylint: disable=invalid-name return base.Fn('MakeZeroState', f) -def LSTM(n_units, mode='train'): - """LSTM running on axis 1.""" - zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter - return cb.Serial( - cb.Branch([], zero_state), - cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), - cb.Select([0], n_in=2), # Drop RNN state. - # Set the name to LSTM and don't print sublayers. - name=f'LSTM_{n_units}', sublayers_to_print=[] - ) +def LSTM(n_units, mode='train', return_state=False, initial_state=False): + """LSTM running on axis 1. + + Args: + mode: if 'predict' then we save the previous state for one-by-one inference. + return_state: Boolean. Whether to return the latest status in addition to the output. Default: False. + initial_state: Boolean. If the state RNN (c, h) is to be obtained from the stack. Default: False. + Returns: + A LSTM layer. + """ + + if not initial_state: + zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter + if return_state: + return cb.Serial( + cb.Branch([], zero_state), + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + name=f'LSTM_{n_units}', sublayers_to_print=[]) + else: + return cb.Serial( + cb.Branch([], zero_state), # fill state RNN with zero. + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + cb.Select([0], n_in=2), + name=f'LSTM_{n_units}', sublayers_to_print=[]) + else: + if return_state: + return cb.Serial( + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + name=f'LSTM_{n_units}', sublayers_to_print=[]) + else: + return cb.Serial( + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + cb.Select([0], n_in=2), + name=f'LSTM_{n_units}', sublayers_to_print=[]) class GRUCell(base.Layer): """Builds a traditional GRU cell with dense internal transformations.