diff --git a/.gitmodules b/.gitmodules index e21abf3baea..2897a3c4a50 100644 --- a/.gitmodules +++ b/.gitmodules @@ -59,3 +59,6 @@ path = third-party/lm-evaluation-harness url = https://github.com/EleutherAI/lm-evaluation-harness branch = v0.4.1 +[submodule "examples/third-party/llama.cpp"] + path = examples/third-party/llama.cpp + url = https://github.com/ggerganov/llama.cpp.git diff --git a/examples/models/llama2/playground/CMakeLists.txt b/examples/models/llama2/playground/CMakeLists.txt new file mode 100644 index 00000000000..cc858df0241 --- /dev/null +++ b/examples/models/llama2/playground/CMakeLists.txt @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# Simple CMake build system for selective build demo. +# +# ### Editing this file ### +# +# This file should be formatted with +# ~~~ +# cmake-format --first-comment-is-literal=True CMakeLists.txt +# ~~~ +# It should also be cmake-lint clean. +# + +cmake_minimum_required(VERSION 3.19) +project(QuantizedLinearOp) +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) +set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch) + +set(_common_compile_options -Wno-deprecated-declarations -fPIC) + +# Let files say "include ". +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +find_package(Llama REQUIRED) +find_package(ExecuTorch REQUIRED) +find_package(Torch CONFIG REQUIRED) + +target_include_directories(executorch INTERFACE ${_common_include_directories}) + + +set(kernel_sources ${EXECUTORCH_ROOT}/examples/models/llama2/playground/op_linear.cpp) +# +# custom_kernels: C++ kernel implementations of custom ops +# +add_library(custom_kernels SHARED ${kernel_sources}) +target_link_libraries(custom_kernels PRIVATE executorch ${LLAMA_LIBRARY}) +target_compile_options(custom_kernels PUBLIC ${_common_compile_options}) +target_include_directories(custom_kernels PRIVATE ${EXECUTORCH_ROOT}/examples/third-party ${TORCH_INCLUDE_DIRS}) + + +if(EXECUTORCH_BUILD_GTEST) + find_package( + gflags REQUIRED PATHS ${CMAKE_CURRENT_BINARY_DIR}/../../../../third-party + ) + # + # llama_cpp_test: test binary to run llama.cpp kernel ggml_mul_mat + # + add_executable(llama_cpp_test PRIVATE ${EXECUTORCH_ROOT}/examples/llama2/playground/test_op_linear.cpp) + + target_link_libraries(llama_cpp_test executorch gflags custom_kernels) + target_compile_options(llama_cpp_test PUBLIC ${_common_compile_options}) +endif() + +# Install libraries +install( + TARGETS custom_kernels + DESTINATION lib + INCLUDES + DESTINATION ${_common_include_directories}) diff --git a/examples/models/llama2/playground/op_linear.cpp b/examples/models/llama2/playground/op_linear.cpp new file mode 100644 index 00000000000..e1b7ccc403c --- /dev/null +++ b/examples/models/llama2/playground/op_linear.cpp @@ -0,0 +1,324 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + #define restrict __restrict__ + extern "C" +{ + #include + #include + #include +} + +#undef MAX +#undef MIN + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +namespace llama_cpp { +namespace native { + +using Tensor = exec_aten::Tensor; +using RuntimeContext = exec_aten::RuntimeContext; +using Error = torch::executor::Error; + +static void ggml_compute_forward_mul_mat( + const void * src0, + int64_t * ne0s, + size_t * nb0s, + const float * src1, + int64_t * ne1s, + size_t * nb1s, + float * dst, + int64_t * nes, + size_t * nbs) { + // Takes a q4_0 weight (src0) and a float activation (src1) + + // src0 dim, this is the weight + int64_t ne00 = ne0s[0]; + int64_t ne01 = ne0s[1]; + int64_t ne02 = ne0s[2]; + int64_t ne03 = ne0s[3]; + + size_t nb00 = nb0s[0]; + size_t nb01 = nb0s[1]; + size_t nb02 = nb0s[2]; + size_t nb03 = nb0s[3]; + + // src1 dim, this is the activation + int64_t ne10 = ne1s[0]; + int64_t ne11 = ne1s[1]; + int64_t ne12 = ne1s[2]; + int64_t ne13 = ne1s[3]; + + size_t nb10 = nb0s[0]; + size_t nb11 = nb0s[1]; + size_t nb12 = nb0s[2]; + size_t nb13 = nb0s[3]; + // dst dim + int64_t ne0 = nes[0]; + int64_t ne1 = nes[1]; + int64_t ne2 = nes[2]; + int64_t ne3 = nes[3]; + + size_t nb0 = nbs[0]; + size_t nb1 = nbs[1]; + size_t nb2 = nbs[2]; + size_t nb3 = nbs[3]; + + // single thread + const int ith = 0; + const int nth = 1; + + // const enum ggml_type type = src0->type; + + // const bool src1_cont = ggml_is_contiguous(src1); + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + // GGML_ASSERT(nb00 == ggml_type_size(type)); + // GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + // quantize activation + const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, ne10); + char * buffer = (char *) malloc(ne11*ne12*ne13*row_size); + char * wdata_itr = buffer; + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + quantize_row_q8_0((float *)((char *) src1 + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata_itr, ne10); + wdata_itr += row_size; + } + } + } + + + const char * wdata = buffer; + + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = ne1*ne12*ne13; // src1 rows + + //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); + + // distribute the thread work across the inner or outer loop based on which one is larger + + const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + + const int64_t ith0 = ith % nth0; + const int64_t ith1 = ith / nth0; + + const int64_t dr0 = (nr0 + nth0 - 1)/nth0; + const int64_t dr1 = (nr1 + nth1 - 1)/nth1; + + const int64_t ir010 = dr0*ith0; + const int64_t ir011 = MIN(ir010 + dr0, nr0); + + const int64_t ir110 = dr1*ith1; + const int64_t ir111 = MIN(ir110 + dr1, nr1); + + //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111); + + // threads with no work simply yield (not sure if it helps) + // if (ir010 >= ir011 || ir110 >= ir111) { + // sched_yield(); + // return; + // } + + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); + + // block-tiling attempt + const int64_t blck_0 = 16; + const int64_t blck_1 = 16; + + // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols + int64_t nrc = 1; + // TODO: currently the mmla kernels support only even numbered rows/cols. + // this check can be removed once they are extended to support odd numbered rows/cols too + if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { + nrc = 1; + } + + const size_t src1_col_stride = row_size; + + // attempt to reduce false-sharing (does not seem to make a difference) + // 16 * 2, accounting for mmla kernels + float tmp[32]; + + for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { + for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) { + const int64_t i13 = (ir1/(ne12*ne1)); + const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1; + const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1); + + // broadcast src0 into src1 + const int64_t i03 = i13/r3; + const int64_t i02 = i12/r2; + + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; + + const char * src0_row = (const char *) src0 + (0 + i02*nb02 + i03*nb03); + + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char *) wdata + (i11+i12*ne11 + i13*ne12*ne11)*row_size; + float * dst_col = (float *) ((char *) dst + (i1*nb1 + i2*nb2 + i3*nb3)); + + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); + //} + + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) { + ggml_vec_dot_q4_0_q8_0(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc); + } + + for (int cn = 0; cn < nrc; ++cn) { + memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + } + } + } + } + free(buffer); +} + +// Helper function to create a ggml q4_0 tensor with preallocated memory +static void * pack_q4_0(const Tensor & t, const Tensor & scale) { + int n_dims = t.dim(); + ET_CHECK_MSG(n_dims >= 1 && n_dims <= GGML_MAX_DIMS, "dimension %d is not within range (1, %d)", n_dims, GGML_MAX_DIMS); + + enum ggml_type type = GGML_TYPE_Q4_0; + + // TODO use memory from context to create tensor + struct ggml_tensor * const result = (struct ggml_tensor *) malloc(sizeof (struct ggml_tensor)); + ET_CHECK_MSG(t.scalar_type() == exec_aten::ScalarType::Byte, "Expected t to be Byte tensor but got %hdd", t.scalar_type()); + ET_CHECK_MSG(scale.scalar_type() == exec_aten::ScalarType::Float, "Expected scale to be Float tensor but got %hdd", scale.scalar_type()); + + // prepare a temp buffer to store the packed quantized values. Each block_q4_0 contains half of the group size (32 / 2 = 16) of uint8_t values and a fp16 scale value. + ET_CHECK_MSG(t.numel() % QK4_0 == 0, "Expecting numel to be multiple of %d but got %zu", QK4_0, t.numel()); + static const int qk = QK4_0; + + size_t group_num = t.numel() / qk; + block_q4_0 buf[group_num]; + int8_t* data = t.mutable_data_ptr(); + float* scales = scale.mutable_data_ptr(); + + // data here is int8 unpacked quantized values, need to convert to packed int4 format + for (size_t i = 0; i < group_num; ++i) { + int8_t* group_start = data + i * qk; + int8_t* group_end = data + (i+1) * qk; + block_q4_0* block = buf + i; + + block->d = GGML_FP32_TO_FP16(scales[i]); + for (int j = 0; j < QK4_0/2; ++j) { + block->qs[j] = group_start[j]; + block->qs[j] |= group_start[qk/2 + j] << 4; + } + } + + // memcopy the packed data into a new data from heap. This is safe because sizeof(block_q4_0) * group_num is smaller than t.numel() + void * dest = malloc(sizeof(block_q4_0) * group_num); + memcpy(dest, buf, sizeof(block_q4_0) * group_num); + return dest; +} + +Tensor& +linear_q4_0_out(const Tensor& weights, const Tensor& scale, const Tensor& activation, Tensor& out) { + // weights are int4 quantized values stored in int8 tensors, i.e., first 4 bits are 0. + // scale contains scales for groupwise (32) quantized values, numel = weights.numel() / 32 + // activation and out are float32 tensor + void * weights_packed = pack_q4_0(weights, scale); + int64_t weights_sizes[4]; + for (int i = 0; i < 4; i++) { + weights_sizes[i] = weights.size(i); + } + size_t weights_byte_sizes[4]; // strides * sizeof(block_q4_0) + weights_byte_sizes[0] = sizeof(block_q4_0); + for (int i = 1; i < 4; i++) { + weights_byte_sizes[i] = weights.size(i-1) / QK4_0 * weights_byte_sizes[i-1]; + } + // activation + const float * input = activation.const_data_ptr(); + int64_t input_sizes[4]; + for (int i = 0; i < 4; i++) { + input_sizes[i] = activation.size(i); + } + size_t input_byte_sizes[4]; + input_byte_sizes[0] = sizeof(float); + for (int i = 1; i < 4; i++) { + input_byte_sizes[i] = activation.size(i-1) * input_byte_sizes[i-1]; + } + // out + float * out_data = out.mutable_data_ptr(); + int64_t out_sizes[4]; + for (int i = 0; i < 4; i++) { + out_sizes[i] = out.size(i); + } + size_t out_byte_sizes[4]; + out_byte_sizes[0] = sizeof(float); + for (int i = 1; i < 4; i++) { + out_byte_sizes[i] = out.size(i-1) * out_byte_sizes[i-1]; + } + + ggml_compute_forward_mul_mat(weights_packed, weights_sizes, weights_byte_sizes, input, input_sizes, input_byte_sizes, out_data, out_sizes, out_byte_sizes); + + free(weights_packed); + return out; +} + +Tensor& +linear_q4_0_out_with_context(RuntimeContext& context, const Tensor& weights, const Tensor& scale, const Tensor& input, Tensor& out) { + (void)context; + return linear_q4_0_out(weights, scale, input, out); +} + + +EXECUTORCH_LIBRARY(ggml, "linear_q4_0.out", linear_q4_0_out_with_context); + +at::Tensor linear_q4_0(const at::Tensor& weight, const at::Tensor& scale, const at::Tensor& input) { + auto output = at::empty({input.size(0), weight.size(0)}, input.options().dtype(at::kHalf)); + WRAP_TO_ATEN(linear_q4_0_out, 3)(weight, scale, input, output); + return output; +} + + +TORCH_LIBRARY(ggml, m) { + m.def("linear_q4_0(Tensor weight, Tensor scale, Tensor input) -> Tensor", linear_q4_0); + m.def("linear_q4_0.out(Tensor weight, Tensor scale, Tensor input, *, Tensor(a!) out) -> Tensor(a!)", WRAP_TO_ATEN(linear_q4_0_out, 3)); +} + +} // namespace native +} // namespace llama_cpp diff --git a/examples/models/llama2/playground/quantize.py b/examples/models/llama2/playground/quantize.py new file mode 100644 index 00000000000..b9cc425a12a --- /dev/null +++ b/examples/models/llama2/playground/quantize.py @@ -0,0 +1,277 @@ + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .ops.quantized_ops import * # noqa +from ..quantize import _check_linear_int4_k, find_multiple # noqa + + +def get_group_qparams_symmetric(w, n_bit=4, groupsize=32, precision=torch.float32): + # GGML Q4_0 quantization. + if groupsize > w.shape[-1]: + groupsize = w.shape[-1] + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + + max_val_abs = torch.max(-min_val_neg, max_val_pos) + max_int = 2 ** (n_bit - 1) - 1 + min_int = -(2 ** (n_bit - 1)) + + scales = max_val_abs / (float(min_int - max_int) / 2) # for 4 bit this is max / -8 + scales = torch.min(scales, torch.full_like(scales, -torch.finfo(precision).eps)) # scale can't be larger than -eps + # TODO: make sure abs(scales) is not too small? + zeros = torch.full_like(scales, 8.5) + return scales.to(precision).reshape(w.shape[0], -1), zeros.to(precision).reshape( + w.shape[0], -1 + ) + + +def group_quantize_tensor_symmetric( + w, n_bit=4, group_size=32, precision=torch.float32 +): + scales, zeros = get_group_qparams_symmetric(w, n_bit, group_size, precision) + n_bit = 4 + max_int = 2 ** (n_bit - 1) - 1 + min_int = -(2 ** (n_bit - 1)) + # TODO: currently we don't know how to express torch.int4, we'll + # add torch.int4 to core later + w_int8 = torch.ops.quantized_decomposed.quantize_per_channel_group( + w, scales, zeros, min_int, max_int, torch.int8, group_size + ) + + return w_int8, scales, zeros + +def prepare_int4_weight_and_scales_and_zeros(weight, group_size, precision): + """ + llama.cpp Q4_0 quantization scheme. Symmetric groupwise 4bit quant with group + size 32 and zero point being fixed to 8.5. + """ + weight_int8, scales, zeros = group_quantize_tensor_symmetric( + weight, + n_bit=4, + group_size=group_size, + precision=precision, + ) + # weight_int4packed = torch.ops.quantized_decomposed.pack_int4_from_int8(weight_int8) + return weight_int8, scales, zeros + + + +def replace_linear_4w( + module, + group_size, + padding_allowed, + precision, + scales_precision, +): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + if _check_linear_int4_k(child.in_features, group_size) or padding_allowed: + setattr( + module, + name, + Int8DynActInt4WeightLinear( + child.in_features, + child.out_features, + bias=False, + group_size=group_size, + precision=precision, + scales_precision=scales_precision, + ), + ) + else: + replace_linear_4w( + child, + group_size, + padding_allowed, + precision, + scales_precision, + ) + + +class Int8DynActInt4WeightQuantHandler: + def __init__( + self, + mod, + group_size=32, + padding_allowed=False, + precision=torch.float32, + scales_precision=torch.float32, + ): + self.mod = mod + self.group_size = group_size + self.padding_allowed = padding_allowed + self.precision = precision + self.scales_precision = scales_precision + # assert group_size in [32, 64, 128, 256] + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + print("in features:", in_features, " out features:", out_features) + # assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + assert ( + in_features % self.group_size == 0 + ), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0" + + weight = mod.weight.data + """ + if not _check_linear_int4_k( + in_features, self.group_size + ): + if self.padding_allowed: + print( + f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" + ) + padded_in_features = _calc_padded_size_linear_int4( + in_features, self.group_size + ) + weight = F.pad( + weight, pad=(0, padded_in_features - in_features) + ) + else: + raise RuntimeError( + f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that group_size" + ) + """ + ( + weight_int4pack, + scales, + zeros, + ) = prepare_int4_weight_and_scales_and_zeros( + weight.to(self.precision), + self.group_size, + self.scales_precision, + ) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") + cur_state_dict[f"{fqn}.scales"] = scales.to("cpu") + cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu") + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_4w( + self.mod, + self.group_size, + self.padding_allowed, + self.precision, + self.scales_precision, + ) + return self.mod + + def quantized_model(self) -> nn.Module: + model_updated_state_dict = self.create_quantized_state_dict() + self.convert_for_runtime() + self.mod.load_state_dict(model_updated_state_dict) + return self.mod + + +class Int4WeightLinear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + + in_features: int + out_features: int + weight: torch.Tensor + + """ + This module implements a dynamic quantized linear layer with int4 weight. + Weights are per channel groupwise quantized. Activations will be quantized + into int8 in custom op. + + Parameters of importance: + + group_size: the number of elements in each quantized group + precision: precision of input and output. e.g. torch.float32 means input + activation is float32 and output is float32. + scales_precision: precision of per group scale. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias=True, + device=None, + dtype=None, + group_size: int = 32, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__() + # always pad if needed since it becomes a noop at runtime if not needed + # self.origin_in_features = in_features + assert ( + in_features % group_size == 0 + ), f"require in_features:{in_features} % group_size:{group_size} == 0" + # in_features = _calc_padded_size_linear_int4( + # in_features, group_size + # ) + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + self.group_size = group_size + # Precision of the activation which also indicates + # output precision of the dynamically quantized linear layer + # that his module represents. + self.precision = precision + + # currently storing unpacked int8 weights + self.register_buffer( + "weight", + torch.empty((out_features, in_features), dtype=torch.int8), + ) + self.register_buffer( + "scales", + torch.empty( + (out_features, in_features // group_size), + dtype=scales_precision, + ), + ) + self.register_buffer( + "zeros", + torch.empty( + (out_features, in_features // group_size), + dtype=scales_precision, + ), + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = input.to(self.precision) + # Change this to pad if needed later + # else this op will always show up + # input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + + """ + TODO: add a custom op here that takes quantized weights (int4 but unpacked into int8) + and fp32 activation, return fp32 result. Inside the op will convert activation to + int8, weights to int4 and perform dot product on int4 and int8. + + ggml::linear_q4_0(Tensor weights, Tensor scale, Tensor zeros, Tensor activation) -> Tensor + + """ + return torch.ops.ggml.linear_q4_0( + self.weight, self.scales, self.zeros, input + ).to(self.precision) diff --git a/examples/models/llama2/playground/test.sh b/examples/models/llama2/playground/test.sh new file mode 100644 index 00000000000..0fe9a97fe16 --- /dev/null +++ b/examples/models/llama2/playground/test.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -exu + +if [[ -z "${BUCK:-}" ]]; then + BUCK=buck2 +fi + +if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then + PYTHON_EXECUTABLE=python3 +fi + +cmake_install_ggml() { + cmake \ + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -Bcmake-out/examples/third-party/llama.cpp \ + examples/third-party/llama.cpp + + cmake --build cmake-out/examples/third-party/llama.cpp -j9 --config Debug --target install +} + +cmake_install_executorch() { + cmake \ + -DBUCK2=BUCK \ + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())') \ + -DPYTHON_EXECUTABLE=python \ + -DEXECUTORCH_BUILD_PYBIND=ON \ + -Bcmake-out . + + cmake --build cmake-out -j9 --config Debug --target install +} + +cmake_install_custom_op() { + cmake \ + -DBUCK2=BUCK \ + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -Bcmake-out/examples/models/llama2/playground \ + examples/models/llama2/playground + + cmake --build cmake-out/examples/models/llama2/playground -j9 --config Debug --target install +} + +cmake_install_ggml +cmake_install_executorch +cmake_install_custom_op diff --git a/examples/models/llama2/playground/test_op_linear.cpp b/examples/models/llama2/playground/test_op_linear.cpp new file mode 100644 index 00000000000..bf6b60fe63c --- /dev/null +++ b/examples/models/llama2/playground/test_op_linear.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace executor { + +using namespace ::testing; + +Tensor& my_op_out(const Tensor& a, Tensor& out) { + (void)a; + return out; +} + +Tensor& add_1_out(const Tensor& a, Tensor& out) { + (void)a; + out.mutable_data_ptr()[0] += 1; + return out; +} + +Tensor& quantized_embedding_byte_out( + const Tensor& weight, + const Tensor& weight_scales, + const Tensor& weight_zero_points, + int64_t weight_quant_min, + int64_t weight_quant_max, + const Tensor& indices, + Tensor& out) { + (void)weight; + (void)weight_scales; + (void)weight_zero_points; + (void)weight_quant_min; + (void)indices; + out.mutable_data_ptr()[0] -= static_cast(weight_quant_max); + return out; +} + +class MakeATenFunctorFromETFunctorTest : public ::testing::Test { + public: + void SetUp() override { + torch::executor::runtime_init(); + } +}; + +TEST_F(MakeATenFunctorFromETFunctorTest, Basic) { + auto function = WRAP_TO_ATEN(my_op_out, 1); + at::Tensor a = torch::tensor({1.0f}); + at::Tensor b = torch::tensor({2.0f}); + at::Tensor c = function(a, b); + EXPECT_EQ(c.const_data_ptr()[0], 2.0f); +} + +TORCH_LIBRARY(my_op, m) { + m.def("add_1.out", WRAP_TO_ATEN(add_1_out, 1)); + m.def( + "embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)", + WRAP_TO_ATEN(quantized_embedding_byte_out, 6)); +}; + +TEST_F(MakeATenFunctorFromETFunctorTest, RegisterWrappedFunction) { + auto op = c10::Dispatcher::singleton().findSchema({"my_op::add_1", "out"}); + EXPECT_TRUE(op.has_value()); + at::Tensor a = + torch::tensor({1}, torch::TensorOptions().dtype(torch::kInt32)); + at::Tensor b = + torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32)); + torch::jit::Stack stack = {a, b}; + op.value().callBoxed(&stack); + EXPECT_EQ(stack.size(), 1); + EXPECT_EQ(stack[0].toTensor().const_data_ptr()[0], 3); +} + +TEST_F(MakeATenFunctorFromETFunctorTest, TestEmbeddingByte) { + auto op = + c10::Dispatcher::singleton().findSchema({"my_op::embedding_byte", "out"}); + EXPECT_TRUE(op.has_value()); + at::Tensor weight = + torch::tensor({1}, torch::TensorOptions().dtype(torch::kInt32)); + at::Tensor scale = + torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32)); + at::Tensor zero_point = + torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32)); + at::Tensor indices = + torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32)); + at::Tensor out = + torch::tensor({4}, torch::TensorOptions().dtype(torch::kInt32)); + torch::jit::Stack stack = {weight, scale, zero_point, 0, 1, indices, out}; + op.value().callBoxed(&stack); + EXPECT_EQ(stack.size(), 1); + EXPECT_EQ(stack[0].toTensor().const_data_ptr()[0], 3); +} + +} // namespace executor +} // namespace torch diff --git a/examples/third-party/TARGETS b/examples/third-party/TARGETS new file mode 100644 index 00000000000..da79620f185 --- /dev/null +++ b/examples/third-party/TARGETS @@ -0,0 +1,18 @@ +load("@fbsource//xplat/executorch/build/runtime_wrapper.bzl", "runtime") + +runtime.cxx_library( + name = "ggml", + headers = glob([ + "llama.cpp/*.h", + ]), + srcs = [ + "llama.cpp/ggml.c", + "llama.cpp/ggml-alloc.c", + "llama.cpp/ggml-backend.c", + "llama.cpp/ggml-quants.c", + ], + _is_external_target = True, + exported_external_deps = [ + ("glibc", None, "pthread"), + ], +) diff --git a/examples/third-party/llama.cpp b/examples/third-party/llama.cpp new file mode 160000 index 00000000000..6cdabe65269 --- /dev/null +++ b/examples/third-party/llama.cpp @@ -0,0 +1 @@ +Subproject commit 6cdabe652695167263c8b447520987b11856f7ca diff --git a/extension/aten_util/make_aten_functor_from_et_functor.h b/extension/aten_util/make_aten_functor_from_et_functor.h index 976099f88fa..288fb6d156c 100644 --- a/extension/aten_util/make_aten_functor_from_et_functor.h +++ b/extension/aten_util/make_aten_functor_from_et_functor.h @@ -18,7 +18,7 @@ #error "This header requires C++17" #endif #include -#include +#include #include #include #include @@ -164,15 +164,15 @@ struct wrapper_impl { } }; +} // namespace executor +} // namespace torch + // Wrapper macro for out variant function. N is the index of the out tensor. // We need N to know how to preserve the semantics of modifying out tensor and // return the reference without allocating a new memory buffer for out tensor. #define _WRAP_2(func, N) \ - wrapper_impl::wrap -#define _WRAP_1(func) wrapper_impl::wrap + torch::executor::wrapper_impl::wrap +#define _WRAP_1(func) torch::executor::wrapper_impl::wrap #define GET_MACRO(_1, _2, NAME, ...) NAME #define WRAP_TO_ATEN(...) GET_MACRO(__VA_ARGS__, _WRAP_2, _WRAP_1)(__VA_ARGS__) - -} // namespace executor -} // namespace torch diff --git a/extension/kernel_util/make_boxed_from_unboxed_functor.h b/extension/kernel_util/make_boxed_from_unboxed_functor.h index fa69ed944a7..732bf80e727 100644 --- a/extension/kernel_util/make_boxed_from_unboxed_functor.h +++ b/extension/kernel_util/make_boxed_from_unboxed_functor.h @@ -138,8 +138,9 @@ static Kernel make_boxed_kernel(const char* name, FuncType) { return Kernel(name, WrapUnboxedIntoFunctor::call); } -#define EXECUTORCH_LIBRARY(ns, op_name, func) \ - static auto res_##ns = register_kernels( \ - make_boxed_kernel(#ns "::" op_name, EXECUTORCH_FN(func))) } // namespace executor } // namespace torch + +#define EXECUTORCH_LIBRARY(ns, op_name, func) \ + static auto res_##ns = torch::executor::register_kernels( \ + torch::executor::make_boxed_kernel(#ns "::" op_name, EXECUTORCH_FN(func))) diff --git a/extension/kernel_util/meta_programming.h b/extension/kernel_util/meta_programming.h index 46262b843ea..7b51d1721b9 100644 --- a/extension/kernel_util/meta_programming.h +++ b/extension/kernel_util/meta_programming.h @@ -48,11 +48,6 @@ template struct is_compile_time_function_pointer< CompileTimeFunctionPointer> : std::true_type {}; -#define EXECUTORCH_FN_TYPE(func) \ - CompileTimeFunctionPointer< \ - std::remove_pointer_t>, \ - func> -#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)() /** * strip_class: helper to remove the class type from pointers to `operator()`. @@ -113,3 +108,9 @@ using infer_function_traits_t = typename infer_function_traits::type; } // namespace executor } // namespace torch + +#define EXECUTORCH_FN_TYPE(func) \ + torch::executor::CompileTimeFunctionPointer< \ + std::remove_pointer_t>, \ + func> +#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)() diff --git a/sdk/CMakeLists.txt b/sdk/CMakeLists.txt index 05df3ba6ded..e4db8ffb728 100644 --- a/sdk/CMakeLists.txt +++ b/sdk/CMakeLists.txt @@ -52,7 +52,7 @@ endforeach() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../third-party/flatcc ${CMAKE_BINARY_DIR}/third-party/flatcc) - +target_compile_options(flatcc PUBLIC -fPIC) # Assume we are cross-compiling and the CMAKE_TOOLCHAIN_FILE is set include(ExternalProject)