diff --git a/src/openbench/cli/commands/evaluate.py b/src/openbench/cli/commands/evaluate.py index 3569da8..1b4f5f7 100644 --- a/src/openbench/cli/commands/evaluate.py +++ b/src/openbench/cli/commands/evaluate.py @@ -175,6 +175,7 @@ def run_alias_mode( wandb_run_name: str | None, wandb_tags: list[str] | None, use_keywords: bool | None, + force_language: bool, verbose: bool, ) -> BenchmarkResult: """Run evaluation using pipeline and dataset aliases.""" @@ -201,6 +202,12 @@ def run_alias_mode( if verbose: typer.echo(f"✅ Keywords: {'enabled' if use_keywords else 'disabled'} (override)") + # Handle force_language override + if force_language: + pipeline_config_override["force_language"] = force_language + if verbose: + typer.echo("✅ Force language: enabled") + pipeline = PipelineRegistry.create_pipeline(pipeline_name, config=pipeline_config_override) ######### Build Benchmark Config ######### @@ -333,6 +340,11 @@ def evaluate( "--use-keywords", help="Enable keyword boosting for compatible pipelines (overrides default config)", ), + force_language: bool = typer.Option( + False, + "--force-language", + help="Force language hinting for compatible pipelines", + ), verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output"), ) -> None: """Run evaluation benchmarks. @@ -393,6 +405,7 @@ def evaluate( wandb_run_name=wandb_run_name, wandb_tags=wandb_tags, use_keywords=use_keywords, + force_language=force_language, verbose=verbose, ) display_result(result) diff --git a/src/openbench/dataset/dataset_aliases.py b/src/openbench/dataset/dataset_aliases.py index 5d20a2a..0787913 100644 --- a/src/openbench/dataset/dataset_aliases.py +++ b/src/openbench/dataset/dataset_aliases.py @@ -344,7 +344,11 @@ def register_dataset_aliases() -> None: # Portuguese DatasetRegistry.register_alias( "common-voice-pt", - DatasetConfig(dataset_id="argmaxinc/common_voice_17_0-argmax_subset-400-openbench", split="test", subset="pt"), + DatasetConfig( + dataset_id="argmaxinc/common_voice_17_0-argmax_subset-400-openbench", + split="test", + subset="pt", + ), supported_pipeline_types={ PipelineType.TRANSCRIPTION, }, diff --git a/src/openbench/dataset/dataset_orchestration.py b/src/openbench/dataset/dataset_orchestration.py index a12d7d4..0ee3dda 100644 --- a/src/openbench/dataset/dataset_orchestration.py +++ b/src/openbench/dataset/dataset_orchestration.py @@ -1,8 +1,6 @@ # For licensing see accompanying LICENSE.md file. # Copyright (C) 2025 Argmax, Inc. All Rights Reserved. -from typing import Any - from datasets import Audio as HfAudio from pydantic import model_validator from typing_extensions import NotRequired, TypedDict @@ -11,6 +9,12 @@ from .dataset_base import BaseDataset, BaseSample +class OrchestrationExtraInfo(TypedDict, total=False): + """Extra info for orchestration samples.""" + + language: str + + class OrchestrationRow(TypedDict): """Expected row structure for orchestration datasets.""" @@ -19,9 +23,10 @@ class OrchestrationRow(TypedDict): word_speakers: list[str] word_timestamps_start: NotRequired[list[float]] word_timestamps_end: NotRequired[list[float]] + language: NotRequired[str] -class OrchestrationSample(BaseSample[Transcript, dict[str, Any]]): +class OrchestrationSample(BaseSample[Transcript, OrchestrationExtraInfo]): """Orchestration sample with speaker validation.""" @model_validator(mode="after") @@ -31,6 +36,11 @@ def validate_speaker_labels(self) -> "OrchestrationSample": raise ValueError("Orchestration samples require transcript with speaker labels") return self + @property + def language(self) -> str | None: + """Convenience property to access language from extra_info.""" + return self.extra_info.get("language") + class OrchestrationDataset(BaseDataset[OrchestrationSample]): """Dataset for orchestration pipelines.""" @@ -38,7 +48,7 @@ class OrchestrationDataset(BaseDataset[OrchestrationSample]): _expected_columns = ["audio", "transcript", "word_speakers"] _sample_class = OrchestrationSample - def prepare_sample(self, row: OrchestrationRow) -> tuple[Transcript, dict[str, Any]]: + def prepare_sample(self, row: OrchestrationRow) -> tuple[Transcript, OrchestrationExtraInfo]: """Prepare transcript with speaker labels and extra info from dataset row.""" transcript_words = row["transcript"] word_speakers = row["word_speakers"] @@ -52,5 +62,7 @@ def prepare_sample(self, row: OrchestrationRow) -> tuple[Transcript, dict[str, A end=row.get("word_timestamps_end"), speaker=word_speakers, ) - extra_info: dict[str, Any] = {} + extra_info: OrchestrationExtraInfo = {} + if "language" in row: + extra_info["language"] = row["language"] return reference, extra_info diff --git a/src/openbench/engine/deepgram_engine.py b/src/openbench/engine/deepgram_engine.py index da75286..1da0745 100644 --- a/src/openbench/engine/deepgram_engine.py +++ b/src/openbench/engine/deepgram_engine.py @@ -38,6 +38,9 @@ def __init__(self, options: PrerecordedOptions, timeout: Timeout = Timeout(300)) self.client = DeepgramClient(os.getenv("DEEPGRAM_API_KEY")) + def set_language(self, language: str) -> None: + self.options.language = language + # Only intended to be used with offiline transcription def transcribe(self, audio_path: Path | str, keyterm: str | None = None) -> DeepgramApiResponse: # Manually construct URL with keyterm parameter using + separator diff --git a/src/openbench/engine/openai_engine.py b/src/openbench/engine/openai_engine.py index 9f24ddd..258fa00 100644 --- a/src/openbench/engine/openai_engine.py +++ b/src/openbench/engine/openai_engine.py @@ -37,7 +37,7 @@ def get_transcription_kwargs(self) -> dict[str, Any]: } def transcribe( - self, audio_path: Path | str, prompt: str | None = None + self, audio_path: Path | str, prompt: str | None = None, language: str | None = None ) -> TranscriptionVerbose | TranscriptionDiarized: if isinstance(audio_path, str): audio_path = Path(audio_path) @@ -50,6 +50,9 @@ def transcribe( if prompt is not None: kwargs["prompt"] = prompt + if language is not None: + kwargs["language"] = language + response = self.client.audio.transcriptions.create(**kwargs) return response diff --git a/src/openbench/engine/whisperkitpro_engine.py b/src/openbench/engine/whisperkitpro_engine.py index bd56fcd..8a855e9 100644 --- a/src/openbench/engine/whisperkitpro_engine.py +++ b/src/openbench/engine/whisperkitpro_engine.py @@ -227,7 +227,6 @@ def download_and_prepare_model(self) -> Path: raise RuntimeError(f"Model download succeeded but path doesn't exist: {model_path}") return model_path - except Exception as e: raise RuntimeError(f"Failed to download model from {self.repo_id}: {e}") from e @@ -238,6 +237,7 @@ class WhisperKitProInput(BaseModel): audio_path: Path keep_audio: bool = False custom_vocabulary_path: str | None = Field(None, description="Optional path to custom vocabulary file") + language: str | None = Field(None, description="Optional language hint for transcription") class WhisperKitProOutput(BaseModel): @@ -295,6 +295,10 @@ def __call__(self, input: WhisperKitProInput) -> WhisperKitProOutput: if input.custom_vocabulary_path: cmd.extend(["--custom-vocabulary-path", input.custom_vocabulary_path]) + # Add language hint if provided + if input.language: + cmd.extend(["--language", input.language]) + if "WHISPERKITPRO_API_KEY" in os.environ: cmd.extend(["--api-key", os.environ["WHISPERKITPRO_API_KEY"]]) else: diff --git a/src/openbench/pipeline/orchestration/common.py b/src/openbench/pipeline/orchestration/common.py index 5d94b7e..5e2820d 100644 --- a/src/openbench/pipeline/orchestration/common.py +++ b/src/openbench/pipeline/orchestration/common.py @@ -23,6 +23,15 @@ logger = get_logger(__name__) +class OrchestrationConfig(PipelineConfig): + """Base configuration for orchestration pipelines.""" + + force_language: bool = Field( + False, + description="Force the language of the audio files i.e. hint the model to use the correct language.", + ) + + class OrchestrationOutput(PipelineOutput[Transcript]): transcription_output: TranscriptionOutput | None = Field( default=None, diff --git a/src/openbench/pipeline/orchestration/nemo/orchestration_mt_parakeet.py b/src/openbench/pipeline/orchestration/nemo/orchestration_mt_parakeet.py index d352ae6..e656fc9 100644 --- a/src/openbench/pipeline/orchestration/nemo/orchestration_mt_parakeet.py +++ b/src/openbench/pipeline/orchestration/nemo/orchestration_mt_parakeet.py @@ -5,7 +5,7 @@ from typing import Callable import torch -from argmaxtools.utils import get_fastest_device +from argmaxtools.utils import get_fastest_device, get_logger from nemo.collections.asr.models import ASRModel, SortformerEncLabelModel # Use the helper class `SpeakerTaggedASR`, which handles all ASR and diarization cache data for streaming. @@ -16,22 +16,25 @@ from ....dataset import OrchestrationSample from ....pipeline_prediction import Transcript, Word -from ...base import Pipeline, PipelineConfig, PipelineType, register_pipeline +from ...base import Pipeline, PipelineType, register_pipeline from ...diarization.nemo.sortformer_pipeline import NeMoSortformerPipelineInput -from ..common import OrchestrationOutput +from ..common import OrchestrationConfig, OrchestrationOutput # Use the pre-defined dataclass template `MultitalkerTranscriptionConfig` from `multitalker_transcript_config.py`. # Configure the diarization model using streaming parameters: from .multitalker_transcript_config import MultitalkerTranscriptionConfig +logger = get_logger(__name__) + + # Constants TEMP_AUDIO_DIR = Path("./temp_audio") __all__ = ["NeMoMTParakeetPipeline", "NeMoMTParakeetPipelineConfig"] -class NeMoMTParakeetPipelineConfig(PipelineConfig): +class NeMoMTParakeetPipelineConfig(OrchestrationConfig): diar_model_id: str = Field( default="nvidia/diar_streaming_sortformer_4spk-v2.1", description="The ID of the diarization model to use.", @@ -97,6 +100,14 @@ def inference(sample: NeMoSortformerPipelineInput) -> list[dict[str, str]]: def parse_input(self, input_sample: OrchestrationSample) -> OrchestrationSample: assert input_sample.sample_rate == 16000, "Sample rate must be 16kHz" + + # Warn if force_language is enabled (not currently supported) + if self.config.force_language: + logger.warning( + f"{self.__class__.__name__} does not support language hinting. " + "The force_language flag will be ignored." + ) + parsed_input = NeMoSortformerPipelineInput( audio_path=input_sample.save_audio(TEMP_AUDIO_DIR), keep_audio=False, diff --git a/src/openbench/pipeline/orchestration/orchestration_deepgram.py b/src/openbench/pipeline/orchestration/orchestration_deepgram.py index 2ac5518..3a95404 100644 --- a/src/openbench/pipeline/orchestration/orchestration_deepgram.py +++ b/src/openbench/pipeline/orchestration/orchestration_deepgram.py @@ -1,21 +1,24 @@ from pathlib import Path from typing import Callable +from argmaxtools.utils import get_logger from deepgram import PrerecordedOptions from pydantic import Field from ...dataset import DiarizationSample from ...engine import DeepgramApi, DeepgramApiResponse -from ...pipeline import Pipeline, PipelineConfig, register_pipeline +from ...pipeline import Pipeline, register_pipeline from ...pipeline_prediction import Transcript from ...types import PipelineType -from .common import OrchestrationOutput +from .common import OrchestrationConfig, OrchestrationOutput +logger = get_logger(__name__) + TEMP_AUDIO_DIR = Path("temp_audio_dir") -class DeepgramOrchestrationPipelineConfig(PipelineConfig): +class DeepgramOrchestrationPipelineConfig(OrchestrationConfig): model_version: str = Field( default="nova-3", description="The version of the Deepgram model to use", @@ -28,14 +31,22 @@ class DeepgramOrchestrationPipeline(Pipeline): pipeline_type = PipelineType.ORCHESTRATION def build_pipeline(self) -> Callable[[Path], DeepgramApiResponse]: - deepgram_api = DeepgramApi( + # Create base API with auto language detection + base_api = DeepgramApi( options=PrerecordedOptions( - model=self.config.model_version, smart_format=True, diarize=True, detect_language=True + model=self.config.model_version, + smart_format=True, + diarize=True, + detect_language=not self.config.force_language, ) ) def transcribe(audio_path: Path) -> DeepgramApiResponse: - response = deepgram_api.transcribe(audio_path) + # Use language-specific API if language is set, otherwise use base API + if self.current_language: + base_api.set_language(self.current_language) + + response = base_api.transcribe(audio_path) # Remove temporary audio path audio_path.unlink(missing_ok=True) return response @@ -43,6 +54,11 @@ def transcribe(audio_path: Path) -> DeepgramApiResponse: return transcribe def parse_input(self, input_sample: DiarizationSample) -> Path: + # Extract language if force_language is enabled + self.current_language = None + if self.config.force_language: + self.current_language = input_sample.language + return input_sample.save_audio(TEMP_AUDIO_DIR) def parse_output(self, output: DeepgramApiResponse) -> OrchestrationOutput: diff --git a/src/openbench/pipeline/orchestration/orchestration_openai.py b/src/openbench/pipeline/orchestration/orchestration_openai.py index 9a3fb32..867a67d 100644 --- a/src/openbench/pipeline/orchestration/orchestration_openai.py +++ b/src/openbench/pipeline/orchestration/orchestration_openai.py @@ -1,21 +1,24 @@ from pathlib import Path from typing import Callable, Literal +from argmaxtools.utils import get_logger from openai.types.audio import TranscriptionDiarized from pydantic import Field from ...dataset import OrchestrationSample from ...engine import OpenAIApi -from ...pipeline import Pipeline, PipelineConfig, register_pipeline +from ...pipeline import Pipeline, register_pipeline from ...pipeline_prediction import Transcript, Word from ...types import PipelineType -from .common import OrchestrationOutput +from .common import OrchestrationConfig, OrchestrationOutput +logger = get_logger(__name__) + TEMP_AUDIO_DIR = Path("temp_audio_dir") -class OpenAIOrchestrationPipelineConfig(PipelineConfig): +class OpenAIOrchestrationPipelineConfig(OrchestrationConfig): model_version: Literal["gpt-4o-transcribe-diarize"] = Field( default="gpt-4o-transcribe-diarize", description="The version of the OpenAI model to use. Currently only `gpt-4o-transcribe-diarize` is supported.", @@ -31,7 +34,7 @@ def build_pipeline(self) -> Callable[[Path], TranscriptionDiarized]: openai_api = OpenAIApi(model=self.config.model_version) def orchestrate(audio_path: Path) -> TranscriptionDiarized: - response = openai_api.transcribe(audio_path) + response = openai_api.transcribe(audio_path, language=self.current_language) # Remove temporary audio path audio_path.unlink(missing_ok=True) return response @@ -39,6 +42,11 @@ def orchestrate(audio_path: Path) -> TranscriptionDiarized: return orchestrate def parse_input(self, input_sample: OrchestrationSample) -> Path: + # Extract language if force_language is enabled + self.current_language = None + if self.config.force_language: + self.current_language = input_sample.language + return input_sample.save_audio(TEMP_AUDIO_DIR) def parse_output(self, output: TranscriptionDiarized) -> OrchestrationOutput: diff --git a/src/openbench/pipeline/orchestration/orchestration_whisperkitpro.py b/src/openbench/pipeline/orchestration/orchestration_whisperkitpro.py index 52fa7b1..577b229 100644 --- a/src/openbench/pipeline/orchestration/orchestration_whisperkitpro.py +++ b/src/openbench/pipeline/orchestration/orchestration_whisperkitpro.py @@ -11,8 +11,8 @@ from ...dataset import OrchestrationSample from ...engine import WhisperKitPro, WhisperKitProConfig, WhisperKitProInput, WhisperKitProOutput from ...pipeline_prediction import Transcript, Word -from ..base import Pipeline, PipelineConfig, PipelineType, register_pipeline -from .common import OrchestrationOutput +from ..base import Pipeline, PipelineType, register_pipeline +from .common import OrchestrationConfig, OrchestrationOutput logger = get_logger(__name__) @@ -20,7 +20,7 @@ TEMP_AUDIO_DIR = Path("./temp_audio") -class WhisperKitProOrchestrationConfig(PipelineConfig): +class WhisperKitProOrchestrationConfig(OrchestrationConfig): cli_path: str = Field( ..., description="The path to the WhisperKitPro CLI", @@ -93,9 +93,15 @@ def build_pipeline(self) -> WhisperKitPro: return engine def parse_input(self, input_sample: OrchestrationSample) -> WhisperKitProInput: + # Extract language if force_language is enabled + language = None + if self.config.force_language: + language = input_sample.language + return WhisperKitProInput( audio_path=input_sample.save_audio(TEMP_AUDIO_DIR), keep_audio=False, + language=language, ) def parse_output(self, output: WhisperKitProOutput) -> OrchestrationOutput: diff --git a/src/openbench/pipeline/orchestration/whisperx.py b/src/openbench/pipeline/orchestration/whisperx.py index b73513d..498ca34 100644 --- a/src/openbench/pipeline/orchestration/whisperx.py +++ b/src/openbench/pipeline/orchestration/whisperx.py @@ -11,8 +11,8 @@ from ...dataset import DiarizationSample from ...pipeline_prediction import Transcript -from ..base import Pipeline, PipelineConfig, PipelineType, register_pipeline -from .common import OrchestrationOutput +from ..base import Pipeline, PipelineType, register_pipeline +from .common import OrchestrationConfig, OrchestrationOutput logger = get_logger(__name__) @@ -20,7 +20,7 @@ TEMP_AUDIO_DIR = Path("audio_temp") -class WhisperXPipelineConfig(PipelineConfig): +class WhisperXPipelineConfig(OrchestrationConfig): model_name: str = Field( default="tiny", description="The name of the Whisper model to use", @@ -43,11 +43,20 @@ class WhisperXPipelineConfig(PipelineConfig): ) +class WhisperXInput: + """Input for WhisperX CLI.""" + + def __init__(self, audio_path: Path, language: str | None = None): + self.audio_path = audio_path + self.language = language + + class WhisperX: def __init__(self, config: WhisperXPipelineConfig): self.config = config - def __call__(self, audio_path: Path | str) -> pd.DataFrame: + def __call__(self, input: WhisperXInput) -> pd.DataFrame: + audio_path = input.audio_path if isinstance(audio_path, str): audio_path = Path(audio_path) @@ -72,6 +81,10 @@ def __call__(self, audio_path: Path | str) -> pd.DataFrame: "--diarize", ] + # Add language if provided + if input.language: + args.extend(["--language", input.language]) + # Run whisperx CLI subprocess.run(args) @@ -127,11 +140,19 @@ class WhisperXPipeline(Pipeline): _config_class = WhisperXPipelineConfig pipeline_type = PipelineType.ORCHESTRATION - def build_pipeline(self) -> Callable[[Path], pd.DataFrame]: + def build_pipeline(self) -> Callable[[WhisperXInput], pd.DataFrame]: return WhisperX(self.config) - def parse_input(self, input_sample: DiarizationSample) -> Path: - return input_sample.save_audio(TEMP_AUDIO_DIR) + def parse_input(self, input_sample: DiarizationSample) -> WhisperXInput: + # Extract language if force_language is enabled + language = None + if self.config.force_language: + language = input_sample.language + + return WhisperXInput( + audio_path=input_sample.save_audio(TEMP_AUDIO_DIR), + language=language, + ) def parse_output(self, output: pd.DataFrame) -> OrchestrationOutput: output = output.assign(words=lambda df: df["text"].str.split()).explode("words") diff --git a/src/openbench/pipeline/transcription/apple_speech_analyzer.py b/src/openbench/pipeline/transcription/apple_speech_analyzer.py index 7153cc7..992658b 100644 --- a/src/openbench/pipeline/transcription/apple_speech_analyzer.py +++ b/src/openbench/pipeline/transcription/apple_speech_analyzer.py @@ -122,9 +122,15 @@ def build_pipeline(self) -> Callable[[SpeechAnalyzerCliInput], Path]: return engine.transcribe def parse_input(self, input_sample: TranscriptionSample) -> SpeechAnalyzerCliInput: + # Extract language if force_language is enabled + language = None + if self.config.force_language: + language = input_sample.language + return SpeechAnalyzerCliInput( audio_path=input_sample.save_audio(TEMP_AUDIO_DIR), keep_audio=False, + language=language, ) def parse_output(self, output: Path) -> TranscriptionOutput: diff --git a/src/openbench/pipeline/transcription/transcription_assemblyai.py b/src/openbench/pipeline/transcription/transcription_assemblyai.py index 9c7a07b..bed1612 100644 --- a/src/openbench/pipeline/transcription/transcription_assemblyai.py +++ b/src/openbench/pipeline/transcription/transcription_assemblyai.py @@ -99,6 +99,13 @@ def parse_input(self, input_sample: TranscriptionSample) -> Path: if keywords: self.current_keywords = keywords + # Warn if force_language is enabled (not currently supported) + if self.config.force_language: + logger.warning( + f"{self.__class__.__name__} does not support language hinting. " + "The force_language flag will be ignored." + ) + return input_sample.save_audio(TEMP_AUDIO_DIR) def parse_output(self, output: str) -> TranscriptionOutput: diff --git a/src/openbench/pipeline/transcription/transcription_deepgram.py b/src/openbench/pipeline/transcription/transcription_deepgram.py index 02dc146..75973fa 100644 --- a/src/openbench/pipeline/transcription/transcription_deepgram.py +++ b/src/openbench/pipeline/transcription/transcription_deepgram.py @@ -27,10 +27,19 @@ class DeepgramTranscriptionPipeline(Pipeline): pipeline_type = PipelineType.TRANSCRIPTION def build_pipeline(self) -> Callable[[Path], DeepgramApiResponse]: - deepgram_api = DeepgramApi(options=PrerecordedOptions(model=self.config.model_version, smart_format=True)) + # Create base API without language detection + base_api = DeepgramApi( + options=PrerecordedOptions( + model=self.config.model_version, smart_format=True, detect_language=not self.config.force_language + ) + ) def transcribe(audio_path: Path) -> DeepgramApiResponse: - response = deepgram_api.transcribe(audio_path, keyterm=self.current_keywords) + # Use language-specific API if language is set, otherwise use base API + if self.current_language: + base_api.set_language(self.current_language) + + response = base_api.transcribe(audio_path, keyterm=self.current_keywords) # Remove temporary audio path audio_path.unlink(missing_ok=True) return response @@ -38,7 +47,7 @@ def transcribe(audio_path: Path) -> DeepgramApiResponse: return transcribe def parse_input(self, input_sample) -> Path: - """Override to extract keywords from sample before processing.""" + """Override to extract keywords and language from sample before processing.""" self.current_keywords = None if self.config.use_keywords: keywords = input_sample.extra_info.get("dictionary", []) @@ -46,6 +55,11 @@ def parse_input(self, input_sample) -> Path: # Add + between keywords for Deepgram URL self.current_keywords = "+".join(keywords) + # Extract language if force_language is enabled + self.current_language = None + if self.config.force_language: + self.current_language = input_sample.language + return input_sample.save_audio(TEMP_AUDIO_DIR) def parse_output(self, output: DeepgramApiResponse) -> TranscriptionOutput: diff --git a/src/openbench/pipeline/transcription/transcription_nemo.py b/src/openbench/pipeline/transcription/transcription_nemo.py index 7556946..68a1205 100644 --- a/src/openbench/pipeline/transcription/transcription_nemo.py +++ b/src/openbench/pipeline/transcription/transcription_nemo.py @@ -219,6 +219,13 @@ def parse_input(self, input_sample) -> Path: self.context_graph = context_biasing.ContextGraphCTC(blank_id=self.blank_idx) self.context_graph.add_to_graph(context_transcripts) + # Warn if force_language is enabled (not currently supported) + if self.config.force_language: + logger.warning( + f"{self.__class__.__name__} does not support language hinting. " + "The force_language flag will be ignored." + ) + return input_sample.save_audio(TEMP_AUDIO_DIR) def parse_output(self, output: TranscriptionOutput) -> TranscriptionOutput: diff --git a/src/openbench/pipeline/transcription/transcription_openai.py b/src/openbench/pipeline/transcription/transcription_openai.py index f12de12..1c6d6aa 100644 --- a/src/openbench/pipeline/transcription/transcription_openai.py +++ b/src/openbench/pipeline/transcription/transcription_openai.py @@ -33,6 +33,7 @@ def transcribe(audio_path: Path) -> TranscriptionVerbose: response = openai_api.transcribe( audio_path, prompt=self.current_keywords_prompt, + language=self.current_language, ) # Remove temporary audio path audio_path.unlink(missing_ok=True) @@ -41,7 +42,7 @@ def transcribe(audio_path: Path) -> TranscriptionVerbose: return transcribe def parse_input(self, input_sample) -> Path: - """Override to extract keywords from sample before processing.""" + """Override to extract keywords and language from sample before processing.""" # Extract keywords from sample's extra_info if flag is enabled self.current_keywords_prompt = None if self.config.use_keywords: @@ -50,6 +51,11 @@ def parse_input(self, input_sample) -> Path: # Format keywords as comma-separated prompt for OpenAI self.current_keywords_prompt = ", ".join(keywords) + # Extract language if force_language is enabled + self.current_language = None + if self.config.force_language: + self.current_language = input_sample.language + return input_sample.save_audio(TEMP_AUDIO_DIR) def parse_output(self, output: TranscriptionVerbose) -> TranscriptionOutput: diff --git a/src/openbench/pipeline/transcription/transcription_oss_whisper.py b/src/openbench/pipeline/transcription/transcription_oss_whisper.py index 8bda785..622f549 100644 --- a/src/openbench/pipeline/transcription/transcription_oss_whisper.py +++ b/src/openbench/pipeline/transcription/transcription_oss_whisper.py @@ -157,7 +157,7 @@ def parse_input(self, input_sample: TranscriptionSample) -> Path: # Extract language if force_language is enabled self.current_language = None if self.config.force_language: - self.current_language = input_sample.extra_info.get("language", None) + self.current_language = input_sample.language return input_sample.save_audio(TEMP_AUDIO_DIR) diff --git a/src/openbench/pipeline/transcription/transcription_whisperkitpro.py b/src/openbench/pipeline/transcription/transcription_whisperkitpro.py index f57213e..bc0f88a 100644 --- a/src/openbench/pipeline/transcription/transcription_whisperkitpro.py +++ b/src/openbench/pipeline/transcription/transcription_whisperkitpro.py @@ -110,7 +110,7 @@ def build_pipeline(self) -> WhisperKitPro: return engine def parse_input(self, input_sample: TranscriptionSample) -> WhisperKitProInput: - """Override to extract keywords from sample before processing.""" + """Override to extract keywords and language from sample before processing.""" # Extract keywords from sample's extra_info if flag is enabled custom_vocab_path = None if self.config.use_keywords: @@ -129,10 +129,16 @@ def parse_input(self, input_sample: TranscriptionSample) -> WhisperKitProInput: custom_vocab_path = str(vocab_file) logger.debug(f"Created custom vocabulary file: {custom_vocab_path} with {len(keywords)} keywords") + # Extract language if force_language is enabled + language = None + if self.config.force_language: + language = input_sample.language + return WhisperKitProInput( audio_path=input_sample.save_audio(TEMP_AUDIO_DIR), keep_audio=False, custom_vocabulary_path=custom_vocab_path, + language=language, ) def parse_output(self, output: WhisperKitProOutput) -> TranscriptionOutput: diff --git a/src/openbench/pipeline/transcription/whisperkit.py b/src/openbench/pipeline/transcription/whisperkit.py index 789ab7d..2e0f444 100644 --- a/src/openbench/pipeline/transcription/whisperkit.py +++ b/src/openbench/pipeline/transcription/whisperkit.py @@ -257,9 +257,15 @@ def build_pipeline(self) -> Callable[[TranscriptionCliInput], TranscriptionCliOu return engine.transcribe def parse_input(self, input_sample: TranscriptionSample) -> TranscriptionCliInput: + # Extract language if force_language is enabled + language = None + if self.config.force_language: + language = input_sample.language + return TranscriptionCliInput( audio_path=input_sample.save_audio(TEMP_AUDIO_DIR), keep_audio=False, + language=language, ) def parse_output(self, output: TranscriptionCliOutput) -> TranscriptionOutput: