forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
1926 lines (1544 loc) · 65.2 KB
/
utils.py
File metadata and controls
1926 lines (1544 loc) · 65.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import operator
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
VkDataType,
VkMemoryLayout,
VkStorageType,
)
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
format_target_name,
)
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.tensor import TensorSpec
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorConverter
from torch.export import ExportedProgram
from torch.export.exported_program import InputKind
from torch.export.graph_signature import TensorArgument
TorchOpType = Union[EdgeOpOverload, torch._ops.OpOverload, str]
_DQ_OPS = {
"dequantize_per_tensor.tensor",
"dequantize_per_tensor.default",
"dequantize_per_channel.default",
"dequantize_per_channel_group.default",
"dequantize_per_token.default",
"dequantize_affine.default",
}
_Q_OPS = {
"quantize_per_tensor.tensor",
"quantize_per_tensor.default",
"quantize_per_channel.default",
"quantize_per_token.default",
"quantize_affine.default",
}
_VULKAN_DTYPES: Dict[torch.dtype, VkDataType] = {
torch.bool: VkDataType.BOOL,
torch.uint8: VkDataType.UINT8,
torch.int8: VkDataType.INT8,
torch.int32: VkDataType.INT32,
torch.int64: VkDataType.INT64,
torch.float16: VkDataType.FLOAT16,
torch.float32: VkDataType.FLOAT32,
torch.float64: VkDataType.FLOAT64,
}
##
## Dtype sets for per-operator dtype constraints
##
DtypeSet = Set[torch.dtype]
FP_T: DtypeSet = {torch.float16, torch.float32}
INT_T: DtypeSet = {torch.int32, torch.int64}
QINT8_T: DtypeSet = {torch.int8}
BOOL_T: DtypeSet = {torch.bool}
ALL_T: DtypeSet = set(_VULKAN_DTYPES.keys())
NONE_T: DtypeSet = set() # Marker for non-tensor args (skip validation)
# Composite dtype sets for specific operator requirements
FP_INT_T: DtypeSet = FP_T | INT_T
FP_INT_BOOL_T: DtypeSet = FP_T | INT_T | BOOL_T
class DtypeSetList:
"""
Wrapper around a list of DtypeSet with broadcasting semantics.
If only one DtypeSet is provided, it applies to all positions.
"""
def __init__(self, dtype_sets: Union[DtypeSet, List[DtypeSet]]):
self.vals: List[DtypeSet] = (
dtype_sets if isinstance(dtype_sets, list) else [dtype_sets]
)
def __len__(self) -> int:
return len(self.vals)
def __getitem__(self, idx: int) -> DtypeSet:
# Broadcasting: single set applies to all positions
if idx > 0 and len(self.vals) == 1:
return self.vals[0]
if idx >= len(self.vals):
return set()
return self.vals[idx]
def is_empty(self) -> bool:
return len(self.vals) == 0
def any_constrained(self) -> bool:
"""Returns True if any position has dtype constraints."""
return any(len(s) > 0 for s in self.vals)
##
## Node type determination
##
# Convenience type
MaybeNodeList = Union[torch.fx.Node, List[torch.fx.Node], Tuple[torch.fx.Node]]
def is_torch_op_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
if isinstance(node.target, EdgeOpOverload):
return True
if isinstance(node.target, torch._ops.OpOverload):
return True
return False
def is_dequant_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
node_name = format_target_name(node.target.__name__) # pyre-ignore
return node_name in _DQ_OPS
def is_quant_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
node_name = format_target_name(node.target.__name__) # pyre-ignore
return node_name in _Q_OPS
def is_choose_qparams_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
node_name = format_target_name(node.target.__name__) # pyre-ignore
return "choose_qparams" in node_name
def is_dynamic_qscale(node: Any) -> bool:
"""Check if a scale node is dynamically computed via a choose_qparams op."""
return (
isinstance(node, torch.fx.Node)
and node.target == operator.getitem
and is_choose_qparams_node(node.args[0])
)
def is_dequant_per_channel_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
node_name = format_target_name(node.target.__name__) # pyre-ignore
return node_name == "dequantize_per_channel.default"
def is_view_copy_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
node_name = format_target_name(node.target.__name__) # pyre-ignore
return "view_copy" in node_name
def is_linear_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
node_name = format_target_name(node.target.__name__) # pyre-ignore
return node_name == "linear.default"
def is_get_attr_node(node: torch.fx.Node) -> bool:
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
def is_constant(program: ExportedProgram, node: torch.fx.Node) -> bool:
return node.name in program.graph_signature.inputs_to_lifted_tensor_constants
def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool:
"""
Check if the given node is a parameter within the exported program
"""
return (
is_get_attr_node(node)
or is_param(program, node)
or is_buffer(program, node)
or is_lifted_tensor_constant(program, node)
)
def is_mutable_buffer_node(
node: torch.fx.Node, exported_program: ExportedProgram
) -> bool:
if node.target not in exported_program.graph_signature.inputs_to_buffers:
return False
buf = exported_program.graph_signature.inputs_to_buffers[node.target]
return buf in exported_program.graph_signature.buffers_to_mutate.values()
def is_symint_node(node: torch.fx.Node) -> bool:
"""
Returns true if the given node produces a SymInt value
"""
if "val" not in node.meta:
return False
if isinstance(node.meta["val"], torch.SymInt):
return True
return False
def is_single_tensor_node(node: torch.fx.Node) -> bool:
"""
Returns true if the given node produces a single tensor value
"""
if "val" not in node.meta:
return False
if isinstance(node.meta["val"], FakeTensor):
return True
return False
def is_tensor_collection_node(node: Any) -> bool:
"""
Returns true if the given node produces a collection of tensor values
"""
if not isinstance(node, torch.fx.Node):
return False
if "val" not in node.meta:
return False
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
return all(isinstance(x, FakeTensor) for x in node.meta["val"])
return False
def is_tensor_node(node: Any) -> bool:
"""
Returns true if the given node produces a tensor value, or a collection of tensor values
"""
if not isinstance(node, torch.fx.Node):
return False
if "val" not in node.meta:
return False
if isinstance(node.meta["val"], FakeTensor):
return True
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
return all(isinstance(x, FakeTensor) for x in node.meta["val"])
return False
def is_tensor_arg_node(node: Any) -> bool:
if isinstance(node, torch.fx.Node):
return is_tensor_node(node)
elif isinstance(node, (list, tuple)):
if len(node) == 0:
return False
return all(is_tensor_node(n) for n in node)
return False
def num_tensor_arg_nodes(node: torch.fx.Node) -> int:
"""
For a given node, return the number of argument nodes that are associated with
tensors.
"""
count = 0
for arg_node in node.args:
if not isinstance(arg_node, torch.fx.Node):
continue
if is_tensor_node(arg_node):
count += 1
return count
def num_tensors_in_node(node: torch.fx.Node) -> int:
"""
Returns the number of tensors associated a given node
"""
if "val" not in node.meta:
return 0
if isinstance(node.meta["val"], FakeTensor):
return 1
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
if all(isinstance(x, FakeTensor) for x in node.meta["val"]):
return len(node.meta["val"])
return 0
def get_vk_datatype(torch_dtype: torch.dtype) -> VkDataType:
"""
Returns Vulkan dtype corresponding to torch dtype
"""
if torch_dtype not in _VULKAN_DTYPES:
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")
return _VULKAN_DTYPES[torch_dtype]
def output_dtypes_are_supported(node: torch.fx.Node) -> bool:
"""
Returns true if the output of the given tensor node has dtype that
is supported by the Vulkan backend.
"""
if not is_tensor_node(node):
return True
# The val metadata must exist after previous check
node_val = node.meta.get("val", None)
assert node_val is not None
# Get all the tensor dtypes in the node
tensor_dtypes = []
if isinstance(node_val, FakeTensor):
tensor_dtypes = [node_val.dtype]
elif isinstance(node_val, list) or isinstance(node_val, tuple):
tensor_dtypes = [x.dtype for x in node_val]
# Verify that all the tensor_dtypes are in vk_torch_dtypes
return all(dtype in _VULKAN_DTYPES for dtype in tensor_dtypes)
def input_dtypes_are_supported(node: torch.fx.Node) -> bool:
"""
Returns true if all the inputs to the given tensor node have dtype that
is supported by the Vulkan backend.
"""
if not is_tensor_node(node):
return True
# Iterate over all the args of the node
for arg_node in node.args:
# The arg could be a single node, or a list (e.g., first arg of cat)
if isinstance(arg_node, torch.fx.Node):
if not output_dtypes_are_supported(arg_node):
return False
elif isinstance(arg_node, (list, tuple)):
if not all(output_dtypes_are_supported(x) for x in arg_node):
return False
return True
def io_dtypes_are_supported(node: torch.fx.Node) -> bool:
"""
Returns true if all the inputs and outputs of the given tensor node have
dtype that is supported by the Vulkan backend.
"""
if not output_dtypes_are_supported(node):
return False
if not input_dtypes_are_supported(node):
return False
return True
def check_node_dtypes( # noqa: C901
node: torch.fx.Node,
inputs_dtypes: DtypeSetList,
outputs_dtypes: DtypeSetList,
) -> Tuple[bool, str]:
"""
Check if all tensor inputs/outputs have dtypes in the allowed sets.
Returns (is_valid, reason_string) for better error reporting.
"""
# Check input tensor dtypes
for i, arg in enumerate(node.args):
allowed_dtypes = inputs_dtypes[i]
# Skip non-constrained positions (NO_DTYPE = empty set)
if len(allowed_dtypes) == 0:
continue
if is_tensor_node(arg):
if isinstance(arg.meta["val"], (list, tuple)):
arg_dtype = arg.meta["val"][0].dtype
else:
arg_dtype = arg.meta["val"].dtype
if arg_dtype not in allowed_dtypes:
return False, f"input[{i}] dtype {arg_dtype} not in {allowed_dtypes}"
elif isinstance(arg, (list, tuple)):
# Handle tensor list inputs (e.g., cat)
for j, sub_arg in enumerate(arg):
if is_tensor_node(sub_arg):
sub_dtype = sub_arg.meta["val"].dtype
if sub_dtype not in allowed_dtypes:
return (
False,
f"input[{i}][{j}] dtype {sub_dtype} not in {allowed_dtypes}",
)
# Check output tensor dtypes
out_val = node.meta.get("val")
if isinstance(out_val, FakeTensor):
allowed_dtypes = outputs_dtypes[0]
if len(allowed_dtypes) > 0 and out_val.dtype not in allowed_dtypes:
return False, f"output dtype {out_val.dtype} not in {allowed_dtypes}"
elif isinstance(out_val, (list, tuple)):
for i, t in enumerate(out_val):
if isinstance(t, FakeTensor):
allowed_dtypes = outputs_dtypes[i]
if len(allowed_dtypes) > 0 and t.dtype not in allowed_dtypes:
return False, f"output[{i}] dtype {t.dtype} not in {allowed_dtypes}"
return True, "dtypes valid"
def tensor_node_is_bool(node: torch.fx.Node) -> bool:
"""
Returns true if a given node contains a tensor with bool dtype
"""
if isinstance(node.meta["val"], FakeTensor):
return node.meta["val"].dtype == torch.bool
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
for fake_tensor in node.meta["val"]:
if isinstance(fake_tensor, FakeTensor):
if fake_tensor.dtype == torch.bool:
return True
return False
def ndim_of(node: Any) -> Optional[int]:
"""
Returns the number of dimensions of the tensor produced by the given node
"""
if not is_single_tensor_node(node):
return None
return node.meta["val"].ndim
def is_unsqueezed_vector(node: torch.fx.Node) -> bool:
"""
Returns True if the node's tensor has all dimensions equal to 1 except for the last dimension.
"""
if not is_single_tensor_node(node):
return False
tensor = node.meta["val"]
assert isinstance(tensor, FakeTensor)
if len(tensor.shape) < 1:
return False
# All dims except last are 1, last can be any size
return all(dim == 1 for dim in tensor.shape[:-1])
def op_contains_bool_tensor(node: torch.fx.Node) -> bool:
"""
Returns true if the operator used to compute the given node contains a bool tensor
"""
if is_tensor_node(node) and tensor_node_is_bool(node):
return True
for arg_node in node.args:
# pyre-ignore[6]
if is_tensor_node(arg_node) and tensor_node_is_bool(arg_node):
return True
return False
def op_contains_high_dim_tensor(node: torch.fx.Node) -> bool:
"""
Returns true if the operator used to compute the given node contains a tensor
with more than 4 dimensions
"""
if is_tensor_node(node) and tensor_node_is_high_dim(node):
return True
for arg_node in node.args:
# pyre-ignore[6]
if is_tensor_node(arg_node) and tensor_node_is_high_dim(arg_node):
return True
return False
def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]:
primary_arg_idx: Optional[int] = None
for i, arg_node in enumerate(node.args):
if self.is_non_constant_tensor_node(arg_node):
return i
return primary_arg_idx
def node_comes_from_any_nn_module_in_set(
node,
nn_module_typenames: Set[str],
) -> bool:
if isinstance(node, (list, tuple)):
return all(
node_comes_from_any_nn_module_in_set(n, nn_module_typenames) for n in node
)
if not isinstance(node, torch.fx.Node):
return False
nn_module_stack = node.meta.get("nn_module_stack", None)
if nn_module_stack is None:
return False
for _, packed in nn_module_stack.items():
_, typename = packed
for partial_name in nn_module_typenames:
if partial_name in typename:
return True
return False
def get_tensor_name(exp_prog: ExportedProgram, node: torch.fx.Node) -> str:
if node is None:
return ""
if is_param(exp_prog, node):
return exp_prog.graph_signature.inputs_to_parameters[node.name]
elif is_buffer(exp_prog, node):
return exp_prog.graph_signature.inputs_to_buffers[node.name]
elif is_lifted_tensor_constant(exp_prog, node):
return exp_prog.graph_signature.inputs_to_lifted_tensor_constants[node.name]
else:
assert isinstance(node.target, str)
return node.target
return ""
def find_dequant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]:
"""
Search the direct users of the given node and return the first one that is a
dequantization op. Returns None if no dequantization op is found.
"""
for user in node.users:
if is_dequant_node(user):
return user
return None
def find_quant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]:
"""
Search the direct users of the given node and return the first one that is a
quantization op. Returns None if no quantization op is found.
"""
for user in node.users:
if is_quant_node(user):
return user
return None
def node_has_target(node: Any, target: str):
if not hasattr(node, "target"):
return False
if isinstance(node.target, str):
return node.target == target
elif hasattr(node.target, "name"):
return node.target.name() == target
return False
##
## Memory Layout, Storage Type Determination
##
ImageExtents = Tuple[int, int, int]
DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048)
DEFAULT_BUFFER_LIMIT = 128 * (1024 * 1024)
all_storage_types: Set[VkStorageType] = {
VkStorageType.BUFFER,
VkStorageType.TEXTURE_3D,
}
# Memory layouts available to non-quantized tensors
all_memory_layouts: Set[VkMemoryLayout] = {
VkMemoryLayout.TENSOR_WIDTH_PACKED,
VkMemoryLayout.TENSOR_HEIGHT_PACKED,
VkMemoryLayout.TENSOR_CHANNELS_PACKED,
}
# Memory layouts available to quantized tensors
all_quantized_memory_layouts: Set[VkMemoryLayout] = {
VkMemoryLayout.PACKED_INT8_4W4C,
VkMemoryLayout.PACKED_INT8_4H4W,
VkMemoryLayout.PACKED_INT8_4W,
VkMemoryLayout.PACKED_INT8_4C1W,
VkMemoryLayout.PACKED_INT8_CONV2D,
}
universal_memory_layout_set: Set[VkMemoryLayout] = (
all_memory_layouts | all_quantized_memory_layouts
)
MemoryLayoutSet = Set[VkMemoryLayout]
MemoryLayoutSetList = Union[MemoryLayoutSet, List[MemoryLayoutSet]]
_LAYOUT_TO_PACKED_DIM: Dict[VkMemoryLayout, int] = {
VkMemoryLayout.TENSOR_WIDTH_PACKED: 0,
VkMemoryLayout.TENSOR_HEIGHT_PACKED: 1,
VkMemoryLayout.TENSOR_CHANNELS_PACKED: 2,
VkMemoryLayout.PACKED_INT8_4W4C: 2,
VkMemoryLayout.PACKED_INT8_4H4W: 0,
VkMemoryLayout.PACKED_INT8_4C1W: 2,
VkMemoryLayout.PACKED_INT8_CONV2D: 2,
}
def packed_dim_of(layout: VkMemoryLayout) -> int:
return _LAYOUT_TO_PACKED_DIM[layout]
@dataclass(frozen=True)
class PackedDimInfo:
"""
Describes how tensor data is organized in physical memory, mirroring the
C++ PackedDimInfo struct in runtime/api/containers/Tensor.h.
"""
packed_dim: int
packed_dim_block_size: int
@classmethod
def from_repr(
cls,
memory_layout: VkMemoryLayout,
storage_type: VkStorageType = VkStorageType.BUFFER,
) -> "PackedDimInfo":
"""
Construct a PackedDimInfo based on a memory layout and storage type,
mirroring calculate_packed_dim_info in runtime/api/containers/Tensor.cpp.
"""
is_buffer = storage_type == VkStorageType.BUFFER
if memory_layout == VkMemoryLayout.TENSOR_WIDTH_PACKED:
return cls(
packed_dim=0,
packed_dim_block_size=1 if is_buffer else 4,
)
elif memory_layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED:
return cls(
packed_dim=1,
packed_dim_block_size=1 if is_buffer else 4,
)
elif memory_layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED:
return cls(
packed_dim=2,
packed_dim_block_size=1 if is_buffer else 4,
)
elif memory_layout == VkMemoryLayout.PACKED_INT8_4W:
return cls(
packed_dim=0,
packed_dim_block_size=4,
)
elif memory_layout == VkMemoryLayout.PACKED_INT8_4W4C:
return cls(
packed_dim=2,
packed_dim_block_size=4,
)
elif memory_layout == VkMemoryLayout.PACKED_INT8_4H4W:
return cls(
packed_dim=0,
packed_dim_block_size=4,
)
elif memory_layout == VkMemoryLayout.PACKED_INT8_4C1W:
return cls(
packed_dim=2,
packed_dim_block_size=4 if is_buffer else 16,
)
elif memory_layout == VkMemoryLayout.PACKED_INT8_CONV2D:
return cls(
packed_dim=2,
packed_dim_block_size=4,
)
else:
raise ValueError(f"Unknown memory layout: {memory_layout}")
def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int:
"""
Checks whether the tensors produced by the given node can fit within the device's
GPU buffer limit, which represents the maximum number of elements that can be stored
in a GPU buffer.
"""
assert is_tensor_node(node)
if isinstance(node.meta["val"], FakeTensor):
return node.meta["val"].numel() < buffer_limit
elif isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
return all(x.numel() < buffer_limit for x in node.meta["val"])
else:
raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}")
def tensor_node_is_high_dim(node: torch.fx.Node) -> bool:
"""
Returns true if a given node contains a tensor with more than 4 dimensions
"""
if isinstance(node.meta["val"], FakeTensor):
return len(node.meta["val"].shape) > 4
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
for fake_tensor in node.meta["val"]:
if isinstance(fake_tensor, FakeTensor):
if len(fake_tensor.shape) > 4:
return True
return False
def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents:
"""
Calculate the image extents that will be used to represent a tensor with the given sizes
and memory layout in the Vulkan Delegate.
"""
width = sizes[-1] if len(sizes) >= 1 else 1
height = sizes[-2] if len(sizes) >= 2 else 1
channels = sizes[-3] if len(sizes) >= 3 else 1
batch = sizes[0] if len(sizes) >= 4 else 1
if layout == VkMemoryLayout.TENSOR_WIDTH_PACKED:
width = (width + 3) // 4
elif layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED:
height = (height + 3) // 4
elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED:
channels = (channels + 3) // 4
elif layout == VkMemoryLayout.PACKED_INT8_4W4C:
width = (width + 3) // 4
channels = (channels + 3) // 4
elif layout == VkMemoryLayout.PACKED_INT8_4H4W:
height = (height + 3) // 4
width = (width + 3) // 4
elif layout == VkMemoryLayout.PACKED_INT8_CONV2D:
# Use conservative extents (same as 4W4C) since this is buffer-only
width = (width + 3) // 4
channels = (channels + 3) // 4
else:
raise RuntimeError(f"Unsupported memory layout {layout}")
return width, height, channels * batch
def extents_are_valid(extents: ImageExtents, limits: ImageExtents) -> bool:
return all(extents[i] <= limits[i] for i in range(len(extents)))
def valid_texture_memory_layouts(
tensor_sizes: torch.Size, texture_limits: ImageExtents
) -> Set[VkMemoryLayout]:
"""
Given tensor sizes, determine the set of memory layouts which will prodice a texture
that can fit within the specified device limits.
"""
valid_layouts = set()
for layout in list(all_memory_layouts):
extents = required_image_extents(tensor_sizes, layout)
if extents_are_valid(extents, texture_limits):
valid_layouts.add(layout)
return valid_layouts
class TensorRepr:
"""
This class is a wrapper around a pair of VkStorageType and VkMemoryLayout which
describes how a tensor should be represented in the Vulkan Delegate.
"""
def __init__(self, storage_type: VkStorageType, memory_layout: VkMemoryLayout):
self.storage_type = storage_type
self.memory_layout = memory_layout
def __str__(self) -> str:
return f"TensorRepr({self.storage_type}, {self.memory_layout})"
def __eq__(self, other: object) -> bool:
if not isinstance(other, TensorRepr):
return NotImplemented
return (
self.storage_type == other.storage_type
and self.memory_layout == other.memory_layout
)
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
class TensorReprList:
"""
This class is a wrapper around a list of TensorRepr instances that automatically
applies a "broadcasting" mechanism. The broadcasting mechanism allows for a single
underlying TensorRepr to be used to represent multiple tensors.
"""
def __init__(self, tensor_reprs: Union[TensorRepr, List[TensorRepr]]):
self.vals: List[TensorRepr] = (
tensor_reprs if isinstance(tensor_reprs, list) else [tensor_reprs]
)
def __len__(self):
return len(self.vals)
def __getitem__(self, idx: int) -> TensorRepr:
if idx > 0 and len(self) == 1:
return self.vals[0]
else:
return self.vals[idx]
def __setitem__(self, idx: int, val: TensorRepr) -> None:
if idx > 0 and len(self) == 1:
self.vals[0] = val
else:
self.vals[idx] = val
def __str__(self) -> str:
return f"[{', '.join(str(ts) for ts in self.vals)}]"
def __eq__(self, other: object) -> bool:
if not isinstance(other, TensorReprList):
return NotImplemented
if len(self) == len(other):
for self_val, other_val in zip(self.vals, other.vals):
if self_val != other_val:
return False
return True
return False
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def append(self, val: TensorRepr) -> None:
self.vals.append(val)
def storage_type(self, idx: int = 0) -> VkStorageType:
return self.vals[idx].storage_type
def memory_layout(self, idx: int = 0) -> VkMemoryLayout:
return self.vals[idx].memory_layout
class TensorRepSet:
"""
This class describes the possible set of representations (i.e. TensorRepr) that may
be used to represent a tensor. This set is determined by the implementation of the
operator that the tensor participates in as well as the texture extents of the GPU.
"""
def __init__(
self,
buffer_memory_layouts: Set[VkMemoryLayout],
texture_memory_layouts: Set[VkMemoryLayout],
):
self.valid_buffer_layouts = buffer_memory_layouts
self.valid_texture_layouts = texture_memory_layouts
def __str__(self) -> str:
buffer_layouts = ", ".join(layout.name for layout in self.valid_buffer_layouts)
texture_layouts = ", ".join(
layout.name for layout in self.valid_texture_layouts
)
return f"TensorRepSet(Buffer Layouts: [{buffer_layouts}], Texture Layouts: [{texture_layouts}])"
def __eq__(self, other: object) -> bool:
if not isinstance(other, TensorRepSet):
return NotImplemented
return (
self.valid_buffer_layouts == other.valid_buffer_layouts
and self.valid_texture_layouts == other.valid_texture_layouts
)
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def copy(self) -> "TensorRepSet":
return TensorRepSet(
set(self.valid_buffer_layouts), set(self.valid_texture_layouts)
)
def is_empty(self) -> bool:
"""
A TensorRepSet is "empty" if there are no valid representations of the tensor.
"""
return (
len(self.valid_buffer_layouts) == 0 and len(self.valid_texture_layouts) == 0
)
def make_intersect(self, other: "TensorRepSet") -> "TensorRepSet":
"""
Merge this TensorRepr with another TensorRepr, returning a new TensorRepr
with the intersection of the two.
"""
return TensorRepSet(
self.valid_buffer_layouts & other.valid_buffer_layouts,
self.valid_texture_layouts & other.valid_texture_layouts,
)
def make_union(self, other: "TensorRepSet") -> "TensorRepSet":
"""
Merge this TensorRepSet with another TensorRepSet, returning a new TensorRepSet
with the union of the two.
"""
return TensorRepSet(
self.valid_buffer_layouts | other.valid_buffer_layouts,
self.valid_texture_layouts | other.valid_texture_layouts,
)
def is_compatible(self, storage: TensorRepr) -> bool:
"""
Check if this TensorRepr is compatible with the given TensorRepSet.
"""
if storage.storage_type == VkStorageType.BUFFER:
return storage.memory_layout in self.valid_buffer_layouts
elif storage.storage_type == VkStorageType.TEXTURE_3D:
return storage.memory_layout in self.valid_texture_layouts
else:
raise RuntimeError(f"Unsupported storage type {storage.storage_type}")
def any_in_common(self, other: "TensorRepSet") -> bool:
"""
Check if this TensorRepr has any representations in common with another
TensorRepr.
"""
return (
len(self.valid_buffer_layouts & other.valid_buffer_layouts) > 0
or len(self.valid_texture_layouts & other.valid_texture_layouts) > 0
)
def texture_is_valid(self):
return len(self.valid_texture_layouts) > 0
def buffer_is_valid(self):
return len(self.valid_buffer_layouts) > 0
def first_valid_buffer_layout(self):
return list(self.valid_buffer_layouts)[0]
def first_valid_texture_layout(self):
return list(self.valid_texture_layouts)[0]
def make_tensor_repr(self) -> TensorRepr:
"""
Pick a representation (i.e. TensorRepr) from the set of possible representations.
If there are multiple valid representations, then:
1. Prefer texture storage over buffer storage
2. Pick the first available memory layout.
"""
if self.is_empty():
# An empty repset typically means that it is associated with a weight tensor
# or non tensor argument. In this case, just return default storage and
# layout as placeholder.
return TensorRepr(
VkStorageType.DEFAULT_STORAGE, VkMemoryLayout.DEFAULT_LAYOUT
)
if self.texture_is_valid():
return TensorRepr(
VkStorageType.TEXTURE_3D, self.first_valid_texture_layout()
)
else:
return TensorRepr(VkStorageType.BUFFER, self.first_valid_buffer_layout())
def is_constrained(self) -> bool:
"""
A "constrained" RepSet is one that has either:
1. A single valid texture memory layout, and no valid buffer memory layouts
2. No valid texture memory layouts, and a single valid buffer memory layout
3. Is empty
In this case, it is unambiguous which representation should be used for the
tensor.
"""
if self.is_empty():
return True
elif (
len(self.valid_texture_layouts) == 1 and len(self.valid_buffer_layouts) == 0