Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
356264f
[Build] Fix html links, Add <function>.html as source in documentation
yonishelach Mar 8, 2023
2202caf
Update CI temporarily and update index
yonishelach Apr 19, 2023
b0f2b5d
Merge pull request #615 from yonishelach/fix-links-html
yonishelach Apr 19, 2023
8b17c8f
[XGB-Custom] Fix test artifact key name
yonishelach Apr 19, 2023
f42f3e3
Merge pull request #616 from yonishelach/fix-xgb-custom-test
yonishelach Apr 19, 2023
a469dca
[XGB-Serving][XGB-Test][XGB-Trainer] Fix tests - artifact key
yonishelach Apr 19, 2023
80c7de2
Merge pull request #617 from yonishelach/fix-xgb-custom-test
yonishelach Apr 19, 2023
3301415
[Build] Install python 3.9 when testing (#618)
yonishelach Apr 19, 2023
0cd1f15
[Build] Update python version in CI (#620)
yonishelach Apr 19, 2023
33e7ab8
Revert "[Build] Update python version in CI (#620)" (#621)
yonishelach Apr 19, 2023
7a7473b
Revert "[Build] Install python 3.9 when testing (#618)" (#619)
yonishelach Apr 19, 2023
81437da
[Build] Build with python 3.9 (#622)
yonishelach Apr 19, 2023
c0a1bb8
Merge branch 'development' of https://github.com/mlrun/functions into…
guy1992l Oct 1, 2023
10cf170
Merge branch 'development' of https://github.com/mlrun/functions into…
guy1992l Dec 28, 2023
81b45ae
Merge branch 'development' of https://github.com/mlrun/functions into…
guy1992l Jan 4, 2024
f1aa62a
Merge branch 'development' of https://github.com/mlrun/functions into…
guy1992l Jan 8, 2024
5bb578c
Merge branch 'development' of https://github.com/mlrun/functions into…
guy1992l Jan 10, 2024
17b6eb0
fix1
guy1992l Jan 10, 2024
c3a141c
fix2
guy1992l Jan 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions silero_vad/function.yaml

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion silero_vad/item.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ spec:
- tqdm
- onnxruntime
url: ''
version: 1.0.0
version: 1.1.0
129 changes: 89 additions & 40 deletions silero_vad/silero_vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ def get_result(self) -> Tuple[str, list]:
"""
return self._audio_file.name, self._result

def to_tuple(self) -> Tuple[str, dict]:
"""
Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue).

:returns: The converted task.
"""
return self.__class__.__name__, {"audio_file": self._audio_file}


class SpeechDiarizationTask(BaseTask):
"""
Expand Down Expand Up @@ -105,12 +113,27 @@ def do_task(self, speech_timestamps: List[List[Dict[str, int]]]):
speech_diarization.sort()
self._result = speech_diarization

def to_tuple(self) -> Tuple[str, dict]:
"""
Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue).

:returns: The converted task.
"""
task_class, task_kwargs = super().to_tuple()
return task_class, {**task_kwargs, "speaker_labels": self._speaker_labels}


class TaskCreator:
"""
A task creator to create different tasks to run after the VAD.
"""

#: A map from task class name to task class to use in `from_tuple`:
_MAP = {
BaseTask.__name__: BaseTask,
SpeechDiarizationTask.__name__: SpeechDiarizationTask,
}

def __init__(self, task_type: Type[BaseTask], task_kwargs: dict = None):
"""
Initialize the task creator.
Expand All @@ -120,7 +143,7 @@ def __init__(self, task_type: Type[BaseTask], task_kwargs: dict = None):
self._task_type = task_type
self._task_kwargs = task_kwargs or {}

def create_task(self, audio_file: Path):
def create_task(self, audio_file: Path) -> BaseTask:
"""
Create a task with the given audio file.

Expand All @@ -130,6 +153,18 @@ def create_task(self, audio_file: Path):
"""
return self._task_type(audio_file=audio_file, **self._task_kwargs)

@classmethod
def from_tuple(cls, task_tuple: Tuple[str, dict]) -> BaseTask:
"""
Create a task from a tuple of the audio file name and the task kwargs.

:param task_tuple: The task tuple to create the task from.

:returns: The created task.
"""
task_class, task_kwargs = task_tuple
return cls._MAP[task_class](**task_kwargs)


class VoiceActivityDetector:
"""
Expand Down Expand Up @@ -196,14 +231,16 @@ def __init__(
self._model: torch.Module = None
self._get_speech_timestamps: FunctionType = None

def load(self):
def load(self, force_reload: bool = True):
"""
Load the VAD model.

:param force_reload: Whether to force reload the model even if it was already loaded. Default is True.
"""
model, utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad",
model="silero_vad",
force_reload=True,
force_reload=force_reload,
onnx=self._use_onnx,
force_onnx_cpu=self._force_onnx_cpu,
)
Expand Down Expand Up @@ -317,31 +354,40 @@ def _multiprocessing_complete_tasks(
"""
# Initialize and load the VAD:
vad = VoiceActivityDetector(**vad_init_kwargs)
vad.load()
vad.load(force_reload=False)

# Start listening to the tasks queue:
while True:
# Get the task:
task: BaseTask = tasks_queue.get()
task: Tuple[str, dict] = tasks_queue.get()
if task == _MULTIPROCESSING_STOP_MARK:
break
try:
# Create the task:
task = TaskCreator.from_tuple(task_tuple=task)
# Run the file through the VAD:
speech_timestamps = vad.detect_voice(audio_file=task.audio_file)
# Complete the task:
task.do_task(speech_timestamps=speech_timestamps)
# Collect the result:
results_queue.put((False, task.get_result()))
# Build the result:
result = (False, task.get_result())
except Exception as exception:
# Collect the error:
results_queue.put((True, (task.audio_file.name, str(exception))))
# Build the error:
result = (True, (task.audio_file.name, str(exception)))
# Collect the result / error:
results_queue.put(result)

# Mark the end of the tasks:
results_queue.put(_MULTIPROCESSING_STOP_MARK)


# Get the global logger:
_LOGGER = logging.getLogger()
try:
import mlrun

_LOGGER = mlrun.get_or_create_ctx("silero_vad").logger
except ModuleNotFoundError:
_LOGGER = logging.getLogger()
Comment thread
yonishelach marked this conversation as resolved.


def detect_voice(
Expand Down Expand Up @@ -702,6 +748,14 @@ def _parallel_run(

:returns: The collected results.
"""
# Load the VAD (download once, and it will be loaded then per process later on):
if verbose:
_LOGGER.info(f"Loading the VAD model.")
vad = VoiceActivityDetector(**vad_init_kwargs)
vad.load()
if verbose:
_LOGGER.info("VAD model loaded.")

# Check the number of workers:
if n_workers > len(audio_files):
_LOGGER.warning(
Expand All @@ -711,7 +765,7 @@ def _parallel_run(
n_workers = len(audio_files)

# Initialize the multiprocessing queues:
tasks_queue = Queue(maxsize=n_workers * 2)
tasks_queue = Queue()
results_queue = Queue()

# Initialize the multiprocessing processes:
Expand All @@ -728,46 +782,41 @@ def _parallel_run(
]

# Start the multiprocessing processes:
if verbose:
_LOGGER.info(f"Loading the VAD model (per process).")
for p in task_completion_processes:
p.start()
if verbose:
_LOGGER.info("VAD model loaded.")

# Put the tasks in the queue (the progress bar is not accurate as it is updating by the queue which has a max size
# of 2*n_workers so the progress bar won't be too off - better than nothing):
for audio_file in tqdm(
audio_files,
desc=description,
unit="file",
total=len(audio_files),
disable=not verbose,
):
# Put the task in the queue:
tasks_queue.put(task_creator.create_task(audio_file=audio_file))
# Put the tasks in the queue:
for audio_file in audio_files:
tasks_queue.put(task_creator.create_task(audio_file=audio_file).to_tuple())

# Put the stop marks in the queue:
for _ in range(n_workers):
tasks_queue.put(_MULTIPROCESSING_STOP_MARK)

# Wait for the processes to finish:
for p in task_completion_processes:
p.join()

# Collect the results:
results = []
stop_marks_counter = 0
while True:
# Get a result from the queue:
result: Tuple[bool, Tuple[str, list]] = results_queue.get()
if result == _MULTIPROCESSING_STOP_MARK:
stop_marks_counter += 1
if stop_marks_counter == n_workers:
break
else:
# Collect the result:
results.append(result)
with tqdm(
desc=description,
unit="file",
total=len(audio_files),
disable=not verbose,
) as progressbar:
while True:
# Get a result from the queue:
result: Tuple[bool, Tuple[str, list]] = results_queue.get()
if result == _MULTIPROCESSING_STOP_MARK:
stop_marks_counter += 1
if stop_marks_counter == n_workers:
break
else:
# Collect the result:
results.append(result)
progressbar.update(1)

# Wait for the processes to finish:
for p in task_completion_processes:
p.join()

return results

Expand Down