Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions trax/layers/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,41 @@ def f(x): # pylint: disable=invalid-name
return base.Fn('MakeZeroState', f)


def LSTM(n_units, mode='train'):
"""LSTM running on axis 1."""
zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter
return cb.Serial(
cb.Branch([], zero_state),
cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
cb.Select([0], n_in=2), # Drop RNN state.
# Set the name to LSTM and don't print sublayers.
name=f'LSTM_{n_units}', sublayers_to_print=[]
)
def LSTM(n_units, mode='train', return_state=False, initial_state=False):
"""LSTM running on axis 1.

Args:
mode: if 'predict' then we save the previous state for one-by-one inference.
return_state: Boolean. Whether to return the latest status in addition to the output. Default: False.
initial_state: Boolean. If the state RNN (c, h) is to be obtained from the stack. Default: False.

Returns:
A LSTM layer.
"""

if not initial_state:
zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter
if return_state:
return cb.Serial(
cb.Branch([], zero_state),
cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
name=f'LSTM_{n_units}', sublayers_to_print=[])
else:
return cb.Serial(
cb.Branch([], zero_state), # fill state RNN with zero.
cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
cb.Select([0], n_in=2),
name=f'LSTM_{n_units}', sublayers_to_print=[])
else:
if return_state:
return cb.Serial(
cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
name=f'LSTM_{n_units}', sublayers_to_print=[])
else:
return cb.Serial(
cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
cb.Select([0], n_in=2),
name=f'LSTM_{n_units}', sublayers_to_print=[])

class GRUCell(base.Layer):
"""Builds a traditional GRU cell with dense internal transformations.
Expand Down