Skip to content

[Draft] Add Llasa TTS family of models#39760

Draft
ebezzam wants to merge 13 commits into
huggingface:mainfrom
ebezzam:add_llasa
Draft

[Draft] Add Llasa TTS family of models#39760
ebezzam wants to merge 13 commits into
huggingface:mainfrom
ebezzam:add_llasa

Conversation

@ebezzam

@ebezzam ebezzam commented Jul 29, 2025

Copy link
Copy Markdown
Contributor

What does this PR do?

This PR adds the Llasa TTS family of models:

Reproducers for integration tests: https://gist.github.com/ebezzam/1863ec8eb7ec4afff02c26bdcb7691f9

TODO

  • Batch integration tests
  • Tokenizer and processing tests like Dia?
  • Create public model cards (update text and add relevant tags and labels). Atm under my account (1B, 3B, 8B).
  • Integrate with XCodec2 (Transformer version) when Add xcodec2 model #37868 merged

Example usage

Below is example usage with my Hub checkpoint (compared to that of original authors)

"""
pip install torchao xcodec2==0.1.3
"""

import torch
from transformers import LlasaTokenizer, LlasaForCausalLM, LlasaProcessor
import soundfile as sf
from xcodec2.modeling_xcodec2 import XCodec2Model

model_repo = "bezzam/Llasa-1B"
# model_repo = "bezzam/Llasa-3B"
# model_repo = "bezzam/Llasa-8B"
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

# load processor (tokenizer + audio codec)
processor = LlasaProcessor(
    LlasaTokenizer.from_pretrained(model_repo),
    XCodec2Model.from_pretrained("HKUSTAudio/xcodec2").eval().to(torch_device)
)
# # -- use below when `XCodec2Model` integrated into `transformers`
# processor = LlasaProcessor.from_pretrained(model_repo)

# load model
model = LlasaForCausalLM.from_pretrained(model_repo)
model.eval().to(torch_device)

# TTS, some text inputs don't work which shows limitations of this approach
input_text = "How much wood would a woodchuck chuck if a woodchuck could chuck speech tokens?"
with torch.no_grad():

    # Tokenize the text
    encoded_text = processor(input_text).to(torch_device)

    # Generate the speech autoregressively
    outputs = model.generate(
        encoded_text["input_ids"],
        do_sample=False,
        max_length=600,    # generates up to ~10s. Max allowed length is 2048, as Llasa was trained with max length 2048
        top_p=1,           # Adjusts the diversity of generated content
        temperature=0.8,   # Controls randomness in output
    )

# decode to audio
gen_wav = processor.decode(outputs, input_offset=encoded_text["input_offset"])
fn = f"gen_{model_repo.split('/')[-1]}.wav"
sf.write(fn, gen_wav.cpu().numpy(), model.config.sampling_rate)
print(f"Generated speech saved to {fn}")

@ebezzam ebezzam marked this pull request as draft July 29, 2025 14:42
model_config.max_length = config.original_model.model_max_length
model = LlasaForCausalLM(model_config)
if config.remote_repo.dtype == "bfloat16":
model.to(torch.bfloat16)

@ebezzam ebezzam Jul 29, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is bf16 fine? Original models are trained in bf16 (config) and their Hub checkpoints are also in bf16 (e.g., 1B)

Comment thread src/transformers/models/llasa/convert_llasa_to_hf.py Outdated
Comment on lines +68 to +74
def from_pretrained_llm(cls, *args, **kwargs):
"""
Load the tokenizer from a pre-trained LLM model, and add relevant speech and Llasa tokens.
"""
tokenizer = super().from_pretrained(*args, **kwargs)
tokenizer.add_tokens(list(tokenizer.llasa_token.values()) + tokenizer.speech_tokens)
return tokenizer

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is something like this fine? (also for LlasaConfig)

The difference with conventional from_pretrained is that this one increases the vocab size according to the (speech and llasa tokens). These methods are useful for the conversion script to copy the tokenizer and config from Llama (an LLM).

