Skip to content
Closed

T5Gemma #1955

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The project implements a custom runtime that applies many performance optimizati

The following model types are currently supported:

* Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper
* Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, T5Gemma, Whisper
* Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, Mistral, Gemma, CodeGen, GPTBigCode, Falcon, Qwen2
* Encoder-only models: BERT, DistilBERT, XLM-RoBERTa

Expand Down
6 changes: 6 additions & 0 deletions include/ctranslate2/layers/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ namespace ctranslate2 {

private:
std::unique_ptr<AttentionLayer> _self_attention;
const std::unique_ptr<const LayerNorm> _input_layer_norm;
const std::unique_ptr<const LayerNorm> _post_attention_layer_norm;
const std::unique_ptr<const LayerNorm> _pre_feedforward_layer_norm;
const std::unique_ptr<const LayerNorm> _post_feedforward_layer_norm;
const FeedForwardNetwork _ff;
};

Expand Down Expand Up @@ -121,6 +125,8 @@ namespace ctranslate2 {
const std::unique_ptr<const LayerNorm> _post_attention_layer_norm;
const std::unique_ptr<const LayerNorm> _pre_feedforward_layer_norm;
const std::unique_ptr<const LayerNorm> _post_feedforward_layer_norm;
const std::unique_ptr<const LayerNorm> _pre_cross_attention_layer_norm;
const std::unique_ptr<const LayerNorm> _post_cross_attention_layer_norm;
const std::unique_ptr<const AttentionLayer> _encoder_attention;
const FeedForwardNetwork _ff;
};
Expand Down
269 changes: 269 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,275 @@ def architecture_name(self):
return "MT5ForConditionalGeneration"


@register_loader("T5GemmaConfig")
class T5GemmaLoader(ModelLoader):
@property
def architecture_name(self):
return "T5GemmaForConditionalGeneration"

def get_model_spec(self, model):
# Get encoder and decoder configs
encoder_config = model.config.encoder
decoder_config = model.config.decoder

# Extract attention head configuration for encoder
num_heads_enc = encoder_config.num_attention_heads
num_heads_kv_enc = getattr(
encoder_config,
"num_key_value_heads",
num_heads_enc,
)
if num_heads_kv_enc == num_heads_enc:
num_heads_kv_enc = None
Comment thread
jncraton marked this conversation as resolved.

# Extract attention head configuration for decoder
num_heads_dec = decoder_config.num_attention_heads
num_heads_kv_dec = getattr(
decoder_config,
"num_key_value_heads",
num_heads_dec,
)
if num_heads_kv_dec == num_heads_dec:
num_heads_kv_dec = None

# Extract head dimension for decoder (needed for RoPE)
head_dim_dec = getattr(
decoder_config,
"head_dim",
decoder_config.hidden_size // num_heads_dec,
)

# Get activation function
activation_config = getattr(
decoder_config, "hidden_activation", "gelu_pytorch_tanh"
)

# Create encoder-decoder spec with Gemma2 features
# T5Gemma is based on Gemma2 architecture adapted for
# encoder-decoder. Note: We can't use TransformerSpec.from_config
# for RoPE, so we create specs manually

activation = (
common_spec.Activation.GELU
if activation_config == "gelu"
else common_spec.Activation.GELUTanh
)
Comment thread
jncraton marked this conversation as resolved.

# Create encoder spec (T5Gemma encoder doesn't use RoPE)
encoder_spec = transformer_spec.TransformerEncoderSpec(
encoder_config.num_hidden_layers,
num_heads_enc,
pre_norm=True,
activation=activation,
ffn_glu=True,
rms_norm=True,
pre_post_layer_norm=True,
)

# Create decoder spec with RoPE (T5Gemma decoder uses RoPE)
decoder_spec = transformer_spec.TransformerDecoderSpec(
decoder_config.num_hidden_layers,
num_heads_dec,
pre_norm=True,
activation=activation,
with_encoder_attention=True,
ffn_glu=True,
rms_norm=True,
pre_post_layer_norm=True,
rotary_dim=0, # Apply RoPE to all dimensions
rotary_interleave=False, # Gemma2-style RoPE
rotary_base=getattr(decoder_config, "rope_theta", 10000),
num_heads_kv=num_heads_kv_dec,
head_dim=head_dim_dec,
)

# Create full transformer spec
spec = transformer_spec.TransformerSpec(encoder_spec, decoder_spec)

# Set encoder and decoder following Gemma2 pattern
# T5Gemma has encoder/decoder under model.model, not directly
# under model
self.set_encoder(spec.encoder, model.model.encoder, encoder_config)
self.set_decoder(spec.decoder, model.model.decoder, decoder_config)

# Handle lm_head - in T5Gemma it has an out_proj wrapper
# (T5GemmaLMHead). Handle tied word embeddings like Gemma2
if model.config.tie_word_embeddings:
# When embeddings are tied, set projection to use decoder
# embeddings
decoder_emb = (
spec.decoder.embeddings[0]
if isinstance(spec.decoder.embeddings, list)
else spec.decoder.embeddings
)
spec.decoder.projection.weight = decoder_emb.weight
# T5-style models require output scaling when embeddings are tied
spec.decoder.scale_outputs = decoder_config.hidden_size**-0.5
else:
# If not tied, set projection weights explicitly from lm_head
lm_head = model.lm_head
if hasattr(lm_head, "out_proj"):
lm_head = lm_head.out_proj
self.set_linear(spec.decoder.projection, lm_head)

return spec

def get_vocabulary(self, model, tokenizer):
tokens = super().get_vocabulary(model, tokenizer)

