Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/cadence/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_op_names(program: et_schema.Program, execution_plan_id: int = 0) -> set[
op_names |= get_op_names(
deserialize_pte_binary(
program.backend_delegate_data[delegate.processed.index].data
)
).program
)
return op_names

Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def dump_context_from_pte(pte_path) -> List[str]:
with open(pte_path, "rb") as f:
program_data = f.read()

program = deserialize_pte_binary(program_data)
program = deserialize_pte_binary(program_data).program

ctx_path = os.path.dirname(pte_path)
dummy_compiler_specs = generate_qnn_executorch_compiler_spec(
Expand Down
2 changes: 1 addition & 1 deletion codegen/tools/gen_ops_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_operators(model_file: str) -> List[Operator]:
print("Processing model file: ", model_file)
with open(model_file, "rb") as f:
flatbuffer = f.read()
program = _deserialize_pte_binary(flatbuffer)
program = _deserialize_pte_binary(flatbuffer).program
print(f"Program loaded from model file: {model_file}")
operators = program.execution_plan[0].operators
return operators
Expand Down
2 changes: 1 addition & 1 deletion examples/qualcomm/oss_scripts/llama/decoder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __init__( # noqa: C901

with open(pte_path, "rb") as f:
program_data = f.read()
program = deserialize_pte_binary(program_data)
program = deserialize_pte_binary(program_data).program

# Retrieve vocab_size from get_metadata under static_llama that is passed to edge manager
self.output_vocab_size = None
Expand Down
2 changes: 2 additions & 0 deletions exir/_serialize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

from executorch.exir._serialize._program import (
deserialize_pte_binary as _deserialize_pte_binary,
PTEFile as _PTEFile,
serialize_pte_binary as _serialize_pte_binary,
)

# Internal APIs that should not be used outside of exir.
__all__ = [
"_deserialize_pte_binary",
"_serialize_pte_binary",
"_PTEFile",
]
182 changes: 140 additions & 42 deletions exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
_program_flatbuffer_to_json,
_program_json_to_flatbuffer,
)
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
from executorch.exir._serialize._named_data_store import (
NamedDataStore,
NamedDataStoreOutput,
)

from executorch.exir._serialize.data_serializer import DataEntry

Expand All @@ -46,6 +49,19 @@
_HEADER_BYTEORDER: Literal["little"] = "little"


@dataclass
class PTEFile:
"""
Wraps together the data required to serialize into a PTE file.
"""

program: Program
# TODO(lfq): add constant data (currently restored in the program)
# TODO(lfq): update this to List[bytes]
mutable_data: Optional[List[Buffer]] = None
named_data: Optional[NamedDataStoreOutput] = None


@dataclass
class AlignedData:
"""
Expand Down Expand Up @@ -575,7 +591,91 @@ def serialize_pte_binary(
return pte_data


def _restore_segments(program: Program, segment_data: bytes) -> Program:
def _restore_delegates(program: Program, segments: List[bytes]) -> Program:
"""Find and replace the Program's references to these segments, inlining
the data.

Args:
program: The Program holding non-inlined delegates. Modified in-place.
segments: List of bytes containing the delegate data. Not modified.

Returns: The Program with delegates restored.
"""
for plan_index, plan in enumerate(program.execution_plan):
for delegate_index, delegate in enumerate(plan.delegates):
if delegate.processed.location == DataLocation.INLINE:
continue
assert delegate.processed.location == DataLocation.SEGMENT
index = delegate.processed.index
if index >= len(segments):
raise ValueError(
f"Plan {plan_index} delegate {delegate_index} "
+ f"segment index {index} >= num segments {len(segments)}"
)

data_index: int = len(program.backend_delegate_data)
program.backend_delegate_data.append(
BackendDelegateInlineData(data=segments[index])
)
delegate.processed = BackendDelegateDataReference(
location=DataLocation.INLINE, index=data_index
)
return program


def _restore_constant_segment(
constant_segment: SubsegmentOffsets, segment_data: bytes
) -> List[Buffer]:
"""Convert constant and mutable tensors from a single byte-blob into a list of individual tensors.

Args:
constant_segment: SubsegmentOffset with the offsets of each tensor.
segment_data: byte data containing the tensors and padding. Not modified.

Returns:
List[Buffer] containing each tensor in a separate object.
"""
buffers: List[Buffer] = []
for i in range(len(constant_segment.offsets)):
start_offset = constant_segment.offsets[i]
# Note: this is the original end offset plus any padding between it and the next start offset
end_offset = (
constant_segment.offsets[i + 1]
if i < len(constant_segment.offsets) - 1
else len(segment_data)
)
buffers.append(Buffer(storage=segment_data[start_offset:end_offset]))
return buffers


def _restore_named_data(
program: Program,
segments: List[bytes],
) -> NamedDataStoreOutput:
"""Moves named data from `segments` and `program` into the
NamedDataStoreOutput class.

Args:
program: The Program holding named data references. Not modified.
segments: The data containing the segments. Not modified.
"""
named_data_store = NamedDataStore()
for entry in program.named_data:
if entry.segment_index >= len(segments):
raise ValueError(
"Named data segment index "
f"{entry.segment_index} >= num segments {len(segments)}"
)
named_data_store.add_named_data(
key=entry.key,
data=segments[entry.segment_index],
alignment=1, # Deserialization does not preserve alignment.
tensor_layout=None, # PTE file currently does not serialize this.
)
return named_data_store.get_named_data_store_output()


def _restore_segments(program: Program, segment_data: bytes) -> PTEFile:
"""Moves segments from `segment_data` into `program`.

This should recreate the original Program that the segments were extracted
Expand All @@ -589,7 +689,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
the preceding data has been stripped off so that the first segment
begins at offset zero.
Returns:
The Program with segments restored.
PTEFile, containing the Program with delegate and constant segments restored, mutable data segment, and named data segment.
"""
# Extract the list of segment data blobs, which parallel program.segments.
segments: List[bytes] = []
Expand All @@ -600,53 +700,51 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
)
segments.append(segment_data[segment.offset : segment.offset + segment.size])

# Find and replace the Program's references to these segments, inlining the
# data.
for plan_index, plan in enumerate(program.execution_plan):
for delegate_index, delegate in enumerate(plan.delegates):
if delegate.processed.location == DataLocation.INLINE:
continue
assert delegate.processed.location == DataLocation.SEGMENT
index = delegate.processed.index
if index >= len(segments):
raise ValueError(
f"Plan {plan_index} delegate {delegate_index} "
+ f"segment index {index} >= num segments {len(segments)}"
)

data_index: int = len(program.backend_delegate_data)
program.backend_delegate_data.append(
BackendDelegateInlineData(data=segments[index])
)
delegate.processed = BackendDelegateDataReference(
location=DataLocation.INLINE, index=data_index
)
# Restore delegate segments that weren't inlined previously.
program = _restore_delegates(program, segments)

# Replace constants from constant_segment into constant_buffer.
if program.constant_segment and len(program.constant_segment.offsets) > 0:
buffers: List[Buffer] = []
constant_segment = segments[program.constant_segment.segment_index]
for i in range(len(program.constant_segment.offsets)):
start_offset = program.constant_segment.offsets[i]
# Note: this is the original end offset plus any padding between
# it and the next start offset.
end_offset = (
program.constant_segment.offsets[i + 1]
if i < len(program.constant_segment.offsets) - 1
else len(constant_segment)
if program.constant_segment.segment_index >= len(segments):
raise ValueError(
f"Constant segment index {program.constant_segment.segment_index} >= num segments {len(segments)}"
)
buffers.append(Buffer(storage=constant_segment[start_offset:end_offset]))
program.constant_buffer = buffers
program.constant_buffer = _restore_constant_segment(
program.constant_segment, segments[program.constant_segment.segment_index]
)
program.constant_segment.segment_index = 0
program.constant_segment.offsets = []

# Clear out the segments list since the original Program didn't have one.
# Extract mutable segments.
mutable_data = None
if program.mutable_data_segments and len(program.mutable_data_segments) > 0:
if len(program.mutable_data_segments) > 1:
raise ValueError("Can't handle more than 1 mutable data segment.")
segment_index = program.mutable_data_segments[0].segment_index
if segment_index >= len(segments):
raise ValueError(
f"Mutable data segment index {segment_index} >= num segments {len(segments)}"
)
mutable_data = _restore_constant_segment(
program.mutable_data_segments[0],
segments[segment_index],
)
program.mutable_data_segments = None

# Extract named data.
named_data = None
if program.named_data:
named_data = _restore_named_data(program, segments)

# Clear named_data and segments, which are empty pre-serialization.
program.named_data = []
program.segments = []
return program

return PTEFile(program=program, mutable_data=mutable_data, named_data=named_data)

def deserialize_pte_binary(program_data: bytes) -> Program:
"""Returns a Program deserialized from the given runtime binary data."""

def deserialize_pte_binary(program_data: bytes) -> PTEFile:
"""Returns a PTEFile deserialized from the given runtime binary data."""
program_size = len(program_data)
segment_base_offset = 0

Expand All @@ -664,8 +762,8 @@ def deserialize_pte_binary(program_data: bytes) -> Program:

if segment_base_offset != 0:
# Move segment data back into the Program.
program = _restore_segments(
return _restore_segments(
program=program, segment_data=program_data[segment_base_offset:]
)

return program
return PTEFile(program=program, mutable_data=None, named_data=None)
Loading
Loading