But when using Llasa, from_pretrained will be used as usual, loading from actual Llasa tokeniers and configs that don't need explicit adding of tokens.

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread src/transformers/models/llasa/modular_llasa.py
@ebezzam ebezzam requested a review from eustlb July 29, 2025 15:06
@Rocketknight1

Copy link
Copy Markdown
Member

cc @eustlb for TTS

Comment on lines +388 to +414
# TODO: how to overwrite generate method?
# Not necessary but could be nice to check max_length <= 2048 (what model was trained on)
# I get the following error (I think because `generate` isn't method of LlamaForCausalLM but its parent):
# ```
# File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 355, in replace_super_calls
# original_modeling_method_body = self.original_modeling_methods[func_name].body.body
# KeyError: 'generate'
# ```
# """
# @torch.no_grad()
# def generate(
# inputs,
# max_length=2048,
# **kwargs,
# ):
# """
# Set specific parameters from Llasa processor output
# """
# if max_length > 2048:
# raise ValueError("Max length should be less than or equal to 2048.")

# # Call the parent class's generate method
# return super().generate(
# inputs,
# max_length=inputs["max_length"],
# **kwargs
# )

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to overwrite the generate method for adding a check that users don't request outputs larger than model.generation_config.max_length (=2048), which is the max length the models were trained on. But maybe there's another way to restrict output sizes that users request? Or maybe it's not needed to add such a check?

I was getting the error (below) when running the modular script. Maybe the issue is that generate is not a method of LlamaForCausalLM but of its parent class GenerationMixin?

# command: python utils/modular_model_converter.py --files-to-parse src/transformers/models/llasa/modular_llasa.py
Traceback (most recent call last):
  File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 1779, in <module>
    converted_files = convert_modular_file(file_name)
  File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 1693, in convert_modular_file
    for file, module in create_modules(cst_transformers).items():
  File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 1634, in create_modules
    nodes_to_add, file_type, new_imports = get_class_node_and_dependencies(modular_mapper, class_name, node, files)
  File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 1577, in get_class_node_and_dependencies
    updated_node = replace_class_node(mapper, node, renamed_super_class, super_class)
  File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 1064, in replace_class_node
    new_replacement_class = new_module.visit(
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/metadata/wrapper.py", line 204, in visit
    return self.module.visit(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/module.py", line 89, in visit
    result = super(Module, self).visit(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/base.py", line 228, in visit
    _CSTNodeSelfT, self._visit_and_replace_children(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/module.py", line 74, in _visit_and_replace_children
    body=visit_body_sequence(self, "body", self.body, visitor),
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/internal.py", line 227, in visit_body_sequence
    return tuple(visit_body_iterable(parent, fieldname, children, visitor))
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/internal.py", line 193, in visit_body_iterable
    new_child = child.visit(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/base.py", line 228, in visit
    _CSTNodeSelfT, self._visit_and_replace_children(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/statement.py", line 1989, in _visit_and_replace_children
    body=visit_required(self, "body", self.body, visitor),
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/internal.py", line 81, in visit_required
    result = node.visit(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/base.py", line 228, in visit
    _CSTNodeSelfT, self._visit_and_replace_children(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/statement.py", line 704, in _visit_and_replace_children
    body=visit_body_sequence(self, "body", self.body, visitor),
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/internal.py", line 227, in visit_body_sequence
    return tuple(visit_body_iterable(parent, fieldname, children, visitor))
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/internal.py", line 193, in visit_body_iterable
    new_child = child.visit(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/base.py", line 237, in visit
    leave_result = visitor.on_leave(self, with_updated_children)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_visitors.py", line 71, in on_leave
    updated_node = leave_func(original_node, updated_node)
  File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 369, in leave_FunctionDef
    new_body = self.replace_super_calls(updated_node.body, name)
  File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 355, in replace_super_calls
    original_modeling_method_body = self.original_modeling_methods[func_name].body.body
KeyError: 'generate'

Comment on lines +38 to +39
# TODO use "audio_tokenizer_class" when merged https://github.com/huggingface/transformers/pull/37868
# audio_tokenizer_class = "XCodec2Model"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several TODOs like this to switch to XCodec2 model from Transformers when #37868 is merged

@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, csm, llasa

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants