diff --git a/trax/layers/rnn.py b/trax/layers/rnn.py index 51969b417..f7b05480f 100644 --- a/trax/layers/rnn.py +++ b/trax/layers/rnn.py @@ -90,17 +90,41 @@ def f(x): # pylint: disable=invalid-name return base.Fn('MakeZeroState', f) -def LSTM(n_units, mode='train'): - """LSTM running on axis 1.""" - zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter - return cb.Serial( - cb.Branch([], zero_state), - cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), - cb.Select([0], n_in=2), # Drop RNN state. - # Set the name to LSTM and don't print sublayers. - name=f'LSTM_{n_units}', sublayers_to_print=[] - ) +def LSTM(n_units, mode='train', return_state=False, initial_state=False): + """LSTM running on axis 1. + + Args: + mode: if 'predict' then we save the previous state for one-by-one inference. + return_state: Boolean. Whether to return the latest status in addition to the output. Default: False. + initial_state: Boolean. If the state RNN (c, h) is to be obtained from the stack. Default: False. + Returns: + A LSTM layer. + """ + + if not initial_state: + zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter + if return_state: + return cb.Serial( + cb.Branch([], zero_state), + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + name=f'LSTM_{n_units}', sublayers_to_print=[]) + else: + return cb.Serial( + cb.Branch([], zero_state), # fill state RNN with zero. + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + cb.Select([0], n_in=2), + name=f'LSTM_{n_units}', sublayers_to_print=[]) + else: + if return_state: + return cb.Serial( + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + name=f'LSTM_{n_units}', sublayers_to_print=[]) + else: + return cb.Serial( + cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), + cb.Select([0], n_in=2), + name=f'LSTM_{n_units}', sublayers_to_print=[]) class GRUCell(base.Layer): """Builds a traditional GRU cell with dense internal transformations.