From 02fe9718461813d2b984c55ce09813ef83d4666d Mon Sep 17 00:00:00 2001 From: Sonam Pankaj Date: Fri, 3 Oct 2025 23:51:59 +0200 Subject: [PATCH 1/5] Gemma3 embeddings (cherry picked from commit 30952310620913cd1e89b2c28ad344b663b3c07f) --- rust/src/embeddings/embed/embedder.rs | 7 + rust/src/embeddings/embed/text.rs | 9 + rust/src/embeddings/local/gemma3.rs | 195 +++++++++++ rust/src/embeddings/local/mod.rs | 1 + rust/src/models/gemma3.rs | 475 ++++++++++++++++++++++++++ rust/src/models/mod.rs | 1 + 6 files changed, 688 insertions(+) create mode 100644 rust/src/embeddings/local/gemma3.rs create mode 100644 rust/src/models/gemma3.rs diff --git a/rust/src/embeddings/embed/embedder.rs b/rust/src/embeddings/embed/embedder.rs index fdcbc1e5..28e525c1 100644 --- a/rust/src/embeddings/embed/embedder.rs +++ b/rust/src/embeddings/embed/embedder.rs @@ -114,6 +114,13 @@ impl Embedder { token, dtype, )?)), + "Gemma3TextModel" => Ok(Self::Text(TextEmbedder::from_pretrained_hf( + architecture, + model_id, + revision, + token, + dtype, + )?)), _ => Err(anyhow!("Model not supported")), } } diff --git a/rust/src/embeddings/embed/text.rs b/rust/src/embeddings/embed/text.rs index 354be4f2..8779d55b 100644 --- a/rust/src/embeddings/embed/text.rs +++ b/rust/src/embeddings/embed/text.rs @@ -2,6 +2,7 @@ use crate::embeddings::cloud::cohere::CohereEmbedder; use crate::embeddings::cloud::gemini::GeminiEmbedder; use crate::embeddings::cloud::openai::OpenAIEmbedder; use crate::embeddings::local::bert::{BertEmbed, BertEmbedder, SparseBertEmbedder}; +use crate::embeddings::local::gemma3::{Gemma3Embed, Gemma3Embedder}; use crate::embeddings::local::jina::{JinaEmbed, JinaEmbedder}; use crate::embeddings::local::model2vec::Model2VecEmbedder; use crate::embeddings::local::modernbert::ModernBertEmbedder; @@ -29,6 +30,7 @@ pub enum TextEmbedder { Model2Vec(Box), Bert(Box), Qwen3(Box), + Gemma3(Box), ColBert(Box), ModernBert(Box), } @@ -48,6 +50,7 @@ impl TextEmbedder { TextEmbedder::Jina(embedder) => embedder.embed(text_batch, batch_size, late_chunking), TextEmbedder::Bert(embedder) => embedder.embed(text_batch, batch_size, late_chunking), TextEmbedder::Qwen3(embedder) => embedder.embed(text_batch, batch_size, late_chunking), + TextEmbedder::Gemma3(embedder) => embedder.embed(text_batch, batch_size, late_chunking), TextEmbedder::ColBert(embedder) => { embedder.embed(text_batch, batch_size, late_chunking) } @@ -95,6 +98,12 @@ impl TextEmbedder { token, dtype, )?))), + "Gemma3TextModel" => Ok(Self::Gemma3(Box::new(Gemma3Embedder::new( + model_id, + revision.map(|s| s.to_string()), + token, + dtype, + )?))), _ => Err(anyhow!("Model not supported")), } } diff --git a/rust/src/embeddings/local/gemma3.rs b/rust/src/embeddings/local/gemma3.rs new file mode 100644 index 00000000..1c0e7593 --- /dev/null +++ b/rust/src/embeddings/local/gemma3.rs @@ -0,0 +1,195 @@ +use crate::{ + embeddings::{embed::EmbeddingResult, normalize_l2, select_device, utils::tokenize_batch}, + models::gemma3::{Config, Model}, +}; +use anyhow::Error; +use candle_core::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use hf_hub::{api::sync::ApiBuilder, Repo}; +use tokenizers::{PaddingParams, Tokenizer, TruncationParams}; + +use super::{ + colpali::hub_load_safetensors, + pooling::{ModelOutput, PooledOutputType, Pooling}, +}; + +pub trait Gemma3Embed { + fn embed( + &self, + text_batch: &[&str], + batch_size: Option, + late_chunking: Option, + ) -> Result, anyhow::Error>; +} + +pub struct Gemma3Embedder { + pub model: std::sync::RwLock, + pub tokenizer: Tokenizer, + pub device: Device, +} + +impl Gemma3Embedder { + pub fn new( + model_id: &str, + revision: Option, + token: Option<&str>, + dtype: Option, + ) -> Result { + let api = ApiBuilder::new() + .with_token(token.map(|s| s.to_string())) + .build() + .unwrap(); + + let repo = match revision { + Some(rev) => api.repo(Repo::with_revision( + model_id.to_string(), + hf_hub::RepoType::Model, + rev, + )), + None => api.repo(hf_hub::Repo::new( + model_id.to_string(), + hf_hub::RepoType::Model, + )), + }; + let (config_filename, tokenizer_filename, weights_filename) = { + let config = repo.get("config.json")?; + let tokenizer = repo.get("tokenizer.json")?; + let weights = repo.get("model.safetensors"); + + (config, tokenizer, weights) + }; + + let config = std::fs::read_to_string(config_filename)?; + let config: Config = serde_json::from_str(&config)?; + + let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(Error::msg)?; + let pp = PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + direction: tokenizers::PaddingDirection::Left, + ..Default::default() + }; + let trunc = TruncationParams { + strategy: tokenizers::TruncationStrategy::LongestFirst, + max_length: 1024, + ..Default::default() + }; + + tokenizer + .with_padding(Some(pp)) + .with_truncation(Some(trunc)) + .map_err(Error::msg)?; + + let device = select_device(); + + let dtype = match dtype { + Some(crate::Dtype::F16) => DType::F16, + Some(crate::Dtype::F32) => DType::F32, + Some(crate::Dtype::BF16) => DType::BF16, + _ => DType::F32, + }; + + let vb = match weights_filename { + Ok(weights) => unsafe { + VarBuilder::from_mmaped_safetensors(&[weights], dtype, &device)? + }, + Err(_) => { + let weights = hub_load_safetensors(&repo, "model.safetensors.index.json")?; + unsafe { VarBuilder::from_mmaped_safetensors(&weights, dtype, &device)? } + } + }; + + let model = Model::new(false, &config, vb)?; // use_flash_attn = false by default + + Ok(Self { + model: std::sync::RwLock::new(model), + tokenizer, + device, + }) + } +} + +impl Gemma3Embed for Gemma3Embedder { + fn embed( + &self, + text_batch: &[&str], + batch_size: Option, + _late_chunking: Option, + ) -> Result, anyhow::Error> { + let batch_size = batch_size.unwrap_or(32); + let mut encodings: Vec = Vec::new(); + + for mini_text_batch in text_batch.chunks(batch_size) { + let (token_ids, attention_mask) = + tokenize_batch(&self.tokenizer, mini_text_batch, &self.device)?; + + // Forward pass through the model - Gemma3 forward takes input_ids and seqlen_offset + let embeddings: Tensor = self + .model + .write() + .unwrap() + .forward(&token_ids, 0) + .unwrap() + .to_dtype(DType::F32)?; + + self.model.write().unwrap().clear_kv_cache(); + + // Convert attention_mask to the expected format for pooling + let attention_mask = PooledOutputType::from(attention_mask); + let attention_mask = Some(&attention_mask); + let model_output = ModelOutput::Tensor(embeddings.clone()); + let pooled_output = Pooling::LastToken + .pool(&model_output, attention_mask) + .unwrap(); + let pooled_output = pooled_output.to_tensor()?; + let embeddings = normalize_l2(pooled_output)?; + let batch_encodings = embeddings.to_vec2::()?; + + encodings.extend( + batch_encodings + .iter() + .map(|x| EmbeddingResult::DenseVector(x.to_vec())), + ); + } + + Ok(encodings) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_gemma3_embed() { + // Test with a small Gemma3 model if available + // Note: You may need to adjust the model_id based on available models + let embedder = Gemma3Embedder::new( + "google/embeddinggemma-300m", // Adjust model ID as needed + None, + None, + Some(crate::Dtype::F32), + ); + + // Only run test if model is available + if let Ok(embedder) = embedder { + let embeddings = embedder + .embed( + &["Hello, world!", "I am a rust programmer now"], + Some(2), + None, + ) + .unwrap(); + + // Basic assertions - embeddings should not be empty + assert!(!embeddings.is_empty()); + assert_eq!(embeddings.len(), 2); + + // Check that embeddings have reasonable dimensions + for embedding in &embeddings { + let dense = embedding.to_dense().unwrap(); + assert!(!dense.is_empty()); + assert!(dense.len() > 100); // Gemma3 should have reasonable embedding dimensions + } + } + } +} \ No newline at end of file diff --git a/rust/src/embeddings/local/mod.rs b/rust/src/embeddings/local/mod.rs index 30e174d2..e12e490b 100644 --- a/rust/src/embeddings/local/mod.rs +++ b/rust/src/embeddings/local/mod.rs @@ -21,5 +21,6 @@ pub mod ort_bert; pub mod ort_jina; pub mod pooling; pub mod qwen3; +pub mod gemma3; pub mod text_embedding; pub mod vision_encoder; diff --git a/rust/src/models/gemma3.rs b/rust/src/models/gemma3.rs new file mode 100644 index 00000000..58f7bf4d --- /dev/null +++ b/rust/src/models/gemma3.rs @@ -0,0 +1,475 @@ +//! Gemma LLM architecture (Google) inference implementation. +//! +//! See ["Introducing Gemma 3: The most capable model you can run on a single GPU or TPU"](https://blog.google/technology/developers/gemma-3/) +//! +//! Based on implementations from HuggingFace transformers. + +use std::sync::Arc; + +use candle_core::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder}; + +use crate::models::qwen3::repeat_kv; + +#[derive(serde::Deserialize, Debug, Clone)] +pub struct Config { + pub attention_bias: bool, + pub head_dim: usize, + pub hidden_activation: Activation, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_attention_heads: usize, + pub num_hidden_layers: usize, + pub num_key_value_heads: usize, + pub rms_norm_eps: f64, + pub rope_theta: f64, + pub vocab_size: usize, + pub final_logit_softcapping: Option, + pub attn_logit_softcapping: Option, + pub query_pre_attn_scalar: usize, + pub sliding_window: usize, + pub _sliding_window_pattern: usize, + pub max_position_embeddings: usize, +} + +#[derive(Debug, Clone)] +struct RmsNorm { + weight: Tensor, + eps: f64, +} + +impl RmsNorm { + fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get(dim, "weight")?; + Ok(Self { weight, eps }) + } +} + +impl Module for RmsNorm { + fn forward(&self, x: &Tensor) -> Result { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = x.dim(D::Minus1)?; + let x = x.to_dtype(internal_dtype)?; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + x_normed + .to_dtype(x_dtype)? + .broadcast_mul(&(&self.weight + 1.0)?) + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let dim = cfg.head_dim; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: candle_nn::Activation, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let intermediate_sz = cfg.intermediate_size; + let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?; + let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?; + let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?; + Ok(Self { + gate_proj, + up_proj, + down_proj, + act_fn: cfg.hidden_activation, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = xs.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +#[derive(Debug, Clone)] +enum KvCache { + Normal(candle_nn::kv_cache::KvCache), + Rotating(candle_nn::kv_cache::RotatingKvCache), +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + q_norm: RmsNorm, + k_norm: RmsNorm, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + attn_logit_softcapping: Option, + rotary_emb: Arc, + kv_cache: KvCache, + use_flash_attn: bool, +} + +impl Attention { + fn new( + rotary_emb: Arc, + use_flash_attn: bool, + is_sliding: bool, + cfg: &Config, + vb: VarBuilder, + ) -> Result { + let hidden_sz = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + let head_dim = cfg.head_dim; + let bias = cfg.attention_bias; + let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?; + let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?; + let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?; + let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?; + let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; + let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; + let kv_cache = if is_sliding { + KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new( + 2, + cfg.sliding_window, + )) + } else { + KvCache::Normal(candle_nn::kv_cache::KvCache::new(2, cfg.sliding_window)) + }; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + q_norm, + k_norm, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + attn_logit_softcapping: cfg.attn_logit_softcapping, + rotary_emb, + kv_cache, + use_flash_attn, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let (b_sz, q_len, _) = xs.dims3()?; + + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = query_states + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let key_states = key_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let value_states = value_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let query_states = self.q_norm.forward(&query_states)?; + let key_states = self.k_norm.forward(&key_states)?; + + let (query_states, key_states) = + self.rotary_emb + .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; + + let (key_states, value_states) = match &mut self.kv_cache { + KvCache::Normal(cache) => cache.append(&key_states, &value_states)?, + KvCache::Rotating(cache) => cache.append(&key_states, &value_states)?, + }; + + let key_states = repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; + let value_states = + repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; + + let attn_output = if self.use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = query_states.transpose(1, 2)?; + let k = key_states.transpose(1, 2)?; + let v = value_states.transpose(1, 2)?; + let scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)? + } else { + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + + let attn_weights = match self.attn_logit_softcapping { + None => attn_weights, + Some(sc) => ((attn_weights / sc)?.tanh()? * sc)?, + }; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? + }; + attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, ()))? + .apply(&self.o_proj) + } + + fn clear_kv_cache(&mut self) { + match &mut self.kv_cache { + KvCache::Normal(c) => c.reset(), + KvCache::Rotating(c) => c.reset(), + } + } +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { + unimplemented!("compile with '--features flash-attn'") +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Attention, + mlp: MLP, + input_layernorm: RmsNorm, + pre_feedforward_layernorm: RmsNorm, + post_feedforward_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl DecoderLayer { + fn new( + rotary_emb: Arc, + use_flash_attn: bool, + is_sliding: bool, + cfg: &Config, + vb: VarBuilder, + ) -> Result { + let self_attn = Attention::new( + rotary_emb, + use_flash_attn, + is_sliding, + cfg, + vb.pp("self_attn"), + )?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + let input_layernorm = + RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let pre_feedforward_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("pre_feedforward_layernorm"), + )?; + let post_feedforward_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_feedforward_layernorm"), + )?; + let post_attention_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + input_layernorm, + pre_feedforward_layernorm, + post_feedforward_layernorm, + post_attention_layernorm, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; + let xs = xs.apply(&self.post_attention_layernorm)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.pre_feedforward_layernorm)?; + let xs = xs.apply(&self.mlp)?; + let xs = xs.apply(&self.post_feedforward_layernorm)?; + residual + xs + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache() + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: RmsNorm, + device: Device, + dtype: DType, + hidden_size: usize, + sliding_window: usize, +} + +impl Model { + pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result { + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let is_sliding = (layer_idx + 1) % cfg._sliding_window_pattern > 0; + let layer = DecoderLayer::new( + rotary_emb.clone(), + use_flash_attn, + is_sliding, + cfg, + vb_l.pp(layer_idx), + )?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?; + Ok(Self { + embed_tokens, + layers, + norm, + device: vb.device().clone(), + dtype: vb.dtype(), + hidden_size: cfg.hidden_size, + sliding_window: cfg.sliding_window, + }) + } + + fn prepare_decoder_attention_mask( + &self, + b_size: usize, + tgt_len: usize, + seqlen_offset: usize, + ) -> Result { + let mask: Vec<_> = match Some(self.sliding_window) { + None => (0..tgt_len) + .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .collect(), + Some(sliding_window) => (0..tgt_len) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) + .collect(), + }; + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(self.dtype) + } + + pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { + let (b_size, seq_len) = input_ids.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + Some(mask) + }; + let xs = self.embed_tokens.forward(input_ids)?; + let mut xs = (xs * (self.hidden_size as f64).sqrt())?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? + } + let logits = xs + .narrow(1, seq_len - 1, 1)? + .apply(&self.norm)?; + + + Ok(logits) + } + + pub fn clear_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.clear_kv_cache() + } + } +} \ No newline at end of file diff --git a/rust/src/models/mod.rs b/rust/src/models/mod.rs index 695a15c0..78c3fe9f 100644 --- a/rust/src/models/mod.rs +++ b/rust/src/models/mod.rs @@ -21,6 +21,7 @@ pub mod clip; pub mod colpali; pub mod dinov2; pub mod gemma; +pub mod gemma3; pub mod idefics3; pub mod jina_bert; pub mod llama; From a59abf3494e4de5482d504cb4690502387f89d90 Mon Sep 17 00:00:00 2001 From: Sonam Pankaj Date: Sat, 4 Oct 2025 19:24:05 +0200 Subject: [PATCH 2/5] gemma3-embeddings (cherry picked from commit 056c4a69bc9ef2ae785209464444a458ea9a8fb2) --- examples/gemma3.py | 48 +++++++++++++++++++++++++++++ rust/src/embeddings/local/gemma3.rs | 2 +- 2 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 examples/gemma3.py diff --git a/examples/gemma3.py b/examples/gemma3.py new file mode 100644 index 00000000..be974cbd --- /dev/null +++ b/examples/gemma3.py @@ -0,0 +1,48 @@ +import heapq +from embed_anything import EmbedData, EmbeddingModel, TextEmbedConfig, WhichModel, Dtype + +from embed_anything import Dtype, ONNXModel +import numpy as np +import os +from huggingface_hub import login + +# Add your HuggingFace token here or set it as an environment variable + + + +# model:EmbeddingModel = EmbeddingModel.from_pretrained_hf( +# WhichModel.Qwen3, model_id="Qwen/Qwen3-Embedding-0.6B", dtype=Dtype.F16 +# ) + +model:EmbeddingModel = EmbeddingModel.from_pretrained_hf( + WhichModel.Gemma3, model_id="google/embeddinggemma-300m", dtype=Dtype.F32, token="hf_key" +) + +config = TextEmbedConfig( + chunk_size=1000, + batch_size=2, + splitting_strategy="sentence", + late_chunking=False, +) + +# Embed a single file +data: list[EmbedData] = model.embed_file("test_files/attention.pdf", config=config) + + +query = "Which GPU is used for training" + +query_embedding = np.array(model.embed_query([query])[0].embedding) + +embedding_array = np.array([e.embedding for e in data]) + +similarities = np.matmul(query_embedding, embedding_array.T) + +# get top 5 similarities and its index +top_5_similarities = np.argsort(similarities)[-10:][::-1] + +# Print the top 5 similarities with sentences +for i in top_5_similarities: + print(f"Score: {similarities[i]:.2} | {data[i].text}") + print("---" * 20) + +context = "\n".join([data[i].text for i in top_5_similarities]) \ No newline at end of file diff --git a/rust/src/embeddings/local/gemma3.rs b/rust/src/embeddings/local/gemma3.rs index 1c0e7593..c93ee98b 100644 --- a/rust/src/embeddings/local/gemma3.rs +++ b/rust/src/embeddings/local/gemma3.rs @@ -137,7 +137,7 @@ impl Gemma3Embed for Gemma3Embedder { let attention_mask = PooledOutputType::from(attention_mask); let attention_mask = Some(&attention_mask); let model_output = ModelOutput::Tensor(embeddings.clone()); - let pooled_output = Pooling::LastToken + let pooled_output = Pooling::Mean .pool(&model_output, attention_mask) .unwrap(); let pooled_output = pooled_output.to_tensor()?; From 0f8dc770002d25fe0ac7bf5aae7df25c566f45d5 Mon Sep 17 00:00:00 2001 From: Akshay Ballal Date: Sun, 5 Oct 2025 20:39:16 +0200 Subject: [PATCH 3/5] Refactor gemma3 example and update token usage in BERT example - Cleaned up gemma3.py by removing commented-out code related to HuggingFace token. - Updated the token placeholder in bert.rs to reflect the correct usage of "hf_key". - Minor adjustments in gemma3.rs to remove unnecessary code for clarity. (cherry picked from commit 2b3c8da067956f3eff1b77fe05f88f5603780bfa) --- examples/gemma3.py | 7 ------- rust/src/models/gemma3.rs | 2 -- 2 files changed, 9 deletions(-) diff --git a/examples/gemma3.py b/examples/gemma3.py index be974cbd..2dcdfc8b 100644 --- a/examples/gemma3.py +++ b/examples/gemma3.py @@ -6,13 +6,6 @@ import os from huggingface_hub import login -# Add your HuggingFace token here or set it as an environment variable - - - -# model:EmbeddingModel = EmbeddingModel.from_pretrained_hf( -# WhichModel.Qwen3, model_id="Qwen/Qwen3-Embedding-0.6B", dtype=Dtype.F16 -# ) model:EmbeddingModel = EmbeddingModel.from_pretrained_hf( WhichModel.Gemma3, model_id="google/embeddinggemma-300m", dtype=Dtype.F32, token="hf_key" diff --git a/rust/src/models/gemma3.rs b/rust/src/models/gemma3.rs index 58f7bf4d..f544e721 100644 --- a/rust/src/models/gemma3.rs +++ b/rust/src/models/gemma3.rs @@ -460,10 +460,8 @@ impl Model { xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? } let logits = xs - .narrow(1, seq_len - 1, 1)? .apply(&self.norm)?; - Ok(logits) } From ad5b32eb791d5970975899ebd0357ebd04936874 Mon Sep 17 00:00:00 2001 From: Wes Chow Date: Tue, 26 May 2026 18:37:07 -0400 Subject: [PATCH 4/5] Adapt gemma3 example to current from_pretrained_hf API The dev merge reworked EmbeddingModel.from_pretrained_hf to drop the WhichModel argument and auto-detect the architecture from config.json (Gemma3TextModel). Update the example call to the new signature. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/gemma3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/gemma3.py b/examples/gemma3.py index 2dcdfc8b..e337c6d1 100644 --- a/examples/gemma3.py +++ b/examples/gemma3.py @@ -7,8 +7,8 @@ from huggingface_hub import login -model:EmbeddingModel = EmbeddingModel.from_pretrained_hf( - WhichModel.Gemma3, model_id="google/embeddinggemma-300m", dtype=Dtype.F32, token="hf_key" +model: EmbeddingModel = EmbeddingModel.from_pretrained_hf( + model_id="google/embeddinggemma-300m", dtype=Dtype.F32, token="hf_key" ) config = TextEmbedConfig( From 755d297dc595aff01703958c1a9c45d06240ad3b Mon Sep 17 00:00:00 2001 From: Wes Chow Date: Tue, 26 May 2026 21:04:05 -0400 Subject: [PATCH 5/5] Make Gemma3 embeddings match sentence-transformers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add 768→3072→768 dense projection head between mean-pool and L2-norm - Switch the transformer to bidirectional attention with a padding mask - Apply per-layer RoPE: local base freq for sliding layers, global for full-attention layers - Add golden test against sentence-transformers output verified to ~1e-6 --- rust/src/embeddings/local/gemma3.rs | 124 ++++++++++++++++++--------- rust/src/models/gemma3.rs | 126 ++++++++++++++++++---------- 2 files changed, 169 insertions(+), 81 deletions(-) diff --git a/rust/src/embeddings/local/gemma3.rs b/rust/src/embeddings/local/gemma3.rs index c93ee98b..917e8ac3 100644 --- a/rust/src/embeddings/local/gemma3.rs +++ b/rust/src/embeddings/local/gemma3.rs @@ -4,7 +4,7 @@ use crate::{ }; use anyhow::Error; use candle_core::{DType, Device, Tensor}; -use candle_nn::VarBuilder; +use candle_nn::{Linear, Module, VarBuilder}; use hf_hub::{api::sync::ApiBuilder, Repo}; use tokenizers::{PaddingParams, Tokenizer, TruncationParams}; @@ -26,6 +26,24 @@ pub struct Gemma3Embedder { pub model: std::sync::RwLock, pub tokenizer: Tokenizer, pub device: Device, + dense1: Linear, + dense2: Linear, +} + +/// Loads a sentence-transformers `Dense` module (a single no-bias `nn.Linear` +/// stored under the key `linear.weight`) from a safetensors file in the repo. +fn load_dense( + repo: &hf_hub::api::sync::ApiRepo, + path: &str, + device: &Device, +) -> Result { + let weights_path = repo.get(path)?; + let mut tensors = candle_core::safetensors::load(weights_path, device)?; + let weight = tensors + .remove("linear.weight") + .ok_or_else(|| anyhow::anyhow!("missing 'linear.weight' in {path}"))? + .to_dtype(DType::F32)?; + Ok(Linear::new(weight, None)) } impl Gemma3Embedder { @@ -35,10 +53,11 @@ impl Gemma3Embedder { token: Option<&str>, dtype: Option, ) -> Result { - let api = ApiBuilder::new() - .with_token(token.map(|s| s.to_string())) - .build() - .unwrap(); + let mut api_builder = ApiBuilder::from_env(); + if let Some(token) = token { + api_builder = api_builder.with_token(Some(token.to_string())); + } + let api = api_builder.build()?; let repo = match revision { Some(rev) => api.repo(Repo::with_revision( @@ -100,10 +119,19 @@ impl Gemma3Embedder { let model = Model::new(false, &config, vb)?; // use_flash_attn = false by default + // EmbeddingGemma applies a 768 -> 3072 -> 768 dense projection head + // (sentence-transformers modules 2_Dense / 3_Dense, both no-bias with + // Identity activation) between mean pooling and L2 normalization. These + // weights live outside model.safetensors and must be applied separately. + let dense1 = load_dense(&repo, "2_Dense/model.safetensors", &device)?; + let dense2 = load_dense(&repo, "3_Dense/model.safetensors", &device)?; + Ok(Self { model: std::sync::RwLock::new(model), tokenizer, device, + dense1, + dense2, }) } } @@ -121,27 +149,27 @@ impl Gemma3Embed for Gemma3Embedder { for mini_text_batch in text_batch.chunks(batch_size) { let (token_ids, attention_mask) = tokenize_batch(&self.tokenizer, mini_text_batch, &self.device)?; - - // Forward pass through the model - Gemma3 forward takes input_ids and seqlen_offset + + // Forward pass through the model. EmbeddingGemma uses bidirectional + // attention, so the padding mask must be passed in (not just used for pooling). let embeddings: Tensor = self .model .write() .unwrap() - .forward(&token_ids, 0) + .forward(&token_ids, Some(&attention_mask), 0) .unwrap() .to_dtype(DType::F32)?; self.model.write().unwrap().clear_kv_cache(); - + // Convert attention_mask to the expected format for pooling let attention_mask = PooledOutputType::from(attention_mask); let attention_mask = Some(&attention_mask); let model_output = ModelOutput::Tensor(embeddings.clone()); - let pooled_output = Pooling::Mean - .pool(&model_output, attention_mask) - .unwrap(); + let pooled_output = Pooling::Mean.pool(&model_output, attention_mask).unwrap(); let pooled_output = pooled_output.to_tensor()?; - let embeddings = normalize_l2(pooled_output)?; + let projected = self.dense2.forward(&self.dense1.forward(pooled_output)?)?; + let embeddings = normalize_l2(&projected)?; let batch_encodings = embeddings.to_vec2::()?; encodings.extend( @@ -161,35 +189,53 @@ mod tests { #[test] fn test_gemma3_embed() { - // Test with a small Gemma3 model if available - // Note: You may need to adjust the model_id based on available models let embedder = Gemma3Embedder::new( - "google/embeddinggemma-300m", // Adjust model ID as needed + "google/embeddinggemma-300m", None, None, Some(crate::Dtype::F32), ); - - // Only run test if model is available - if let Ok(embedder) = embedder { - let embeddings = embedder - .embed( - &["Hello, world!", "I am a rust programmer now"], - Some(2), - None, - ) - .unwrap(); - - // Basic assertions - embeddings should not be empty - assert!(!embeddings.is_empty()); - assert_eq!(embeddings.len(), 2); - - // Check that embeddings have reasonable dimensions - for embedding in &embeddings { - let dense = embedding.to_dense().unwrap(); - assert!(!dense.is_empty()); - assert!(dense.len() > 100); // Gemma3 should have reasonable embedding dimensions - } - } + let Ok(embedder) = embedder else { + return; + }; + let embeddings = embedder + .embed( + &["Hello, world!", "I am a rust programmer now"], + Some(2), + None, + ) + .unwrap(); + + // Exercise full pipeline and check first 6 dims. + // bidirectional transformer -> mean pool + // -> dense(768->3072) -> dense(3072->768) -> L2 norm + let test_embeddings: Vec = vec![ + -0.17819002, + 0.02147517, + 0.06739803, + -0.03160102, + 0.02198322, + -0.00981485, + ]; + let first_embeddings = embeddings[0].to_dense().unwrap()[0..6].to_vec(); + println!("{:?}", first_embeddings); + assert!(first_embeddings + .iter() + .zip(test_embeddings.iter()) + .all(|(a, b)| (a.abs() - b.abs()).abs() < 1e-6)); + let test_embeddings: Vec = vec![ + -0.19414055, + -0.01050718, + 0.02919163, + 0.0027125, + 0.037645, + 0.04710986, + ]; + let second_embeddings = embeddings[1].to_dense().unwrap()[0..6].to_vec(); + println!("{:?}", second_embeddings); + assert!(second_embeddings + .iter() + .zip(test_embeddings.iter()) + .all(|(a, b)| (a.abs() - b.abs()).abs() < 1e-6)); } -} \ No newline at end of file +} diff --git a/rust/src/models/gemma3.rs b/rust/src/models/gemma3.rs index f544e721..0ca08b2e 100644 --- a/rust/src/models/gemma3.rs +++ b/rust/src/models/gemma3.rs @@ -11,6 +11,10 @@ use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder}; use crate::models::qwen3::repeat_kv; +fn default_rope_local_base_freq() -> f64 { + 10000.0 +} + #[derive(serde::Deserialize, Debug, Clone)] pub struct Config { pub attention_bias: bool, @@ -23,6 +27,8 @@ pub struct Config { pub num_key_value_heads: usize, pub rms_norm_eps: f64, pub rope_theta: f64, + #[serde(default = "default_rope_local_base_freq")] + pub rope_local_base_freq: f64, pub vocab_size: usize, pub final_logit_softcapping: Option, pub attn_logit_softcapping: Option, @@ -69,12 +75,12 @@ struct RotaryEmbedding { } impl RotaryEmbedding { - fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + fn new(dtype: DType, rope_theta: f64, cfg: &Config, dev: &Device) -> Result { let dim = cfg.head_dim; let max_seq_len = cfg.max_position_embeddings; let inv_freq: Vec<_> = (0..dim) .step_by(2) - .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .map(|i| 1f32 / rope_theta.powf(i as f64 / dim as f64) as f32) .collect(); let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; @@ -240,8 +246,7 @@ impl Attention { }; let key_states = repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; - let value_states = - repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; + let value_states = repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; let attn_output = if self.use_flash_attn { // flash-attn expects (b_sz, seq_len, nheads, head_dim) @@ -304,6 +309,7 @@ struct DecoderLayer { pre_feedforward_layernorm: RmsNorm, post_feedforward_layernorm: RmsNorm, post_attention_layernorm: RmsNorm, + is_sliding: bool, } impl DecoderLayer { @@ -346,6 +352,7 @@ impl DecoderLayer { pre_feedforward_layernorm, post_feedforward_layernorm, post_attention_layernorm, + is_sliding, }) } @@ -387,13 +394,32 @@ impl Model { pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result { let embed_tokens = candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?; - let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?); + // Gemma3 uses two RoPE base frequencies: a global one (rope_theta) for + // full-attention layers and a local one (rope_local_base_freq) for + // sliding-attention layers. + let global_rotary_emb = Arc::new(RotaryEmbedding::new( + vb.dtype(), + cfg.rope_theta, + cfg, + vb.device(), + )?); + let local_rotary_emb = Arc::new(RotaryEmbedding::new( + vb.dtype(), + cfg.rope_local_base_freq, + cfg, + vb.device(), + )?); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let vb_l = vb.pp("layers"); for layer_idx in 0..cfg.num_hidden_layers { let is_sliding = (layer_idx + 1) % cfg._sliding_window_pattern > 0; + let rotary_emb = if is_sliding { + local_rotary_emb.clone() + } else { + global_rotary_emb.clone() + }; let layer = DecoderLayer::new( - rotary_emb.clone(), + rotary_emb, use_flash_attn, is_sliding, cfg, @@ -413,56 +439,72 @@ impl Model { }) } - fn prepare_decoder_attention_mask( + /// Builds an additive attention bias for EmbeddingGemma's bidirectional + /// attention (`use_bidirectional_attention=true`): a token may attend to every + /// non-padding token. Sliding-attention layers are additionally restricted to a + /// symmetric band `|i - j| < window`; full-attention layers pass `window = None`. + fn bidirectional_mask( &self, + attention_mask: Option<&Tensor>, b_size: usize, tgt_len: usize, - seqlen_offset: usize, + window: Option, ) -> Result { - let mask: Vec<_> = match Some(self.sliding_window) { - None => (0..tgt_len) - .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) - .collect(), - Some(sliding_window) => (0..tgt_len) - .flat_map(|i| { - (0..tgt_len).map(move |j| { - if i < j || j + sliding_window < i { - f32::NEG_INFINITY - } else { - 0. - } - }) - }) + let real: Vec> = match attention_mask { + Some(am) => am + .to_vec2::()? + .into_iter() + .map(|row| row.into_iter().map(|v| v != 0).collect()) .collect(), + None => vec![vec![true; tgt_len]; b_size], }; - let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; - let mask = if seqlen_offset > 0 { - let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; - Tensor::cat(&[&mask0, &mask], D::Minus1)? - } else { - mask - }; - mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? - .to_dtype(self.dtype) + let mut data = Vec::with_capacity(b_size * tgt_len * tgt_len); + for row in real.iter().take(b_size) { + for i in 0..tgt_len { + for (j, &is_real) in row.iter().enumerate() { + let out_of_band = match window { + Some(w) => ((i as i64) - (j as i64)).unsigned_abs() as usize >= w, + None => false, + }; + let masked = !is_real || out_of_band; + data.push(if masked { f32::NEG_INFINITY } else { 0.0 }); + } + } + } + Tensor::from_vec(data, (b_size, 1, tgt_len, tgt_len), &self.device)?.to_dtype(self.dtype) } - pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { + pub fn forward( + &mut self, + input_ids: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { let (b_size, seq_len) = input_ids.dims2()?; - let attention_mask = if seq_len <= 1 { - None + let (full_mask, sliding_mask) = if seq_len <= 1 { + (None, None) } else { - let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; - Some(mask) + ( + Some(self.bidirectional_mask(attention_mask, b_size, seq_len, None)?), + Some(self.bidirectional_mask( + attention_mask, + b_size, + seq_len, + Some(self.sliding_window), + )?), + ) }; let xs = self.embed_tokens.forward(input_ids)?; let mut xs = (xs * (self.hidden_size as f64).sqrt())?; for layer in self.layers.iter_mut() { - xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? + let mask = if layer.is_sliding { + sliding_mask.as_ref() + } else { + full_mask.as_ref() + }; + xs = layer.forward(&xs, mask, seqlen_offset)?; } - let logits = xs - .apply(&self.norm)?; - - Ok(logits) + xs.apply(&self.norm) } pub fn clear_kv_cache(&mut self) { @@ -470,4 +512,4 @@ impl Model { layer.clear_kv_cache() } } -} \ No newline at end of file +}