diff --git a/backends/cadence/runtime/runtime.py b/backends/cadence/runtime/runtime.py index a7d35fbd0c9..3a139e415ea 100644 --- a/backends/cadence/runtime/runtime.py +++ b/backends/cadence/runtime/runtime.py @@ -45,7 +45,7 @@ def get_op_names(program: et_schema.Program, execution_plan_id: int = 0) -> set[ op_names |= get_op_names( deserialize_pte_binary( program.backend_delegate_data[delegate.processed.index].data - ) + ).program ) return op_names diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 11a2f57a64f..c4def2dc474 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -197,7 +197,7 @@ def dump_context_from_pte(pte_path) -> List[str]: with open(pte_path, "rb") as f: program_data = f.read() - program = deserialize_pte_binary(program_data) + program = deserialize_pte_binary(program_data).program ctx_path = os.path.dirname(pte_path) dummy_compiler_specs = generate_qnn_executorch_compiler_spec( diff --git a/codegen/tools/gen_ops_def.py b/codegen/tools/gen_ops_def.py index aba3f9242ac..98fdab73fd1 100644 --- a/codegen/tools/gen_ops_def.py +++ b/codegen/tools/gen_ops_def.py @@ -23,7 +23,7 @@ def get_operators(model_file: str) -> List[Operator]: print("Processing model file: ", model_file) with open(model_file, "rb") as f: flatbuffer = f.read() - program = _deserialize_pte_binary(flatbuffer) + program = _deserialize_pte_binary(flatbuffer).program print(f"Program loaded from model file: {model_file}") operators = program.execution_plan[0].operators return operators diff --git a/examples/qualcomm/oss_scripts/llama/decoder_utils.py b/examples/qualcomm/oss_scripts/llama/decoder_utils.py index 3a9f0f39fac..d41d9d32120 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_utils.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_utils.py @@ -276,7 +276,7 @@ def __init__( # noqa: C901 with open(pte_path, "rb") as f: program_data = f.read() - program = deserialize_pte_binary(program_data) + program = deserialize_pte_binary(program_data).program # Retrieve vocab_size from get_metadata under static_llama that is passed to edge manager self.output_vocab_size = None diff --git a/exir/_serialize/__init__.py b/exir/_serialize/__init__.py index 5a5ec315b7f..242f254ca46 100644 --- a/exir/_serialize/__init__.py +++ b/exir/_serialize/__init__.py @@ -8,6 +8,7 @@ from executorch.exir._serialize._program import ( deserialize_pte_binary as _deserialize_pte_binary, + PTEFile as _PTEFile, serialize_pte_binary as _serialize_pte_binary, ) @@ -15,4 +16,5 @@ __all__ = [ "_deserialize_pte_binary", "_serialize_pte_binary", + "_PTEFile", ] diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index bee5b3438b0..93e769c565c 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -21,7 +21,10 @@ _program_flatbuffer_to_json, _program_json_to_flatbuffer, ) -from executorch.exir._serialize._named_data_store import NamedDataStoreOutput +from executorch.exir._serialize._named_data_store import ( + NamedDataStore, + NamedDataStoreOutput, +) from executorch.exir._serialize.data_serializer import DataEntry @@ -46,6 +49,19 @@ _HEADER_BYTEORDER: Literal["little"] = "little" +@dataclass +class PTEFile: + """ + Wraps together the data required to serialize into a PTE file. + """ + + program: Program + # TODO(lfq): add constant data (currently restored in the program) + # TODO(lfq): update this to List[bytes] + mutable_data: Optional[List[Buffer]] = None + named_data: Optional[NamedDataStoreOutput] = None + + @dataclass class AlignedData: """ @@ -575,7 +591,91 @@ def serialize_pte_binary( return pte_data -def _restore_segments(program: Program, segment_data: bytes) -> Program: +def _restore_delegates(program: Program, segments: List[bytes]) -> Program: + """Find and replace the Program's references to these segments, inlining + the data. + + Args: + program: The Program holding non-inlined delegates. Modified in-place. + segments: List of bytes containing the delegate data. Not modified. + + Returns: The Program with delegates restored. + """ + for plan_index, plan in enumerate(program.execution_plan): + for delegate_index, delegate in enumerate(plan.delegates): + if delegate.processed.location == DataLocation.INLINE: + continue + assert delegate.processed.location == DataLocation.SEGMENT + index = delegate.processed.index + if index >= len(segments): + raise ValueError( + f"Plan {plan_index} delegate {delegate_index} " + + f"segment index {index} >= num segments {len(segments)}" + ) + + data_index: int = len(program.backend_delegate_data) + program.backend_delegate_data.append( + BackendDelegateInlineData(data=segments[index]) + ) + delegate.processed = BackendDelegateDataReference( + location=DataLocation.INLINE, index=data_index + ) + return program + + +def _restore_constant_segment( + constant_segment: SubsegmentOffsets, segment_data: bytes +) -> List[Buffer]: + """Convert constant and mutable tensors from a single byte-blob into a list of individual tensors. + + Args: + constant_segment: SubsegmentOffset with the offsets of each tensor. + segment_data: byte data containing the tensors and padding. Not modified. + + Returns: + List[Buffer] containing each tensor in a separate object. + """ + buffers: List[Buffer] = [] + for i in range(len(constant_segment.offsets)): + start_offset = constant_segment.offsets[i] + # Note: this is the original end offset plus any padding between it and the next start offset + end_offset = ( + constant_segment.offsets[i + 1] + if i < len(constant_segment.offsets) - 1 + else len(segment_data) + ) + buffers.append(Buffer(storage=segment_data[start_offset:end_offset])) + return buffers + + +def _restore_named_data( + program: Program, + segments: List[bytes], +) -> NamedDataStoreOutput: + """Moves named data from `segments` and `program` into the + NamedDataStoreOutput class. + + Args: + program: The Program holding named data references. Not modified. + segments: The data containing the segments. Not modified. + """ + named_data_store = NamedDataStore() + for entry in program.named_data: + if entry.segment_index >= len(segments): + raise ValueError( + "Named data segment index " + f"{entry.segment_index} >= num segments {len(segments)}" + ) + named_data_store.add_named_data( + key=entry.key, + data=segments[entry.segment_index], + alignment=1, # Deserialization does not preserve alignment. + tensor_layout=None, # PTE file currently does not serialize this. + ) + return named_data_store.get_named_data_store_output() + + +def _restore_segments(program: Program, segment_data: bytes) -> PTEFile: """Moves segments from `segment_data` into `program`. This should recreate the original Program that the segments were extracted @@ -589,7 +689,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program: the preceding data has been stripped off so that the first segment begins at offset zero. Returns: - The Program with segments restored. + PTEFile, containing the Program with delegate and constant segments restored, mutable data segment, and named data segment. """ # Extract the list of segment data blobs, which parallel program.segments. segments: List[bytes] = [] @@ -600,53 +700,51 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program: ) segments.append(segment_data[segment.offset : segment.offset + segment.size]) - # Find and replace the Program's references to these segments, inlining the - # data. - for plan_index, plan in enumerate(program.execution_plan): - for delegate_index, delegate in enumerate(plan.delegates): - if delegate.processed.location == DataLocation.INLINE: - continue - assert delegate.processed.location == DataLocation.SEGMENT - index = delegate.processed.index - if index >= len(segments): - raise ValueError( - f"Plan {plan_index} delegate {delegate_index} " - + f"segment index {index} >= num segments {len(segments)}" - ) - - data_index: int = len(program.backend_delegate_data) - program.backend_delegate_data.append( - BackendDelegateInlineData(data=segments[index]) - ) - delegate.processed = BackendDelegateDataReference( - location=DataLocation.INLINE, index=data_index - ) + # Restore delegate segments that weren't inlined previously. + program = _restore_delegates(program, segments) # Replace constants from constant_segment into constant_buffer. if program.constant_segment and len(program.constant_segment.offsets) > 0: - buffers: List[Buffer] = [] - constant_segment = segments[program.constant_segment.segment_index] - for i in range(len(program.constant_segment.offsets)): - start_offset = program.constant_segment.offsets[i] - # Note: this is the original end offset plus any padding between - # it and the next start offset. - end_offset = ( - program.constant_segment.offsets[i + 1] - if i < len(program.constant_segment.offsets) - 1 - else len(constant_segment) + if program.constant_segment.segment_index >= len(segments): + raise ValueError( + f"Constant segment index {program.constant_segment.segment_index} >= num segments {len(segments)}" ) - buffers.append(Buffer(storage=constant_segment[start_offset:end_offset])) - program.constant_buffer = buffers + program.constant_buffer = _restore_constant_segment( + program.constant_segment, segments[program.constant_segment.segment_index] + ) program.constant_segment.segment_index = 0 program.constant_segment.offsets = [] - # Clear out the segments list since the original Program didn't have one. + # Extract mutable segments. + mutable_data = None + if program.mutable_data_segments and len(program.mutable_data_segments) > 0: + if len(program.mutable_data_segments) > 1: + raise ValueError("Can't handle more than 1 mutable data segment.") + segment_index = program.mutable_data_segments[0].segment_index + if segment_index >= len(segments): + raise ValueError( + f"Mutable data segment index {segment_index} >= num segments {len(segments)}" + ) + mutable_data = _restore_constant_segment( + program.mutable_data_segments[0], + segments[segment_index], + ) + program.mutable_data_segments = None + + # Extract named data. + named_data = None + if program.named_data: + named_data = _restore_named_data(program, segments) + + # Clear named_data and segments, which are empty pre-serialization. + program.named_data = [] program.segments = [] - return program + return PTEFile(program=program, mutable_data=mutable_data, named_data=named_data) -def deserialize_pte_binary(program_data: bytes) -> Program: - """Returns a Program deserialized from the given runtime binary data.""" + +def deserialize_pte_binary(program_data: bytes) -> PTEFile: + """Returns a PTEFile deserialized from the given runtime binary data.""" program_size = len(program_data) segment_base_offset = 0 @@ -664,8 +762,8 @@ def deserialize_pte_binary(program_data: bytes) -> Program: if segment_base_offset != 0: # Move segment data back into the Program. - program = _restore_segments( + return _restore_segments( program=program, segment_data=program_data[segment_base_offset:] ) - return program + return PTEFile(program=program, mutable_data=None, named_data=None) diff --git a/exir/_serialize/test/test_program.py b/exir/_serialize/test/test_program.py index 80f4b8ca49f..b2a22694245 100644 --- a/exir/_serialize/test/test_program.py +++ b/exir/_serialize/test/test_program.py @@ -8,12 +8,13 @@ # pyre-unsafe import copy +import dataclasses import difflib import json import math import unittest -from typing import List, Sequence +from typing import Dict, List, Sequence from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json from executorch.exir._serialize._named_data_store import NamedDataStoreOutput @@ -281,13 +282,43 @@ def constant_segment_with_tensor_alignment( ) # Convert back. - program2 = deserialize_pte_binary(pte_data) + deserialized = deserialize_pte_binary(pte_data) # Programs are the same besides constant_buffer, as deserialization # does not preserve constant segment; padding may be added # during serialization. - self.assertEqual(program2.execution_plan, program.execution_plan) + self.assertEqual(deserialized.program.execution_plan, program.execution_plan) # Number of constant tensors should be the same. - self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer)) + self.assertEqual( + len(deserialized.program.constant_buffer), len(program.constant_buffer) + ) + self.assertEqual(deserialized.mutable_data, None) + self.assertEqual(deserialized.named_data, None) + + def _check_named_data_entries( + self, reference: Dict[str, DataEntry], actual: Dict[str, DataEntry] + ) -> None: + self.assertEqual(reference.keys(), actual.keys()) + SKIP_FIELDS = {"alignment"} # Fields to ignore in comparison. + for key in reference.keys(): + ref_entry = reference[key] + actual_entry = actual[key] + for field in dataclasses.fields(ref_entry): + if field.name not in SKIP_FIELDS: + self.assertEqual( + getattr(ref_entry, field.name), + getattr(actual_entry, field.name), + f"Named data record {key}.{field.name} does not match.", + ) + + def _check_named_data_store_output( + self, reference: NamedDataStoreOutput, actual: NamedDataStoreOutput + ) -> None: + # Check buffers. + self.assertEqual(reference.buffers, actual.buffers) + # Check pte_data. + self._check_named_data_entries(reference.pte_data, actual.pte_data) + # Should be empty. + self.assertEqual(reference.external_data, actual.external_data) def test_canonicalize_delegate_indices(self) -> None: def make_execution_plan( @@ -426,10 +457,12 @@ def test_round_trip_no_header_no_segments(self) -> None: self.assertIsNone(eh) # Convert back. - program2 = deserialize_pte_binary(pte_data) + deserialized = deserialize_pte_binary(pte_data) # Programs should be the same. - self.assert_programs_equal(program, program2) + self.assert_programs_equal(program, deserialized.program) + self.assertEqual(deserialized.mutable_data, None) + self.assertEqual(deserialized.named_data, None) def test_round_trip_large_buffer_sizes(self) -> None: """Tests that when the non_const_buffer_sizes contains integers @@ -439,7 +472,9 @@ def test_round_trip_large_buffer_sizes(self) -> None: program = get_test_program() program.execution_plan[0].non_const_buffer_sizes = [0, 2**48] flatbuffer_from_py = bytes(serialize_pte_binary(program)) - self.assert_programs_equal(program, deserialize_pte_binary(flatbuffer_from_py)) + self.assert_programs_equal( + program, deserialize_pte_binary(flatbuffer_from_py).program + ) def test_round_trip_no_segments_and_no_header(self) -> None: """Tests that a Program serialized with extract_delegate_segments=True @@ -463,10 +498,12 @@ def test_round_trip_no_segments_and_no_header(self) -> None: self.assertEqual(program_with_segments.segments, []) # Convert back. - program2 = deserialize_pte_binary(pte_data) + deserialized = deserialize_pte_binary(pte_data) # Programs should be the same. - self.assert_programs_equal(program, program2) + self.assert_programs_equal(program, deserialized.program) + self.assertEqual(deserialized.mutable_data, None) + self.assertEqual(deserialized.named_data, None) @staticmethod def gen_blob_data(size: int, pattern: bytes) -> bytes: @@ -598,8 +635,10 @@ def test_round_trip_with_segments(self) -> None: # meaning that the segments were moved back to inline. This also # demonstrates that the contents of all segments survived, and weren't # truncated or corrupted. - program2 = deserialize_pte_binary(pte_data) - self.assert_programs_equal(program, program2) + deserialized = deserialize_pte_binary(pte_data) + self.assert_programs_equal(program, deserialized.program) + self.assertEqual(deserialized.mutable_data, None) + self.assertEqual(deserialized.named_data, None) def test_no_constants(self) -> None: program = get_test_program() @@ -884,13 +923,17 @@ def test_constant_delegate_and_named_data_segments(self) -> None: ) # Convert back. - program2 = deserialize_pte_binary(pte_data) + deserialized = deserialize_pte_binary(pte_data) # Programs are the same besides constant_buffer, as deserialization # does not preserve constant segment; padding may be added # during serialization. - self.assertEqual(program2.execution_plan, program.execution_plan) + self.assertEqual(deserialized.program.execution_plan, program.execution_plan) # Number of constant tensors should be the same. - self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer)) + self.assertEqual( + len(deserialized.program.constant_buffer), len(program.constant_buffer) + ) + self.assertEqual(deserialized.mutable_data, None) + self._check_named_data_store_output(deserialized.named_data, named_data) def test_named_data_segments(self) -> None: # Set segment alignment to 12 to test the padding. @@ -995,6 +1038,27 @@ def test_named_data_segments(self) -> None: buffers[2], ) + # Test roundtrip + deserialized = deserialize_pte_binary(pte_data) + self.assert_programs_equal(deserialized.program, program) + self.assertEqual(deserialized.mutable_data, None) + self._check_named_data_store_output(deserialized.named_data, named_data) + + # Test re-serialize + pte_data2 = serialize_pte_binary( + deserialized.program, + extract_delegate_segments=True, + segment_alignment=SEGMENT_ALIGNMENT, + constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, + named_data=deserialized.named_data, + ) + # pte_data2 is not going to be the same as pte_data due to alignment; + # directly test the deserialized one. + deserialized2 = deserialize_pte_binary(bytes(pte_data2)) + self.assert_programs_equal(deserialized2.program, program) + self.assertEqual(deserialized2.mutable_data, None) + self._check_named_data_store_output(deserialized2.named_data, named_data) + # Common data for extended header tests. The two example values should produce # the example data. diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 199a667ab64..4844088c0c2 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1665,9 +1665,10 @@ def forward(self, x): self.assertEqual(values[5].val, Double(double_val=float("-inf"))) # Confirm that we can also deserialize the model with infinity in it. - pte_data = deserialize_pte_binary(model.buffer) + deserialize = deserialize_pte_binary(model.buffer) self.assertEqual( - pte_data.execution_plan, model.executorch_program.execution_plan + deserialize.program.execution_plan, + model.executorch_program.execution_plan, ) def test_mutate_input_tensor(self) -> None: