Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions trax/data/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,11 +920,12 @@ def _f(generator):
n_tokens = len(example)
if n_tokens <= max_length:
yield example
n_segments = int(math.ceil(float(n_tokens) / float(max_length)))
for i in range(n_segments):
start = max_length * i
end = min(start + max_length, n_tokens)
yield example[start:end]
else:
n_segments = int(math.ceil(float(n_tokens) / float(max_length)))
for i in range(n_segments):
start = max_length * i
end = min(start + max_length, n_tokens)
yield example[start:end]
return _f


Expand Down
1 change: 1 addition & 0 deletions trax/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,5 @@ def layer_configure(*args, **kwargs):
LinearPooling = layer_configure(LinearPooling)
LinearUpsampling = layer_configure(LinearUpsampling)
NoUpsampling = layer_configure(NoUpsampling)
NaiveUpsampling = layer_configure(NaiveUpsampling)
AttentionResampling = layer_configure(AttentionResampling)
7 changes: 6 additions & 1 deletion trax/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,12 @@ def init_weights_and_state(self, input_signature):
def PrintShape(n_in=1, msg=''):
"""Prints the shapes of `n_in` inputs and returns then unchanged."""
def Fwd(xs):
shapes_and_dtypes = ', '.join([str(x.shape) + f'[{x.dtype}]' for x in xs])
def format_shape(x):
return str(x.shape) + f'[{x.dtype}]'
if n_in > 1:
shapes_and_dtypes = ', '.join([format_shape(x) for x in xs])
else:
shapes_and_dtypes = format_shape(xs)
info = f'PrintShape: {msg}: [{shapes_and_dtypes}]'
print(info)
logging.info(info)
Expand Down
18 changes: 11 additions & 7 deletions trax/layers/research/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,23 @@ def LinearUpsampling(shorten_factor, d_model, *args, dropout=0.0, mode='train',
)


def NaiveUpsampling(shorten_factor, d_model, *args, **kwargs):
return core.Fn('Repeat', lambda x: jnp.repeat(x, shorten_factor, axis=1))


def NoUpsampling(shorten_factor, d_model, *args, **kwargs):
del d_model, args, kwargs

return core.Fn('ReturnZero', lambda x: jnp.zeros( # pylint: disable=g-long-lambda
(x.shape[0], x.shape[1] * shorten_factor, x.shape[2]), dtype=x.dtype))


def _FeedForwardBlock(d_model,
d_ff,
dropout,
dropout_shared_axes,
mode,
activation):
def FeedForwardBlock(d_model,
d_ff,
dropout,
dropout_shared_axes,
mode,
activation):
# We copy the ff block function because we cannot import it from models
return [
core.Dense(d_ff),
Expand All @@ -95,7 +99,7 @@ def AttentionResampling(shorten_factor, d_model, is_upsampling, d_ff, n_heads,
total_pooling, n_heads=n_heads, dropout=dropout,
mode=mode)

feed_forward = _FeedForwardBlock(
feed_forward = FeedForwardBlock(
d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation)

resampling = resampling_fn(shorten_factor, d_model,
Expand Down
10 changes: 2 additions & 8 deletions trax/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from trax.models.reformer import reformer
from trax.models.research import bert
from trax.models.research import configurable_transformer
from trax.models.research import funnel_transformer
from trax.models.research import hourglass
from trax.models.research import layerdrop_transformer
from trax.models.research import rezero
from trax.models.research import rse
Expand Down Expand Up @@ -88,11 +88,5 @@ def model_configure(*args, **kwargs):
RNNLM = model_configure(rnn.RNNLM)
GRULM = model_configure(rnn.GRULM)
LSTMSeq2SeqAttn = model_configure(rnn.LSTMSeq2SeqAttn)
FunnelTransformerEncoder = model_configure(
funnel_transformer.FunnelTransformerEncoder)
FunnelTransformer = model_configure(
funnel_transformer.FunnelTransformer)
ResidualShuffleExchange = model_configure(rse.ResidualShuffleExchange)
RelformerLM = model_configure(
funnel_transformer.RelformerLM)
RelformerChunkedLM = model_configure(funnel_transformer.RelformerChunkedLM)
HourglassLM = model_configure(hourglass.HourglassLM)
Loading