diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 538d2a44be33..6156dc836ad3 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1131,6 +1131,8 @@ title: PE Audio - local: model_doc/pop2piano title: Pop2Piano + - local: model_doc/qwen3_asr + title: Qwen3 ASR - local: model_doc/seamless_m4t title: Seamless-M4T - local: model_doc/seamless_m4t_v2 diff --git a/docs/source/en/model_doc/qwen3_asr.md b/docs/source/en/model_doc/qwen3_asr.md new file mode 100644 index 000000000000..49a3fc86009d --- /dev/null +++ b/docs/source/en/model_doc/qwen3_asr.md @@ -0,0 +1,511 @@ + +*This model was published in HF papers on 2026-01-29 and contributed to Hugging Face Transformers on 2026-06-25.* + +# Qwen3 ASR + +
+FlashAttention +SDPA +
+ +## Overview + +Qwen3 ASR is an automatic speech recognition model from Alibaba's Qwen team that combines a Whisper-style audio encoder with a Qwen3 language model decoder for speech-to-text transcription. The model supports automatic language detection and multilingual transcription. + +A forced aligner model is also included. It can be used to timestamp a provided transcript and its audio. It uses the same audio encoder model with a classification head that predicts a word's length. This model can be used with the transcript from any ASR model (see the example below with Parakeet CTC). + +Available checkpoints: +- [bezzam/Qwen3-ASR-1.7B-hf](https://huggingface.co/bezzam/Qwen3-ASR-1.7B-hf) +- [bezzam/Qwen3-ASR-0.6B-hf](https://huggingface.co/bezzam/Qwen3-ASR-0.6B-hf) +- [bezzam/Qwen3-ForcedAligner-0.6B-hf](https://huggingface.co/bezzam/Qwen3-ForcedAligner-0.6B-hf) + +The following languages are supported: +- `Qwen3-ASR-1.7B` and `Qwen3-ASR-0.6B`: Chinese (zh), English (en), Cantonese (yue), Arabic (ar), German (de), French (fr), Spanish (es), Portuguese (pt), Indonesian (id), Italian (it), Korean (ko), Russian (ru), Thai (th), Vietnamese (vi), Japanese (ja), Turkish (tr), Hindi (hi), Malay (ms), Dutch (nl), Swedish (sv), Danish (da), Finnish (fi), Polish (pl), Czech (cs), Filipino (fil), Persian (fa), Greek (el), Hungarian (hu), Macedonian (mk), Romanian (ro). +- `Qwen3-ForcedAligner-0.6B`: Chinese (zh), English (en), Cantonese (yue), French (fr), German (de), Italian (it), Japanese (ja), Korean (ko), Portuguese (pt), Russian (ru), Spanish (es). + +See the original repository at [QwenLM/Qwen3-ASR](https://github.com/QwenLM/Qwen3-ASR) and the [report](https://huggingface.co/papers/2601.21337) for more details. + +This model was contributed by [Eric Bezzam](https://huggingface.co/bezzam) and [Muhammed Tariq](https://huggingface.co/mbtariq82). + +## Usage + +### Simple transcription + +The simplest way to transcribe audio is with `apply_transcription_request`, which handles the chat template formatting for you, namely it is a convenience wrapper for `apply_chat_template` (see [Chat template](#chat-template) below). + +```python +from transformers import AutoProcessor, AutoModelForMultimodalLM + +model_id = "bezzam/Qwen3-ASR-1.7B-hf" +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForMultimodalLM.from_pretrained(model_id, device_map="auto") +print(f"Model loaded on {model.device} with dtype {model.dtype}") + +inputs = processor.apply_transcription_request( + audio="https://huggingface.co/datasets/bezzam/audio_samples/resolve/main/librispeech_mr_quilter.wav", +).to(model.device, model.dtype) + +output_ids = model.generate(**inputs, max_new_tokens=256) +generated_ids = output_ids[:, inputs["input_ids"].shape[1]:] + +# Raw output includes language tag and marker +raw = processor.decode(generated_ids)[0] +print(f"Raw: {raw}") + +# Parsed output: dict with "language" and "transcription" +parsed = processor.decode(generated_ids, return_format="parsed")[0] +print(f"Parsed: {parsed}") + +# Extract only the transcription text +transcription = processor.decode(generated_ids, return_format="transcription_only")[0] +print(f"Transcription: {transcription}") + +""" +Raw: language EnglishMr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. +Parsed: {'language': 'English', 'transcription': 'Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'} +Transcription: Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. +""" +``` + +### Language hint + +You can provide a language hint to guide the model. + +```python +from transformers import AutoProcessor, AutoModelForMultimodalLM + +model_id = "bezzam/Qwen3-ASR-1.7B-hf" +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForMultimodalLM.from_pretrained(model_id, device_map="auto") + +# Without language hint (auto-detect) +inputs = processor.apply_transcription_request( + audio="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_zh.wav", +).to(model.device, model.dtype) +output_ids = model.generate(**inputs, max_new_tokens=256) +generated_ids = output_ids[:, inputs["input_ids"].shape[1]:] +print(f"Auto-detect: {processor.decode(generated_ids, return_format='transcription_only')[0]}") + +# With language hint +inputs = processor.apply_transcription_request( + audio="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_zh.wav", + language="Chinese", # or language code "zh" +).to(model.device, model.dtype) +output_ids = model.generate(**inputs, max_new_tokens=256) +generated_ids = output_ids[:, inputs["input_ids"].shape[1]:] +print(f"With hint: {processor.decode(generated_ids, return_format='transcription_only')[0]}") +``` + +### Batch inference + +Batch inference is possible by passing a list of audios and, if provided, a list of languages. + +```python +from transformers import AutoProcessor, AutoModelForMultimodalLM + +model_id = "bezzam/Qwen3-ASR-1.7B-hf" +audio = [ + "https://huggingface.co/datasets/bezzam/audio_samples/resolve/main/librispeech_mr_quilter.wav", + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_zh.wav", +] + +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForMultimodalLM.from_pretrained(model_id, device_map="auto") + +inputs = processor.apply_transcription_request( + audio, language=[None, "zh"], # language codes ("zh") and full names ("Chinese") are both accepted +).to(model.device, model.dtype) + +output_ids = model.generate(**inputs, max_new_tokens=256) +generated_ids = output_ids[:, inputs["input_ids"].shape[1]:] +transcriptions = processor.decode(generated_ids, return_format="transcription_only") + +for i, text in enumerate(transcriptions): + print(f"Audio {i + 1}: {text}") +``` + +### Chat template + +Qwen3 ASR also accepts chat template inputs. The `apply_transcription_request` usage [above](#simple-transcription) is a convenience wrapper for `apply_chat_template`. + +```python +from transformers import AutoProcessor, Qwen3ASRForConditionalGeneration + +model_id = "bezzam/Qwen3-ASR-1.7B-hf" +processor = AutoProcessor.from_pretrained(model_id) +model = Qwen3ASRForConditionalGeneration.from_pretrained(model_id, device_map="auto") + +# With language hint as system message +chat_template = [ + [ + {"role": "system", "content": [{"type": "text", "text": "English"}]}, + { + "role": "user", + "content": [ + { + "type": "audio", + "path": "https://huggingface.co/datasets/bezzam/audio_samples/resolve/main/librispeech_mr_quilter.wav", + }, + ], + }, + ], + [ + { + "role": "user", + "content": [ + { + "type": "audio", + "path": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_zh.wav", + }, + ], + }, + ], +] + +inputs = processor.apply_chat_template( + chat_template, tokenize=True, return_dict=True, +).to(model.device, model.dtype) + +output_ids = model.generate(**inputs, max_new_tokens=256) +generated_ids = output_ids[:, inputs["input_ids"].shape[1]:] +transcriptions = processor.decode(generated_ids, return_format="transcription_only") +for text in transcriptions: + print(text) +``` + +### Training + +Qwen3 ASR can be trained with the loss outputted by the model. + +```python +from transformers import AutoProcessor, Qwen3ASRForConditionalGeneration + +model_id = "bezzam/Qwen3-ASR-1.7B-hf" +processor = AutoProcessor.from_pretrained(model_id) +model = Qwen3ASRForConditionalGeneration.from_pretrained(model_id, device_map="auto") +model.train() + +chat_template = [ + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", + }, + { + "type": "audio", + "path": "https://huggingface.co/datasets/bezzam/audio_samples/resolve/main/librispeech_mr_quilter.wav", + }, + ], + } + ], +] + +inputs = processor.apply_chat_template( + chat_template, tokenize=True, return_dict=True, output_labels=True, +).to(model.device, model.dtype) + +loss = model(**inputs).loss +print("Loss:", loss.item()) +loss.backward() +``` + +### Forced alignment (word-level timestamping) + +Use `Qwen3ASRForTokenClassification` to obtain word-level timestamps from a transcript. First transcribe with the ASR model, then align with the forced aligner. + +The following languages are supported: Chinese (zh), English (en), Cantonese (yue), French (fr), German (de), Italian (it), Japanese (ja), Korean (ko), Portuguese (pt), Russian (ru), Spanish (es). + +Japanese requires the `nagisa` library, while Korean requires the `soynlp` library: +``` +pip install nagisa soynlp +``` + +#### With Qwen3 ASR + +```python +import torch +from transformers import AutoProcessor, AutoModelForMultimodalLM, AutoModelForTokenClassification + +asr_model_id = "bezzam/Qwen3-ASR-0.6B-hf" +aligner_model_id = "bezzam/Qwen3-ForcedAligner-0.6B-hf" + +asr_processor = AutoProcessor.from_pretrained(asr_model_id) +asr_model = AutoModelForMultimodalLM.from_pretrained(asr_model_id, device_map="auto") + +aligner_processor = AutoProcessor.from_pretrained(aligner_model_id) +aligner_model = AutoModelForTokenClassification.from_pretrained( + aligner_model_id, dtype=torch.bfloat16, device_map="auto" +) + +audio_url = "https://huggingface.co/datasets/bezzam/audio_samples/resolve/main/librispeech_mr_quilter.wav" + +# Step 1: Transcribe +inputs = asr_processor.apply_transcription_request(audio=audio_url) +inputs = inputs.to(asr_model.device, asr_model.dtype) +output_ids = asr_model.generate(**inputs, max_new_tokens=256) +generated_ids = output_ids[:, inputs["input_ids"].shape[1]:] +parsed = asr_processor.decode(generated_ids, return_format="parsed")[0] +transcript = parsed["transcription"] +language = parsed["language"] or "English" + +# Step 2: Prepare alignment inputs +aligner_inputs, word_lists = aligner_processor.prepare_forced_aligner_inputs( + audio=audio_url, transcript=transcript, language=language, +) +aligner_inputs = aligner_inputs.to(aligner_model.device, aligner_model.dtype) + +# Step 3: Run forced aligner +with torch.inference_mode(): + outputs = aligner_model(**aligner_inputs) + +# Step 4: Decode timestamps +timestamps = aligner_processor.decode_forced_alignment( + logits=outputs.logits, + input_ids=aligner_inputs["input_ids"], + word_lists=word_lists, + timestamp_token_id=aligner_model.config.timestamp_token_id, +)[0] + +for item in timestamps: + print(f"{item['text']:<20} {item['start_time']:>8.3f}s β†’ {item['end_time']:>8.3f}s") + +""" +Word Start (s) End (s) +------------------------------------------ +Mr 0.560 0.800 +Quilter 0.800 1.280 +is 1.280 1.440 +the 1.440 1.520 +apostle 1.520 2.080 +... +""" +``` + +#### With another ASR model + +The forced aligner is model-agnostic, meaning the transcripts from any ASR system can be provided. Below is a batch inference example using [NVIDIA Parakeet CTC](https://huggingface.co/nvidia/parakeet-ctc-1.1b) for transcription. + + +```python +import torch +from datasets import Audio, load_dataset +from transformers import AutoModelForCTC, AutoProcessor, AutoModelForTokenClassification + +parakeet_processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b") +parakeet_model = AutoModelForCTC.from_pretrained( + "nvidia/parakeet-ctc-1.1b", dtype="auto", device_map="cuda", +) + +aligner_model_id = "bezzam/Qwen3-ForcedAligner-0.6B-hf" +aligner_processor = AutoProcessor.from_pretrained(aligner_model_id) +aligner_model = AutoModelForTokenClassification.from_pretrained( + aligner_model_id, dtype=torch.bfloat16, device_map="cuda", +) + +ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") +ds = ds.cast_column("audio", Audio(sampling_rate=parakeet_processor.feature_extractor.sampling_rate)) +audio_arrays = [ds[i]["audio"]["array"] for i in range(3)] +sr = ds[0]["audio"]["sampling_rate"] + +# Batch transcribe with Parakeet +inputs = parakeet_processor(audio_arrays, sampling_rate=sr, return_tensors="pt", padding=True).to( + parakeet_model.device, dtype=parakeet_model.dtype +) +with torch.inference_mode(): + outputs = parakeet_model.generate(**inputs) +transcripts = parakeet_processor.decode(outputs) + +# Batch align with Qwen3 Forced Aligner +aligner_inputs, word_lists = aligner_processor.prepare_forced_aligner_inputs( + audio=audio_arrays, transcript=transcripts, language="English", +) +aligner_inputs = aligner_inputs.to(aligner_model.device, aligner_model.dtype) + +with torch.inference_mode(): + aligner_outputs = aligner_model(**aligner_inputs) + +batch_timestamps = aligner_processor.decode_forced_alignment( + logits=aligner_outputs.logits, + input_ids=aligner_inputs["input_ids"], + word_lists=word_lists, + timestamp_token_id=aligner_model.config.timestamp_token_id, +) + +for i, (transcript, timestamps) in enumerate(zip(transcripts, batch_timestamps)): + print(f"\n[Sample {i}] {transcript}") + for item in timestamps[:5]: + print(f" {item['text']:<20} {item['start_time']:>8.3f}s β†’ {item['end_time']:>8.3f}s") + if len(timestamps) > 5: + print(f" ... ({len(timestamps) - 5} more words)") +``` + +### Torch compile + +Both the ASR and forced aligner models support `torch.compile` for faster inference. The forced aligner is an especially good fit for compilation because it runs a single forward pass (no autoregressive decoding). This makes it ideal for **bulk audio timestamping**: transcribe with any ASR model, then batch-align with the compiled forced aligner for maximum throughput. + +#### Forced aligner + +On an A100, we observed a speed-up of ~1.2 for a batch size of 4 ([script](https://gist.github.com/ebezzam/3e0551708631784aeb684e0e838299f3#file-benchmark_compile_forced_alignment-py)). + +```python +import torch +from transformers import AutoProcessor, AutoModelForTokenClassification + +model_id = "bezzam/Qwen3-ForcedAligner-0.6B-hf" +num_warmup = 5 +batch_size = 4 + +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForTokenClassification.from_pretrained(model_id, dtype=torch.bfloat16).to("cuda") + +# Prepare a batch of 4 samples +audio_url = "https://huggingface.co/datasets/bezzam/audio_samples/resolve/main/librispeech_mr_quilter.wav" +transcript = "Mr. Quilter is the apostle of the middle classes." + +aligner_inputs, word_lists = processor.prepare_forced_aligner_inputs( + audio=[audio_url] * batch_size, + transcript=[transcript] * batch_size, + language=["English"] * batch_size, +) +aligner_inputs = aligner_inputs.to("cuda", torch.bfloat16) + +# Warm-up and apply model +model = torch.compile(model) +with torch.no_grad(): + for _ in range(num_warmup): + _ = model(**aligner_inputs) +with torch.no_grad(): + _ = model(**aligner_inputs) +``` + +#### ASR model (generate) + +For autoregressive transcription, `torch.compile` accelerates the per-token forward passes inside `generate` setting providing a `CompileConfig` object. + +On an A100, we observed a speed-up of ~3.8 for a batch size of 4 ([script](https://gist.github.com/ebezzam/3e0551708631784aeb684e0e838299f3#file-benchmark_compile_generate-py)). + +```python +import torch +from transformers import AutoProcessor, AutoModelForMultimodalLM, CompileConfig + +model_id = "bezzam/Qwen3-ASR-1.7B-hf" +num_warmup = 3 +max_new_tokens = 256 + +processor = AutoProcessor.from_pretrained(model_id) +model = AutoModelForMultimodalLM.from_pretrained(model_id, dtype=torch.bfloat16).to("cuda").eval() + +audio_url = "https://huggingface.co/datasets/bezzam/audio_samples/resolve/main/librispeech_mr_quilter.wav" +inputs = processor.apply_transcription_request( + audio=[audio_url] * 4, # batch of 4 +).to("cuda", torch.bfloat16) + +compile_config = CompileConfig() + +# Warmup +with torch.inference_mode(): + for _ in range(num_warmup): + _ = model.generate( + **inputs, max_new_tokens=max_new_tokens, do_sample=False, + cache_implementation="static", compile_config=compile_config, + ) +torch.cuda.synchronize() + +# Apply model +with torch.inference_mode(): + output_ids = model.generate( + **inputs, max_new_tokens=max_new_tokens, do_sample=False, + cache_implementation="static", compile_config=compile_config, + ) +generated_ids = output_ids[:, inputs["input_ids"].shape[1] :] +text_compiled = processor.decode(generated_ids, return_format="transcription_only")[0] +print(f"Output: {text_compiled}") +``` + +### Pipeline usage + +```python +from transformers import pipeline + +model_id = "bezzam/Qwen3-ASR-1.7B-hf" +pipe = pipeline("any-to-any", model=model_id, device_map="auto") + +chat_template = [ + { + "role": "user", + "content": [ + { + "type": "audio", + "path": "https://huggingface.co/datasets/bezzam/audio_samples/resolve/main/librispeech_mr_quilter.wav", + }, + ], + } +] +outputs = pipe(text=chat_template, return_full_text=False) +raw_text = outputs[0]["generated_text"] +print(f"Raw: {raw_text}") + +# Use processor helper to extract transcription +transcription = pipe.processor.extract_transcription(raw_text) +print(f"Transcription: {transcription}") +``` + +## Qwen3ASRConfig + +[[autodoc]] Qwen3ASRConfig + + +## Qwen3ASREncoderConfig + +[[autodoc]] Qwen3ASREncoderConfig + + +## Qwen3ASRFeatureExtractor + +[[autodoc]] Qwen3ASRFeatureExtractor + - __call__ + +## Qwen3ASRProcessor + +[[autodoc]] Qwen3ASRProcessor + - __call__ + - apply_transcription_request + - prepare_forced_aligner_inputs + - decode_forced_alignment + - decode + +## Qwen3ASREncoder + +[[autodoc]] Qwen3ASREncoder + +## Qwen3ASRModel + +[[autodoc]] Qwen3ASRModel + +## Qwen3ASRForConditionalGeneration + +[[autodoc]] Qwen3ASRForConditionalGeneration + - forward + - get_audio_features + +## Qwen3ASRForTokenClassification + +[[autodoc]] Qwen3ASRForTokenClassification + - forward diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index 04af54759735..74e0348ff30b 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -402,6 +402,30 @@ def make_list_of_audio( raise ValueError("Invalid input type. Must be a single audio or a list of audio") +def make_list_of_audio_chat_template( + audio: list[AudioInput] | AudioInput | str | list[str], +) -> AudioInput: + """ + Ensure that the output is a list of audio. Unlike `make_list_of_audio`, this function also accepts a URL string or + local path, as accepted by chat templates. + + Args: + audio (`Union[list[AudioInput], AudioInput]`): + The input audio. Can be a URL string, local path, numpy/torch array, or a list of these. + Returns: + list: A list of audio. + """ + + # Handle string inputs + if isinstance(audio, str): + return [audio] + if isinstance(audio, (list, tuple)) and audio and all(isinstance(a, str) for a in audio): + return list(audio) + + # Handle numpy/torch array inputs + return make_list_of_audio(audio) + + def hertz_to_mel(freq: float | np.ndarray, mel_scale: str = "htk") -> float | np.ndarray: """ Convert frequency from hertz to mels. diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index e3776af32e53..e69a56912377 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -263,7 +263,7 @@ def __post_init__(self, **kwargs): # Our configs prev wouldn't save `id2label` for 2 labels because it is the default. In all other # cases we expect the config dict to have an `id2label` field if it's a clf model, or not otherwise if self.id2label is None: - self.num_labels = kwargs.get("num_labels", 2) + self.num_labels = kwargs.get("num_labels", self.num_labels if self.num_labels is not None else 2) else: if kwargs.get("num_labels") is not None and len(self.id2label) != kwargs.get("num_labels"): logger.warning( diff --git a/src/transformers/modeling_layers.py b/src/transformers/modeling_layers.py index d1406ce070a4..ed043bf34285 100644 --- a/src/transformers/modeling_layers.py +++ b/src/transformers/modeling_layers.py @@ -245,7 +245,11 @@ def __init__(self, config): else: classifier_dropout = 0.1 self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.get_text_config().hidden_size, config.num_labels) + self.score = nn.Linear( + config.get_text_config().hidden_size, + config.num_labels, + bias=getattr(config, "token_classification_bias", True), + ) # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/audioflamingo3/processing_audioflamingo3.py b/src/transformers/models/audioflamingo3/processing_audioflamingo3.py index bb3a73e836c9..af8e957f9c1d 100644 --- a/src/transformers/models/audioflamingo3/processing_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/processing_audioflamingo3.py @@ -16,7 +16,7 @@ import numpy as np -from ...audio_utils import AudioInput, make_list_of_audio +from ...audio_utils import AudioInput, make_list_of_audio_chat_template from ...feature_extraction_utils import BatchFeature from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import TextInput @@ -200,14 +200,9 @@ def apply_transcription_request( """ - if isinstance(audio, str): - audio_items: list[str | np.ndarray] = [audio] - elif isinstance(audio, (list, tuple)) and audio and all(isinstance(el, str) for el in audio): - audio_items = list(audio) - else: - audio_items = list(make_list_of_audio(audio)) - if is_torch_available(): - audio_items = [el.detach().cpu().numpy() if isinstance(el, torch.Tensor) else el for el in audio_items] + audio_items: list[str | np.ndarray] = list(make_list_of_audio_chat_template(audio)) + if is_torch_available(): + audio_items = [el.detach().cpu().numpy() if isinstance(el, torch.Tensor) else el for el in audio_items] batch_size = len(audio_items) if batch_size == 0: diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index 5ba2485be0b8..f6b1c258c6d6 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -508,6 +508,8 @@ ("qwen3_5_moe_vision", "Qwen3_5MoeVisionConfig"), ("qwen3_5_text", "Qwen3_5TextConfig"), ("qwen3_5_vision", "Qwen3_5VisionConfig"), + ("qwen3_asr", "Qwen3ASRConfig"), + ("qwen3_asr_encoder", "Qwen3ASREncoderConfig"), ("qwen3_moe", "Qwen3MoeConfig"), ("qwen3_next", "Qwen3NextConfig"), ("qwen3_omni_moe", "Qwen3OmniMoeConfig"), @@ -850,6 +852,7 @@ ("qwen3_5_moe_vision", "qwen3_5_moe"), ("qwen3_5_text", "qwen3_5"), ("qwen3_5_vision", "qwen3_5"), + ("qwen3_asr_encoder", "qwen3_asr"), ("qwen3_omni_moe_audio_encoder", "qwen3_omni_moe"), ("qwen3_omni_moe_talker_code_predictor", "qwen3_omni_moe"), ("qwen3_omni_moe_talker_text", "qwen3_omni_moe"), @@ -961,6 +964,7 @@ ("pe_audio", "PeAudioFeatureExtractor"), ("phi4_multimodal", "Phi4MultimodalFeatureExtractor"), ("pop2piano", "Pop2PianoFeatureExtractor"), + ("qwen3_asr", "Qwen3ASRFeatureExtractor"), ("seamless_m4t", "SeamlessM4TFeatureExtractor"), ("speech_to_text", "Speech2TextFeatureExtractor"), ("speecht5", "SpeechT5FeatureExtractor"), @@ -1074,6 +1078,7 @@ ("qwen2_5_vl", "Qwen2_5_VLProcessor"), ("qwen2_audio", "Qwen2AudioProcessor"), ("qwen2_vl", "Qwen2VLProcessor"), + ("qwen3_asr", "Qwen3ASRProcessor"), ("qwen3_omni_moe", "Qwen3OmniMoeProcessor"), ("qwen3_vl", "Qwen3VLProcessor"), ("sam", "SamProcessor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index f2660b89732e..b4e1159d60eb 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -411,6 +411,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("qwen3_5_moe_vision", "Qwen3_5MoeVisionModel"), ("qwen3_5_text", "Qwen3_5TextModel"), ("qwen3_5_vision", "Qwen3_5VisionModel"), + ("qwen3_asr", "Qwen3ASRModel"), + ("qwen3_asr_encoder", "Qwen3ASREncoder"), ("qwen3_moe", "Qwen3MoeModel"), ("qwen3_next", "Qwen3NextModel"), ("qwen3_vl", "Qwen3VLModel"), @@ -614,6 +616,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("openai-gpt", "OpenAIGPTLMHeadModel"), ("paligemma", "PaliGemmaForConditionalGeneration"), ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), + ("qwen3_asr", "Qwen3ASRForConditionalGeneration"), ("roberta", "RobertaForMaskedLM"), ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), ("roc_bert", "RoCBertForPreTraining"), @@ -1115,6 +1118,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("phi4_multimodal", "Phi4MultimodalForCausalLM"), ("qwen2_5_omni", "Qwen2_5OmniForConditionalGeneration"), ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), + ("qwen3_asr", "Qwen3ASRForConditionalGeneration"), ("qwen3_omni_moe", "Qwen3OmniMoeForConditionalGeneration"), ("vibevoice_asr", "VibeVoiceAsrForConditionalGeneration"), ("voxtral", "VoxtralForConditionalGeneration"), @@ -1269,6 +1273,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("plbart", "PLBartForConditionalGeneration"), ("prophetnet", "ProphetNetForConditionalGeneration"), ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), + ("qwen3_asr", "Qwen3ASRForConditionalGeneration"), ("seamless_m4t", "SeamlessM4TForTextToText"), ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), @@ -1293,6 +1298,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("moonshine", "MoonshineForConditionalGeneration"), ("moonshine_streaming", "MoonshineStreamingForConditionalGeneration"), ("pop2piano", "Pop2PianoForConditionalGeneration"), + ("qwen3_asr", "Qwen3ASRForConditionalGeneration"), ("seamless_m4t", "SeamlessM4TForSpeechToText"), ("seamless_m4t_v2", "SeamlessM4Tv2ForSpeechToText"), ("speech-encoder-decoder", "SpeechEncoderDecoderModel"), @@ -1612,6 +1618,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("qwen2_moe", "Qwen2MoeForTokenClassification"), ("qwen3", "Qwen3ForTokenClassification"), ("qwen3_5", "Qwen3_5ForTokenClassification"), + ("qwen3_asr", "Qwen3ASRForTokenClassification"), ("qwen3_moe", "Qwen3MoeForTokenClassification"), ("qwen3_next", "Qwen3NextForTokenClassification"), ("rembert", "RemBertForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 86764669b193..f5c21726abae 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -282,6 +282,7 @@ ("qwen3", "Qwen2Tokenizer" if is_tokenizers_available() else None), ("qwen3_5", "Qwen3_5Tokenizer" if is_tokenizers_available() else None), ("qwen3_5_moe", "Qwen3_5Tokenizer" if is_tokenizers_available() else None), + ("qwen3_asr", "Qwen2Tokenizer" if is_tokenizers_available() else None), ("qwen3_moe", "Qwen2Tokenizer" if is_tokenizers_available() else None), ("qwen3_next", "Qwen2Tokenizer" if is_tokenizers_available() else None), ("qwen3_omni_moe", "Qwen2Tokenizer" if is_tokenizers_available() else None), diff --git a/src/transformers/models/glmasr/modular_glmasr.py b/src/transformers/models/glmasr/modular_glmasr.py index ff7d40c12efd..708f9dc64c4d 100644 --- a/src/transformers/models/glmasr/modular_glmasr.py +++ b/src/transformers/models/glmasr/modular_glmasr.py @@ -17,7 +17,7 @@ import numpy as np from ...activations import ACT2FN -from ...audio_utils import AudioInput, make_list_of_audio +from ...audio_utils import AudioInput, make_list_of_audio_chat_template from ...cache_utils import Cache from ...feature_extraction_utils import BatchFeature from ...modeling_layers import GradientCheckpointingLayer @@ -112,14 +112,9 @@ def apply_transcription_request( """ - if isinstance(audio, str): - audio_items: list[str | np.ndarray] = [audio] - elif isinstance(audio, (list, tuple)) and audio and all(isinstance(el, str) for el in audio): - audio_items = list(audio) - else: - audio_items = list(make_list_of_audio(audio)) - if is_torch_available(): - audio_items = [el.detach().cpu().numpy() if isinstance(el, torch.Tensor) else el for el in audio_items] + audio_items: list[str | np.ndarray] = list(make_list_of_audio_chat_template(audio)) + if is_torch_available(): + audio_items = [el.detach().cpu().numpy() if isinstance(el, torch.Tensor) else el for el in audio_items] batch_size = len(audio_items) if batch_size == 0: diff --git a/src/transformers/models/glmasr/processing_glmasr.py b/src/transformers/models/glmasr/processing_glmasr.py index 8cd9f9a941a7..2945b059d819 100644 --- a/src/transformers/models/glmasr/processing_glmasr.py +++ b/src/transformers/models/glmasr/processing_glmasr.py @@ -21,7 +21,7 @@ import numpy as np -from ...audio_utils import AudioInput, make_list_of_audio +from ...audio_utils import AudioInput, make_list_of_audio_chat_template from ...feature_extraction_utils import BatchFeature from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import TextInput @@ -209,14 +209,9 @@ def apply_transcription_request( """ - if isinstance(audio, str): - audio_items: list[str | np.ndarray] = [audio] - elif isinstance(audio, (list, tuple)) and audio and all(isinstance(el, str) for el in audio): - audio_items = list(audio) - else: - audio_items = list(make_list_of_audio(audio)) - if is_torch_available(): - audio_items = [el.detach().cpu().numpy() if isinstance(el, torch.Tensor) else el for el in audio_items] + audio_items: list[str | np.ndarray] = list(make_list_of_audio_chat_template(audio)) + if is_torch_available(): + audio_items = [el.detach().cpu().numpy() if isinstance(el, torch.Tensor) else el for el in audio_items] batch_size = len(audio_items) if batch_size == 0: diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 9b6bc9f30642..3fa477e1f3d5 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -144,10 +144,8 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel): def _init_weights(self, module): super()._init_weights(module) if isinstance(module, SinusoidsPositionEmbedding): - log_timescale_increment = np.log(module.max_timescale) / (module.channels // 2 - 1) - inv_timescales = torch.exp(-log_timescale_increment * torch.arange(module.channels // 2).float()) - scaled_time = torch.arange(module.length)[:, np.newaxis] * inv_timescales[np.newaxis, :] - init.copy_(module.positional_embedding, torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)) + position_embeddings = module.compute_default_singular_positional_embedding() + init.copy_(module.positional_embedding, position_embeddings) elif isinstance(module, Qwen2_5OmniUpSample1d): filter_tensor = kaiser_sinc_filter1d(0.5 / module.ratio, 0.6 / module.ratio, module.kernel_size) init.copy_(module.filter, filter_tensor) @@ -718,14 +716,14 @@ def __init__(self, length, channels, max_timescale=10000): self.max_timescale = max_timescale if channels % 2 != 0: raise ValueError("SinusoidsPositionEmbedding needs even channels input") - log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) - inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) - scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] - self.register_buffer( - "positional_embedding", - torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), - persistent=False, - ) + position_embedding = self.compute_default_singular_positional_embedding() + self.register_buffer("positional_embedding", position_embedding, persistent=False) + + def compute_default_singular_positional_embedding(self): + log_timescale_increment = np.log(self.max_timescale) / (self.channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(self.channels // 2).float()) + scaled_time = torch.arange(self.length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) def forward(self, seqlen: int): return self.positional_embedding[:seqlen, :] diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 7501f7372cfd..d011c427e4bf 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -760,10 +760,8 @@ class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel): def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, SinusoidsPositionEmbedding): - log_timescale_increment = np.log(module.max_timescale) / (module.channels // 2 - 1) - inv_timescales = torch.exp(-log_timescale_increment * torch.arange(module.channels // 2).float()) - scaled_time = torch.arange(module.length)[:, np.newaxis] * inv_timescales[np.newaxis, :] - init.copy_(module.positional_embedding, torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)) + position_embeddings = module.compute_default_singular_positional_embedding() + init.copy_(module.positional_embedding, position_embeddings) elif isinstance(module, Qwen2_5OmniUpSample1d): filter_tensor = kaiser_sinc_filter1d(0.5 / module.ratio, 0.6 / module.ratio, module.kernel_size) init.copy_(module.filter, filter_tensor) @@ -1283,14 +1281,14 @@ def __init__(self, length, channels, max_timescale=10000): self.max_timescale = max_timescale if channels % 2 != 0: raise ValueError("SinusoidsPositionEmbedding needs even channels input") - log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) - inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) - scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] - self.register_buffer( - "positional_embedding", - torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), - persistent=False, - ) + position_embedding = self.compute_default_singular_positional_embedding() + self.register_buffer("positional_embedding", position_embedding, persistent=False) + + def compute_default_singular_positional_embedding(self): + log_timescale_increment = np.log(self.max_timescale) / (self.channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(self.channels // 2).float()) + scaled_time = torch.arange(self.length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) def forward(self, seqlen: int): return self.positional_embedding[:seqlen, :] diff --git a/src/transformers/models/qwen3_asr/__init__.py b/src/transformers/models/qwen3_asr/__init__.py new file mode 100644 index 000000000000..19df31aaf924 --- /dev/null +++ b/src/transformers/models/qwen3_asr/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qwen3_asr import * + from .feature_extraction_qwen3_asr import * + from .modeling_qwen3_asr import * + from .processing_qwen3_asr import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/qwen3_asr/configuration_qwen3_asr.py b/src/transformers/models/qwen3_asr/configuration_qwen3_asr.py new file mode 100644 index 000000000000..73c42030cd8d --- /dev/null +++ b/src/transformers/models/qwen3_asr/configuration_qwen3_asr.py @@ -0,0 +1,136 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_asr/modular_qwen3_asr.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_asr.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...utils import auto_docstring +from ..auto import CONFIG_MAPPING, AutoConfig + + +@auto_docstring(checkpoint="bezzam/Qwen3-ASR-1.7B-hf") +@strict +class Qwen3ASREncoderConfig(PreTrainedConfig): + r""" + n_window (`int`, *optional*, defaults to 50): + Half the number of mel frames in one encoder chunk. Each chunk processed by the conv stack has + ``2 * n_window`` mel frames (1 second of audio at 16 kHz with a 10 ms hop). + output_dim (`int`, *optional*, defaults to 3584): + Dimensionality of the output. + n_window_infer (`int`, *optional*, defaults to 800): + Number of mel frames worth of audio over which each attention window spans. Must be a multiple + of ``n_window * 2`` so attention windows align with encoder chunks. + downsample_hidden_size (`int`, *optional*, defaults to 480): + Hidden size of the convolutional downsampling stack. + """ + + model_type = "qwen3_asr_encoder" + attribute_map = { + "num_hidden_layers": "encoder_layers", + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + "intermediate_size": "encoder_ffn_dim", + } + + num_mel_bins: int = 128 + encoder_layers: int = 24 + encoder_attention_heads: int = 16 + encoder_ffn_dim: int = 4096 + d_model: int = 1024 + dropout: float | int = 0.0 + attention_dropout: float | int = 0.0 + activation_function: str = "gelu" + activation_dropout: float | int = 0.0 + scale_embedding: bool = False + initializer_range: float = 0.02 + + n_window: int = 50 + output_dim: int = 3584 + n_window_infer: int = 800 + downsample_hidden_size: int = 480 + max_position_embeddings: int = 13 + + +@auto_docstring(checkpoint="bezzam/Qwen3-ASR-1.7B-hf") +@strict +class Qwen3ASRConfig(PreTrainedConfig): + r""" + audio_token_id (`int`, *optional*, defaults to 151676): + The audio token id to encode the audio prompt. + timestamp_token_id (`int`, *optional*, defaults to 151705): + Token ID of the ```` marker in the tokenizer vocabulary. These markers + delimit word boundaries in the forced-alignment input sequence. + token_classification_bias (`bool`, *optional*, defaults to False): + Whether the token classification head for forced alignment should have a bias term. + + Example: + + ```python + >>> from transformers import Qwen3ASRForConditionalGeneration, Qwen3ASRConfig + + >>> # Initializing a Qwen3ASR style configuration + >>> configuration = Qwen3ASRConfig() + + >>> # Initializing a model from the configuration + >>> model = Qwen3ASRForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_asr" + sub_configs = {"audio_config": AutoConfig, "text_config": AutoConfig} + + audio_config: dict | PreTrainedConfig | None = None + text_config: dict | PreTrainedConfig | None = None + audio_token_id: int = 151676 + timestamp_token_id: int = 151705 + pad_token_id: int = 151645 + eos_token_id: list[int] | tuple[int, ...] | int = (151643, 151645) + initializer_range: float = 0.02 + tie_word_embeddings: bool = True + token_classification_bias: bool = False + + def __post_init__(self, **kwargs): + if isinstance(self.audio_config, dict): + self.audio_config["model_type"] = self.audio_config.get("model_type", "qwen3_asr_encoder") + self.audio_config = CONFIG_MAPPING[self.audio_config["model_type"]](**self.audio_config) + elif self.audio_config is None: + self.audio_config = CONFIG_MAPPING["qwen3_asr_encoder"]() + + if isinstance(self.text_config, dict): + self.text_config["model_type"] = self.text_config.get("model_type", "qwen3") + self.text_config = CONFIG_MAPPING[self.text_config["model_type"]](**self.text_config) + elif self.text_config is None: + self.text_config = CONFIG_MAPPING["qwen3"]( + hidden_size=2048, + intermediate_size=6144, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=8, + head_dim=128, + max_position_embeddings=65536, + tie_word_embeddings=True, + ) + + super().__post_init__(**kwargs) + + +__all__ = ["Qwen3ASREncoderConfig", "Qwen3ASRConfig"] diff --git a/src/transformers/models/qwen3_asr/convert_qwen3_asr_to_hf.py b/src/transformers/models/qwen3_asr/convert_qwen3_asr_to_hf.py new file mode 100644 index 000000000000..85a3d5bb15f1 --- /dev/null +++ b/src/transformers/models/qwen3_asr/convert_qwen3_asr_to_hf.py @@ -0,0 +1,390 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Convert Qwen3 ASR or Qwen3 Forced Aligner checkpoints to Hugging Face format. + +The script auto-detects the model type from the source checkpoint's config.json +(by looking for a ``classify_num`` field inside ``thinker_config``). You can +also force the type with ``--model_type asr`` or ``--model_type forced_aligner``. + +Reproducible Usage +================== + +1) Convert a Qwen3 ASR model: + +``` +python src/transformers/models/qwen3_asr/convert_qwen3_asr_to_hf.py \ + --model_id Qwen/Qwen3-ASR-0.6B \ + --dst_dir qwen3-asr-hf \ + --push_to_hub /Qwen3-ASR-0.6B +``` + +2) Convert a Qwen3 Forced Aligner model: + +``` +python src/transformers/models/qwen3_asr/convert_qwen3_asr_to_hf.py \ + --model_id Qwen/Qwen3-ForcedAligner-0.6B \ + --dst_dir qwen3-forced-aligner-hf \ + --push_to_hub /Qwen3-ForcedAligner-0.6B +``` + +3) Convert from a local directory with explicit model type: + +``` +python src/transformers/models/qwen3_asr/convert_qwen3_asr_to_hf.py \ + --src_dir /path/to/local/model \ + --dst_dir output-hf \ + --model_type forced_aligner +``` +""" + +import argparse +import json +import logging +import shutil +import tempfile +from pathlib import Path +from typing import Any + +import torch +from huggingface_hub import snapshot_download +from safetensors.torch import safe_open + +from transformers import ( + AutoTokenizer, + GenerationConfig, + Qwen3ASRConfig, + Qwen3ASRFeatureExtractor, + Qwen3ASRForConditionalGeneration, + Qwen3ASRForTokenClassification, + Qwen3ASRProcessor, +) + + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + +# fmt: off +STATE_DICT_MAPPING_ASR = { + "thinker.model.": "model.language_model.", + "thinker.lm_head.": "lm_head.", + "thinker.": "model.", + "model.audio_tower.proj1.": "model.multi_modal_projector.linear_1.", + "model.audio_tower.proj2.": "model.multi_modal_projector.linear_2.", +} + +STATE_DICT_MAPPING_FORCED_ALIGNER = { + "thinker.model.": "model.language_model.", + "thinker.lm_head.": "score.", + "thinker.": "model.", + "model.audio_tower.proj1.": "model.multi_modal_projector.linear_1.", + "model.audio_tower.proj2.": "model.multi_modal_projector.linear_2.", +} +# fmt: on + + +def convert_state_dict(original_state_dict: dict[str, Any], mapping: dict[str, str]) -> dict[str, Any]: + """Convert checkpoint state dict to transformers format.""" + converted = {} + for k, v in original_state_dict.items(): + for old_prefix, new_prefix in mapping.items(): + k = k.replace(old_prefix, new_prefix) + converted[k] = v + return converted + + +def detect_model_type(src_root: Path) -> str: + """Auto-detect model type from the source checkpoint's config.json.""" + config_path = src_root / "config.json" + with open(config_path, "r") as f: + config = json.load(f) + + thinker = config.get("thinker_config", {}) + if "classify_num" in thinker: + logger.info("Auto-detected model type: forced_aligner (found classify_num in thinker_config)") + return "forced_aligner" + + logger.info("Auto-detected model type: asr (no classify_num in thinker_config)") + return "asr" + + +def clean_config(src_root: Path, model_type: str) -> dict: + """Load and clean up the source config for transformers compatibility.""" + config_path = src_root / "config.json" + with open(config_path, "r") as f: + model_config = json.load(f) + + config_dict = model_config.copy() + + # fmt: off + # Remove unused top-level keys + for key in ["support_languages"]: + config_dict.pop(key, None) + + # Flatten thinker_config structure + if "thinker_config" in config_dict: + thinker_config = config_dict.pop("thinker_config") + if "audio_config" in thinker_config: + config_dict["audio_config"] = thinker_config["audio_config"] + if "text_config" in thinker_config: + config_dict["text_config"] = thinker_config["text_config"] + if "audio_token_id" in thinker_config: + config_dict["audio_token_id"] = thinker_config["audio_token_id"] + if "initializer_range" in thinker_config: + config_dict["initializer_range"] = thinker_config["initializer_range"] + # Forced aligner specific + if model_type == "forced_aligner" and "classify_num" in thinker_config: + config_dict["num_labels"] = thinker_config["classify_num"] + + # Audio config: rename Whisper-style field names + if "audio_config" in config_dict: + audio_renames = {"encoder_attention_heads": "num_attention_heads"} + for old_name, new_name in audio_renames.items(): + if old_name in config_dict["audio_config"]: + config_dict["audio_config"][new_name] = config_dict["audio_config"].pop(old_name) + + # Also set num_key_value_heads = num_attention_heads (MHA, no GQA in the encoder) + if "num_key_value_heads" not in config_dict["audio_config"] and "num_attention_heads" in config_dict["audio_config"]: + config_dict["audio_config"]["num_key_value_heads"] = config_dict["audio_config"]["num_attention_heads"] + + # Override max_source_positions: the original checkpoint uses 1500 (inherited from Whisper/OmniMoe), + # but Qwen3ASR chunks are fixed at n_window*2=100 mel frames β†’ 13 post-CNN positions. + config_dict["audio_config"]["max_source_positions"] = 13 + + # Audio config: strip non-standard fields + if "audio_config" in config_dict: + audio_unused = [ + "_name_or_path", "architectures", "dtype", "model_type", "use_bfloat16", "add_cross_attention", + "chunk_size_feed_forward", "cross_attention_hidden_size", "decoder_start_token_id", + "finetuning_task", "id2label", "label2id", "is_decoder", "is_encoder_decoder", + "output_attentions", "output_hidden_states", "pad_token_id", "bos_token_id", "eos_token_id", + "prefix", "problem_type", "pruned_heads", "return_dict", "sep_token_id", "task_specific_params", + "tf_legacy_loss", "tie_encoder_decoder", "tie_word_embeddings", "tokenizer_class", "torchscript", + ] + for key in audio_unused: + config_dict["audio_config"].pop(key, None) + + # Text config: strip non-standard fields + MoE fields + M-RoPE fields + if "text_config" in config_dict: + text_unused = [ + "_name_or_path", "architectures", "dtype", "model_type", "use_bfloat16", "add_cross_attention", + "chunk_size_feed_forward", "cross_attention_hidden_size", "decoder_start_token_id", + "finetuning_task", "id2label", "label2id", "is_decoder", "is_encoder_decoder", + "output_attentions", "output_hidden_states", "prefix", "problem_type", "pruned_heads", + "return_dict", "sep_token_id", "task_specific_params", "tf_legacy_loss", "tie_encoder_decoder", + "tokenizer_class", "torchscript", + # MoE-specific fields + "decoder_sparse_step", "moe_intermediate_size", "num_experts_per_tok", "num_experts", + "norm_topk_prob", "output_router_logits", "router_aux_loss_coef", "mlp_only_layers", + ] + for key in text_unused: + config_dict["text_config"].pop(key, None) + + # Strip M-RoPE fields from rope_scaling + rope_cfg = config_dict["text_config"].get("rope_scaling") + if isinstance(rope_cfg, dict): + for mrope_key in ["mrope_interleaved", "interleaved", "mrope_section", "type"]: + rope_cfg.pop(mrope_key, None) + # fmt: on + + return config_dict + + +# fmt: off +FORCED_ALIGNER_CHAT_TEMPLATE = ( + "{%- set ns = namespace(audio_tokens='', words=[]) -%}" + "{%- for m in messages -%}" + "{%- if m.content is not string -%}" + "{%- for c in m.content -%}" + "{%- if c.type == 'audio' or ('audio' in c) or ('audio_url' in c) -%}" + "{%- set ns.audio_tokens = ns.audio_tokens + '<|audio_start|><|audio_pad|><|audio_end|>' -%}" + "{%- endif -%}" + "{%- if c.type == 'text' and (c.text is defined) -%}" + "{%- set ns.words = ns.words + [c.text] -%}" + "{%- endif -%}" + "{%- endfor -%}" + "{%- endif -%}" + "{%- endfor -%}" + "{{- ns.audio_tokens + ns.words | join('') + '' -}}" +) +# fmt: on + + +def write_processor(src_root: Path, dst_root: Path, model_type: str): + """Write processor (shared by both ASR and Forced Aligner).""" + tokenizer = AutoTokenizer.from_pretrained(src_root) + + if model_type == "forced_aligner": + chat_template = FORCED_ALIGNER_CHAT_TEMPLATE + else: + # Load chat template from separate file if it exists + chat_template_file = src_root / "chat_template.json" + chat_template = None + if chat_template_file.exists(): + logger.info("Loading chat template from %s", chat_template_file) + with open(chat_template_file, "r", encoding="utf-8") as f: + chat_template_data = json.load(f) + chat_template = chat_template_data.get("chat_template") + + processor = Qwen3ASRProcessor( + feature_extractor=Qwen3ASRFeatureExtractor(), + tokenizer=tokenizer, + chat_template=chat_template, + ) + processor.save_pretrained(str(dst_root)) + logger.info("Processor saved to %s", dst_root) + return processor + + +def load_state_dict(src_root: Path) -> dict[str, torch.Tensor]: + """Load safetensors state dict from source directory.""" + state = {} + shard_files = sorted(src_root.glob("model-*.safetensors")) + single_file = src_root / "model.safetensors" + + if shard_files: + logger.info("Found %d sharded safetensor files", len(shard_files)) + safetensor_paths = shard_files + elif single_file.exists(): + safetensor_paths = [single_file] + else: + raise FileNotFoundError(f"No safetensor files found in {src_root}") + + for path in safetensor_paths: + logger.info("Loading %s", path.name) + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + state[key] = f.get_tensor(key) + + return state + + +def write_asr_model(src_root: Path, dst_root: Path): + """Convert and write a Qwen3 ASR model.""" + config_dict = clean_config(src_root, "asr") + config = Qwen3ASRConfig(**config_dict) + model = Qwen3ASRForConditionalGeneration(config).to(torch.bfloat16) + + state = load_state_dict(src_root) + state = convert_state_dict(state, STATE_DICT_MAPPING_ASR) + + load_res = model.load_state_dict(state, strict=True) + if load_res.missing_keys: + raise ValueError(f"Missing keys: {load_res.missing_keys}") + if load_res.unexpected_keys: + raise ValueError(f"Unexpected keys: {load_res.unexpected_keys}") + + model.to(torch.bfloat16) + # max_new_tokens=512 matches the default in the original Qwen3-ASR library: + # https://github.com/QwenLM/Qwen3-ASR/blob/c17a131fe028b2e428b6e80a33d30bb4fa57b8df/qwen_asr/inference/qwen3_asr.py#L153 + model.generation_config = GenerationConfig( + eos_token_id=(151643, 151645), + pad_token_id=151645, + do_sample=False, + max_new_tokens=512, + ) + model.save_pretrained(str(dst_root)) + logger.info("ASR model saved to %s", dst_root) + return model + + +def write_forced_aligner_model(src_root: Path, dst_root: Path): + """Convert and write a Qwen3 Forced Aligner model.""" + config_dict = clean_config(src_root, "forced_aligner") + config = Qwen3ASRConfig(**config_dict) + model = Qwen3ASRForTokenClassification(config).to(torch.bfloat16) + + state = load_state_dict(src_root) + state = convert_state_dict(state, STATE_DICT_MAPPING_FORCED_ALIGNER) + + load_res = model.load_state_dict(state, strict=True) + if load_res.missing_keys: + raise ValueError(f"Missing keys: {load_res.missing_keys}") + if load_res.unexpected_keys: + raise ValueError(f"Unexpected keys: {load_res.unexpected_keys}") + + model.to(torch.bfloat16) + model.save_pretrained(str(dst_root)) + logger.info("Forced Aligner model saved to %s", dst_root) + return model + + +def main() -> None: + ap = argparse.ArgumentParser( + description="Convert Qwen3 ASR or Qwen3 Forced Aligner checkpoints to Hugging Face format." + ) + ap.add_argument("--model_id", default=None, type=str, help="Hugging Face model ID") + ap.add_argument("--src_dir", default=None, help="Source model root directory (alternative to --model_id)") + ap.add_argument("--dst_dir", required=True, help="Destination directory for converted model") + ap.add_argument( + "--model_type", + default=None, + choices=["asr", "forced_aligner"], + help="Model type to convert. If not specified, auto-detected from the source config.", + ) + ap.add_argument("--push_to_hub", default=None, type=str, help="Push to Hub repo ID") + args = ap.parse_args() + + # Determine source directory + if args.model_id: + logger.info("Downloading model from Hugging Face Hub: %s", args.model_id) + src_root = Path(tempfile.mkdtemp()) + src_root = Path(snapshot_download(args.model_id, cache_dir=str(src_root))) + logger.info("Model downloaded to: %s", src_root) + elif args.src_dir: + src_root = Path(args.src_dir).resolve() + else: + raise ValueError("Either --model_id or --src_dir must be provided") + + if not src_root.is_dir(): + raise FileNotFoundError(f"Source directory not found: {src_root}") + + # Auto-detect or use provided model type + model_type = args.model_type or detect_model_type(src_root) + logger.info("Converting model type: %s", model_type) + + dst_root = Path(args.dst_dir).resolve() + if dst_root.exists(): + logger.info("Removing existing destination directory: %s", dst_root) + shutil.rmtree(dst_root) + + # Write processor (shared class, model-type-specific chat template) + processor = write_processor(src_root, dst_root, model_type) + + # Write model + if model_type == "asr": + model = write_asr_model(src_root, dst_root) + else: + model = write_forced_aligner_model(src_root, dst_root) + + # Optionally push to Hub + if args.push_to_hub: + logger.info("Pushing processor to the Hub ...") + processor.push_to_hub(args.push_to_hub) + logger.info("Pushing model to the Hub ...") + model.push_to_hub(args.push_to_hub) + + # Verify upload + logger.info("Verifying upload by loading from Hub: %s", args.push_to_hub) + _ = Qwen3ASRProcessor.from_pretrained(args.push_to_hub) + if model_type == "asr": + _ = Qwen3ASRForConditionalGeneration.from_pretrained(args.push_to_hub) + else: + _ = Qwen3ASRForTokenClassification.from_pretrained(args.push_to_hub) + logger.info("Verification successful!") + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/qwen3_asr/feature_extraction_qwen3_asr.py b/src/transformers/models/qwen3_asr/feature_extraction_qwen3_asr.py new file mode 100644 index 000000000000..6f22a8c8f4ff --- /dev/null +++ b/src/transformers/models/qwen3_asr/feature_extraction_qwen3_asr.py @@ -0,0 +1,239 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from ...audio_utils import mel_filter_bank +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import logging +from ...utils.import_utils import is_torch_available, requires + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +@requires(backends=("torch",)) +class Qwen3ASRFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a Qwen3 ASR feature extractor. + + Extracts 128-bin log-mel features from raw speech, then right-pads the mel time axis to a multiple of ``2 * n_window``. + + Args: + feature_size (`int`, *optional*, defaults to 128): + Number of mel filter banks. + sampling_rate (`int`, *optional*, defaults to 16000): + Audio sampling rate in Hz. + hop_length (`int`, *optional*, defaults to 160): + Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients. + chunk_length (`int`, *optional*, defaults to 30): + Maximum audio length (in seconds) used to trim/pad when ``padding="max_length"``. + n_fft (`int`, *optional*, defaults to 400): + Size of the Fourier transform. + padding_value (`float`, *optional*, defaults to 0.0): + Padding value used to pad the raw audio. + dither (`float`, *optional*, defaults to 0.0): + If non-zero, adds Gaussian noise (`std = dither`) to each STFT frame. + return_attention_mask (`bool`, *optional*, defaults to `True`): + Whether to return the attention mask corresponding to the padded mel frames. + n_window (`int`, *optional*, defaults to 50): + Half the mel-frame chunk size used for padding. The log-mel time axis is right-padded to a + multiple of ``2 * n_window``. + min_length (`int`, *optional*, defaults to 8000): + Minimum number of samples for each audio clip. Clips shorter than this are zero-padded, matching the + original Qwen3-ASR library behaviour. + """ + + model_input_names = ["input_features"] + + def __init__( + self, + feature_size=128, + sampling_rate=16000, + hop_length=160, + chunk_length=30, + n_fft=400, + padding_value=0.0, + dither=0.0, + return_attention_mask=True, + n_window=50, + min_length=8000, + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + return_attention_mask=return_attention_mask, + **kwargs, + ) + self.n_fft = n_fft + self.hop_length = hop_length + self.min_length = min_length + self.chunk_length = chunk_length + self.n_samples = chunk_length * sampling_rate + self.nb_max_frames = self.n_samples // hop_length + self.sampling_rate = sampling_rate + self.dither = dither + self.n_window = n_window + self.mel_filters = mel_filter_bank( + num_frequency_bins=1 + n_fft // 2, + num_mel_filters=feature_size, + min_frequency=0.0, + max_frequency=8000.0, + sampling_rate=sampling_rate, + norm="slaney", + mel_scale="slaney", + ) + + def _torch_extract_fbank_features(self, waveform: np.ndarray, device: str = "cpu") -> np.ndarray: + """Compute log-mel spectrograms using PyTorch's (optionally GPU-accelerated) STFT.""" + waveform = torch.from_numpy(waveform).to(device, torch.float32) + window = torch.hann_window(self.n_fft, device=device) + + if self.dither != 0.0: + waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device) + + stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) + mel_spec = mel_filters.T @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + if waveform.dim() == 2: + max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] + log_spec = torch.maximum(log_spec, max_val - 8.0) + else: + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + if device != "cpu": + log_spec = log_spec.detach().cpu() + return log_spec.numpy() + + def __call__( + self, + raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], + truncation: bool = False, + pad_to_multiple_of: int | None = None, + return_tensors: str | None = "pt", + return_attention_mask: bool | None = None, + padding: str | None = "max_length", + max_length: int | None = None, + sampling_rate: int | None = None, + n_window: int | None = None, + device: str | None = "cpu", + **kwargs, + ) -> BatchFeature: + r""" + Prepare log-mel features from one or several audio sequences. + + Args: + raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): + The sequence or batch of sequences to be padded. Mono-channel audio only. + pad_to_multiple_of (`int`, *optional*): + If set, pads the raw audio to a multiple of this value (in samples). Separate from + ``n_window``, which applies to the mel-frame axis. + n_window (`int`, *optional*): + Override the instance's ``n_window`` for this call. The mel axis is padded to a multiple + of ``2 * n_window``. Set to ``0`` to skip mel-axis padding entirely. + device (`str`, *optional*, defaults to `"cpu"`): + Device used to compute the log-mel spectrogram. + """ + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" + f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" + f" was sampled with {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float32) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float32) + + # always return batch + if not is_batched: + raw_speech = [np.asarray([raw_speech]).T] + + # Zero-pad clips shorter than min_length before batching, matching the original Qwen3-ASR library: + # https://github.com/QwenLM/Qwen3-ASR/blob/c17a131fe028b2e428b6e80a33d30bb4fa57b8df/qwen_asr/inference/utils.py#L322 + # NOTE: as original, do not adjust padding/attention masks (hurts performance on AMI) + if self.min_length > 0: + raw_speech = [ + np.pad(s, ((0, self.min_length - s.shape[0]), (0, 0))) if s.shape[0] < self.min_length else s + for s in raw_speech + ] + + batched_speech = BatchFeature({"input_features": raw_speech}) + + padded_inputs = self.pad( + batched_speech, + padding=padding, + max_length=max_length if max_length else self.n_samples, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + input_features = padded_inputs["input_features"].transpose(2, 0, 1) + input_features = self._torch_extract_fbank_features(input_features[0], device) + padded_inputs["input_features"] = input_features + + # Rescale raw-sample attention mask to mel-frame resolution. + rescaled_attention_mask = padded_inputs["attention_mask"][:, :: self.hop_length] + if padded_inputs["attention_mask"].shape[1] % self.hop_length != 0: + rescaled_attention_mask = rescaled_attention_mask[:, :-1] + padded_inputs["attention_mask"] = rescaled_attention_mask + + # Right-pad the mel time axis to a multiple of `2 * n_window` (needed by `Qwen3ASREncoder`). + if n_window is None: + n_window = self.n_window + multiple = n_window * 2 + if multiple and multiple > 1: + remainder = padded_inputs["input_features"].shape[-1] % multiple + pad = (multiple - remainder) if remainder else 0 + if pad: + padded_inputs["input_features"] = np.pad(padded_inputs["input_features"], [(0, 0), (0, 0), (0, pad)]) + padded_inputs["attention_mask"] = np.pad(padded_inputs["attention_mask"], [(0, 0), (0, pad)]) + + if not return_attention_mask: + padded_inputs.pop("attention_mask", None) + + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + +__all__ = ["Qwen3ASRFeatureExtractor"] diff --git a/src/transformers/models/qwen3_asr/modeling_qwen3_asr.py b/src/transformers/models/qwen3_asr/modeling_qwen3_asr.py new file mode 100644 index 000000000000..96e28ed0cd5b --- /dev/null +++ b/src/transformers/models/qwen3_asr/modeling_qwen3_asr.py @@ -0,0 +1,708 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_asr/modular_qwen3_asr.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_asr.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable +from dataclasses import dataclass + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_layers import GenericForTokenClassification, GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check +from ...utils.generic import is_flash_attention_requested +from ...utils.output_capturing import capture_outputs +from ..auto import AutoModel +from .configuration_qwen3_asr import Qwen3ASRConfig, Qwen3ASREncoderConfig + + +@auto_docstring +class Qwen3ASRPreTrainedModel(PreTrainedModel): + config: Qwen3ASRConfig + base_model_prefix = "model" + input_modalities = ("audio", "text") + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3ASREncoderLayer", "Qwen3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _can_compile_fullgraph = True + _supports_attention_backend = True + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, SinusoidsPositionEmbedding): + position_embeddings = module.compute_default_singular_positional_embedding() + init.copy_(module.positional_embedding, position_embeddings) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen3ASRAudioAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.embed_dim = config.d_model + self.num_heads = config.encoder_attention_heads + self.dropout = config.attention_dropout + self.head_dim = self.embed_dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.config = config + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = 0.0 + self.is_decoder = False + self.is_causal = False + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + + seq_length, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + if is_flash_attention_requested(self.config): + # Flash Attention: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Qwen3ASRAudioEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3ASREncoderConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = Qwen3ASRAudioAttention(config) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + **kwargs, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + return outputs + + +class SinusoidsPositionEmbedding(nn.Module): + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + self.length = length + self.channels = channels + self.max_timescale = max_timescale + if channels % 2 != 0: + raise ValueError("SinusoidsPositionEmbedding needs even channels input") + position_embedding = self.compute_default_singular_positional_embedding() + self.register_buffer("positional_embedding", position_embedding, persistent=False) + + def compute_default_singular_positional_embedding(self): + log_timescale_increment = np.log(self.max_timescale) / (self.channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(self.channels // 2).float()) + scaled_time = torch.arange(self.length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +def _get_feat_extract_output_lengths(input_lengths, n_window=50): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + chunk_len = n_window * 2 + input_lengths_leave = input_lengths % chunk_len + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + return ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // chunk_len) * 13 + + +def get_audio_cu_seqlens( + chunk_lengths: torch.Tensor, + feature_lens: torch.Tensor, + n_window_infer: int, + n_window: int, + kwargs: dict | None = None, +) -> torch.Tensor: + """Compute cumulative sequence lengths for audio attention windowing, or pop `"cu_seqlens"` from `kwargs` if precomputed. + + Splits each sample's post-CNN features into inference windows and returns + cumulative boundaries for flash-attention-style sequence packing. + + Args: + chunk_lengths: `(num_chunks,)` pre-CNN chunk lengths. + feature_lens: `(batch_size,)` per-sample frame counts. + n_window_infer: inference window size (in raw frames). + n_window: half the chunk size (in raw frames). + kwargs: optional caller kwargs β€” if it contains `"cu_seqlens"` it is popped and returned. + + Returns: + `(num_windows + 1,)` int32 cumulative sequence boundaries. + """ + if kwargs is not None and (cu_seqlens := kwargs.pop("cu_seqlens", None)) is not None: + return cu_seqlens + + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens, n_window) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths, n_window) + max_len_after_cnn = feature_lens_after_cnn.max().item() + + n_window_ratio = n_window_infer // (n_window * 2) + window_aftercnn = max_len_after_cnn * n_window_ratio + + cu_chunk_lens = [0] + for cnn_len in aftercnn_lens: + cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) + remainder = cnn_len % window_aftercnn + if remainder != 0: + cu_chunk_lens += [remainder] + + return torch.tensor(cu_chunk_lens, device=feature_lens.device).cumsum(-1, dtype=torch.int32) + + +@auto_docstring( + custom_intro=""" + The audio model for Qwen3 ASR without any head or projection on top. + """ +) +class Qwen3ASREncoder(Qwen3ASRPreTrainedModel): + config: Qwen3ASREncoderConfig + main_input_name = "input_features" + input_modalities = "audio" + _no_split_modules = ["Qwen3ASREncoderLayer"] + _supports_sdpa = True + _can_record_outputs = { + "hidden_states": Qwen3ASRAudioEncoderLayer, + "attentions": Qwen3ASRAudioAttention, + } + + def __init__(self, config: Qwen3ASREncoderConfig): + super().__init__(config) + self.dropout = config.dropout + embed_dim = config.d_model + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.n_window = config.n_window + self.positional_embedding = SinusoidsPositionEmbedding(config.max_position_embeddings, embed_dim) + self.layers = nn.ModuleList([Qwen3ASRAudioEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.ln_post = nn.LayerNorm(config.d_model) + self.gradient_checkpointing = False + self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1) + self.conv2d2 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1) + self.conv2d3 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1) + self.conv_out = nn.Linear( + config.downsample_hidden_size * ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2), + config.d_model, + bias=False, + ) + self.n_window_infer = self.config.n_window_infer + # Initialize weights and apply final processing + self.post_init() + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def get_input_embeddings(self) -> nn.Module: + return self.conv2d1 + + def set_input_embeddings(self, value): + self.conv2d1 = value + + @capture_outputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + input_features: torch.Tensor, + input_features_mask: torch.Tensor, + **kwargs, + ) -> BaseModelOutputWithPooling: + r""" + input_features_mask (`torch.LongTensor` of shape `(batch_size, padded_feature_length)`): + 1 for valid mel frames and 0 for padding. + """ + batch_size, num_mel_bins, padded_feature_length = input_features.shape + chunk_len = self.n_window * 2 + + if padded_feature_length % chunk_len != 0: + raise ValueError( + f"Qwen3ASREncoder expects `padded_feature_length` to be a multiple of " + f"`n_window * 2` ({chunk_len}), but got {padded_feature_length}." + ) + + num_chunks = padded_feature_length // chunk_len + + # Compute cu_seqlens for windowed attention + feature_lens = input_features_mask.sum(-1).to(torch.long) + chunk_lengths = ( + input_features_mask.view(batch_size, num_chunks, chunk_len).sum(dim=-1).reshape(-1).to(torch.long) + ) + cu_seqlens = get_audio_cu_seqlens(chunk_lengths, feature_lens, self.n_window_infer, self.n_window) + + # Chunk and process through CNN + chunked = ( + input_features.view(batch_size, num_mel_bins, num_chunks, chunk_len) + .permute(0, 2, 1, 3) + .reshape(batch_size * num_chunks, 1, num_mel_bins, chunk_len) + ) + + conv_out = F.gelu(self.conv2d1(chunked)) + conv_out = F.gelu(self.conv2d2(conv_out)) + conv_out = F.gelu(self.conv2d3(conv_out)) + total_chunks, conv_channels, freq_bins, time_steps = conv_out.size() + conv_out = self.conv_out( + conv_out.permute(0, 3, 1, 2).contiguous().view(total_chunks, time_steps, conv_channels * freq_bins) + ) + conv_out += self.positional_embedding.positional_embedding[:time_steps].to(conv_out.dtype) + + # Select only valid (non-padding) post-CNN positions into a flat packed sequence + chunk_post_cnn_lens = self._post_cnn_length( + input_features_mask.view(batch_size, num_chunks, chunk_len).sum(dim=-1).reshape(-1).to(torch.long) + ) + valid_mask = torch.arange(time_steps, device=input_features.device) < chunk_post_cnn_lens.unsqueeze(1) + valid_indices = valid_mask.flatten().nonzero().squeeze(-1) + hidden_states = torch.index_select(conv_out.reshape(-1, conv_out.shape[-1]), 0, valid_indices) + + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs) + hidden_states = layer_outputs[0] + + hidden_states = self.ln_post(hidden_states) + return BaseModelOutputWithPooling(last_hidden_state=hidden_states) + + @staticmethod + def _post_cnn_length(lengths: torch.Tensor) -> torch.Tensor: + """Length after three (k=3, s=2, p=1) convolutions; zero-length input stays zero.""" + for _ in range(3): + lengths = torch.where(lengths > 0, (lengths - 1) // 2 + 1, torch.zeros_like(lengths)) + return lengths + + +class Qwen3ASRMultiModalProjector(nn.Module): + def __init__(self, config: Qwen3ASRConfig): + super().__init__() + self.linear_1 = nn.Linear(config.audio_config.d_model, config.audio_config.d_model) + self.act = ACT2FN[config.audio_config.activation_function] + self.linear_2 = nn.Linear(config.audio_config.d_model, config.audio_config.output_dim) + + def forward(self, audio_features): + hidden_states = self.linear_1(audio_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@dataclass +class Qwen3ASRModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The Qwen3ASR model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model), + without a language modeling head. + """ +) +class Qwen3ASRModel(Qwen3ASRPreTrainedModel): + _tp_plan = None + _pp_plan = None + _keep_in_fp32_modules_strict = None + + def __init__(self, config): + super().__init__(config) + self.audio_tower = AutoModel.from_config(config.audio_config) + self.language_model = AutoModel.from_config(config.text_config) + self.multi_modal_projector = Qwen3ASRMultiModalProjector(config) + self.post_init() + + @can_return_tuple + @auto_docstring( + custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram)." + ) + def get_audio_features( + self, + input_features: torch.FloatTensor, + input_features_mask: torch.LongTensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + input_features_mask (`torch.LongTensor` of shape `(batch_size, padded_feature_length)`): + 1 for valid mel frames and 0 for padding. + """ + audio_output = self.audio_tower( + input_features=input_features, + input_features_mask=input_features_mask, + **kwargs, + ) + audio_output.pooler_output = self.multi_modal_projector(audio_output.last_hidden_state) + return audio_output + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_audio_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_audio_mask = special_audio_mask.all(-1) + else: + special_audio_mask = input_ids == self.config.audio_token_id + + n_audio_tokens = special_audio_mask.sum() + n_audio_features = audio_features.shape[0] + special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[special_audio_mask].numel() == audio_features.numel(), + f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", + ) + return special_audio_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Qwen3ASRModelOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + """ + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_features is not None and input_ids is not None: + audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output + + # replace text-audio token placeholders with audio embeddings + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + return Qwen3ASRModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + Base class for Qwen3ASR causal language model (or autoregressive) outputs. + """ +) +@dataclass +class Qwen3ASRCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The Qwen3ASR model which consists of an audio encoder and a language model. + """ +) +class Qwen3ASRForConditionalGeneration(Qwen3ASRPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3ASRModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_audio_features(self, input_features, input_features_mask, **kwargs): + return self.model.get_audio_features(input_features, input_features_mask, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Qwen3ASRCausalLMOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. + + Example: + + ```python + >>> from transformers import Qwen3ASRForConditionalGeneration, AutoProcessor + + >>> model_id = "bezzam/Qwen3-ASR-1.7B-hf" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = Qwen3ASRForConditionalGeneration.from_pretrained(model_id, device_map="auto") + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return Qwen3ASRCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) + + def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): + input_features = kwargs.pop("input_features", None) + input_features_mask = kwargs.pop("input_features_mask", None) + + model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) + + if is_first_iteration or not model_inputs.get("use_cache", False): + if input_features is not None: + model_inputs["input_features"] = input_features + if input_features_mask is not None: + model_inputs["input_features_mask"] = input_features_mask + + return model_inputs + + +@auto_docstring( + custom_intro=""" + The Qwen3 ASR model with a token classification head for timestamp prediction (forced alignment). + """ +) +class Qwen3ASRForTokenClassification(GenericForTokenClassification, Qwen3ASRPreTrainedModel): + pass + + +__all__ = [ + "Qwen3ASREncoder", + "Qwen3ASRForConditionalGeneration", + "Qwen3ASRModel", + "Qwen3ASRPreTrainedModel", + "Qwen3ASRForTokenClassification", +] diff --git a/src/transformers/models/qwen3_asr/modular_qwen3_asr.py b/src/transformers/models/qwen3_asr/modular_qwen3_asr.py new file mode 100644 index 000000000000..e09ffb030fd1 --- /dev/null +++ b/src/transformers/models/qwen3_asr/modular_qwen3_asr.py @@ -0,0 +1,316 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F +from huggingface_hub.dataclasses import strict +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...configuration_utils import PreTrainedConfig +from ...modeling_layers import GenericForTokenClassification +from ...modeling_outputs import BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.output_capturing import capture_outputs +from ..audioflamingo3.modeling_audioflamingo3 import AudioFlamingo3ForConditionalGeneration, AudioFlamingo3Model +from ..auto import CONFIG_MAPPING, AutoConfig +from ..qwen2_audio.modeling_qwen2_audio import Qwen2AudioPreTrainedModel +from ..qwen3_omni_moe.configuration_qwen3_omni_moe import Qwen3OmniMoeAudioEncoderConfig +from ..qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeAudioEncoder, + Qwen3OmniMoeAudioEncoderLayer, + SinusoidsPositionEmbedding, + get_audio_cu_seqlens, +) +from ..voxtral.modeling_voxtral import VoxtralMultiModalProjector + + +@auto_docstring(checkpoint="bezzam/Qwen3-ASR-1.7B-hf") +@strict +class Qwen3ASREncoderConfig(Qwen3OmniMoeAudioEncoderConfig): + r""" + n_window (`int`, *optional*, defaults to 50): + Half the number of mel frames in one encoder chunk. Each chunk processed by the conv stack has + ``2 * n_window`` mel frames (1 second of audio at 16 kHz with a 10 ms hop). + output_dim (`int`, *optional*, defaults to 3584): + Dimensionality of the output. + n_window_infer (`int`, *optional*, defaults to 800): + Number of mel frames worth of audio over which each attention window spans. Must be a multiple + of ``n_window * 2`` so attention windows align with encoder chunks. + downsample_hidden_size (`int`, *optional*, defaults to 480): + Hidden size of the convolutional downsampling stack. + """ + + model_type = "qwen3_asr_encoder" + encoder_layers: int = 24 + encoder_attention_heads: int = 16 + encoder_ffn_dim: int = 4096 + d_model: int = 1024 + max_position_embeddings: int = 13 + conv_chunksize = AttributeError() + max_source_positions = AttributeError() + + +@auto_docstring(checkpoint="bezzam/Qwen3-ASR-1.7B-hf") +@strict +class Qwen3ASRConfig(PreTrainedConfig): + r""" + audio_token_id (`int`, *optional*, defaults to 151676): + The audio token id to encode the audio prompt. + timestamp_token_id (`int`, *optional*, defaults to 151705): + Token ID of the ```` marker in the tokenizer vocabulary. These markers + delimit word boundaries in the forced-alignment input sequence. + token_classification_bias (`bool`, *optional*, defaults to False): + Whether the token classification head for forced alignment should have a bias term. + + Example: + + ```python + >>> from transformers import Qwen3ASRForConditionalGeneration, Qwen3ASRConfig + + >>> # Initializing a Qwen3ASR style configuration + >>> configuration = Qwen3ASRConfig() + + >>> # Initializing a model from the configuration + >>> model = Qwen3ASRForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_asr" + sub_configs = {"audio_config": AutoConfig, "text_config": AutoConfig} + + audio_config: dict | PreTrainedConfig | None = None + text_config: dict | PreTrainedConfig | None = None + audio_token_id: int = 151676 + timestamp_token_id: int = 151705 + pad_token_id: int = 151645 + eos_token_id: list[int] | tuple[int, ...] | int = (151643, 151645) + initializer_range: float = 0.02 + tie_word_embeddings: bool = True + token_classification_bias: bool = False + + def __post_init__(self, **kwargs): + if isinstance(self.audio_config, dict): + self.audio_config["model_type"] = self.audio_config.get("model_type", "qwen3_asr_encoder") + self.audio_config = CONFIG_MAPPING[self.audio_config["model_type"]](**self.audio_config) + elif self.audio_config is None: + self.audio_config = CONFIG_MAPPING["qwen3_asr_encoder"]() + + if isinstance(self.text_config, dict): + self.text_config["model_type"] = self.text_config.get("model_type", "qwen3") + self.text_config = CONFIG_MAPPING[self.text_config["model_type"]](**self.text_config) + elif self.text_config is None: + self.text_config = CONFIG_MAPPING["qwen3"]( + hidden_size=2048, + intermediate_size=6144, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=8, + head_dim=128, + max_position_embeddings=65536, + tie_word_embeddings=True, + ) + + super().__post_init__(**kwargs) + + +@auto_docstring +class Qwen3ASRPreTrainedModel(Qwen2AudioPreTrainedModel): + _no_split_modules = ["Qwen3ASREncoderLayer", "Qwen3DecoderLayer"] + _can_compile_fullgraph = True + _supports_attention_backend = True + + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + if isinstance(module, SinusoidsPositionEmbedding): + position_embeddings = module.compute_default_singular_positional_embedding() + init.copy_(module.positional_embedding, position_embeddings) + + +class Qwen3ASRAudioEncoderLayer(Qwen3OmniMoeAudioEncoderLayer): + def __init__(self, config: Qwen3ASREncoderConfig): + super().__init__(config) + + +@auto_docstring( + custom_intro=""" + The audio model for Qwen3 ASR without any head or projection on top. + """ +) +class Qwen3ASREncoder(Qwen3OmniMoeAudioEncoder): + config: Qwen3ASREncoderConfig + + def __init__(self, config: Qwen3ASREncoderConfig): + super().__init__(config) + embed_dim = config.d_model + self.positional_embedding = SinusoidsPositionEmbedding(config.max_position_embeddings, embed_dim) + del self.conv_chunksize + del self.proj1 + del self.act + del self.proj2 + del self.max_source_positions + del self.num_mel_bins + + @staticmethod + def _post_cnn_length(lengths: torch.Tensor) -> torch.Tensor: + """Length after three (k=3, s=2, p=1) convolutions; zero-length input stays zero.""" + for _ in range(3): + lengths = torch.where(lengths > 0, (lengths - 1) // 2 + 1, torch.zeros_like(lengths)) + return lengths + + @capture_outputs(tie_last_hidden_states=False) + @auto_docstring + def forward( + self, + input_features: torch.Tensor, + input_features_mask: torch.Tensor, + **kwargs, + ) -> BaseModelOutputWithPooling: + r""" + input_features_mask (`torch.LongTensor` of shape `(batch_size, padded_feature_length)`): + 1 for valid mel frames and 0 for padding. + """ + batch_size, num_mel_bins, padded_feature_length = input_features.shape + chunk_len = self.n_window * 2 + + if padded_feature_length % chunk_len != 0: + raise ValueError( + f"Qwen3ASREncoder expects `padded_feature_length` to be a multiple of " + f"`n_window * 2` ({chunk_len}), but got {padded_feature_length}." + ) + + num_chunks = padded_feature_length // chunk_len + + # Compute cu_seqlens for windowed attention + feature_lens = input_features_mask.sum(-1).to(torch.long) + chunk_lengths = ( + input_features_mask.view(batch_size, num_chunks, chunk_len).sum(dim=-1).reshape(-1).to(torch.long) + ) + cu_seqlens = get_audio_cu_seqlens(chunk_lengths, feature_lens, self.n_window_infer, self.n_window) + + # Chunk and process through CNN + chunked = ( + input_features.view(batch_size, num_mel_bins, num_chunks, chunk_len) + .permute(0, 2, 1, 3) + .reshape(batch_size * num_chunks, 1, num_mel_bins, chunk_len) + ) + + conv_out = F.gelu(self.conv2d1(chunked)) + conv_out = F.gelu(self.conv2d2(conv_out)) + conv_out = F.gelu(self.conv2d3(conv_out)) + total_chunks, conv_channels, freq_bins, time_steps = conv_out.size() + conv_out = self.conv_out( + conv_out.permute(0, 3, 1, 2).contiguous().view(total_chunks, time_steps, conv_channels * freq_bins) + ) + conv_out += self.positional_embedding.positional_embedding[:time_steps].to(conv_out.dtype) + + # Select only valid (non-padding) post-CNN positions into a flat packed sequence + chunk_post_cnn_lens = self._post_cnn_length( + input_features_mask.view(batch_size, num_chunks, chunk_len).sum(dim=-1).reshape(-1).to(torch.long) + ) + valid_mask = torch.arange(time_steps, device=input_features.device) < chunk_post_cnn_lens.unsqueeze(1) + valid_indices = valid_mask.flatten().nonzero().squeeze(-1) + hidden_states = torch.index_select(conv_out.reshape(-1, conv_out.shape[-1]), 0, valid_indices) + + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs) + hidden_states = layer_outputs[0] + + hidden_states = self.ln_post(hidden_states) + return BaseModelOutputWithPooling(last_hidden_state=hidden_states) + + +class Qwen3ASRMultiModalProjector(VoxtralMultiModalProjector): + def __init__(self, config: Qwen3ASRConfig): + super().__init__(config) + self.linear_1 = nn.Linear(config.audio_config.d_model, config.audio_config.d_model) + self.act = ACT2FN[config.audio_config.activation_function] + self.linear_2 = nn.Linear(config.audio_config.d_model, config.audio_config.output_dim) + + +class Qwen3ASRModel(AudioFlamingo3Model): + @can_return_tuple + @auto_docstring( + custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram)." + ) + def get_audio_features( + self, + input_features: torch.FloatTensor, + input_features_mask: torch.LongTensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + input_features_mask (`torch.LongTensor` of shape `(batch_size, padded_feature_length)`): + 1 for valid mel frames and 0 for padding. + """ + audio_output = self.audio_tower( + input_features=input_features, + input_features_mask=input_features_mask, + **kwargs, + ) + audio_output.pooler_output = self.multi_modal_projector(audio_output.last_hidden_state) + return audio_output + + +@auto_docstring( + custom_intro=""" + The Qwen3ASR model which consists of an audio encoder and a language model. + """ +) +class Qwen3ASRForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + _keep_in_fp32_modules_strict = AttributeError() + + def forward(self, **super_kwargs): + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. + + Example: + + ```python + >>> from transformers import Qwen3ASRForConditionalGeneration, AutoProcessor + + >>> model_id = "bezzam/Qwen3-ASR-1.7B-hf" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = Qwen3ASRForConditionalGeneration.from_pretrained(model_id, device_map="auto") + ```""" + return super().forward(**super_kwargs) + + +@auto_docstring( + custom_intro=""" + The Qwen3 ASR model with a token classification head for timestamp prediction (forced alignment). + """ +) +class Qwen3ASRForTokenClassification(GenericForTokenClassification, Qwen3ASRPreTrainedModel): + pass + + +__all__ = [ + "Qwen3ASREncoderConfig", + "Qwen3ASRConfig", + "Qwen3ASREncoder", + "Qwen3ASRForConditionalGeneration", + "Qwen3ASRModel", + "Qwen3ASRPreTrainedModel", + "Qwen3ASRForTokenClassification", +] diff --git a/src/transformers/models/qwen3_asr/processing_qwen3_asr.py b/src/transformers/models/qwen3_asr/processing_qwen3_asr.py new file mode 100644 index 000000000000..2c20b3a39d77 --- /dev/null +++ b/src/transformers/models/qwen3_asr/processing_qwen3_asr.py @@ -0,0 +1,812 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unicodedata + +import numpy as np + +from ...audio_utils import AudioInput, make_list_of_audio_chat_template +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import TextInput +from ...utils import auto_docstring +from ...utils.import_utils import is_nagisa_available, is_soynlp_available + + +# fmt: off +# The ASR model was trained with these full names as system prompts. +LANGUAGE_CODE_TO_NAME = { + "ar": "Arabic", + "yue": "Cantonese", + "zh": "Chinese", + "cs": "Czech", + "da": "Danish", + "nl": "Dutch", + "en": "English", + "fil": "Filipino", + "fi": "Finnish", + "fr": "French", + "de": "German", + "el": "Greek", + "hi": "Hindi", + "hu": "Hungarian", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "ko": "Korean", + "mk": "Macedonian", + "ms": "Malay", + "fa": "Persian", + "pl": "Polish", + "pt": "Portuguese", + "ro": "Romanian", + "ru": "Russian", + "es": "Spanish", + "sv": "Swedish", + "th": "Thai", + "tr": "Turkish", + "vi": "Vietnamese", +} + +# The forced aligner supports a subset of the ASR languages. +FORCED_ALIGNER_LANGUAGES = { + "Chinese", "English", "Cantonese", "French", "German", + "Italian", "Japanese", "Korean", "Portuguese", "Russian", "Spanish", +} +# fmt: on + +SUPPORTED_LANGUAGE_NAMES = set(LANGUAGE_CODE_TO_NAME.values()) + + +def resolve_language(language: str | None) -> str | None: + """Map a language code or name to the canonical full name, with validation. + + Accepts language codes (e.g. ``"zh"``, ``"en"``) or full names + (e.g. ``"Chinese"``, ``"English"``). Returns the full name. + Raises ``ValueError`` if the language is not recognized. + ``None`` passes through unchanged (auto-detect). + """ + if language is None: + return None + # Try code lookup first + resolved = LANGUAGE_CODE_TO_NAME.get(language.lower()) + if resolved is not None: + return resolved + # Check if it's already a valid full name (case-insensitive) + for name in SUPPORTED_LANGUAGE_NAMES: + if language.lower() == name.lower(): + return name + raise ValueError( + f"Unsupported language: {language!r}. Use a language code " + f"(e.g. 'en', 'zh') or full name (e.g. 'English', 'Chinese'). " + f"Supported codes: {sorted(LANGUAGE_CODE_TO_NAME.keys())}. " + f"Supported names: {sorted(SUPPORTED_LANGUAGE_NAMES)}." + ) + + +def _prepare_language_inputs( + language: str | list[str] | None, batch_size: int, allow_broadcast: bool = False +) -> list[str | None]: + """Broadcast / validate a language argument to match batch_size. + + Accepts language codes (e.g. ``"zh"``, ``"en"``) or full names + (e.g. ``"Chinese"``, ``"English"``). Each value is resolved to the + canonical full language name via :func:`resolve_language`. + """ + if language is None: + return [None] * batch_size + if isinstance(language, str): + return [resolve_language(language)] * batch_size + if isinstance(language, (list, tuple)): + if allow_broadcast and len(language) == 1 and batch_size > 1: + return [resolve_language(language[0])] * batch_size + if len(language) != batch_size: + raise ValueError(f"Got {len(language)} language(s) for {batch_size} sample(s); counts must match.") + return [resolve_language(lang) for lang in language] + raise TypeError("`language` must be a string, a list of strings, or `None`.") + + +def _audio_content_item(audio_item) -> dict: + """Build a chat-template content dict for a single audio item.""" + if isinstance(audio_item, str): + return {"type": "audio", "path": audio_item} + return {"type": "audio", "audio": audio_item} + + +def _is_cjk_char(char: str) -> bool: + """ + Return True for Chinese-Japanese-Korean (CJK) ideograph characters. + Original: https://github.com/QwenLM/Qwen3-ASR/blob/c17a131fe028b2e428b6e80a33d30bb4fa57b8df/qwen_asr/inference/qwen3_forced_aligner.py#L62 + """ + codepoint = ord(char) + return ( + (0x4E00 <= codepoint <= 0x9FFF) + or (0x3400 <= codepoint <= 0x4DBF) + or (0x20000 <= codepoint <= 0x2A6DF) + or (0x2A700 <= codepoint <= 0x2B73F) + or (0x2B740 <= codepoint <= 0x2B81F) + or (0x2B820 <= codepoint <= 0x2CEAF) + or (0xF900 <= codepoint <= 0xFAFF) + or (0x2F800 <= codepoint <= 0x2FA1F) + ) + + +def _is_kept_char(char: str) -> bool: + """Return True for characters kept during forced-alignment tokenisation (letters, numbers, apostrophes, CJK).""" + if char == "'": + return True + category = unicodedata.category(char) + return category.startswith("L") or category.startswith("N") or _is_cjk_char(char) + + +def _clean_tokens(raw_tokens) -> list[str]: + """Filter each raw token to kept characters, dropping empty results.""" + return [cleaned for token in raw_tokens if (cleaned := "".join(char for char in token if _is_kept_char(char)))] + + +def _parse_single_output(text: str) -> dict: + """Parse a single decoded ASR string into language + transcription like the original implementation.""" + if text is None or not str(text).strip(): + return {"language": None, "transcription": ""} + text = str(text).strip() + + if "assistant\n" in text: + text = text.split("assistant\n", 1)[-1] + + # Apply repetition fix from original implementation + text = _detect_and_fix_repetitions(text) + + marker = "" + if marker not in text: + # No tag β€” treat the whole string as plain transcription + return {"language": None, "transcription": text.strip()} + + prefix, transcription = text.split(marker, 1) + prefix = prefix.strip() + + # Empty-audio heuristic: "language None" + if prefix.lower() == "language none": + return {"language": None, "transcription": transcription.strip()} + + language = None + for line in prefix.splitlines(): + line = line.strip() + if not line: + continue + if line.lower().startswith("language "): + val = line[len("language ") :].strip() + if val: + language = val + else: + language = line + break # only inspect the first non-empty line, matching the original + + return {"language": language or None, "transcription": transcription.strip()} + + +def _fix_timestamps(raw: np.ndarray) -> list[int]: + """ + Ensure predicted timestamps are monotonically increasing. + + The model may predict out-of-order timestamps. This method: + 1. Finds the longest increasing subsequence (LIS) β€” these are "good" timestamps. + 2. Marks everything else as an outlier. + 3. Fills outlier blocks by snapping short blocks (\u22642) to the nearest + good neighbour, or linearly interpolating longer blocks between + the surrounding good values. + + Original: https://github.com/QwenLM/Qwen3-ASR/blob/c17a131fe028b2e428b6e80a33d30bb4fa57b8df/qwen_asr/inference/qwen3_forced_aligner.py#L147 + """ + data = raw.tolist() + num_values = len(data) + + # --- Step 1: find the longest increasing subsequence (LIS) via O(n^2) DP --- + # dp[idx] = length of the LIS ending at index idx + # parent[idx] = previous index in that LIS (-1 = start of chain) + dp = [1] * num_values + parent = [-1] * num_values + + for current in range(1, num_values): + for prev in range(current): + if data[prev] <= data[current] and dp[prev] + 1 > dp[current]: + dp[current] = dp[prev] + 1 + parent[current] = prev + + # --- Step 2: backtrack to recover LIS indices and mark them as "normal" --- + max_length = max(dp) + max_idx = dp.index(max_length) + + lis_indices = [] + idx = max_idx + while idx != -1: + lis_indices.append(idx) + idx = parent[idx] + lis_indices.reverse() + + is_normal = [False] * num_values + for idx in lis_indices: + is_normal[idx] = True + + # --- Step 3: replace outlier blocks with interpolated / snapped values --- + result = data.copy() + block_start = 0 + + while block_start < num_values: + if is_normal[block_start]: + block_start += 1 + continue + + # Scan forward to find the end of this contiguous outlier block + block_end = block_start + while block_end < num_values and not is_normal[block_end]: + block_end += 1 + + anomaly_count = block_end - block_start + + if anomaly_count <= 2: + # Short block: snap each position to the closer good neighbour + left_val = None + for scan in range(block_start - 1, -1, -1): + if is_normal[scan]: + left_val = result[scan] + break + + right_val = None + for scan in range(block_end, num_values): + if is_normal[scan]: + right_val = result[scan] + break + + for pos in range(block_start, block_end): + if left_val is None: + result[pos] = right_val + elif right_val is None: + result[pos] = left_val + else: + result[pos] = left_val if (pos - (block_start - 1)) <= (block_end - pos) else right_val + + else: + # Long block: linearly interpolate between the surrounding good values + left_val = None + for scan in range(block_start - 1, -1, -1): + if is_normal[scan]: + left_val = result[scan] + break + + right_val = None + for scan in range(block_end, num_values): + if is_normal[scan]: + right_val = result[scan] + break + + if left_val is not None and right_val is not None: + step = (right_val - left_val) / (anomaly_count + 1) + for pos in range(block_start, block_end): + result[pos] = left_val + step * (pos - block_start + 1) + elif left_val is not None: + for pos in range(block_start, block_end): + result[pos] = left_val + elif right_val is not None: + for pos in range(block_start, block_end): + result[pos] = right_val + + block_start = block_end + + return [int(val) for val in result] + + +def _detect_and_fix_repetitions(text, threshold=20): + """ + Original implementation uses this post-processing to remove repeated characters and patterns in the ASR output + https://github.com/QwenLM/Qwen3-ASR/blob/c17a131fe028b2e428b6e80a33d30bb4fa57b8df/qwen_asr/inference/utils.py#L432 + """ + + def fix_char_repeats(s, thresh): + res = [] + i = 0 + n = len(s) + while i < n: + count = 1 + while i + count < n and s[i + count] == s[i]: + count += 1 + + if count > thresh: + res.append(s[i]) + i += count + else: + res.append(s[i : i + count]) + i += count + return "".join(res) + + def fix_pattern_repeats(s, thresh, max_len=20): + n = len(s) + min_repeat_chars = thresh * 2 + if n < min_repeat_chars: + return s + + i = 0 + result = [] + while i <= n - min_repeat_chars: + found = False + for k in range(1, max_len + 1): + if i + k * thresh > n: + break + + pattern = s[i : i + k] + valid = True + for rep in range(1, thresh): + start_idx = i + rep * k + if s[start_idx : start_idx + k] != pattern: + valid = False + break + + if valid: + total_rep = thresh + end_index = i + thresh * k + while end_index + k <= n and s[end_index : end_index + k] == pattern: + total_rep += 1 + end_index += k + result.append(pattern) + result.append(fix_pattern_repeats(s[end_index:], thresh, max_len)) + i = n + found = True + break + + if found: + break + else: + result.append(s[i]) + i += 1 + + if not found: + result.append(s[i:]) + return "".join(result) + + text_raw = text + text = fix_char_repeats(text_raw, threshold) + text = fix_pattern_repeats(text, threshold) + return text + + +class Qwen3ASRProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": True, + "padding_side": "left", + }, + "audio_kwargs": { + "sampling_rate": 16000, + "padding": True, + "truncation": False, + "return_attention_mask": True, + "n_window": 50, # should match config.n_window + }, + "common_kwargs": {"return_tensors": "pt"}, + } + + +@auto_docstring +class Qwen3ASRProcessor(ProcessorMixin): + valid_processor_kwargs = Qwen3ASRProcessorKwargs + + def __init__( + self, + feature_extractor=None, + tokenizer=None, + chat_template=None, + timestamp_segment_time: float = 80, + ): + r""" + timestamp_segment_time (`float`, *optional*): + Milliseconds per timestamp class. Defaults to 80 ms. + """ + super().__init__(feature_extractor, tokenizer, chat_template=chat_template) + self.timestamp_segment_time = timestamp_segment_time + self.audio_token = self.tokenizer.audio_token + self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token) + self.audio_bos_token = self.tokenizer.audio_bos_token + self.audio_bos_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_bos_token) + self.audio_eos_token = self.tokenizer.audio_eos_token + self.audio_eos_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_eos_token) + + @auto_docstring + def __call__( + self, + text: TextInput | list[TextInput], + audio: AudioInput, + output_labels: bool | None = False, + **kwargs: Unpack[Qwen3ASRProcessorKwargs], + ) -> BatchFeature: + r""" + output_labels (bool, *optional*, default=False): + Whether to return labels for training. + + Returns: + [`BatchFeature`]: A dictionary with tokenized text (`input_ids`, `attention_mask`) and + audio features (`input_features`, `input_features_mask`). + """ + if "return_tensors" in kwargs and kwargs["return_tensors"] != "pt": + raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.") + + if output_labels: + kwargs["return_mm_token_type_ids"] = True + model_inputs = super().__call__(audio=audio, text=text, **kwargs) + + if output_labels: + labels = model_inputs.pop("mm_token_type_ids") + for token_id in [ + self.audio_token_id, + self.tokenizer.pad_token_id, + self.audio_bos_token_id, + self.audio_eos_token_id, + ]: + labels[labels == token_id] = -100 + model_inputs["labels"] = labels + + return BatchFeature(data=model_inputs, tensor_type="pt") + + def validate_inputs( + self, + audio: AudioInput | None = None, + text: TextInput | list[TextInput] | None = None, + **kwargs: Unpack[ProcessingKwargs], + ): + super().validate_inputs(audio=audio, text=text, **kwargs) + + if text is not None and audio is not None and len(text) != len(audio): + raise ValueError(f"Got {len(text)} text but {len(audio)} audios; they must match 1:1.") + + def _get_audio_token_length(self, audio_lengths, n_window=50): + chunk_len = n_window * 2 + remainder = audio_lengths % chunk_len # mel frames in the final partial chunk + feat_lengths = (remainder - 1) // 2 + 1 # after first conv (stride 2) + per_chunk_tokens = (feat_lengths - 1) // 2 + 1 # after second conv (stride 2) + token_lengths = ( + (per_chunk_tokens - 1) // 2 + 1 + (audio_lengths // chunk_len) * 13 + ) # after third conv + full chunks + return token_lengths.cpu().numpy() + + def _process_audio(self, audio: AudioInput, **kwargs): + n_window = kwargs.get("n_window", 50) + audio_inputs = self.feature_extractor(audio, **kwargs) + audio_inputs["input_features_mask"] = audio_inputs.pop("attention_mask") + + audio_lengths = self._get_audio_token_length(audio_inputs["input_features_mask"].sum(-1), n_window) + audio_inputs["num_audio_tokens"] = audio_lengths + + audio_replacements = [self.replace_audio_token(audio_inputs, idx) for idx in range(len(audio))] + return audio_inputs, audio_replacements + + def replace_audio_token(self, audio_inputs: dict, audio_idx: int) -> str: + num_tokens = int(audio_inputs["num_audio_tokens"][audio_idx]) + return self.audio_token * num_tokens + + def apply_transcription_request( + self, + audio: AudioInput | list[AudioInput], + language: str | list[str] | None = None, + **kwargs, + ) -> BatchFeature: + """ + Prepare inputs for automatic speech recognition without manually writing the chat template. + + Args: + audio (`AudioInput` or `list[AudioInput]`): + Audio to transcribe. Can be a URL string, local path, numpy array, or a list of these. + language (`str` or `list[str]`, *optional*): + Language hint(s) to include in the system prompt. Accepts full names + (e.g. ``"English"``, ``"Chinese"``) or ISO codes (e.g. ``"en"``, ``"zh"``). + A list must be the same length as the audio batch. + When ``None``, the model performs automatic language detection. + **kwargs: + Additional keyword arguments forwarded to + [`~Qwen3ASRProcessor.apply_chat_template`]. + + Returns: + [`BatchFeature`]: Processor outputs ready to be passed to + [`Qwen3ASRForConditionalGeneration.generate`]. + """ + audio_items = make_list_of_audio_chat_template(audio) + batch_size = len(audio_items) + if batch_size == 0: + raise ValueError("`audio` must contain at least one sample.") + languages = _prepare_language_inputs(language, batch_size) + + conversations = [] + for lang, audio_item in zip(languages, audio_items): + messages = [] + if lang is not None: + messages.append({"role": "system", "content": [{"type": "text", "text": lang}]}) + messages.append({"role": "user", "content": [_audio_content_item(audio_item)]}) + conversations.append(messages) + + return self.apply_chat_template( + conversations, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + **kwargs, + ) + + def decode(self, *args, return_format="raw", **kwargs): + """ + Forward arguments to the tokenizer's decode and optionally parse the ASR output. + + Qwen3 ASR outputs transcription in the format: ``language transcribed text`` + + Args: + return_format (`str`, *optional*, defaults to `"raw"`): + Options: + + - ``"raw"``: Return raw decoded strings from the tokenizer. + - ``"parsed"``: Return a dict (or list of dicts) with ``"language"`` and ``"transcription"`` keys. + - ``"transcription_only"``: Extract only the transcribed text (after ````). + + ``skip_special_tokens`` is hard-set to ``True`` for ``"parsed"`` and ``"transcription_only"``. + """ + valid_formats = ["raw", "parsed", "transcription_only"] + if return_format not in valid_formats: + raise ValueError(f"return_format must be one of {valid_formats}.") + if return_format != "raw": + kwargs["skip_special_tokens"] = True + + decoded = self.tokenizer.decode(*args, **kwargs) + if return_format == "parsed": + decoded = self.parse_output(decoded) + elif return_format == "transcription_only": + decoded = self.extract_transcription(decoded) + return decoded + + def parse_output(self, text: str | list[str]) -> dict | list[dict]: + """ + Parse Qwen3 ASR raw output into a structured dict. + + The model outputs ``language transcribed text``. + This method returns a dict with ``"language"`` and ``"transcription"`` keys. + + Args: + text (`str` or `list[str]`): Raw decoded output(s). + + Returns: + `dict` or `list[dict]`: Parsed output(s). Each dict has keys + ``"language"`` (str or None) and ``"transcription"`` (str). + Returns the original string as the transcription if parsing fails. + """ + if isinstance(text, str): + return _parse_single_output(text) + return [_parse_single_output(raw_text) for raw_text in text] + + def extract_transcription(self, text: str | list[str]) -> str | list[str]: + """ + Extract transcription text from Qwen3 ASR raw output. + + The model outputs ``language transcribed text``. + This method extracts the text after ````. + + Args: + text (`str` or `list[str]`): Raw decoded output(s). + + Returns: + `str` or `list[str]`: Extracted transcription(s). Returns the + original string if ```` is not found. + """ + if isinstance(text, str): + return _parse_single_output(text)["transcription"] + return [_parse_single_output(raw_text)["transcription"] for raw_text in text] + + def split_words_for_alignment(self, text: str | list[str], language: str | None = None) -> list[str]: + """ + Split text into word-level tokens suitable for forced alignment. + Original: https://github.com/QwenLM/Qwen3-ASR/blob/c17a131fe028b2e428b6e80a33d30bb4fa57b8df/qwen_asr/inference/qwen3_forced_aligner.py#L101-L145 + + The tokenization strategy depends on the language: + + - **Japanese**: Uses the ``nagisa`` library for morphological analysis + (install with ``pip install nagisa``). + - **Korean**: Uses the ``soynlp`` library for tokenization + (install with ``pip install soynlp``). + - **All other languages** (including Chinese): CJK characters are emitted + individually; space-delimited scripts produce whole words. Punctuation + is dropped. + + Args: + text (`str`): Transcript text. + language (`str` or `None`, *optional*): + Language of the transcript. Accepts full names (e.g. ``"Japanese"``, + ``"English"``) or codes (e.g. ``"ja"``, ``"en"``). When ``None``, + falls back to the default CJK / space-based tokenizer. + + Returns: + `list[str]`: Word-level tokens. + """ + text = text.strip() + lang = language.lower() if language else "" + + if lang == "japanese": + if not is_nagisa_available(): + raise ImportError( + "Japanese forced alignment requires the `nagisa` package. Install it with: pip install nagisa" + ) + import nagisa + + return _clean_tokens(nagisa.tagging(text).words) + + if lang == "korean": + if not is_soynlp_available(): + raise ImportError( + "Korean forced alignment requires the `soynlp` package. Install it with: pip install soynlp" + ) + from soynlp.tokenizer import LTokenizer + + return _clean_tokens(LTokenizer().tokenize(text)) + + # Default: CJK characters individually, space-delimited words otherwise + tokens: list[str] = [] + char_buffer: list[str] = [] + + def flush_buffer(): + if char_buffer: + word = "".join(char_buffer) + if word: + tokens.append(word) + char_buffer.clear() + + for char in text: + if _is_cjk_char(char): + flush_buffer() + tokens.append(char) + elif char.isspace(): + flush_buffer() + elif _is_kept_char(char): + char_buffer.append(char) + flush_buffer() + return tokens + + def prepare_forced_aligner_inputs( + self, + audio: AudioInput, + transcript: str | list[str], + language: str | list[str] | None = None, + **kwargs, + ) -> tuple[BatchFeature, list[list[str]]]: + """ + Prepare inputs for the forced aligner model. + + Args: + audio (`AudioInput`): + Audio input(s). Accepts paths, URLs, numpy arrays, or a list of these. + transcript (`str` or `list[str]`): + Transcript(s) to align against the audio. + language (`str`, `list[str]`, or `None`, *optional*): + Language hint(s). Currently unused in tokenization but reserved for + language-specific tokenizers (e.g. Japanese, Korean). + **kwargs: + Additional keyword arguments forwarded to + [`~Qwen3ASRProcessor.apply_chat_template`]. + + Returns: + `tuple[BatchFeature, list[list[str]]]`: + - ``inputs``: A [`BatchFeature`] with ``input_ids``, ``attention_mask``, + ``input_features``, and ``input_features_mask`` ready for the forced + aligner model. + - ``word_lists``: A list (one per sample) of word-level token lists used + to build the input. Pass these to + [`~Qwen3ASRProcessor.decode_forced_alignment`] to pair timestamps + with words. + """ + if isinstance(transcript, str): + transcript = [transcript] + + audio_items = make_list_of_audio_chat_template(audio) + batch_size = len(audio_items) + if len(transcript) != batch_size: + raise ValueError(f"Got {len(transcript)} transcript(s) but {batch_size} audio(s); they must match 1:1.") + + languages = _prepare_language_inputs(language, batch_size, allow_broadcast=True) + + # Validate that all languages are supported by the forced aligner + for lang in languages: + if lang is not None and lang not in FORCED_ALIGNER_LANGUAGES: + aligner_codes = sorted( + code for code, name in LANGUAGE_CODE_TO_NAME.items() if name in FORCED_ALIGNER_LANGUAGES + ) + raise ValueError( + f"Language {lang!r} is not supported by the forced aligner. " + f"Supported languages: {sorted(FORCED_ALIGNER_LANGUAGES)} " + f"(codes: {aligner_codes})." + ) + + word_lists = [self.split_words_for_alignment(t, lang) for t, lang in zip(transcript, languages)] + + conversations = [] + for wl, audio_item in zip(word_lists, audio_items): + content = [_audio_content_item(audio_item)] + content.extend({"type": "text", "text": word} for word in wl) + conversations.append([{"role": "user", "content": content}]) + + inputs = self.apply_chat_template( + conversations, + tokenize=True, + return_dict=True, + **kwargs, + ) + + return inputs, word_lists + + def decode_forced_alignment( + self, + logits, + input_ids, + word_lists: list[list[str]], + timestamp_token_id: int, + timestamp_segment_time: float | None = None, + ) -> list[list[dict]]: + """ + Decode forced aligner model outputs into word-level timestamps. + + Args: + logits (`torch.Tensor` of shape `(batch_size, seq_len, num_labels)`): + Classification logits from [`Qwen3ASRForTokenClassification`]. + input_ids (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Input token IDs used for the forward pass. + word_lists (`list[list[str]]`): + Word-level token lists as returned by + [`~Qwen3ASRProcessor.prepare_forced_aligner_inputs`]. + timestamp_token_id (`int`): + Token ID of the ```` marker (from + ``model.config.timestamp_token_id``). + timestamp_segment_time (`float`, *optional*): + Milliseconds per timestamp class. If not provided, uses `self.timestamp_segment_time`. + + Returns: + `list[list[dict]]`: One list per sample. Each inner list contains dicts + with keys ``"text"`` (`str`), ``"start_time"`` (`float`, seconds), and + ``"end_time"`` (`float`, seconds). + """ + if timestamp_segment_time is None: + timestamp_segment_time = self.timestamp_segment_time + pred_ids = logits.argmax(dim=-1) + batch_results = [] + + for sample_idx, word_list in enumerate(word_lists): + mask = input_ids[sample_idx] == timestamp_token_id + masked_pred = pred_ids[sample_idx][mask] + raw_ms = (masked_pred.float() * timestamp_segment_time).cpu().numpy() + fixed_ms = _fix_timestamps(raw_ms) + + items = [ + { + "text": word, + "start_time": round(fixed_ms[word_idx * 2] / 1000.0, 3), + "end_time": round(fixed_ms[word_idx * 2 + 1] / 1000.0, 3), + } + for word_idx, word in enumerate(word_list) + ] + batch_results.append(items) + + return batch_results + + @property + def unused_input_names(self) -> list[str]: + "Input names returned always by subprocessors but not used in model's `forward`" + return ["num_audio_tokens"] + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + feature_extractor_input_names = self.feature_extractor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names + ["input_features_mask"])) + + +__all__ = ["Qwen3ASRProcessor"] diff --git a/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py index df960a27c1bb..1cd303b7b1cf 100644 --- a/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py @@ -28,22 +28,22 @@ logger = logging.get_logger(__name__) -@auto_docstring(checkpoint="Qwen/Qwen2.5-Omni-7B") +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") @strict class Qwen3OmniMoeAudioEncoderConfig(PreTrainedConfig): r""" max_source_positions (`int`, *optional*, defaults to 1500): Maximum sequence length for the inputs - n_window (`int`, *optional*, defaults to 100): + n_window (`int`, *optional*, defaults to 50): Number of windows output_dim (`int`, *optional*, defaults to 3584): Dimensionality of the output - n_window_infer (`int`, *optional*, defaults to `400`): + n_window_infer (`int`, *optional*, defaults to `800`): Number of windows during inference conv_chunksize (`int`, *optional*, defaults to `500`): Chunk size of each input to convolutional layer downsample_hidden_size (`int`, *optional*, defaults to `480`): - Hidden size in donwsampling layer + Hidden size in downsampling layer """ model_type = "qwen3_omni_moe_audio_encoder" @@ -67,15 +67,14 @@ class Qwen3OmniMoeAudioEncoderConfig(PreTrainedConfig): initializer_range: float = 0.02 max_source_positions: int = 1500 - n_window: int = 100 + n_window: int = 50 output_dim: int = 3584 - - n_window_infer: int = 400 + n_window_infer: int = 800 conv_chunksize: int = 500 downsample_hidden_size: int = 480 -@auto_docstring(checkpoint="Qwen/Qwen3-30B-A3B-Base") +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") @strict class Qwen3OmniMoeVisionEncoderConfig(PreTrainedConfig): r""" @@ -105,7 +104,7 @@ class Qwen3OmniMoeVisionEncoderConfig(PreTrainedConfig): initializer_range: float = 0.02 -@auto_docstring(checkpoint="Qwen/Qwen3-30B-A3B-Base") +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") @strict class Qwen3OmniMoeTextConfig(PreTrainedConfig): r""" @@ -185,7 +184,7 @@ def __post_init__(self, **kwargs): super().__post_init__(**kwargs) -@auto_docstring(checkpoint="Qwen/Qwen3-30B-A3B-Base") +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") @strict class Qwen3OmniMoeThinkerConfig(PreTrainedConfig): r""" @@ -252,7 +251,7 @@ def __post_init__(self, **kwargs): super().__post_init__(**kwargs) -@auto_docstring(checkpoint="Qwen/Qwen3OmniMoeTalkerCodePredictor-8B") +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") @strict class Qwen3OmniMoeTalkerCodePredictorConfig(PreTrainedConfig): r""" @@ -320,7 +319,7 @@ def __post_init__(self, **kwargs): super().__post_init__(**kwargs) -@auto_docstring(checkpoint="Qwen/Qwen3-30B-A3B-Base") +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") @strict class Qwen3OmniMoeTalkerTextConfig(PreTrainedConfig): r""" @@ -413,7 +412,7 @@ def __post_init__(self, **kwargs): super().__post_init__(**kwargs) -@auto_docstring(checkpoint="Qwen/Qwen3-30B-A3B-Base") +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") @strict class Qwen3OmniMoeTalkerConfig(PreTrainedConfig): r""" @@ -508,7 +507,7 @@ def __post_init__(self, **kwargs): super().__post_init__(**kwargs) -@auto_docstring(checkpoint="Qwen/Qwen3-30B-A3B-Base") +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") @strict class Qwen3OmniMoeCode2WavConfig(PreTrainedConfig): r""" @@ -564,7 +563,7 @@ def layer_types(self): return ["sliding_attention"] * self.num_hidden_layers -@auto_docstring(checkpoint="Qwen/Qwen3-30B-A3B-Base") +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") @strict class Qwen3OmniMoeConfig(PreTrainedConfig): r""" @@ -675,10 +674,10 @@ def get_text_config(self, decoder=False) -> "PreTrainedConfig": __all__ = [ + "Qwen3OmniMoeAudioEncoderConfig", "Qwen3OmniMoeConfig", "Qwen3OmniMoeThinkerConfig", "Qwen3OmniMoeTalkerConfig", - "Qwen3OmniMoeAudioEncoderConfig", "Qwen3OmniMoeTalkerCodePredictorConfig", "Qwen3OmniMoeTalkerTextConfig", "Qwen3OmniMoeTextConfig", diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 55a35c594748..fe774490ff0b 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -100,14 +100,14 @@ def __init__(self, length, channels, max_timescale=10000): self.max_timescale = max_timescale if channels % 2 != 0: raise ValueError("SinusoidsPositionEmbedding needs even channels input") - log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) - inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) - scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] - self.register_buffer( - "positional_embedding", - torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), - persistent=False, - ) + position_embedding = self.compute_default_singular_positional_embedding() + self.register_buffer("positional_embedding", position_embedding, persistent=False) + + def compute_default_singular_positional_embedding(self): + log_timescale_increment = np.log(self.max_timescale) / (self.channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(self.channels // 2).float()) + scaled_time = torch.arange(self.length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) def forward(self, seqlen: int): return self.positional_embedding[:seqlen, :] @@ -129,7 +129,7 @@ class Qwen3OmniMoePreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) - std = self.config.initializer_range + std = getattr(self.config, "initializer_range", 0.02) if isinstance(module, Qwen3OmniMoeThinkerTextSparseMoeBlock): init.normal_(module.experts.gate_up_proj, mean=0.0, std=std) init.normal_(module.experts.down_proj, mean=0.0, std=std) @@ -140,24 +140,21 @@ def _init_weights(self, module): torch.arange(module.config.num_quantizers).view(1, -1, 1) * module.config.codebook_size, ) elif isinstance(module, SinusoidsPositionEmbedding): - log_timescale_increment = np.log(module.max_timescale) / (module.channels // 2 - 1) - inv_timescales = torch.exp(-log_timescale_increment * torch.arange(module.channels // 2).float()) - scaled_time = torch.arange(module.length)[:, np.newaxis] * inv_timescales[np.newaxis, :] - init.copy_(module.positional_embedding, torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)) + position_embeddings = module.compute_default_singular_positional_embedding() + init.copy_(module.positional_embedding, position_embeddings) elif isinstance(module, Qwen3OmniMoeVisionRotaryEmbedding): inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim)) init.copy_(module.inv_freq, inv_freq) -def _get_feat_extract_output_lengths(input_lengths): - """Compute output lengths after the 3-layer CNN feature extractor with deepstack. - - Three stride-2 convolutions within each 100-frame block, plus 13 output frames - per full block from the deepstack path. +def _get_feat_extract_output_lengths(input_lengths, n_window=50): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder """ - input_lengths_leave = input_lengths % 100 + chunk_len = n_window * 2 + input_lengths_leave = input_lengths % chunk_len feat_lengths = (input_lengths_leave - 1) // 2 + 1 - return ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + return ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // chunk_len) * 13 class Qwen3OmniMoePreTrainedModelForConditionalGeneration(Qwen3OmniMoePreTrainedModel): @@ -355,7 +352,9 @@ def get_rope_index( st_idx += bos_len # Audio Only if min_ed == ed_audio_start: - audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx]) + audio_len = _get_feat_extract_output_lengths( + audio_seqlens[audio_idx], self.config.audio_config.n_window + ) llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx llm_pos_ids_list.append(llm_pos_ids) @@ -399,7 +398,9 @@ def get_rope_index( # Audio in Video elif min_ed == ed_vision_start and ed_vision_start + 1 == ed_audio_start: - audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx]) + audio_len = _get_feat_extract_output_lengths( + audio_seqlens[audio_idx], self.config.audio_config.n_window + ) audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx grid_t = video_grid_thw[video_idx][0] grid_hs = video_grid_thw[:, 1] @@ -678,11 +679,12 @@ def chunk_and_pad_features( return padded_feature, chunk_lengths -def get_valid_indices(chunk_lengths: torch.Tensor, kwargs: dict | None = None) -> torch.Tensor: +def get_valid_indices(chunk_lengths: torch.Tensor, n_window: int, kwargs: dict | None = None) -> torch.Tensor: """Compute flat indices of valid (non-padding) positions after CNN extraction, or pop `"valid_indices"` from `kwargs` if precomputed. Args: chunk_lengths: `(num_chunks,)` pre-CNN chunk lengths. + n_window: half the chunk size (in raw frames). kwargs: optional caller kwargs β€” if it contains `"valid_indices"` it is popped and returned. Returns: @@ -690,7 +692,7 @@ def get_valid_indices(chunk_lengths: torch.Tensor, kwargs: dict | None = None) - """ if kwargs is not None and (valid_indices := kwargs.pop("valid_indices", None)) is not None: return valid_indices - feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths, n_window) max_len_after_cnn = feature_lens_after_cnn.max().item() mask = torch.arange(max_len_after_cnn, device=chunk_lengths.device) < feature_lens_after_cnn.unsqueeze(1) return mask.flatten().nonzero().squeeze(-1) @@ -721,8 +723,8 @@ def get_audio_cu_seqlens( if kwargs is not None and (cu_seqlens := kwargs.pop("cu_seqlens", None)) is not None: return cu_seqlens - aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) - feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens, n_window) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths, n_window) max_len_after_cnn = feature_lens_after_cnn.max().item() n_window_ratio = n_window_infer // (n_window * 2) @@ -806,7 +808,7 @@ def forward(self, input_features=None, feature_lens=None, **kwargs: Unpack[Trans padded_feature, chunk_lengths = chunk_and_pad_features( input_features, feature_lens, self.n_window, kwargs=kwargs ) - valid_indices = get_valid_indices(chunk_lengths, kwargs=kwargs) + valid_indices = get_valid_indices(chunk_lengths, self.n_window, kwargs=kwargs) cu_seqlens = get_audio_cu_seqlens( chunk_lengths, feature_lens, self.n_window_infer, self.n_window, kwargs=kwargs ) @@ -846,58 +848,6 @@ def forward(self, input_features=None, feature_lens=None, **kwargs: Unpack[Trans hidden_states = self.proj2(hidden_states) return BaseModelOutputWithPooling(last_hidden_state=hidden_states) - # Ignore copy - def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): - """ - Computes the output length of the convolutional layers and the output length of the audio encoder - """ - input_lengths = (input_lengths - 1) // 2 + 1 - output_lengths = (input_lengths - 2) // 2 + 1 - return input_lengths, output_lengths - - def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): - """ - Pads a sequence of tensors to their maximum length on indicated `padding_side`. - Then prepares a mask so that pad tokens are not attended to. - """ - warnings.warn( - f"`{self.__class__.__name__}.padded_and_mask_function` is deprecated and will be removed in v5.11. Use `chunk_and_pad_features` and `get_audio_cu_seqlens` helpers instead.", - FutureWarning, - stacklevel=2, - ) - max_len = tensor_len.max() - dim = tensor_list[0].shape[0] - padded_tensor = torch.full( - size=(len(tensor_list), dim, max_len), - fill_value=padding_value, - dtype=self.dtype, - device=tensor_list[0].device, - ) - - batch_mask = torch.zeros( - (len(tensor_len), max_len), - dtype=torch.long, - device=padded_tensor.device, - ) - for i, length in enumerate(tensor_len): - batch_mask[i, :length] = 1 - padded_tensor[i, :, :length] = tensor_list[i] - - feature_lens_after_cnn = (tensor_len - 1) // 2 + 1 - max_len_after_cnn = feature_lens_after_cnn.max() - batch_mask_after_cnn = torch.zeros( - (len(tensor_len), max_len_after_cnn), - dtype=torch.long, - device=padded_tensor.device, - ) - for i, length in enumerate(feature_lens_after_cnn): - batch_mask_after_cnn[i, :length] = 1 - return ( - padded_tensor, - batch_mask.unsqueeze(1), - batch_mask_after_cnn.bool(), - ) - def rotate_half(x): """Rotates half the hidden dims of the input.""" diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 39f2551e255b..bebc9fc8ccf6 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -103,15 +103,14 @@ logger = logging.get_logger(__name__) -def _get_feat_extract_output_lengths(input_lengths): - """Compute output lengths after the 3-layer CNN feature extractor with deepstack. - - Three stride-2 convolutions within each 100-frame block, plus 13 output frames - per full block from the deepstack path. +def _get_feat_extract_output_lengths(input_lengths, n_window=50): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder """ - input_lengths_leave = input_lengths % 100 + chunk_len = n_window * 2 + input_lengths_leave = input_lengths % chunk_len feat_lengths = (input_lengths_leave - 1) // 2 + 1 - return ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + return ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // chunk_len) * 13 def chunk_and_pad_features( @@ -149,11 +148,12 @@ def chunk_and_pad_features( return padded_feature, chunk_lengths -def get_valid_indices(chunk_lengths: torch.Tensor, kwargs: dict | None = None) -> torch.Tensor: +def get_valid_indices(chunk_lengths: torch.Tensor, n_window: int, kwargs: dict | None = None) -> torch.Tensor: """Compute flat indices of valid (non-padding) positions after CNN extraction, or pop `"valid_indices"` from `kwargs` if precomputed. Args: chunk_lengths: `(num_chunks,)` pre-CNN chunk lengths. + n_window: half the chunk size (in raw frames). kwargs: optional caller kwargs β€” if it contains `"valid_indices"` it is popped and returned. Returns: @@ -161,7 +161,7 @@ def get_valid_indices(chunk_lengths: torch.Tensor, kwargs: dict | None = None) - """ if kwargs is not None and (valid_indices := kwargs.pop("valid_indices", None)) is not None: return valid_indices - feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths, n_window) max_len_after_cnn = feature_lens_after_cnn.max().item() mask = torch.arange(max_len_after_cnn, device=chunk_lengths.device) < feature_lens_after_cnn.unsqueeze(1) return mask.flatten().nonzero().squeeze(-1) @@ -192,8 +192,8 @@ def get_audio_cu_seqlens( if kwargs is not None and (cu_seqlens := kwargs.pop("cu_seqlens", None)) is not None: return cu_seqlens - aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) - feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens, n_window) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths, n_window) max_len_after_cnn = feature_lens_after_cnn.max().item() n_window_ratio = n_window_infer // (n_window * 2) @@ -220,34 +220,37 @@ class BaseModelOutputWithDeepstackFeatures(BaseModelOutputWithPooling): deepstack_features: list[torch.FloatTensor] | None = None +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") +@strict class Qwen3OmniMoeAudioEncoderConfig(Qwen2_5OmniAudioEncoderConfig): r""" max_source_positions (`int`, *optional*, defaults to 1500): Maximum sequence length for the inputs - n_window (`int`, *optional*, defaults to 100): + n_window (`int`, *optional*, defaults to 50): Number of windows output_dim (`int`, *optional*, defaults to 3584): Dimensionality of the output - n_window_infer (`int`, *optional*, defaults to `400`): + n_window_infer (`int`, *optional*, defaults to `800`): Number of windows during inference conv_chunksize (`int`, *optional*, defaults to `500`): Chunk size of each input to convolutional layer downsample_hidden_size (`int`, *optional*, defaults to `480`): - Hidden size in donwsampling layer + Hidden size in downsampling layer """ - n_window_infer: int = 400 + n_window: int = 50 + n_window_infer: int = 800 conv_chunksize: int = 500 downsample_hidden_size: int = 480 -@auto_docstring(checkpoint="Qwen/Qwen3-30B-A3B-Base") +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") @strict class Qwen3OmniMoeVisionEncoderConfig(Qwen3VLMoeVisionConfig): pass -@auto_docstring(checkpoint="Qwen/Qwen3-30B-A3B-Base") +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") @strict class Qwen3OmniMoeTextConfig(PreTrainedConfig): r""" @@ -327,7 +330,7 @@ def __post_init__(self, **kwargs): super().__post_init__(**kwargs) -@auto_docstring(checkpoint="Qwen/Qwen3-30B-A3B-Base") +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") @strict class Qwen3OmniMoeThinkerConfig(Qwen2_5OmniThinkerConfig): r""" @@ -368,6 +371,8 @@ class Qwen3OmniMoeThinkerConfig(Qwen2_5OmniThinkerConfig): audio_end_token_id = AttributeError() +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") +@strict class Qwen3OmniMoeTalkerCodePredictorConfig(Qwen3Config): r""" num_code_groups (`int`, *optional*, defaults to 32): @@ -389,6 +394,8 @@ def __post_init__(self, **kwargs): self.sliding_window = self.sliding_window +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") +@strict class Qwen3OmniMoeTalkerTextConfig(Qwen3MoeConfig): base_model_ep_plan = { "layers.*.mlp.gate": "ep_router", @@ -412,7 +419,7 @@ def __post_init__(self, **kwargs): self.sliding_window = self.sliding_window -@auto_docstring(checkpoint="Qwen/Qwen3-30B-A3B-Base") +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") @strict class Qwen3OmniMoeTalkerConfig(PreTrainedConfig): r""" @@ -507,7 +514,7 @@ def __post_init__(self, **kwargs): super().__post_init__(**kwargs) -@auto_docstring(checkpoint="Qwen/Qwen3-30B-A3B-Base") +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") @strict class Qwen3OmniMoeCode2WavConfig(PreTrainedConfig): r""" @@ -563,7 +570,7 @@ def layer_types(self): return ["sliding_attention"] * self.num_hidden_layers -@auto_docstring(checkpoint="Qwen/Qwen3-30B-A3B-Base") +@auto_docstring(checkpoint="Qwen/Qwen3-Omni-30B-A3B-Instruct") @strict class Qwen3OmniMoeConfig(PreTrainedConfig): r""" @@ -677,7 +684,7 @@ class Qwen3OmniMoePreTrainedModel(Qwen2_5OmniPreTrainedModel, PreTrainedModel): @torch.no_grad() def _init_weights(self, module): PreTrainedModel._init_weights(self, module) - std = self.config.initializer_range + std = getattr(self.config, "initializer_range", 0.02) if isinstance(module, Qwen3OmniMoeThinkerTextSparseMoeBlock): init.normal_(module.experts.gate_up_proj, mean=0.0, std=std) init.normal_(module.experts.down_proj, mean=0.0, std=std) @@ -688,10 +695,8 @@ def _init_weights(self, module): torch.arange(module.config.num_quantizers).view(1, -1, 1) * module.config.codebook_size, ) elif isinstance(module, SinusoidsPositionEmbedding): - log_timescale_increment = np.log(module.max_timescale) / (module.channels // 2 - 1) - inv_timescales = torch.exp(-log_timescale_increment * torch.arange(module.channels // 2).float()) - scaled_time = torch.arange(module.length)[:, np.newaxis] * inv_timescales[np.newaxis, :] - init.copy_(module.positional_embedding, torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)) + position_embeddings = module.compute_default_singular_positional_embedding() + init.copy_(module.positional_embedding, position_embeddings) elif isinstance(module, Qwen3OmniMoeVisionRotaryEmbedding): inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim)) init.copy_(module.inv_freq, inv_freq) @@ -853,7 +858,9 @@ def get_rope_index( st_idx += bos_len # Audio Only if min_ed == ed_audio_start: - audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx]) + audio_len = _get_feat_extract_output_lengths( + audio_seqlens[audio_idx], self.config.audio_config.n_window + ) llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx llm_pos_ids_list.append(llm_pos_ids) @@ -897,7 +904,9 @@ def get_rope_index( # Audio in Video elif min_ed == ed_vision_start and ed_vision_start + 1 == ed_audio_start: - audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx]) + audio_len = _get_feat_extract_output_lengths( + audio_seqlens[audio_idx], self.config.audio_config.n_window + ) audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx grid_t = video_grid_thw[video_idx][0] grid_hs = video_grid_thw[:, 1] @@ -989,6 +998,12 @@ def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig): self.n_window_infer = self.config.n_window_infer self.conv_chunksize = self.config.conv_chunksize + def _get_feat_extract_output_lengths(self, input_lengths): + raise NotImplementedError("Using the standalone function _get_feat_extract_output_lengths instead.") + + def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): + raise NotImplementedError("Not needed") + def get_input_embeddings(self): return self.conv2d1 @@ -1003,7 +1018,7 @@ def forward(self, input_features=None, feature_lens=None, **kwargs: Unpack[Trans padded_feature, chunk_lengths = chunk_and_pad_features( input_features, feature_lens, self.n_window, kwargs=kwargs ) - valid_indices = get_valid_indices(chunk_lengths, kwargs=kwargs) + valid_indices = get_valid_indices(chunk_lengths, self.n_window, kwargs=kwargs) cu_seqlens = get_audio_cu_seqlens( chunk_lengths, feature_lens, self.n_window_infer, self.n_window, kwargs=kwargs ) @@ -2538,6 +2553,7 @@ class Qwen3OmniMoeProcessorKwargs(Qwen2_5OmniProcessorKwargs): }, }, "audio_kwargs": { + "n_window": 50, # should match model config "sampling_rate": 16000, "padding": True, "truncation": False, @@ -2646,6 +2662,7 @@ def __call__( position_id_per_seconds = output_kwargs["videos_kwargs"].pop("position_id_per_seconds") use_audio_in_video = output_kwargs["videos_kwargs"].pop("use_audio_in_video") fps = output_kwargs["videos_kwargs"].get("fps", 1.0) + n_window = output_kwargs["audio_kwargs"].pop("n_window", 50) if audio is not None: audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) @@ -2655,7 +2672,9 @@ def __call__( audio_inputs["input_features"] = audio_inputs.pop( "input_features" ) # rename input_features to prevent conflicts later on - audio_lengths = iter(_get_feat_extract_output_lengths(audio_inputs["feature_attention_mask"].sum(-1))) + audio_lengths = iter( + _get_feat_extract_output_lengths(audio_inputs["feature_attention_mask"].sum(-1), n_window) + ) else: audio_inputs = {} audio_lengths = iter([]) @@ -2707,10 +2726,10 @@ def apply_chat_template(self, conversations, chat_template=None, **kwargs): __all__ = [ + "Qwen3OmniMoeAudioEncoderConfig", "Qwen3OmniMoeConfig", "Qwen3OmniMoeThinkerConfig", "Qwen3OmniMoeTalkerConfig", - "Qwen3OmniMoeAudioEncoderConfig", "Qwen3OmniMoeTalkerCodePredictorConfig", "Qwen3OmniMoeTalkerTextConfig", "Qwen3OmniMoeTextConfig", diff --git a/src/transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py index 966f399c856c..3a4ad96c723f 100644 --- a/src/transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py @@ -96,6 +96,7 @@ class Qwen3OmniMoeProcessorKwargs(ProcessingKwargs, total=False): }, }, "audio_kwargs": { + "n_window": 50, # should match model config "sampling_rate": 16000, "padding": True, "truncation": False, @@ -104,15 +105,14 @@ class Qwen3OmniMoeProcessorKwargs(ProcessingKwargs, total=False): } -def _get_feat_extract_output_lengths(input_lengths): - """Compute output lengths after the 3-layer CNN feature extractor with deepstack. - - Three stride-2 convolutions within each 100-frame block, plus 13 output frames - per full block from the deepstack path. +def _get_feat_extract_output_lengths(input_lengths, n_window=50): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder """ - input_lengths_leave = input_lengths % 100 + chunk_len = n_window * 2 + input_lengths_leave = input_lengths % chunk_len feat_lengths = (input_lengths_leave - 1) // 2 + 1 - return ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + return ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // chunk_len) * 13 @auto_docstring @@ -151,6 +151,7 @@ def __call__( position_id_per_seconds = output_kwargs["videos_kwargs"].pop("position_id_per_seconds") use_audio_in_video = output_kwargs["videos_kwargs"].pop("use_audio_in_video") fps = output_kwargs["videos_kwargs"].get("fps", 1.0) + n_window = output_kwargs["audio_kwargs"].pop("n_window", 50) if audio is not None: audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) @@ -160,7 +161,9 @@ def __call__( audio_inputs["input_features"] = audio_inputs.pop( "input_features" ) # rename input_features to prevent conflicts later on - audio_lengths = iter(_get_feat_extract_output_lengths(audio_inputs["feature_attention_mask"].sum(-1))) + audio_lengths = iter( + _get_feat_extract_output_lengths(audio_inputs["feature_attention_mask"].sum(-1), n_window) + ) else: audio_inputs = {} audio_lengths = iter([]) diff --git a/src/transformers/models/vibevoice_asr/processing_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/processing_vibevoice_asr.py index a694fbc99366..b0635ac89667 100644 --- a/src/transformers/models/vibevoice_asr/processing_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/processing_vibevoice_asr.py @@ -17,7 +17,7 @@ import numpy as np -from ...audio_utils import AudioInput, make_list_of_audio +from ...audio_utils import AudioInput, make_list_of_audio, make_list_of_audio_chat_template from ...feature_extraction_utils import BatchFeature from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import TextInput @@ -196,14 +196,9 @@ def apply_transcription_request( [`BatchFeature`]: Processor outputs ready to be passed to [`VibeVoiceAsrForConditionalGeneration.generate`]. """ - if isinstance(audio, str): - audio_items: list[str | np.ndarray] = [audio] - elif isinstance(audio, (list, tuple)) and audio and all(isinstance(el, str) for el in audio): - audio_items = list(audio) - else: - audio_items = list(make_list_of_audio(audio)) - if is_torch_available(): - audio_items = [el.detach().cpu().numpy() if isinstance(el, torch.Tensor) else el for el in audio_items] + audio_items: list[str | np.ndarray] = list(make_list_of_audio_chat_template(audio)) + if is_torch_available(): + audio_items = [el.detach().cpu().numpy() if isinstance(el, torch.Tensor) else el for el in audio_items] batch_size = len(audio_items) if batch_size == 0: diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 4f152e3a173c..e640d5d40590 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -796,6 +796,16 @@ def is_librosa_available() -> bool: return _is_package_available("librosa")[0] +@lru_cache +def is_nagisa_available() -> bool: + return _is_package_available("nagisa")[0] + + +@lru_cache +def is_soynlp_available() -> bool: + return _is_package_available("soynlp")[0] + + @lru_cache def is_multipart_available() -> bool: return _is_package_available("multipart")[0] diff --git a/tests/fixtures/qwen3_asr/expected_results_batched.json b/tests/fixtures/qwen3_asr/expected_results_batched.json new file mode 100644 index 000000000000..ff256f4a163d --- /dev/null +++ b/tests/fixtures/qwen3_asr/expected_results_batched.json @@ -0,0 +1 @@ +{"transcriptions": ["system\n\nuser\n\nassistant\nlanguage EnglishMr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", "system\n\nuser\n\nassistant\nlanguage Chineseη”šθ‡³ε‡ΊηŽ°δΊ€ζ˜“ε‡ δΉŽεœζ»žηš„ζƒ…ε†΅γ€‚"], "token_ids": [[11528, 6364, 151704, 12275, 13, 3406, 2044, 374, 279, 38471, 273, 315, 279, 6149, 6846, 11, 323, 582, 525, 15713, 311, 10565, 806, 41482, 13, 151645], [11528, 8453, 151704, 100636, 100347, 99886, 100740, 118083, 102072, 1773, 151645, 151645, 151645, 151645, 151645, 151645, 151645, 151645, 151645, 151645, 151645, 151645, 151645, 151645, 151645, 151645]]} \ No newline at end of file diff --git a/tests/fixtures/qwen3_asr/expected_results_single.json b/tests/fixtures/qwen3_asr/expected_results_single.json new file mode 100644 index 000000000000..bb48e15f757e --- /dev/null +++ b/tests/fixtures/qwen3_asr/expected_results_single.json @@ -0,0 +1 @@ +{"transcriptions": ["system\n\nuser\n\nassistant\nlanguage EnglishMr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."], "token_ids": [[11528, 6364, 151704, 12275, 13, 3406, 2044, 374, 279, 38471, 273, 315, 279, 6149, 6846, 11, 323, 582, 525, 15713, 311, 10565, 806, 41482, 13, 151645]]} \ No newline at end of file diff --git a/tests/fixtures/qwen3_asr/expected_timestamps_batched.json b/tests/fixtures/qwen3_asr/expected_timestamps_batched.json new file mode 100644 index 000000000000..35b893354446 --- /dev/null +++ b/tests/fixtures/qwen3_asr/expected_timestamps_batched.json @@ -0,0 +1,164 @@ +[ + { + "language": "English", + "text": "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", + "time_stamps": [ + { + "text": "Mr", + "start_time": 0.56, + "end_time": 0.8 + }, + { + "text": "Quilter", + "start_time": 0.8, + "end_time": 1.28 + }, + { + "text": "is", + "start_time": 1.28, + "end_time": 1.44 + }, + { + "text": "the", + "start_time": 1.44, + "end_time": 1.52 + }, + { + "text": "apostle", + "start_time": 1.52, + "end_time": 2.08 + }, + { + "text": "of", + "start_time": 2.08, + "end_time": 2.32 + }, + { + "text": "the", + "start_time": 2.32, + "end_time": 2.32 + }, + { + "text": "middle", + "start_time": 2.32, + "end_time": 2.56 + }, + { + "text": "classes", + "start_time": 2.56, + "end_time": 3.28 + }, + { + "text": "and", + "start_time": 3.36, + "end_time": 3.52 + }, + { + "text": "we", + "start_time": 3.52, + "end_time": 3.6 + }, + { + "text": "are", + "start_time": 3.6, + "end_time": 3.68 + }, + { + "text": "glad", + "start_time": 3.68, + "end_time": 4.08 + }, + { + "text": "to", + "start_time": 4.16, + "end_time": 4.16 + }, + { + "text": "welcome", + "start_time": 4.16, + "end_time": 4.64 + }, + { + "text": "his", + "start_time": 4.64, + "end_time": 4.8 + }, + { + "text": "gospel", + "start_time": 4.8, + "end_time": 5.44 + } + ] + }, + { + "language": "Chinese", + "text": "η”šθ‡³ε‡ΊηŽ°δΊ€ζ˜“ε‡ δΉŽεœζ»žηš„ζƒ…ε†΅γ€‚", + "time_stamps": [ + { + "text": "η”š", + "start_time": 0.4, + "end_time": 0.72 + }, + { + "text": "至", + "start_time": 0.72, + "end_time": 0.96 + }, + { + "text": "ε‡Ί", + "start_time": 0.96, + "end_time": 1.12 + }, + { + "text": "现", + "start_time": 1.12, + "end_time": 1.52 + }, + { + "text": "δΊ€", + "start_time": 1.52, + "end_time": 1.76 + }, + { + "text": "ζ˜“", + "start_time": 1.76, + "end_time": 2.0 + }, + { + "text": "ε‡ ", + "start_time": 2.0, + "end_time": 2.24 + }, + { + "text": "乎", + "start_time": 2.24, + "end_time": 2.48 + }, + { + "text": "停", + "start_time": 2.48, + "end_time": 2.72 + }, + { + "text": "滞", + "start_time": 2.72, + "end_time": 2.88 + }, + { + "text": "ηš„", + "start_time": 2.88, + "end_time": 3.04 + }, + { + "text": "ζƒ…", + "start_time": 3.04, + "end_time": 3.36 + }, + { + "text": "冡", + "start_time": 3.36, + "end_time": 3.68 + } + ] + } +] \ No newline at end of file diff --git a/tests/fixtures/qwen3_asr/expected_timestamps_single.json b/tests/fixtures/qwen3_asr/expected_timestamps_single.json new file mode 100644 index 000000000000..1786d4a86ae3 --- /dev/null +++ b/tests/fixtures/qwen3_asr/expected_timestamps_single.json @@ -0,0 +1,91 @@ +{ + "language": "English", + "text": "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", + "time_stamps": [ + { + "text": "Mr", + "start_time": 0.56, + "end_time": 0.8 + }, + { + "text": "Quilter", + "start_time": 0.8, + "end_time": 1.28 + }, + { + "text": "is", + "start_time": 1.28, + "end_time": 1.44 + }, + { + "text": "the", + "start_time": 1.44, + "end_time": 1.52 + }, + { + "text": "apostle", + "start_time": 1.52, + "end_time": 2.08 + }, + { + "text": "of", + "start_time": 2.08, + "end_time": 2.32 + }, + { + "text": "the", + "start_time": 2.32, + "end_time": 2.32 + }, + { + "text": "middle", + "start_time": 2.32, + "end_time": 2.56 + }, + { + "text": "classes", + "start_time": 2.56, + "end_time": 3.28 + }, + { + "text": "and", + "start_time": 3.36, + "end_time": 3.52 + }, + { + "text": "we", + "start_time": 3.52, + "end_time": 3.6 + }, + { + "text": "are", + "start_time": 3.6, + "end_time": 3.68 + }, + { + "text": "glad", + "start_time": 3.68, + "end_time": 4.08 + }, + { + "text": "to", + "start_time": 4.16, + "end_time": 4.16 + }, + { + "text": "welcome", + "start_time": 4.16, + "end_time": 4.64 + }, + { + "text": "his", + "start_time": 4.64, + "end_time": 4.8 + }, + { + "text": "gospel", + "start_time": 4.8, + "end_time": 5.44 + } + ] +} \ No newline at end of file diff --git a/tests/models/dac/test_feature_extraction_dac.py b/tests/models/dac/test_feature_extraction_dac.py index 2620804b2cd2..c1684edd704d 100644 --- a/tests/models/dac/test_feature_extraction_dac.py +++ b/tests/models/dac/test_feature_extraction_dac.py @@ -14,7 +14,6 @@ """Tests for the dac feature extractor.""" import itertools -import random import unittest import numpy as np @@ -23,6 +22,7 @@ from transformers.testing_utils import require_torch from transformers.utils.import_utils import is_torch_available +from ...test_processing_common import floats_list from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin @@ -30,24 +30,6 @@ import torch -global_rng = random.Random() - - -# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list -def floats_list(shape, scale=1.0, rng=None, name=None): - """Creates a random float32 tensor""" - if rng is None: - rng = global_rng - - values = [] - for batch_idx in range(shape[0]): - values.append([]) - for _ in range(shape[1]): - values[-1].append(rng.random() * scale) - - return values - - @require_torch # Copied from transformers.tests.encodec.test_feature_extraction_encodec.EncodecFeatureExtractionTester with Encodec->Dac class DacFeatureExtractionTester: diff --git a/tests/models/qwen3_asr/__init__.py b/tests/models/qwen3_asr/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/qwen3_asr/test_feature_extraction_qwen3_asr.py b/tests/models/qwen3_asr/test_feature_extraction_qwen3_asr.py new file mode 100644 index 000000000000..83a2b3c6e3c7 --- /dev/null +++ b/tests/models/qwen3_asr/test_feature_extraction_qwen3_asr.py @@ -0,0 +1,171 @@ +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import unittest + +import numpy as np + +from transformers import Qwen3ASRFeatureExtractor + +from ...test_processing_common import floats_list +from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin + + +class Qwen3ASRFeatureExtractionTester: + def __init__( + self, + parent, + batch_size=7, + min_seq_length=400, + max_seq_length=2000, + feature_size=10, + hop_length=160, + chunk_length=8, + padding_value=0.0, + sampling_rate=4_000, + return_attention_mask=False, + n_window=13, + ): + self.parent = parent + self.batch_size = batch_size + self.min_seq_length = min_seq_length + self.max_seq_length = max_seq_length + self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1) + self.feature_size = feature_size + self.hop_length = hop_length + self.chunk_length = chunk_length + self.padding_value = padding_value + self.sampling_rate = sampling_rate + self.return_attention_mask = return_attention_mask + self.n_window = n_window + + def prepare_feat_extract_dict(self): + return { + "feature_size": self.feature_size, + "hop_length": self.hop_length, + "chunk_length": self.chunk_length, + "padding_value": self.padding_value, + "sampling_rate": self.sampling_rate, + "return_attention_mask": self.return_attention_mask, + "n_window": self.n_window, + } + + def prepare_inputs_for_common(self, equal_length=False, numpify=False): + def _flatten(list_of_lists): + return list(itertools.chain(*list_of_lists)) + + if equal_length: + speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)] + else: + speech_inputs = [ + floats_list((x, self.feature_size)) + for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff) + ] + if numpify: + speech_inputs = [np.asarray(x) for x in speech_inputs] + return speech_inputs + + +class Qwen3ASRFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): + feature_extraction_class = Qwen3ASRFeatureExtractor + + def setUp(self): + self.feat_extract_tester = Qwen3ASRFeatureExtractionTester(self) + + def test_default_feature_size_is_128(self): + """Qwen3 ASR uses 128-bin mel filters by default.""" + fe = Qwen3ASRFeatureExtractor() + self.assertEqual(fe.feature_size, 128) + self.assertEqual(fe.mel_filters.shape[1], 128) + + def test_default_n_window_is_50(self): + fe = Qwen3ASRFeatureExtractor() + self.assertEqual(fe.n_window, 50) + + def test_mel_padding_aligns_to_chunk(self): + """The mel time axis is right-padded to a multiple of `2 * n_window`.""" + fe = Qwen3ASRFeatureExtractor() + # 5.85 s at 16 kHz -> 585 mel frames before padding -> 600 after (multiple of 100). + audio = np.random.randn(int(5.85 * 16_000)).astype(np.float32) + out = fe( + audio, + sampling_rate=16_000, + padding="longest", + truncation=False, + return_attention_mask=True, + return_tensors="np", + ) + self.assertEqual(out["input_features"].shape, (1, 128, 600)) + self.assertEqual(out["attention_mask"].shape, (1, 600)) + self.assertEqual(int(out["attention_mask"].sum(-1)), 585) + self.assertEqual(out["input_features"].shape[-1] % 100, 0) + + def test_n_window_kwarg_override(self): + fe = Qwen3ASRFeatureExtractor() + audio = np.random.randn(int(5.85 * 16_000)).astype(np.float32) + out = fe( + audio, + sampling_rate=16_000, + padding="longest", + truncation=False, + return_attention_mask=True, + return_tensors="np", + n_window=25, + ) + self.assertEqual(out["input_features"].shape[-1] % 50, 0) + + def test_n_window_disabled(self): + """`n_window=0` disables mel-axis padding.""" + fe = Qwen3ASRFeatureExtractor() + audio = np.random.randn(int(5.85 * 16_000)).astype(np.float32) + out = fe( + audio, + sampling_rate=16_000, + padding="longest", + truncation=False, + return_attention_mask=True, + return_tensors="np", + n_window=0, + ) + self.assertEqual(out["input_features"].shape[-1], 585) + self.assertEqual(out["attention_mask"].shape[-1], 585) + + def test_batched_call_shape(self): + fe = Qwen3ASRFeatureExtractor() + # Two clips of different lengths; padded to the longer one (rounded up to 2 * n_window). + audio = [ + np.random.randn(int(2.0 * 16_000)).astype(np.float32), + np.random.randn(int(5.5 * 16_000)).astype(np.float32), + ] + out = fe( + audio, + sampling_rate=16_000, + padding="longest", + truncation=False, + return_attention_mask=True, + return_tensors="np", + ) + self.assertEqual(out["input_features"].ndim, 3) + self.assertEqual(out["input_features"].shape[0], 2) + self.assertEqual(out["input_features"].shape[1], 128) + self.assertEqual(out["input_features"].shape[-1] % 100, 0) + per_sample_valid = out["attention_mask"].sum(-1).tolist() + self.assertEqual(per_sample_valid, [200, 550]) + + def test_mismatched_sampling_rate_raises(self): + fe = Qwen3ASRFeatureExtractor(sampling_rate=16_000) + audio = np.random.randn(16_000).astype(np.float32) + with self.assertRaises(ValueError): + fe(audio, sampling_rate=8_000, return_tensors="np") diff --git a/tests/models/qwen3_asr/test_modeling_qwen3_asr.py b/tests/models/qwen3_asr/test_modeling_qwen3_asr.py new file mode 100644 index 000000000000..987636eab9c9 --- /dev/null +++ b/tests/models/qwen3_asr/test_modeling_qwen3_asr.py @@ -0,0 +1,310 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import unittest +from pathlib import Path + +from transformers import ( + AutoProcessor, + Qwen3ASRConfig, + Qwen3ASREncoderConfig, + Qwen3ASRForConditionalGeneration, + Qwen3ASRForTokenClassification, + Qwen3ASRModel, + Qwen3Config, + is_torch_available, +) +from transformers.testing_utils import ( + cleanup, + require_torch, + slow, + torch_device, +) + +from ...alm_tester import ALMModelTest, ALMModelTester + + +if is_torch_available(): + import torch + + +class Qwen3ASRModelTester(ALMModelTester): + config_class = Qwen3ASRConfig + conditional_generation_class = Qwen3ASRForConditionalGeneration + text_config_class = Qwen3Config + audio_config_class = Qwen3ASREncoderConfig + audio_mask_key = "input_features_mask" + + def __init__(self, parent, **kwargs): + kwargs.setdefault("num_mel_bins", 20) + kwargs.setdefault("feat_seq_length", 100) + kwargs.setdefault("d_model", 16) + kwargs.setdefault("hidden_size", 16) + kwargs.setdefault("encoder_layers", 1) + kwargs.setdefault("num_attention_heads", 2) + kwargs.setdefault("num_key_value_heads", 2) + kwargs.setdefault("encoder_ffn_dim", 16) + kwargs.setdefault("output_dim", 16) + kwargs.setdefault("downsample_hidden_size", 4) + kwargs.setdefault("head_dim", 8) + kwargs.setdefault("n_window", 50) + kwargs.setdefault("max_position_embeddings", 13) + super().__init__(parent, **kwargs) + + def create_audio_mask(self): + return torch.ones([self.batch_size, self.feat_seq_length], dtype=torch.long).to(torch_device) + + def get_audio_embeds_mask(self, audio_mask): + from transformers.models.qwen3_asr.modeling_qwen3_asr import _get_feat_extract_output_lengths + + input_lengths = audio_mask.sum(-1) + output_lengths = _get_feat_extract_output_lengths(input_lengths, n_window=self.n_window) + max_len = int(output_lengths.max().item()) + positions = torch.arange(max_len, device=audio_mask.device)[None, :] + return (positions < output_lengths[:, None]).long() + + +@require_torch +class Qwen3ASRForConditionalGenerationModelTest(ALMModelTest, unittest.TestCase): + model_tester_class = Qwen3ASRModelTester + all_model_classes = ( + (Qwen3ASRForConditionalGeneration, Qwen3ASRModel, Qwen3ASRForTokenClassification) + if is_torch_available() + else () + ) + pipeline_model_mapping = ( + { + "audio-text-to-text": Qwen3ASRForConditionalGeneration, + } + if is_torch_available() + else {} + ) + + # The audio encoder merges batch_size and output_lengths in dim 0 + skip_test_audio_features_output_shape = True + + def _audio_features_get_expected_num_attentions(self, model_tester=None): + return self.model_tester.encoder_layers + + def _audio_features_get_expected_num_hidden_states(self, model_tester=None): + return self.model_tester.encoder_layers + 1 + + @unittest.skip( + reason="Like other audio LMs (Audio Flamingo, Voxtral) inputs_embeds corresponding to audio tokens are replaced when input features are provided." + ) + def test_inputs_embeds_matches_input_ids(self): + pass + + +@require_torch +class Qwen3ASRForConditionalGenerationIntegrationTest(unittest.TestCase): + @classmethod + def setUp(cls): + cleanup(torch_device, gc_collect=True) + cls.checkpoint = "bezzam/Qwen3-ASR-0.6B-hf" + cls.processor = AutoProcessor.from_pretrained(cls.checkpoint) + cls.fixtures_path = Path(__file__).parent.parent.parent / "fixtures/qwen3_asr" + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @slow + def test_fixture_single_matches(self): + """ + reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/3e0551708631784aeb684e0e838299f3#file-reproducer-py + """ + path = self.fixtures_path / "expected_results_single.json" + with open(path, "r", encoding="utf-8") as f: + raw = json.load(f) + exp_ids = torch.tensor(raw["token_ids"]) + exp_txt = raw["transcriptions"] + + conversation = [ + { + "role": "user", + "content": [ + { + "type": "audio", + "path": "https://huggingface.co/datasets/bezzam/audio_samples/resolve/main/librispeech_mr_quilter.wav", + }, + ], + } + ] + + model = Qwen3ASRForConditionalGeneration.from_pretrained( + self.checkpoint, device_map="auto", dtype=torch.bfloat16 + ).eval() + + batch = self.processor.apply_chat_template( + conversation, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" + ).to(model.device, dtype=model.dtype) + seq = model.generate(**batch, max_new_tokens=32) + + inp_len = batch["input_ids"].shape[1] + gen_ids = seq[:, inp_len:] if seq.shape[1] >= inp_len else seq + torch.testing.assert_close(gen_ids.cpu(), exp_ids) + txt = self.processor.decode(seq, skip_special_tokens=True) + self.assertListEqual(txt, exp_txt) + + @slow + def test_fixture_batch_matches(self): + """ + reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/3e0551708631784aeb684e0e838299f3#file-reproducer-py + """ + path = self.fixtures_path / "expected_results_batched.json" + with open(path, "r", encoding="utf-8") as f: + raw = json.load(f) + exp_ids = torch.tensor(raw["token_ids"]) + exp_txt = raw["transcriptions"] + + conversation = [ + [ + { + "role": "user", + "content": [ + { + "type": "audio", + "path": "https://huggingface.co/datasets/bezzam/audio_samples/resolve/main/librispeech_mr_quilter.wav", + }, + ], + } + ], + [ + { + "role": "user", + "content": [ + { + "type": "audio", + "path": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_zh.wav", + }, + ], + } + ], + ] + + model = Qwen3ASRForConditionalGeneration.from_pretrained( + self.checkpoint, device_map="auto", dtype=torch.bfloat16 + ).eval() + batch = self.processor.apply_chat_template( + conversation, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + padding=True, + truncation=True, + ).to(model.device, dtype=model.dtype) + + seq = model.generate(**batch, max_new_tokens=32) + + inp_len = batch["input_ids"].shape[1] + gen_ids = seq[:, inp_len:] if seq.shape[1] >= inp_len else seq + torch.testing.assert_close(gen_ids.cpu(), exp_ids) + txt = self.processor.decode(seq, skip_special_tokens=True) + self.assertListEqual(txt, exp_txt) + + +@require_torch +class Qwen3ForcedAlignerIntegrationTest(unittest.TestCase): + """ + reproducer scripts (create JSON fixtures directly in repo): https://gist.github.com/ebezzam/3e0551708631784aeb684e0e838299f3#file-reproducer_timestamps-py + """ + + @classmethod + def setUp(cls): + cleanup(torch_device, gc_collect=True) + cls.aligner_checkpoint = "bezzam/Qwen3-ForcedAligner-0.6B-hf" + cls.aligner_processor = AutoProcessor.from_pretrained(cls.aligner_checkpoint) + cls.fixtures_path = Path(__file__).parent.parent.parent / "fixtures/qwen3_asr" + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def _load_aligner(self): + return Qwen3ASRForTokenClassification.from_pretrained( + self.aligner_checkpoint, + device_map="auto", + torch_dtype=torch.bfloat16, + ).eval() + + def _run_alignment(self, model, audio, transcript, language): + """Run forced alignment and return list of timestamp dicts.""" + aligner_inputs, word_lists = self.aligner_processor.prepare_forced_aligner_inputs( + audio=audio, + transcript=transcript, + language=language, + ) + aligner_inputs = aligner_inputs.to(model.device, model.dtype) + + with torch.inference_mode(): + outputs = model(**aligner_inputs) + + return self.aligner_processor.decode_forced_alignment( + logits=outputs.logits, + input_ids=aligner_inputs["input_ids"], + word_lists=word_lists, + timestamp_token_id=model.config.timestamp_token_id, + ) + + @slow + def test_fixture_timestamps_single(self): + path = self.fixtures_path / "expected_timestamps_single.json" + with open(path, "r", encoding="utf-8") as f: + expected = json.load(f) + + model = self._load_aligner() + audio_url = "https://huggingface.co/datasets/bezzam/audio_samples/resolve/main/librispeech_mr_quilter.wav" + + timestamps = self._run_alignment( + model, + audio=audio_url, + transcript=expected["text"], + language=expected["language"], + )[0] + + self.assertEqual(len(timestamps), len(expected["time_stamps"])) + for pred, exp in zip(timestamps, expected["time_stamps"]): + self.assertAlmostEqual(pred["start_time"], exp["start_time"], places=2) + self.assertAlmostEqual(pred["end_time"], exp["end_time"], places=2) + + @slow + def test_fixture_timestamps_batched(self): + path = self.fixtures_path / "expected_timestamps_batched.json" + with open(path, "r", encoding="utf-8") as f: + expected_batch = json.load(f) + + model = self._load_aligner() + audio_urls = [ + "https://huggingface.co/datasets/bezzam/audio_samples/resolve/main/librispeech_mr_quilter.wav", + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_zh.wav", + ] + + batch_timestamps = self._run_alignment( + model, + audio=audio_urls, + transcript=[e["text"] for e in expected_batch], + language=[e["language"] for e in expected_batch], + ) + + self.assertEqual(len(batch_timestamps), len(expected_batch)) + for sample_idx, (pred_ts, exp) in enumerate(zip(batch_timestamps, expected_batch)): + self.assertEqual( + len(pred_ts), + len(exp["time_stamps"]), + f"Sample {sample_idx}: expected {len(exp['time_stamps'])} timestamps, got {len(pred_ts)}", + ) + for pred, exp_ts in zip(pred_ts, exp["time_stamps"]): + self.assertAlmostEqual(pred["start_time"], exp_ts["start_time"]) + self.assertAlmostEqual(pred["end_time"], exp_ts["end_time"]) diff --git a/tests/models/qwen3_asr/test_processor_qwen3_asr.py b/tests/models/qwen3_asr/test_processor_qwen3_asr.py new file mode 100644 index 000000000000..b149d3927214 --- /dev/null +++ b/tests/models/qwen3_asr/test_processor_qwen3_asr.py @@ -0,0 +1,180 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest + +from parameterized import parameterized + +from transformers import ( + AutoProcessor, + AutoTokenizer, + Qwen2TokenizerFast, + Qwen3ASRFeatureExtractor, +) +from transformers.models.qwen3_asr.processing_qwen3_asr import Qwen3ASRProcessor +from transformers.testing_utils import ( + require_torch, + require_torchaudio, +) + +from ...test_processing_common import ProcessorTesterMixin + + +class Qwen3ASRProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = Qwen3ASRProcessor + + @classmethod + @require_torch + @require_torchaudio + def setUpClass(cls): + cls.checkpoint = "bezzam/Qwen3-ASR-0.6B-hf" + cls.tmpdirname = tempfile.mkdtemp() + processor = Qwen3ASRProcessor.from_pretrained(cls.checkpoint) + processor.save_pretrained(cls.tmpdirname) + + @require_torch + @require_torchaudio + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + @require_torch + @require_torchaudio + def get_feature_extractor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).feature_extractor + + @require_torch + @require_torchaudio + def get_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdirname, ignore_errors=True) + + @require_torch + @require_torchaudio + def test_can_load_various_tokenizers(self): + processor = Qwen3ASRProcessor.from_pretrained(self.checkpoint) + tokenizer = AutoTokenizer.from_pretrained(self.checkpoint) + self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__) + + @require_torch + @require_torchaudio + def test_save_load_pretrained_default(self): + tokenizer = AutoTokenizer.from_pretrained(self.checkpoint) + processor = Qwen3ASRProcessor.from_pretrained(self.checkpoint) + feature_extractor = processor.feature_extractor + + processor = Qwen3ASRProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + with tempfile.TemporaryDirectory() as tmpdir: + processor.save_pretrained(tmpdir) + reloaded = Qwen3ASRProcessor.from_pretrained(tmpdir) + + self.assertEqual(reloaded.tokenizer.get_vocab(), tokenizer.get_vocab()) + self.assertEqual(reloaded.feature_extractor.to_json_string(), feature_extractor.to_json_string()) + self.assertIsInstance(reloaded.feature_extractor, Qwen3ASRFeatureExtractor) + self.assertIsInstance(reloaded.tokenizer, Qwen2TokenizerFast) + + @require_torch + @require_torchaudio + def test_chat_template(self): + processor = AutoProcessor.from_pretrained(self.checkpoint) + expected_prompt = ( + "<|im_start|>system\n" + "<|im_end|>\n" + "<|im_start|>user\n" + "<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n" + "<|im_start|>assistant\n" + ) + messages = [ + { + "role": "user", + "content": [ + { + "type": "audio", + "path": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_en.wav", + }, + ], + }, + ] + formatted_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + self.assertEqual(expected_prompt, formatted_prompt) + + @require_torch + @require_torchaudio + def test_apply_transcription_request_single(self): + processor = AutoProcessor.from_pretrained(self.checkpoint) + + audio_url = "https://huggingface.co/datasets/bezzam/audio_samples/resolve/main/librispeech_mr_quilter.wav" + helper_outputs = processor.apply_transcription_request(audio=audio_url) + + conversation = [ + { + "role": "user", + "content": [ + {"type": "audio", "path": audio_url}, + ], + } + ] + manual_outputs = processor.apply_chat_template( + conversation, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + ) + + for key in ("input_ids", "attention_mask", "input_features", "input_features_mask"): + self.assertIn(key, helper_outputs) + self.assertTrue(helper_outputs[key].equal(manual_outputs[key])) + + @require_torch + @require_torchaudio + def test_apply_transcription_request_with_language(self): + processor = AutoProcessor.from_pretrained(self.checkpoint) + + audio_url = "https://huggingface.co/datasets/bezzam/audio_samples/resolve/main/librispeech_mr_quilter.wav" + outputs = processor.apply_transcription_request(audio=audio_url, language="English") + + for key in ("input_ids", "attention_mask", "input_features", "input_features_mask"): + self.assertIn(key, outputs) + + @require_torch + @require_torchaudio + def test_decode_formats(self): + processor = AutoProcessor.from_pretrained(self.checkpoint) + + raw_text = "language EnglishMr. Quilter is the apostle of the middle classes." + + # raw + self.assertEqual(raw_text, raw_text) + + # parsed + parsed = processor.parse_output(raw_text) + self.assertIsInstance(parsed, dict) + self.assertEqual(parsed["language"], "English") + self.assertEqual(parsed["transcription"], "Mr. Quilter is the apostle of the middle classes.") + + # transcription_only + transcription = processor.extract_transcription(raw_text) + self.assertEqual(transcription, "Mr. Quilter is the apostle of the middle classes.") + + @parameterized.expand([(1, "np"), (1, "pt"), (2, "np"), (2, "pt")]) + def test_apply_chat_template_audio(self, batch_size: int, return_tensors: str): + self.skipTest("Qwen3ASR processor requires audio; not compatible with text-only chat template tests.") + + def test_apply_chat_template_assistant_mask(self): + self.skipTest("Qwen3ASR processor requires audio; not compatible with text-only chat template tests.") diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index b410b80fec22..559f9d616e59 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -79,7 +79,6 @@ def prepare_image_inputs(): return image_inputs -# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list def floats_list(shape, scale=1.0, rng=None, name=None): """Creates a random float32 tensor""" if rng is None: diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 98932d6c6680..62b640138ae4 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -266,6 +266,8 @@ "vision_feature_layer", "vision_feature_select_strategy", "vision_aspect_ratio", + # used by GenericForTokenClassification in modeling_layers.py via getattr + "token_classification_bias", )