diff --git a/.github/workflows/check-types.yml b/.github/workflows/check-types.yml index f8754189..c4490611 100644 --- a/.github/workflows/check-types.yml +++ b/.github/workflows/check-types.yml @@ -29,7 +29,7 @@ jobs: run: python -m pip install --upgrade mypy - name: Install packages - run: python -m pip install pytest ./livekit-api ./livekit-protocol ./livekit-rtc + run: python -m pip install pytest ./livekit-api ./livekit-protocol ./livekit-rtc pydantic - name: Check Types run: python -m mypy --install-type --non-interactive -p 'livekit-protocol' -p 'livekit-api' -p 'livekit-rtc' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 56d3c705..db7a89fe 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,5 +23,5 @@ jobs: - name: Run tests run: | python3 ./livekit-rtc/rust-sdks/download_ffi.py --output livekit-rtc/livekit/rtc/resources - pip3 install pytest ./livekit-protocol ./livekit-api ./livekit-rtc + pip3 install pytest ./livekit-protocol ./livekit-api ./livekit-rtc pydantic pytest . --ignore=livekit-rtc/rust-sdks diff --git a/livekit-rtc/livekit/rtc/audio_frame.py b/livekit-rtc/livekit/rtc/audio_frame.py index 9e9931ab..4d83c7a1 100644 --- a/livekit-rtc/livekit/rtc/audio_frame.py +++ b/livekit-rtc/livekit/rtc/audio_frame.py @@ -17,7 +17,7 @@ from ._proto import audio_frame_pb2 as proto_audio from ._proto import ffi_pb2 as proto_ffi from ._utils import get_address -from typing import Union +from typing import Any, Union class AudioFrame: @@ -55,6 +55,10 @@ def __init__( "data length must be >= num_channels * samples_per_channel * sizeof(int16)" ) + if len(data) % ctypes.sizeof(ctypes.c_int16) != 0: + # can happen if data is bigger than needed + raise ValueError("data length must be a multiple of sizeof(int16)") + self._data = bytearray(data) self._sample_rate = sample_rate self._num_channels = num_channels @@ -197,3 +201,58 @@ def __repr__(self) -> str: f"samples_per_channel={self.samples_per_channel}, " f"duration={self.duration:.3f})" ) + + @classmethod + def __get_pydantic_core_schema__(cls, *_: Any): + from pydantic_core import core_schema + import base64 + + def validate_audio_frame(value: Any) -> "AudioFrame": + if isinstance(value, AudioFrame): + return value + + if isinstance(value, tuple): + value = value[0] + + if isinstance(value, dict): + return AudioFrame( + data=base64.b64decode(value["data"]), + sample_rate=value["sample_rate"], + num_channels=value["num_channels"], + samples_per_channel=value["samples_per_channel"], + ) + + raise TypeError("Invalid type for AudioFrame") + + return core_schema.json_or_python_schema( + json_schema=core_schema.chain_schema( + [ + core_schema.model_fields_schema( + { + "data": core_schema.model_field(core_schema.str_schema()), + "sample_rate": core_schema.model_field( + core_schema.int_schema() + ), + "num_channels": core_schema.model_field( + core_schema.int_schema() + ), + "samples_per_channel": core_schema.model_field( + core_schema.int_schema() + ), + }, + ), + core_schema.no_info_plain_validator_function(validate_audio_frame), + ] + ), + python_schema=core_schema.no_info_plain_validator_function( + validate_audio_frame + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: { + "data": base64.b64encode(instance.data).decode("utf-8"), + "sample_rate": instance.sample_rate, + "num_channels": instance.num_channels, + "samples_per_channel": instance.samples_per_channel, + } + ), + ) diff --git a/livekit-rtc/livekit/rtc/video_frame.py b/livekit-rtc/livekit/rtc/video_frame.py index 1aa2a3d9..cf34a8bd 100644 --- a/livekit-rtc/livekit/rtc/video_frame.py +++ b/livekit-rtc/livekit/rtc/video_frame.py @@ -20,6 +20,8 @@ from ._ffi_client import FfiClient, FfiHandle from ._utils import get_address +from typing import Any + class VideoFrame: """ @@ -203,6 +205,55 @@ def convert( def __repr__(self) -> str: return f"rtc.VideoFrame(width={self.width}, height={self.height}, type={self.type})" + @classmethod + def __get_pydantic_core_schema__(cls, *_: Any): + from pydantic_core import core_schema + import base64 + + def validate_video_frame(value: Any) -> "VideoFrame": + if isinstance(value, VideoFrame): + return value + + if isinstance(value, tuple): + value = value[0] + + if isinstance(value, dict): + return VideoFrame( + width=value["width"], + height=value["height"], + type=proto_video.VideoBufferType.ValueType(value["type"]), + data=base64.b64decode(value["data"]), + ) + + raise TypeError("Invalid type for VideoFrame") + + return core_schema.json_or_python_schema( + json_schema=core_schema.chain_schema( + [ + core_schema.model_fields_schema( + { + "width": core_schema.model_field(core_schema.int_schema()), + "height": core_schema.model_field(core_schema.int_schema()), + "type": core_schema.model_field(core_schema.int_schema()), + "data": core_schema.model_field(core_schema.str_schema()), + }, + ), + core_schema.no_info_plain_validator_function(validate_video_frame), + ] + ), + python_schema=core_schema.no_info_plain_validator_function( + validate_video_frame + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: { + "width": instance.width, + "height": instance.height, + "type": instance.type, + "data": base64.b64encode(instance.data).decode("utf-8"), + } + ), + ) + def _component_info( data_ptr: int, stride: int, size: int