extra_ids = model.config.vocab_size - len(tokens)
for i in range(extra_ids):
tokens.append("<extra_id_%d>" % i)
if model.config.vocab_size < len(tokens):
tokens = tokens[: model.config.vocab_size]
Comment thread
jncraton marked this conversation as resolved.

return tokens

def set_vocabulary(self, spec, tokens):
spec.register_source_vocabulary(tokens)
spec.register_target_vocabulary(tokens)

def set_config(self, config, model, tokenizer):
config.bos_token = tokenizer.bos_token
config.eos_token = tokenizer.eos_token
config.unk_token = tokenizer.unk_token
config.layer_norm_epsilon = model.config.decoder.rms_norm_eps
if (
hasattr(model.config, "decoder_start_token_id")
and model.config.decoder_start_token_id is not None
):
config.decoder_start_token = tokenizer.convert_ids_to_tokens(
model.config.decoder_start_token_id
)
else:
config.decoder_start_token = tokenizer.bos_token

def set_layer_norm(self, spec, layer_norm):
spec.gamma = layer_norm.weight
spec.layer_norm_use_residual = True

def set_encoder(self, spec, encoder, encoder_config):
spec.scale_embeddings = True
# Set Gemma2-style embedding scaling
embeddings = (
spec.embeddings[0] if isinstance(spec.embeddings, list) else spec.embeddings
)
self.set_embeddings(embeddings, encoder.embed_tokens)
embeddings.multiply_by_sqrt_depth = encoder_config.hidden_size**0.5

# T5Gemma encoder has a final layer norm
self.set_layer_norm(spec.layer_norm, encoder.norm)

for i, (layer_spec, layer) in enumerate(zip(spec.layer, encoder.layers)):
# T5Gemma encoder follows Gemma2 pattern with pre+post layer norms
# Map T5Gemma naming to CTranslate2 spec naming:
# - pre_self_attn_layernorm -> input_layer_norm
# - post_self_attn_layernorm -> post_attention_layer_norm
# - pre_feedforward_layernorm -> pre_feedforward_layer_norm
# - post_feedforward_layernorm -> post_feedforward_layer_norm

self.set_layer_norm(
layer_spec.input_layer_norm, layer.pre_self_attn_layernorm
)
self.set_layer_norm(
layer_spec.post_attention_layer_norm, layer.post_self_attn_layernorm
)
self.set_layer_norm(
layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
)
self.set_layer_norm(
layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
)

# Set attention weights
wq = layer.self_attn.q_proj.weight
wk = layer.self_attn.k_proj.weight
wv = layer.self_attn.v_proj.weight
wo = layer.self_attn.o_proj.weight

layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv])
layer_spec.self_attention.linear[1].weight = wo

# Set FFN weights (GeGLU activation)
self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)

delattr(layer, "self_attn")
delattr(layer, "mlp")
gc.collect()

def set_decoder(self, spec, decoder, decoder_config):
spec.scale_embeddings = True
# Set Gemma2-style embedding scaling
embeddings = (
spec.embeddings[0] if isinstance(spec.embeddings, list) else spec.embeddings
)
self.set_embeddings(embeddings, decoder.embed_tokens)
embeddings.multiply_by_sqrt_depth = decoder_config.hidden_size**0.5

self.set_layer_norm(spec.layer_norm, decoder.norm)

for i, (layer_spec, layer) in enumerate(zip(spec.layer, decoder.layers)):
# T5Gemma decoder follows Gemma2 pattern with pre+post layer norms
# Map T5Gemma naming to CTranslate2 spec naming:
# - pre_self_attn_layernorm -> input_layer_norm
# - post_self_attn_layernorm -> post_attention_layer_norm
# - pre_cross_attn_layernorm -> attention.layer_norm
# - post_cross_attn_layernorm -> post_cross_attention_layer_norm
# - pre_feedforward_layernorm -> pre_feedforward_layer_norm
# - post_feedforward_layernorm -> post_feedforward_layer_norm

self.set_layer_norm(
layer_spec.input_layer_norm, layer.pre_self_attn_layernorm
)
self.set_layer_norm(
layer_spec.post_attention_layer_norm, layer.post_self_attn_layernorm
)
self.set_layer_norm(
layer_spec.attention.layer_norm, layer.pre_cross_attn_layernorm
)
self.set_layer_norm(
layer_spec.post_cross_attention_layer_norm,
layer.post_cross_attn_layernorm,
)
self.set_layer_norm(
layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
)
self.set_layer_norm(
layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
)

# Set self-attention weights
wq = layer.self_attn.q_proj.weight
wk = layer.self_attn.k_proj.weight
wv = layer.self_attn.v_proj.weight
wo = layer.self_attn.o_proj.weight

layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv])
layer_spec.self_attention.linear[1].weight = wo

# Set cross-attention weights
wq_cross = layer.cross_attn.q_proj.weight
wk_cross = layer.cross_attn.k_proj.weight
wv_cross = layer.cross_attn.v_proj.weight
wo_cross = layer.cross_attn.o_proj.weight

layer_spec.attention.linear[0].weight = wq_cross
layer_spec.attention.linear[1].weight = torch.cat([wk_cross, wv_cross])
layer_spec.attention.linear[2].weight = wo_cross

# Set FFN weights (GeGLU activation)
self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)

delattr(layer, "self_attn")
delattr(layer, "cross_attn")
delattr(layer, "mlp")
gc.collect()


@register_loader("BloomConfig")
class BloomLoader(ModelLoader):
@property
Expand Down
Loading
Loading