From e04da2e35c29a13177bafa7df45d287a0f02f77c Mon Sep 17 00:00:00 2001 From: Jon Craton Date: Sat, 20 Dec 2025 09:02:08 -0500 Subject: [PATCH] Initial t5gemma implementation --- README.md | 2 +- include/ctranslate2/layers/transformer.h | 6 + python/ctranslate2/converters/transformers.py | 269 ++++++++++++++++++ python/ctranslate2/specs/transformer_spec.py | 27 ++ src/layers/transformer.cc | 89 ++++++ 5 files changed, 392 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index fb91f5eb3..2955c62c1 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ The project implements a custom runtime that applies many performance optimizati The following model types are currently supported: -* Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper +* Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, T5Gemma, Whisper * Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, Mistral, Gemma, CodeGen, GPTBigCode, Falcon, Qwen2 * Encoder-only models: BERT, DistilBERT, XLM-RoBERTa diff --git a/include/ctranslate2/layers/transformer.h b/include/ctranslate2/layers/transformer.h index 017f9b675..08ff23bdb 100644 --- a/include/ctranslate2/layers/transformer.h +++ b/include/ctranslate2/layers/transformer.h @@ -68,6 +68,10 @@ namespace ctranslate2 { private: std::unique_ptr _self_attention; + const std::unique_ptr _input_layer_norm; + const std::unique_ptr _post_attention_layer_norm; + const std::unique_ptr _pre_feedforward_layer_norm; + const std::unique_ptr _post_feedforward_layer_norm; const FeedForwardNetwork _ff; }; @@ -121,6 +125,8 @@ namespace ctranslate2 { const std::unique_ptr _post_attention_layer_norm; const std::unique_ptr _pre_feedforward_layer_norm; const std::unique_ptr _post_feedforward_layer_norm; + const std::unique_ptr _pre_cross_attention_layer_norm; + const std::unique_ptr _post_cross_attention_layer_norm; const std::unique_ptr _encoder_attention; const FeedForwardNetwork _ff; }; diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index b20ccbf62..294525c2d 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -1325,6 +1325,275 @@ def architecture_name(self): return "MT5ForConditionalGeneration" +@register_loader("T5GemmaConfig") +class T5GemmaLoader(ModelLoader): + @property + def architecture_name(self): + return "T5GemmaForConditionalGeneration" + + def get_model_spec(self, model): + # Get encoder and decoder configs + encoder_config = model.config.encoder + decoder_config = model.config.decoder + + # Extract attention head configuration for encoder + num_heads_enc = encoder_config.num_attention_heads + num_heads_kv_enc = getattr( + encoder_config, + "num_key_value_heads", + num_heads_enc, + ) + if num_heads_kv_enc == num_heads_enc: + num_heads_kv_enc = None + + # Extract attention head configuration for decoder + num_heads_dec = decoder_config.num_attention_heads + num_heads_kv_dec = getattr( + decoder_config, + "num_key_value_heads", + num_heads_dec, + ) + if num_heads_kv_dec == num_heads_dec: + num_heads_kv_dec = None + + # Extract head dimension for decoder (needed for RoPE) + head_dim_dec = getattr( + decoder_config, + "head_dim", + decoder_config.hidden_size // num_heads_dec, + ) + + # Get activation function + activation_config = getattr( + decoder_config, "hidden_activation", "gelu_pytorch_tanh" + ) + + # Create encoder-decoder spec with Gemma2 features + # T5Gemma is based on Gemma2 architecture adapted for + # encoder-decoder. Note: We can't use TransformerSpec.from_config + # for RoPE, so we create specs manually + + activation = ( + common_spec.Activation.GELU + if activation_config == "gelu" + else common_spec.Activation.GELUTanh + ) + + # Create encoder spec (T5Gemma encoder doesn't use RoPE) + encoder_spec = transformer_spec.TransformerEncoderSpec( + encoder_config.num_hidden_layers, + num_heads_enc, + pre_norm=True, + activation=activation, + ffn_glu=True, + rms_norm=True, + pre_post_layer_norm=True, + ) + + # Create decoder spec with RoPE (T5Gemma decoder uses RoPE) + decoder_spec = transformer_spec.TransformerDecoderSpec( + decoder_config.num_hidden_layers, + num_heads_dec, + pre_norm=True, + activation=activation, + with_encoder_attention=True, + ffn_glu=True, + rms_norm=True, + pre_post_layer_norm=True, + rotary_dim=0, # Apply RoPE to all dimensions + rotary_interleave=False, # Gemma2-style RoPE + rotary_base=getattr(decoder_config, "rope_theta", 10000), + num_heads_kv=num_heads_kv_dec, + head_dim=head_dim_dec, + ) + + # Create full transformer spec + spec = transformer_spec.TransformerSpec(encoder_spec, decoder_spec) + + # Set encoder and decoder following Gemma2 pattern + # T5Gemma has encoder/decoder under model.model, not directly + # under model + self.set_encoder(spec.encoder, model.model.encoder, encoder_config) + self.set_decoder(spec.decoder, model.model.decoder, decoder_config) + + # Handle lm_head - in T5Gemma it has an out_proj wrapper + # (T5GemmaLMHead). Handle tied word embeddings like Gemma2 + if model.config.tie_word_embeddings: + # When embeddings are tied, set projection to use decoder + # embeddings + decoder_emb = ( + spec.decoder.embeddings[0] + if isinstance(spec.decoder.embeddings, list) + else spec.decoder.embeddings + ) + spec.decoder.projection.weight = decoder_emb.weight + # T5-style models require output scaling when embeddings are tied + spec.decoder.scale_outputs = decoder_config.hidden_size**-0.5 + else: + # If not tied, set projection weights explicitly from lm_head + lm_head = model.lm_head + if hasattr(lm_head, "out_proj"): + lm_head = lm_head.out_proj + self.set_linear(spec.decoder.projection, lm_head) + + return spec + + def get_vocabulary(self, model, tokenizer): + tokens = super().get_vocabulary(model, tokenizer) + + extra_ids = model.config.vocab_size - len(tokens) + for i in range(extra_ids): + tokens.append("" % i) + if model.config.vocab_size < len(tokens): + tokens = tokens[: model.config.vocab_size] + + return tokens + + def set_vocabulary(self, spec, tokens): + spec.register_source_vocabulary(tokens) + spec.register_target_vocabulary(tokens) + + def set_config(self, config, model, tokenizer): + config.bos_token = tokenizer.bos_token + config.eos_token = tokenizer.eos_token + config.unk_token = tokenizer.unk_token + config.layer_norm_epsilon = model.config.decoder.rms_norm_eps + if ( + hasattr(model.config, "decoder_start_token_id") + and model.config.decoder_start_token_id is not None + ): + config.decoder_start_token = tokenizer.convert_ids_to_tokens( + model.config.decoder_start_token_id + ) + else: + config.decoder_start_token = tokenizer.bos_token + + def set_layer_norm(self, spec, layer_norm): + spec.gamma = layer_norm.weight + spec.layer_norm_use_residual = True + + def set_encoder(self, spec, encoder, encoder_config): + spec.scale_embeddings = True + # Set Gemma2-style embedding scaling + embeddings = ( + spec.embeddings[0] if isinstance(spec.embeddings, list) else spec.embeddings + ) + self.set_embeddings(embeddings, encoder.embed_tokens) + embeddings.multiply_by_sqrt_depth = encoder_config.hidden_size**0.5 + + # T5Gemma encoder has a final layer norm + self.set_layer_norm(spec.layer_norm, encoder.norm) + + for i, (layer_spec, layer) in enumerate(zip(spec.layer, encoder.layers)): + # T5Gemma encoder follows Gemma2 pattern with pre+post layer norms + # Map T5Gemma naming to CTranslate2 spec naming: + # - pre_self_attn_layernorm -> input_layer_norm + # - post_self_attn_layernorm -> post_attention_layer_norm + # - pre_feedforward_layernorm -> pre_feedforward_layer_norm + # - post_feedforward_layernorm -> post_feedforward_layer_norm + + self.set_layer_norm( + layer_spec.input_layer_norm, layer.pre_self_attn_layernorm + ) + self.set_layer_norm( + layer_spec.post_attention_layer_norm, layer.post_self_attn_layernorm + ) + self.set_layer_norm( + layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm + ) + self.set_layer_norm( + layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm + ) + + # Set attention weights + wq = layer.self_attn.q_proj.weight + wk = layer.self_attn.k_proj.weight + wv = layer.self_attn.v_proj.weight + wo = layer.self_attn.o_proj.weight + + layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv]) + layer_spec.self_attention.linear[1].weight = wo + + # Set FFN weights (GeGLU activation) + self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj) + self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj) + self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj) + + delattr(layer, "self_attn") + delattr(layer, "mlp") + gc.collect() + + def set_decoder(self, spec, decoder, decoder_config): + spec.scale_embeddings = True + # Set Gemma2-style embedding scaling + embeddings = ( + spec.embeddings[0] if isinstance(spec.embeddings, list) else spec.embeddings + ) + self.set_embeddings(embeddings, decoder.embed_tokens) + embeddings.multiply_by_sqrt_depth = decoder_config.hidden_size**0.5 + + self.set_layer_norm(spec.layer_norm, decoder.norm) + + for i, (layer_spec, layer) in enumerate(zip(spec.layer, decoder.layers)): + # T5Gemma decoder follows Gemma2 pattern with pre+post layer norms + # Map T5Gemma naming to CTranslate2 spec naming: + # - pre_self_attn_layernorm -> input_layer_norm + # - post_self_attn_layernorm -> post_attention_layer_norm + # - pre_cross_attn_layernorm -> attention.layer_norm + # - post_cross_attn_layernorm -> post_cross_attention_layer_norm + # - pre_feedforward_layernorm -> pre_feedforward_layer_norm + # - post_feedforward_layernorm -> post_feedforward_layer_norm + + self.set_layer_norm( + layer_spec.input_layer_norm, layer.pre_self_attn_layernorm + ) + self.set_layer_norm( + layer_spec.post_attention_layer_norm, layer.post_self_attn_layernorm + ) + self.set_layer_norm( + layer_spec.attention.layer_norm, layer.pre_cross_attn_layernorm + ) + self.set_layer_norm( + layer_spec.post_cross_attention_layer_norm, + layer.post_cross_attn_layernorm, + ) + self.set_layer_norm( + layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm + ) + self.set_layer_norm( + layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm + ) + + # Set self-attention weights + wq = layer.self_attn.q_proj.weight + wk = layer.self_attn.k_proj.weight + wv = layer.self_attn.v_proj.weight + wo = layer.self_attn.o_proj.weight + + layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv]) + layer_spec.self_attention.linear[1].weight = wo + + # Set cross-attention weights + wq_cross = layer.cross_attn.q_proj.weight + wk_cross = layer.cross_attn.k_proj.weight + wv_cross = layer.cross_attn.v_proj.weight + wo_cross = layer.cross_attn.o_proj.weight + + layer_spec.attention.linear[0].weight = wq_cross + layer_spec.attention.linear[1].weight = torch.cat([wk_cross, wv_cross]) + layer_spec.attention.linear[2].weight = wo_cross + + # Set FFN weights (GeGLU activation) + self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj) + self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj) + self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj) + + delattr(layer, "self_attn") + delattr(layer, "cross_attn") + delattr(layer, "mlp") + gc.collect() + + @register_loader("BloomConfig") class BloomLoader(ModelLoader): @property diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index 334ecc041..2fc74dfce 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -23,6 +23,7 @@ def __init__( ffn_glu: bool = False, rms_norm: bool = False, multi_query_attention: bool = False, + pre_post_layer_norm: bool = False, ): """Initializes a Transformer encoder specification. @@ -44,6 +45,7 @@ def __init__( https://arxiv.org/abs/2002.05202. rms_norm: Use the root mean square layer normalization. multi_query_attention: Use multi-query attention. + pre_post_layer_norm: Add post layer norm for each pre norm layer. """ self.multi_query_attention = multi_query_attention self.num_heads = np.dtype("int16").type(num_heads) @@ -67,6 +69,7 @@ def __init__( ffn_glu=ffn_glu, rms_norm=rms_norm, num_heads_kv=1 if multi_query_attention else None, + pre_post_layer_norm=pre_post_layer_norm, ) for _ in range(num_layers) ] @@ -255,6 +258,7 @@ def __init__( rms_norm=False, num_heads_kv=None, sliding_window=None, + pre_post_layer_norm=False, ): self.self_attention = attention_spec.MultiHeadAttentionSpec( self_attention=True, @@ -266,6 +270,21 @@ def __init__( ) self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm) + if pre_post_layer_norm: + self.input_layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm) + self.post_attention_layer_norm = common_spec.LayerNormSpec( + rms_norm=rms_norm + ) + self.pre_feedforward_layer_norm = common_spec.LayerNormSpec( + rms_norm=rms_norm + ) + self.post_feedforward_layer_norm = common_spec.LayerNormSpec( + rms_norm=rms_norm + ) + + delattr(self.self_attention, "layer_norm") + delattr(self.ffn, "layer_norm") + class TransformerDecoderLayerSpec(model_spec.LayerSpec): def __init__( @@ -333,6 +352,10 @@ def __init__( self.post_attention_layer_norm = common_spec.LayerNormSpec( rms_norm=rms_norm ) + if with_encoder_attention: + self.post_cross_attention_layer_norm = common_spec.LayerNormSpec( + rms_norm=rms_norm + ) self.pre_feedforward_layer_norm = common_spec.LayerNormSpec( rms_norm=rms_norm ) @@ -417,6 +440,7 @@ def from_config( ffn_glu: bool = False, rms_norm: bool = False, multi_query_attention: bool = False, + pre_post_layer_norm: bool = False, ): """Creates a Transformer model specification. @@ -441,6 +465,7 @@ def from_config( https://arxiv.org/abs/2002.05202. rms_norm: Use the root mean square layer normalization. multi_query_attention: Use multi-query attention. + pre_post_layer_norm: Add post layer norm for each pre norm layer. """ if isinstance(num_layers, (list, tuple)): num_encoder_layers, num_decoder_layers = num_layers @@ -461,6 +486,7 @@ def from_config( ffn_glu=ffn_glu, rms_norm=rms_norm, multi_query_attention=multi_query_attention, + pre_post_layer_norm=pre_post_layer_norm, ) decoder = TransformerDecoderSpec( @@ -477,6 +503,7 @@ def from_config( ffn_glu=ffn_glu, rms_norm=rms_norm, multi_query_attention=multi_query_attention, + pre_post_layer_norm=pre_post_layer_norm, ) return cls(encoder, decoder) diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index 5ac5bfa36..fcf51512b 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -70,6 +70,13 @@ namespace ctranslate2 { num_heads, /*self_attention=*/true, pre_norm))) + , _input_layer_norm(build_optional_layer(model, scope + "/input_layer_norm")) + , _post_attention_layer_norm(build_optional_layer( + model, scope + "/post_attention_layer_norm")) + , _pre_feedforward_layer_norm(build_optional_layer( + model, scope + "/pre_feedforward_layer_norm")) + , _post_feedforward_layer_norm(build_optional_layer( + model, scope + "/post_feedforward_layer_norm")) , _ff(model, scope + "/ffn", pre_norm, activation_type) { } @@ -79,6 +86,44 @@ namespace ctranslate2 { const Padder* padder, StorageView* position_bias) const { PROFILE("TransformerEncoderLayer"); + + const DataType dtype = input.dtype(); + const Device device = input.device(); + + const bool pre_post_layer_norm = _post_feedforward_layer_norm && _pre_feedforward_layer_norm; + if (pre_post_layer_norm) { + StorageView hidden(dtype, device); + StorageView context(dtype, device); + (*_input_layer_norm)(input, hidden); + + if (_self_attention) + (*_self_attention)(hidden, + hidden, + lengths, + context, + nullptr, + nullptr, + nullptr, + padder, + padder, + true, + position_bias); + + (*_post_attention_layer_norm)(context, output); + ops::Add()(output, input, output); + + context = std::move(output); + (*_pre_feedforward_layer_norm)(context, output); + hidden = std::move(output); + + _ff(hidden, output); + + hidden = std::move(output); + (*_post_feedforward_layer_norm)(hidden, output); + ops::Add()(output, context, output); + return; + } + StorageView context(input.dtype(), input.device()); if (_self_attention) (*_self_attention)(input, @@ -124,6 +169,10 @@ namespace ctranslate2 { model, scope + "/pre_feedforward_layer_norm")) , _post_feedforward_layer_norm(build_optional_layer( model, scope + "/post_feedforward_layer_norm")) + , _pre_cross_attention_layer_norm(build_optional_layer( + model, scope + "/attention/layer_norm")) + , _post_cross_attention_layer_norm(build_optional_layer( + model, scope + "/post_cross_attention_layer_norm")) , _encoder_attention(build_optional_layer(model, scope + "/attention", num_heads, @@ -157,8 +206,11 @@ namespace ctranslate2 { if (pre_post_layer_norm) { StorageView hidden(dtype, device); StorageView context(dtype, device); + + // Pre-self-attention layer norm (*_input_layer_norm)(input, hidden); + // Self-attention if (_self_attention) (*_self_attention)(hidden, hidden, @@ -173,15 +225,52 @@ namespace ctranslate2 { position_bias, offset); + // Post-self-attention layer norm + residual (*_post_attention_layer_norm)(context, output); ops::Add()(output, input, output); + // Cross-attention (if present) + if (_encoder_attention) { + context = std::move(output); + + // Pre-cross-attention layer norm + if (_pre_cross_attention_layer_norm) { + (*_pre_cross_attention_layer_norm)(context, hidden); + } else { + hidden = std::move(context); + context = std::move(output); + } + + (*_encoder_attention)(hidden, + *memory, + memory_lengths, + output, + cached_attn_keys, + cached_attn_values, + attention, + input_padder, + memory_padder, + return_normalized_attention); + + // Post-cross-attention layer norm + residual + if (_post_cross_attention_layer_norm) { + hidden = std::move(output); + (*_post_cross_attention_layer_norm)(hidden, output); + ops::Add()(output, context, output); + } else { + ops::Add()(output, context, output); + } + } + + // Pre-FFN layer norm context = std::move(output); (*_pre_feedforward_layer_norm)(context, output); hidden = std::move(output); + // Feed-forward network _ff(hidden, output); + // Post-FFN layer norm + residual hidden = std::move(output); (*_post_feedforward_layer_norm)(hidden, output); ops::Add()(output, context, output);