diff --git a/trax/data/inputs.py b/trax/data/inputs.py index 128e9b15e..fc3260dab 100644 --- a/trax/data/inputs.py +++ b/trax/data/inputs.py @@ -920,11 +920,12 @@ def _f(generator): n_tokens = len(example) if n_tokens <= max_length: yield example - n_segments = int(math.ceil(float(n_tokens) / float(max_length))) - for i in range(n_segments): - start = max_length * i - end = min(start + max_length, n_tokens) - yield example[start:end] + else: + n_segments = int(math.ceil(float(n_tokens) / float(max_length))) + for i in range(n_segments): + start = max_length * i + end = min(start + max_length, n_tokens) + yield example[start:end] return _f diff --git a/trax/layers/__init__.py b/trax/layers/__init__.py index eae30fa9f..f16ccf495 100644 --- a/trax/layers/__init__.py +++ b/trax/layers/__init__.py @@ -122,4 +122,5 @@ def layer_configure(*args, **kwargs): LinearPooling = layer_configure(LinearPooling) LinearUpsampling = layer_configure(LinearUpsampling) NoUpsampling = layer_configure(NoUpsampling) +NaiveUpsampling = layer_configure(NaiveUpsampling) AttentionResampling = layer_configure(AttentionResampling) diff --git a/trax/layers/core.py b/trax/layers/core.py index 2ba7a5ed7..fd9cba23a 100644 --- a/trax/layers/core.py +++ b/trax/layers/core.py @@ -311,7 +311,12 @@ def init_weights_and_state(self, input_signature): def PrintShape(n_in=1, msg=''): """Prints the shapes of `n_in` inputs and returns then unchanged.""" def Fwd(xs): - shapes_and_dtypes = ', '.join([str(x.shape) + f'[{x.dtype}]' for x in xs]) + def format_shape(x): + return str(x.shape) + f'[{x.dtype}]' + if n_in > 1: + shapes_and_dtypes = ', '.join([format_shape(x) for x in xs]) + else: + shapes_and_dtypes = format_shape(xs) info = f'PrintShape: {msg}: [{shapes_and_dtypes}]' print(info) logging.info(info) diff --git a/trax/layers/research/resampling.py b/trax/layers/research/resampling.py index 2569d8976..7b4f5e46d 100644 --- a/trax/layers/research/resampling.py +++ b/trax/layers/research/resampling.py @@ -61,6 +61,10 @@ def LinearUpsampling(shorten_factor, d_model, *args, dropout=0.0, mode='train', ) +def NaiveUpsampling(shorten_factor, d_model, *args, **kwargs): + return core.Fn('Repeat', lambda x: jnp.repeat(x, shorten_factor, axis=1)) + + def NoUpsampling(shorten_factor, d_model, *args, **kwargs): del d_model, args, kwargs @@ -68,12 +72,12 @@ def NoUpsampling(shorten_factor, d_model, *args, **kwargs): (x.shape[0], x.shape[1] * shorten_factor, x.shape[2]), dtype=x.dtype)) -def _FeedForwardBlock(d_model, - d_ff, - dropout, - dropout_shared_axes, - mode, - activation): +def FeedForwardBlock(d_model, + d_ff, + dropout, + dropout_shared_axes, + mode, + activation): # We copy the ff block function because we cannot import it from models return [ core.Dense(d_ff), @@ -95,7 +99,7 @@ def AttentionResampling(shorten_factor, d_model, is_upsampling, d_ff, n_heads, total_pooling, n_heads=n_heads, dropout=dropout, mode=mode) - feed_forward = _FeedForwardBlock( + feed_forward = FeedForwardBlock( d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) resampling = resampling_fn(shorten_factor, d_model, diff --git a/trax/models/__init__.py b/trax/models/__init__.py index df3871551..01778a508 100644 --- a/trax/models/__init__.py +++ b/trax/models/__init__.py @@ -26,7 +26,7 @@ from trax.models.reformer import reformer from trax.models.research import bert from trax.models.research import configurable_transformer -from trax.models.research import funnel_transformer +from trax.models.research import hourglass from trax.models.research import layerdrop_transformer from trax.models.research import rezero from trax.models.research import rse @@ -88,11 +88,5 @@ def model_configure(*args, **kwargs): RNNLM = model_configure(rnn.RNNLM) GRULM = model_configure(rnn.GRULM) LSTMSeq2SeqAttn = model_configure(rnn.LSTMSeq2SeqAttn) -FunnelTransformerEncoder = model_configure( - funnel_transformer.FunnelTransformerEncoder) -FunnelTransformer = model_configure( - funnel_transformer.FunnelTransformer) ResidualShuffleExchange = model_configure(rse.ResidualShuffleExchange) -RelformerLM = model_configure( - funnel_transformer.RelformerLM) -RelformerChunkedLM = model_configure(funnel_transformer.RelformerChunkedLM) +HourglassLM = model_configure(hourglass.HourglassLM) diff --git a/trax/models/research/funnel_transformer.py b/trax/models/research/funnel_transformer.py deleted file mode 100644 index 3c958fa55..000000000 --- a/trax/models/research/funnel_transformer.py +++ /dev/null @@ -1,975 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Lint as: python3 -"""Funnel Transformer model. - -Funnel-Transformer: Filtering out Sequential Redundancy for Efficient -Language Processing https://arxiv.org/abs/2006.03236 -""" -import functools - -from trax import fastmath -from trax import layers as tl -from trax.fastmath import numpy as jnp -from trax.fastmath.ops import index_add -from trax.layers.assert_shape import assert_shape -from trax.layers.research.rel_attention import get_rel_att_inputs -from trax.layers.research.rel_attention import RelativeAttentionWrapper -from trax.layers.research.resampling import _FeedForwardBlock -from trax.layers.research.resampling import AttentionResampling -from trax.layers.research.resampling import AveragePooling -from trax.layers.research.resampling import LinearPooling -from trax.layers.research.resampling import LinearUpsampling -from trax.models.reformer.reformer import DecoderBlock -from trax.models.research.configurable_transformer import ApplyAttentionLayer -from trax.models.research.configurable_transformer import PositionalEncoder -from trax.models.transformer import _EncoderBlock - - -@assert_shape('bld->bSd') -def PoolLayer(pool_layer=tl.AvgPool, - pool_size=(2,), - strides=(2,), - separate_cls=True): - """Returns a pool layer for Funnel Transformer. - - Args: - pool_layer: Type of pooling layer used for downsampling; - should be `tl.AvgPool` or `tl.MaxPool`. - pool_size: Shape of window that gets reduced to a single vector value. - If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` - must be a tuple of length :math:`n-2`. - strides: Offsets from the location of one window to the locations of - neighboring windows along each axis. If specified, must be a tuple of - the same length as `pool_size`. If None, then offsets of 1 along each - window axis, :math:`(1, ..., 1)`, will be used. - separate_cls: If `True`, pooling in funnel blocks is not applied to - embeddings of the first token (`cls` from BERT paper). - """ - if separate_cls: - cls_selection = tl.Fn('select_cls_token', lambda x: x[:, :1, :]) - tokens_after_cls = tl.Fn('rest_tokens', lambda x: x[:, 1:, :]) - - return tl.Serial( - tl.Branch( - cls_selection, - tl.Serial( - tokens_after_cls, - pool_layer(pool_size, strides) - ) - ), - tl.Concatenate(axis=1) - ) - else: - return pool_layer(pool_size, strides) - - -@assert_shape('b11l->b11S') -def MaskPool(pool_size=(2,), strides=(2,), separate_cls=True): - return tl.Serial( - tl.Fn('reshape', lambda x: x.swapaxes(1, -1).squeeze(axis=-1)), - PoolLayer(tl.MaxPool, pool_size, strides, separate_cls), - tl.Fn('reshape_back', lambda x: x[..., None].swapaxes(1, -1)) - ) - - -@assert_shape('bld->bd') -def SelectFirst(): - return tl.Fn('select_first', lambda x: x[:, 0, :]) - - -def _Upsampler(total_pool_size, separate_cls): - """Returns an upsampling layer for Funnel Transformer. - - Args: - total_pool_size: The combined pool size of previously used funnel blocks. - separate_cls: If `True`, pooling in funnel blocks is not applied to - embeddings of the first token (`cls` from BERT paper). - """ - - def _Upsample(short, long): - if separate_cls: - upsampled_short = jnp.concatenate( - (short[:, :1, :], - short[:, 1:, :].repeat(total_pool_size, axis=1)), - axis=1) - return index_add( - long, - (slice(None), - slice(None, upsampled_short.shape[1]), - slice(None)), - upsampled_short) - else: - upsampled_short = short.repeat(total_pool_size, axis=1) - return long + upsampled_short - - return tl.Fn('Upsampler', _Upsample) - - -def _FunnelBlock(d_model, d_ff, n_heads, - dropout, dropout_shared_axes, mode, ff_activation, - pool_layer, pool_size, strides, separate_cls): - """Internal funnel block. Returns a list of layers implementing it. - - The input is an activation tensor. - - Args: - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within a block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is - a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If `'train'`, each block will include dropout; else, it will - pass all values through unaltered. - ff_activation: Type of activation function at the end of each block; must - be an activation-type subclass of `Layer`. - pool_layer: Type of pooling layer used for downsampling; - should be `tl.AvgPool` or `tl.MaxPool`. - pool_size: Shape of window that gets reduced to a single vector value. - If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` - must be a tuple of length :math:`n-2`. - strides: Offsets from the location of one window to the locations of - neighboring windows along each axis. If specified, must be a tuple of - the same length as `pool_size`. If None, then offsets of 1 along each - window axis, :math:`(1, ..., 1)`, will be used. - separate_cls: If `True`, pooling in funnel blocks is not applied to - embeddings of the first token (`cls` from BERT paper). - Returns: - A list of layers that maps (activations, mask) to (activations', mask). - """ - pooling = PoolLayer(pool_layer, pool_size, strides, separate_cls) - mask_pooling = MaskPool(pool_size, strides, separate_cls) - - attention = tl.AttentionQKV(d_model, n_heads=n_heads, dropout=dropout, - mode=mode) - hidden_dropout = tl.Dropout( - rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - feed_forward = _FeedForwardBlock( - d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) - - return [ # h, mask - tl.LayerNorm(), # h, mask - tl.Branch(pooling, None), # h', h, mask - tl.Residual( - tl.Select([0, 1, 1, 2]), # h', h, h, mask - attention, # attn, mask - tl.Parallel(None, mask_pooling), # attn, mask' - hidden_dropout # attn, mask' - ), # funnel_activations, mask' - tl.Residual( - tl.LayerNorm(), - feed_forward, - hidden_dropout, - ) - ] - - -def FunnelTransformerEncoder(vocab_size, - n_classes=10, - d_model=512, - d_ff=2048, - encoder_segment_lengths=(2, 2, 2), - n_heads=8, - max_len=2048, - dropout=0.1, - dropout_shared_axes=None, - mode='train', - ff_activation=tl.Relu, - pool_layer=tl.AvgPool, - pool_size=(2,), - strides=(2,), - separate_cls=True): - """Returns a Funnel Encoder. - - This model performs text categorization: - - - input: rank 2 tensor representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). The tensor - elements are integers in `range(vocab_size)`, and `0` values mark padding - positions. - - - output: rank 2 tensor representing a batch of log-probability - distributions over N categories; shape is (batch_size, `n_classes`). - - Args: - vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in `range(vocab_size)`. These integers typically - represent token IDs from a vocabulary-based tokenizer. - n_classes: Final dimension of the output tensors, representing N-way - classification. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each encoder - block. - encoder_segment_lengths: Tuple, where each element denotes the number of - transformer encoder blocks preceding a funnel transformer block. - There is no funnel block after the last sequence of encoder blocks, - therefore the total number of blocks in the model is equal to - `sum(encoder_segment_lengths) + len(encoder_segment_lengths) - 1`. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within an encoder block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is - a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If `'train'`, each encoder block will include dropout; else, it will - pass all values through unaltered. - ff_activation: Type of activation function at the end of each encoder - block; must be an activation-type subclass of `Layer`. - pool_layer: Type of pooling layer used for downsampling in each of the - funnel blocks; should be `tl.AvgPool` or `tl.MaxPool`. - pool_size: Shape of window that gets reduced to a single vector value. - If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` - must be a tuple of length :math:`n-2`. - strides: Offsets from the location of one window to the locations of - neighboring windows along each axis. If specified, must be a tuple of - the same length as `pool_size`. If None, then offsets of 1 along each - window axis, :math:`(1, ..., 1)`, will be used. - separate_cls: If `True`, pooling in funnel blocks is not applied to - embeddings of the first token (`cls` from BERT paper) and only final - embedding of this token is used for categorization - the rest are - discarded. If `False`, each token from the beginning is pooled and - all embeddings are averaged and mapped to output categories like in - original `TransformerEncoder` model. - Returns: - A Transformer model that maps strings (conveyed via token IDs) to - probability-like activations over a range of output classes. - """ - assert encoder_segment_lengths - - positional_encoder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), - tl.PositionalEncoding(max_len=max_len)] - - encoder_blocks = [] - n_encoder_segments = len(encoder_segment_lengths) - - for i in range(n_encoder_segments): - # Building i'th segment - for _ in range(encoder_segment_lengths[i]): - # Create segment_size encoder blocks - encoder_blocks.append( - _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation)) - - # If not last segment, add funnel block - if i != n_encoder_segments - 1: - encoder_blocks.append( - _FunnelBlock(d_model, d_ff, n_heads, dropout, - dropout_shared_axes, mode, - ff_activation, pool_layer, pool_size, - strides, separate_cls)) - - cls_pooling = SelectFirst() if separate_cls else tl.Mean(axis=1) - - # Assemble and return the model. - return tl.Serial( # toks - # Encode. - tl.Branch( - positional_encoder, tl.PaddingMask()), # vecs masks - encoder_blocks, # vecs masks - tl.Select([0], n_in=2), # vecs - tl.LayerNorm(), # vecs - - # Map to output categories. - cls_pooling, # cls - tl.Dense(n_classes), # cls - ) - - -def FunnelTransformer(vocab_size, - d_model=512, - d_ff=2048, - encoder_segment_lengths=(2, 2, 2), - n_decoder_blocks=2, - n_heads=8, - max_len=2048, - dropout=0.1, - dropout_shared_axes=None, - mode='train', - ff_activation=tl.Relu, - pool_layer=tl.AvgPool, - pool_size=(2,), - separate_cls=True): - """Returns a Full Funnel Transformer, that can be used for example for BERT. - - This model outputs token-level categorical distributions over all vocab: - - - input: rank 2 tensor representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). The tensor - elements are integers in `range(vocab_size)`, and `0` values mark padding - positions. - - - output: rank 3 tensor representing a batch of log-probability - distributions over `vocab_size` categories for each token; shape is - (batch_size, sequence_length, vocab_size). - - - Args: - vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in `range(vocab_size)`. These integers typically - represent token IDs from a vocabulary-based tokenizer. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each encoder - block. - encoder_segment_lengths: Tuple, where each element denotes the number of - transformer encoder blocks preceding a funnel transformer block. - There is no funnel block after the last sequence of encoder blocks, - therefore the total number of blocks in the model is equal to - `sum(encoder_segment_lengths) + len(encoder_segment_lengths) - 1`. - n_decoder_blocks: Number of transformer blocks in the upsampling decoder. - n_heads: Number of attention heads. - max_len: Maximum symbol length for positional encoding. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within an encoder block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is - a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If `'train'`, each encoder block will include dropout; else, it will - pass all values through unaltered. - ff_activation: Type of activation function at the end of each encoder - block; must be an activation-type subclass of `Layer`. - pool_layer: Type of pooling layer used for downsampling in each of the - funnel blocks; should be `tl.AvgPool` or `tl.MaxPool`. - pool_size: Shape of window that gets reduced to a single vector value. - If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` - must be a tuple of length :math:`n-2`. - separate_cls: If `True`, pooling in funnel blocks is not applied to - embeddings of the first token (`cls` from BERT paper) and only final - embedding of this token is used for categorization - the rest are - discarded. If `False`, each token from the beginning is pooled and - all embeddings are averaged and mapped to output categories like in - original `TransformerEncoder` model. - """ - assert encoder_segment_lengths - - positional_encoder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), - tl.PositionalEncoding(max_len=max_len)] - - n_encoder_segments = len(encoder_segment_lengths) - - encoder_blocks_before_first_pooling = [ - _EncoderBlock(d_model, d_ff, n_heads, dropout, - dropout_shared_axes, mode, ff_activation) - for _ in range(encoder_segment_lengths[0])] - encoder_blocks_from_first_pooling = [] - - for i in range(1, n_encoder_segments): - # Building i'th segment - - # Add funnel block between segments - encoder_blocks_from_first_pooling.append( - _FunnelBlock(d_model, d_ff, n_heads, dropout, - dropout_shared_axes, mode, - ff_activation, pool_layer, - pool_size=pool_size, strides=pool_size, - separate_cls=separate_cls)) - - for _ in range(encoder_segment_lengths[i]): - # Create segment_size encoder blocks - encoder_blocks_from_first_pooling.append( - _EncoderBlock(d_model, d_ff, n_heads, dropout, - dropout_shared_axes, mode, ff_activation)) - - decoder_blocks = [_EncoderBlock(d_model, d_ff, n_heads, dropout, - dropout_shared_axes, mode, ff_activation) - for _ in range(n_decoder_blocks)] - - total_pool_size = pool_size[0] ** (len(encoder_segment_lengths) - 1) - - # Assemble and return the model. - return tl.Serial( # toks - tl.Branch( - positional_encoder, tl.PaddingMask()), # vecs masks - encoder_blocks_before_first_pooling, # vecs masks - tl.Select([0, 1, 0, 1]), - # vecs masks residual = vecs old_masks - encoder_blocks_from_first_pooling, # vecs masks residual masks - tl.Select([0, 2, 3]), # vecs residual masks - tl.Parallel( - # residual from first segment is taken before - # normalization, so apply it now - None, tl.LayerNorm(), None), # vecs norm(residual) masks - _Upsampler(total_pool_size, separate_cls), # vecs masks - decoder_blocks, - tl.Select([0], n_in=2), # vecs - tl.LayerNorm(), - tl.Dense(vocab_size), - ) - - -def _RelativeDecoderBlock(attention_type, d_model, - d_ff, n_heads, dropout, dropout_shared_axes, - mode, ff_activation, context_bias_layer, - location_bias_layer, total_pooling): - """Returns a list of layers that implements a Transformer encoder block. - - The input to the block is a pair, (activations, mask), where the mask was - created from the original source tokens to prevent attending to the padding - part of the input. - - Args: - attention_type: attention type. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each block. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within a block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is - a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: If `'train'`, each block will include dropout; else, it will - pass all values through unaltered. - ff_activation: Type of activation function at the end of each block; must - be an activation-type subclass of `Layer`. - context_bias_layer: context bias layer. - location_bias_layer: location bias layer. - total_pooling: The combined pool size of previously used funnel blocks. - Returns: - A list of layers that maps (activations, att_vecs, mask) to - (activations, att_vecs, mask). - """ - if attention_type == RelativeAttentionWrapper: - attention = RelativeAttentionWrapper( - d_model, - n_heads, - dropout, - mode=mode, - context_bias_layer=context_bias_layer, - location_bias_layer=location_bias_layer, - total_pooling=total_pooling - ) - else: - attention = ApplyAttentionLayer( - attention_type, - d_model, - n_heads, - d_model // n_heads, - d_model // n_heads, - causal=True, - masked=False, - attention_dropout=dropout, - output_dropout=dropout, - attention_chunk_size=0, # Disables tl.Chunk in ApplyAttentionLayer. - mode=mode, - ) - - feed_forward = _FeedForwardBlock( - d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) - - def _Dropout(): - return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) - - return [ - tl.Residual( # vecs - tl.LayerNorm(), - attention, - _Dropout(), - ), # vecs - tl.Residual( - tl.LayerNorm(), - feed_forward, - _Dropout(), - ), # vecs - ] - - -def RelformerLM(vocab_size, - d_model=512, - d_ff=2048, - vanilla_layers=(0, 1), - shorten_factors=(3,), - n_funnel_blocks=(6,), - n_heads=8, - dropout=0.1, - dropout_shared_axes=None, - mode='train', - ff_activation=tl.FastGelu, - vanilla_attn_type=RelativeAttentionWrapper, - middle_attn_type=RelativeAttentionWrapper, - downsampling_fn=AttentionResampling, - upsampling_fn=AttentionResampling, - attention_downsampling_fn=AveragePooling, - attention_upsampling_fn=LinearUpsampling): - """Returns a Transformer language model. - - This model performs autoregressive language modeling: - - - input: rank 2 tensor representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). The tensor - elements are integers in `range(vocab_size)`, and `0` values mark padding - positions. - - - output: rank 3 tensor representing a batch of log-probability - distributions for each sequence position over possible token IDs; - shape is (batch_size, sequence_length, `vocab_size`). - - This model uses only the decoder part of the overall Transformer. - - Args: - vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in `range(vocab_size)`. These integers typically - represent token IDs from a vocabulary-based tokenizer. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each encoder - block. - vanilla_layers: (pre_layers, post_layers) tuple - number of full token-level - Transformer decoder layers before and after shortening. - shorten_factors: by how much to shorten at each step - tuple of arbitrary - length denoting by how much shorten at each pooling stage. - n_funnel_blocks: number of Transformer decoder blocks after each stage of - pooling - tuple of the same length as `shorten_factors`. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within an encoder block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is - a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - mode: str: 'train' or 'eval'. - ff_activation: Type of activation function at the end of each encoder - block; must be an activation-type subclass of `Layer`. - vanilla_attn_type: class: attention class such as SelfAttention to use in - the layers before and after shortening (vanilla layers). - middle_attn_type: class: attention class to use in the middle layers - (these operating on the shortened sequence). - downsampling_fn: function that takes full token-level vectors of - length `l` and transforms them into `l` / `k` vectors, where `k` - denotes `shorten_factor` parameter. - upsampling_fn: function that takes shortened representations of a sequence, - consisting of `l` / `k` vectors and transforms them into full - token-level representations of length `l`. - attention_downsampling_fn: Downsampling function that transforms token-level - vectors into query vectors with reduced length. Necessary only when - AttentionResampling is used as `downsampling_fn`. - attention_upsampling_fn: Upsampling function for AttentionResampling. - Valid only when AttentionResampling is used as a `upsampling_fn`. - - - Returns: - A Transformer language model as a layer that maps from a tensor of tokens - to activations over a vocab set. - """ - assert mode != 'predict' # For now, 'predict' mode is unsupported. - assert len(n_funnel_blocks) == len(shorten_factors) - - token_encoder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)] - - context_bias_layer, location_bias_layer = get_rel_att_inputs(d_model, - n_heads) - - n_pre_decoder_blocks, n_post_decoder_blocks = vanilla_layers - - def create_decoder_blocks(n_layers, total_pooling, attention_type): # pylint: disable=invalid-name - decoder_blocks = [ - # pylint: disable=g-complex-comprehension - _RelativeDecoderBlock(attention_type, d_model, d_ff, n_heads, dropout, - dropout_shared_axes, mode, ff_activation, - context_bias_layer, location_bias_layer, - total_pooling) - for _ in range(n_layers)] - return decoder_blocks + [tl.LayerNorm()] - - total_pooling_acc = 1 - pre_decoder_blocks = create_decoder_blocks(n_pre_decoder_blocks, - total_pooling_acc, - vanilla_attn_type) - - funnel_blocks = [] - - for shorten_factor, block_len in zip(shorten_factors, n_funnel_blocks): - funnel_blocks = funnel_blocks + [ - downsampling_fn(shorten_factor, d_model, - is_upsampling=False, d_ff=d_ff, n_heads=n_heads, - dropout=dropout, - dropout_shared_axes=dropout_shared_axes, mode=mode, - ff_activation=ff_activation, - context_bias_layer=context_bias_layer, - location_bias_layer=location_bias_layer, - total_pooling=total_pooling_acc, - resampling_fn=attention_downsampling_fn)] - total_pooling_acc *= shorten_factor - funnel_blocks = funnel_blocks + create_decoder_blocks( - block_len, - total_pooling_acc, - middle_attn_type) - - upsampling_layer = upsampling_fn(shorten_factor=total_pooling_acc, - d_model=d_model, is_upsampling=True, - d_ff=d_ff, n_heads=n_heads, - dropout=dropout, - dropout_shared_axes=dropout_shared_axes, - mode=mode, ff_activation=ff_activation, - context_bias_layer=context_bias_layer, - location_bias_layer=location_bias_layer, - total_pooling=total_pooling_acc, - resampling_fn=attention_upsampling_fn) - - post_decoder_blocks = create_decoder_blocks(n_post_decoder_blocks, 1, - vanilla_attn_type) - - # Assemble and return the model. - return tl.Serial( # tokens (or chunked tuple of tokens) - tl.ShiftRight(mode=mode), # toks - token_encoder, # vecs - pre_decoder_blocks, # vecs - tl.Dup(), # vecs - tl.ShiftRight(n_positions=total_pooling_acc - 1), # shifted_vecs, vecs - funnel_blocks, # shifted_vecs, vecs - upsampling_layer, # vecs, vecs - tl.LayerNorm(), # vecs, vecs - tl.Add(), # vecs - post_decoder_blocks, # vecs - tl.Dense(vocab_size), # vecs - ) - - -class RelformerCacher(tl.Layer): - """Cache for Relformer. - - A class for caching tokens going through model to provide fast inference - for Relformer model. - """ - - def __init__(self, - total_kv_pooling, - n_raw_tokens_generated=1, - max_inference_length=64 * 64 * 3, - shift=0, - sliding=False, - mode='train'): - super().__init__(n_in=1, n_out=1) - self._total_kv_pooling = total_kv_pooling - self._n_raw_tokens_generated = n_raw_tokens_generated - self._max_len = max_inference_length - self._shift = shift - self._sliding = sliding - self._mode = mode - - def forward(self, inputs): - if self._mode != 'predict': - return inputs - return self.update_state(inputs=inputs) - - def init_weights_and_state(self, input_signature): - if self._mode == 'predict': - shape, dtype = input_signature.as_tuple() - batch_size, _, d_feature = shape - cache = jnp.zeros((batch_size, 2 * self._total_kv_pooling, d_feature), - dtype=dtype) - self.state = cache, jnp.array(0) - - def update_state(self, inputs): - cache, idx = self.state - cache = fastmath.dynamic_update_slice_in_dim( - cache, - inputs, (idx + self._shift) % (2 * self._total_kv_pooling), - axis=1) - - if self._sliding: - cache = fastmath.dynamic_update_slice_in_dim( - cache, - inputs, - (idx + self._total_kv_pooling * 2 - 1) % (2 * self._total_kv_pooling), - axis=1) - - if self._sliding: - left_index = idx % self._total_kv_pooling - else: - left_index = (idx - - (idx % self._total_kv_pooling)) % (2 * - self._total_kv_pooling) - - output = fastmath.dynamic_slice( - cache, [0, left_index, 0], - [cache.shape[0], self._total_kv_pooling, cache.shape[2]]) - - self.state = cache, idx + self._n_raw_tokens_generated - return output - - -class RelformerPicker(tl.Layer): - """Relformer Picker. - - A class for picking tokens going through model to provide fast inference - for Relformer model. - """ - - def __init__(self, total_kv_pooling, n_raw_tokens_generated=1, mode='train'): - super().__init__(n_in=1, n_out=1) - self._total_kv_pooling = total_kv_pooling - self._n_raw_tokens_generated = n_raw_tokens_generated - self._mode = mode - - def forward(self, inputs): - if self._mode != 'predict': - return inputs - - output = fastmath.dynamic_slice( - inputs, [0, self.state, 0], - [inputs.shape[0], self._n_raw_tokens_generated, inputs.shape[2]]) - self.state = (self.state + - self._n_raw_tokens_generated) % self._total_kv_pooling - return output - - def init_weights_and_state(self, input_signature): - if self._mode == 'predict': - self.state = jnp.array(0) - - -def PickLastTokenInPredict(mode='train'): - """Picks the last token logits. - - Self-descriptive layer for picking the last token logits in predict mode - for fast inference. - - Args: - mode: the model mode (train, predict, ...) - - Returns: - The last token logits. - """ - - def last_token(x): # pylint: disable=invalid-name - if mode == 'predict': - return x[:, -1:, :] - return x - - return tl.Fn('Pick last token in predict', last_token) - - -def RelformerChunkedLM(vocab_size, - d_model=512, - d_ff=2048, - vanilla_layers=(1, 1), - shorten_factor=3, - n_rel_layers=6, - rel_chunk_len=None, - vanilla_chunk_len=None, - n_heads=8, - dropout=0.1, - dropout_shared_axes=None, - vanilla_attn_type=tl.LSHSelfAttention, - pos_type='fixed-base', - max_len=3072, - n_raw_tokens_generated=1, - mode='train', - ff_activation=tl.FastGelu): - """Returns a Transformer language model. - - This model performs autoregressive language modeling: - - - input: rank 2 tensor representing a batch of text strings via token IDs - plus padding markers; shape is (batch_size, sequence_length). The tensor - elements are integers in `range(vocab_size)`, and `0` values mark padding - positions. - - - output: rank 3 tensor representing a batch of log-probability - distributions for each sequence position over possible token IDs; - shape is (batch_size, sequence_length, `vocab_size`). - - This model uses only the decoder part of the overall Transformer. - - Args: - vocab_size: Input vocabulary size -- each element of the input tensor - should be an integer in `range(vocab_size)`. These integers typically - represent token IDs from a vocabulary-based tokenizer. - d_model: Final dimension of tensors at most points in the model, including - the initial embedding output. - d_ff: Size of special dense layer in the feed-forward part of each encoder - block. - vanilla_layers: (pre_layers, post_layers) tuple - number of full token-level - Transformer decoder layers before and after shortening. - shorten_factor: by how much to shorten - n_rel_layers: number of Transformer blocks after the pooling. These blocks - use relative attention. - rel_chunk_len (optional): Number of tokens per chunk. Setting this option - will enable chunked relative attention. - vanilla_chunk_len (optional): If set, enables chunked relative attention - also in layers before and after shortening. - n_heads: Number of attention heads. - dropout: Stochastic rate (probability) for dropping an activation value - when applying dropout within an encoder block. - dropout_shared_axes: Tensor axes on which to share a dropout mask. - Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is - a useful way to save memory and apply consistent masks to activation - vectors at different sequence positions. - vanilla_attn_type: class: attention class such as SelfAttention to use in - the layers before and after shortening (vanilla layers). - pos_type: string, the type of positional embeddings to use. - max_len: int: maximum symbol length both for positional encoding and it is - also the maximum length of the possible inference in 'predict' mode - n_raw_tokens_generated: int: number of tokens generated with every pass - through model in 'predict' mode. Number of tokens should be smaller and - divisible by the first shorten factor we are using in the model. - It cannot be larger than one if we use vanilla layers because we would - lose autoregressive property of the model. - mode: str: 'train' or 'eval' or 'predict'. - ff_activation: Type of activation function at the end of each encoder - block; must be an activation-type subclass of `Layer`. - - Returns: - A Transformer language model as a layer that maps from a tensor of tokens - to activations over a vocab set. - """ - - token_encoder = [ - tl.Embedding(vocab_size, d_model), - tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)] - - if vanilla_chunk_len is None: - positional_encoder = PositionalEncoder(mode, dropout, max_len, pos_type) - else: - positional_encoder = [] - - n_pre_decoder_blocks, n_post_decoder_blocks = vanilla_layers - - def create_reformer_blocks( # pylint: disable=invalid-name - n_layers, - total_kv_pooling=1, - layer_chunk_len=None, - force_relative=False, - dense=True): - if n_layers == 0: - return [tl.LayerNorm()] - - def determine_attn_type(layer_number): # pylint: disable=invalid-name - if layer_chunk_len is None and not force_relative: - return vanilla_attn_type - - if layer_chunk_len is not None: - chunk_offset = (layer_number % 2) * (layer_chunk_len // 2) - else: - chunk_offset = None - - return functools.partial( - RelativeAttentionWrapper, - n_raw_tokens_generated=n_raw_tokens_generated, - max_inference_length=max_len, - total_kv_pooling=total_kv_pooling, - chunk_len=layer_chunk_len, - chunk_offset=chunk_offset) - - d_per_head = d_model // n_heads - - decoder_blocks = [] - for i in range(n_layers): - layer_attn_type = determine_attn_type(i) - - decoder_blocks.append( - DecoderBlock( - d_model, - d_ff, - d_per_head, - d_per_head, - n_heads, - layer_attn_type, - dropout, - ff_activation, - dropout, - ff_use_sru=0, - ff_chunk_size=0, - ff_sparsity=0, - attention_chunk_size=0, - mode=mode)) - - return [ - tl.Dup(), - tl.ReversibleSerial(decoder_blocks), - tl.Concatenate(), - tl.LayerNorm(), - tl.Dense(d_model) if dense else [], - ] - - pre_decoder_blocks = create_reformer_blocks( - n_pre_decoder_blocks, layer_chunk_len=vanilla_chunk_len) - - relative_decoder_blocks = create_reformer_blocks( - n_rel_layers, - total_kv_pooling=shorten_factor, - layer_chunk_len=rel_chunk_len, - force_relative=True) - - conv_layer = tl.Serial( - tl.CausalConv(d_model, shorten_factor), - ff_activation() - ) - - post_decoder_blocks = create_reformer_blocks( - n_post_decoder_blocks, layer_chunk_len=vanilla_chunk_len, dense=False) - - cacher = RelformerCacher( - total_kv_pooling=shorten_factor, - n_raw_tokens_generated=n_raw_tokens_generated, - max_inference_length=max_len, - shift=shorten_factor - 1, - mode=mode) - - picker = RelformerPicker( - total_kv_pooling=shorten_factor, - n_raw_tokens_generated=n_raw_tokens_generated, - mode=mode) - - cacher_conv = RelformerCacher( - total_kv_pooling=shorten_factor, - n_raw_tokens_generated=n_raw_tokens_generated, - max_inference_length=max_len, - shift=shorten_factor - 1, - sliding=True, - mode=mode) - - picker_conv = PickLastTokenInPredict(mode=mode) - - # Assemble and return the model. - return tl.Serial( # tokens (or chunked tuple of tokens) - tl.ShiftRight(mode=mode), # toks - token_encoder, # vecs - positional_encoder, - pre_decoder_blocks, # vecs - tl.Dup(), - cacher, - tl.ShiftRight(n_positions=shorten_factor - 1, mode=mode), - LinearPooling(shorten_factor, d_model), - relative_decoder_blocks, - tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), - LinearUpsampling(shorten_factor, d_model), - tl.LayerNorm(), - picker, - tl.Concatenate(), - cacher_conv, - conv_layer, - picker_conv, - post_decoder_blocks, - tl.Dense(vocab_size), # vecs - ) diff --git a/trax/models/research/funnel_transformer_test.py b/trax/models/research/funnel_transformer_test.py deleted file mode 100644 index f6fcbf0a4..000000000 --- a/trax/models/research/funnel_transformer_test.py +++ /dev/null @@ -1,371 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Lint as: python3 -"""Tests for Funnel-Transformer models.""" - -from absl.testing import absltest -from absl.testing import parameterized -import gin -import jax -import numpy as np -from trax import fastmath -from trax import layers as tl -from trax import shapes -import trax.models.research.funnel_transformer as ft -from trax.supervised import decoding - - -# pylint: disable=g-unreachable-test-method - - -class FunnelTransformerTest(parameterized.TestCase): - - def test_mean_pool(self): - x = np.ones((1, 4, 1)) - x[0, :3, 0] = [5., 2., 4.] - - pooling = ft.PoolLayer(tl.AvgPool, (2,), (2,)) - y = pooling(x) - - self.assertEqual(y.shape, (1, 2, 1)) - self.assertEqual(y.tolist(), [[[5.], [3.]]]) - - def test_mask_pool(self): - x = np.array([1, 0, 0, 1], dtype=bool).reshape((1, 1, 1, 4)) - pooling_cls = ft.MaskPool((2,), (2,)) - y1 = pooling_cls(x) - - self.assertEqual(y1.shape, (1, 1, 1, 2)) - self.assertEqual(y1.squeeze().tolist(), [True, False]) - - pooling_without_cls = ft.MaskPool((2,), (2,), separate_cls=False) - y2 = pooling_without_cls(x) - - self.assertEqual(y2.shape, (1, 1, 1, 2)) - self.assertEqual(y2.squeeze().tolist(), [True, True]) - - def test_upsampler(self): - long = np.ones((1, 8, 1)) - short = np.ones((1, 2, 1)) - total_pool_size = long.shape[1] // short.shape[1] - up_cls = ft._Upsampler(total_pool_size, separate_cls=True) - up = ft._Upsampler(total_pool_size, separate_cls=False) - - y_cls = up_cls([short, long]) - y = up((short, long)) - self.assertEqual(y_cls.shape, long.shape) - self.assertEqual(y.shape, long.shape) - - self.assertEqual(y_cls.squeeze().tolist(), 5*[2] + 3*[1]) - self.assertEqual(y.squeeze().tolist(), 8*[2]) - - def test_funnel_block_forward_shape(self): - n_even = 4 - d_model = 8 - - x = np.ones((1, n_even, d_model), dtype=np.float) - mask = np.ones((1, n_even), dtype=np.int32) - - masker = tl.PaddingMask() - mask = masker(mask) - - block = tl.Serial( - ft._FunnelBlock(d_model, 8, 2, 0.1, None, 'train', tl.Relu, - tl.AvgPool, (2,), (2,), separate_cls=True)) - - xs = [x, mask] - _, _ = block.init(shapes.signature(xs)) - - y, _ = block(xs) - - self.assertEqual(y.shape, (1, n_even // 2, d_model)) - - def test_funnel_transformer_encoder_forward_shape(self): - n_classes = 5 - model = ft.FunnelTransformerEncoder(2, n_classes=n_classes, d_model=8, - d_ff=8, encoder_segment_lengths=(1, 1), - n_heads=2, max_len=8) - - batch_size = 2 - n_tokens = 4 - x = np.ones((batch_size, n_tokens), dtype=np.int32) - _ = model.init(shapes.signature(x)) - y = model(x) - - self.assertEqual(y.shape, (batch_size, n_classes)) - - def test_funnel_transformer_forward_shape(self): - d_model = 8 - vocab_size = 7 - model = ft.FunnelTransformer(7, d_model=d_model, d_ff=8, - encoder_segment_lengths=(1, 1), - n_decoder_blocks=1, n_heads=2, max_len=8) - - batch_size = 2 - n_tokens = 4 - x = np.ones((batch_size, n_tokens), dtype=np.int32) - _ = model.init(shapes.signature(x)) - y = model(x) - - self.assertEqual(y.shape, (batch_size, n_tokens, vocab_size)) - - def _check_forward_shape(self, model, input_shape, output_vocab_size): - x = np.ones(input_shape).astype(np.int32) - model.init(shapes.signature(x)) - y = model(x) - self.assertEqual(y.shape, (*input_shape, output_vocab_size)) - - def test_funnel_transformer_lm_forward_shape(self): - d_model = 16 - vocab_size = 7 - model = ft.RelformerLM( - vocab_size, - shorten_factors=(3,), - n_funnel_blocks=(2,), - vanilla_layers=(1, 1), - d_model=d_model, - d_ff=d_model, - n_heads=2, - ) - - batch_size, seq_len = 3, 12 - self._check_forward_shape( - model, input_shape=(batch_size, seq_len), output_vocab_size=vocab_size) - - def test_lsh_attention_in_vanilla(self): - d_model = 16 - vocab_size = 7 - - gin.bind_parameter('PureLSHSelfAttentionWrapper.pure_lsh_implementation', - tl.PureLSHSelfAttention) - gin.bind_parameter('PureLSHSelfAttention.chunk_len', 2) - - model = ft.RelformerLM( - vocab_size, - shorten_factors=(3,), - n_funnel_blocks=(2,), - vanilla_layers=(1, 1), - d_model=d_model, - d_ff=d_model, - n_heads=2, - vanilla_attn_type=tl.PureLSHSelfAttentionWrapper, - downsampling_fn=ft.LinearPooling, - upsampling_fn=ft.LinearUpsampling, - ) - - batch_size, seq_len = 3, 12 - self._check_forward_shape( - model, input_shape=(batch_size, seq_len), output_vocab_size=vocab_size) - - def _test_autoregressive_property(self, model, input_shape, - output_vocab_size): - rng_1 = jax.random.PRNGKey(0) - rng_2 = jax.random.PRNGKey(1) - - def _get_output_logits(unitialized_eval_model: tl.Layer, x): - input_signature = shapes.signature(x) - unitialized_eval_model.init(input_signature, rng=rng_1, use_cache=False) - - output_logits, *_ = unitialized_eval_model(x, rng=rng_1) - return output_logits - - def check_autoregressive_property(model): - with fastmath.use_backend(fastmath.Backend.JAX): - x_1 = jax.random.randint(rng_1, input_shape, 0, output_vocab_size) - y_1 = _get_output_logits(model, x_1) - - x_2 = jax.random.randint(rng_2, input_shape, 0, output_vocab_size) - - for i in range(input_shape[1]): - masked_x_2 = np.concatenate((x_1[:, :i], x_2[:, i:]), axis=1) - - y_2 = _get_output_logits(model, masked_x_2) - self.assertEqual(y_2.shape[0], input_shape[1]) - np.testing.assert_array_almost_equal(y_1[:i + 1], y_2[:i + 1]) - - check_autoregressive_property(model) - - def test_funnel_transformer_lm_autoregressive_property(self): - d_model = 8 - vocab_size = 26 - - model = ft.RelformerLM( - vocab_size, - shorten_factors=(3,), - n_funnel_blocks=(2,), - vanilla_layers=(1, 1), - d_model=d_model, - d_ff=d_model, - n_heads=2, - ) - - input_shape = (1, 12) - self._test_autoregressive_property( - model, input_shape, output_vocab_size=vocab_size) - - def test_autoregressive_property_vanilla(self): - d_model = 8 - vocab_size = 26 - - gin.bind_parameter('trax.layers.SelfAttention.chunk_len', 2) - model = ft.RelformerLM( - vocab_size, - shorten_factors=(3,), - n_funnel_blocks=(2,), - vanilla_layers=(1, 1), - d_model=d_model, - d_ff=d_model, - n_heads=2, - vanilla_attn_type=tl.SelfAttention, - downsampling_fn=ft.LinearPooling, - upsampling_fn=ft.LinearUpsampling, - ) - input_shape = (1, 12) - self._test_autoregressive_property( - model, input_shape, output_vocab_size=vocab_size) - - def _test_funnel_transformer_lm_forward_shape_predict(self): - d_model = 8 - vocab_size = 4 - batch_size = 1 - n_len_eval = 42 - attention_type = tl.SelfAttention - - shorten_factor = 3 - n_rel_layers = 2 - vanilla_layers = (1, 1) - n_heads = 2 - - rel_chunk_len, vanilla_chunk_len = 2, 6 - - x = np.ones((batch_size, 1)).astype(np.int32) - gin.bind_parameter('trax.layers.SelfAttention.chunk_len', 20) - - predict_funnel = ft.RelformerChunkedLM( - vocab_size, - shorten_factor=shorten_factor, - n_rel_layers=n_rel_layers, - vanilla_layers=vanilla_layers, - d_model=d_model, - d_ff=d_model, - n_heads=n_heads, - vanilla_attn_type=attention_type, - rel_chunk_len=rel_chunk_len, - vanilla_chunk_len=vanilla_chunk_len, - max_len=n_len_eval, - mode='predict') - - _, _ = predict_funnel.init(shapes.signature(x)) - - for _ in range(5): - y = predict_funnel(x) - self.assertEqual(y.shape, (batch_size, 1, vocab_size)) - gin.clear_config() - - def _test_funnel_transformer_lm_predict_eval_equal(self): - - def _test_for_chunk_lens(rel_chunk_len, vanilla_chunk_len): - d_model = 8 - vocab_size = 4 - batch_size = 1 - n_len_eval = 42 - attention_type = tl.SelfAttention - - shorten_factor = 3 - n_rel_layers = 2 - vanilla_layers = (1, 1) - n_heads = 2 - - eval_funnel = ft.RelformerChunkedLM( - vocab_size, - shorten_factor=shorten_factor, - n_rel_layers=n_rel_layers, - vanilla_layers=vanilla_layers, - d_model=d_model, - d_ff=d_model, - n_heads=n_heads, - vanilla_attn_type=attention_type, - rel_chunk_len=rel_chunk_len, - vanilla_chunk_len=vanilla_chunk_len, - mode='eval') - - inputs = jax.random.randint( - key=jax.random.PRNGKey(0), - minval=0, - maxval=vocab_size, - shape=(batch_size, n_len_eval)).astype(np.int32) - _, _ = eval_funnel.init( - shapes.signature(inputs), rng=jax.random.PRNGKey(0)) - y_eval = eval_funnel(inputs) - self.assertEqual(y_eval.shape, (batch_size, n_len_eval, vocab_size)) - - if attention_type == tl.SelfAttention: - gin.bind_parameter('trax.layers.SelfAttention.chunk_len', n_len_eval) - - predict_funnel = ft.RelformerChunkedLM( - vocab_size, - shorten_factor=shorten_factor, - n_rel_layers=n_rel_layers, - vanilla_layers=vanilla_layers, - d_model=d_model, - d_ff=d_model, - n_heads=n_heads, - vanilla_attn_type=attention_type, - rel_chunk_len=rel_chunk_len, - vanilla_chunk_len=vanilla_chunk_len, - mode='predict') - - inputs = np.concatenate( - [np.zeros((batch_size, 1)).astype(np.int32), inputs], axis=1) - inputs = inputs[:, :-1] - - _, _ = predict_funnel.init( - shapes.signature(inputs[:, 0:1]), - rng=jax.random.PRNGKey(0), - use_cache=False) - - for i in range(n_len_eval): - y = predict_funnel(inputs[:, i:i + 1]) - np.testing.assert_array_almost_equal( - y, y_eval[:, i:i + 1, :], decimal=5) - - _test_for_chunk_lens(rel_chunk_len=None, vanilla_chunk_len=None) - _test_for_chunk_lens(rel_chunk_len=2, vanilla_chunk_len=6) - - def _test_autoregressive_sample_relformerlm(self): - batch_size = 4 - max_length = 5 - model = ft.RelformerChunkedLM( - 10, - d_model=8, - d_ff=16, - n_rel_layers=1, - vanilla_layers=(1, 1), - shorten_factor=3, - n_heads=2, - mode='predict') - model.init(shapes.ShapeDtype((batch_size, 1), dtype=np.int32)) - s1 = decoding.autoregressive_sample( - model, - batch_size=batch_size, - eos_id=-1, - max_length=max_length, - accelerate=False) - self.assertEqual(s1.shape, (batch_size, max_length)) - - -if __name__ == '__main__': - absltest.main() diff --git a/trax/models/research/hourglass.py b/trax/models/research/hourglass.py new file mode 100644 index 000000000..58f5bb066 --- /dev/null +++ b/trax/models/research/hourglass.py @@ -0,0 +1,311 @@ +# coding=utf-8 +# Copyright 2021 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Hourglass - a hierarchical Transformer language model.""" + +import trax.layers as tl +from trax.layers.research.rel_attention import RelativeAttentionWrapper, \ + get_rel_att_inputs +from trax.layers.research.resampling import AttentionResampling, \ + AveragePooling, LinearUpsampling, FeedForwardBlock +from trax.models.research.configurable_transformer import ApplyAttentionLayer + + +def _RelativeDecoderBlock(attention_type, d_model, + d_ff, n_heads, dropout, dropout_shared_axes, + mode, ff_activation, context_bias_layer, + location_bias_layer, total_pooling): + """Returns a list of layers that implements a Transformer decoder block with + relative attention parametrization. + + The input to the block is a pair, (activations, mask), where the mask was + created from the original source tokens to prevent attending to the padding + part of the input. + + Args: + attention_type: attention type. + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each block. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within a block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is + a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: If `'train'`, each block will include dropout; else, it will + pass all values through unaltered. + ff_activation: Type of activation function at the end of each block; must + be an activation-type subclass of `Layer`. + context_bias_layer: context bias layer. + location_bias_layer: location bias layer. + total_pooling: The combined pool size of previously used funnel blocks. + Returns: + A list of layers that maps (activations, att_vecs, mask) to + (activations, att_vecs, mask). + """ + if attention_type == RelativeAttentionWrapper: + attention = RelativeAttentionWrapper( + d_model, + n_heads, + dropout, + mode=mode, + context_bias_layer=context_bias_layer, + location_bias_layer=location_bias_layer, + total_pooling=total_pooling + ) + else: + attention = ApplyAttentionLayer( + attention_type, + d_model, + n_heads, + d_model // n_heads, + d_model // n_heads, + causal=True, + masked=False, + attention_dropout=dropout, + output_dropout=dropout, + attention_chunk_size=0, # Disables tl.Chunk in ApplyAttentionLayer. + mode=mode, + ) + + feed_forward = FeedForwardBlock( + d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) + + def _Dropout(): + return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) + + return [ + tl.Residual( # vecs + tl.LayerNorm(), + attention, + _Dropout(), + ), # vecs + tl.Residual( + tl.LayerNorm(), + feed_forward, + _Dropout(), + ), # vecs + ] + + +def _parse_hierarchy(hierarchy_str): + levels = hierarchy_str.split(' ') + if levels != levels[::-1]: + raise ValueError('Hierarchy is not a palindrome') + layer_level_pairs = [(x.split('@')) for x in levels[:1 + (len(levels) // 2)]] + hierarchy_n_layers = [int(x[0]) for x in layer_level_pairs] + total_sf_per_level = [int(x[1]) for x in layer_level_pairs] + + hierarchy_shorten_factors = [] + for current_sf, prev_sf in zip(total_sf_per_level, + [1] + total_sf_per_level[:-1]): + if current_sf % prev_sf != 0: + raise ValueError( + f'Hierarchy not divisible by previous level: {current_sf}, {prev_sf}') + hierarchy_shorten_factors.append(current_sf // prev_sf) + + return hierarchy_n_layers, hierarchy_shorten_factors + + +def HourglassLM(vocab_size, + d_model=512, + d_ff=2048, + vanilla_layers=(1, 1), + hierarchy='6@3', + n_heads=8, + dropout=0.1, + dropout_shared_axes=None, + mode='train', + ff_activation=tl.FastGelu, + vanilla_attn_type=RelativeAttentionWrapper, + middle_attn_type=RelativeAttentionWrapper, + downsampling_fn=AttentionResampling, + upsampling_fn=AttentionResampling, + attention_downsampling_fn=AveragePooling, + attention_upsampling_fn=LinearUpsampling): + """Returns a hierarchical Transformer language model. + + This model performs autoregressive language modeling: + + - input: rank 2 tensor representing a batch of text strings via token IDs + plus padding markers; shape is (batch_size, sequence_length). The tensor + elements are integers in `range(vocab_size)`, and `0` values mark padding + positions. + + - output: rank 3 tensor representing a batch of log-probability + distributions for each sequence position over possible token IDs; + shape is (batch_size, sequence_length, `vocab_size`). + + This model uses only the decoder part of the overall Transformer. + + Args: + vocab_size: Input vocabulary size -- each element of the input tensor + should be an integer in `range(vocab_size)`. These integers typically + represent token IDs from a vocabulary-based tokenizer. + d_model: Final dimension of tensors at most points in the model, including + the initial embedding output. + d_ff: Size of special dense layer in the feed-forward part of each encoder + block. + vanilla_layers: (pre_layers, post_layers) tuple - number of full token-level + Transformer decoder layers before and after shortening. + hierarchy: string - shortening hierarchy, as described in the paper. + Hierarchy levels must form a palindrome, e.g. '1@2 2@6 1@2'. + n_heads: Number of attention heads. + dropout: Stochastic rate (probability) for dropping an activation value + when applying dropout within an encoder block. + dropout_shared_axes: Tensor axes on which to share a dropout mask. + Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is + a useful way to save memory and apply consistent masks to activation + vectors at different sequence positions. + mode: str: 'train' or 'eval'. + ff_activation: Type of activation function at the end of each encoder + block; must be an activation-type subclass of `Layer`. + vanilla_attn_type: class: attention class such as SelfAttention to use in + the layers before and after shortening (vanilla layers). + middle_attn_type: class: attention class to use in the middle layers + (these operating on the shortened sequence). + downsampling_fn: function that takes full token-level vectors of + length `l` and transforms them into `l` / `k` vectors, where `k` + denotes `shorten_factor` parameter. + upsampling_fn: function that takes shortened representations of a sequence, + consisting of `l` / `k` vectors and transforms them into full + token-level representations of length `l`. + attention_downsampling_fn: Downsampling function that transforms token-level + vectors into query vectors with reduced length. Necessary only when + AttentionResampling is used as `downsampling_fn`. + attention_upsampling_fn: Upsampling function for AttentionResampling. + Valid only when AttentionResampling is used as a `upsampling_fn`. + + + Returns: + A Transformer language model as a layer that maps from a tensor of tokens + to activations over a vocab set. + """ + assert mode != 'predict' # For now, 'predict' mode is unsupported. + hierarchy_n_layers, hierarchy_shorten_factors = _parse_hierarchy(hierarchy) + + token_encoder = [ + tl.Embedding(vocab_size, d_model), + tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)] + + context_bias_layer, location_bias_layer = get_rel_att_inputs(d_model, + n_heads) + + n_pre_decoder_blocks, n_post_decoder_blocks = vanilla_layers + + def create_decoder_blocks(n_layers, total_pooling, attention_type): + decoder_blocks = [ + # pylint: disable=g-complex-comprehension + _RelativeDecoderBlock(attention_type, d_model, d_ff, n_heads, dropout, + dropout_shared_axes, mode, ff_activation, + context_bias_layer, location_bias_layer, + total_pooling) + for _ in range(n_layers)] + return decoder_blocks + [tl.LayerNorm()] + + def create_hourglass_valley(rest_shorten_factors, + rest_n_funnel_blocks, + current_total_pooling): + assert (len(rest_shorten_factors) > 0) + assert (len(rest_shorten_factors) == len(rest_n_funnel_blocks)) + + current_sf = rest_shorten_factors[0] + current_n_layers = rest_n_funnel_blocks[0] + + shortening_layer = downsampling_fn(current_sf, d_model, + is_upsampling=False, d_ff=d_ff, + n_heads=n_heads, + dropout=dropout, + dropout_shared_axes=dropout_shared_axes, + mode=mode, + ff_activation=ff_activation, + context_bias_layer=context_bias_layer, + location_bias_layer=location_bias_layer, + total_pooling=current_total_pooling, + resampling_fn=attention_downsampling_fn) + + upsampling_layer = upsampling_fn(current_sf, + d_model=d_model, is_upsampling=True, + d_ff=d_ff, n_heads=n_heads, + dropout=dropout, + dropout_shared_axes=dropout_shared_axes, + mode=mode, ff_activation=ff_activation, + context_bias_layer=context_bias_layer, + location_bias_layer=location_bias_layer, + total_pooling=current_total_pooling, + resampling_fn=attention_upsampling_fn) + + if len(rest_shorten_factors) > 1: # we need to go deeper again + pre_stage_blocks = create_decoder_blocks( + current_n_layers, + current_total_pooling * current_sf, + middle_attn_type + ) + + post_stage_blocks = create_decoder_blocks( + current_n_layers, + current_total_pooling * current_sf, + middle_attn_type + ) + + return [ + tl.Dup(), + tl.ShiftRight(current_sf - 1, mode=mode), + shortening_layer, + pre_stage_blocks, + *create_hourglass_valley(rest_shorten_factors[1:], + rest_n_funnel_blocks[1:], + current_total_pooling * current_sf), + post_stage_blocks, + upsampling_layer, + tl.LayerNorm(), + tl.Add() + ] + else: + blocks = create_decoder_blocks(current_n_layers, + current_total_pooling * current_sf, + middle_attn_type) + + return [ + tl.Dup(), + tl.ShiftRight(current_sf - 1), + shortening_layer, + blocks, + upsampling_layer, + tl.LayerNorm(), + tl.Add() + ] + + pre_decoder_blocks = create_decoder_blocks(n_pre_decoder_blocks, 1, + vanilla_attn_type) + + post_decoder_blocks = create_decoder_blocks(n_post_decoder_blocks, 1, + vanilla_attn_type) + + valley = create_hourglass_valley(hierarchy_shorten_factors, + hierarchy_n_layers, 1) + + # Assemble and return the model. + return tl.Serial( # tokens (or chunked tuple of tokens) + tl.ShiftRight(mode=mode), # toks + token_encoder, # vecs + pre_decoder_blocks, # vecs + valley, # shortened vecs + post_decoder_blocks, # vecs + tl.Dense(vocab_size), # vecs + ) diff --git a/trax/models/research/hourglass_test.py b/trax/models/research/hourglass_test.py new file mode 100644 index 000000000..b20114472 --- /dev/null +++ b/trax/models/research/hourglass_test.py @@ -0,0 +1,145 @@ +# coding=utf-8 +# Copyright 2021 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for Hourglass model.""" + +from absl.testing import absltest +from absl.testing import parameterized +import gin +import jax +import numpy as np +from trax import fastmath +from trax import layers as tl +from trax import shapes +import trax.models.research.hourglass as hourglass +import trax.layers.research.resampling as resampling + + +class HourglassTest(parameterized.TestCase): + def _check_forward_shape(self, model, input_shape, output_vocab_size): + x = np.ones(input_shape).astype(np.int32) + model.init(shapes.signature(x)) + y = model(x) + self.assertEqual(y.shape, (*input_shape, output_vocab_size)) + + def test_hourglass_lm_forward_shape(self): + d_model = 16 + vocab_size = 7 + model = hourglass.HourglassLM( + vocab_size, + hierarchy='2@3 2@6 2@3', + vanilla_layers=(1, 1), + d_model=d_model, + d_ff=d_model, + n_heads=2, + ) + + batch_size, seq_len = 3, 24 + self._check_forward_shape(model, + input_shape=(batch_size, seq_len), + output_vocab_size=vocab_size) + + def test_lsh_attention_in_vanilla(self): + d_model = 16 + vocab_size = 7 + + gin.bind_parameter('PureLSHSelfAttentionWrapper.pure_lsh_implementation', + tl.PureLSHSelfAttention) + gin.bind_parameter('PureLSHSelfAttention.chunk_len', 2) + + model = hourglass.HourglassLM( + vocab_size, + hierarchy='2@3', + vanilla_layers=(1, 1), + d_model=d_model, + d_ff=d_model, + n_heads=2, + vanilla_attn_type=tl.PureLSHSelfAttentionWrapper, + downsampling_fn=resampling.LinearPooling, + upsampling_fn=resampling.LinearUpsampling, + ) + + batch_size, seq_len = 3, 12 + self._check_forward_shape( + model, input_shape=(batch_size, seq_len), output_vocab_size=vocab_size) + + def _test_autoregressive_property(self, model, input_shape, + output_vocab_size): + rng_1 = jax.random.PRNGKey(0) + rng_2 = jax.random.PRNGKey(1) + + def _get_output_logits(unitialized_eval_model: tl.Layer, x): + input_signature = shapes.signature(x) + unitialized_eval_model.init(input_signature, rng=rng_1, use_cache=False) + + output_logits, *_ = unitialized_eval_model(x, rng=rng_1) + return output_logits + + def check_autoregressive_property(model): + with fastmath.use_backend(fastmath.Backend.JAX): + x_1 = jax.random.randint(rng_1, input_shape, 0, output_vocab_size) + y_1 = _get_output_logits(model, x_1) + + x_2 = jax.random.randint(rng_2, input_shape, 0, output_vocab_size) + + for i in range(input_shape[1]): + masked_x_2 = np.concatenate((x_1[:, :i], x_2[:, i:]), axis=1) + + y_2 = _get_output_logits(model, masked_x_2) + self.assertEqual(y_2.shape[0], input_shape[1]) + np.testing.assert_array_almost_equal(y_1[:i + 1], y_2[:i + 1]) + + check_autoregressive_property(model) + + def test_hourglass_lm_autoregressive_property(self): + d_model = 8 + vocab_size = 26 + + model_single_stage = hourglass.HourglassLM( + vocab_size, + hierarchy='2@4', + vanilla_layers=(1, 1), + d_model=d_model, + d_ff=d_model, + n_heads=2, + ) + + model_multi_stage = hourglass.HourglassLM( + vocab_size, + hierarchy='2@3 2@6 2@3', + vanilla_layers=(1, 1), + d_model=d_model, + d_ff=d_model, + n_heads=2, + ) + + input_shape = (1, 12) + self._test_autoregressive_property(model_single_stage, input_shape, + output_vocab_size=vocab_size) + self._test_autoregressive_property(model_multi_stage, input_shape, + output_vocab_size=vocab_size) + + def test_parse_hourglass_hierarchy(self): + self.assertEqual(hourglass._parse_hierarchy('6@3'), ([6], [3])) + self.assertEqual(hourglass._parse_hierarchy('3@2 2@6 5@24 2@6 3@2'), ( + [3, 2, 5], [2, 3, 4] + )) + self.assertRaises(ValueError, hourglass._parse_hierarchy, '1@2 2@3 1@2') + self.assertRaises(ValueError, hourglass._parse_hierarchy, '1@2 2@3') + + +if __name__ == '__main__': + absltest.main() diff --git a/trax/supervised/configs/relformer_cifar10.gin b/trax/supervised/configs/hourglass_cifar10.gin similarity index 52% rename from trax/supervised/configs/relformer_cifar10.gin rename to trax/supervised/configs/hourglass_cifar10.gin index 2fa18c332..0eac45735 100644 --- a/trax/supervised/configs/relformer_cifar10.gin +++ b/trax/supervised/configs/hourglass_cifar10.gin @@ -18,10 +18,12 @@ import trax.models import trax.optimizers import trax.supervised.trainer_lib +train_steps = 100000 + # Parameters for batcher: # ============================================================================== batcher.data_streams = @data.data_streams -batcher.batch_size_per_device = 4 +batcher.batch_size_per_device = 1 batcher.eval_batch_size = 8 batcher.max_eval_length = 3072 # 32 * 32 * 3 batcher.variable_shapes = False @@ -35,58 +37,38 @@ data_streams.target_name = 'image' data_streams.bare_preprocess_fn = \ @data.downsampled_imagenet_flatten_bare_preprocess -# Parameters for multifactor: -# ============================================================================== -# 0.0442 ~= 512^-0.5 = d_model^-0.5 -multifactor.constant = 0.01 -multifactor.factors = 'constant * rsqrt_normalized_decay' -multifactor.warmup_steps = 10000 +# Parameters for multifactor: # ================================================ +multifactor.constant = 1e-3 +multifactor.factors = 'constant * linear_warmup * cosine_decay' +multifactor.warmup_steps = 5000 +multifactor.steps_per_cycle = %train_steps -# Parameters for Adafactor: +# Parameters for Adam: # ============================================================================== -Adafactor.learning_rate=0.1 -Adafactor.factored=True -Adafactor.multiply_by_parameter_scale=True -Adafactor.do_clipping=True -Adafactor.do_momentum=True -Adafactor.momentum_in_bfloat16=False -Adafactor.beta1=0.9 -Adafactor.decay_rate=1.0 -Adafactor.clipping_threshold=1.0 -Adafactor.weight_decay_rate=0 -Adafactor.weight_decay_n_steps=0 -Adafactor.epsilon1=1e-16 -Adafactor.epsilon2=1e-3 +Adam.weight_decay_rate=0.0 +Adam.b1 = 0.9 +Adam.b2 = 0.98 +Adam.eps = 1e-9 # Parameters for train: # ============================================================================== train.eval_frequency = 2000 train.eval_steps = 625 +train.checkpoints_at = [100000] +train.model = @trax.models.HourglassLM +train.optimizer = @trax.optimizers.Adam +train.steps = %train_steps -train.model = @trax.models.RelformerLM # @trax.models.TransformerLM -train.optimizer = @trax.optimizers.Adafactor -train.steps = 200000 -train.checkpoints_at = [50000, 100000, 200000] - -# Parameters for RelformerLM: -# ============================================================================== -RelformerLM.d_model = 512 -RelformerLM.d_ff = 2048 -RelformerLM.vanilla_layers=(3, 3) -RelformerLM.n_funnel_blocks=(24,) -RelformerLM.shorten_factors=(3,) -RelformerLM.dropout = 0.0 -RelformerLM.mode = 'train' -RelformerLM.n_heads = 8 -RelformerLM.vocab_size = 256 -# Parameters for TransformerLM: +# Parameters for HourglassLM: # ============================================================================== -TransformerLM.d_model = 512 -TransformerLM.d_ff = 2048 -TransformerLM.dropout = 0.0 -TransformerLM.max_len = 3072 -TransformerLM.mode = 'train' -TransformerLM.n_heads = 8 -TransformerLM.n_layers = 12 -TransformerLM.vocab_size = 256 +HourglassLM.d_model = 512 +HourglassLM.d_ff = 2048 +HourglassLM.vanilla_layers = (1, 1) +HourglassLM.hierarchy = '8@3' +HourglassLM.dropout = 0.0 +HourglassLM.mode = 'train' +HourglassLM.n_heads = 8 +HourglassLM.vocab_size = 256 +HourglassLM.attention_downsampling_fn = @LinearPooling +HourglassLM.attention_upsampling_fn = @LinearUpsampling diff --git a/trax/supervised/configs/relformer_enwik8.gin b/trax/supervised/configs/hourglass_enwik8.gin similarity index 75% rename from trax/supervised/configs/relformer_enwik8.gin rename to trax/supervised/configs/hourglass_enwik8.gin index 5c7fa2902..e65b64720 100644 --- a/trax/supervised/configs/relformer_enwik8.gin +++ b/trax/supervised/configs/hourglass_enwik8.gin @@ -22,8 +22,8 @@ import trax.supervised.trainer_lib # Parameters for batcher: # ============================================================================== batcher.data_streams = @data.data_streams -batcher.max_eval_length = 2048 -batcher.buckets = ([2048], [8]) +batcher.max_eval_length = 2049 +batcher.buckets = ([2049], [8]) batcher.id_to_mask = 0 # Parameters for data_streams: @@ -38,11 +38,11 @@ data_streams.input_name = 'targets' multifactor.constant = 4.1e-4 multifactor.factors = 'constant * linear_warmup * cosine_decay' multifactor.warmup_steps = 4000 -multifactor.steps_per_cycle = 200000 +multifactor.steps_per_cycle = 350000 # Parameters for Adam: # ============================================================================== -Adam.weight_decay_rate=0.0 +Adam.weight_decay_rate = 0.0 Adam.b1 = 0.9 Adam.b2 = 0.98 Adam.eps = 1e-9 @@ -51,22 +51,21 @@ Adam.eps = 1e-9 # ============================================================================== train.eval_frequency = 2000 train.eval_steps = 305 -train.model = @trax.models.RelformerLM +train.model = @trax.models.HourglassLM train.optimizer = @trax.optimizers.Adam -train.steps = 400000 +train.steps = 263000 train.save_graphs = False -train.checkpoints_at = [200000, 300000, 400000] +train.checkpoints_at = [150000, 175000, 263000] -# Parameters for RelformerLM: +# Parameters for HourglassLM: # ============================================================================== -RelformerLM.d_ff = 2048 -RelformerLM.d_model = 512 -RelformerLM.dropout = 0.15 -RelformerLM.vanilla_layers = (6,6) -RelformerLM.n_funnel_blocks = (8,) -RelformerLM.shorten_factors = (2,) -RelformerLM.n_heads = 8 -RelformerLM.vocab_size = 256 -RelformerLM.attention_upsampling_fn = @NoUpsampling -RelformerLM.ff_activation = @trax.layers.FastGelu +HourglassLM.d_ff = 2048 +HourglassLM.d_model = 512 +HourglassLM.dropout = 0.2 +HourglassLM.vanilla_layers = (5,5) +HourglassLM.hierarchy = '24@3' +HourglassLM.n_heads = 8 +HourglassLM.vocab_size = 256 +HourglassLM.attention_upsampling_fn = @NaiveUpsampling +HourglassLM.ff_activation = @trax.layers.FastGelu diff --git a/trax/supervised/configs/relformer_imagenet32.gin b/trax/supervised/configs/hourglass_imagenet32.gin similarity index 84% rename from trax/supervised/configs/relformer_imagenet32.gin rename to trax/supervised/configs/hourglass_imagenet32.gin index 563d6129a..36cfa452f 100644 --- a/trax/supervised/configs/relformer_imagenet32.gin +++ b/trax/supervised/configs/hourglass_imagenet32.gin @@ -53,7 +53,7 @@ Adam.eps = 1e-09 # ============================================================================== train.eval_frequency = 1000 train.eval_steps = 512 -train.model = @trax.models.RelformerLM +train.model = @trax.models.HourglassLM train.optimizer = @trax.optimizers.Adam train.steps = 400000 train.checkpoints_at = \ @@ -62,15 +62,14 @@ train.checkpoints_at = \ train.permanent_checkpoints_at = \ [100000, 150000, 200000, 250000, 300000, 350000] -# Parameters for RelformerLM: +# Parameters for HourglassLM: # ============================================================================== -RelformerLM.d_ff = 2048 -RelformerLM.d_model = 512 -RelformerLM.dropout = 0.01 -RelformerLM.n_funnel_blocks = (24,) -RelformerLM.n_heads = 8 -RelformerLM.shorten_factors = (3,) -RelformerLM.vanilla_layers = (3, 3) -RelformerLM.vocab_size = 256 -RelformerLM.attention_downsampling_fn = @LinearPooling -RelformerLM.ff_activation = @trax.layers.FastGelu +HourglassLM.d_ff = 2048 +HourglassLM.d_model = 512 +HourglassLM.dropout = 0.01 +HourglassLM.n_heads = 8 +HourglassLM.vanilla_layers = (3, 3) +HourglassLM.hierarchy = '24@3' +HourglassLM.vocab_size = 256 +HourglassLM.attention_downsampling_fn = @LinearPooling +HourglassLM.ff_activation = @trax.layers.FastGelu diff --git a/trax/supervised/configs/relformer_imagenet64.gin b/trax/supervised/configs/hourglass_imagenet64.gin similarity index 86% rename from trax/supervised/configs/relformer_imagenet64.gin rename to trax/supervised/configs/hourglass_imagenet64.gin index df117dc30..dca4cd5af 100644 --- a/trax/supervised/configs/relformer_imagenet64.gin +++ b/trax/supervised/configs/hourglass_imagenet64.gin @@ -56,7 +56,7 @@ Adam.eps = 1e-9 # ============================================================================== train.eval_frequency = 1000 train.eval_steps = 512 -train.model = @trax.models.RelformerLM +train.model = @trax.models.HourglassLM train.optimizer = @trax.optimizers.Adam train.steps = 500000 train.save_graphs = False @@ -86,18 +86,17 @@ layers.SelfAttention.chunk_len = 512 layers.SelfAttention.n_chunks_after = 0 layers.SelfAttention.n_chunks_before = 1 -# Parameters for RelformerLM: +# Parameters for HourglassLM: # ============================================================================== -RelformerLM.d_model = 768 -RelformerLM.d_ff = 3072 -RelformerLM.dropout = 0.0 -RelformerLM.ff_activation = @trax.layers.FastGelu -RelformerLM.mode = 'train' -RelformerLM.n_heads = 8 -RelformerLM.n_funnel_blocks = (12,) -RelformerLM.vanilla_layers = (3, 3) -RelformerLM.shorten_factors = (3,) -RelformerLM.vocab_size = 256 -RelformerLM.vanilla_attn_type = %vanilla_attn_type -RelformerLM.downsampling_fn = @LinearPooling -RelformerLM.upsampling_fn = @LinearUpsampling +HourglassLM.d_model = 768 +HourglassLM.d_ff = 3072 +HourglassLM.dropout = 0.0 +HourglassLM.ff_activation = @trax.layers.FastGelu +HourglassLM.mode = 'train' +HourglassLM.n_heads = 8 +HourglassLM.vanilla_layers = (3, 3) +HourglassLM.hierarchy = '12@3' +HourglassLM.vocab_size = 256 +HourglassLM.vanilla_attn_type = %vanilla_attn_type +HourglassLM.downsampling_fn = @LinearPooling +HourglassLM.upsampling_fn = @LinearUpsampling diff --git a/trax/supervised/configs/relformer_wiki40b.gin b/trax/supervised/configs/hourglass_wiki40b.gin similarity index 89% rename from trax/supervised/configs/relformer_wiki40b.gin rename to trax/supervised/configs/hourglass_wiki40b.gin index 2c49b5c94..98f7ee1cf 100644 --- a/trax/supervised/configs/relformer_wiki40b.gin +++ b/trax/supervised/configs/hourglass_wiki40b.gin @@ -20,7 +20,7 @@ import trax.supervised.trainer_lib # Macros: # ============================================================================== train_batch = 128 -valid_batch = 128 +valid_batch = 8 max_length = 2048 vocab_size = 32000 @@ -47,6 +47,7 @@ make_inputs.eval_stream = [ @data.PadToLength(), @data.ConcatenateToLMInput(), @data.AddLossWeights(), + @data.Shuffle(), @validation/data.Batch(), ] @@ -80,7 +81,7 @@ validation/data.Batch.batch_size = %valid_batch # ============================================================================== train.eval_frequency = 2500 train.eval_steps = 600 -train.model = @trax.models.TransformerLM # @trax.models.RelformerLM +train.model = @trax.models.HourglassLM # @trax.models.TransformerLM train.steps = 125000 train.optimizer = @trax.optimizers.Adam train.permanent_checkpoints_at = [50000, 100000, 125000] @@ -112,15 +113,14 @@ TransformerLM.n_heads = 12 TransformerLM.n_layers = 12 TransformerLM.vocab_size = %vocab_size -# Parameters for RelformerLM: +# Parameters for HourglassLM: # ============================================================================== -RelformerLM.d_model = 768 -RelformerLM.d_ff = 3072 -RelformerLM.vanilla_layers=(4, 4) -RelformerLM.n_funnel_blocks=(8,) -RelformerLM.shorten_factors=(2,) -RelformerLM.dropout = 0.0 -RelformerLM.mode = 'train' -RelformerLM.n_heads = 12 -RelformerLM.vocab_size = %vocab_size -RelformerLM.attention_upsampling_fn = @NoUpsampling +HourglassLM.d_model = 512 +HourglassLM.d_ff = 2048 +HourglassLM.vanilla_layers=(4, 4) +HourglassLM.hierarchy = '8@4' +HourglassLM.dropout = 0.0 +HourglassLM.mode = 'train' +HourglassLM.n_heads = 8 +HourglassLM.vocab_size = %vocab_size +HourglassLM.attention_upsampling_fn = @LinearUpsampling diff --git a/trax/supervised/configs/relformer_enwik8_sweep.yaml b/trax/supervised/configs/relformer_enwik8_sweep.yaml deleted file mode 100644 index 7c7a807a7..000000000 --- a/trax/supervised/configs/relformer_enwik8_sweep.yaml +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2021 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Recommended hyper-parameters sweep: - -sweep1: - n_funnel_blocks: [16] - shorten_factors: [4] - d_model: [768] - dropout: [0.15, 0.3] - d_ff: [3072] - -sweep2: - n_funnel_blocks: [12, 16] - shorten_factors: [2] - d_model: [512, 768] - dropout: [0.15, 0.3] - d_ff: [3072] - diff --git a/trax/supervised/configs/relformer_imagenet32_sweep.yaml b/trax/supervised/configs/relformer_imagenet32_sweep.yaml deleted file mode 100644 index fae1006c1..000000000 --- a/trax/supervised/configs/relformer_imagenet32_sweep.yaml +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2021 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -sweep1: - d_model: [512] - d_ff: [2048] - n_heads: [8] - -sweep2: - d_model: [768] - d_ff: [3072] - n_heads: [12] \ No newline at end of file diff --git a/trax/supervised/configs/scientific_papers_relformer_lm.gin b/trax/supervised/configs/scientific_papers_relformer_lm.gin deleted file mode 100644 index 964dbb36a..000000000 --- a/trax/supervised/configs/scientific_papers_relformer_lm.gin +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2021 The Trax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import trax.data -import trax.models -import trax.optimizers -import trax.supervised.trainer_lib - -# Macros: -# ============================================================================== -# Maximum length of an input sequence. -max_len = 16384 -pos_axial_shape = (128, 128) # should multiply out to max_len -attn_kv = 128 -n_layers = 8 # TODO(wgaj): used to be 12 -dropout = 0.2 -d_model = 1024 -pos_d_axial_embs = (512, 512) -d_ff = 4096 -n_heads = 8 -ff_chunk_size = 0 -attn_type = @trax.layers.SelfAttention - -# Parameters for TFDS data pipeline: -# ============================================================================== -# TODO(wgaj): Add shuffling. -make_inputs.train_stream = [ - @train/data.TFDS(), - @data.ConvertToUnicode(), - @data.Tokenize(), - @data.FilterEmptyExamples(), - @data.TruncateToLength(), - @data.AppendValue(), - @data.ConcatenateToLMInput(), - @data.Batch(), -] -train/data.TFDS.dataset_name = 'scientific_papers/arxiv:1.1.1' -train/data.TFDS.keys = ('article', 'abstract') -train/data.TFDS.train = True -data.ConvertToUnicode.keys = [0, 1] -data.Tokenize.vocab_file = 'gs://t5-data/vocabs/cc_all.32000/sentencepiece.model' -data.Tokenize.keys = [0, 1] -data.Tokenize.vocab_type = 'sentencepiece' -data.TruncateToLength.len_map = {0: (15359, ), 1: (1023, )} -data.AppendValue.val = {0:[0], 1:[1]} -data.ConcatenateToLMInput.pad_to_length = 16384 -data.Batch.batch_size = 8 - -make_inputs.eval_stream = [ - @eval/data.TFDS(), - @data.ConvertToUnicode(), - @data.Tokenize(), - @data.FilterEmptyExamples(), - @data.TruncateToLength(), - @data.AppendValue(), - @data.ConcatenateToLMInput(), - @data.Batch(), -] -eval/data.TFDS.dataset_name = 'scientific_papers/arxiv:1.1.1' -eval/data.TFDS.keys = ('article', 'abstract') -eval/data.TFDS.train = False - - -# Parameters for multifactor: -# ============================================================================== -multifactor.constant = 1.0 -multifactor.factors = 'constant * linear_warmup * rsqrt_decay' -multifactor.warmup_steps = 10000 - - -# Parameters for Adafactor: -# ============================================================================== -Adafactor.beta1 = 0.0 -Adafactor.decay_rate = 0.95 # Used to be 0.8 -Adafactor.clipping_threshold = 1.0 -#Adafactor.epsilon1 = 1e-16 -Adafactor.epsilon1 = 1e-25 -Adafactor.epsilon2 = 0.001 -Adafactor.factored = True -Adafactor.multiply_by_parameter_scale = True - - -# Parameters for train: -# ============================================================================== -train.eval_frequency = 500 -train.eval_steps = 10 -train.model = @trax.models.RelformerLM -train.steps = 1000000 -train.optimizer = @trax.optimizers.Adafactor -train.checkpoint_highest = 'neg_log_perplexity' -train.checkpoint_lowest = 'loss' -# train.use_memory_efficient_trainer = True -train.inputs = @trax.data.make_inputs - - -# We are using T5's 32k SPM model by default. -vocab_size = 32000 - -# Parameters for SelfAttention: -# ============================================================================== -trax.layers.SelfAttention.chunk_len = 128 -trax.layers.SelfAttention.n_chunks_after = 0 -trax.layers.SelfAttention.n_chunks_before = 0 -trax.layers.SelfAttention.n_parallel_heads = 1 -trax.layers.SelfAttention.attention_dropout = 0.2 - -# Parameters for RelformerLM: -# ============================================================================== -RelformerLM.vocab_size=%vocab_size # Includes pad token and unused EOS token -RelformerLM.d_model = 768 -RelformerLM.d_ff = 3072 -RelformerLM.vanilla_layers = (3, 3) -RelformerLM.shorten_factors = (4,) -RelformerLM.n_heads = 8 -RelformerLM.dropout = 0.1