diff --git a/backends/qualcomm/_passes/fuse_consecutive_transpose.py b/backends/qualcomm/_passes/fuse_consecutive_transpose.py index 16ce3803076..04d96462c9f 100644 --- a/backends/qualcomm/_passes/fuse_consecutive_transpose.py +++ b/backends/qualcomm/_passes/fuse_consecutive_transpose.py @@ -55,12 +55,6 @@ def _clone_transpose( clone_permute_node.meta = n.meta users[i].replace_input_with(n, clone_permute_node) - def _is_dispensable(self, axis_order): - for index, value in enumerate(axis_order): - if index != value: - return False - return True - def _traverse(self, node): if node in self.visited or node.target not in self.op_map: return @@ -87,25 +81,22 @@ def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: axis_order = torch.arange(len(input_shape)).tolist() for node in self.nodes: axis_order = [axis_order[i] for i in node.args[1]] - # If axis order is just [0,1,2,3], we ignore permute node - if self._is_dispensable(axis_order): - for user in output_node.users.copy(): - user.replace_input_with(output_node, n.args[0]) - else: - with graph.inserting_after(input_node): - permute_op = exir_ops.edge.aten.permute_copy.default - permute_node = graph.create_node( - "call_function", permute_op, (input_node, axis_order) - ) - users = output_node.users.copy() - for user in users: - user.replace_input_with(output_node, permute_node) - - # copy metadata - permute_node.meta = output_node.meta - # Without "qnn_permute", we might obtain wrong input shape - if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]: - permute_node.meta[QCOM_INSERTED_PERMUTE] = True + + # Reserve [0,1,2,3] permute node to ensure the next node get the right axis order. + with graph.inserting_after(input_node): + permute_op = exir_ops.edge.aten.permute_copy.default + permute_node = graph.create_node( + "call_function", permute_op, (input_node, axis_order) + ) + users = output_node.users.copy() + for user in users: + user.replace_input_with(output_node, permute_node) + + # copy metadata + permute_node.meta = output_node.meta + # Without "qnn_permute", we might obtain wrong input shape + if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]: + permute_node.meta[QCOM_INSERTED_PERMUTE] = True # clear current stack self.nodes = [] diff --git a/backends/qualcomm/_passes/recompose_rms_norm.py b/backends/qualcomm/_passes/recompose_rms_norm.py index bfaddfc47b5..77feecf9c1f 100644 --- a/backends/qualcomm/_passes/recompose_rms_norm.py +++ b/backends/qualcomm/_passes/recompose_rms_norm.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch +from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.fx.passes.utils.source_matcher_utils import get_source_partitions @@ -16,8 +17,9 @@ class RecomposeRmsNorm(ExportPass): Merge decomposed operators back to one super node. """ - def __init__(self): - super().__init__() + def __init__(self, edge_program: torch.export.ExportedProgram): + super(RecomposeRmsNorm, self).__init__() + self.edge_program = edge_program def _get_eps_node(self, nodes): # eps: one of inputs of add node @@ -47,11 +49,15 @@ def call(self, graph_module: torch.fx.GraphModule): input_node = inp_0 if len(inp_0.users) == 2 else inp_1 else: raise RuntimeError( - f"Found a edge case of rms_node partitoin {src_partition}, which has {input_len} inputs" + f"Found a edge case of rms_node partition {src_partition}, which has {input_len} inputs" ) output_node = src_partition.output_nodes[0] - eps_node = self._get_eps_node(src_partition.nodes) + eps = self._get_eps_node(src_partition.nodes) + if isinstance(eps, torch.fx.Node) and is_parameter( + eps, self.edge_program + ): + eps = get_parameter(eps, self.edge_program).item() gamma_node = self._get_gamma_node(output_node) with graph.inserting_before(output_node): @@ -64,7 +70,7 @@ def call(self, graph_module: torch.fx.GraphModule): input_node, list(gamma_node.meta["val"].shape), gamma_node, - eps_node, + eps, ), ) users = output_node.users.copy() diff --git a/backends/qualcomm/builders/op_rms_norm.py b/backends/qualcomm/builders/op_rms_norm.py index e5b4778312e..d224e34feb5 100644 --- a/backends/qualcomm/builders/op_rms_norm.py +++ b/backends/qualcomm/builders/op_rms_norm.py @@ -12,7 +12,11 @@ import torch from executorch.backends.qualcomm.builders.utils import get_parameter -from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS +from executorch.backends.qualcomm.utils.constants import ( + QCOM_DATA, + QCOM_QUANT_ATTRS, + QCOM_ZERO_POINT, +) from executorch.exir.dialects._ops import ops as exir_ops from .node_visitor import NodeVisitor, register_node_visitor @@ -66,7 +70,7 @@ def define_node( nodes_to_wrappers, ) - # Fake node, nn module seems to be inconsistant with document + # Fake node, nn module seems to be inconsistent with document bias_tensor = torch.zeros(weight_tensor.shape) bias_node = torch.fx.Node( node.graph, @@ -78,6 +82,7 @@ def define_node( ) if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + bias_node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINT] = 0 bias_tensor_wrapper = self.define_tensor( bias_node, node, @@ -87,14 +92,6 @@ def define_node( ) epsilon = node.args[3] - if isinstance(epsilon, torch.fx.Node): - epsilon = get_parameter(epsilon, self.edge_program) - epsilon = ( - epsilon - if isinstance(epsilon, float) - else torch.finfo(epsilon.dtype).eps - ) - output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 0829d99d57a..a999270c15b 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -539,6 +539,28 @@ def compile(args, pte_filename, tokenizer): if "model" in state_dict: state_dict = state_dict["model"] + # Change to HuggingFace weight to improve the performance of RoPE in HTP backend. + def permute(w, heads): + dim_0 = w.size(0) + dim_1 = w.size(1) + return ( + w.view(heads, dim_0 // heads // 2, 2, dim_1) + .transpose(1, 2) + .reshape(dim_0, dim_1) + ) + + n_heads = llama_instance_list[0].n_heads + n_kv_heads = llama_instance_list[0].n_kv_heads + n_layers = llama_instance_list[0].n_layers + + for layer_i in range(n_layers): + state_dict[f"layers.{layer_i}.attention.wq.weight"] = permute( + state_dict[f"layers.{layer_i}.attention.wq.weight"], n_heads + ) + state_dict[f"layers.{layer_i}.attention.wk.weight"] = permute( + state_dict[f"layers.{layer_i}.attention.wk.weight"], n_kv_heads + ) + for llama_instance in llama_instance_list: llama_instance.load_state_dict( state_dict, diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index ea8e2f5d319..f7893792e00 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -19,12 +19,14 @@ def apply_rotary_emb_single( x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor ) -> torch.Tensor: - x_r, x_i = x[..., ::2], x[..., 1::2] - - # brodcast for batch_prefill mode input x + # The implementation of RoPE in HuggingFace processes query and key with two half instead of interleaved way. + # The main difference is stride in StrideSlice op. For interleaved way, stride is two which is not friendly for HTP backend. + # Ref: https://github.com/huggingface/transformers/issues/25199 + x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + # broadcast for batch_prefill mode input x if x.dim() == 4: - freqs_cos = freqs_cos[None, :, None, :] - freqs_sin = freqs_sin[None, :, None, :] + freqs_cos = freqs_cos[None, None, :, :] + freqs_sin = freqs_sin[None, None, :, :] x_out_r = x_r * freqs_cos - x_i * freqs_sin x_out_i = x_r * freqs_sin + x_i * freqs_cos @@ -104,25 +106,33 @@ def forward_sha( v_caches: Optional[List[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, seq_len, _ = hidden_states.shape + # In the HTP backend, the input axis order for the convolution operation is + # more efficient with [1, 1, seq_len, dim] compared to [1, seq_len, 1, dim]. hidden_states = torch.reshape( hidden_states, (bsz, seq_len, 1, self.dim) ).transpose(1, 3) q = [ - wq_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + wq_sha(hidden_states) + .permute(0, 2, 3, 1) + .reshape(bsz, seq_len, self.head_dim) for wq_sha in self.wq_sha ] k = [ - wk_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + wk_sha(hidden_states) + .permute(0, 2, 3, 1) + .reshape(bsz, seq_len, self.head_dim) for wk_sha in self.wk_sha ] v = [ - wv_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + wv_sha(hidden_states) + .permute(0, 2, 3, 1) + .reshape(bsz, seq_len, self.head_dim) for wv_sha in self.wv_sha ] for i in range(len(q)): q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) for i in range(len(k)): - k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1) + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).transpose(1, 2) output_y = [] kh, vh = [], [] @@ -249,10 +259,10 @@ def prepare_feedfoward_conv(self): def forward_feedfoward_conv(self, x): bsz, _, _ = x.size() - x = torch.reshape(x, (bsz, -1, self.dim, 1)) - x = x.transpose(1, 2) # Transpose right before and after Conv + x = torch.reshape(x, (bsz, -1, 1, self.dim)) + x = x.transpose(1, 3) # Transpose right before and after Conv x = self.w2_conv(F.silu(self.w1_conv(x)) * self.w3_conv(x)) - x = x.transpose(1, 2) + x = x.transpose(1, 3) x = torch.reshape(x, (bsz, -1, self.dim)) return x