From 356264fa9f566e8edecbcbd4f7a7952e5a161c9d Mon Sep 17 00:00:00 2001 From: yonishelach Date: Wed, 8 Mar 2023 11:01:28 +0200 Subject: [PATCH 01/11] [Build] Fix html links, Add .html as source in documentation --- cli/marketplace/build.py | 116 ++++++++++++++++++++++++++------------- 1 file changed, 78 insertions(+), 38 deletions(-) diff --git a/cli/marketplace/build.py b/cli/marketplace/build.py index 1a6c56b76..bc6d91293 100644 --- a/cli/marketplace/build.py +++ b/cli/marketplace/build.py @@ -18,7 +18,7 @@ import subprocess import uuid from pathlib import Path -from typing import Union, Optional, Set, Dict, List +from typing import Dict, List, Optional, Set, Union import click import yaml @@ -26,13 +26,8 @@ from sphinx.cmd.build import main as sphinx_build_cmd from sphinx.ext.apidoc import main as sphinx_apidoc_cmd -from cli.helpers import ( - is_item_dir, - render_jinja, - PROJECT_ROOT, - get_item_yaml_values, - get_mock_requirements, -) +from cli.helpers import (PROJECT_ROOT, get_item_yaml_values, + get_mock_requirements, is_item_dir, render_jinja) from cli.marketplace.changelog import ChangeLog from cli.path_iterator import PathIterator @@ -68,6 +63,14 @@ default=False, help="When this flag is set, the process will output extra information", ) +@click.option( + "-f", + "--force-update", + "force_update_items", + is_flag=True, + default=False, + help="When this flag is set, item pages will be created even if the item did not changed", +) def build_marketplace_cli( source_dir: str, source_name: str, @@ -75,6 +78,7 @@ def build_marketplace_cli( temp_dir: str, channel: str, verbose: bool, + force_update_items: bool, ): build_marketplace( source_dir, @@ -83,6 +87,7 @@ def build_marketplace_cli( temp_dir, channel, verbose, + force_update_items, ) @@ -93,6 +98,7 @@ def build_marketplace( temp_dir: str = "/tmp", channel: str = "development", verbose: bool = False, + force_update_items: bool = False, ): """Main entry point to marketplace building @@ -103,6 +109,8 @@ def build_marketplace( if not provided '/tmp/' will be used :param channel: The name of the marketplace channel to write to :param verbose: When True, additional debug information will be written to stdout + :param force_update_items: If True, items will be updated unrelated if they are not changed. + The purpose of this flag is to fix existed broken pages (e.g. broken links) """ global _verbose _verbose = verbose @@ -152,9 +160,15 @@ def build_marketplace( render_html_files(temp_docs) change_log = ChangeLog() - copy_static_resources(marketplace_dir, temp_docs) + copy_resources(marketplace_dir, temp_docs) - update_or_create_items(source_dir, marketplace_dir, temp_docs, change_log) + update_or_create_items( + source_dir, + marketplace_dir, + temp_docs, + change_log, + force_update=force_update_items, + ) build_catalog_json( marketplace_dir=marketplace_dir, catalog_path=(marketplace_root / "catalog.json"), @@ -212,18 +226,22 @@ def write_index_html(marketplace_root: Union[str, Path]): shutil.copy(template_path, index_path) -def copy_static_resources(marketplace_dir, temp_docs): +def copy_resources(marketplace_dir, temp_docs): marketplace_static = marketplace_dir / "_static" - if not marketplace_static.exists(): - click.echo("Copying static resources...") - shutil.copytree(temp_docs / "_build/_static", marketplace_static) - shutil.copytree(temp_docs / "_build/_modules", marketplace_dir / "_modules") + click.echo("Copying static resources...") + shutil.copytree( + temp_docs / "_build/_static", marketplace_static, dirs_exist_ok=True + ) -def update_or_create_items(source_dir, marketplace_dir, temp_docs, change_log): +def update_or_create_items( + source_dir, marketplace_dir, temp_docs, change_log, force_update: bool = False +): click.echo("Creating items...") for item_dir in PathIterator(root=source_dir, rule=is_item_dir, as_path=True): - update_or_create_item(item_dir, marketplace_dir, temp_docs, change_log) + update_or_create_item( + item_dir, marketplace_dir, temp_docs, change_log, force_update + ) def build_catalog_json( @@ -329,17 +347,22 @@ def add_assets(item_yaml: dict): def update_or_create_item( - item_dir: Path, marketplace_dir: Path, temp_docs: Path, change_log: ChangeLog + item_dir: Path, + marketplace_dir: Path, + temp_docs: Path, + change_log: ChangeLog, + force_update: bool = False, ): # Copy source directories to target directories, if target already has the directory, archive previous version item_yaml = yaml.full_load(open(item_dir / "item.yaml", "r")) source_version = item_yaml["version"] + relative_path = "../../../" marketplace_item = marketplace_dir / item_dir.stem target_latest = marketplace_item / "latest" target_version = marketplace_item / source_version - if target_version.exists(): + if target_version.exists() and not force_update: latest_item_yaml = yaml.full_load( open(target_latest / "src" / "item.yaml", "r") ) @@ -351,16 +374,21 @@ def update_or_create_item( example_html_name = f"{item_dir.stem}_example.html" build_path = temp_docs / "_build" - source_html = marketplace_dir / "_modules" / item_dir.stem / f"{item_dir.stem}.html" - update_html_resource_paths(source_html, relative_path="../") + source_html = ( + temp_docs / "_build" / "_modules" / item_dir.stem / f"{item_dir.stem}.html" + ) + update_html_resource_paths(source_html, relative_path=relative_path) documentation_html = build_path / documentation_html_name update_html_resource_paths( - documentation_html, relative_path="../../../", with_download=False, item_name=item_dir.stem + documentation_html, + relative_path=relative_path, + with_download=False, + item_name=item_dir.stem, ) example_html = build_path / example_html_name - update_html_resource_paths(example_html, relative_path="../../../") + update_html_resource_paths(example_html, relative_path=relative_path) latest_src = target_latest / "src" version_src = target_version / "src" @@ -445,34 +473,38 @@ def update_or_create_item( def update_html_resource_paths( - html_path: Path, relative_path: str, with_download: bool = True, item_name: str = None + html_path: Path, + relative_path: str, + with_download: bool = True, + item_name: str = None, ): if html_path.exists(): with open(html_path, "r", encoding="utf8") as html: parsed = BeautifulSoup(html.read(), features="html.parser") # Update back to docs link (from source page) - back_to_docs_nodes = parsed.find_all(lambda node: "viewcode-back" in node.get("class", "")) - pattern = r"^.*?(?=.html)" + back_to_docs_nodes = parsed.find_all( + lambda node: "viewcode-back" in node.get("class", "") + ) + pattern = r"^.*?(?={})" for node in back_to_docs_nodes: - node["href"] = re.sub(pattern, "documentation", node["href"]) + node["href"] = re.sub( + pattern.format(".html"), "documentation", node["href"] + ) - # Remove _static from links and replace with src + # Fix links with relative paths: nodes = parsed.find_all( - lambda node: node.name == "link" and "_static" in node.get("href", "") + lambda node: "_static" in node.get("src", "") + or "_static" in node.get("href", "") ) for node in nodes: - node["href"] = f"{relative_path}{node['href']}" + key = "href" if "_static" in node.get("href", "") else "src" + node[key] = re.sub(pattern.format("_static"), relative_path, node[key]) - nodes = parsed.find_all( - lambda node: node.name == "script" - and node.get("src", "").startswith("_static") - ) - for node in nodes: - node["src"] = f"{relative_path}{node['src']}" if with_download: nodes = parsed.find_all(lambda node: "_sources" in node.get("href", "")) for node in nodes: + # fix path and remove example from name: node[ "href" ] = f'../{node["href"].replace("_sources", "src").replace("_example", "")}' @@ -487,7 +519,9 @@ def update_html_resource_paths( # Fix links in source page: if item_name: - nodes = parsed.find_all(lambda node: node.name == "a" and "_modules" in node.get("href", "")) + nodes = parsed.find_all( + lambda node: node.name == "a" and "_modules" in node.get("href", "") + ) for node in nodes: node["href"] = node["href"].replace(f"_modules/{item_name}/", "") @@ -653,4 +687,10 @@ def build_temp_docs(temp_root, temp_docs): if __name__ == "__main__": # build_marketplace_cli() - build_marketplace("../../", "../../../marketp", verbose=True) + build_marketplace( + source_dir="../../../functions", + marketplace_dir="../../../marketplace", + verbose=True, + channel="development", + force_update_items=True, + ) From 2202cafb5adb8fcc8cc93bb21627143a7df494fa Mon Sep 17 00:00:00 2001 From: yonishelach Date: Wed, 19 Apr 2023 09:36:25 +0300 Subject: [PATCH 02/11] Update CI temporarily and update index --- .github/workflows/test-all.yaml | 2 +- cli/marketplace/index.html | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-all.yaml b/.github/workflows/test-all.yaml index e946b11b3..d0eb03579 100644 --- a/.github/workflows/test-all.yaml +++ b/.github/workflows/test-all.yaml @@ -167,7 +167,7 @@ jobs: pwd git pull origin cd .. - python functions/functions.py build-marketplace -s functions -m marketplace -c $CHANNEL -v + python functions/functions.py build-marketplace -s functions -m marketplace -c $CHANNEL -v -f - name: Publish marketplace release env: GITHUB_TOKEN: ${{ secrets.MARKETPLACE_ACCESS_TOKEN_V3 }} diff --git a/cli/marketplace/index.html b/cli/marketplace/index.html index b6e38cfa3..d1030e92f 100644 --- a/cli/marketplace/index.html +++ b/cli/marketplace/index.html @@ -115,6 +115,7 @@ item.example = item.example ? `${base_url}/static/example.html` : null; item.functionPath = `${base_url}/static/function.html`; item.itemPath = `${base_url}/static/item.html`; + item.code = `${base_url}/static/${item.rawName}.html`; table.addRow(item); } @@ -189,6 +190,18 @@ }, }, width: 150 + }, + { + title: "Source Code", + field: "code", + headerSort: false, + formatter: "link", + formatterParams: { + label: (cell) => { + return (cell._cell.value === undefined ? '': 'View'); + }, + }, + width: 150 }, { title: "Deployment", From 8b17c8f3dea43adcde9f9a389f16564eac082dca Mon Sep 17 00:00:00 2001 From: yonishelach Date: Wed, 19 Apr 2023 10:20:07 +0300 Subject: [PATCH 03/11] [XGB-Custom] Fix test artifact key name --- xgb_custom/test_xgb_custom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xgb_custom/test_xgb_custom.py b/xgb_custom/test_xgb_custom.py index 6cdada74b..81b77a2e0 100644 --- a/xgb_custom/test_xgb_custom.py +++ b/xgb_custom/test_xgb_custom.py @@ -41,7 +41,7 @@ def test_local_xgb_custom(): "verbose_eval": False, "XGB_max_depth": 2, "XGB_subsample": 0.9, - "test_set_key": "./artifacts/inputs/test-set", + "test_set_key": "test-set", }, inputs={"dataset": run.artifact('xgb-outs').url}, handler="fit", From a469dcaebeea7cb8874fc05eb18b9728019cde84 Mon Sep 17 00:00:00 2001 From: yonishelach Date: Wed, 19 Apr 2023 13:43:33 +0300 Subject: [PATCH 04/11] [XGB-Serving][XGB-Test][XGB-Trainer] Fix tests - artifact key --- xgb_serving/test_xgb_serving.py | 3 +-- xgb_test/test_xgb_test.py | 2 -- xgb_trainer/test_xgb_trainer.py | 7 +++---- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/xgb_serving/test_xgb_serving.py b/xgb_serving/test_xgb_serving.py index 9f4dd8244..ce5e8aaa8 100644 --- a/xgb_serving/test_xgb_serving.py +++ b/xgb_serving/test_xgb_serving.py @@ -37,8 +37,7 @@ def test_local_xgb_serving(): "CLASS_objective": "binary:logistic", "CLASS_booster": "gbtree", "FIT_verbose": 0, - "label_column": "labels", - "test_set": "./"}, + "label_column": "labels"}, local=True, inputs={"dataset": gen_data_run.artifact('classifier-data').url}, artifact_path='./') diff --git a/xgb_test/test_xgb_test.py b/xgb_test/test_xgb_test.py index 71dfee3e2..e56b9db01 100644 --- a/xgb_test/test_xgb_test.py +++ b/xgb_test/test_xgb_test.py @@ -51,7 +51,6 @@ def xgb_trainer(): "CLASS_booster": "gbtree", "FIT_verbose": 0, "label_column": "labels", - "test_set": "./artifacts/test-set", }, local=True, inputs={"dataset": data}, @@ -111,7 +110,6 @@ def test_local_xgb_test_import_local_function(): "CLASS_booster": "gbtree", "FIT_verbose": 0, "label_column": "labels", - "test_set": "./artifacts/test-set", }, local=True, inputs={"dataset": data}, diff --git a/xgb_trainer/test_xgb_trainer.py b/xgb_trainer/test_xgb_trainer.py index 1356f72c7..52df8db48 100644 --- a/xgb_trainer/test_xgb_trainer.py +++ b/xgb_trainer/test_xgb_trainer.py @@ -29,6 +29,7 @@ def get_class_data(): 'file_ext': 'csv'}, local=True, artifact_path='./') return run + def test_xgb_trainer_code_to_function(): gen_data_run = get_class_data() fn = code_to_function(name='test_xgb_trainer', @@ -41,8 +42,7 @@ def test_xgb_trainer_code_to_function(): 'CLASS_objective': 'binary:logistic', 'CLASS_booster': 'gbtree', 'FIT_verbose': 0, - 'label_column': 'labels', - 'test_set': './'}, + 'label_column': 'labels'}, local=False, inputs={'dataset': gen_data_run.artifact('classifier-data').url}) @@ -59,8 +59,7 @@ def test_local_xgb_trainer_import_function(): 'CLASS_objective': 'binary:logistic', 'CLASS_booster': 'gbtree', 'FIT_verbose': 0, - 'label_column': 'labels', - 'test_set': './'}, + 'label_column': 'labels'}, local=True, inputs={'dataset': gen_data_run.artifact('classifier-data').url}) assert (run.artifact('model')) \ No newline at end of file From 3301415200e52326bade1e17f99cb6b6d3880860 Mon Sep 17 00:00:00 2001 From: Yoni Shelach <92271540+yonishelach@users.noreply.github.com> Date: Wed, 19 Apr 2023 19:05:40 +0300 Subject: [PATCH 05/11] [Build] Install python 3.9 when testing (#618) --- cli/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/helpers.py b/cli/helpers.py index 64d5505c5..75661f345 100644 --- a/cli/helpers.py +++ b/cli/helpers.py @@ -65,7 +65,7 @@ def install_pipenv(): def install_python(directory: Union[str, Path]): print(f"Installing python for {directory}...") python_install: subprocess.CompletedProcess = subprocess.run( - f"pipenv --rm;pipenv --python 3.7", + f"pipenv --rm;pipenv --python 3.9.13", stdout=sys.stdout, stderr=subprocess.PIPE, cwd=directory, From 0cd1f1585a618c253f201b6f5a63502cdbddb591 Mon Sep 17 00:00:00 2001 From: Yoni Shelach <92271540+yonishelach@users.noreply.github.com> Date: Wed, 19 Apr 2023 19:41:19 +0300 Subject: [PATCH 06/11] [Build] Update python version in CI (#620) * [Build] Install python 3.9 when testing * [Build] Update python version in CI * . --- .github/workflows/ci.yaml | 2 +- .github/workflows/test-all.yaml | 6 +++--- cli/helpers.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 155634dfc..1ab67ffc7 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -107,7 +107,7 @@ jobs: - name: Install python uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.9 # Install dependencies - uses: actions/cache@v1 id: cache diff --git a/.github/workflows/test-all.yaml b/.github/workflows/test-all.yaml index d0eb03579..fe7248bcd 100644 --- a/.github/workflows/test-all.yaml +++ b/.github/workflows/test-all.yaml @@ -70,7 +70,7 @@ jobs: - name: Install python uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.9 # Install dependencies - uses: actions/cache@v1 id: cache @@ -106,7 +106,7 @@ jobs: # - name: Install python # uses: actions/setup-python@v2 # with: -# python-version: 3.7 +# python-version: 3.9 # # Install dependencies # - uses: actions/cache@v1 # id: cache @@ -153,7 +153,7 @@ jobs: - name: Install python uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.9 - name: Install requirements run: | cd functions diff --git a/cli/helpers.py b/cli/helpers.py index 75661f345..b44ebc92e 100644 --- a/cli/helpers.py +++ b/cli/helpers.py @@ -65,7 +65,7 @@ def install_pipenv(): def install_python(directory: Union[str, Path]): print(f"Installing python for {directory}...") python_install: subprocess.CompletedProcess = subprocess.run( - f"pipenv --rm;pipenv --python 3.9.13", + f"pipenv --rm;pipenv --python 3.9", stdout=sys.stdout, stderr=subprocess.PIPE, cwd=directory, From 33e7ab8c43579b8609ed4f9654cc7b0d0f06671a Mon Sep 17 00:00:00 2001 From: Yoni Shelach <92271540+yonishelach@users.noreply.github.com> Date: Wed, 19 Apr 2023 19:47:50 +0300 Subject: [PATCH 07/11] Revert "[Build] Update python version in CI (#620)" (#621) This reverts commit 0cd1f1585a618c253f201b6f5a63502cdbddb591. --- .github/workflows/ci.yaml | 2 +- .github/workflows/test-all.yaml | 6 +++--- cli/helpers.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 1ab67ffc7..155634dfc 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -107,7 +107,7 @@ jobs: - name: Install python uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.7 # Install dependencies - uses: actions/cache@v1 id: cache diff --git a/.github/workflows/test-all.yaml b/.github/workflows/test-all.yaml index fe7248bcd..d0eb03579 100644 --- a/.github/workflows/test-all.yaml +++ b/.github/workflows/test-all.yaml @@ -70,7 +70,7 @@ jobs: - name: Install python uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.7 # Install dependencies - uses: actions/cache@v1 id: cache @@ -106,7 +106,7 @@ jobs: # - name: Install python # uses: actions/setup-python@v2 # with: -# python-version: 3.9 +# python-version: 3.7 # # Install dependencies # - uses: actions/cache@v1 # id: cache @@ -153,7 +153,7 @@ jobs: - name: Install python uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.7 - name: Install requirements run: | cd functions diff --git a/cli/helpers.py b/cli/helpers.py index b44ebc92e..75661f345 100644 --- a/cli/helpers.py +++ b/cli/helpers.py @@ -65,7 +65,7 @@ def install_pipenv(): def install_python(directory: Union[str, Path]): print(f"Installing python for {directory}...") python_install: subprocess.CompletedProcess = subprocess.run( - f"pipenv --rm;pipenv --python 3.9", + f"pipenv --rm;pipenv --python 3.9.13", stdout=sys.stdout, stderr=subprocess.PIPE, cwd=directory, From 7a7473b8f41e80032f381d927214a9076a4a55b8 Mon Sep 17 00:00:00 2001 From: Yoni Shelach <92271540+yonishelach@users.noreply.github.com> Date: Wed, 19 Apr 2023 19:48:09 +0300 Subject: [PATCH 08/11] Revert "[Build] Install python 3.9 when testing (#618)" (#619) This reverts commit 3301415200e52326bade1e17f99cb6b6d3880860. --- cli/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/helpers.py b/cli/helpers.py index 75661f345..64d5505c5 100644 --- a/cli/helpers.py +++ b/cli/helpers.py @@ -65,7 +65,7 @@ def install_pipenv(): def install_python(directory: Union[str, Path]): print(f"Installing python for {directory}...") python_install: subprocess.CompletedProcess = subprocess.run( - f"pipenv --rm;pipenv --python 3.9.13", + f"pipenv --rm;pipenv --python 3.7", stdout=sys.stdout, stderr=subprocess.PIPE, cwd=directory, From 81437da88e99ed48a9e1b0b0b367c4c02db80140 Mon Sep 17 00:00:00 2001 From: Yoni Shelach <92271540+yonishelach@users.noreply.github.com> Date: Wed, 19 Apr 2023 20:26:27 +0300 Subject: [PATCH 09/11] [Build] Build with python 3.9 (#622) * [Build] Build with python 3.9 * . --- .github/workflows/ci.yaml | 5 +++++ .github/workflows/test-all.yaml | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 155634dfc..5b4bfcd79 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -108,6 +108,11 @@ jobs: uses: actions/setup-python@v2 with: python-version: 3.7 + # Install python 3.9 + - name: Install python 3.9 + uses: actions/setup-python@v2 + with: + python-version: 3.9 # Install dependencies - uses: actions/cache@v1 id: cache diff --git a/.github/workflows/test-all.yaml b/.github/workflows/test-all.yaml index d0eb03579..5eff03b0f 100644 --- a/.github/workflows/test-all.yaml +++ b/.github/workflows/test-all.yaml @@ -150,10 +150,10 @@ jobs: with: repository: mlrun/marketplace path: marketplace - - name: Install python + - name: Install python 3.9 uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.9 - name: Install requirements run: | cd functions From 17b6eb054b9542ccc72c2b9a618f01bfcf4393f0 Mon Sep 17 00:00:00 2001 From: guy1992l Date: Wed, 10 Jan 2024 18:01:47 +0200 Subject: [PATCH 10/11] fix1 --- silero_vad/silero_vad.py | 80 ++++++++++++++++++++++++++++++++-------- 1 file changed, 65 insertions(+), 15 deletions(-) diff --git a/silero_vad/silero_vad.py b/silero_vad/silero_vad.py index a7a44e055..08ff585fa 100644 --- a/silero_vad/silero_vad.py +++ b/silero_vad/silero_vad.py @@ -66,6 +66,14 @@ def get_result(self) -> Tuple[str, list]: """ return self._audio_file.name, self._result + def to_tuple(self) -> Tuple[str, dict]: + """ + Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue). + + :returns: The converted task. + """ + return self.__class__.__name__, {"audio_file": self._audio_file} + class SpeechDiarizationTask(BaseTask): """ @@ -105,12 +113,27 @@ def do_task(self, speech_timestamps: List[List[Dict[str, int]]]): speech_diarization.sort() self._result = speech_diarization + def to_tuple(self) -> Tuple[str, dict]: + """ + Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue). + + :returns: The converted task. + """ + task_class, task_kwargs = super().to_tuple() + return task_class, {**task_kwargs, "speaker_labels": self._speaker_labels} + class TaskCreator: """ A task creator to create different tasks to run after the VAD. """ + #: A map from task class name to task class to use in `from_tuple`: + _MAP = { + BaseTask.__name__: BaseTask, + SpeechDiarizationTask.__name__: SpeechDiarizationTask, + } + def __init__(self, task_type: Type[BaseTask], task_kwargs: dict = None): """ Initialize the task creator. @@ -120,7 +143,7 @@ def __init__(self, task_type: Type[BaseTask], task_kwargs: dict = None): self._task_type = task_type self._task_kwargs = task_kwargs or {} - def create_task(self, audio_file: Path): + def create_task(self, audio_file: Path) -> BaseTask: """ Create a task with the given audio file. @@ -130,6 +153,18 @@ def create_task(self, audio_file: Path): """ return self._task_type(audio_file=audio_file, **self._task_kwargs) + @classmethod + def from_tuple(cls, task_tuple: Tuple[str, dict]) -> BaseTask: + """ + Create a task from a tuple of the audio file name and the task kwargs. + + :param task_tuple: The task tuple to create the task from. + + :returns: The created task. + """ + task_class, task_kwargs = task_tuple + return cls._MAP[task_class](**task_kwargs) + class VoiceActivityDetector: """ @@ -196,14 +231,16 @@ def __init__( self._model: torch.Module = None self._get_speech_timestamps: FunctionType = None - def load(self): + def load(self, force_reload: bool = True): """ Load the VAD model. + + :param force_reload: Whether to force reload the model even if it was already loaded. Default is True. """ model, utils = torch.hub.load( repo_or_dir="snakers4/silero-vad", model="silero_vad", - force_reload=True, + force_reload=force_reload, onnx=self._use_onnx, force_onnx_cpu=self._force_onnx_cpu, ) @@ -317,31 +354,40 @@ def _multiprocessing_complete_tasks( """ # Initialize and load the VAD: vad = VoiceActivityDetector(**vad_init_kwargs) - vad.load() + vad.load(force_reload=False) # Start listening to the tasks queue: while True: # Get the task: - task: BaseTask = tasks_queue.get() + task: Tuple[str, dict] = tasks_queue.get() if task == _MULTIPROCESSING_STOP_MARK: break try: + # Create the task: + task = TaskCreator.from_tuple(task_tuple=task) # Run the file through the VAD: speech_timestamps = vad.detect_voice(audio_file=task.audio_file) # Complete the task: task.do_task(speech_timestamps=speech_timestamps) - # Collect the result: - results_queue.put((False, task.get_result())) + # Build the result: + result = (False, task.get_result()) except Exception as exception: - # Collect the error: - results_queue.put((True, (task.audio_file.name, str(exception)))) + # Build the error: + result = (True, (task.audio_file.name, str(exception))) + # Collect the result / error: + results_queue.put(result) # Mark the end of the tasks: results_queue.put(_MULTIPROCESSING_STOP_MARK) # Get the global logger: -_LOGGER = logging.getLogger() +try: + import mlrun + + _LOGGER = mlrun.get_or_create_ctx("silero_vad").logger +except ModuleNotFoundError: + _LOGGER = logging.getLogger() def detect_voice( @@ -702,6 +748,14 @@ def _parallel_run( :returns: The collected results. """ + # Load the VAD (download once, and it will be loaded then per process later on): + if verbose: + _LOGGER.info(f"Loading the VAD model.") + vad = VoiceActivityDetector(**vad_init_kwargs) + vad.load() + if verbose: + _LOGGER.info("VAD model loaded.") + # Check the number of workers: if n_workers > len(audio_files): _LOGGER.warning( @@ -728,12 +782,8 @@ def _parallel_run( ] # Start the multiprocessing processes: - if verbose: - _LOGGER.info(f"Loading the VAD model (per process).") for p in task_completion_processes: p.start() - if verbose: - _LOGGER.info("VAD model loaded.") # Put the tasks in the queue (the progress bar is not accurate as it is updating by the queue which has a max size # of 2*n_workers so the progress bar won't be too off - better than nothing): @@ -745,7 +795,7 @@ def _parallel_run( disable=not verbose, ): # Put the task in the queue: - tasks_queue.put(task_creator.create_task(audio_file=audio_file)) + tasks_queue.put(task_creator.create_task(audio_file=audio_file).to_tuple()) # Put the stop marks in the queue: for _ in range(n_workers): From c3a141c9668953c595c14d286b7beede6a0d15a1 Mon Sep 17 00:00:00 2001 From: guy1992l Date: Wed, 10 Jan 2024 18:39:00 +0200 Subject: [PATCH 11/11] fix2 --- silero_vad/function.yaml | 43 +++++++++++++++++++++++++++++------ silero_vad/item.yaml | 2 +- silero_vad/silero_vad.py | 49 ++++++++++++++++++++-------------------- 3 files changed, 61 insertions(+), 33 deletions(-) diff --git a/silero_vad/function.yaml b/silero_vad/function.yaml index 731d280d8..75d1ce0cc 100644 --- a/silero_vad/function.yaml +++ b/silero_vad/function.yaml @@ -2,7 +2,7 @@ kind: job metadata: name: silero-vad tag: '' - hash: 064d82f265f7a42c937584473e6092dbbfd8ffe8 + hash: bc0ad5572cc391fcdc93baaee48e1ef949a7984d project: '' labels: author: guyl @@ -15,7 +15,7 @@ spec: args: [] image: '' build: - functionSourceCode: # Copyright 2024 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from multiprocessing import Process, Queue
from pathlib import Path
from types import FunctionType
from typing import Dict, List, Tuple, Type, Union

import torch
import torchaudio
from tqdm import tqdm


class BaseTask:
    """
    A base class for a task to complete after VAD.
    """

    def __init__(self, audio_file: Path):
        """
        Initialize the base task.

        :param audio_file: The audio file assigned to the task.
        """
        # Store the audio file:
        self._audio_file = audio_file

        # Prepare the result:
        self._result = None

    @property
    def audio_file(self) -> Path:
        """
        Get the audio file of the task.

        :returns: The audio file of the task.
        """
        return self._audio_file

    def do_task(
        self, speech_timestamps: Union[List[Dict[str, int]], List[List[Dict[str, int]]]]
    ):
        """
        Do the task on the given speech timestamps. The base task will simply save the speech timestamps as the result.

        :param speech_timestamps: The speech timestamps to do the task on as outputted from the VAD.
        """
        self._result = speech_timestamps

    def get_result(self) -> Tuple[str, list]:
        """
        Get the result of the task. A tuple of the audio file name and the result.

        :returns: The result of the task.
        """
        return self._audio_file.name, self._result


class SpeechDiarizationTask(BaseTask):
    """
    A speech diarization task. The task will diarize the VAD speech timestamps into speakers.
    """

    def __init__(self, audio_file: Path, speaker_labels: List[str]):
        """
        Initialize the speech diarization task.

        :param audio_file:     The audio file assigned to the task.
        :param speaker_labels: The speaker labels to use for the diarization. If not given, the speakers will be named
                               "speaker_0", "speaker_1", etc.
        """
        super().__init__(audio_file=audio_file)
        self._speaker_labels = speaker_labels

    def do_task(self, speech_timestamps: List[List[Dict[str, int]]]):
        """
        Do the task on the given speech timestamps. The task will diarize the VAD speech timestamps into speakers.

        :param speech_timestamps: The speech timestamps per channel to do the task on as outputted from the VAD.
        """
        # Get the speaker labels (set default if not given):
        speaker_labels = self._speaker_labels or [
            f"speaker_{i}" for i in range(len(speech_timestamps))
        ]

        # Diarize - organize the speech timestamps into a single list of speakers and sort it by start time:
        speech_diarization = [
            (speech_timestamp["start"], speech_timestamp["end"], speaker_label)
            for speaker_label, channel_speech_timestamps in zip(
                speaker_labels, speech_timestamps
            )
            for speech_timestamp in channel_speech_timestamps
        ]
        speech_diarization.sort()
        self._result = speech_diarization


class TaskCreator:
    """
    A task creator to create different tasks to run after the VAD.
    """

    def __init__(self, task_type: Type[BaseTask], task_kwargs: dict = None):
        """
        Initialize the task creator.
        :param task_type: The task type - a `BaseTask` subclass.
        :param task_kwargs: Additional keyword arguments to pass to the to be created tasks.
        """
        self._task_type = task_type
        self._task_kwargs = task_kwargs or {}

    def create_task(self, audio_file: Path):
        """
        Create a task with the given audio file.

        :param audio_file: The audio file to assign to the task.

        :returns: The created task.
        """
        return self._task_type(audio_file=audio_file, **self._task_kwargs)


class VoiceActivityDetector:
    """
    A voice activity detection wrapper for the silero VAD model - https://github.com/snakers4/silero-vad.
    """

    def __init__(
        self,
        # Model loading kwargs:
        use_onnx: bool = True,
        force_onnx_cpu: bool = True,
        # Detection kwargs:
        threshold: float = 0.5,
        sampling_rate: int = 16_000,
        min_speech_duration_ms: int = 250,
        max_speech_duration_s: float = float("inf"),
        min_silence_duration_ms: int = 100,
        window_size_samples: int = 512,
        speech_pad_ms: int = 30,
        return_seconds: bool = False,
        per_channel: bool = False,
    ):
        """
        Initialize the voice activity detector.

        :param use_onnx:                Whether to use ONNX for inference. Default is True.
        :param force_onnx_cpu:          Whether to force ONNX to use CPU for inference. Default is True.
        :param threshold:               Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
                                        probabilities ABOVE this value are considered as SPEECH. It is better to tune
                                        this parameter for each dataset separately, but "lazy" 0.5 is pretty good for
                                        most datasets.
        :param sampling_rate:           Currently, silero VAD models support 8000 and 16000 sample rates.
        :param min_speech_duration_ms:  Final speech chunks shorter min_speech_duration_ms are thrown out.
        :param max_speech_duration_s:   Maximum duration of speech chunks in seconds. Chunks longer than
                                        `max_speech_duration_s` will be split at the timestamp of the last silence that
                                        lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise,
                                        they will be split aggressively just before max_speech_duration_s.
        :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before
                                        separating it.
        :param window_size_samples:     Audio chunks of window_size_samples size are fed to the silero VAD model.
                                        WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000
                                        sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than
                                        these may affect model performance!
        :param speech_pad_ms:           Final speech chunks are padded by speech_pad_ms each side.
        :param return_seconds:          Whether return timestamps in seconds. False means to return timestamps in
                                        samples (default - False).
        :param per_channel:             Whether to return timestamps per channel (default - False). This will run VAD
                                        on each channel separately and return a list of timestamps per channel.
        """
        # Store configurations:
        self._use_onnx = use_onnx
        self._force_onnx_cpu = force_onnx_cpu
        self._threshold = threshold
        self._sampling_rate = sampling_rate
        self._min_speech_duration_ms = min_speech_duration_ms
        self._max_speech_duration_s = max_speech_duration_s
        self._min_silence_duration_ms = min_silence_duration_ms
        self._window_size_samples = window_size_samples
        self._speech_pad_ms = speech_pad_ms
        self._return_seconds = return_seconds
        self._per_channel = per_channel

        # Prepare the model variables
        self._model: torch.Module = None
        self._get_speech_timestamps: FunctionType = None

    def load(self):
        """
        Load the VAD model.
        """
        model, utils = torch.hub.load(
            repo_or_dir="snakers4/silero-vad",
            model="silero_vad",
            force_reload=True,
            onnx=self._use_onnx,
            force_onnx_cpu=self._force_onnx_cpu,
        )
        self._model = model
        (
            self._get_speech_timestamps,
            _,  # save_audio,
            _,  # read_audio,
            _,  # VADIterator,
            _,  # collect_chunks
        ) = utils

    def detect_voice(
        self,
        audio_file: Path,
    ) -> Union[List[Dict[str, int]], List[List[Dict[str, int]]]]:
        """
        Infer the audio through the VAD model and return the speech timestamps.

        :param audio_file: The audio file to infer.

        :returns: The speech timestamps in the audio. A list of timestamps where each timestamp is a dictionary with the
                 following keys:

                 * "start": The start sample index of the speech in the audio.
                 * "end":   The end sample index of the speech in the audio.

                 If `per_channel` is True, a list of timestamps per channel will be returned.
        """
        # Cast to a numpy array:
        audio = self._read_audio(audio_file)

        # Detect speech:
        if not self._per_channel:
            return self._get_speech_timestamps(
                audio,
                self._model,
                threshold=self._threshold,
                min_speech_duration_ms=self._min_speech_duration_ms,
                max_speech_duration_s=self._max_speech_duration_s,
                min_silence_duration_ms=self._min_silence_duration_ms,
                speech_pad_ms=self._speech_pad_ms,
                sampling_rate=self._sampling_rate,
                window_size_samples=self._window_size_samples,
                return_seconds=self._return_seconds,
            )

        # Per channel:
        speech_timestamps = []
        for channel in audio:
            speech_timestamps.append(
                self._get_speech_timestamps(
                    channel,
                    self._model,
                    threshold=self._threshold,
                    min_speech_duration_ms=self._min_speech_duration_ms,
                    max_speech_duration_s=self._max_speech_duration_s,
                    min_silence_duration_ms=self._min_silence_duration_ms,
                    speech_pad_ms=self._speech_pad_ms,
                    sampling_rate=self._sampling_rate,
                    window_size_samples=self._window_size_samples,
                    return_seconds=self._return_seconds,
                )
            )

        return speech_timestamps

    def _read_audio(
        self,
        path: Path,
    ) -> torch.Tensor:
        """
        Read the audio from the given path and return it as a tensor.

        :param path: The path to the audio file.

        :returns: The audio as a tensor.
        """
        # Read the audio:
        audio, sampling_rate = torchaudio.load(str(path))

        # Check if the audio is stereo and if so, convert it to mono (only if not per channel):
        if audio.size(0) > 1 and not self._per_channel:
            audio = audio.mean(dim=0, keepdim=True)

        # Resample the audio if needed:
        if sampling_rate != self._sampling_rate:
            transform = torchaudio.transforms.Resample(
                orig_freq=sampling_rate, new_freq=self._sampling_rate
            )
            audio = transform(audio)

        # Return the audio (squeeze if not per channel):
        return audio if self._per_channel else audio.squeeze(0)


#: The value to send into multiprocessing queues to stop the process:
_MULTIPROCESSING_STOP_MARK = "STOP"


def _multiprocessing_complete_tasks(
    vad_init_kwargs: dict, tasks_queue: Queue, results_queue: Queue
):
    """
    Complete the tasks in the given queue and put the results in the given results queue. The function will stop when
    the given tasks queue will receive the stop mark. It is aimed to be used with multiprocessing as a process.

    :param vad_init_kwargs: The VAD initialization kwargs.
    :param tasks_queue:     A queue to get the tasks from.
    :param results_queue:   A queue to put the results in.
    """
    # Initialize and load the VAD:
    vad = VoiceActivityDetector(**vad_init_kwargs)
    vad.load()

    # Start listening to the tasks queue:
    while True:
        # Get the task:
        task: BaseTask = tasks_queue.get()
        if task == _MULTIPROCESSING_STOP_MARK:
            break
        try:
            # Run the file through the VAD:
            speech_timestamps = vad.detect_voice(audio_file=task.audio_file)
            # Complete the task:
            task.do_task(speech_timestamps=speech_timestamps)
            # Collect the result:
            results_queue.put((False, task.get_result()))
        except Exception as exception:
            # Collect the error:
            results_queue.put((True, (task.audio_file.name, str(exception))))

    # Mark the end of the tasks:
    results_queue.put(_MULTIPROCESSING_STOP_MARK)


# Get the global logger:
_LOGGER = logging.getLogger()


def detect_voice(
    # Input kwargs:
    data_path: Union[str, Path, List[Union[str, Path]]],
    # Model loading kwargs:
    use_onnx: bool = True,
    force_onnx_cpu: bool = True,
    # Detection kwargs:
    threshold: float = 0.5,
    sampling_rate: int = 16_000,
    min_speech_duration_ms: int = 250,
    max_speech_duration_s: float = float("inf"),
    min_silence_duration_ms: int = 100,
    window_size_samples: int = 512,
    speech_pad_ms: int = 30,
    return_seconds: bool = False,
    per_channel: bool = False,
    # Other kwargs:
    use_multiprocessing: int = 0,
    verbose: bool = False,
):
    """
    Perform voice activity detection on given audio files using the silero VAD model -
    https://github.com/snakers4/silero-vad. The end result is a dictionary with the file names as keys and their
    VAD timestamps dictionaries as value.

    For example::

        {
            "file_1.wav": [
                {"start": 0, "end": 16000},
                {"start": 16000, "end": 32000},
                {"start": 32000, "end": 48000},
                ...
            ],
            "file_2.wav": [
                {"start": 0, "end": 16000},
                {"start": 16000, "end": 32000},
                {"start": 32000, "end": 48000},
                ...
            ],
            ...
        }


    :param data_path:               The path to the audio files to diarize. Can be a path to a single file, a path to a
                                    directory or a list of paths to files.
    :param use_onnx:                Whether to use ONNX for inference. Default is True.
    :param force_onnx_cpu:          Whether to force ONNX to use CPU for inference. Default is True.
    :param threshold:               Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
                                    probabilities ABOVE this value are considered as SPEECH. It is better to tune
                                    this parameter for each dataset separately, but "lazy" 0.5 is pretty good for
                                    most datasets.
    :param sampling_rate:           Currently, silero VAD models support 8000 and 16000 sample rates.
    :param min_speech_duration_ms:  Final speech chunks shorter min_speech_duration_ms are thrown out.
    :param max_speech_duration_s:   Maximum duration of speech chunks in seconds. Chunks longer than
                                    `max_speech_duration_s` will be split at the timestamp of the last silence that
                                    lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will
                                    be split aggressively just before max_speech_duration_s.
    :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before separating
                                    it.
    :param window_size_samples:     Audio chunks of window_size_samples size are fed to the silero VAD model.

                                    WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000
                                    sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than
                                    these may affect model performance!
    :param speech_pad_ms:           Final speech chunks are padded by speech_pad_ms each side.
    :param return_seconds:          Whether return timestamps in seconds. False means to return timestamps in samples
                                    (default - False).
    :param per_channel:             Whether to return timestamps per channel (default - False). This will run VAD on
                                    each channel separately and return a list of timestamps per channel.
    :param use_multiprocessing:     The number of workers to use for multiprocessing. If 0, no multiprocessing will
                                    be used. Default is 0.
    :param verbose:                 Verbosity.
    """
    global _LOGGER

    # Get the input audio files to transcribe:
    if verbose:
        _LOGGER.info("Collecting audio files.")
    audio_files = _get_audio_files(data_path=data_path)
    if verbose:
        _LOGGER.info(f"Collected {len(audio_files)} audio files.")

    # Initialize the transcription pipeline:
    vad_init_kwargs = {
        "use_onnx": use_onnx,
        "force_onnx_cpu": force_onnx_cpu,
        "threshold": threshold,
        "sampling_rate": sampling_rate,
        "min_speech_duration_ms": min_speech_duration_ms,
        "max_speech_duration_s": max_speech_duration_s,
        "min_silence_duration_ms": min_silence_duration_ms,
        "window_size_samples": window_size_samples,
        "speech_pad_ms": speech_pad_ms,
        "return_seconds": return_seconds,
        "per_channel": per_channel,
    }

    # Create the task creator:
    task_creator = TaskCreator(task_type=BaseTask)

    # Run the transcription:
    if use_multiprocessing:
        results = _parallel_run(
            n_workers=use_multiprocessing,
            audio_files=audio_files,
            description="Detecting voice",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )
    else:
        results = _run(
            audio_files=audio_files,
            description="Detecting voice",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )

    # Process the results:
    return _process_results(results=results, verbose=verbose)


def diarize(
    # Input / Output kwargs:
    data_path: Union[str, Path, List[Union[str, Path]]],
    # Model loading kwargs:
    use_onnx: bool = True,
    force_onnx_cpu: bool = True,
    # Detection kwargs:
    threshold: float = 0.5,
    sampling_rate: int = 16_000,
    min_speech_duration_ms: int = 250,
    max_speech_duration_s: float = float("inf"),
    min_silence_duration_ms: int = 100,
    window_size_samples: int = 512,
    speech_pad_ms: int = 30,
    # Diarization kwargs:
    speaker_labels: List[str] = None,
    # Other kwargs:
    use_multiprocessing: int = 0,
    verbose: bool = False,
):
    """
    Perform speech diarization on given audio files using the silero VAD model - https://github.com/snakers4/silero-vad.
    The speech diarization is performed per channel so that each channel in the audio belong to a different speaker. The
    end result is a dictionary with the file names as keys and their diarization as value. A diarization is a list
    of tuples: (start, end, speaker_label).

    For example::

        {
            "file_1.wav": [
                (0.0, 1.0, "speaker_0"),
                (1.0, 2.0, "speaker_1"),
                (2.0, 3.0, "speaker_0"),
                ...
            ],
            "file_2.wav": [
                (0.0, 1.0, "speaker_0"),
                (1.0, 2.0, "speaker_1"),
                (2.0, 3.0, "speaker_0"),
                ...
            ],
            ...
        }


    :param data_path:               The path to the audio files to diarize. Can be a path to a single file, a path to a
                                    directory or a list of paths to files.
    :param use_onnx:                Whether to use ONNX for inference. Default is True.
    :param force_onnx_cpu:          Whether to force ONNX to use CPU for inference. Default is True.
    :param threshold:               Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
                                    probabilities ABOVE this value are considered as SPEECH. It is better to tune
                                    this parameter for each dataset separately, but "lazy" 0.5 is pretty good for
                                    most datasets.
    :param sampling_rate:           Currently, silero VAD models support 8000 and 16000 sample rates.
    :param min_speech_duration_ms:  Final speech chunks shorter min_speech_duration_ms are thrown out.
    :param max_speech_duration_s:   Maximum duration of speech chunks in seconds. Chunks longer than
                                    `max_speech_duration_s` will be split at the timestamp of the last silence that
                                    lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will
                                    be split aggressively just before max_speech_duration_s.
    :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before separating
                                    it.
    :param window_size_samples:     Audio chunks of window_size_samples size are fed to the silero VAD model.

                                    WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000
                                    sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than
                                    these may affect model performance!
    :param speech_pad_ms:           Final speech chunks are padded by speech_pad_ms each side.
    :param speaker_labels:          The speaker labels to use for the diarization. If not given, the speakers will be
                                    named "speaker_0", "speaker_1", etc.
    :param use_multiprocessing:     The number of workers to use for multiprocessing. If 0, no multiprocessing will
                                    be used. Default is 0.
    :param verbose:                 Verbosity.
    """
    global _LOGGER

    # Get the input audio files to transcribe:
    if verbose:
        _LOGGER.info("Collecting audio files.")
    audio_files = _get_audio_files(data_path=data_path)
    if verbose:
        _LOGGER.info(f"Collected {len(audio_files)} audio files.")

    # Initialize the transcription pipeline:
    vad_init_kwargs = {
        "use_onnx": use_onnx,
        "force_onnx_cpu": force_onnx_cpu,
        "threshold": threshold,
        "sampling_rate": sampling_rate,
        "min_speech_duration_ms": min_speech_duration_ms,
        "max_speech_duration_s": max_speech_duration_s,
        "min_silence_duration_ms": min_silence_duration_ms,
        "window_size_samples": window_size_samples,
        "speech_pad_ms": speech_pad_ms,
        "return_seconds": True,
        "per_channel": True,
    }

    # Create the task creator:
    task_creator = TaskCreator(
        task_type=SpeechDiarizationTask, task_kwargs={"speaker_labels": speaker_labels}
    )

    # Run the transcription:
    if use_multiprocessing:
        results = _parallel_run(
            n_workers=use_multiprocessing,
            audio_files=audio_files,
            description="Diarizing",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )
    else:
        results = _run(
            audio_files=audio_files,
            description="Diarizing",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )

    # Process the results:
    return _process_results(results=results, verbose=verbose)


def _get_audio_files(
    data_path: Union[Path, str, list],
) -> List[Path]:
    """
    Get the audio files from the data path. If a path to a directory is given, all files in the directory will be
    collected.

    :param data_path: The data path to collect the audio files from.

    :returns: The audio files list.
    """
    # Check if given a list of paths:
    if isinstance(data_path, list):
        audio_files = []
        for path in data_path:
            audio_files.extend(_get_audio_files(data_path=path))
        return audio_files

    # Check if given a single string path to cast it to a `pathlib.Path`:
    if isinstance(data_path, str):
        data_path = Path(data_path).absolute()

    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        audio_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        audio_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be a valid path to either a directory path or a "
            f"file. Given: {str(data_path)} "
        )

    return audio_files


def _run(
    audio_files: List[Path],
    description: str,
    vad_init_kwargs: dict,
    task_creator: TaskCreator,
    verbose: bool,
) -> List[Tuple[bool, Tuple[str, list]]]:
    """
    Load a VAD and use it to complete the tasks that will be created on the provided files using the given task creator.

    :param audio_files:     The audio files to use.
    :param description:     The description to use for the progress bar.
    :param vad_init_kwargs: The VAD initialization keyword arguments.
    :param task_creator:    The task creator to use to create the tasks.
    :param verbose:         Verbosity.

    :returns: The collected results.
    """
    # Load the VAD:
    vad = VoiceActivityDetector(**vad_init_kwargs)
    if verbose:
        _LOGGER.info(f"Loading the VAD model.")
    vad.load()
    if verbose:
        _LOGGER.info("VAD model loaded.")

    # Run the VAD on the audio files and collect the results:
    results = []
    for audio_file in tqdm(
        audio_files,
        desc=description,
        unit="file",
        total=len(audio_files),
        disable=not verbose,
    ):
        try:
            # Create the task:
            task = task_creator.create_task(audio_file=audio_file)
            # Run the file through the VAD:
            speech_timestamps = vad.detect_voice(audio_file=audio_file)
            # Complete the task:
            task.do_task(speech_timestamps=speech_timestamps)
            # Collect the result:
            results.append((False, task.get_result()))
        except Exception as exception:
            # Collect the error:
            results.append((True, (audio_file.name, str(exception))))

    return results


def _parallel_run(
    n_workers: int,
    audio_files: List[Path],
    description: str,
    vad_init_kwargs: dict,
    task_creator: TaskCreator,
    verbose: bool,
) -> List[Tuple[bool, Tuple[str, list]]]:
    """
    Run multiple VAD workers with multiprocessing to complete the tasks that will be created on the provided files using
    the given task creator.

    :param n_workers:       The number of workers to use.
    :param audio_files:     The audio files to use.
    :param description:     The description to use for the progress bar.
    :param vad_init_kwargs: The VAD initialization keyword arguments.
    :param task_creator:    The task creator to use to create the tasks.
    :param verbose:         Verbosity.

    :returns: The collected results.
    """
    # Check the number of workers:
    if n_workers > len(audio_files):
        _LOGGER.warning(
            f"The number of workers ({n_workers}) is larger than the number of audio files ({len(audio_files)}). "
            f"Setting the number of workers to {len(audio_files)}."
        )
        n_workers = len(audio_files)

    # Initialize the multiprocessing queues:
    tasks_queue = Queue(maxsize=n_workers * 2)
    results_queue = Queue()

    # Initialize the multiprocessing processes:
    task_completion_processes = [
        Process(
            target=_multiprocessing_complete_tasks,
            kwargs={
                "vad_init_kwargs": vad_init_kwargs,
                "tasks_queue": tasks_queue,
                "results_queue": results_queue,
            },
        )
        for _ in range(n_workers)
    ]

    # Start the multiprocessing processes:
    if verbose:
        _LOGGER.info(f"Loading the VAD model (per process).")
    for p in task_completion_processes:
        p.start()
    if verbose:
        _LOGGER.info("VAD model loaded.")

    # Put the tasks in the queue (the progress bar is not accurate as it is updating by the queue which has a max size
    # of 2*n_workers so the progress bar won't be too off - better than nothing):
    for audio_file in tqdm(
        audio_files,
        desc=description,
        unit="file",
        total=len(audio_files),
        disable=not verbose,
    ):
        # Put the task in the queue:
        tasks_queue.put(task_creator.create_task(audio_file=audio_file))

    # Put the stop marks in the queue:
    for _ in range(n_workers):
        tasks_queue.put(_MULTIPROCESSING_STOP_MARK)

    # Wait for the processes to finish:
    for p in task_completion_processes:
        p.join()

    # Collect the results:
    results = []
    stop_marks_counter = 0
    while True:
        # Get a result from the queue:
        result: Tuple[bool, Tuple[str, list]] = results_queue.get()
        if result == _MULTIPROCESSING_STOP_MARK:
            stop_marks_counter += 1
            if stop_marks_counter == n_workers:
                break
        else:
            # Collect the result:
            results.append(result)

    return results


def _process_results(
    results: List[Tuple[bool, Tuple[str, list]]], verbose: bool
) -> Tuple[dict, dict]:
    """
    Process the results of the tasks.

    :param results: The results to process.
    :param verbose: Verbosity.

    :returns: The processed results as a tuple of successes and errors.
    """
    if verbose:
        _LOGGER.info("Summarizing the results.")
    successes = {}
    errors = {}
    for is_error, result in results:
        if is_error:
            errors[result[0]] = result[1]
        else:
            successes[result[0]] = result[1]
    if verbose:
        _LOGGER.info(f"Done ({len(successes)}/{len(successes) + len(errors)})\n")

    return successes, errors
 + functionSourceCode: # Copyright 2024 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from multiprocessing import Process, Queue
from pathlib import Path
from types import FunctionType
from typing import Dict, List, Tuple, Type, Union

import torch
import torchaudio
from tqdm import tqdm


class BaseTask:
    """
    A base class for a task to complete after VAD.
    """

    def __init__(self, audio_file: Path):
        """
        Initialize the base task.

        :param audio_file: The audio file assigned to the task.
        """
        # Store the audio file:
        self._audio_file = audio_file

        # Prepare the result:
        self._result = None

    @property
    def audio_file(self) -> Path:
        """
        Get the audio file of the task.

        :returns: The audio file of the task.
        """
        return self._audio_file

    def do_task(
        self, speech_timestamps: Union[List[Dict[str, int]], List[List[Dict[str, int]]]]
    ):
        """
        Do the task on the given speech timestamps. The base task will simply save the speech timestamps as the result.

        :param speech_timestamps: The speech timestamps to do the task on as outputted from the VAD.
        """
        self._result = speech_timestamps

    def get_result(self) -> Tuple[str, list]:
        """
        Get the result of the task. A tuple of the audio file name and the result.

        :returns: The result of the task.
        """
        return self._audio_file.name, self._result

    def to_tuple(self) -> Tuple[str, dict]:
        """
        Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue).

        :returns: The converted task.
        """
        return self.__class__.__name__, {"audio_file": self._audio_file}


class SpeechDiarizationTask(BaseTask):
    """
    A speech diarization task. The task will diarize the VAD speech timestamps into speakers.
    """

    def __init__(self, audio_file: Path, speaker_labels: List[str]):
        """
        Initialize the speech diarization task.

        :param audio_file:     The audio file assigned to the task.
        :param speaker_labels: The speaker labels to use for the diarization. If not given, the speakers will be named
                               "speaker_0", "speaker_1", etc.
        """
        super().__init__(audio_file=audio_file)
        self._speaker_labels = speaker_labels

    def do_task(self, speech_timestamps: List[List[Dict[str, int]]]):
        """
        Do the task on the given speech timestamps. The task will diarize the VAD speech timestamps into speakers.

        :param speech_timestamps: The speech timestamps per channel to do the task on as outputted from the VAD.
        """
        # Get the speaker labels (set default if not given):
        speaker_labels = self._speaker_labels or [
            f"speaker_{i}" for i in range(len(speech_timestamps))
        ]

        # Diarize - organize the speech timestamps into a single list of speakers and sort it by start time:
        speech_diarization = [
            (speech_timestamp["start"], speech_timestamp["end"], speaker_label)
            for speaker_label, channel_speech_timestamps in zip(
                speaker_labels, speech_timestamps
            )
            for speech_timestamp in channel_speech_timestamps
        ]
        speech_diarization.sort()
        self._result = speech_diarization

    def to_tuple(self) -> Tuple[str, dict]:
        """
        Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue).

        :returns: The converted task.
        """
        task_class, task_kwargs = super().to_tuple()
        return task_class, {**task_kwargs, "speaker_labels": self._speaker_labels}


class TaskCreator:
    """
    A task creator to create different tasks to run after the VAD.
    """

    #: A map from task class name to task class to use in `from_tuple`:
    _MAP = {
        BaseTask.__name__: BaseTask,
        SpeechDiarizationTask.__name__: SpeechDiarizationTask,
    }

    def __init__(self, task_type: Type[BaseTask], task_kwargs: dict = None):
        """
        Initialize the task creator.
        :param task_type: The task type - a `BaseTask` subclass.
        :param task_kwargs: Additional keyword arguments to pass to the to be created tasks.
        """
        self._task_type = task_type
        self._task_kwargs = task_kwargs or {}

    def create_task(self, audio_file: Path) -> BaseTask:
        """
        Create a task with the given audio file.

        :param audio_file: The audio file to assign to the task.

        :returns: The created task.
        """
        return self._task_type(audio_file=audio_file, **self._task_kwargs)

    @classmethod
    def from_tuple(cls, task_tuple: Tuple[str, dict]) -> BaseTask:
        """
        Create a task from a tuple of the audio file name and the task kwargs.

        :param task_tuple: The task tuple to create the task from.

        :returns: The created task.
        """
        task_class, task_kwargs = task_tuple
        return cls._MAP[task_class](**task_kwargs)


class VoiceActivityDetector:
    """
    A voice activity detection wrapper for the silero VAD model - https://github.com/snakers4/silero-vad.
    """

    def __init__(
        self,
        # Model loading kwargs:
        use_onnx: bool = True,
        force_onnx_cpu: bool = True,
        # Detection kwargs:
        threshold: float = 0.5,
        sampling_rate: int = 16_000,
        min_speech_duration_ms: int = 250,
        max_speech_duration_s: float = float("inf"),
        min_silence_duration_ms: int = 100,
        window_size_samples: int = 512,
        speech_pad_ms: int = 30,
        return_seconds: bool = False,
        per_channel: bool = False,
    ):
        """
        Initialize the voice activity detector.

        :param use_onnx:                Whether to use ONNX for inference. Default is True.
        :param force_onnx_cpu:          Whether to force ONNX to use CPU for inference. Default is True.
        :param threshold:               Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
                                        probabilities ABOVE this value are considered as SPEECH. It is better to tune
                                        this parameter for each dataset separately, but "lazy" 0.5 is pretty good for
                                        most datasets.
        :param sampling_rate:           Currently, silero VAD models support 8000 and 16000 sample rates.
        :param min_speech_duration_ms:  Final speech chunks shorter min_speech_duration_ms are thrown out.
        :param max_speech_duration_s:   Maximum duration of speech chunks in seconds. Chunks longer than
                                        `max_speech_duration_s` will be split at the timestamp of the last silence that
                                        lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise,
                                        they will be split aggressively just before max_speech_duration_s.
        :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before
                                        separating it.
        :param window_size_samples:     Audio chunks of window_size_samples size are fed to the silero VAD model.
                                        WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000
                                        sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than
                                        these may affect model performance!
        :param speech_pad_ms:           Final speech chunks are padded by speech_pad_ms each side.
        :param return_seconds:          Whether return timestamps in seconds. False means to return timestamps in
                                        samples (default - False).
        :param per_channel:             Whether to return timestamps per channel (default - False). This will run VAD
                                        on each channel separately and return a list of timestamps per channel.
        """
        # Store configurations:
        self._use_onnx = use_onnx
        self._force_onnx_cpu = force_onnx_cpu
        self._threshold = threshold
        self._sampling_rate = sampling_rate
        self._min_speech_duration_ms = min_speech_duration_ms
        self._max_speech_duration_s = max_speech_duration_s
        self._min_silence_duration_ms = min_silence_duration_ms
        self._window_size_samples = window_size_samples
        self._speech_pad_ms = speech_pad_ms
        self._return_seconds = return_seconds
        self._per_channel = per_channel

        # Prepare the model variables
        self._model: torch.Module = None
        self._get_speech_timestamps: FunctionType = None

    def load(self, force_reload: bool = True):
        """
        Load the VAD model.

        :param force_reload: Whether to force reload the model even if it was already loaded. Default is True.
        """
        model, utils = torch.hub.load(
            repo_or_dir="snakers4/silero-vad",
            model="silero_vad",
            force_reload=force_reload,
            onnx=self._use_onnx,
            force_onnx_cpu=self._force_onnx_cpu,
        )
        self._model = model
        (
            self._get_speech_timestamps,
            _,  # save_audio,
            _,  # read_audio,
            _,  # VADIterator,
            _,  # collect_chunks
        ) = utils

    def detect_voice(
        self,
        audio_file: Path,
    ) -> Union[List[Dict[str, int]], List[List[Dict[str, int]]]]:
        """
        Infer the audio through the VAD model and return the speech timestamps.

        :param audio_file: The audio file to infer.

        :returns: The speech timestamps in the audio. A list of timestamps where each timestamp is a dictionary with the
                 following keys:

                 * "start": The start sample index of the speech in the audio.
                 * "end":   The end sample index of the speech in the audio.

                 If `per_channel` is True, a list of timestamps per channel will be returned.
        """
        # Cast to a numpy array:
        audio = self._read_audio(audio_file)

        # Detect speech:
        if not self._per_channel:
            return self._get_speech_timestamps(
                audio,
                self._model,
                threshold=self._threshold,
                min_speech_duration_ms=self._min_speech_duration_ms,
                max_speech_duration_s=self._max_speech_duration_s,
                min_silence_duration_ms=self._min_silence_duration_ms,
                speech_pad_ms=self._speech_pad_ms,
                sampling_rate=self._sampling_rate,
                window_size_samples=self._window_size_samples,
                return_seconds=self._return_seconds,
            )

        # Per channel:
        speech_timestamps = []
        for channel in audio:
            speech_timestamps.append(
                self._get_speech_timestamps(
                    channel,
                    self._model,
                    threshold=self._threshold,
                    min_speech_duration_ms=self._min_speech_duration_ms,
                    max_speech_duration_s=self._max_speech_duration_s,
                    min_silence_duration_ms=self._min_silence_duration_ms,
                    speech_pad_ms=self._speech_pad_ms,
                    sampling_rate=self._sampling_rate,
                    window_size_samples=self._window_size_samples,
                    return_seconds=self._return_seconds,
                )
            )

        return speech_timestamps

    def _read_audio(
        self,
        path: Path,
    ) -> torch.Tensor:
        """
        Read the audio from the given path and return it as a tensor.

        :param path: The path to the audio file.

        :returns: The audio as a tensor.
        """
        # Read the audio:
        audio, sampling_rate = torchaudio.load(str(path))

        # Check if the audio is stereo and if so, convert it to mono (only if not per channel):
        if audio.size(0) > 1 and not self._per_channel:
            audio = audio.mean(dim=0, keepdim=True)

        # Resample the audio if needed:
        if sampling_rate != self._sampling_rate:
            transform = torchaudio.transforms.Resample(
                orig_freq=sampling_rate, new_freq=self._sampling_rate
            )
            audio = transform(audio)

        # Return the audio (squeeze if not per channel):
        return audio if self._per_channel else audio.squeeze(0)


#: The value to send into multiprocessing queues to stop the process:
_MULTIPROCESSING_STOP_MARK = "STOP"


def _multiprocessing_complete_tasks(
    vad_init_kwargs: dict, tasks_queue: Queue, results_queue: Queue
):
    """
    Complete the tasks in the given queue and put the results in the given results queue. The function will stop when
    the given tasks queue will receive the stop mark. It is aimed to be used with multiprocessing as a process.

    :param vad_init_kwargs: The VAD initialization kwargs.
    :param tasks_queue:     A queue to get the tasks from.
    :param results_queue:   A queue to put the results in.
    """
    # Initialize and load the VAD:
    vad = VoiceActivityDetector(**vad_init_kwargs)
    vad.load(force_reload=False)

    # Start listening to the tasks queue:
    while True:
        # Get the task:
        task: Tuple[str, dict] = tasks_queue.get()
        if task == _MULTIPROCESSING_STOP_MARK:
            break
        try:
            # Create the task:
            task = TaskCreator.from_tuple(task_tuple=task)
            # Run the file through the VAD:
            speech_timestamps = vad.detect_voice(audio_file=task.audio_file)
            # Complete the task:
            task.do_task(speech_timestamps=speech_timestamps)
            # Build the result:
            result = (False, task.get_result())
        except Exception as exception:
            # Build the error:
            result = (True, (task.audio_file.name, str(exception)))
        # Collect the result / error:
        results_queue.put(result)

    # Mark the end of the tasks:
    results_queue.put(_MULTIPROCESSING_STOP_MARK)


# Get the global logger:
try:
    import mlrun

    _LOGGER = mlrun.get_or_create_ctx("silero_vad").logger
except ModuleNotFoundError:
    _LOGGER = logging.getLogger()


def detect_voice(
    # Input kwargs:
    data_path: Union[str, Path, List[Union[str, Path]]],
    # Model loading kwargs:
    use_onnx: bool = True,
    force_onnx_cpu: bool = True,
    # Detection kwargs:
    threshold: float = 0.5,
    sampling_rate: int = 16_000,
    min_speech_duration_ms: int = 250,
    max_speech_duration_s: float = float("inf"),
    min_silence_duration_ms: int = 100,
    window_size_samples: int = 512,
    speech_pad_ms: int = 30,
    return_seconds: bool = False,
    per_channel: bool = False,
    # Other kwargs:
    use_multiprocessing: int = 0,
    verbose: bool = False,
):
    """
    Perform voice activity detection on given audio files using the silero VAD model -
    https://github.com/snakers4/silero-vad. The end result is a dictionary with the file names as keys and their
    VAD timestamps dictionaries as value.

    For example::

        {
            "file_1.wav": [
                {"start": 0, "end": 16000},
                {"start": 16000, "end": 32000},
                {"start": 32000, "end": 48000},
                ...
            ],
            "file_2.wav": [
                {"start": 0, "end": 16000},
                {"start": 16000, "end": 32000},
                {"start": 32000, "end": 48000},
                ...
            ],
            ...
        }


    :param data_path:               The path to the audio files to diarize. Can be a path to a single file, a path to a
                                    directory or a list of paths to files.
    :param use_onnx:                Whether to use ONNX for inference. Default is True.
    :param force_onnx_cpu:          Whether to force ONNX to use CPU for inference. Default is True.
    :param threshold:               Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
                                    probabilities ABOVE this value are considered as SPEECH. It is better to tune
                                    this parameter for each dataset separately, but "lazy" 0.5 is pretty good for
                                    most datasets.
    :param sampling_rate:           Currently, silero VAD models support 8000 and 16000 sample rates.
    :param min_speech_duration_ms:  Final speech chunks shorter min_speech_duration_ms are thrown out.
    :param max_speech_duration_s:   Maximum duration of speech chunks in seconds. Chunks longer than
                                    `max_speech_duration_s` will be split at the timestamp of the last silence that
                                    lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will
                                    be split aggressively just before max_speech_duration_s.
    :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before separating
                                    it.
    :param window_size_samples:     Audio chunks of window_size_samples size are fed to the silero VAD model.

                                    WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000
                                    sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than
                                    these may affect model performance!
    :param speech_pad_ms:           Final speech chunks are padded by speech_pad_ms each side.
    :param return_seconds:          Whether return timestamps in seconds. False means to return timestamps in samples
                                    (default - False).
    :param per_channel:             Whether to return timestamps per channel (default - False). This will run VAD on
                                    each channel separately and return a list of timestamps per channel.
    :param use_multiprocessing:     The number of workers to use for multiprocessing. If 0, no multiprocessing will
                                    be used. Default is 0.
    :param verbose:                 Verbosity.
    """
    global _LOGGER

    # Get the input audio files to transcribe:
    if verbose:
        _LOGGER.info("Collecting audio files.")
    audio_files = _get_audio_files(data_path=data_path)
    if verbose:
        _LOGGER.info(f"Collected {len(audio_files)} audio files.")

    # Initialize the transcription pipeline:
    vad_init_kwargs = {
        "use_onnx": use_onnx,
        "force_onnx_cpu": force_onnx_cpu,
        "threshold": threshold,
        "sampling_rate": sampling_rate,
        "min_speech_duration_ms": min_speech_duration_ms,
        "max_speech_duration_s": max_speech_duration_s,
        "min_silence_duration_ms": min_silence_duration_ms,
        "window_size_samples": window_size_samples,
        "speech_pad_ms": speech_pad_ms,
        "return_seconds": return_seconds,
        "per_channel": per_channel,
    }

    # Create the task creator:
    task_creator = TaskCreator(task_type=BaseTask)

    # Run the transcription:
    if use_multiprocessing:
        results = _parallel_run(
            n_workers=use_multiprocessing,
            audio_files=audio_files,
            description="Detecting voice",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )
    else:
        results = _run(
            audio_files=audio_files,
            description="Detecting voice",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )

    # Process the results:
    return _process_results(results=results, verbose=verbose)


def diarize(
    # Input / Output kwargs:
    data_path: Union[str, Path, List[Union[str, Path]]],
    # Model loading kwargs:
    use_onnx: bool = True,
    force_onnx_cpu: bool = True,
    # Detection kwargs:
    threshold: float = 0.5,
    sampling_rate: int = 16_000,
    min_speech_duration_ms: int = 250,
    max_speech_duration_s: float = float("inf"),
    min_silence_duration_ms: int = 100,
    window_size_samples: int = 512,
    speech_pad_ms: int = 30,
    # Diarization kwargs:
    speaker_labels: List[str] = None,
    # Other kwargs:
    use_multiprocessing: int = 0,
    verbose: bool = False,
):
    """
    Perform speech diarization on given audio files using the silero VAD model - https://github.com/snakers4/silero-vad.
    The speech diarization is performed per channel so that each channel in the audio belong to a different speaker. The
    end result is a dictionary with the file names as keys and their diarization as value. A diarization is a list
    of tuples: (start, end, speaker_label).

    For example::

        {
            "file_1.wav": [
                (0.0, 1.0, "speaker_0"),
                (1.0, 2.0, "speaker_1"),
                (2.0, 3.0, "speaker_0"),
                ...
            ],
            "file_2.wav": [
                (0.0, 1.0, "speaker_0"),
                (1.0, 2.0, "speaker_1"),
                (2.0, 3.0, "speaker_0"),
                ...
            ],
            ...
        }


    :param data_path:               The path to the audio files to diarize. Can be a path to a single file, a path to a
                                    directory or a list of paths to files.
    :param use_onnx:                Whether to use ONNX for inference. Default is True.
    :param force_onnx_cpu:          Whether to force ONNX to use CPU for inference. Default is True.
    :param threshold:               Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
                                    probabilities ABOVE this value are considered as SPEECH. It is better to tune
                                    this parameter for each dataset separately, but "lazy" 0.5 is pretty good for
                                    most datasets.
    :param sampling_rate:           Currently, silero VAD models support 8000 and 16000 sample rates.
    :param min_speech_duration_ms:  Final speech chunks shorter min_speech_duration_ms are thrown out.
    :param max_speech_duration_s:   Maximum duration of speech chunks in seconds. Chunks longer than
                                    `max_speech_duration_s` will be split at the timestamp of the last silence that
                                    lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will
                                    be split aggressively just before max_speech_duration_s.
    :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before separating
                                    it.
    :param window_size_samples:     Audio chunks of window_size_samples size are fed to the silero VAD model.

                                    WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000
                                    sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than
                                    these may affect model performance!
    :param speech_pad_ms:           Final speech chunks are padded by speech_pad_ms each side.
    :param speaker_labels:          The speaker labels to use for the diarization. If not given, the speakers will be
                                    named "speaker_0", "speaker_1", etc.
    :param use_multiprocessing:     The number of workers to use for multiprocessing. If 0, no multiprocessing will
                                    be used. Default is 0.
    :param verbose:                 Verbosity.
    """
    global _LOGGER

    # Get the input audio files to transcribe:
    if verbose:
        _LOGGER.info("Collecting audio files.")
    audio_files = _get_audio_files(data_path=data_path)
    if verbose:
        _LOGGER.info(f"Collected {len(audio_files)} audio files.")

    # Initialize the transcription pipeline:
    vad_init_kwargs = {
        "use_onnx": use_onnx,
        "force_onnx_cpu": force_onnx_cpu,
        "threshold": threshold,
        "sampling_rate": sampling_rate,
        "min_speech_duration_ms": min_speech_duration_ms,
        "max_speech_duration_s": max_speech_duration_s,
        "min_silence_duration_ms": min_silence_duration_ms,
        "window_size_samples": window_size_samples,
        "speech_pad_ms": speech_pad_ms,
        "return_seconds": True,
        "per_channel": True,
    }

    # Create the task creator:
    task_creator = TaskCreator(
        task_type=SpeechDiarizationTask, task_kwargs={"speaker_labels": speaker_labels}
    )

    # Run the transcription:
    if use_multiprocessing:
        results = _parallel_run(
            n_workers=use_multiprocessing,
            audio_files=audio_files,
            description="Diarizing",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )
    else:
        results = _run(
            audio_files=audio_files,
            description="Diarizing",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )

    # Process the results:
    return _process_results(results=results, verbose=verbose)


def _get_audio_files(
    data_path: Union[Path, str, list],
) -> List[Path]:
    """
    Get the audio files from the data path. If a path to a directory is given, all files in the directory will be
    collected.

    :param data_path: The data path to collect the audio files from.

    :returns: The audio files list.
    """
    # Check if given a list of paths:
    if isinstance(data_path, list):
        audio_files = []
        for path in data_path:
            audio_files.extend(_get_audio_files(data_path=path))
        return audio_files

    # Check if given a single string path to cast it to a `pathlib.Path`:
    if isinstance(data_path, str):
        data_path = Path(data_path).absolute()

    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        audio_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        audio_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be a valid path to either a directory path or a "
            f"file. Given: {str(data_path)} "
        )

    return audio_files


def _run(
    audio_files: List[Path],
    description: str,
    vad_init_kwargs: dict,
    task_creator: TaskCreator,
    verbose: bool,
) -> List[Tuple[bool, Tuple[str, list]]]:
    """
    Load a VAD and use it to complete the tasks that will be created on the provided files using the given task creator.

    :param audio_files:     The audio files to use.
    :param description:     The description to use for the progress bar.
    :param vad_init_kwargs: The VAD initialization keyword arguments.
    :param task_creator:    The task creator to use to create the tasks.
    :param verbose:         Verbosity.

    :returns: The collected results.
    """
    # Load the VAD:
    vad = VoiceActivityDetector(**vad_init_kwargs)
    if verbose:
        _LOGGER.info(f"Loading the VAD model.")
    vad.load()
    if verbose:
        _LOGGER.info("VAD model loaded.")

    # Run the VAD on the audio files and collect the results:
    results = []
    for audio_file in tqdm(
        audio_files,
        desc=description,
        unit="file",
        total=len(audio_files),
        disable=not verbose,
    ):
        try:
            # Create the task:
            task = task_creator.create_task(audio_file=audio_file)
            # Run the file through the VAD:
            speech_timestamps = vad.detect_voice(audio_file=audio_file)
            # Complete the task:
            task.do_task(speech_timestamps=speech_timestamps)
            # Collect the result:
            results.append((False, task.get_result()))
        except Exception as exception:
            # Collect the error:
            results.append((True, (audio_file.name, str(exception))))

    return results


def _parallel_run(
    n_workers: int,
    audio_files: List[Path],
    description: str,
    vad_init_kwargs: dict,
    task_creator: TaskCreator,
    verbose: bool,
) -> List[Tuple[bool, Tuple[str, list]]]:
    """
    Run multiple VAD workers with multiprocessing to complete the tasks that will be created on the provided files using
    the given task creator.

    :param n_workers:       The number of workers to use.
    :param audio_files:     The audio files to use.
    :param description:     The description to use for the progress bar.
    :param vad_init_kwargs: The VAD initialization keyword arguments.
    :param task_creator:    The task creator to use to create the tasks.
    :param verbose:         Verbosity.

    :returns: The collected results.
    """
    # Load the VAD (download once, and it will be loaded then per process later on):
    if verbose:
        _LOGGER.info(f"Loading the VAD model.")
    vad = VoiceActivityDetector(**vad_init_kwargs)
    vad.load()
    if verbose:
        _LOGGER.info("VAD model loaded.")

    # Check the number of workers:
    if n_workers > len(audio_files):
        _LOGGER.warning(
            f"The number of workers ({n_workers}) is larger than the number of audio files ({len(audio_files)}). "
            f"Setting the number of workers to {len(audio_files)}."
        )
        n_workers = len(audio_files)

    # Initialize the multiprocessing queues:
    tasks_queue = Queue()
    results_queue = Queue()

    # Initialize the multiprocessing processes:
    task_completion_processes = [
        Process(
            target=_multiprocessing_complete_tasks,
            kwargs={
                "vad_init_kwargs": vad_init_kwargs,
                "tasks_queue": tasks_queue,
                "results_queue": results_queue,
            },
        )
        for _ in range(n_workers)
    ]

    # Start the multiprocessing processes:
    for p in task_completion_processes:
        p.start()

    # Put the tasks in the queue:
    for audio_file in audio_files:
        tasks_queue.put(task_creator.create_task(audio_file=audio_file).to_tuple())

    # Put the stop marks in the queue:
    for _ in range(n_workers):
        tasks_queue.put(_MULTIPROCESSING_STOP_MARK)

    # Collect the results:
    results = []
    stop_marks_counter = 0
    with tqdm(
        desc=description,
        unit="file",
        total=len(audio_files),
        disable=not verbose,
    ) as progressbar:
        while True:
            # Get a result from the queue:
            result: Tuple[bool, Tuple[str, list]] = results_queue.get()
            if result == _MULTIPROCESSING_STOP_MARK:
                stop_marks_counter += 1
                if stop_marks_counter == n_workers:
                    break
            else:
                # Collect the result:
                results.append(result)
                progressbar.update(1)

    # Wait for the processes to finish:
    for p in task_completion_processes:
        p.join()

    return results


def _process_results(
    results: List[Tuple[bool, Tuple[str, list]]], verbose: bool
) -> Tuple[dict, dict]:
    """
    Process the results of the tasks.

    :param results: The results to process.
    :param verbose: Verbosity.

    :returns: The processed results as a tuple of successes and errors.
    """
    if verbose:
        _LOGGER.info("Summarizing the results.")
    successes = {}
    errors = {}
    for is_error, result in results:
        if is_error:
            errors[result[0]] = result[1]
        else:
            successes[result[0]] = result[1]
    if verbose:
        _LOGGER.info(f"Done ({len(successes)}/{len(successes) + len(errors)})\n")

    return successes, errors
 base_image: mlrun/mlrun commands: [] code_origin: '' @@ -48,7 +48,7 @@ spec: the VAD. outputs: - default: '' - lineno: 86 + lineno: 94 get_result: name: get_result doc: Get the result of the task. A tuple of the audio file name and the result. @@ -58,6 +58,16 @@ spec: - doc: The result of the task. default: '' lineno: 61 + to_tuple: + name: to_tuple + doc: Convert the task to a tuple to reconstruct it later (used for multiprocessing + to pass in queue). + parameters: + - name: self + outputs: + - doc: The converted task. + default: '' + lineno: 116 create_task: name: create_task doc: Create a task with the given audio file. @@ -68,16 +78,35 @@ spec: doc: The audio file to assign to the task. outputs: - doc: The created task. + type: BaseTask + default: '' + lineno: 146 + from_tuple: + name: from_tuple + doc: Create a task from a tuple of the audio file name and the task kwargs. + parameters: + - name: cls + - name: task_tuple + type: Tuple[str, dict] + doc: The task tuple to create the task from. + outputs: + - doc: The created task. + type: BaseTask default: '' - lineno: 123 + lineno: 157 load: name: load doc: Load the VAD model. parameters: - name: self + - name: force_reload + type: bool + doc: Whether to force reload the model even if it was already loaded. Default + is True. + default: true outputs: - default: '' - lineno: 199 + lineno: 234 detect_voice: name: detect_voice doc: "Perform voice activity detection on given audio files using the silero\ @@ -159,7 +188,7 @@ spec: default: false outputs: - default: '' - lineno: 347 + lineno: 393 diarize: name: diarize doc: "Perform speech diarization on given audio files using the silero VAD model\ @@ -237,7 +266,7 @@ spec: default: false outputs: - default: '' - lineno: 471 + lineno: 517 description: Silero VAD (Voice Activity Detection) functions. default_handler: detect_voice disable_auto_mount: false diff --git a/silero_vad/item.yaml b/silero_vad/item.yaml index 4b6c74fc5..6f85a4c7d 100644 --- a/silero_vad/item.yaml +++ b/silero_vad/item.yaml @@ -27,4 +27,4 @@ spec: - tqdm - onnxruntime url: '' -version: 1.0.0 +version: 1.1.0 diff --git a/silero_vad/silero_vad.py b/silero_vad/silero_vad.py index 08ff585fa..a477d4ecf 100644 --- a/silero_vad/silero_vad.py +++ b/silero_vad/silero_vad.py @@ -765,7 +765,7 @@ def _parallel_run( n_workers = len(audio_files) # Initialize the multiprocessing queues: - tasks_queue = Queue(maxsize=n_workers * 2) + tasks_queue = Queue() results_queue = Queue() # Initialize the multiprocessing processes: @@ -785,39 +785,38 @@ def _parallel_run( for p in task_completion_processes: p.start() - # Put the tasks in the queue (the progress bar is not accurate as it is updating by the queue which has a max size - # of 2*n_workers so the progress bar won't be too off - better than nothing): - for audio_file in tqdm( - audio_files, - desc=description, - unit="file", - total=len(audio_files), - disable=not verbose, - ): - # Put the task in the queue: + # Put the tasks in the queue: + for audio_file in audio_files: tasks_queue.put(task_creator.create_task(audio_file=audio_file).to_tuple()) # Put the stop marks in the queue: for _ in range(n_workers): tasks_queue.put(_MULTIPROCESSING_STOP_MARK) - # Wait for the processes to finish: - for p in task_completion_processes: - p.join() - # Collect the results: results = [] stop_marks_counter = 0 - while True: - # Get a result from the queue: - result: Tuple[bool, Tuple[str, list]] = results_queue.get() - if result == _MULTIPROCESSING_STOP_MARK: - stop_marks_counter += 1 - if stop_marks_counter == n_workers: - break - else: - # Collect the result: - results.append(result) + with tqdm( + desc=description, + unit="file", + total=len(audio_files), + disable=not verbose, + ) as progressbar: + while True: + # Get a result from the queue: + result: Tuple[bool, Tuple[str, list]] = results_queue.get() + if result == _MULTIPROCESSING_STOP_MARK: + stop_marks_counter += 1 + if stop_marks_counter == n_workers: + break + else: + # Collect the result: + results.append(result) + progressbar.update(1) + + # Wait for the processes to finish: + for p in task_completion_processes: + p.join() return results