From 28f92af4a63fb5eec9b59dcebff8f40c01c775d7 Mon Sep 17 00:00:00 2001 From: lucylq Date: Wed, 12 Nov 2025 18:23:48 -0800 Subject: [PATCH 1/5] Introduce PTEFile class PTEFile class holds the components of a PTE file: the program, mutable constants and named data. Currently, the `program` definition does not contain mutable constants and named data; they are always stored in segments and not inline. This means when we deserialize, they are lost, because we only deserialize into the `program` concept. Now, segment data is included in the PTEFile class. Differential Revision: [D86814175](https://our.internmc.facebook.com/intern/diff/D86814175/) [ghstack-poisoned] --- backends/cadence/runtime/runtime.py | 2 +- backends/qualcomm/utils/utils.py | 2 +- codegen/tools/gen_ops_def.py | 2 +- .../oss_scripts/llama/decoder_utils.py | 2 +- exir/_serialize/__init__.py | 2 + exir/_serialize/_program.py | 81 +++++++++++++--- exir/_serialize/test/test_program.py | 92 ++++++++++++++++--- exir/emit/test/test_emit.py | 5 +- 8 files changed, 157 insertions(+), 31 deletions(-) diff --git a/backends/cadence/runtime/runtime.py b/backends/cadence/runtime/runtime.py index a7d35fbd0c9..3a139e415ea 100644 --- a/backends/cadence/runtime/runtime.py +++ b/backends/cadence/runtime/runtime.py @@ -45,7 +45,7 @@ def get_op_names(program: et_schema.Program, execution_plan_id: int = 0) -> set[ op_names |= get_op_names( deserialize_pte_binary( program.backend_delegate_data[delegate.processed.index].data - ) + ).program ) return op_names diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index d26e9530f0b..0cd9aff247f 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -197,7 +197,7 @@ def dump_context_from_pte(pte_path) -> List[str]: with open(pte_path, "rb") as f: program_data = f.read() - program = deserialize_pte_binary(program_data) + program = deserialize_pte_binary(program_data).program ctx_path = os.path.dirname(pte_path) dummy_compiler_specs = generate_qnn_executorch_compiler_spec( diff --git a/codegen/tools/gen_ops_def.py b/codegen/tools/gen_ops_def.py index aba3f9242ac..98fdab73fd1 100644 --- a/codegen/tools/gen_ops_def.py +++ b/codegen/tools/gen_ops_def.py @@ -23,7 +23,7 @@ def get_operators(model_file: str) -> List[Operator]: print("Processing model file: ", model_file) with open(model_file, "rb") as f: flatbuffer = f.read() - program = _deserialize_pte_binary(flatbuffer) + program = _deserialize_pte_binary(flatbuffer).program print(f"Program loaded from model file: {model_file}") operators = program.execution_plan[0].operators return operators diff --git a/examples/qualcomm/oss_scripts/llama/decoder_utils.py b/examples/qualcomm/oss_scripts/llama/decoder_utils.py index 085e2a6c07e..352758902de 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_utils.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_utils.py @@ -276,7 +276,7 @@ def __init__( # noqa: C901 with open(pte_path, "rb") as f: program_data = f.read() - program = deserialize_pte_binary(program_data) + program = deserialize_pte_binary(program_data).program # Retrieve vocab_size from get_metadata under static_llama that is passed to edge manager self.output_vocab_size = None diff --git a/exir/_serialize/__init__.py b/exir/_serialize/__init__.py index 5a5ec315b7f..242f254ca46 100644 --- a/exir/_serialize/__init__.py +++ b/exir/_serialize/__init__.py @@ -8,6 +8,7 @@ from executorch.exir._serialize._program import ( deserialize_pte_binary as _deserialize_pte_binary, + PTEFile as _PTEFile, serialize_pte_binary as _serialize_pte_binary, ) @@ -15,4 +16,5 @@ __all__ = [ "_deserialize_pte_binary", "_serialize_pte_binary", + "_PTEFile", ] diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index bee5b3438b0..e4af45c08ce 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -21,7 +21,10 @@ _program_flatbuffer_to_json, _program_json_to_flatbuffer, ) -from executorch.exir._serialize._named_data_store import NamedDataStoreOutput +from executorch.exir._serialize._named_data_store import ( + NamedDataStore, + NamedDataStoreOutput, +) from executorch.exir._serialize.data_serializer import DataEntry @@ -46,6 +49,19 @@ _HEADER_BYTEORDER: Literal["little"] = "little" +@dataclass +class PTEFile: + """ + Wraps together the data required to serialize into a PTE file. + """ + + program: Program + # TODO(lfq): add constant data (currently restored in the program) + # TODO(lfq): update this to List[bytes] + mutable_data: Optional[List[Buffer]] = None + named_data: Optional[NamedDataStoreOutput] = None + + @dataclass class AlignedData: """ @@ -575,7 +591,7 @@ def serialize_pte_binary( return pte_data -def _restore_segments(program: Program, segment_data: bytes) -> Program: +def _restore_segments(program: Program, segment_data: bytes) -> PTEFile: """Moves segments from `segment_data` into `program`. This should recreate the original Program that the segments were extracted @@ -589,7 +605,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program: the preceding data has been stripped off so that the first segment begins at offset zero. Returns: - The Program with segments restored. + PTEFile, containing the Program with delegate and constant segments restored, mutable data segment, and named data segment. """ # Extract the list of segment data blobs, which parallel program.segments. segments: List[bytes] = [] @@ -624,7 +640,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program: # Replace constants from constant_segment into constant_buffer. if program.constant_segment and len(program.constant_segment.offsets) > 0: - buffers: List[Buffer] = [] + constant_buffers: List[Buffer] = [] constant_segment = segments[program.constant_segment.segment_index] for i in range(len(program.constant_segment.offsets)): start_offset = program.constant_segment.offsets[i] @@ -635,17 +651,60 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program: if i < len(program.constant_segment.offsets) - 1 else len(constant_segment) ) - buffers.append(Buffer(storage=constant_segment[start_offset:end_offset])) - program.constant_buffer = buffers + constant_buffers.append( + Buffer(storage=constant_segment[start_offset:end_offset]) + ) + program.constant_buffer = constant_buffers program.constant_segment.segment_index = 0 program.constant_segment.offsets = [] - # Clear out the segments list since the original Program didn't have one. + # Extract mutable segments. + mutable_data = None + if program.mutable_data_segments and len(program.mutable_data_segments.offsets) > 0: + mutable_buffers: List[Buffer] = [] + mutable_segment = segments[program.mutable_segment.segment_index] + for i in range(len(program.mutable_segments.offsets)): + start_offset = program.mutable_segment.offsets[i] + # Note: this is the original end offset plus any padding between + # it and the next start offset. + end_offset = ( + program.mutable_segment.offsets[i + 1] + if i < len(program.mutable_segment.offsets) - 1 + else len(mutable_segment) + ) + mutable_buffers.append( + Buffer(storage=mutable_segment[start_offset:end_offset]) + ) + mutable_data = mutable_buffers + program.mutable_segment.segment_index = 0 + program.mutable_segment.offsets = [] + + # Extract named data. + named_data = None + if program.named_data: + named_data_store = NamedDataStore() + named_data_buffers: List[bytes] = [] + pte_data: Dict[str, DataEntry] = {} + + for entry in program.named_data: + if entry.segment_index >= len(segments): + raise ValueError( + "Named data segment index " + f"{entry.segment_index} >= num segments {len(segments)}" + ) + named_data_store.add_named_data( + key=entry.key, + data=segments[entry.segment_index], + alignment=1, # Deserialization does not preserve alignment. + tensor_layout=None, # PTE file currently does not serialize this. + ) + named_data = named_data_store.get_named_data_store_output() + program.named_data = [] program.segments = [] - return program + return PTEFile(program=program, mutable_data=mutable_data, named_data=named_data) -def deserialize_pte_binary(program_data: bytes) -> Program: +def deserialize_pte_binary(program_data: bytes) -> PTEFile: """Returns a Program deserialized from the given runtime binary data.""" program_size = len(program_data) segment_base_offset = 0 @@ -664,8 +723,8 @@ def deserialize_pte_binary(program_data: bytes) -> Program: if segment_base_offset != 0: # Move segment data back into the Program. - program = _restore_segments( + return _restore_segments( program=program, segment_data=program_data[segment_base_offset:] ) - return program + return PTEFile(program=program, mutable_data=None, named_data=None) diff --git a/exir/_serialize/test/test_program.py b/exir/_serialize/test/test_program.py index 80f4b8ca49f..b2a22694245 100644 --- a/exir/_serialize/test/test_program.py +++ b/exir/_serialize/test/test_program.py @@ -8,12 +8,13 @@ # pyre-unsafe import copy +import dataclasses import difflib import json import math import unittest -from typing import List, Sequence +from typing import Dict, List, Sequence from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json from executorch.exir._serialize._named_data_store import NamedDataStoreOutput @@ -281,13 +282,43 @@ def constant_segment_with_tensor_alignment( ) # Convert back. - program2 = deserialize_pte_binary(pte_data) + deserialized = deserialize_pte_binary(pte_data) # Programs are the same besides constant_buffer, as deserialization # does not preserve constant segment; padding may be added # during serialization. - self.assertEqual(program2.execution_plan, program.execution_plan) + self.assertEqual(deserialized.program.execution_plan, program.execution_plan) # Number of constant tensors should be the same. - self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer)) + self.assertEqual( + len(deserialized.program.constant_buffer), len(program.constant_buffer) + ) + self.assertEqual(deserialized.mutable_data, None) + self.assertEqual(deserialized.named_data, None) + + def _check_named_data_entries( + self, reference: Dict[str, DataEntry], actual: Dict[str, DataEntry] + ) -> None: + self.assertEqual(reference.keys(), actual.keys()) + SKIP_FIELDS = {"alignment"} # Fields to ignore in comparison. + for key in reference.keys(): + ref_entry = reference[key] + actual_entry = actual[key] + for field in dataclasses.fields(ref_entry): + if field.name not in SKIP_FIELDS: + self.assertEqual( + getattr(ref_entry, field.name), + getattr(actual_entry, field.name), + f"Named data record {key}.{field.name} does not match.", + ) + + def _check_named_data_store_output( + self, reference: NamedDataStoreOutput, actual: NamedDataStoreOutput + ) -> None: + # Check buffers. + self.assertEqual(reference.buffers, actual.buffers) + # Check pte_data. + self._check_named_data_entries(reference.pte_data, actual.pte_data) + # Should be empty. + self.assertEqual(reference.external_data, actual.external_data) def test_canonicalize_delegate_indices(self) -> None: def make_execution_plan( @@ -426,10 +457,12 @@ def test_round_trip_no_header_no_segments(self) -> None: self.assertIsNone(eh) # Convert back. - program2 = deserialize_pte_binary(pte_data) + deserialized = deserialize_pte_binary(pte_data) # Programs should be the same. - self.assert_programs_equal(program, program2) + self.assert_programs_equal(program, deserialized.program) + self.assertEqual(deserialized.mutable_data, None) + self.assertEqual(deserialized.named_data, None) def test_round_trip_large_buffer_sizes(self) -> None: """Tests that when the non_const_buffer_sizes contains integers @@ -439,7 +472,9 @@ def test_round_trip_large_buffer_sizes(self) -> None: program = get_test_program() program.execution_plan[0].non_const_buffer_sizes = [0, 2**48] flatbuffer_from_py = bytes(serialize_pte_binary(program)) - self.assert_programs_equal(program, deserialize_pte_binary(flatbuffer_from_py)) + self.assert_programs_equal( + program, deserialize_pte_binary(flatbuffer_from_py).program + ) def test_round_trip_no_segments_and_no_header(self) -> None: """Tests that a Program serialized with extract_delegate_segments=True @@ -463,10 +498,12 @@ def test_round_trip_no_segments_and_no_header(self) -> None: self.assertEqual(program_with_segments.segments, []) # Convert back. - program2 = deserialize_pte_binary(pte_data) + deserialized = deserialize_pte_binary(pte_data) # Programs should be the same. - self.assert_programs_equal(program, program2) + self.assert_programs_equal(program, deserialized.program) + self.assertEqual(deserialized.mutable_data, None) + self.assertEqual(deserialized.named_data, None) @staticmethod def gen_blob_data(size: int, pattern: bytes) -> bytes: @@ -598,8 +635,10 @@ def test_round_trip_with_segments(self) -> None: # meaning that the segments were moved back to inline. This also # demonstrates that the contents of all segments survived, and weren't # truncated or corrupted. - program2 = deserialize_pte_binary(pte_data) - self.assert_programs_equal(program, program2) + deserialized = deserialize_pte_binary(pte_data) + self.assert_programs_equal(program, deserialized.program) + self.assertEqual(deserialized.mutable_data, None) + self.assertEqual(deserialized.named_data, None) def test_no_constants(self) -> None: program = get_test_program() @@ -884,13 +923,17 @@ def test_constant_delegate_and_named_data_segments(self) -> None: ) # Convert back. - program2 = deserialize_pte_binary(pte_data) + deserialized = deserialize_pte_binary(pte_data) # Programs are the same besides constant_buffer, as deserialization # does not preserve constant segment; padding may be added # during serialization. - self.assertEqual(program2.execution_plan, program.execution_plan) + self.assertEqual(deserialized.program.execution_plan, program.execution_plan) # Number of constant tensors should be the same. - self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer)) + self.assertEqual( + len(deserialized.program.constant_buffer), len(program.constant_buffer) + ) + self.assertEqual(deserialized.mutable_data, None) + self._check_named_data_store_output(deserialized.named_data, named_data) def test_named_data_segments(self) -> None: # Set segment alignment to 12 to test the padding. @@ -995,6 +1038,27 @@ def test_named_data_segments(self) -> None: buffers[2], ) + # Test roundtrip + deserialized = deserialize_pte_binary(pte_data) + self.assert_programs_equal(deserialized.program, program) + self.assertEqual(deserialized.mutable_data, None) + self._check_named_data_store_output(deserialized.named_data, named_data) + + # Test re-serialize + pte_data2 = serialize_pte_binary( + deserialized.program, + extract_delegate_segments=True, + segment_alignment=SEGMENT_ALIGNMENT, + constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, + named_data=deserialized.named_data, + ) + # pte_data2 is not going to be the same as pte_data due to alignment; + # directly test the deserialized one. + deserialized2 = deserialize_pte_binary(bytes(pte_data2)) + self.assert_programs_equal(deserialized2.program, program) + self.assertEqual(deserialized2.mutable_data, None) + self._check_named_data_store_output(deserialized2.named_data, named_data) + # Common data for extended header tests. The two example values should produce # the example data. diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 199a667ab64..4844088c0c2 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1665,9 +1665,10 @@ def forward(self, x): self.assertEqual(values[5].val, Double(double_val=float("-inf"))) # Confirm that we can also deserialize the model with infinity in it. - pte_data = deserialize_pte_binary(model.buffer) + deserialize = deserialize_pte_binary(model.buffer) self.assertEqual( - pte_data.execution_plan, model.executorch_program.execution_plan + deserialize.program.execution_plan, + model.executorch_program.execution_plan, ) def test_mutate_input_tensor(self) -> None: From e9a1e37cdec90a7df21d93b37ded9a7ff5e179ad Mon Sep 17 00:00:00 2001 From: lucylq Date: Thu, 13 Nov 2025 15:55:03 -0800 Subject: [PATCH 2/5] Update on "Introduce PTEFile class" PTEFile class holds the components of a PTE file: the program, mutable constants and named data. Currently, the `program` definition does not contain mutable constants and named data; they are always stored in segments and not inline. This means when we deserialize, they are lost, because we only deserialize into the `program` concept. Now, segment data is included in the PTEFile class. Differential Revision: [D86814175](https://our.internmc.facebook.com/intern/diff/D86814175/) [ghstack-poisoned] From 9867dafdc3099fb8335d07a0d2c9e9aa6ddb2231 Mon Sep 17 00:00:00 2001 From: lucylq Date: Fri, 14 Nov 2025 15:01:04 -0800 Subject: [PATCH 3/5] Update on "Introduce PTEFile class" PTEFile class holds the components of a PTE file: the program, mutable constants and named data. Currently, the `program` definition does not contain mutable constants and named data; they are always stored in segments and not inline. This means when we deserialize, they are lost, because we only deserialize into the `program` concept. Now, segment data is included in the PTEFile class. Differential Revision: [D86814175](https://our.internmc.facebook.com/intern/diff/D86814175/) [ghstack-poisoned] --- exir/_serialize/_program.py | 72 ++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index e4af45c08ce..bf10f07245f 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -591,6 +591,31 @@ def serialize_pte_binary( return pte_data +def _restore_constant_segment( + constant_segment: SubsegmentOffsets, segment_data: bytes +) -> List[Buffer]: + """Convert constant and mutable tensors from a single byte-blob into a list of individual tensors. + + Args: + constant_segment: SubsegmentOffset with the offsets of each tensor. + segment_data: byte data containing the tensors and padding. + + Returns: + List[Buffer] containing each tensor in a separate object. + """ + buffers: List[Buffer] = [] + for i in range(len(constant_segment.offsets)): + start_offset = constant_segment.offsets[i] + # Note: this is the original end offset plus any padding between it and the next start offset + end_offset = ( + constant_segment.offsets[i + 1] + if i < len(constant_segment.offsets) - 1 + else len(segment_data) + ) + buffers.append(Buffer(storage=segment_data[start_offset:end_offset])) + return buffers + + def _restore_segments(program: Program, segment_data: bytes) -> PTEFile: """Moves segments from `segment_data` into `program`. @@ -640,52 +665,27 @@ def _restore_segments(program: Program, segment_data: bytes) -> PTEFile: # Replace constants from constant_segment into constant_buffer. if program.constant_segment and len(program.constant_segment.offsets) > 0: - constant_buffers: List[Buffer] = [] - constant_segment = segments[program.constant_segment.segment_index] - for i in range(len(program.constant_segment.offsets)): - start_offset = program.constant_segment.offsets[i] - # Note: this is the original end offset plus any padding between - # it and the next start offset. - end_offset = ( - program.constant_segment.offsets[i + 1] - if i < len(program.constant_segment.offsets) - 1 - else len(constant_segment) - ) - constant_buffers.append( - Buffer(storage=constant_segment[start_offset:end_offset]) - ) - program.constant_buffer = constant_buffers + program.constant_buffer = _restore_constant_segment( + program.constant_segment, segments[program.constant_segment.segment_index] + ) program.constant_segment.segment_index = 0 program.constant_segment.offsets = [] # Extract mutable segments. mutable_data = None - if program.mutable_data_segments and len(program.mutable_data_segments.offsets) > 0: - mutable_buffers: List[Buffer] = [] - mutable_segment = segments[program.mutable_segment.segment_index] - for i in range(len(program.mutable_segments.offsets)): - start_offset = program.mutable_segment.offsets[i] - # Note: this is the original end offset plus any padding between - # it and the next start offset. - end_offset = ( - program.mutable_segment.offsets[i + 1] - if i < len(program.mutable_segment.offsets) - 1 - else len(mutable_segment) - ) - mutable_buffers.append( - Buffer(storage=mutable_segment[start_offset:end_offset]) - ) - mutable_data = mutable_buffers - program.mutable_segment.segment_index = 0 - program.mutable_segment.offsets = [] + if program.mutable_data_segments and len(program.mutable_data_segments) > 0: + if len(program.mutable_data_segments) > 1: + raise ValueError("Cant't handle more than 1 mutable data segment.") + mutable_data = _restore_constant_segment( + program.mutable_data_segments[0], + segments[program.mutable_data_segments[0].segment_index], + ) + program.mutable_data_segments = None # Extract named data. named_data = None if program.named_data: named_data_store = NamedDataStore() - named_data_buffers: List[bytes] = [] - pte_data: Dict[str, DataEntry] = {} - for entry in program.named_data: if entry.segment_index >= len(segments): raise ValueError( From d83e35c08a6e4d11686e7126459135c3db8ff772 Mon Sep 17 00:00:00 2001 From: lucylq Date: Fri, 14 Nov 2025 15:04:28 -0800 Subject: [PATCH 4/5] Update on "Introduce PTEFile class" PTEFile class holds the components of a PTE file: the program, mutable constants and named data. Currently, the `program` definition does not contain mutable constants and named data; they are always stored in segments and not inline. This means when we deserialize, they are lost, because we only deserialize into the `program` concept. Now, segment data is included in the PTEFile class. Differential Revision: [D86814175](https://our.internmc.facebook.com/intern/diff/D86814175/) [ghstack-poisoned] From c6d72f53c9f25e06e6bf38016c6609164beb2855 Mon Sep 17 00:00:00 2001 From: lucylq Date: Fri, 14 Nov 2025 15:16:22 -0800 Subject: [PATCH 5/5] Update on "Introduce PTEFile class" PTEFile class holds the components of a PTE file: the program, mutable constants and named data. Currently, the `program` definition does not contain mutable constants and named data; they are always stored in segments and not inline. This means when we deserialize, they are lost, because we only deserialize into the `program` concept. Now, segment data is included in the PTEFile class. Differential Revision: [D86814175](https://our.internmc.facebook.com/intern/diff/D86814175/) [ghstack-poisoned]