diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000..7242dd283c --- /dev/null +++ b/.editorconfig @@ -0,0 +1,22 @@ +root = true + +# Unix-style newlines with a newline ending every file +[*] +end_of_line = lf +insert_final_newline = true + +[{CMakeLists.txt,*.cmake}] +indent_style = space +indent_size = 2 + +[*.{cc,h,cu,cpp}] +indent_style = space +indent_size = 2 + +[*.py] +indent_style = space +indent_size = 4 + +[*.toml] +indent_style = space +indent_size = 2 diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000..efec9cf353 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.dtg.cc linguist-generated=true +*.dtg.h linguist-generated=true diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 8b15d68960..c554c54dc4 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -72,13 +72,13 @@ jobs: run: | build_libs.sh kernels - - name: Build substitutions - run: | - build_libs.sh substitutions + # - name: Build substitutions + # run: | + # build_libs.sh substitutions - - name: Build compiler - run: | - build_libs.sh compiler + # - name: Build compiler + # run: | + # build_libs.sh compiler - name: Build substitution-generator run: | @@ -88,13 +88,21 @@ jobs: run: | test_libs.sh utils - - name: Test substitutions + - name: Test op-attrs run: | - test_libs.sh substitutions + test_libs.sh op-attrs - - name: Test compiler + - name: Test pcg run: | - test_libs.sh compiler + test_libs.sh pcg + + # - name: Test substitutions + # run: | + # test_libs.sh substitutions + + # - name: Test compiler + # run: | + # test_libs.sh compiler - name: Test substitution-generator run: | @@ -102,16 +110,19 @@ jobs: - name: Generate code coverage run: | + echo "gitwork: $GITHUB_WORKSPACE" lcov --capture --directory . --output-file main_coverage.info - lcov --remove main_coverage.info '/nix/store/' --output-file main_coverage.info - lcov --list main_coverage.info + lcov --extract main_coverage.info "$GITHUB_WORKSPACE/lib/*" --output-file main_coverage_e.info + lcov --remove main_coverage_e.info "$GITHUB_WORKSPACE/lib/*.dtg.h" "$GITHUB_WORKSPACE/lib/*.dtg.cc" --output-file main_coverage_e_f.info + lcov --list main_coverage_e_f.info - name: Upload code coverage uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} - files: main_coverage.info + file: main_coverage_e_f.info flags: unittests + plugin: pycoverage #hope this will disable gcov name: codecov-umbrella fail_ci_if_error: false verbose: true diff --git a/.proj.toml b/.proj.toml index a4592dcccc..b076671498 100644 --- a/.proj.toml +++ b/.proj.toml @@ -7,13 +7,18 @@ build_targets = [ "utils", "op-attrs", "kernels", - "substitutions", - "compiler", + "pcg", + # "substitutions", + # "compiler", + "substitution-generator", ] test_targets = [ "utils-tests", - "substitutions-tests", - "compiler-tests", + "op-attrs-tests", + "pcg-tests", + # "substitutions-tests", + # "compiler-tests", + "substitution-generator-tests", ] [cmake_flags_extra] diff --git a/flake.lock b/flake.lock index 6562cc1c78..f0fc292a5e 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1717041372, - "narHash": "sha256-YW0fHKoMxpI9Bmyk4/aAz6VkCPgDnAAO34m4Vp7DLiY=", + "lastModified": 1717449667, + "narHash": "sha256-xFGnB44WadxlCa2LnlH82g1c89+7UAomVgytIewSwO0=", "owner": "lockshaw", "repo": "proj", - "rev": "ae83ac17bebe6eae4594f8e0a8b5869529693eea", + "rev": "28b37a9bd993d3de3d80695eb3834a0436c805a4", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 70146a750c..2dc005b113 100644 --- a/flake.nix +++ b/flake.nix @@ -112,6 +112,10 @@ inputsFrom = [ ci ]; inherit (ci) CMAKE_FLAGS; + VIMPLUGINS = lib.strings.concatStringsSep "," [ + "${proj-repo.packages.${system}.proj-nvim}" + ]; + buildInputs = builtins.concatLists [ (with pkgs; [ clang-tools diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index 91c7a11888..8c176eb4d2 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -41,10 +41,9 @@ TEST_SUITE(FF_TEST_SUITE) { MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; pcg.add_edge(e); - pcg.add_output(e, - ParallelTensor(ParallelTensorDims({2, 1}), - DataType::FLOAT, - CreateGrad::YES)); + ParallelDim dim = {2, 1, false}; + ParallelTensorDims dims = {FFOrdered{dim}}; + pcg.add_output(e, ParallelTensor(dims, DataType::FLOAT, CreateGrad::YES)); auto test_allowed_machine_views = [](Operator const &, MachineSpecification const &) { diff --git a/lib/kernels/include/kernels/array_shape.h b/lib/kernels/include/kernels/array_shape.h index 15f14f8757..1cb10e8ce7 100644 --- a/lib/kernels/include/kernels/array_shape.h +++ b/lib/kernels/include/kernels/array_shape.h @@ -6,6 +6,7 @@ #include "utils/stack_vector.h" #include "utils/visitable.h" #include +#include #include namespace FlexFlow { @@ -41,8 +42,10 @@ struct ArrayShape { std::optional at_maybe(std::size_t) const; ArrayShape reversed_dim_order() const; - ArrayShape sub_shape(std::optional start, - std::optional end) const; + + ArrayShape + sub_shape(std::optional> start, + std::optional> end) const; public: LegionTensorDims dims; diff --git a/lib/kernels/include/kernels/cast_kernels.h b/lib/kernels/include/kernels/cast_kernels.h index 4e6878e318..96f9aadd52 100644 --- a/lib/kernels/include/kernels/cast_kernels.h +++ b/lib/kernels/include/kernels/cast_kernels.h @@ -4,7 +4,7 @@ #include "device.h" #include "kernels/accessor.h" #include "kernels/ff_handle.h" -#include "op-attrs/activation.h" +#include "op-attrs/activation.dtg.h" namespace FlexFlow { namespace Kernels { diff --git a/lib/kernels/include/kernels/conv_2d_kernels.h b/lib/kernels/include/kernels/conv_2d_kernels.h index b646c4b7cb..0a93125367 100644 --- a/lib/kernels/include/kernels/conv_2d_kernels.h +++ b/lib/kernels/include/kernels/conv_2d_kernels.h @@ -4,7 +4,7 @@ #include "device.h" #include "kernels/accessor.h" #include "kernels/ff_handle.h" -#include "op-attrs/activation.h" +#include "op-attrs/activation.dtg.h" #include "utils/visitable.h" namespace FlexFlow { diff --git a/lib/kernels/include/kernels/element_binary_kernels.h b/lib/kernels/include/kernels/element_binary_kernels.h index a9cbba420e..41447e98e6 100644 --- a/lib/kernels/include/kernels/element_binary_kernels.h +++ b/lib/kernels/include/kernels/element_binary_kernels.h @@ -5,7 +5,7 @@ #include "ff_handle.h" #include "kernels/array_shape.h" #include "op-attrs/datatype.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" namespace FlexFlow { diff --git a/lib/kernels/include/kernels/element_unary_kernels.h b/lib/kernels/include/kernels/element_unary_kernels.h index dedfbb01ef..5044b0cdb2 100644 --- a/lib/kernels/include/kernels/element_unary_kernels.h +++ b/lib/kernels/include/kernels/element_unary_kernels.h @@ -24,18 +24,34 @@ namespace ElementUnary { ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, ArrayShape const &output_shape, - ElementUnaryUnifiedAttrs const &attrs); + ElementUnaryAttrs const &attrs); void forward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryUnifiedAttrs const &attrs, + ElementUnaryAttrs const &attrs, PerDeviceFFHandle &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output); +void forward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementScalarUnaryAttrs const &attrs, + PerDeviceFFHandle &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output); + +void backward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad); + void backward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryUnifiedAttrs const &attrs, + ElementScalarUnaryAttrs const &attrs, PerDeviceFFHandle &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, diff --git a/lib/kernels/include/kernels/legion_dim.h b/lib/kernels/include/kernels/legion_dim.h index f5c1d7ccc9..cf6ebfc2d4 100644 --- a/lib/kernels/include/kernels/legion_dim.h +++ b/lib/kernels/include/kernels/legion_dim.h @@ -1,14 +1,14 @@ #ifndef _FLEXFLOW_KERNELS_INCLUDE_KERNELS_LEGION_DIM_H #define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_LEGION_DIM_H +#include "kernels/legion_dim_t.dtg.h" #include "op-attrs/dim_ordered.h" -#include "utils/strong_typedef.h" namespace FlexFlow { -struct legion_dim_t : strong_typedef { - using strong_typedef::strong_typedef; -}; +legion_dim_t add_to_legion_dim(legion_dim_t, int); + +legion_dim_t legion_dim_from_ff_dim(ff_dim_t, int num_dimensions); template using LegionOrdered = DimOrdered; diff --git a/lib/kernels/include/kernels/legion_dim_t.dtg.h b/lib/kernels/include/kernels/legion_dim_t.dtg.h new file mode 100644 index 0000000000..622f9c240a --- /dev/null +++ b/lib/kernels/include/kernels/legion_dim_t.dtg.h @@ -0,0 +1,54 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/kernels/include/kernels/legion_dim_t.struct.toml +/* proj-data +{ + "generated_from": "f67d6e50c53539a21d69e7162cf965f4" +} +*/ + +#ifndef _FLEXFLOW_LIB_KERNELS_INCLUDE_KERNELS_LEGION_DIM_T_DTG_H +#define _FLEXFLOW_LIB_KERNELS_INCLUDE_KERNELS_LEGION_DIM_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct legion_dim_t { + legion_dim_t() = delete; + legion_dim_t(int const &value); + + bool operator==(legion_dim_t const &) const; + bool operator!=(legion_dim_t const &) const; + bool operator<(legion_dim_t const &) const; + bool operator>(legion_dim_t const &) const; + bool operator<=(legion_dim_t const &) const; + bool operator>=(legion_dim_t const &) const; + int value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::legion_dim_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::legion_dim_t from_json(json const &); + static void to_json(json &, FlexFlow::legion_dim_t const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(legion_dim_t const &); +std::ostream &operator<<(std::ostream &, legion_dim_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_KERNELS_INCLUDE_KERNELS_LEGION_DIM_T_DTG_H diff --git a/lib/kernels/include/kernels/legion_dim_t.struct.toml b/lib/kernels/include/kernels/legion_dim_t.struct.toml new file mode 100644 index 0000000000..d2afb0d73f --- /dev/null +++ b/lib/kernels/include/kernels/legion_dim_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "legion_dim_t" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/kernels/include/kernels/pool_2d_kernels.h b/lib/kernels/include/kernels/pool_2d_kernels.h index 96bb6eccf9..798c0507f8 100644 --- a/lib/kernels/include/kernels/pool_2d_kernels.h +++ b/lib/kernels/include/kernels/pool_2d_kernels.h @@ -3,7 +3,7 @@ #include "device.h" #include "kernels/ff_handle.h" -#include "op-attrs/activation.h" +#include "op-attrs/activation.dtg.h" #include "op-attrs/ops/pool_2d.h" #include "utils/visitable.h" diff --git a/lib/kernels/include/kernels/reduce_kernels.h b/lib/kernels/include/kernels/reduce_kernels.h index 51730fb0cd..56241b73ce 100644 --- a/lib/kernels/include/kernels/reduce_kernels.h +++ b/lib/kernels/include/kernels/reduce_kernels.h @@ -4,7 +4,7 @@ #include "array_shape.h" #include "device.h" #include "ff_handle.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.dtg.h" namespace FlexFlow { diff --git a/lib/kernels/include/kernels/transpose_kernels.h b/lib/kernels/include/kernels/transpose_kernels.h index cb34ff6736..fa087fada3 100644 --- a/lib/kernels/include/kernels/transpose_kernels.h +++ b/lib/kernels/include/kernels/transpose_kernels.h @@ -8,7 +8,7 @@ namespace FlexFlow { struct TransposePerDeviceState { int num_dim; - req> perm; + req> perm; }; FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(TransposePerDeviceState, diff --git a/lib/kernels/src/cuda/ops/concat_kernels.cu b/lib/kernels/src/cuda/ops/concat_kernels.cu index dcf7a41a2f..68004738d2 100644 --- a/lib/kernels/src/cuda/ops/concat_kernels.cu +++ b/lib/kernels/src/cuda/ops/concat_kernels.cu @@ -25,15 +25,8 @@ void calc_blk_size(size_t &num_blocks, size_t &blk_size, ArrayShape const &shape, ff_dim_t axis) { - num_blocks = 1; - blk_size = 1; - for (int d = 0; d < shape.num_dims(); d++) { - if (d <= axis) { - blk_size *= shape[legion_dim_t(d)]; - } else { - num_blocks *= shape[legion_dim_t(d)]; - } - } + blk_size = shape.sub_shape(legion_dim_t{0}, axis).num_elements(); + num_blocks = shape.sub_shape(axis, std::nullopt).num_elements(); } void forward_kernel(cudaStream_t stream, diff --git a/lib/kernels/src/cuda/ops/element_binary_kernels.cu b/lib/kernels/src/cuda/ops/element_binary_kernels.cu index be369ff064..45b4d43006 100644 --- a/lib/kernels/src/cuda/ops/element_binary_kernels.cu +++ b/lib/kernels/src/cuda/ops/element_binary_kernels.cu @@ -17,14 +17,12 @@ #include "kernels/element_binary_kernels.h" #include "kernels/ff_handle.h" #include "op-attrs/datatype.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" namespace FlexFlow { namespace Kernels { namespace ElementBinary { -using OperatorType = Op; - __global__ void elewise_binary_backward_kernel(size_t volume, float const alpha, float const beta, @@ -36,28 +34,28 @@ __global__ void elewise_binary_backward_kernel(size_t volume, float *rhs_grad) { CUDA_KERNEL_LOOP(i, volume) { switch (type) { - case Op::EW_ADD: { + case OperatorType::EW_ADD: { lhs_grad[i] = alpha * out_grad[i] + beta * lhs_grad[i]; rhs_grad[i] = alpha * out_grad[i] + beta * rhs_grad[i]; break; } - case Op::EW_SUB: { + case OperatorType::EW_SUB: { lhs_grad[i] = alpha * out_grad[i] + beta * lhs_grad[i]; rhs_grad[i] = -alpha * out_grad[i] + beta * rhs_grad[i]; break; } - case Op::EW_MUL: { + case OperatorType::EW_MUL: { lhs_grad[i] = alpha * out_grad[i] * rhs[i] + beta * lhs_grad[i]; rhs_grad[i] = alpha * out_grad[i] * lhs[i] + beta * rhs_grad[i]; break; } - case Op::EW_DIV: { + case OperatorType::EW_DIV: { lhs_grad[i] = alpha * out_grad[i] / rhs[i] + beta * lhs_grad[i]; rhs_grad[i] = -alpha * out_grad[i] * lhs[i] / (rhs[i] * rhs[i]) + beta * rhs_grad[i]; break; } - case Op::EW_MAX: { + case OperatorType::EW_MAX: { lhs_grad[i] = (lhs[i] >= rhs[i]) ? alpha * out_grad[i] + beta * lhs_grad[i] : beta * lhs_grad[i]; @@ -66,7 +64,7 @@ __global__ void elewise_binary_backward_kernel(size_t volume, : beta * rhs_grad[i]; break; } - case Op::EW_MIN: { + case OperatorType::EW_MIN: { lhs_grad[i] = (lhs[i] <= rhs[i]) ? alpha * out_grad[i] + beta * lhs_grad[i] : beta * lhs_grad[i]; @@ -102,17 +100,17 @@ ElementBinaryPerDeviceState init_kernel(PerDeviceFFHandle handle, checkCUDNN(cudnnCreateReduceTensorDescriptor(&reduceAddDesc)); switch (op_type) { - case Op::EW_ADD: - case Op::EW_SUB: + case OperatorType::EW_ADD: + case OperatorType::EW_SUB: mode = CUDNN_OP_TENSOR_ADD; break; - case Op::EW_MUL: + case OperatorType::EW_MUL: mode = CUDNN_OP_TENSOR_MUL; break; - case Op::EW_MAX: + case OperatorType::EW_MAX: mode = CUDNN_OP_TENSOR_MAX; break; - case Op::EW_MIN: + case OperatorType::EW_MIN: mode = CUDNN_OP_TENSOR_MIN; break; default: @@ -152,13 +150,13 @@ void forward_kernel(cudaStream_t stream, checkCUDNN(cudnnSetStream(handle.dnn, stream)); float alpha1 = 1.0f, alpha2 = 1.0f, beta = 0.0f; switch (op_type) { - case Op::EW_SUB: + case OperatorType::EW_SUB: alpha2 = -1.0f; break; - case Op::EW_ADD: - case Op::EW_MUL: - case Op::EW_MAX: - case Op::EW_MIN: + case OperatorType::EW_ADD: + case OperatorType::EW_MUL: + case OperatorType::EW_MAX: + case OperatorType::EW_MIN: break; default: assert(false); @@ -167,9 +165,9 @@ void forward_kernel(cudaStream_t stream, // cudnnOpTensor if (broadcast_inputLHS) { // currently only handle add and sub - assert(op_type == Op::EW_SUB || op_type == Op::EW_ADD || - op_type == Op::EW_MUL); - if (op_type == Op::EW_SUB || op_type == Op::EW_ADD) { + assert(op_type == OperatorType::EW_SUB || op_type == OperatorType::EW_ADD || + op_type == OperatorType::EW_MUL); + if (op_type == OperatorType::EW_SUB || op_type == OperatorType::EW_ADD) { // output = (beta*output + alpha1*input1) + beta*output = input1 checkCUDNN(cudnnOpTensor(handle.dnn, m.opDesc, @@ -195,7 +193,7 @@ void forward_kernel(cudaStream_t stream, &alpha1, m.outputTensor, out_ptr)); - } else if (op_type == Op::EW_MUL) { + } else if (op_type == OperatorType::EW_MUL) { checkCUDNN(cudnnSetOpTensorDescriptor(m.opDesc, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, @@ -258,7 +256,7 @@ void backward_kernel(cudaStream_t stream, checkCUDA(cublasSetStream(handle.blas, stream)); checkCUDNN(cudnnSetStream(handle.dnn, stream)); - if (op_type == Op::EW_ADD || op_type == Op::EW_SUB) { + if (op_type == OperatorType::EW_ADD || op_type == OperatorType::EW_SUB) { float alpha = 1.0f, beta = 1.0f; if (lhs_grad_ptr != nullptr) { if (broadcast_inputLHS) { @@ -284,7 +282,7 @@ void backward_kernel(cudaStream_t stream, lhs_grad_ptr)); } } - if (op_type == Op::EW_SUB) { + if (op_type == OperatorType::EW_SUB) { alpha = -1.0f; } if (rhs_grad_ptr != nullptr) { @@ -311,7 +309,7 @@ void backward_kernel(cudaStream_t stream, rhs_grad_ptr)); } } - } else if (op_type == Op::EW_MUL) { + } else if (op_type == OperatorType::EW_MUL) { float alpha1 = 1.0f, alpha2 = 1.0f, beta = 1.0f, zero = 0.0f; if (lhs_grad_ptr != nullptr) { if (broadcast_inputLHS) { @@ -393,7 +391,8 @@ void backward_kernel(cudaStream_t stream, rhs_grad_ptr)); } } - } else if (op_type == Op::EW_MIN || op_type == Op::EW_MAX) { + } else if (op_type == OperatorType::EW_MIN || + op_type == OperatorType::EW_MAX) { float alpha = 1.0f, beta = 1.0f; cudnnDataType_t dataType; int n; diff --git a/lib/kernels/src/cuda/ops/element_unary_kernels.cu b/lib/kernels/src/cuda/ops/element_unary_kernels.cu index 305e778726..e37d32c325 100644 --- a/lib/kernels/src/cuda/ops/element_unary_kernels.cu +++ b/lib/kernels/src/cuda/ops/element_unary_kernels.cu @@ -25,29 +25,19 @@ namespace ElementUnary { static bool use_cudnn(OperatorType op_type) { switch (op_type) { - case Op::RELU: - case Op::SIGMOID: - case Op::TANH: - case Op::ELU: + case OperatorType::RELU: + case OperatorType::SIGMOID: + case OperatorType::TANH: + case OperatorType::ELU: return true; default: return false; } } -template -T get_scalar(ElementUnaryUnifiedAttrs const &attrs) { - if (std::holds_alternative(attrs)) { - return (T)std::get(attrs).scalar; - } else { - T dummy_scalar; - return dummy_scalar; - } -} - -ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, - ArrayShape const &output_shape, - ElementUnaryUnifiedAttrs const &attrs) { +static ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, + ArrayShape const &output_shape, + OperatorType op_type) { ffTensorDescriptor_t inputTensor; ffTensorDescriptor_t outputTensor; @@ -57,21 +47,19 @@ ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); - Op op_type = get_op_type(attrs); - if (use_cudnn(op_type)) { cudnnActivationMode_t mode; switch (op_type) { - case Op::SIGMOID: + case OperatorType::SIGMOID: mode = CUDNN_ACTIVATION_SIGMOID; break; - case Op::RELU: + case OperatorType::RELU: mode = CUDNN_ACTIVATION_RELU; break; - case Op::TANH: + case OperatorType::TANH: mode = CUDNN_ACTIVATION_TANH; break; - case Op::ELU: + case OperatorType::ELU: mode = CUDNN_ACTIVATION_ELU; break; default: @@ -88,52 +76,64 @@ ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, return {inputTensor, outputTensor, actiDesc}; } +ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, + ArrayShape const &output_shape, + ElementUnaryAttrs const &attrs) { + return init_kernel(input_shape, output_shape, get_op_type(attrs)); +} + +ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, + ArrayShape const &output_shape, + ElementScalarUnaryAttrs const &attrs) { + return init_kernel(input_shape, output_shape, get_op_type(attrs)); +} + template __global__ void elewise_unary_forward_kernel( - coord_t volume, T const scalar, OperatorType type, T const *in, T *out) { + coord_t volume, T scalar, OperatorType type, T const *in, T *out) { CUDA_KERNEL_LOOP(i, volume) { switch (type) { - case Op::EXP: { + case OperatorType::EXP: { out[i] = (T)exp((float)in[i]); break; } - case Op::IDENTITY: { + case OperatorType::IDENTITY: { out[i] = in[i]; break; } - case Op::SCALAR_MULTIPLY: { + case OperatorType::SCALAR_MULTIPLY: { out[i] = in[i] * scalar; break; } - case Op::SCALAR_ADD: { + case OperatorType::SCALAR_ADD: { out[i] = in[i] + scalar; break; } - case Op::SCALAR_SUB: { + case OperatorType::SCALAR_SUB: { out[i] = in[i] - scalar; break; } - case Op::SCALAR_TRUE_DIV: { + case OperatorType::SCALAR_TRUE_DIV: { out[i] = in[i] / scalar; break; } - case Op::GELU: { + case OperatorType::GELU: { out[i] = (T)(in[i] * 0.5 * erfc(-in[i] * M_SQRT1_2)); break; } - case Op::RSQRT: { + case OperatorType::RSQRT: { out[i] = (T)(1.0f / sqrt((float)in[i])); break; } - case Op::POW: { + case OperatorType::POW: { out[i] = (T)(powf(in[i], scalar)); break; } - case Op::SIN: { + case OperatorType::SIN: { out[i] = (T)sin((float)in[i]); break; } - case Op::COS: { + case OperatorType::COS: { out[i] = (T)cos((float)in[i]); break; } @@ -145,7 +145,7 @@ __global__ void elewise_unary_forward_kernel( template __global__ void elewise_unary_backward_kernel(coord_t volume, - T const scalar, + T scalar, OperatorType type, T const *output, T const *output_grad, @@ -153,53 +153,53 @@ __global__ void elewise_unary_backward_kernel(coord_t volume, T *input_grad) { CUDA_KERNEL_LOOP(i, volume) { switch (type) { - case Op::EXP: { + case OperatorType::EXP: { // TODO: change to use output instead of recomputing input_grad[i] += (T)(output_grad[i] * exp((float)input[i])); break; } - case Op::IDENTITY: { + case OperatorType::IDENTITY: { input_grad[i] += output_grad[i]; break; } - case Op::SCALAR_MULTIPLY: { + case OperatorType::SCALAR_MULTIPLY: { input_grad[i] += output_grad[i] * scalar; break; } - case Op::SCALAR_ADD: { + case OperatorType::SCALAR_ADD: { input_grad[i] += output_grad[i]; break; } - case Op::SCALAR_SUB: { + case OperatorType::SCALAR_SUB: { input_grad[i] += output_grad[i]; break; } - case Op::SCALAR_TRUE_DIV: { + case OperatorType::SCALAR_TRUE_DIV: { input_grad[i] += output_grad[i] / scalar; break; } - case Op::GELU: { + case OperatorType::GELU: { input_grad[i] = (T)(output_grad[i] * (0.5 * erfc(-input[i] * M_SQRT1_2) - 0.5 * M_SQRT1_2 * input[i] * exp(-input[i] * input[i] * 0.5))); break; } - case Op::RSQRT: { + case OperatorType::RSQRT: { input_grad[i] = (T)(-0.5f * output_grad[i] * output[i] * output[i] * output[i]); break; } - case Op::POW: { + case OperatorType::POW: { input_grad[i] = (T)(output_grad[i] * scalar * powf(input[i], scalar - 1)); break; } - case Op::SIN: { + case OperatorType::SIN: { input_grad[i] += (T)(output_grad[i] * cos((float)input[i])); break; } - case Op::COS: { + case OperatorType::COS: { input_grad[i] += (T)(output_grad[i] * -sin((float)input[i])); break; } @@ -213,12 +213,12 @@ template struct ForwardKernel { void operator()(ffStream_t stream, ElementUnaryPerDeviceState const &m, - ElementUnaryUnifiedAttrs const &attrs, + OperatorType op_type, + std::optional scalar, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) const { checkCUDNN(cudnnSetStream(handle.dnn, stream)); - Op op_type = get_op_type(attrs); if (use_cudnn(op_type)) { float alpha = 1.0f, beta = 0.0f; checkCUDNN(cudnnActivationForward(handle.dnn, @@ -234,7 +234,7 @@ struct ForwardKernel { elewise_unary_forward_kernel> <<>>( num_elements, - get_scalar>(attrs), + static_cast>(scalar.value()), op_type, input.get(), output.get()); @@ -246,7 +246,8 @@ template struct BackwardKernel { void operator()(ffStream_t stream, ElementUnaryPerDeviceState const &m, - ElementUnaryUnifiedAttrs const &attrs, + OperatorType op_type, + std::optional scalar, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, @@ -254,7 +255,6 @@ struct BackwardKernel { GenericTensorAccessorR const &output_grad) { checkCUDNN(cudnnSetStream(handle.dnn, stream)); - Op op_type = get_op_type(attrs); if (use_cudnn(op_type)) { float alpha = 1.0f; checkCUDNN(cudnnActivationBackward(handle.dnn, @@ -274,7 +274,7 @@ struct BackwardKernel { elewise_unary_backward_kernel> <<>>( num_elements, - get_scalar>(attrs), + static_cast>(scalar.value()), op_type, output.get(), output_grad.get(), @@ -286,17 +286,59 @@ struct BackwardKernel { void forward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryUnifiedAttrs const &attrs, + ElementUnaryAttrs const &attrs, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - DataTypeDispatch1{}( - input.data_type, stream, device_state, attrs, handle, input, output); + DataTypeDispatch1{}(input.data_type, + stream, + device_state, + get_op_type(attrs), + std::nullopt, + handle, + input, + output); +} + +void forward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementScalarUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { + DataTypeDispatch1{}(input.data_type, + stream, + device_state, + get_op_type(attrs), + attrs.scalar, + handle, + input, + output); +} + +void backward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad) { + DataTypeDispatch1{}(input.data_type, + stream, + device_state, + get_op_type(attrs), + std::nullopt, + handle, + input, + input_grad, + output, + output_grad); } void backward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryUnifiedAttrs const &attrs, + ElementScalarUnaryAttrs const &attrs, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, @@ -305,7 +347,8 @@ void backward_kernel(ffStream_t stream, DataTypeDispatch1{}(input.data_type, stream, device_state, - attrs, + get_op_type(attrs), + attrs.scalar, handle, input, input_grad, diff --git a/lib/kernels/src/cuda/ops/gather_kernels.cu b/lib/kernels/src/cuda/ops/gather_kernels.cu index 286acf7376..e002cf7e71 100644 --- a/lib/kernels/src/cuda/ops/gather_kernels.cu +++ b/lib/kernels/src/cuda/ops/gather_kernels.cu @@ -126,9 +126,8 @@ void forward_kernel(ffStream_t stream, checkCUDA(get_legion_stream(&stream)); coord_t stride = - output.shape - .sub_shape(std::nullopt, legion_dim_t{m.legion_dim.value() + 1}) - .get_volume(); + output.shape.sub_shape(std::nullopt, add_to_legion_dim(m.legion_dim, 1)) + .num_elements(); coord_t output_dim_size = output.shape[m.legion_dim]; coord_t input_dim_size = input.shape[m.legion_dim]; @@ -155,7 +154,7 @@ void backward_kernel(ffStream_t stream, coord_t stride = output_grad.shape - .sub_shape(std::nullopt, legion_dim_t{m.legion_dim.value() + 1}) + .sub_shape(std::nullopt, add_to_legion_dim(m.legion_dim, 1)) .get_volume(); coord_t output_dim_size = output_grad.shape[m.legion_dim]; coord_t input_dim_size = input_grad.shape[m.legion_dim]; diff --git a/lib/kernels/src/cuda/ops/linear_kernels.cu b/lib/kernels/src/cuda/ops/linear_kernels.cu index 81ab34380e..9a36534a1b 100644 --- a/lib/kernels/src/cuda/ops/linear_kernels.cu +++ b/lib/kernels/src/cuda/ops/linear_kernels.cu @@ -253,9 +253,8 @@ void backward_kernel(cudaStream_t stream, // do nothing } else { RegularizerAttrs regularizer_attrs = m.regularizer.value(); - if (std::holds_alternative(regularizer_attrs)) { - L2RegularizerAttrs l2_attrs = - std::get(regularizer_attrs); + if (regularizer_attrs.has()) { + L2RegularizerAttrs l2_attrs = regularizer_attrs.get(); float lambda = l2_attrs.lambda; checkCUDA(cublasSgeam(m.handle.blas, CUBLAS_OP_N, diff --git a/lib/kernels/src/cuda/ops/reduce_kernels.cu b/lib/kernels/src/cuda/ops/reduce_kernels.cu index 8571219648..02a89da807 100644 --- a/lib/kernels/src/cuda/ops/reduce_kernels.cu +++ b/lib/kernels/src/cuda/ops/reduce_kernels.cu @@ -71,10 +71,10 @@ void backward_kernel(cudaStream_t stream, checkCUDNN(cudnnSetStream(m.handle.dnn, stream)); float alpha = 1.0, beta = 1.0f; switch (m.op_type) { - case Op::REDUCE_SUM: + case OperatorType::REDUCE_SUM: alpha = 1.0f; break; - case Op::REDUCE_MEAN: + case OperatorType::REDUCE_MEAN: // When the output is the average of multiple input elements // we need to scale the gradients by 1.0 / reduction_size alpha = 1.0f / m.reduction_size; diff --git a/lib/kernels/src/cuda/ops/transpose_kernels.cu b/lib/kernels/src/cuda/ops/transpose_kernels.cu index 7dac25d0c9..3b3f80944d 100644 --- a/lib/kernels/src/cuda/ops/transpose_kernels.cu +++ b/lib/kernels/src/cuda/ops/transpose_kernels.cu @@ -33,10 +33,10 @@ TransposePerDeviceState init_kernel(int num_dim, std::vector const &perm) { int const length = perm.size(); - std::vector perm_vector; + std::vector perm_vector; assert(length <= MAX_TENSOR_DIM); for (int i = 0; i < length; ++i) { - perm_vector.push_back(perm[i].value()); + perm_vector.push_back(legion_dim_from_ff_dim(perm[i], num_dim)); } return {num_dim, perm_vector}; @@ -77,7 +77,7 @@ void forward_kernel(cudaStream_t stream, info.in_strides[i] = info.in_strides[i - 1] * in_dim_size; info.out_strides[i] = info.out_strides[i - 1] * out_dim_size; } - info.perm[i] = m.perm[i]; + info.perm[i] = m.perm[i].value; } transpose_simple_kernel<< #if defined(FF_USE_CUDA) diff --git a/lib/kernels/src/hip/element_binary_kernels.cpp b/lib/kernels/src/hip/element_binary_kernels.cpp index 232e377797..bc66bbff2f 100644 --- a/lib/kernels/src/hip/element_binary_kernels.cpp +++ b/lib/kernels/src/hip/element_binary_kernels.cpp @@ -17,15 +17,13 @@ #include "device.h" #include "kernels/ff_handle.h" #include "op-attrs/datatype.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.dtg.h" #include namespace FlexFlow { namespace Kernels { namespace ElementBinary { -using OperatorType = Op; - __global__ void elewise_binary_backward_kernel(coord_t volume, float const alpha, float const beta, @@ -37,28 +35,28 @@ __global__ void elewise_binary_backward_kernel(coord_t volume, float *rhs_grad) { CUDA_KERNEL_LOOP(i, volume) { switch (type) { - case Op::EW_ADD: { + case OperatorType::EW_ADD: { lhs_grad[i] = alpha * out_grad[i] + beta * lhs_grad[i]; rhs_grad[i] = alpha * out_grad[i] + beta * rhs_grad[i]; break; } - case Op::EW_SUB: { + case OperatorType::EW_SUB: { lhs_grad[i] = alpha * out_grad[i] + beta * lhs_grad[i]; rhs_grad[i] = -alpha * out_grad[i] + beta * rhs_grad[i]; break; } - case Op::EW_MUL: { + case OperatorType::EW_MUL: { lhs_grad[i] = alpha * out_grad[i] * rhs[i] + beta * lhs_grad[i]; rhs_grad[i] = alpha * out_grad[i] * lhs[i] + beta * rhs_grad[i]; break; } - case Op::EW_DIV: { + case OperatorType::EW_DIV: { lhs_grad[i] = alpha * out_grad[i] / rhs[i] + beta * lhs_grad[i]; rhs_grad[i] = -alpha * out_grad[i] * lhs[i] / (rhs[i] * rhs[i]) + beta * rhs_grad[i]; break; } - case Op::EW_MAX: { + case OperatorType::EW_MAX: { lhs_grad[i] = (lhs[i] >= rhs[i]) ? alpha * out_grad[i] + beta * lhs_grad[i] : beta * lhs_grad[i]; @@ -67,7 +65,7 @@ __global__ void elewise_binary_backward_kernel(coord_t volume, : beta * rhs_grad[i]; break; } - case Op::EW_MIN: { + case OperatorType::EW_MIN: { lhs_grad[i] = (lhs[i] <= rhs[i]) ? alpha * out_grad[i] + beta * lhs_grad[i] : beta * lhs_grad[i]; @@ -103,17 +101,17 @@ ElementBinaryPerDeviceState init_kernel(PerDeviceFFHandle handle, checkCUDNN(miopenCreateReduceTensorDescriptor(&reduceAddDesc)); switch (op_type) { - case Op::EW_ADD: - case Op::EW_SUB: + case OperatorType::EW_ADD: + case OperatorType::EW_SUB: mode = miopenTensorOpAdd; break; - case Op::EW_MUL: + case OperatorType::EW_MUL: mode = miopenTensorOpMul; break; - case Op::EW_MAX: + case OperatorType::EW_MAX: mode = miopenOpTensorMax; break; - case Op::EW_MIN: + case OperatorType::EW_MIN: mode = miopenOpTensorMin; break; default: @@ -156,13 +154,13 @@ void forward_kernel(hipStream_t stream, float alpha1 = 1.0f, alpha2 = 1.0f, beta = 0.0f; switch (op_type) { - case Op::EW_SUB: + case OperatorType::EW_SUB: alpha2 = -1.0f; break; - case Op::EW_ADD: - case Op::EW_MUL: - case Op::EW_MAX: - case Op::EW_MIN: + case OperatorType::EW_ADD: + case OperatorType::EW_MUL: + case OperatorType::EW_MAX: + case OperatorType::EW_MIN: break; default: assert(false); @@ -171,9 +169,9 @@ void forward_kernel(hipStream_t stream, // cudnnOpTensor if (broadcast_inputLHS) { // currently only handle add and sub - assert(op_type == Op::EW_SUB || op_type == Op::EW_ADD || - op_type == Op::EW_MUL); - if (op_type == Op::EW_SUB || op_type == Op::EW_ADD) { + assert(op_type == OperatorType::EW_SUB || op_type == OperatorType::EW_ADD || + op_type == OperatorType::EW_MUL); + if (op_type == OperatorType::EW_SUB || op_type == OperatorType::EW_ADD) { checkCUDNN(miopenOpTensor(handle.dnn, m.opDesc, &beta, @@ -196,7 +194,7 @@ void forward_kernel(hipStream_t stream, &alpha1, m.outputTensor, out_ptr)); - } else if (op_type == Op::EW_MUL) { + } else if (op_type == OperatorType::EW_MUL) { checkCUDNN(cudnnSetOpTensorDescriptor(m.opDesc, CUDNN_OP_TENSOR_MUL, CUDNN_DATA_FLOAT, @@ -258,7 +256,7 @@ void backward_kernel(hipStream_t stream, checkCUDA(hipblasSetStream(handle.blas, stream)); checkCUDNN(miopenSetStream(handle.dnn, stream)); - if (m.op_type == Op::EW_ADD || m.op_type == Op::EW_SUB) { + if (m.op_type == OperatorType::EW_ADD || m.op_type == OperatorType::EW_SUB) { float alpha = 1.0f, beta = 1.0f; if (lhs_grad_ptr != nullptr) { if (broadcast_inputLHS) { @@ -288,7 +286,7 @@ void backward_kernel(hipStream_t stream, lhs_grad_ptr)); } } - if (m.op_type == Op::EW_SUB) { + if (m.op_type == OperatorType::EW_SUB) { alpha = -1.0f; } if (rhs_grad_ptr != nullptr) { @@ -319,7 +317,7 @@ void backward_kernel(hipStream_t stream, rhs_grad_ptr)); } } - } else if (m.op_type == Op::EW_MUL) { + } else if (m.op_type == OperatorType::EW_MUL) { float alpha1 = 1.0f, alpha2 = 1.0f, beta = 1.0f, zero = 0.0f; if (lhs_grad_ptr != nullptr) { if (broadcast_inputLHS) { diff --git a/lib/kernels/src/kernels/legion_dim_t.dtg.cc b/lib/kernels/src/kernels/legion_dim_t.dtg.cc new file mode 100644 index 0000000000..99c1a3b3a2 --- /dev/null +++ b/lib/kernels/src/kernels/legion_dim_t.dtg.cc @@ -0,0 +1,69 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/kernels/include/kernels/legion_dim_t.struct.toml +/* proj-data +{ + "generated_from": "f67d6e50c53539a21d69e7162cf965f4" +} +*/ + +#include "kernels/legion_dim_t.dtg.h" + +#include + +namespace FlexFlow { +legion_dim_t::legion_dim_t(int const &value) : value(value) {} +bool legion_dim_t::operator==(legion_dim_t const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool legion_dim_t::operator!=(legion_dim_t const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool legion_dim_t::operator<(legion_dim_t const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool legion_dim_t::operator>(legion_dim_t const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool legion_dim_t::operator<=(legion_dim_t const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool legion_dim_t::operator>=(legion_dim_t const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::legion_dim_t const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::legion_dim_t + adl_serializer::from_json(json const &j) { + return {j.at("value").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::legion_dim_t const &v) { + j["__type"] = "legion_dim_t"; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(legion_dim_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, legion_dim_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/CMakeLists.txt b/lib/op-attrs/CMakeLists.txt index 778be53d7c..9a9721ef2d 100644 --- a/lib/op-attrs/CMakeLists.txt +++ b/lib/op-attrs/CMakeLists.txt @@ -12,3 +12,4 @@ ff_add_library( ) add_subdirectory(ffi) +add_subdirectory(test) diff --git a/lib/op-attrs/include/op-attrs/activation.dtg.h b/lib/op-attrs/include/op-attrs/activation.dtg.h new file mode 100644 index 0000000000..a4c0e97882 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/activation.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/activation.enum.toml +/* proj-data +{ + "generated_from": "2b0d2e3e825732838aa5be99f2f0e6df" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_ACTIVATION_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_ACTIVATION_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class Activation { RELU, SIGMOID, TANH, GELU }; +std::string format_as(Activation); +std::ostream &operator<<(std::ostream &, Activation); +void to_json(::nlohmann::json &, Activation); +void from_json(::nlohmann::json const &, Activation &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::Activation) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_ACTIVATION_DTG_H diff --git a/lib/op-attrs/include/op-attrs/activation.enum.toml b/lib/op-attrs/include/op-attrs/activation.enum.toml new file mode 100644 index 0000000000..66119da9b1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/activation.enum.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "Activation" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "RELU" + +[[values]] +name = "SIGMOID" + +[[values]] +name = "TANH" + +[[values]] +name = "GELU" diff --git a/lib/op-attrs/include/op-attrs/activation.h b/lib/op-attrs/include/op-attrs/activation.h deleted file mode 100644 index 8fa07825fd..0000000000 --- a/lib/op-attrs/include/op-attrs/activation.h +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_ACTIVATION_H -#define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_ACTIVATION_H - -#include "utils/fmt.h" - -namespace FlexFlow { - -enum class Activation { RELU, SIGMOID, TANH, GELU }; - -} - -namespace fmt { - -template <> -struct formatter<::FlexFlow::Activation> : formatter { - template - auto format(::FlexFlow::Activation a, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (a) { - case Activation::RELU: - name = "ReLU"; - break; - case Activation::SIGMOID: - name = "Sigmoid"; - break; - case Activation::TANH: - name = "Tanh"; - break; - case Activation::GELU: - name = "GeLU"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - -#endif diff --git a/lib/op-attrs/include/op-attrs/aggregate_op.dtg.h b/lib/op-attrs/include/op-attrs/aggregate_op.dtg.h new file mode 100644 index 0000000000..3ff3848dca --- /dev/null +++ b/lib/op-attrs/include/op-attrs/aggregate_op.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/aggregate_op.enum.toml +/* proj-data +{ + "generated_from": "441fe9b0bb8f2dc2b31f74c58320ef30" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class AggregateOp { SUM }; +std::string format_as(AggregateOp); +std::ostream &operator<<(std::ostream &, AggregateOp); +void to_json(::nlohmann::json &, AggregateOp); +void from_json(::nlohmann::json const &, AggregateOp &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::AggregateOp) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_DTG_H diff --git a/lib/op-attrs/include/op-attrs/aggregate_op.enum.toml b/lib/op-attrs/include/op-attrs/aggregate_op.enum.toml new file mode 100644 index 0000000000..27aa50f38f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/aggregate_op.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "AggregateOp" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "SUM" + +[[value]] +name = "AVG" diff --git a/lib/op-attrs/include/op-attrs/as_dot.h b/lib/op-attrs/include/op-attrs/as_dot.h new file mode 100644 index 0000000000..d92557c2f4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/as_dot.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AS_DOT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AS_DOT_H + +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "utils/record_formatter.h" + +namespace FlexFlow { + +RecordFormatter as_dot(ComputationGraphOpAttrs const &); +RecordFormatter as_dot(PCGOperatorAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h new file mode 100644 index 0000000000..cc45628145 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h @@ -0,0 +1,471 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml +/* proj-data +{ + "generated_from": "cc0ab49405423594ffa1d8f541235a48" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ops/attention_attrs.dtg.h" +#include "op-attrs/ops/batch_matmul.dtg.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" +#include "op-attrs/ops/broadcast.dtg.h" +#include "op-attrs/ops/cast_attrs.dtg.h" +#include "op-attrs/ops/concat_attrs.dtg.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" +#include "op-attrs/ops/flat_attrs.dtg.h" +#include "op-attrs/ops/gather_attrs.dtg.h" +#include "op-attrs/ops/input_attrs.dtg.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" +#include "op-attrs/ops/linear_attrs.dtg.h" +#include "op-attrs/ops/noop_attrs.dtg.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" +#include "op-attrs/ops/split_attrs.dtg.h" +#include "op-attrs/ops/topk_attrs.dtg.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" +#include "op-attrs/ops/weight_attrs.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct ComputationGraphOpAttrs { + ComputationGraphOpAttrs() = delete; + explicit ComputationGraphOpAttrs(::FlexFlow::BatchMatmulAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::BatchNormAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::BroadcastAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::CastAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ConcatAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::Conv2DAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::DropoutAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ElementBinaryAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ElementUnaryAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ElementScalarUnaryAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::EmbeddingAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::FlatAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::GatherAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::InputAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::LayerNormAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::LinearAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::MultiHeadAttentionAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::NoopAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::Pool2DAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ReduceAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ReverseAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ReshapeAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::SplitAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::SoftmaxAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::TopKAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::TransposeAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::WeightAttrs const &); + template + static constexpr bool IsPartOfComputationGraphOpAttrs_v = + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::BatchMatmulAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::BatchNormAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::BroadcastAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::CastAttrs>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); + return result; + } + case 7: { + ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); + return result; + } + case 9: { + ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); + return result; + } + case 13: { + ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); + return result; + } + case 14: { + ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); + return result; + } + case 15: { + ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); + return result; + } + case 16: { + ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); + return result; + } + case 17: { + ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); + return result; + } + case 18: { + ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); + return result; + } + case 19: { + ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); + return result; + } + case 20: { + ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); + return result; + } + case 21: { + ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + return result; + } + case 22: { + ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + return result; + } + case 23: { + ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + return result; + } + case 24: { + ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + return result; + } + case 25: { + ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); + return result; + } + case 26: { + ReturnType result = v(this->get<::FlexFlow::WeightAttrs>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type ComputationGraphOpAttrs", + this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::BatchMatmulAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::BatchNormAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::BroadcastAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::CastAttrs>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); + return result; + } + case 7: { + ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); + return result; + } + case 9: { + ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); + return result; + } + case 13: { + ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); + return result; + } + case 14: { + ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); + return result; + } + case 15: { + ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); + return result; + } + case 16: { + ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); + return result; + } + case 17: { + ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); + return result; + } + case 18: { + ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); + return result; + } + case 19: { + ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); + return result; + } + case 20: { + ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); + return result; + } + case 21: { + ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + return result; + } + case 22: { + ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + return result; + } + case 23: { + ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + return result; + } + case 24: { + ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + return result; + } + case 25: { + ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); + return result; + } + case 26: { + ReturnType result = v(this->get<::FlexFlow::WeightAttrs>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type ComputationGraphOpAttrs", + this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfComputationGraphOpAttrs_v, + "ComputationGraphOpAttrs::has() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::BroadcastAttrs, ::FlexFlow::CastAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReverseAttrs, ::FlexFlow::ReshapeAttrs, " + "::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, " + "::FlexFlow::TopKAttrs, ::FlexFlow::TransposeAttrs, " + "::FlexFlow::WeightAttrs], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfComputationGraphOpAttrs_v, + "ComputationGraphOpAttrs::get() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::BroadcastAttrs, ::FlexFlow::CastAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReverseAttrs, ::FlexFlow::ReshapeAttrs, " + "::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, " + "::FlexFlow::TopKAttrs, ::FlexFlow::TransposeAttrs, " + "::FlexFlow::WeightAttrs], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfComputationGraphOpAttrs_v, + "ComputationGraphOpAttrs::get() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::BroadcastAttrs, ::FlexFlow::CastAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReverseAttrs, ::FlexFlow::ReshapeAttrs, " + "::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, " + "::FlexFlow::TopKAttrs, ::FlexFlow::TransposeAttrs, " + "::FlexFlow::WeightAttrs], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(ComputationGraphOpAttrs const &) const; + bool operator!=(ComputationGraphOpAttrs const &) const; + bool operator<(ComputationGraphOpAttrs const &) const; + bool operator>(ComputationGraphOpAttrs const &) const; + bool operator<=(ComputationGraphOpAttrs const &) const; + bool operator>=(ComputationGraphOpAttrs const &) const; + std::variant<::FlexFlow::BatchMatmulAttrs, + ::FlexFlow::BatchNormAttrs, + ::FlexFlow::BroadcastAttrs, + ::FlexFlow::CastAttrs, + ::FlexFlow::ConcatAttrs, + ::FlexFlow::Conv2DAttrs, + ::FlexFlow::DropoutAttrs, + ::FlexFlow::ElementBinaryAttrs, + ::FlexFlow::ElementUnaryAttrs, + ::FlexFlow::ElementScalarUnaryAttrs, + ::FlexFlow::EmbeddingAttrs, + ::FlexFlow::FlatAttrs, + ::FlexFlow::GatherAttrs, + ::FlexFlow::InputAttrs, + ::FlexFlow::LayerNormAttrs, + ::FlexFlow::LinearAttrs, + ::FlexFlow::MultiHeadAttentionAttrs, + ::FlexFlow::NoopAttrs, + ::FlexFlow::Pool2DAttrs, + ::FlexFlow::ReduceAttrs, + ::FlexFlow::ReverseAttrs, + ::FlexFlow::ReshapeAttrs, + ::FlexFlow::SplitAttrs, + ::FlexFlow::SoftmaxAttrs, + ::FlexFlow::TopKAttrs, + ::FlexFlow::TransposeAttrs, + ::FlexFlow::WeightAttrs> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::ComputationGraphOpAttrs> { + size_t operator()(::FlexFlow::ComputationGraphOpAttrs const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::ComputationGraphOpAttrs> { + static ::FlexFlow::ComputationGraphOpAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ComputationGraphOpAttrs const &); +}; +} // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::ComputationGraphOpAttrs> { + static Gen<::FlexFlow::ComputationGraphOpAttrs> arbitrary(); +}; +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::ComputationGraphOpAttrs const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::ComputationGraphOpAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h new file mode 100644 index 0000000000..4be17798f7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_H + +#include "op-attrs/computation_graph_op_attrs.dtg.h" + +namespace FlexFlow { + +OperatorType get_op_type(ComputationGraphOpAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml new file mode 100644 index 0000000000..bb25514e1d --- /dev/null +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml @@ -0,0 +1,148 @@ +namespace = "FlexFlow" +name = "ComputationGraphOpAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ops/attention_attrs.dtg.h", + "op-attrs/ops/batch_matmul.dtg.h", + "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/broadcast.dtg.h", + "op-attrs/ops/cast_attrs.dtg.h", + "op-attrs/ops/concat_attrs.dtg.h", + "op-attrs/ops/conv_2d_attrs.dtg.h", + "op-attrs/ops/dropout_attrs.dtg.h", + "op-attrs/ops/element_binary_attrs.dtg.h", + "op-attrs/ops/element_scalar_unary_attrs.dtg.h", + "op-attrs/ops/element_unary_attrs.dtg.h", + "op-attrs/ops/embedding_attrs.dtg.h", + "op-attrs/ops/flat_attrs.dtg.h", + "op-attrs/ops/gather_attrs.dtg.h", + "op-attrs/ops/input_attrs.dtg.h", + "op-attrs/ops/layer_norm_attrs.dtg.h", + "op-attrs/ops/linear_attrs.dtg.h", + "op-attrs/ops/noop_attrs.dtg.h", + "op-attrs/ops/pool_2d_attrs.dtg.h", + "op-attrs/ops/reduce_attrs.dtg.h", + "op-attrs/ops/reshape_attrs.dtg.h", + "op-attrs/ops/reverse_attrs.dtg.h", + "op-attrs/ops/softmax_attrs.dtg.h", + "op-attrs/ops/split_attrs.dtg.h", + "op-attrs/ops/topk_attrs.dtg.h", + "op-attrs/ops/transpose_attrs.dtg.h", + "op-attrs/ops/weight_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::BatchMatmulAttrs" +key = "batch_matmul" + +[[values]] +type = "::FlexFlow::BatchNormAttrs" +key = "batch_norm" + +[[values]] +type = "::FlexFlow::BroadcastAttrs" +key = "broadcast" + +[[values]] +type = "::FlexFlow::CastAttrs" +key = "cast" + +[[values]] +type = "::FlexFlow::ConcatAttrs" +key = "concat" + +[[values]] +type = "::FlexFlow::Conv2DAttrs" +key = "conv2d" + +[[values]] +type = "::FlexFlow::DropoutAttrs" +key = "dropout" + +[[values]] +type = "::FlexFlow::ElementBinaryAttrs" +key = "element_binary" + +[[values]] +type = "::FlexFlow::ElementUnaryAttrs" +key = "element_unary" + +[[values]] +type = "::FlexFlow::ElementScalarUnaryAttrs" +key = "element_scalar_unary" + +[[values]] +type = "::FlexFlow::EmbeddingAttrs" +key = "embedding" + +[[values]] +type = "::FlexFlow::FlatAttrs" +key = "flat" + +[[values]] +type = "::FlexFlow::GatherAttrs" +key = "gather" + +[[values]] +type = "::FlexFlow::InputAttrs" +key = "input" + +[[values]] +type = "::FlexFlow::LayerNormAttrs" +key = "layer_norm" + +[[values]] +type = "::FlexFlow::LinearAttrs" +key = "linear" + +[[values]] +type = "::FlexFlow::MultiHeadAttentionAttrs" +key = "multi_head_attention" + +[[values]] +type = "::FlexFlow::NoopAttrs" +key = "noop" + +[[values]] +type = "::FlexFlow::Pool2DAttrs" +key = "pool2d" + +[[values]] +type = "::FlexFlow::ReduceAttrs" +key = "reduce" + +[[values]] +type = "::FlexFlow::ReverseAttrs" +key = "reverse" + +[[values]] +type = "::FlexFlow::ReshapeAttrs" +key = "reshape" + +[[values]] +type = "::FlexFlow::SplitAttrs" +key = "split" + +[[values]] +type = "::FlexFlow::SoftmaxAttrs" +key = "softmax" + +[[values]] +type = "::FlexFlow::TopKAttrs" +key = "topk" + +[[values]] +type = "::FlexFlow::TransposeAttrs" +key = "transpose" + +[[values]] +type = "::FlexFlow::WeightAttrs" +key = "weight" diff --git a/lib/op-attrs/include/op-attrs/datatype.dtg.h b/lib/op-attrs/include/op-attrs/datatype.dtg.h new file mode 100644 index 0000000000..7052dba3b3 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/datatype.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/datatype.enum.toml +/* proj-data +{ + "generated_from": "8315d0aa0a65b00c13aa580e923592ef" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DATATYPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DATATYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class DataType { BOOL, INT32, INT64, HALF, FLOAT, DOUBLE }; +std::string format_as(DataType); +std::ostream &operator<<(std::ostream &, DataType); +void to_json(::nlohmann::json &, DataType); +void from_json(::nlohmann::json const &, DataType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::DataType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DATATYPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/datatype.enum.toml b/lib/op-attrs/include/op-attrs/datatype.enum.toml new file mode 100644 index 0000000000..15210cfe29 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/datatype.enum.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "DataType" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "BOOL" + +[[values]] +name = "INT32" + +[[values]] +name = "INT64" + +[[values]] +name = "HALF" + +[[values]] +name = "FLOAT" + +[[values]] +name = "DOUBLE" diff --git a/lib/op-attrs/include/op-attrs/datatype.h b/lib/op-attrs/include/op-attrs/datatype.h index 643fe44c41..a435c1bc12 100644 --- a/lib/op-attrs/include/op-attrs/datatype.h +++ b/lib/op-attrs/include/op-attrs/datatype.h @@ -1,14 +1,13 @@ #ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_DATATYPE_H #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_DATATYPE_H +#include "op-attrs/datatype.dtg.h" #include "utils/fmt.h" #include "utils/fp16.h" #include namespace FlexFlow { -enum class DataType { BOOL, INT32, INT64, HALF, FLOAT, DOUBLE }; - template struct data_type_enum_to_class; @@ -54,46 +53,11 @@ using DataTypeValue = std::variant, real_type, real_type, real_type, - real_type, + /* real_type, */ real_type>; size_t size_of_datatype(DataType); } // namespace FlexFlow -namespace fmt { -template <> -struct formatter<::FlexFlow::DataType> : formatter { - template - auto format(::FlexFlow::DataType dt, FormatContext &ctx) - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (dt) { - case DataType::BOOL: - name = "BOOL"; - break; - case DataType::INT32: - name = "INT32"; - break; - case DataType::INT64: - name = "INT64"; - break; - case DataType::HALF: - name = "HALF"; - break; - case DataType::FLOAT: - name = "FLOAT"; - break; - case DataType::DOUBLE: - name = "DOUBLE"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - #endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index b726d0687f..dbc237a03d 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -1,7 +1,8 @@ #ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim.dtg.h" +#include "utils/json.h" #include "utils/stack_vector.h" namespace FlexFlow { @@ -28,11 +29,19 @@ struct DimOrdered { : contents(contents.begin(), contents.end()) {} T const &at(Idx idx) const { - return this->contents.at(idx.value()); + int raw = idx.value; + if (raw < 0) { + raw = this->contents.size() + raw; + } + return this->contents.at(raw); } T &at(Idx idx) { - return this->contents.at(idx.value()); + int raw = idx.value; + if (raw < 0) { + raw = this->contents.size() + raw; + } + return this->contents.at(raw); } T const &operator[](Idx idx) const { @@ -43,6 +52,14 @@ struct DimOrdered { return this->at(idx); } + bool idx_is_valid(Idx const &idx) const { + int raw = idx.value; + if (raw < 0) { + raw = this->contents.size() + raw; + } + return (raw >= 0 && raw < this->contents.size()); + } + bool operator==(DimOrdered const &other) const { return this->contents == other.contents; } @@ -133,6 +150,17 @@ struct DimOrdered { template using FFOrdered = DimOrdered; +template +std::string format_as(FFOrdered const &v) { + std::vector as_vec(v.cbegin(), v.cend()); + return fmt::format("", as_vec); +} + +template +std::ostream &operator<<(std::ostream &s, FFOrdered const &v) { + return (s << fmt::to_string(v)); +} + template auto inner_to_outer(FFOrdered const &ff_ordered) -> decltype(reversed_container(ff_ordered)) { @@ -160,6 +188,29 @@ FFOrdered const &outer_to_inner(FFOrdered const &ff_ordered) { } // namespace FlexFlow +/* template */ +/* void to_json(json &j, DimOrdered const &x) { */ +/* /1* j = std::vector{x.cbegin(), x.cend()}; *1/ */ +/* } */ + +/* template */ +/* void from_json(json const &j, DimOrdered &x) { */ +/* /1* x = DimOrdered{j.template get>()}; *1/ */ +/* } */ + +namespace nlohmann { +template +struct adl_serializer<::FlexFlow::DimOrdered> { + static ::FlexFlow::DimOrdered from_json(json const &j) { + return {j.template get>()}; + } + + static void to_json(json &j, ::FlexFlow::DimOrdered const &x) { + j = std::vector{x.cbegin(), x.cend()}; + } +}; +} // namespace nlohmann + namespace std { template @@ -174,4 +225,16 @@ struct hash<::FlexFlow::DimOrdered> { } // namespace std +namespace rc { + +template +struct Arbitrary<::FlexFlow::DimOrdered> { + static Gen<::FlexFlow::DimOrdered> arbitrary() { + return gen::construct<::FlexFlow::DimOrdered>( + gen::arbitrary<::FlexFlow::stack_vector>()); + } +}; + +} // namespace rc + #endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h new file mode 100644 index 0000000000..4d6e82b71b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -0,0 +1,54 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H + +#include "op-attrs/dim_ordered.h" +#include "utils/containers.h" +#include "utils/optional.h" + +namespace FlexFlow { + +template +DimOrdered nonoverloaded_slice(DimOrdered const &d, + std::optional const &start, + std::optional const &end) { + auto to_raw_idx = [](std::optional const &idx) -> std::optional { + return transform(idx, [](Idx const &i) { return i.value; }); + }; + + return DimOrdered{ + subvec(as_vector(d), to_raw_idx(start), to_raw_idx(end))}; +} + +template +DimOrdered slice(DimOrdered const &d, + std::optional const &start, + std::optional const &end) { + return nonoverloaded_slice(d, start, end); +} + +template +DimOrdered slice(DimOrdered const &d, + std::nullopt_t const &start, + Idx const &end) { + return nonoverloaded_slice( + d, std::optional{start}, std::optional{end}); +} + +template +DimOrdered slice(DimOrdered const &d, + Idx const &start, + std::nullopt_t const &end) { + return nonoverloaded_slice( + d, std::optional{start}, std::optional{end}); +} + +template +DimOrdered + slice(DimOrdered const &d, Idx const &start, Idx const &end) { + return nonoverloaded_slice( + d, std::optional{start}, std::optional{end}); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h new file mode 100644 index 0000000000..880f13b4d4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H + +#include "op-attrs/dim_ordered.h" +#include "utils/containers.h" +#include "utils/containers/vector_transform.h" + +namespace FlexFlow { + +template +DimOrdered> + transform(DimOrdered const &d, F f) { + using Out = std::invoke_result_t; + + return DimOrdered{vector_transform(as_vector(d), f)}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ff_dim.dtg.h b/lib/op-attrs/include/op-attrs/ff_dim.dtg.h new file mode 100644 index 0000000000..1697f78196 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ff_dim.dtg.h @@ -0,0 +1,54 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ff_dim.struct.toml +/* proj-data +{ + "generated_from": "a5fa89a024e95c4f2d52681a74cab30f" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct ff_dim_t { + ff_dim_t() = delete; + ff_dim_t(int const &value); + + bool operator==(ff_dim_t const &) const; + bool operator!=(ff_dim_t const &) const; + bool operator<(ff_dim_t const &) const; + bool operator>(ff_dim_t const &) const; + bool operator<=(ff_dim_t const &) const; + bool operator>=(ff_dim_t const &) const; + int value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ff_dim_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ff_dim_t from_json(json const &); + static void to_json(json &, FlexFlow::ff_dim_t const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ff_dim_t const &); +std::ostream &operator<<(std::ostream &, ff_dim_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ff_dim.h b/lib/op-attrs/include/op-attrs/ff_dim.h index be1f148a70..e78ce4b51e 100644 --- a/lib/op-attrs/include/op-attrs/ff_dim.h +++ b/lib/op-attrs/include/op-attrs/ff_dim.h @@ -1,18 +1,18 @@ -#ifndef _FLEXFLOW_OPATTRS_INCLUDE_FF_DIM_H -#define _FLEXFLOW_OPATTRS_INCLUDE_FF_DIM_H -#include "utils/strong_typedef.h" -#include +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H -namespace FlexFlow { +#include "op-attrs/ff_dim.dtg.h" +#include "rapidcheck.h" -struct ff_dim_t : public numerical_typedef { - using numerical_typedef::numerical_typedef; +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::construct( + gen::inRange(0, MAX_TENSOR_DIM)); + } }; +} // namespace rc -} // namespace FlexFlow - -MAKE_TYPEDEF_HASHABLE(::FlexFlow::ff_dim_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::ff_dim_t, "ff_dim"); - -#endif +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H diff --git a/lib/op-attrs/include/op-attrs/ff_dim.struct.toml b/lib/op-attrs/include/op-attrs/ff_dim.struct.toml new file mode 100644 index 0000000000..441f9826ca --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ff_dim.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ff_dim_t" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/get_op_type.h b/lib/op-attrs/include/op-attrs/get_op_type.h index a2db4ab5f0..39541aa8d6 100644 --- a/lib/op-attrs/include/op-attrs/get_op_type.h +++ b/lib/op-attrs/include/op-attrs/get_op_type.h @@ -1,8 +1,37 @@ #ifndef _FLEXFLOW_OP_ATTRS_GET_OP_TYPE_H #define _FLEXFLOW_OP_ATTRS_GET_OP_TYPE_H -#include "operator_attrs.h" -#include "utils/variant.h" +#include "op-attrs/ops/attention_attrs.dtg.h" +#include "op-attrs/ops/batch_matmul.dtg.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" +#include "op-attrs/ops/broadcast.dtg.h" +#include "op-attrs/ops/cast_attrs.dtg.h" +#include "op-attrs/ops/combine_attrs.dtg.h" +#include "op-attrs/ops/concat_attrs.dtg.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" +#include "op-attrs/ops/flat_attrs.dtg.h" +#include "op-attrs/ops/gather_attrs.dtg.h" +#include "op-attrs/ops/input_attrs.dtg.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" +#include "op-attrs/ops/linear_attrs.dtg.h" +#include "op-attrs/ops/noop_attrs.dtg.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" +#include "op-attrs/ops/reduction_attrs.dtg.h" +#include "op-attrs/ops/repartition_attrs.dtg.h" +#include "op-attrs/ops/replicate_attrs.dtg.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" +#include "op-attrs/ops/split_attrs.dtg.h" +#include "op-attrs/ops/topk_attrs.dtg.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" +#include "op-attrs/ops/weight_attrs.dtg.h" namespace FlexFlow { @@ -32,23 +61,13 @@ OperatorType get_op_type(SplitAttrs const &); OperatorType get_op_type(SoftmaxAttrs const &); OperatorType get_op_type(TopKAttrs const &); OperatorType get_op_type(TransposeAttrs const &); +OperatorType get_op_type(WeightAttrs const &); + OperatorType get_op_type(CombineAttrs const &); OperatorType get_op_type(ReductionAttrs const &); OperatorType get_op_type(RepartitionAttrs const &); OperatorType get_op_type(ReplicateAttrs const &); -struct GetOpTypeFunctor { - template - OperatorType operator()(T const &t) { - return get_op_type(t); - } -}; - -template -OperatorType get_op_type(std::variant const &attrs) { - return visit(GetOpTypeFunctor{}, attrs); -} - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/get_output_shapes.h b/lib/op-attrs/include/op-attrs/get_output_shapes.h index 496cfbb755..a826e1cb54 100644 --- a/lib/op-attrs/include/op-attrs/get_output_shapes.h +++ b/lib/op-attrs/include/op-attrs/get_output_shapes.h @@ -112,26 +112,14 @@ std::vector get_output_shapes(Attrs const &attrs, ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &, std::vector const &); -ParallelTensorShape get_output_shape(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(CastAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(CombineAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(ConcatAttrs const &, std::vector const &); ParallelTensorShape get_output_shape(Conv2DAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(DropoutAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ElementUnaryUnifiedAttrs const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(EmbeddingAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(FlatAttrs const &, ParallelTensorShape const &); std::vector get_output_shapes(GatherAttrs const &, @@ -139,20 +127,12 @@ std::vector get_output_shapes(GatherAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(LayerNormAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(LinearAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(ReduceAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ReductionAttrs const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(RepartitionAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(ReplicateAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ReshapeAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(ReverseAttrs const &, ParallelTensorShape const &); std::vector get_output_shapes(SplitAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.h b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.h new file mode 100644 index 0000000000..1d4747db7e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "50968fb8a3d43395d0eab7594f4935c0" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L1_REGULARIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L1_REGULARIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct L1RegularizerAttrs { + L1RegularizerAttrs() = delete; + L1RegularizerAttrs(float const &lambda); + + bool operator==(L1RegularizerAttrs const &) const; + bool operator!=(L1RegularizerAttrs const &) const; + bool operator<(L1RegularizerAttrs const &) const; + bool operator>(L1RegularizerAttrs const &) const; + bool operator<=(L1RegularizerAttrs const &) const; + bool operator>=(L1RegularizerAttrs const &) const; + float lambda; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::L1RegularizerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::L1RegularizerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::L1RegularizerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(L1RegularizerAttrs const &); +std::ostream &operator<<(std::ostream &, L1RegularizerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L1_REGULARIZER_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml new file mode 100644 index 0000000000..60fabfb94a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "L1RegularizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "lambda" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.h b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.h new file mode 100644 index 0000000000..981d3f4905 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "c4f182e547ab6f0d5613e7eeb95d438e" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L2_REGULARIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L2_REGULARIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct L2RegularizerAttrs { + L2RegularizerAttrs() = delete; + L2RegularizerAttrs(float const &lambda); + + bool operator==(L2RegularizerAttrs const &) const; + bool operator!=(L2RegularizerAttrs const &) const; + bool operator<(L2RegularizerAttrs const &) const; + bool operator>(L2RegularizerAttrs const &) const; + bool operator<=(L2RegularizerAttrs const &) const; + bool operator>=(L2RegularizerAttrs const &) const; + float lambda; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::L2RegularizerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::L2RegularizerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::L2RegularizerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(L2RegularizerAttrs const &); +std::ostream &operator<<(std::ostream &, L2RegularizerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L2_REGULARIZER_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml new file mode 100644 index 0000000000..adce4397a4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "L2RegularizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "lambda" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/op.h b/lib/op-attrs/include/op-attrs/op.h deleted file mode 100644 index 9ad83c3641..0000000000 --- a/lib/op-attrs/include/op-attrs/op.h +++ /dev/null @@ -1,369 +0,0 @@ -#ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_OP_H -#define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_OP_H - -#include "utils/fmt.h" - -namespace FlexFlow { - -enum class Op { - NOOP, - INPUT, - WEIGHT, - CONV2D, - DROPOUT, - LINEAR, - BATCHMATMUL, - POOL2D, - SCALAR_MULTIPLY, - SCALAR_ADD, - SCALAR_FLOOR_DIV, - SCALAR_TRUE_DIV, - SCALAR_SUB, - RELU, - IDENTITY, - SIGMOID, - TANH, - ELU, - FLAT, - SOFTMAX, - BATCHNORM, - CONCAT, - SPLIT, - EMBEDDING, - CACHE, - // OP_ELEMENTWISE, - RESHAPE, - REVERSE, - TRANSPOSE, - EW_ADD, - EW_MUL, - MATMUL, - MUL, - ENLARGE, - SQUEEZE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Squeeze - UNSQUEEZE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Unsqueeze - EW_SUB, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sub - EW_DIV, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Div - EW_EQUAL, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Equal - EW_GREATER, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Greater - EW_LESS, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Less - EW_MAX, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Max - EW_MIN, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Min - REDUCE_ARGMAX, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ArgMax - REDUCE_ARGMIN, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ArgMin - REDUCE_MAX, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceMax - REDUCE_MEAN, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceMean - REDUCE_MIN, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceMin - REDUCE_PROD, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceProd - REDUCE_SUM, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceSum - PAD, // https://github.com/dmlc/tvm/blob/master/topi/python/topi/nn/pad.py - SHAPE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shape - SIZE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Size - TOPK, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#TopK - WHERE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Where - CEIL, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Ceil - CAST, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Cast - EXP, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Exp - ROUND, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Round - LOG, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Log - LOGICAL_NOT, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Not - SQRT, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sqrt - SIN, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sin - COS, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Cos - LEAKYRELU, - SLICE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Slice - RESIZE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Resize - PRELU, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#PRelu - GELU, - MULTIHEAD_ATTENTION, - FUSED, // Fused operator type for internal fusion optimizations - RSQRT, // https://pytorch.org/docs/stable/generated/torch.rsqrt.html - POW, // https://pytorch.org/docs/stable/generated/torch.pow.html - MEAN, // https://pytorch.org/docs/stable/generated/torch.mean.html - LAYERNORM, - GATHER, // https://pytorch.org/docs/stable/generated/torch.gather.html - BROADCAST, - // Parallel Ops - REPARTITION, - COMBINE, - REPLICATE, - REDUCTION, - BATCH, - PIPELINE, - FUSED_PARALLEL, -}; - -using OperatorType = Op; - -std::string get_operator_type_name(Op op); - -} // namespace FlexFlow - -namespace fmt { - -template <> -struct formatter<::FlexFlow::Op> : formatter { - template - auto format(::FlexFlow::Op ot, FormatContext &ctx) -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (ot) { - case Op::CONV2D: - name = "Conv2D"; - break; - case Op::DROPOUT: - name = "Dropout"; - break; - case Op::LINEAR: - name = "Dense"; - break; - case Op::BATCHMATMUL: - name = "BatchMatMul"; - break; - case Op::POOL2D: - name = "Pool2D"; - break; - case Op::SCALAR_MULTIPLY: - name = "ScalarMultiply"; - break; - case Op::SCALAR_ADD: - name = "ScalarAdd"; - break; - case Op::SCALAR_FLOOR_DIV: - name = "ScalarFloorDiv"; - break; - case Op::SCALAR_TRUE_DIV: - name = "ScalarTrueDiv"; - break; - case Op::SCALAR_SUB: - name = "ScalarSub"; - break; - case Op::RELU: - name = "ReLU"; - break; - case Op::SIGMOID: - name = "Sigmoid"; - break; - case Op::TANH: - name = "Tanh"; - break; - case Op::ELU: - name = "Elu"; - break; - case Op::FLAT: - name = "Flat"; - break; - case Op::SOFTMAX: - name = "Softmax"; - break; - case Op::BATCHNORM: - name = "BatchNorm"; - break; - case Op::CONCAT: - name = "Concat"; - break; - case Op::SPLIT: - name = "Split"; - break; - case Op::EMBEDDING: - name = "Embedding"; - break; - case Op::GATHER: - name = "Gather"; - break; - case Op::CACHE: - name = "Cache"; - break; - case Op::RESHAPE: - name = "Reshape"; - break; - case Op::REVERSE: - name = "Reverse"; - break; - case Op::TRANSPOSE: - name = "Transpose"; - break; - case Op::EW_ADD: - name = "Add"; - break; - case Op::EW_MUL: - name = "Mul"; - break; - case Op::MATMUL: - name = "Matmul"; - break; - case Op::MUL: - name = "Mul"; - break; - case Op::ENLARGE: - name = "Enlarge"; - break; - case Op::SQUEEZE: - name = "Squeeze"; - break; - case Op::UNSQUEEZE: - name = "Unsqueeze"; - break; - case Op::EW_SUB: - name = "Sub"; - break; - case Op::EW_DIV: - name = "Div"; - break; - case Op::EW_EQUAL: - name = "Equal"; - break; - case Op::EW_GREATER: - name = "Greater"; - break; - case Op::EW_LESS: - name = "Less"; - break; - case Op::EW_MAX: - name = "Max"; - break; - case Op::EW_MIN: - name = "Min"; - break; - case Op::REDUCE_ARGMAX: - name = "ReduceArgMax"; - break; - case Op::REDUCE_ARGMIN: - name = "ReduceArgMin"; - break; - case Op::REDUCE_MAX: - name = "ReduceMax"; - break; - case Op::REDUCE_MEAN: - name = "ReduceMean"; - break; - case Op::REDUCE_MIN: - name = "ReduceMin"; - break; - case Op::REDUCE_PROD: - name = "ReduceProd"; - break; - case Op::REDUCE_SUM: - name = "ReduceSum"; - break; - case Op::PAD: - name = "Pad"; - break; - case Op::SHAPE: - name = "Shape"; - break; - case Op::SIZE: - name = "Size"; - break; - case Op::TOPK: - name = "TopK"; - break; - case Op::WHERE: - name = "Where"; - break; - case Op::CEIL: - name = "Ceil"; - break; - case Op::CAST: - name = "Cast"; - break; - case Op::EXP: - name = "Exp"; - break; - case Op::SIN: - name = "Sin"; - break; - case Op::COS: - name = "Cos"; - break; - case Op::ROUND: - name = "Round"; - break; - case Op::LOG: - name = "Log"; - break; - case Op::LOGICAL_NOT: - name = "LogicalNot"; - break; - case Op::SQRT: - name = "Sqrt"; - break; - case Op::LEAKYRELU: - name = "LeakyReLU"; - break; - case Op::SLICE: - name = "Slice"; - break; - case Op::RESIZE: - name = "Resize"; - break; - case Op::PRELU: - name = "PReLU"; - break; - case Op::MULTIHEAD_ATTENTION: - name = "MultiHeadAttention"; - break; - case Op::INPUT: - name = "Input"; - break; - case Op::WEIGHT: - name = "Weight"; - break; - case Op::NOOP: - name = "NoOp"; - break; - case Op::FUSED: - name = "FusedOp"; - break; - case Op::RSQRT: - name = "Rsqrt"; - break; - case Op::POW: - name = "Pow"; - break; - case Op::MEAN: - name = "Mean"; - break; - case Op::LAYERNORM: - name = "LayerNorm"; - break; - case Op::IDENTITY: - name = "Identity"; - break; - // Parallel Ops - case Op::REPARTITION: - name = "Repartition"; - break; - case Op::COMBINE: - name = "Combine"; - break; - case Op::REPLICATE: - name = "Replicate"; - break; - case Op::REDUCTION: - name = "Reduction"; - break; - case Op::PIPELINE: - name = "Pipeline"; - break; - case Op::FUSED_PARALLEL: - name = "FusedParallelOp"; - break; - case Op::GELU: - name = "GeLU"; - break; - case Op::BROADCAST: - name = "Broadcast"; - break; - case Op::BATCH: - name = "Batch"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - -#endif diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index b63563cd67..268554b5be 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -2,6 +2,7 @@ #define _OPERATOR_PARAMS_H #include "op-attrs/ops/core.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" #include "ops/attention.h" #include "ops/batch_matmul.h" #include "ops/batch_norm.h" @@ -31,102 +32,16 @@ #include "ops/split.h" #include "ops/topk.h" #include "ops/transpose.h" +#include "utils/record_formatter.h" #include "utils/variant.h" #include namespace FlexFlow { -using SharedOperatorAttrs = std::variant; - -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); - -using ParallelOperatorAttrs = std:: - variant; - -using ComputationGraphAttrs = - variant_join>; -using CompGraphOperatorAttrs = ComputationGraphAttrs; - -using PCGOperatorAttrs = - variant_join; - -static_assert(is_equal_comparable::value, - "ComputationGraphAttrs must support =="); -static_assert(elements_satisfy::value, - ""); -static_assert(is_neq_comparable::value, - "ComputationGraphAttrs must support !="); -static_assert(is_lt_comparable::value, - "ComputationGraphAttrs must support <"); -static_assert(is_hashable::value, - "ComputationGraphAttrs must be hashable"); - -static_assert(is_equal_comparable::value, - "PCGOperatorAttrs must support =="); -static_assert(is_neq_comparable::value, - "PCGOperatorAttrs must support !="); -static_assert(is_lt_comparable::value, - "PCGOperatorAttrs must support <"); -static_assert(is_hashable::value, - "PCGOperatorAttrs must be hashable"); - -/* OperatorType get_op_type(CompGraphOperatorAttrs const &); */ -/* OperatorType get_op_type(PCGOperatorAttrs const &); */ - -RecordFormatter as_dot(CompGraphOperatorAttrs const &); -RecordFormatter as_dot(PCGOperatorAttrs const &); - std::vector get_output_shapes( PCGOperatorAttrs const &op_params, std::vector const &input_tensor_shapes); -bool is_parallel_op(PCGOperatorAttrs const &); bool is_valid(PCGOperatorAttrs const &, std::vector const &); diff --git a/lib/op-attrs/include/op-attrs/operator_type.dtg.h b/lib/op-attrs/include/op-attrs/operator_type.dtg.h new file mode 100644 index 0000000000..3b4bd86552 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_type.dtg.h @@ -0,0 +1,124 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/operator_type.enum.toml +/* proj-data +{ + "generated_from": "c1c4687ef2fbc7dad996e5c25d47124c" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class OperatorType { + NOOP, + INPUT, + WEIGHT, + CONV2D, + DROPOUT, + LINEAR, + BATCHMATMUL, + POOL2D, + SCALAR_MULTIPLY, + SCALAR_ADD, + SCALAR_FLOOR_DIV, + SCALAR_TRUE_DIV, + SCALAR_SUB, + RELU, + IDENTITY, + SIGMOID, + TANH, + ELU, + FLAT, + SOFTMAX, + BATCHNORM, + CONCAT, + SPLIT, + EMBEDDING, + CACHE, + RESHAPE, + REVERSE, + TRANSPOSE, + EW_ADD, + EW_MUL, + MATMUL, + MUL, + ENLARGE, + SQUEEZE, + UNSQUEEZE, + EW_SUB, + EW_DIV, + EW_EQUAL, + EW_GREATER, + EW_LESS, + EW_MAX, + EW_MIN, + REDUCE_ARGMAX, + REDUCE_ARGMIN, + REDUCE_MAX, + REDUCE_MEAN, + REDUCE_MIN, + REDUCE_PROD, + REDUCE_SUM, + PAD, + SHAPE, + SIZE, + TOPK, + WHERE, + CEIL, + CAST, + EXP, + ROUND, + LOG, + LOGICAL_NOT, + SQRT, + SIN, + COS, + LEAKYRELU, + SLICE, + RESIZE, + PRELU, + GELU, + MULTIHEAD_ATTENTION, + FUSED, + RSQRT, + POW, + MEAN, + LAYERNORM, + GATHER, + BROADCAST, + REPARTITION, + COMBINE, + REPLICATE, + REDUCTION, + BATCH, + PIPELINE, + FUSED_PARALLEL +}; +std::string format_as(OperatorType); +std::ostream &operator<<(std::ostream &, OperatorType); +void to_json(::nlohmann::json &, OperatorType); +void from_json(::nlohmann::json const &, OperatorType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/operator_type.enum.toml b/lib/op-attrs/include/op-attrs/operator_type.enum.toml new file mode 100644 index 0000000000..8815d69dda --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_type.enum.toml @@ -0,0 +1,95 @@ +namespace = "FlexFlow" +name = "OperatorType" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +values = [ + { name = "NOOP" }, + { name = "INPUT" }, + { name = "WEIGHT" }, + { name = "CONV2D" }, + { name = "DROPOUT" }, + { name = "LINEAR" }, + { name = "BATCHMATMUL" }, + { name = "POOL2D" }, + { name = "SCALAR_MULTIPLY" }, + { name = "SCALAR_ADD" }, + { name = "SCALAR_FLOOR_DIV" }, + { name = "SCALAR_TRUE_DIV" }, + { name = "SCALAR_SUB" }, + { name = "RELU" }, + { name = "IDENTITY" }, + { name = "SIGMOID" }, + { name = "TANH" }, + { name = "ELU" }, + { name = "FLAT" }, + { name = "SOFTMAX" }, + { name = "BATCHNORM" }, + { name = "CONCAT" }, + { name = "SPLIT" }, + { name = "EMBEDDING" }, + { name = "CACHE" }, + { name = "RESHAPE" }, + { name = "REVERSE" }, + { name = "TRANSPOSE" }, + { name = "EW_ADD" }, + { name = "EW_MUL" }, + { name = "MATMUL" }, + { name = "MUL" }, + { name = "ENLARGE" }, + { name = "SQUEEZE" }, + { name = "UNSQUEEZE" }, + { name = "EW_SUB" }, + { name = "EW_DIV" }, + { name = "EW_EQUAL" }, + { name = "EW_GREATER" }, + { name = "EW_LESS" }, + { name = "EW_MAX" }, + { name = "EW_MIN" }, + { name = "REDUCE_ARGMAX" }, + { name = "REDUCE_ARGMIN" }, + { name = "REDUCE_MAX" }, + { name = "REDUCE_MEAN" }, + { name = "REDUCE_MIN" }, + { name = "REDUCE_PROD" }, + { name = "REDUCE_SUM" }, + { name = "PAD" }, + { name = "SHAPE" }, + { name = "SIZE" }, + { name = "TOPK" }, + { name = "WHERE" }, + { name = "CEIL" }, + { name = "CAST" }, + { name = "EXP" }, + { name = "ROUND" }, + { name = "LOG" }, + { name = "LOGICAL_NOT" }, + { name = "SQRT" }, + { name = "SIN" }, + { name = "COS" }, + { name = "LEAKYRELU" }, + { name = "SLICE" }, + { name = "RESIZE" }, + { name = "PRELU" }, + { name = "GELU" }, + { name = "MULTIHEAD_ATTENTION" }, + { name = "FUSED" }, + { name = "RSQRT" }, + { name = "POW" }, + { name = "MEAN" }, + { name = "LAYERNORM" }, + { name = "GATHER" }, + { name = "BROADCAST" }, + { name = "REPARTITION" }, + { name = "COMBINE" }, + { name = "REPLICATE" }, + { name = "REDUCTION" }, + { name = "BATCH" }, + { name = "PIPELINE" }, + { name = "FUSED_PARALLEL" }, +] + diff --git a/lib/op-attrs/include/op-attrs/operator_type.h b/lib/op-attrs/include/op-attrs/operator_type.h new file mode 100644 index 0000000000..4750af51ee --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_type.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_H + +#include "op-attrs/operator_type.dtg.h" + +namespace FlexFlow { + +std::string get_operator_type_name(OperatorType); +bool is_parallel_op(OperatorType); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index ec3e592607..8233775e63 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -2,73 +2,62 @@ #define _FLEXFLOW_ATTENTION_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" +#include "op-attrs/ops/attention_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { -struct MultiHeadAttentionAttrs { - req embed_dim, num_heads, kdim, vdim; - req dropout; - req bias, add_bias_kv, add_zero_attn; -}; -FF_VISITABLE_STRUCT(MultiHeadAttentionAttrs, - embed_dim, - num_heads, - kdim, - vdim, - dropout, - bias, - add_bias_kv, - add_zero_attn); - -template -struct MultiHeadAttentionInputs - : public use_visitable_cmp> { -public: - MultiHeadAttentionInputs() = delete; - - MultiHeadAttentionInputs(TensorType const &query, - TensorType const &key, - TensorType const &value) - : query(query), key(key), value(value) {} - - template - MultiHeadAttentionInputs(MultiHeadAttentionInputs const &sub) - : query(sub.query), key(sub.key), value(sub.value) {} - -public: - TensorType query; - TensorType key; - TensorType value; -}; - int get_qProjSize(MultiHeadAttentionAttrs const &); int get_vProjSize(MultiHeadAttentionAttrs const &); int get_kProjSize(MultiHeadAttentionAttrs const &); int get_oProjSize(MultiHeadAttentionAttrs const &); -int get_qSize(MultiHeadAttentionInputs const &); -int get_kSize(MultiHeadAttentionInputs const &); -int get_vSize(MultiHeadAttentionInputs const &); +int get_qSize(MultiHeadAttentionParallelInputs const &); +int get_qSize(MultiHeadAttentionInputs const &); + +int get_kSize(MultiHeadAttentionParallelInputs const &); +int get_kSize(MultiHeadAttentionInputs const &); + +int get_vSize(MultiHeadAttentionParallelInputs const &); +int get_vSize(MultiHeadAttentionInputs const &); + int get_oSize(ParallelTensorShape const &); +int get_oSize(TensorShape const &); + +int get_qoSeqLength(MultiHeadAttentionParallelInputs const &); +int get_qoSeqLength(MultiHeadAttentionInputs const &); -int get_qoSeqLength(MultiHeadAttentionInputs const &); -int get_kvSeqLength(MultiHeadAttentionInputs const &); +int get_kvSeqLength(MultiHeadAttentionParallelInputs const &); +int get_kvSeqLength(MultiHeadAttentionInputs const &); -int get_num_samples(MultiHeadAttentionInputs const &); +int get_num_samples(MultiHeadAttentionParallelInputs const &); +int get_num_samples(MultiHeadAttentionInputs const &); -TensorShape get_weights_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &); -ParallelTensorShape +tl::expected get_weights_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &); + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v); +tl::expected + get_weights_shape(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); -ParallelTensorShape +tl::expected + get_output_shape(MultiHeadAttentionAttrs const &, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v); +tl::expected get_output_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &); -TensorShape get_output_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &); + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); CHECK_VALID_OP_ATTR(MultiHeadAttentionAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h new file mode 100644 index 0000000000..7b61305a1a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h @@ -0,0 +1,74 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml +/* proj-data +{ + "generated_from": "c57a9d1d2822a726ee9d9369d22e8e72" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_INPUTS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_INPUTS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct MultiHeadAttentionInputs { + MultiHeadAttentionInputs() = delete; + MultiHeadAttentionInputs(size_t const &batch_size, + size_t const &sequence_length, + size_t const &query_size, + size_t const &key_size, + size_t const &value_size, + ::FlexFlow::DataType const &datatype); + + bool operator==(MultiHeadAttentionInputs const &) const; + bool operator!=(MultiHeadAttentionInputs const &) const; + bool operator<(MultiHeadAttentionInputs const &) const; + bool operator>(MultiHeadAttentionInputs const &) const; + bool operator<=(MultiHeadAttentionInputs const &) const; + bool operator>=(MultiHeadAttentionInputs const &) const; + size_t batch_size; + size_t sequence_length; + size_t query_size; + size_t key_size; + size_t value_size; + ::FlexFlow::DataType datatype; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MultiHeadAttentionInputs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MultiHeadAttentionInputs from_json(json const &); + static void to_json(json &, FlexFlow::MultiHeadAttentionInputs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionInputs const &); +std::ostream &operator<<(std::ostream &, MultiHeadAttentionInputs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_INPUTS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.h new file mode 100644 index 0000000000..aed9f577ff --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_INPUTS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_INPUTS_H + +#include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include + +namespace FlexFlow { + +tl::expected + parse_attention_input_shape(TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml new file mode 100644 index 0000000000..b82b285451 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml @@ -0,0 +1,39 @@ +namespace = "FlexFlow" +name = "MultiHeadAttentionInputs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "batch_size" +type = "size_t" + +[[fields]] +name = "sequence_length" +type = "size_t" + +[[fields]] +name = "query_size" +type = "size_t" + +[[fields]] +name = "key_size" +type = "size_t" + +[[fields]] +name = "value_size" +type = "size_t" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h new file mode 100644 index 0000000000..297b1f8f1c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h @@ -0,0 +1,82 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml +/* proj-data +{ + "generated_from": "7c434445707968123a361c038a337da2" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_PARALLEL_INPUTS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_PARALLEL_INPUTS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" +#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct MultiHeadAttentionParallelInputs { + MultiHeadAttentionParallelInputs() = delete; + MultiHeadAttentionParallelInputs( + ::FlexFlow::ShardParallelDim const &batch_dim, + ::FlexFlow::ShardParallelDim const &sequence_dim, + ::FlexFlow::ShardParallelDim const &query_dim, + ::FlexFlow::ShardParallelDim const &key_dim, + ::FlexFlow::ShardParallelDim const &value_dim, + ::FlexFlow::DiscardCopyDegree const &discard_copy_degree, + ::FlexFlow::DataType const &datatype); + + bool operator==(MultiHeadAttentionParallelInputs const &) const; + bool operator!=(MultiHeadAttentionParallelInputs const &) const; + bool operator<(MultiHeadAttentionParallelInputs const &) const; + bool operator>(MultiHeadAttentionParallelInputs const &) const; + bool operator<=(MultiHeadAttentionParallelInputs const &) const; + bool operator>=(MultiHeadAttentionParallelInputs const &) const; + ::FlexFlow::ShardParallelDim batch_dim; + ::FlexFlow::ShardParallelDim sequence_dim; + ::FlexFlow::ShardParallelDim query_dim; + ::FlexFlow::ShardParallelDim key_dim; + ::FlexFlow::ShardParallelDim value_dim; + ::FlexFlow::DiscardCopyDegree discard_copy_degree; + ::FlexFlow::DataType datatype; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MultiHeadAttentionParallelInputs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MultiHeadAttentionParallelInputs from_json(json const &); + static void to_json(json &, + FlexFlow::MultiHeadAttentionParallelInputs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionParallelInputs const &); +std::ostream &operator<<(std::ostream &, + MultiHeadAttentionParallelInputs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_PARALLEL_INPUTS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.h new file mode 100644 index 0000000000..53cc3167f2 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_PARALLEL_INPUTS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_PARALLEL_INPUTS_H + +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include + +namespace FlexFlow { + +tl::expected + parse_attention_parallel_input_shape(ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml new file mode 100644 index 0000000000..b0636db353 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml @@ -0,0 +1,46 @@ +namespace = "FlexFlow" +name = "MultiHeadAttentionParallelInputs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "op-attrs/datatype.dtg.h", + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", + "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", +] + +[[fields]] +name = "batch_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "sequence_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "query_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "key_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "value_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "discard_copy_degree" +type = "::FlexFlow::DiscardCopyDegree" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.h new file mode 100644 index 0000000000..18b2906759 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.h @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml +/* proj-data +{ + "generated_from": "360324465947562229dc6632a9e9a2f3" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct MultiHeadAttentionAttrs { + MultiHeadAttentionAttrs() = delete; + MultiHeadAttentionAttrs(int const &embed_dim, + int const &num_heads, + int const &kdim, + int const &vdim, + float const &dropout, + bool const &bias, + bool const &add_bias_kv, + bool const &add_zero_attn); + + bool operator==(MultiHeadAttentionAttrs const &) const; + bool operator!=(MultiHeadAttentionAttrs const &) const; + bool operator<(MultiHeadAttentionAttrs const &) const; + bool operator>(MultiHeadAttentionAttrs const &) const; + bool operator<=(MultiHeadAttentionAttrs const &) const; + bool operator>=(MultiHeadAttentionAttrs const &) const; + int embed_dim; + int num_heads; + int kdim; + int vdim; + float dropout; + bool bias; + bool add_bias_kv; + bool add_zero_attn; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MultiHeadAttentionAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MultiHeadAttentionAttrs from_json(json const &); + static void to_json(json &, FlexFlow::MultiHeadAttentionAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionAttrs const &); +std::ostream &operator<<(std::ostream &, MultiHeadAttentionAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml new file mode 100644 index 0000000000..d96d8af69c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml @@ -0,0 +1,43 @@ +namespace = "FlexFlow" +name = "MultiHeadAttentionAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "embed_dim" +type = "int" + +[[fields]] +name = "num_heads" +type = "int" + +[[fields]] +name = "kdim" +type = "int" + +[[fields]] +name = "vdim" +type = "int" + +[[fields]] +name = "dropout" +type = "float" + +[[fields]] +name = "bias" +type = "bool" + +[[fields]] +name = "add_bias_kv" +type = "bool" + +[[fields]] +name = "add_zero_attn" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.dtg.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.dtg.h new file mode 100644 index 0000000000..a8ab52d2b3 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml +/* proj-data +{ + "generated_from": "c3bbf4c76982ef27107b74e1e6e5d360" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct BatchMatmulAttrs { + BatchMatmulAttrs() = delete; + BatchMatmulAttrs(int const &a_seq_length_dim, int const &b_seq_length_dim); + + bool operator==(BatchMatmulAttrs const &) const; + bool operator!=(BatchMatmulAttrs const &) const; + bool operator<(BatchMatmulAttrs const &) const; + bool operator>(BatchMatmulAttrs const &) const; + bool operator<=(BatchMatmulAttrs const &) const; + bool operator>=(BatchMatmulAttrs const &) const; + int a_seq_length_dim; + int b_seq_length_dim; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::BatchMatmulAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::BatchMatmulAttrs from_json(json const &); + static void to_json(json &, FlexFlow::BatchMatmulAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(BatchMatmulAttrs const &); +std::ostream &operator<<(std::ostream &, BatchMatmulAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index b05a5eb022..57760d1110 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -1,18 +1,26 @@ -#ifndef _FF_OP_META_BATCH_MATMUL_ATTRS_H -#define _FF_OP_META_BATCH_MATMUL_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H -#include "core.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/batch_matmul.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { -struct BatchMatmulAttrs { - req a_seq_length_dim, b_seq_length_dim; -}; -FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); +bool is_valid(BatchMatmulAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); -CHECK_VALID_OP_ATTR(BatchMatmulAttrs); +tl::expected + get_output_shape(BatchMatmulAttrs const &attrs, + TensorShape const &input_lhs, + TensorShape const &input_rhs); + +tl::expected + get_output_shape(BatchMatmulAttrs const &attrs, + ParallelTensorShape const &input_lhs, + ParallelTensorShape const &input_rhs); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml b/lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml new file mode 100644 index 0000000000..3b1dd3f687 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "BatchMatmulAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "a_seq_length_dim" +type = "int" + +[[fields]] +name = "b_seq_length_dim" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index 4ec823d4ae..b9a1d87a75 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -2,17 +2,13 @@ #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H #include "core.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" namespace FlexFlow { -struct BatchNormAttrs { - req relu; -}; -FF_VISITABLE_STRUCT(BatchNormAttrs, relu); - -ParallelTensorShape get_output_shape(BatchNormAttrs const &); +ParallelTensorShape get_output_shape(BatchNormAttrs const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchNormAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.h new file mode 100644 index 0000000000..f153bfde7e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml +/* proj-data +{ + "generated_from": "f8e0219d8a3e008a73c38cf84d25f66e" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct BatchNormAttrs { + BatchNormAttrs() = delete; + BatchNormAttrs(bool const &relu); + + bool operator==(BatchNormAttrs const &) const; + bool operator!=(BatchNormAttrs const &) const; + bool operator<(BatchNormAttrs const &) const; + bool operator>(BatchNormAttrs const &) const; + bool operator<=(BatchNormAttrs const &) const; + bool operator>=(BatchNormAttrs const &) const; + bool relu; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::BatchNormAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::BatchNormAttrs from_json(json const &); + static void to_json(json &, FlexFlow::BatchNormAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(BatchNormAttrs const &); +std::ostream &operator<<(std::ostream &, BatchNormAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml new file mode 100644 index 0000000000..bc82f3c743 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "BatchNormAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "relu" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h b/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h new file mode 100644 index 0000000000..e4de3dcc75 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml +/* proj-data +{ + "generated_from": "12715c970e8416eacbd0750f338478e5" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include "utils/stack_vector.h" +#include +#include +#include + +namespace FlexFlow { +struct BroadcastAttrs { + BroadcastAttrs() = delete; + BroadcastAttrs( + ::FlexFlow::stack_vector const &target_dims); + + bool operator==(BroadcastAttrs const &) const; + bool operator!=(BroadcastAttrs const &) const; + bool operator<(BroadcastAttrs const &) const; + bool operator>(BroadcastAttrs const &) const; + bool operator<=(BroadcastAttrs const &) const; + bool operator>=(BroadcastAttrs const &) const; + ::FlexFlow::stack_vector target_dims; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::BroadcastAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::BroadcastAttrs from_json(json const &); + static void to_json(json &, FlexFlow::BroadcastAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(BroadcastAttrs const &); +std::ostream &operator<<(std::ostream &, BroadcastAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index 433bf23241..ad44060400 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -1,18 +1,13 @@ -#ifndef _FLEXFLOW_INCLUDE_OPATTRS_OPS_BROADCAST_H -#define _FLEXFLOW_INCLUDE_OPATTRS_OPS_BROADCAST_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H -#include "core.h" -#include "utils/stack_vector.h" -#include "utils/visitable.h" +#include "op-attrs/ops/broadcast.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct BroadcastAttrs { - req> target_dims; -}; -FF_VISITABLE_STRUCT(BroadcastAttrs, target_dims); - -CHECK_VALID_OP_ATTR(BroadcastAttrs); +ParallelTensorShape get_output_shape(BroadcastAttrs const &, + ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml b/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml new file mode 100644 index 0000000000..c87afa59b5 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "BroadcastAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/stack_vector.h", +] + +[[fields]] +name = "target_dims" +type = "::FlexFlow::stack_vector" diff --git a/lib/op-attrs/include/op-attrs/ops/cast.h b/lib/op-attrs/include/op-attrs/ops/cast.h index 63563f8df8..117dcb1e01 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast.h +++ b/lib/op-attrs/include/op-attrs/ops/cast.h @@ -2,17 +2,10 @@ #define _FLEXFLOW_CAST_ATTRS_H #include "core.h" -#include "op-attrs/datatype.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/cast_attrs.dtg.h" namespace FlexFlow { -struct CastAttrs { - req dtype; -}; -FF_VISITABLE_STRUCT(CastAttrs, dtype); - CHECK_VALID_OP_ATTR(CastAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h new file mode 100644 index 0000000000..33391eb221 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml +/* proj-data +{ + "generated_from": "c171c87db89b9ec9ea7d52a50c153054" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct CastAttrs { + CastAttrs() = delete; + CastAttrs(DataType const &dtype); + + bool operator==(CastAttrs const &) const; + bool operator!=(CastAttrs const &) const; + bool operator<(CastAttrs const &) const; + bool operator>(CastAttrs const &) const; + bool operator<=(CastAttrs const &) const; + bool operator>=(CastAttrs const &) const; + DataType dtype; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::CastAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::CastAttrs from_json(json const &); + static void to_json(json &, FlexFlow::CastAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(CastAttrs const &); +std::ostream &operator<<(std::ostream &, CastAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml new file mode 100644 index 0000000000..6c12680ea1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "CastAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/datatype.h" +] + +[[fields]] +name = "dtype" +type = "DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/combine.h b/lib/op-attrs/include/op-attrs/ops/combine.h index deaba9e093..d9b20fc2c5 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine.h +++ b/lib/op-attrs/include/op-attrs/ops/combine.h @@ -1,20 +1,18 @@ -#ifndef _FLEXFLOW_COMBINE_ATTRS_H -#define _FLEXFLOW_COMBINE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_H -#include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/combine_attrs.dtg.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include namespace FlexFlow { -struct CombineAttrs { - ff_dim_t combine_dim; - req combine_degree; -}; -FF_VISITABLE_STRUCT(CombineAttrs, combine_dim, combine_degree); CHECK_VALID_OP_ATTR(CombineAttrs); +tl::expected + get_output_shape(CombineAttrs const &, ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h new file mode 100644 index 0000000000..43db204bc5 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml +/* proj-data +{ + "generated_from": "58fc5a388fd1a325ef4142094607e39a" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct CombineAttrs { + CombineAttrs() = delete; + CombineAttrs(::FlexFlow::ff_dim_t const &combine_dim, + int const &combine_degree); + + bool operator==(CombineAttrs const &) const; + bool operator!=(CombineAttrs const &) const; + bool operator<(CombineAttrs const &) const; + bool operator>(CombineAttrs const &) const; + bool operator<=(CombineAttrs const &) const; + bool operator>=(CombineAttrs const &) const; + ::FlexFlow::ff_dim_t combine_dim; + int combine_degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::CombineAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::CombineAttrs from_json(json const &); + static void to_json(json &, FlexFlow::CombineAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(CombineAttrs const &); +std::ostream &operator<<(std::ostream &, CombineAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml new file mode 100644 index 0000000000..585295fe1c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "CombineAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", +] + +[[fields]] +name = "combine_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "combine_degree" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index 78f848f18b..8a72708971 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -2,17 +2,10 @@ #define _FLEXFLOW_CONCAT_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/concat_attrs.dtg.h" namespace FlexFlow { -struct ConcatAttrs { - ff_dim_t axis; - req num_inputs; -}; -FF_VISITABLE_STRUCT(ConcatAttrs, axis, num_inputs); CHECK_VALID_OP_ATTR(ConcatAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h new file mode 100644 index 0000000000..3c26473a4e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +/* proj-data +{ + "generated_from": "68e0520b143e0579140a2f2cdd390759" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ConcatAttrs { + ConcatAttrs() = delete; + ConcatAttrs(::FlexFlow::ff_dim_t const &axis, int const &num_inputs); + + bool operator==(ConcatAttrs const &) const; + bool operator!=(ConcatAttrs const &) const; + bool operator<(ConcatAttrs const &) const; + bool operator>(ConcatAttrs const &) const; + bool operator<=(ConcatAttrs const &) const; + bool operator>=(ConcatAttrs const &) const; + ::FlexFlow::ff_dim_t axis; + int num_inputs; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ConcatAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ConcatAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ConcatAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ConcatAttrs const &); +std::ostream &operator<<(std::ostream &, ConcatAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml new file mode 100644 index 0000000000..4faa870bc4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ConcatAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h" +] + +[[fields]] +name = "axis" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "num_inputs" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index 79980d545d..7759380088 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -2,35 +2,26 @@ #define _FLEXFLOW_CONV_2D_ATTRS_H #include "core.h" -#include "op-attrs/activation.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" -#include "utils/visitable.h" namespace FlexFlow { -struct Conv2DAttrs { - int out_channels, kernel_h, kernel_w, stride_h, stride_w, padding_h, - padding_w, groups; - std::optional activation; - req use_bias; -}; - -FF_VISITABLE_STRUCT(Conv2DAttrs, - out_channels, - kernel_h, - kernel_w, - stride_h, - stride_w, - padding_h, - padding_w, - groups, - activation, - use_bias); CHECK_VALID_OP_ATTR(Conv2DAttrs); -TensorShape get_kernel_shape(Conv2DAttrs const &, TensorShape const &); -TensorShape get_bias_shape(Conv2DAttrs const &, TensorShape const &); +TensorShape get_kernel_shape(Conv2DAttrs const &attrs, + TensorShape const &input); +TensorShape get_bias_shape(Conv2DAttrs const &attrs, TensorShape const &input); +TensorShape get_output_shape(Conv2DAttrs const &attrs, + TensorShape const &input); + +ParallelTensorShape get_kernel_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &input_shape); +ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &input_shape); +ParallelTensorShape get_output_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h new file mode 100644 index 0000000000..2e7833064c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h @@ -0,0 +1,72 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml +/* proj-data +{ + "generated_from": "51911f58c134d55b2d0245444acbae53" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_INPUT_SHAPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_INPUT_SHAPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct Conv2DInputShape { + Conv2DInputShape() = delete; + Conv2DInputShape(size_t const &num_samples, + size_t const &num_channels, + size_t const &height, + size_t const &width, + ::FlexFlow::DataType const &datatype); + + bool operator==(Conv2DInputShape const &) const; + bool operator!=(Conv2DInputShape const &) const; + bool operator<(Conv2DInputShape const &) const; + bool operator>(Conv2DInputShape const &) const; + bool operator<=(Conv2DInputShape const &) const; + bool operator>=(Conv2DInputShape const &) const; + size_t num_samples; + size_t num_channels; + size_t height; + size_t width; + ::FlexFlow::DataType datatype; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::Conv2DInputShape const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::Conv2DInputShape from_json(json const &); + static void to_json(json &, FlexFlow::Conv2DInputShape const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DInputShape const &); +std::ostream &operator<<(std::ostream &, Conv2DInputShape const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_INPUT_SHAPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.h b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.h new file mode 100644 index 0000000000..043f5854ae --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_INPUT_SHAPE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_INPUT_SHAPE_H + +#include "op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" + +namespace FlexFlow { + +Conv2DInputShape parse_input_shape(TensorShape const &input); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml new file mode 100644 index 0000000000..77e8c51244 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml @@ -0,0 +1,35 @@ +namespace = "FlexFlow" +name = "Conv2DInputShape" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [ + "", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "num_samples" +type = "size_t" + +[[fields]] +name = "num_channels" +type = "size_t" + +[[fields]] +name = "height" +type = "size_t" + +[[fields]] +name = "width" +type = "size_t" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h new file mode 100644 index 0000000000..846c9e413a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml +/* proj-data +{ + "generated_from": "d80394bdc90f843372760310b6d17a22" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_PARALLEL_INPUT_SHAPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_PARALLEL_INPUT_SHAPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct Conv2DParallelInputShape { + Conv2DParallelInputShape() = delete; + Conv2DParallelInputShape(::FlexFlow::ShardParallelDim const &sample_dim, + ::FlexFlow::ShardParallelDim const &channel_dim, + ::FlexFlow::ShardParallelDim const &height_dim, + ::FlexFlow::ShardParallelDim const &width_dim, + int const &sum_reduction_degree, + int const &discard_copy_reduction_degree, + ::FlexFlow::DataType const &datatype); + + bool operator==(Conv2DParallelInputShape const &) const; + bool operator!=(Conv2DParallelInputShape const &) const; + bool operator<(Conv2DParallelInputShape const &) const; + bool operator>(Conv2DParallelInputShape const &) const; + bool operator<=(Conv2DParallelInputShape const &) const; + bool operator>=(Conv2DParallelInputShape const &) const; + ::FlexFlow::ShardParallelDim sample_dim; + ::FlexFlow::ShardParallelDim channel_dim; + ::FlexFlow::ShardParallelDim height_dim; + ::FlexFlow::ShardParallelDim width_dim; + int sum_reduction_degree; + int discard_copy_reduction_degree; + ::FlexFlow::DataType datatype; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::Conv2DParallelInputShape const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::Conv2DParallelInputShape from_json(json const &); + static void to_json(json &, FlexFlow::Conv2DParallelInputShape const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DParallelInputShape const &); +std::ostream &operator<<(std::ostream &, Conv2DParallelInputShape const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_PARALLEL_INPUT_SHAPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h new file mode 100644 index 0000000000..accc64e751 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_PARALLEL_INPUT_SHAPE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_PARALLEL_INPUT_SHAPE_H + +#include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" + +namespace FlexFlow { + +Conv2DParallelInputShape + parse_parallel_input_shape(ParallelTensorShape const &input); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml new file mode 100644 index 0000000000..68cbd878d1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml @@ -0,0 +1,43 @@ +namespace = "FlexFlow" +name = "Conv2DParallelInputShape" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "sample_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "channel_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "height_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "width_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "sum_reduction_degree" +type = "int" + +[[fields]] +name = "discard_copy_reduction_degree" +type = "int" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h new file mode 100644 index 0000000000..06827656da --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h @@ -0,0 +1,83 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml +/* proj-data +{ + "generated_from": "74f98e1aacb57d847bb450e1d28d3e67" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/activation.dtg.h" +#include "rapidcheck.h" +#include "utils/json.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct Conv2DAttrs { + Conv2DAttrs() = delete; + Conv2DAttrs(int const &out_channels, + int const &kernel_h, + int const &kernel_w, + int const &stride_h, + int const &stride_w, + int const &padding_h, + int const &padding_w, + int const &groups, + std::optional<::FlexFlow::Activation> const &activation, + bool const &use_bias); + + bool operator==(Conv2DAttrs const &) const; + bool operator!=(Conv2DAttrs const &) const; + bool operator<(Conv2DAttrs const &) const; + bool operator>(Conv2DAttrs const &) const; + bool operator<=(Conv2DAttrs const &) const; + bool operator>=(Conv2DAttrs const &) const; + int out_channels; + int kernel_h; + int kernel_w; + int stride_h; + int stride_w; + int padding_h; + int padding_w; + int groups; + std::optional<::FlexFlow::Activation> activation; + bool use_bias; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::Conv2DAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::Conv2DAttrs from_json(json const &); + static void to_json(json &, FlexFlow::Conv2DAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DAttrs const &); +std::ostream &operator<<(std::ostream &, Conv2DAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml new file mode 100644 index 0000000000..353ef93004 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "Conv2DAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "op-attrs/activation.dtg.h", + "utils/json.h", +] + +fields = [ + { name = "out_channels", type = "int" }, + { name = "kernel_h", type = "int" }, + { name = "kernel_w", type = "int" }, + { name = "stride_h", type = "int" }, + { name = "stride_w", type = "int" }, + { name = "padding_h", type = "int" }, + { name = "padding_w", type = "int" }, + { name = "groups", type = "int" }, + { name = "activation", type = "std::optional<::FlexFlow::Activation>" }, + { name = "use_bias", type = "bool" }, +] diff --git a/lib/op-attrs/include/op-attrs/ops/dropout.h b/lib/op-attrs/include/op-attrs/ops/dropout.h index 8e0049f526..a0493301c4 100644 --- a/lib/op-attrs/include/op-attrs/ops/dropout.h +++ b/lib/op-attrs/include/op-attrs/ops/dropout.h @@ -2,16 +2,14 @@ #define _FLEXFLOW_DROPOUT_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct DropoutAttrs { - req rate; - req seed; -}; -FF_VISITABLE_STRUCT(DropoutAttrs, rate, seed); +ParallelTensorShape get_output_shape(DropoutAttrs const &, + ParallelTensorShape const &); + CHECK_VALID_OP_ATTR(DropoutAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.h new file mode 100644 index 0000000000..ef86e49560 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml +/* proj-data +{ + "generated_from": "4fdbf129ea59b8a7306813cfa4c46021" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct DropoutAttrs { + DropoutAttrs() = delete; + DropoutAttrs(float const &rate, unsigned long long const &seed); + + bool operator==(DropoutAttrs const &) const; + bool operator!=(DropoutAttrs const &) const; + bool operator<(DropoutAttrs const &) const; + bool operator>(DropoutAttrs const &) const; + bool operator<=(DropoutAttrs const &) const; + bool operator>=(DropoutAttrs const &) const; + float rate; + unsigned long long seed; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::DropoutAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::DropoutAttrs from_json(json const &); + static void to_json(json &, FlexFlow::DropoutAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(DropoutAttrs const &); +std::ostream &operator<<(std::ostream &, DropoutAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml new file mode 100644 index 0000000000..8731e0780b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "DropoutAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "rate" +type = "float" + +[[fields]] +name = "seed" +type = "unsigned long long" diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary.h b/lib/op-attrs/include/op-attrs/ops/element_binary.h index c4a096166d..d51c3a3afa 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary.h @@ -1,25 +1,20 @@ #ifndef _FLEXFLOW_ELEMENT_BINARY_ATTRS_H #define _FLEXFLOW_ELEMENT_BINARY_ATTRS_H -#include "core.h" -#include "op-attrs/datatype.h" -#include "op-attrs/op.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include namespace FlexFlow { -struct ElementBinaryAttrs { - req type; - req compute_type; - req should_broadcast_lhs; - req should_broadcast_rhs; -}; -FF_VISITABLE_STRUCT(ElementBinaryAttrs, - type, - compute_type, - should_broadcast_lhs, - should_broadcast_rhs); +tl::expected get_output_shape( + ElementBinaryAttrs const &, TensorShape const &, TensorShape const &); +tl::expected + get_output_shape(ElementBinaryAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); + CHECK_VALID_OP_ATTR(ElementBinaryAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h new file mode 100644 index 0000000000..10d93c87d3 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml +/* proj-data +{ + "generated_from": "2bb947c9cc92e3833ee88c908c539629" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.h" +#include "op-attrs/operator_type.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ElementBinaryAttrs { + ElementBinaryAttrs() = delete; + ElementBinaryAttrs(::FlexFlow::OperatorType const &type, + ::FlexFlow::DataType const &compute_type, + bool const &should_broadcast_lhs, + bool const &should_broadcast_rhs); + + bool operator==(ElementBinaryAttrs const &) const; + bool operator!=(ElementBinaryAttrs const &) const; + bool operator<(ElementBinaryAttrs const &) const; + bool operator>(ElementBinaryAttrs const &) const; + bool operator<=(ElementBinaryAttrs const &) const; + bool operator>=(ElementBinaryAttrs const &) const; + ::FlexFlow::OperatorType type; + ::FlexFlow::DataType compute_type; + bool should_broadcast_lhs; + bool should_broadcast_rhs; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ElementBinaryAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ElementBinaryAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ElementBinaryAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ElementBinaryAttrs const &); +std::ostream &operator<<(std::ostream &, ElementBinaryAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml new file mode 100644 index 0000000000..d167c67aed --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "ElementBinaryAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/operator_type.h", + "op-attrs/datatype.h", +] + +[[fields]] +name = "type" +type = "::FlexFlow::OperatorType" + +[[fields]] +name = "compute_type" +type = "::FlexFlow::DataType" + +[[fields]] +name = "should_broadcast_lhs" +type = "bool" + +[[fields]] +name = "should_broadcast_rhs" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h new file mode 100644 index 0000000000..a9fe63ca71 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml +/* proj-data +{ + "generated_from": "aa6f98b992d46bdf7ad59158bc143a3f" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_SCALAR_UNARY_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_SCALAR_UNARY_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/operator_type.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ElementScalarUnaryAttrs { + ElementScalarUnaryAttrs() = delete; + ElementScalarUnaryAttrs(::FlexFlow::OperatorType const &op_type, + float const &scalar); + + bool operator==(ElementScalarUnaryAttrs const &) const; + bool operator!=(ElementScalarUnaryAttrs const &) const; + bool operator<(ElementScalarUnaryAttrs const &) const; + bool operator>(ElementScalarUnaryAttrs const &) const; + bool operator<=(ElementScalarUnaryAttrs const &) const; + bool operator>=(ElementScalarUnaryAttrs const &) const; + ::FlexFlow::OperatorType op_type; + float scalar; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ElementScalarUnaryAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ElementScalarUnaryAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ElementScalarUnaryAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ElementScalarUnaryAttrs const &); +std::ostream &operator<<(std::ostream &, ElementScalarUnaryAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_SCALAR_UNARY_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml new file mode 100644 index 0000000000..609805ab98 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ElementScalarUnaryAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/operator_type.h" +] + +[[fields]] +name = "op_type" +type = "::FlexFlow::OperatorType" + +[[fields]] +name = "scalar" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 6a80094dfa..471a2a30f5 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -1,24 +1,27 @@ #ifndef _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H #define _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H -#include "core.h" -#include "op-attrs/op.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { -struct ElementUnaryAttrs { - req op_type; -}; -FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type); -CHECK_VALID_OP_ATTR(ElementUnaryAttrs); +tl::expected + get_output_shape(ElementUnaryAttrs const &, TensorShape const &); +tl::expected + get_output_shape(ElementUnaryAttrs const &, ParallelTensorShape const &); + +tl::expected + get_output_shape(ElementScalarUnaryAttrs const &, TensorShape const &); +tl::expected + get_output_shape(ElementScalarUnaryAttrs const &, + ParallelTensorShape const &); -struct ElementScalarUnaryAttrs { - Op op_type; - req scalar; -}; -FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op_type, scalar); +CHECK_VALID_OP_ATTR(ElementUnaryAttrs); CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); using ElementUnaryUnifiedAttrs = diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h new file mode 100644 index 0000000000..3220234bd1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml +/* proj-data +{ + "generated_from": "75272cff78d3db866122dbb1001aedbe" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/operator_type.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ElementUnaryAttrs { + ElementUnaryAttrs() = delete; + ElementUnaryAttrs(::FlexFlow::OperatorType const &op_type); + + bool operator==(ElementUnaryAttrs const &) const; + bool operator!=(ElementUnaryAttrs const &) const; + bool operator<(ElementUnaryAttrs const &) const; + bool operator>(ElementUnaryAttrs const &) const; + bool operator<=(ElementUnaryAttrs const &) const; + bool operator>=(ElementUnaryAttrs const &) const; + ::FlexFlow::OperatorType op_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ElementUnaryAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ElementUnaryAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ElementUnaryAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ElementUnaryAttrs const &); +std::ostream &operator<<(std::ostream &, ElementUnaryAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml new file mode 100644 index 0000000000..b0e23aa5c7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "ElementUnaryAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/operator_type.h" +] + +[[fields]] +name = "op_type" +type = "::FlexFlow::OperatorType" diff --git a/lib/op-attrs/include/op-attrs/ops/embedding.h b/lib/op-attrs/include/op-attrs/ops/embedding.h index 733a6523da..aa67c6cb04 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding.h @@ -1,51 +1,26 @@ #ifndef _FLEXFLOW_EMBEDDING_ATTRS_H #define _FLEXFLOW_EMBEDDING_ATTRS_H -#include "core.h" -#include "op-attrs/datatype.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" -#include "utils/fmt.h" -#include "utils/visitable.h" +#include namespace FlexFlow { -enum class AggregateOp { SUM, AVG }; - -struct EmbeddingAttrs { - int num_entries, out_channels; - std::optional aggr; - req data_type; -}; -FF_VISITABLE_STRUCT(EmbeddingAttrs, num_entries, out_channels, aggr, data_type); CHECK_VALID_OP_ATTR(EmbeddingAttrs); -TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &); - -} // namespace FlexFlow - -namespace fmt { - -template <> -struct formatter<::FlexFlow::AggregateOp> : formatter { - template - auto format(::FlexFlow::AggregateOp o, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; +tl::expected get_output_shape(EmbeddingAttrs const &, + TensorShape const &); +tl::expected get_weights_shape(EmbeddingAttrs const &, + TensorShape const &); - string_view name = "unknown"; - switch (o) { - case AggregateOp::SUM: - name = "Sum"; - break; - case AggregateOp::AVG: - name = "Avg"; - break; - } - return formatter::format(name, ctx); - } -}; +tl::expected + get_output_shape(EmbeddingAttrs const &, ParallelTensorShape const &); +tl::expected + get_weights_shape(EmbeddingAttrs const &, ParallelTensorShape const &); -} // namespace fmt +} // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h new file mode 100644 index 0000000000..f1cae86460 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h @@ -0,0 +1,71 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml +/* proj-data +{ + "generated_from": "f2bdea52e23dee6f674f598f8691d994" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/aggregate_op.dtg.h" +#include "op-attrs/datatype.dtg.h" +#include "rapidcheck.h" +#include "utils/stack_vector.h" +#include +#include +#include + +namespace FlexFlow { +struct EmbeddingAttrs { + EmbeddingAttrs() = delete; + EmbeddingAttrs(int const &num_entries, + int const &out_channels, + std::optional<::FlexFlow::AggregateOp> const &aggr, + ::FlexFlow::DataType const &data_type); + + bool operator==(EmbeddingAttrs const &) const; + bool operator!=(EmbeddingAttrs const &) const; + bool operator<(EmbeddingAttrs const &) const; + bool operator>(EmbeddingAttrs const &) const; + bool operator<=(EmbeddingAttrs const &) const; + bool operator>=(EmbeddingAttrs const &) const; + int num_entries; + int out_channels; + std::optional<::FlexFlow::AggregateOp> aggr; + ::FlexFlow::DataType data_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::EmbeddingAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::EmbeddingAttrs from_json(json const &); + static void to_json(json &, FlexFlow::EmbeddingAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(EmbeddingAttrs const &); +std::ostream &operator<<(std::ostream &, EmbeddingAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml new file mode 100644 index 0000000000..f0772c351e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "EmbeddingAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/stack_vector.h", + "op-attrs/aggregate_op.dtg.h", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "num_entries" +type = "int" + +[[fields]] +name = "out_channels" +type = "int" + +[[fields]] +name = "aggr" +type = "std::optional<::FlexFlow::AggregateOp>" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index 706689199d..d5d9069f51 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -2,13 +2,11 @@ #define _FLEXFLOW_FLAT_ATTRS_H #include "core.h" +#include "op-attrs/ops/flat_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" namespace FlexFlow { -struct FlatAttrs {}; -FF_VISITABLE_STRUCT(FlatAttrs); CHECK_VALID_OP_ATTR(FlatAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.h new file mode 100644 index 0000000000..a94c0aeff3 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml +/* proj-data +{ + "generated_from": "b63924cd671481df30fae314a199c606" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct FlatAttrs { + bool operator==(FlatAttrs const &) const; + bool operator!=(FlatAttrs const &) const; + bool operator<(FlatAttrs const &) const; + bool operator>(FlatAttrs const &) const; + bool operator<=(FlatAttrs const &) const; + bool operator>=(FlatAttrs const &) const; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::FlatAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::FlatAttrs from_json(json const &); + static void to_json(json &, FlexFlow::FlatAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(FlatAttrs const &); +std::ostream &operator<<(std::ostream &, FlatAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml new file mode 100644 index 0000000000..e445535e29 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "FlatAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/op-attrs/include/op-attrs/ops/gather.h b/lib/op-attrs/include/op-attrs/ops/gather.h index ca2406ef75..79516a8862 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather.h +++ b/lib/op-attrs/include/op-attrs/ops/gather.h @@ -2,16 +2,11 @@ #define _FLEXFLOW_GATHER_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ops/gather_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" namespace FlexFlow { -struct GatherAttrs { - ff_dim_t dim; -}; -FF_VISITABLE_STRUCT(GatherAttrs, dim); CHECK_VALID_OP_ATTR(GatherAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h new file mode 100644 index 0000000000..e7a35e5800 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml +/* proj-data +{ + "generated_from": "4ba46b6b494a7a52edda437d2a05fcf1" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct GatherAttrs { + GatherAttrs() = delete; + GatherAttrs(::FlexFlow::ff_dim_t const &dim); + + bool operator==(GatherAttrs const &) const; + bool operator!=(GatherAttrs const &) const; + bool operator<(GatherAttrs const &) const; + bool operator>(GatherAttrs const &) const; + bool operator<=(GatherAttrs const &) const; + bool operator>=(GatherAttrs const &) const; + ::FlexFlow::ff_dim_t dim; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::GatherAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::GatherAttrs from_json(json const &); + static void to_json(json &, FlexFlow::GatherAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(GatherAttrs const &); +std::ostream &operator<<(std::ostream &, GatherAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml new file mode 100644 index 0000000000..c8bb88dcc7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "GatherAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h" +] + +[[fields]] +name = "dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/input.h b/lib/op-attrs/include/op-attrs/ops/input.h index 26c486c9ac..9fe0ee2c2d 100644 --- a/lib/op-attrs/include/op-attrs/ops/input.h +++ b/lib/op-attrs/include/op-attrs/ops/input.h @@ -2,14 +2,15 @@ #define _FLEXFLOW_OP_ATTRS_OPS_OP_ATTRS_INPUT_H #include "core.h" -#include "utils/visitable.h" +#include "op-attrs/ops/input_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct InputAttrs {}; -FF_VISITABLE_STRUCT(InputAttrs); CHECK_VALID_OP_ATTR(InputAttrs); +ParallelTensorShape get_output_shape(InputAttrs const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.h new file mode 100644 index 0000000000..aa2ca1e933 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml +/* proj-data +{ + "generated_from": "139ea46d57a3c8738b31b17a8c59a0aa" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_INPUT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_INPUT_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct InputAttrs { + bool operator==(InputAttrs const &) const; + bool operator!=(InputAttrs const &) const; + bool operator<(InputAttrs const &) const; + bool operator>(InputAttrs const &) const; + bool operator<=(InputAttrs const &) const; + bool operator>=(InputAttrs const &) const; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::InputAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::InputAttrs from_json(json const &); + static void to_json(json &, FlexFlow::InputAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(InputAttrs const &); +std::ostream &operator<<(std::ostream &, InputAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_INPUT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml new file mode 100644 index 0000000000..7e29de78df --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "InputAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index dab055b2c9..01130139f1 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -2,18 +2,14 @@ #define _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" namespace FlexFlow { -struct LayerNormAttrs { - stack_vector axes; - req elementwise_affine; - req eps; -}; -FF_VISITABLE_STRUCT(LayerNormAttrs, axes, elementwise_affine, eps); +ParallelTensorShape get_output_shape(LayerNormAttrs const &, + ParallelTensorShape const &); + CHECK_VALID_OP_ATTR(LayerNormAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h new file mode 100644 index 0000000000..c945206863 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml +/* proj-data +{ + "generated_from": "349deae8d9356d3eeacd7e7d069c3155" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include "utils/stack_vector.h" +#include +#include +#include + +namespace FlexFlow { +struct LayerNormAttrs { + LayerNormAttrs() = delete; + LayerNormAttrs(::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, + MAX_TENSOR_DIM> const &axes, + bool const &elementwise_affine, + float const &eps); + + bool operator==(LayerNormAttrs const &) const; + bool operator!=(LayerNormAttrs const &) const; + bool operator<(LayerNormAttrs const &) const; + bool operator>(LayerNormAttrs const &) const; + bool operator<=(LayerNormAttrs const &) const; + bool operator>=(LayerNormAttrs const &) const; + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> axes; + bool elementwise_affine; + float eps; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LayerNormAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::LayerNormAttrs from_json(json const &); + static void to_json(json &, FlexFlow::LayerNormAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(LayerNormAttrs const &); +std::ostream &operator<<(std::ostream &, LayerNormAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml new file mode 100644 index 0000000000..ec60d39f7f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "LayerNormAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", + "utils/stack_vector.h", +] + +[[fields]] +name = "axes" +type = "::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>" + +[[fields]] +name = "elementwise_affine" +type = "bool" + +[[fields]] +name = "eps" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index 3b57a959b8..dd6948165e 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -1,46 +1,31 @@ #ifndef _FLEXFLOW_LINEAR_ATTRS_H #define _FLEXFLOW_LINEAR_ATTRS_H -#include "op-attrs/activation.h" -#include "op-attrs/datatype.h" #include "op-attrs/ops/core.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/linear_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { -struct L1RegularizerAttrs { - req lambda; -}; -FF_VISITABLE_STRUCT(L1RegularizerAttrs, lambda); -CHECK_VALID_OP_ATTR(L1RegularizerAttrs); - -struct L2RegularizerAttrs { - req lambda; -}; -FF_VISITABLE_STRUCT(L2RegularizerAttrs, lambda); -CHECK_VALID_OP_ATTR(L2RegularizerAttrs); - -using RegularizerAttrs = std::variant; - -struct LinearAttrs { - int out_channels; - bool use_bias; - DataType data_type; - std::optional activation; - req> regularizer; -}; -FF_VISITABLE_STRUCT( - LinearAttrs, out_channels, use_bias, data_type, activation, regularizer); CHECK_VALID_OP_ATTR(LinearAttrs); -TensorShape get_weights_shape(LinearAttrs const &attrs, - TensorShape const &input); -ParallelTensorShape get_weights_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input); -TensorShape get_bias_shape(LinearAttrs const &attrs, TensorShape const &input); -ParallelTensorShape get_bias_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input); +tl::expected + get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input); +tl::expected get_bias_shape(LinearAttrs const &attrs, + TensorShape const &input); +tl::expected + get_output_shape(LinearAttrs const &attrs, TensorShape const &input); + +tl::expected + get_kernel_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input); +tl::expected + get_bias_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); +tl::expected + get_output_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h new file mode 100644 index 0000000000..28cd2a8b33 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h @@ -0,0 +1,74 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml +/* proj-data +{ + "generated_from": "7e82d282f90e08f1e0db7d5c4ce528b7" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/activation.dtg.h" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/regularizer_attrs.dtg.h" +#include "rapidcheck.h" +#include "utils/json.h" +#include +#include +#include + +namespace FlexFlow { +struct LinearAttrs { + LinearAttrs() = delete; + LinearAttrs(int const &out_channels, + bool const &use_bias, + ::FlexFlow::DataType const &data_type, + std::optional<::FlexFlow::Activation> const &activation, + std::optional<::FlexFlow::RegularizerAttrs> const ®ularizer); + + bool operator==(LinearAttrs const &) const; + bool operator!=(LinearAttrs const &) const; + bool operator<(LinearAttrs const &) const; + bool operator>(LinearAttrs const &) const; + bool operator<=(LinearAttrs const &) const; + bool operator>=(LinearAttrs const &) const; + int out_channels; + bool use_bias; + ::FlexFlow::DataType data_type; + std::optional<::FlexFlow::Activation> activation; + std::optional<::FlexFlow::RegularizerAttrs> regularizer; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LinearAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::LinearAttrs from_json(json const &); + static void to_json(json &, FlexFlow::LinearAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(LinearAttrs const &); +std::ostream &operator<<(std::ostream &, LinearAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml new file mode 100644 index 0000000000..4ac8f83ec9 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml @@ -0,0 +1,37 @@ +namespace = "FlexFlow" +name = "LinearAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/datatype.dtg.h", + "op-attrs/activation.dtg.h", + "op-attrs/regularizer_attrs.dtg.h", + "utils/json.h", +] + +[[fields]] +name = "out_channels" +type = "int" + +[[fields]] +name = "use_bias" +type = "bool" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" + +[[fields]] +name = "activation" +type = "std::optional<::FlexFlow::Activation>" + +[[fields]] +name = "regularizer" +type = "std::optional<::FlexFlow::RegularizerAttrs>" diff --git a/lib/op-attrs/include/op-attrs/ops/noop.h b/lib/op-attrs/include/op-attrs/ops/noop.h index 658e1b7d98..eb01009259 100644 --- a/lib/op-attrs/include/op-attrs/ops/noop.h +++ b/lib/op-attrs/include/op-attrs/ops/noop.h @@ -2,14 +2,16 @@ #define _FLEXFLOW_OP_ATTRS_OPS_NOOP_H #include "core.h" -#include "utils/visitable.h" +#include "op-attrs/ops/noop_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct NoopAttrs {}; -FF_VISITABLE_STRUCT(NoopAttrs); CHECK_VALID_OP_ATTR(NoopAttrs); +ParallelTensorShape get_output_shape(NoopAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.h new file mode 100644 index 0000000000..ed0d8c9348 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml +/* proj-data +{ + "generated_from": "d440077aa598fdad0e5aa95288b63c40" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct NoopAttrs { + bool operator==(NoopAttrs const &) const; + bool operator!=(NoopAttrs const &) const; + bool operator<(NoopAttrs const &) const; + bool operator>(NoopAttrs const &) const; + bool operator<=(NoopAttrs const &) const; + bool operator>=(NoopAttrs const &) const; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::NoopAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::NoopAttrs from_json(json const &); + static void to_json(json &, FlexFlow::NoopAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(NoopAttrs const &); +std::ostream &operator<<(std::ostream &, NoopAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml new file mode 100644 index 0000000000..3d9202093c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "NoopAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h new file mode 100644 index 0000000000..d3903bd3b2 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml +/* proj-data +{ + "generated_from": "b76a39763275090d8376e1c27668d2cb" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_PARALLEL_ATTENTION_INPUTS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_PARALLEL_ATTENTION_INPUTS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/parallel_tensor_shape.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ParallelMultiHeadAttentionInputs { + ParallelMultiHeadAttentionInputs() = delete; + ParallelMultiHeadAttentionInputs( + ::FlexFlow::ParallelTensorShape const &query, + ::FlexFlow::ParallelTensorShape const &key, + ::FlexFlow::ParallelTensorShape const &value); + + bool operator==(ParallelMultiHeadAttentionInputs const &) const; + bool operator!=(ParallelMultiHeadAttentionInputs const &) const; + ::FlexFlow::ParallelTensorShape query; + ::FlexFlow::ParallelTensorShape key; + ::FlexFlow::ParallelTensorShape value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelMultiHeadAttentionInputs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelMultiHeadAttentionInputs from_json(json const &); + static void to_json(json &, + FlexFlow::ParallelMultiHeadAttentionInputs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ParallelMultiHeadAttentionInputs const &); +std::ostream &operator<<(std::ostream &, + ParallelMultiHeadAttentionInputs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_PARALLEL_ATTENTION_INPUTS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml new file mode 100644 index 0000000000..4809ee998a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "ParallelMultiHeadAttentionInputs" +features = [ + "eq", + # "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.h" +] + +[[fields]] +name = "query" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "key" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "value" +type = "::FlexFlow::ParallelTensorShape" diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index efe29b3b2e..162f9aef05 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -2,57 +2,16 @@ #define _FLEXFLOW_POOL_2D_ATTRS_H #include "core.h" -#include "op-attrs/activation.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -enum class PoolOp { - MAX, - AVG, -}; - -struct Pool2DAttrs { - req kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w; - req pool_type; - req activation; -}; -FF_VISITABLE_STRUCT(Pool2DAttrs, - kernel_h, - kernel_w, - stride_h, - stride_w, - padding_h, - padding_w, - pool_type, - activation); CHECK_VALID_OP_ATTR(Pool2DAttrs); -} // namespace FlexFlow - -namespace fmt { +ParallelTensorShape get_output_shape(Pool2DAttrs const &, + ParallelTensorShape const &); -template <> -struct formatter<::FlexFlow::PoolOp> : formatter { - template - auto format(::FlexFlow::PoolOp o, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (o) { - case PoolOp::AVG: - name = "Avg"; - break; - case PoolOp::MAX: - name = "Max"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt +} // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h new file mode 100644 index 0000000000..a5c6603302 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml +/* proj-data +{ + "generated_from": "03aeafe335f68ff831e3e73a77f45caf" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/activation.dtg.h" +#include "op-attrs/pool_op.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct Pool2DAttrs { + Pool2DAttrs() = delete; + Pool2DAttrs(int const &kernel_h, + int const &kernel_w, + int const &stride_h, + int const &stride_w, + int const &padding_h, + int const &padding_w, + ::FlexFlow::PoolOp const &pool_type, + ::FlexFlow::Activation const &activation); + + bool operator==(Pool2DAttrs const &) const; + bool operator!=(Pool2DAttrs const &) const; + bool operator<(Pool2DAttrs const &) const; + bool operator>(Pool2DAttrs const &) const; + bool operator<=(Pool2DAttrs const &) const; + bool operator>=(Pool2DAttrs const &) const; + int kernel_h; + int kernel_w; + int stride_h; + int stride_w; + int padding_h; + int padding_w; + ::FlexFlow::PoolOp pool_type; + ::FlexFlow::Activation activation; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::Pool2DAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::Pool2DAttrs from_json(json const &); + static void to_json(json &, FlexFlow::Pool2DAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(Pool2DAttrs const &); +std::ostream &operator<<(std::ostream &, Pool2DAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml new file mode 100644 index 0000000000..56bf682f50 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml @@ -0,0 +1,47 @@ +namespace = "FlexFlow" +name = "Pool2DAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/pool_op.dtg.h", + "op-attrs/activation.dtg.h", +] + +[[fields]] +name = "kernel_h" +type = "int" + +[[fields]] +name = "kernel_w" +type = "int" + +[[fields]] +name = "stride_h" +type = "int" + +[[fields]] +name = "stride_w" +type = "int" + +[[fields]] +name = "padding_h" +type = "int" + +[[fields]] +name = "padding_w" +type = "int" + +[[fields]] +name = "pool_type" +type = "::FlexFlow::PoolOp" + +[[fields]] +name = "activation" +type = "::FlexFlow::Activation" diff --git a/lib/op-attrs/include/op-attrs/ops/reduce.h b/lib/op-attrs/include/op-attrs/ops/reduce.h index 193d3b0dc8..04e44b4161 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce.h @@ -1,23 +1,17 @@ #ifndef _FLEXFLOW_OP_META_OPS_REDUCE_ATTRS_H #define _FLEXFLOW_OP_META_OPS_REDUCE_ATTRS_H -#include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/op.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/stack_vector.h" -#include "utils/visitable.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct ReduceAttrs { - stack_vector axes; - req op_type; - req keepdims; -}; -FF_VISITABLE_STRUCT(ReduceAttrs, axes, op_type, keepdims); CHECK_VALID_OP_ATTR(ReduceAttrs); +ParallelTensorShape get_output_shape(ReduceAttrs const &, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h new file mode 100644 index 0000000000..af27bf35be --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h @@ -0,0 +1,71 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml +/* proj-data +{ + "generated_from": "097463446e254f662c7bdf5df4e12d17" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "op-attrs/operator_type.dtg.h" +#include "rapidcheck.h" +#include "utils/stack_vector.h" +#include +#include +#include + +namespace FlexFlow { +struct ReduceAttrs { + ReduceAttrs() = delete; + ReduceAttrs(::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, + MAX_TENSOR_DIM> const &axes, + ::FlexFlow::OperatorType const &op_type, + bool const &keepdims); + + bool operator==(ReduceAttrs const &) const; + bool operator!=(ReduceAttrs const &) const; + bool operator<(ReduceAttrs const &) const; + bool operator>(ReduceAttrs const &) const; + bool operator<=(ReduceAttrs const &) const; + bool operator>=(ReduceAttrs const &) const; + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> axes; + ::FlexFlow::OperatorType op_type; + bool keepdims; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReduceAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReduceAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ReduceAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReduceAttrs const &); +std::ostream &operator<<(std::ostream &, ReduceAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml new file mode 100644 index 0000000000..717e7954e8 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "ReduceAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/operator_type.dtg.h", + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", + "utils/stack_vector.h", +] + +[[fields]] +name = "axes" +type = "::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>" + +[[fields]] +name = "op_type" +type = "::FlexFlow::OperatorType" + +[[fields]] +name = "keepdims" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/reduction.h b/lib/op-attrs/include/op-attrs/ops/reduction.h index f848f879fc..a6047b38f9 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction.h @@ -1,20 +1,19 @@ #ifndef _FLEXFLOW_REDUCTION_ATTRS_H #define _FLEXFLOW_REDUCTION_ATTRS_H -#include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/reduction_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include namespace FlexFlow { -struct ReductionAttrs { - ff_dim_t reduction_dim; - req reduction_degree; -}; -FF_VISITABLE_STRUCT(ReductionAttrs, reduction_dim, reduction_degree); CHECK_VALID_OP_ATTR(ReductionAttrs); +tl::expected + get_output_shape(ReductionAttrs const &attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h new file mode 100644 index 0000000000..9de5eb2252 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml +/* proj-data +{ + "generated_from": "1d2b5b7cf11ed04a27a6fd8215e4e2a5" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ReductionAttrs { + ReductionAttrs() = delete; + ReductionAttrs(int const &reduction_degree); + + bool operator==(ReductionAttrs const &) const; + bool operator!=(ReductionAttrs const &) const; + bool operator<(ReductionAttrs const &) const; + bool operator>(ReductionAttrs const &) const; + bool operator<=(ReductionAttrs const &) const; + bool operator>=(ReductionAttrs const &) const; + int reduction_degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReductionAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReductionAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ReductionAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReductionAttrs const &); +std::ostream &operator<<(std::ostream &, ReductionAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml new file mode 100644 index 0000000000..ee0ae54132 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ReductionAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "reduction_degree" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/repartition.h b/lib/op-attrs/include/op-attrs/ops/repartition.h index 83c4ae870b..559e7278f5 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition.h @@ -1,20 +1,19 @@ #ifndef _FLEXFLOW_PARTITION_ATTRS_H #define _FLEXFLOW_PARTITION_ATTRS_H -#include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/repartition_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include namespace FlexFlow { -struct RepartitionAttrs { - ff_dim_t repartition_dim; - req repartition_degree; -}; -FF_VISITABLE_STRUCT(RepartitionAttrs, repartition_dim, repartition_degree); CHECK_VALID_OP_ATTR(RepartitionAttrs); +tl::expected + get_output_shape(RepartitionAttrs const &, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h new file mode 100644 index 0000000000..66c21466f4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml +/* proj-data +{ + "generated_from": "0a4d8b435768ce3ee37013fc550c9ebb" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct RepartitionAttrs { + RepartitionAttrs() = delete; + RepartitionAttrs(::FlexFlow::ff_dim_t const &repartition_dim, + int const &repartition_degree); + + bool operator==(RepartitionAttrs const &) const; + bool operator!=(RepartitionAttrs const &) const; + bool operator<(RepartitionAttrs const &) const; + bool operator>(RepartitionAttrs const &) const; + bool operator<=(RepartitionAttrs const &) const; + bool operator>=(RepartitionAttrs const &) const; + ::FlexFlow::ff_dim_t repartition_dim; + int repartition_degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::RepartitionAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::RepartitionAttrs from_json(json const &); + static void to_json(json &, FlexFlow::RepartitionAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(RepartitionAttrs const &); +std::ostream &operator<<(std::ostream &, RepartitionAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml new file mode 100644 index 0000000000..25a33c0c15 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "RepartitionAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", +] + +[[fields]] +name = "repartition_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "repartition_degree" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index 92e64a4120..4c46bf88a9 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -2,19 +2,16 @@ #define _FLEXFLOW_REPLICATE_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/replicate_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct ReplicateAttrs { - ff_dim_t replicate_dim; - req replicate_degree; -}; -FF_VISITABLE_STRUCT(ReplicateAttrs, replicate_dim, replicate_degree); CHECK_VALID_OP_ATTR(ReplicateAttrs); +ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h new file mode 100644 index 0000000000..ea3f0d46c7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml +/* proj-data +{ + "generated_from": "6d3ad4d10c24dae819ffee4592a72499" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ReplicateAttrs { + ReplicateAttrs() = delete; + ReplicateAttrs(int const &replicate_degree); + + bool operator==(ReplicateAttrs const &) const; + bool operator!=(ReplicateAttrs const &) const; + bool operator<(ReplicateAttrs const &) const; + bool operator>(ReplicateAttrs const &) const; + bool operator<=(ReplicateAttrs const &) const; + bool operator>=(ReplicateAttrs const &) const; + int replicate_degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReplicateAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReplicateAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ReplicateAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicateAttrs const &); +std::ostream &operator<<(std::ostream &, ReplicateAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml new file mode 100644 index 0000000000..4e43ea747a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "ReplicateAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ ] + +[[fields]] +name = "replicate_degree" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/reshape.h b/lib/op-attrs/include/op-attrs/ops/reshape.h index b118482a2b..cd2ca80c3a 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape.h +++ b/lib/op-attrs/include/op-attrs/ops/reshape.h @@ -2,17 +2,16 @@ #define _FLEXFLOW_RESHAPE_ATTRS_H #include "core.h" -#include "op-attrs/tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct ReshapeAttrs { - TensorShape shape; -}; -FF_VISITABLE_STRUCT(ReshapeAttrs, shape); CHECK_VALID_OP_ATTR(ReshapeAttrs); +ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h new file mode 100644 index 0000000000..612874790f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml +/* proj-data +{ + "generated_from": "015d04de0ccb982e7eaa013a842880ca" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/tensor_shape.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ReshapeAttrs { + ReshapeAttrs() = delete; + ReshapeAttrs(::FlexFlow::TensorShape const &shape); + + bool operator==(ReshapeAttrs const &) const; + bool operator!=(ReshapeAttrs const &) const; + bool operator<(ReshapeAttrs const &) const; + bool operator>(ReshapeAttrs const &) const; + bool operator<=(ReshapeAttrs const &) const; + bool operator>=(ReshapeAttrs const &) const; + ::FlexFlow::TensorShape shape; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReshapeAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReshapeAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ReshapeAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReshapeAttrs const &); +std::ostream &operator<<(std::ostream &, ReshapeAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml new file mode 100644 index 0000000000..69ac761859 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "ReshapeAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_shape.dtg.h", +] + +[[fields]] +name = "shape" +type = "::FlexFlow::TensorShape" diff --git a/lib/op-attrs/include/op-attrs/ops/reverse.h b/lib/op-attrs/include/op-attrs/ops/reverse.h index 6030285f14..adc62dc9ae 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse.h +++ b/lib/op-attrs/include/op-attrs/ops/reverse.h @@ -2,17 +2,16 @@ #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "utils/visitable.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct ReverseAttrs { - ff_dim_t axis; -}; -FF_VISITABLE_STRUCT(ReverseAttrs, axis); CHECK_VALID_OP_ATTR(ReverseAttrs); +ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h new file mode 100644 index 0000000000..8c8c8a7a9e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml +/* proj-data +{ + "generated_from": "c5a82c8a15ac3ce6f47dc054236ab69b" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ReverseAttrs { + ReverseAttrs() = delete; + ReverseAttrs(::FlexFlow::ff_dim_t const &axis); + + bool operator==(ReverseAttrs const &) const; + bool operator!=(ReverseAttrs const &) const; + bool operator<(ReverseAttrs const &) const; + bool operator>(ReverseAttrs const &) const; + bool operator<=(ReverseAttrs const &) const; + bool operator>=(ReverseAttrs const &) const; + ::FlexFlow::ff_dim_t axis; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReverseAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReverseAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ReverseAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReverseAttrs const &); +std::ostream &operator<<(std::ostream &, ReverseAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml new file mode 100644 index 0000000000..198346e5dd --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "ReverseAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", +] + +[[fields]] +name = "axis" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/softmax.h b/lib/op-attrs/include/op-attrs/ops/softmax.h index 9a776737f5..d855716cfb 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax.h +++ b/lib/op-attrs/include/op-attrs/ops/softmax.h @@ -2,18 +2,16 @@ #define _FLEXFLOW_SOFTMAX_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct SoftmaxAttrs { - ff_dim_t dim; -}; -FF_VISITABLE_STRUCT(SoftmaxAttrs, dim); CHECK_VALID_OP_ATTR(SoftmaxAttrs); +ParallelTensorShape get_output_shape(SoftmaxAttrs const &attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h new file mode 100644 index 0000000000..1c855d90f4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml +/* proj-data +{ + "generated_from": "2ddf5a8b7daa32a43387f5fd5866bb3b" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct SoftmaxAttrs { + SoftmaxAttrs() = delete; + SoftmaxAttrs(::FlexFlow::ff_dim_t const &dim); + + bool operator==(SoftmaxAttrs const &) const; + bool operator!=(SoftmaxAttrs const &) const; + bool operator<(SoftmaxAttrs const &) const; + bool operator>(SoftmaxAttrs const &) const; + bool operator<=(SoftmaxAttrs const &) const; + bool operator>=(SoftmaxAttrs const &) const; + ::FlexFlow::ff_dim_t dim; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SoftmaxAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::SoftmaxAttrs from_json(json const &); + static void to_json(json &, FlexFlow::SoftmaxAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(SoftmaxAttrs const &); +std::ostream &operator<<(std::ostream &, SoftmaxAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml new file mode 100644 index 0000000000..8b839c122a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "SoftmaxAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", +] + +[[fields]] +name = "dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/split.h b/lib/op-attrs/include/op-attrs/ops/split.h index fa66bc46f5..8fc2257760 100644 --- a/lib/op-attrs/include/op-attrs/ops/split.h +++ b/lib/op-attrs/include/op-attrs/ops/split.h @@ -2,19 +2,18 @@ #define _FLEXFLOW_SPLIT_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/split_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include namespace FlexFlow { -struct SplitAttrs { - req> splits; - ff_dim_t axis; -}; -FF_VISITABLE_STRUCT(SplitAttrs, splits, axis); CHECK_VALID_OP_ATTR(SplitAttrs); +std::vector + get_output_shapes(SplitAttrs const &attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h new file mode 100644 index 0000000000..b602015e2e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h @@ -0,0 +1,67 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml +/* proj-data +{ + "generated_from": "cde6b5caf6739d3b02fe8fce0d8ae8c5" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include "utils/stack_vector.h" +#include +#include +#include + +namespace FlexFlow { +struct SplitAttrs { + SplitAttrs() = delete; + SplitAttrs(::FlexFlow::stack_vector const &splits, + ::FlexFlow::ff_dim_t const &axis); + + bool operator==(SplitAttrs const &) const; + bool operator!=(SplitAttrs const &) const; + bool operator<(SplitAttrs const &) const; + bool operator>(SplitAttrs const &) const; + bool operator<=(SplitAttrs const &) const; + bool operator>=(SplitAttrs const &) const; + ::FlexFlow::stack_vector splits; + ::FlexFlow::ff_dim_t axis; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SplitAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::SplitAttrs from_json(json const &); + static void to_json(json &, FlexFlow::SplitAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(SplitAttrs const &); +std::ostream &operator<<(std::ostream &, SplitAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml new file mode 100644 index 0000000000..8cdf7728af --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "SplitAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/stack_vector.h", + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", +] + +[[fields]] +name = "splits" +type = "::FlexFlow::stack_vector" + +[[fields]] +name = "axis" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index 413855913c..c6af40dd48 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -2,18 +2,16 @@ #define _FLEXFLOW_TOPK_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/topk_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct TopKAttrs { - req k; - req sorted; -}; -FF_VISITABLE_STRUCT(TopKAttrs, k, sorted); CHECK_VALID_OP_ATTR(TopKAttrs); +ParallelTensorShape get_output_shape(TopKAttrs const &attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.h new file mode 100644 index 0000000000..d1f32f67b7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml +/* proj-data +{ + "generated_from": "c1be9dc2acafc58690713e650663cc93" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct TopKAttrs { + TopKAttrs() = delete; + TopKAttrs(int const &k, bool const &sorted); + + bool operator==(TopKAttrs const &) const; + bool operator!=(TopKAttrs const &) const; + bool operator<(TopKAttrs const &) const; + bool operator>(TopKAttrs const &) const; + bool operator<=(TopKAttrs const &) const; + bool operator>=(TopKAttrs const &) const; + int k; + bool sorted; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TopKAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TopKAttrs from_json(json const &); + static void to_json(json &, FlexFlow::TopKAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(TopKAttrs const &); +std::ostream &operator<<(std::ostream &, TopKAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml new file mode 100644 index 0000000000..9ecbf1d725 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "TopKAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "k" +type = "int" + +[[fields]] +name = "sorted" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/transpose.h b/lib/op-attrs/include/op-attrs/ops/transpose.h index 87db435979..6e23d91d78 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose.h @@ -2,18 +2,16 @@ #define _FLEXFLOW_OP_META_OPS_TRANSPOSE_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct TransposeAttrs { - req> perm; -}; -FF_VISITABLE_STRUCT(TransposeAttrs, perm); CHECK_VALID_OP_ATTR(TransposeAttrs); +ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h new file mode 100644 index 0000000000..f4d932845f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml +/* proj-data +{ + "generated_from": "de62a505821a59c4b77197c100e204f7" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/dim_ordered.h" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct TransposeAttrs { + TransposeAttrs() = delete; + TransposeAttrs(::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t> const &perm); + + bool operator==(TransposeAttrs const &) const; + bool operator!=(TransposeAttrs const &) const; + bool operator<(TransposeAttrs const &) const; + bool operator>(TransposeAttrs const &) const; + bool operator<=(TransposeAttrs const &) const; + bool operator>=(TransposeAttrs const &) const; + ::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t> perm; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TransposeAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TransposeAttrs from_json(json const &); + static void to_json(json &, FlexFlow::TransposeAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(TransposeAttrs const &); +std::ostream &operator<<(std::ostream &, TransposeAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml new file mode 100644 index 0000000000..756091f653 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "TransposeAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", + "op-attrs/dim_ordered.h", +] + +[[fields]] +name = "perm" +type = "::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>" diff --git a/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.h new file mode 100644 index 0000000000..4a19909c25 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml +/* proj-data +{ + "generated_from": "59f49374ffca95b2117b8940af1b6cac" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct WeightAttrs { + bool operator==(WeightAttrs const &) const; + bool operator!=(WeightAttrs const &) const; + bool operator<(WeightAttrs const &) const; + bool operator>(WeightAttrs const &) const; + bool operator<=(WeightAttrs const &) const; + bool operator>=(WeightAttrs const &) const; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::WeightAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::WeightAttrs from_json(json const &); + static void to_json(json &, FlexFlow::WeightAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(WeightAttrs const &); +std::ostream &operator<<(std::ostream &, WeightAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml new file mode 100644 index 0000000000..28810a437e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "WeightAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h b/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h new file mode 100644 index 0000000000..4115d4ce1f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h @@ -0,0 +1,128 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_dim.variant.toml +/* proj-data +{ + "generated_from": "f382ff547aae62777e5091f00d034d84" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/replica_parallel_dim.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct ParallelDim { + ParallelDim() = delete; + explicit ParallelDim(::FlexFlow::ShardParallelDim const &); + explicit ParallelDim(::FlexFlow::ReplicaParallelDim const &); + template + static constexpr bool IsPartOfParallelDim_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::ShardParallelDim>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::ReplicaParallelDim>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type ParallelDim", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::ShardParallelDim>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::ReplicaParallelDim>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type ParallelDim", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfParallelDim_v, + "ParallelDim::has() expected one of [::FlexFlow::ShardParallelDim, " + "::FlexFlow::ReplicaParallelDim], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfParallelDim_v, + "ParallelDim::get() expected one of [::FlexFlow::ShardParallelDim, " + "::FlexFlow::ReplicaParallelDim], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfParallelDim_v, + "ParallelDim::get() expected one of [::FlexFlow::ShardParallelDim, " + "::FlexFlow::ReplicaParallelDim], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(ParallelDim const &) const; + bool operator!=(ParallelDim const &) const; + bool operator<(ParallelDim const &) const; + bool operator>(ParallelDim const &) const; + bool operator<=(ParallelDim const &) const; + bool operator>=(ParallelDim const &) const; + std::variant<::FlexFlow::ShardParallelDim, ::FlexFlow::ReplicaParallelDim> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::ParallelDim> { + size_t operator()(::FlexFlow::ParallelDim const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::ParallelDim> { + static ::FlexFlow::ParallelDim from_json(json const &); + static void to_json(json &, ::FlexFlow::ParallelDim const &); +}; +} // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::ParallelDim> { + static Gen<::FlexFlow::ParallelDim> arbitrary(); +}; +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::ParallelDim const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::ParallelDim const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.h b/lib/op-attrs/include/op-attrs/parallel_dim.h index 9d407ec469..5397ad7c68 100644 --- a/lib/op-attrs/include/op-attrs/parallel_dim.h +++ b/lib/op-attrs/include/op-attrs/parallel_dim.h @@ -1,24 +1,17 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_H -#include "utils/type_traits.h" -#include "utils/visitable.h" +#include "op-attrs/parallel_dim.dtg.h" namespace FlexFlow { -struct ParallelDim { - size_t size; - int degree; - req is_replica_dim; -}; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ParallelDim, - size, - degree, - is_replica_dim); - bool is_valid(ParallelDim const &); bool is_replica_dim(ParallelDim const &); +ParallelDim with_size_set_to(ParallelDim const &, size_t); +ParallelDim with_degree_set_to(ParallelDim const &, int); +ParallelDim with_is_replica_set_to(ParallelDim const &, bool); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml b/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml new file mode 100644 index 0000000000..e27e6509fe --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ParallelDim" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/replica_parallel_dim.dtg.h", +] + +[[values]] +type = "::FlexFlow::ShardParallelDim" +key = "shard_dim" + +[[values]] +type = "::FlexFlow::ReplicaParallelDim" +key = "replica_dim" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h new file mode 100644 index 0000000000..71ad517095 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h @@ -0,0 +1,71 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml +/* proj-data +{ + "generated_from": "aec3b6b66e34be0d5ce3055822479430" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/dim_ordered.h" +#include "op-attrs/replica_parallel_dim_set.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "rapidcheck.h" +#include "utils/fmt/pair.h" +#include "utils/fmt/unordered_map.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ParallelTensorDims { + ParallelTensorDims() = delete; + ParallelTensorDims( + ::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim> const &shard_dims, + ::FlexFlow::ReplicaParallelDimSet const &replica_dims); + + bool operator==(ParallelTensorDims const &) const; + bool operator!=(ParallelTensorDims const &) const; + bool operator<(ParallelTensorDims const &) const; + bool operator>(ParallelTensorDims const &) const; + bool operator<=(ParallelTensorDims const &) const; + bool operator>=(ParallelTensorDims const &) const; + ::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim> shard_dims; + ::FlexFlow::ReplicaParallelDimSet replica_dims; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelTensorDims const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelTensorDims from_json(json const &); + static void to_json(json &, FlexFlow::ParallelTensorDims const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ParallelTensorDims const &); +std::ostream &operator<<(std::ostream &, ParallelTensorDims const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index d38ba75232..8e02e3607b 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -1,53 +1,32 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_H -#include "parallel_dim.h" -#include "utils/visitable.h" +#include "op-attrs/parallel_dim.h" +#include "op-attrs/parallel_tensor_dims.dtg.h" +#include "op-attrs/tensor_dims.dtg.h" namespace FlexFlow { -struct ParallelTensorDims : public use_visitable_cmp { - explicit ParallelTensorDims(TensorDims const &); - - size_t get_volume() const; - size_t num_dims() const; - - using iterator = typename FFOrdered::iterator; - using const_iterator = typename FFOrdered::const_iterator; - using reverse_iterator = typename FFOrdered::reverse_iterator; - using const_reverse_iterator = - typename FFOrdered::const_reverse_iterator; - using value_type = typename FFOrdered::value_type; - using pointer = typename FFOrdered::pointer; - using const_pointer = typename FFOrdered::const_pointer; - - ParallelDim const &at(ff_dim_t const &) const; - ParallelDim &at(ff_dim_t const &); - - iterator begin(); - const_iterator begin() const; - const_iterator cbegin() const; - iterator end(); - const_iterator end() const; - const_iterator cend() const; - reverse_iterator rbegin(); - const_reverse_iterator rbegin() const; - const_reverse_iterator crbegin() const; - reverse_iterator rend(); - const_reverse_iterator rend() const; - const_reverse_iterator crend() const; - -public: - FFOrdered data; -}; +FFOrdered ff_ordered_shard_dims(ParallelTensorDims const &); +FFOrdered ff_ordered_shard_degrees(ParallelTensorDims const &); +std::unordered_set replica_dims(ParallelTensorDims const &); + +/* size_t get_volume(ParallelTensorDims const &); */ +size_t num_shard_dims(ParallelTensorDims const &); + +int total_replica_degree(ParallelTensorDims const &); +int total_shard_degree(ParallelTensorDims const &); +int total_parallel_degree(ParallelTensorDims const &); + +ShardParallelDim shard_dim_at_idx(ParallelTensorDims const &, ff_dim_t); +ShardParallelDim &shard_dim_at_idx(ParallelTensorDims &, ff_dim_t); bool is_valid(ParallelTensorDims const &); TensorDims get_piece_dims(ParallelTensorDims const &); TensorDims get_tensor_dims_unsafe(ParallelTensorDims const &); -} // namespace FlexFlow +TensorDims get_reduced_dims(ParallelTensorDims const &); -VISITABLE_STRUCT(::FlexFlow::ParallelTensorDims, data); -MAKE_VISIT_HASHABLE(::FlexFlow::ParallelTensorDims); +} // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml new file mode 100644 index 0000000000..ae6eab1e58 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "ParallelTensorDims" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/dim_ordered.h", + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/replica_parallel_dim_set.dtg.h", + "", + "utils/fmt/unordered_map.h", + "utils/fmt/pair.h", +] + +[[fields]] +name = "shard_dims" +type = "::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>" + +[[fields]] +name = "replica_dims" +type = "::FlexFlow::ReplicaParallelDimSet" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h new file mode 100644 index 0000000000..62d291fa4f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml +/* proj-data +{ + "generated_from": "06d657d1e95f34aebf4b721c768cbee8" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.h" +#include "op-attrs/parallel_tensor_dims.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ParallelTensorShape { + ParallelTensorShape() = delete; + ParallelTensorShape(::FlexFlow::ParallelTensorDims const &dims, + ::FlexFlow::DataType const &data_type); + + bool operator==(ParallelTensorShape const &) const; + bool operator!=(ParallelTensorShape const &) const; + bool operator<(ParallelTensorShape const &) const; + bool operator>(ParallelTensorShape const &) const; + bool operator<=(ParallelTensorShape const &) const; + bool operator>=(ParallelTensorShape const &) const; + ::FlexFlow::ParallelTensorDims dims; + ::FlexFlow::DataType data_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelTensorShape const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelTensorShape from_json(json const &); + static void to_json(json &, FlexFlow::ParallelTensorShape const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ParallelTensorShape const &); +std::ostream &operator<<(std::ostream &, ParallelTensorShape const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index fd560352bb..99be635ffc 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -1,55 +1,47 @@ #ifndef _OP_META_PARALLEL_TENSOR_SHAPE_H #define _OP_META_PARALLEL_TENSOR_SHAPE_H -#include "datatype.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.h" -#include "parallel_tensor_dims.h" -#include "utils/bidict.h" -#include "utils/record_formatter.h" -#include "utils/stack_vector.h" -#include "utils/visitable.h" -#include #include namespace FlexFlow { -/** - * @brief Represent the shape of a ParallelTensor. - */ -struct ParallelTensorShape : public use_visitable_cmp { - ParallelTensorShape() = delete; +int num_shard_dims(ParallelTensorShape const &); +ShardParallelDim shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); +ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &, ff_dim_t); - template - ParallelTensorShape(Dims const &dims, DataType data_type) - : dims(dims), data_type(data_type) {} +FFOrdered ff_ordered_shard_degrees(ParallelTensorShape const &); - ParallelTensorShape(TensorShape const &); +std::optional + try_get_shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); - int num_dims() const; - - ParallelDim const &at(ff_dim_t const &) const; - ParallelDim &at(ff_dim_t const &); - ParallelDim const &operator[](ff_dim_t const &) const; - ParallelDim &operator[](ff_dim_t const &); - -public: - ParallelTensorDims dims; - DataType data_type; -}; +ParallelTensorShape lift_to_parallel(TensorShape const &); +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &, + SumDegree sum_degree, + DiscardCopyDegree discard_copy_degree, + FFOrdered const &shard_degrees); +std::unordered_set + replica_dims(ParallelTensorShape const &); TensorShape get_piece_shape(ParallelTensorShape const &); int get_num_replica_dims(ParallelTensorShape const &); int get_num_replicas(ParallelTensorShape const &); +int get_sum_degree(ParallelTensorShape const &); +int get_discard_copy_degree(ParallelTensorShape const &); + +int get_total_parallel_degree(ParallelTensorShape const &); + bool is_valid(ParallelTensorShape const &); TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &); std::vector get_tensor_shapes_unsafe(std::vector const &); -} // namespace FlexFlow +TensorShape get_reduced_shape(ParallelTensorShape const &); -VISITABLE_STRUCT(::FlexFlow::ParallelTensorShape, data_type, dims); -MAKE_VISIT_HASHABLE(::FlexFlow::ParallelTensorShape); +} // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml new file mode 100644 index 0000000000..e6197bcd51 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ParallelTensorShape" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_dims.h", + "op-attrs/datatype.h", +] + +[[fields]] +name = "dims" +type = "::FlexFlow::ParallelTensorDims" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h new file mode 100644 index 0000000000..a820bfe81c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml +/* proj-data +{ + "generated_from": "e4677d1fb25d3833570ee567f5659914" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DISCARD_COPY_DEGREE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DISCARD_COPY_DEGREE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct DiscardCopyDegree { + DiscardCopyDegree() = delete; + DiscardCopyDegree(int const &value); + + bool operator==(DiscardCopyDegree const &) const; + bool operator!=(DiscardCopyDegree const &) const; + bool operator<(DiscardCopyDegree const &) const; + bool operator>(DiscardCopyDegree const &) const; + bool operator<=(DiscardCopyDegree const &) const; + bool operator>=(DiscardCopyDegree const &) const; + int value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::DiscardCopyDegree const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::DiscardCopyDegree from_json(json const &); + static void to_json(json &, FlexFlow::DiscardCopyDegree const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(DiscardCopyDegree const &); +std::ostream &operator<<(std::ostream &, DiscardCopyDegree const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DISCARD_COPY_DEGREE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml new file mode 100644 index 0000000000..b4905fb0ce --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "DiscardCopyDegree" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.h new file mode 100644 index 0000000000..17388f8d05 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml +/* proj-data +{ + "generated_from": "e94a05618f2ad92dd7b3328a1d9c6786" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_SUM_DEGREE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_SUM_DEGREE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct SumDegree { + SumDegree() = delete; + SumDegree(int const &value); + + bool operator==(SumDegree const &) const; + bool operator!=(SumDegree const &) const; + bool operator<(SumDegree const &) const; + bool operator>(SumDegree const &) const; + bool operator<=(SumDegree const &) const; + bool operator>=(SumDegree const &) const; + int value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SumDegree const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::SumDegree from_json(json const &); + static void to_json(json &, FlexFlow::SumDegree const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(SumDegree const &); +std::ostream &operator<<(std::ostream &, SumDegree const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_SUM_DEGREE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml new file mode 100644 index 0000000000..d86917211e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "SumDegree" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/param_sync.dtg.h b/lib/op-attrs/include/op-attrs/param_sync.dtg.h new file mode 100644 index 0000000000..785105fbc4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/param_sync.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/param_sync.enum.toml +/* proj-data +{ + "generated_from": "288c6e9e256cf58ba5dbd0e3791c08df" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARAM_SYNC_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARAM_SYNC_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class ParamSync { PS, NCCL }; +std::string format_as(ParamSync); +std::ostream &operator<<(std::ostream &, ParamSync); +void to_json(::nlohmann::json &, ParamSync); +void from_json(::nlohmann::json const &, ParamSync &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParamSync) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARAM_SYNC_DTG_H diff --git a/lib/op-attrs/include/op-attrs/param_sync.enum.toml b/lib/op-attrs/include/op-attrs/param_sync.enum.toml new file mode 100644 index 0000000000..b16a47ab3c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/param_sync.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ParamSync" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "PS" + +[[values]] +name = "NCCL" diff --git a/lib/op-attrs/include/op-attrs/param_sync.h b/lib/op-attrs/include/op-attrs/param_sync.h index bfae1e712b..dd7048ff36 100644 --- a/lib/op-attrs/include/op-attrs/param_sync.h +++ b/lib/op-attrs/include/op-attrs/param_sync.h @@ -1,36 +1,8 @@ #ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_PARAM_SYNC_H #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_PARAM_SYNC_H -#include "utils/fmt.h" +#include "param_sync_t.h" -namespace FlexFlow { - -enum class ParamSync { PS, NCCL }; - -} - -namespace fmt { - -template <> -struct formatter<::FlexFlow::ParamSync> : formatter { - template - auto format(::FlexFlow::ParamSync ps, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (ps) { - case ParamSync::PS: - name = "ParameterServer"; - break; - case ParamSync::NCCL: - name = "NCCL"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt +namespace FlexFlow {} #endif diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h new file mode 100644 index 0000000000..5370773a45 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h @@ -0,0 +1,495 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml +/* proj-data +{ + "generated_from": "e1d10b0c7c98524c27886bdae0972321" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PCG_OPERATOR_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PCG_OPERATOR_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ops/attention_attrs.dtg.h" +#include "op-attrs/ops/batch_matmul.dtg.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" +#include "op-attrs/ops/cast_attrs.dtg.h" +#include "op-attrs/ops/combine_attrs.dtg.h" +#include "op-attrs/ops/concat_attrs.dtg.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" +#include "op-attrs/ops/flat_attrs.dtg.h" +#include "op-attrs/ops/gather_attrs.dtg.h" +#include "op-attrs/ops/input_attrs.dtg.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" +#include "op-attrs/ops/linear_attrs.dtg.h" +#include "op-attrs/ops/noop_attrs.dtg.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" +#include "op-attrs/ops/reduction_attrs.dtg.h" +#include "op-attrs/ops/repartition_attrs.dtg.h" +#include "op-attrs/ops/replicate_attrs.dtg.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" +#include "op-attrs/ops/split_attrs.dtg.h" +#include "op-attrs/ops/topk_attrs.dtg.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct PCGOperatorAttrs { + PCGOperatorAttrs() = delete; + explicit PCGOperatorAttrs(::FlexFlow::BatchMatmulAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::BatchNormAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::CastAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::CombineAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ConcatAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::Conv2DAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::DropoutAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ElementBinaryAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ElementUnaryAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ElementScalarUnaryAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::EmbeddingAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::FlatAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::GatherAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::InputAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::LayerNormAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::LinearAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::MultiHeadAttentionAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::NoopAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::Pool2DAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ReduceAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ReductionAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::RepartitionAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ReplicateAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ReverseAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ReshapeAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::SplitAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::SoftmaxAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::TopKAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::TransposeAttrs const &); + template + static constexpr bool IsPartOfPCGOperatorAttrs_v = + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::BatchMatmulAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::BatchNormAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::CastAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::CombineAttrs>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); + return result; + } + case 7: { + ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); + return result; + } + case 9: { + ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); + return result; + } + case 13: { + ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); + return result; + } + case 14: { + ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); + return result; + } + case 15: { + ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); + return result; + } + case 16: { + ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); + return result; + } + case 17: { + ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); + return result; + } + case 18: { + ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); + return result; + } + case 19: { + ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); + return result; + } + case 20: { + ReturnType result = v(this->get<::FlexFlow::ReductionAttrs>()); + return result; + } + case 21: { + ReturnType result = v(this->get<::FlexFlow::RepartitionAttrs>()); + return result; + } + case 22: { + ReturnType result = v(this->get<::FlexFlow::ReplicateAttrs>()); + return result; + } + case 23: { + ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); + return result; + } + case 24: { + ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + return result; + } + case 25: { + ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + return result; + } + case 26: { + ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + return result; + } + case 27: { + ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + return result; + } + case 28: { + ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type PCGOperatorAttrs", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::BatchMatmulAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::BatchNormAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::CastAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::CombineAttrs>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); + return result; + } + case 7: { + ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); + return result; + } + case 9: { + ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); + return result; + } + case 13: { + ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); + return result; + } + case 14: { + ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); + return result; + } + case 15: { + ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); + return result; + } + case 16: { + ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); + return result; + } + case 17: { + ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); + return result; + } + case 18: { + ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); + return result; + } + case 19: { + ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); + return result; + } + case 20: { + ReturnType result = v(this->get<::FlexFlow::ReductionAttrs>()); + return result; + } + case 21: { + ReturnType result = v(this->get<::FlexFlow::RepartitionAttrs>()); + return result; + } + case 22: { + ReturnType result = v(this->get<::FlexFlow::ReplicateAttrs>()); + return result; + } + case 23: { + ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); + return result; + } + case 24: { + ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + return result; + } + case 25: { + ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + return result; + } + case 26: { + ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + return result; + } + case 27: { + ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + return result; + } + case 28: { + ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type PCGOperatorAttrs", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfPCGOperatorAttrs_v, + "PCGOperatorAttrs::has() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::CastAttrs, ::FlexFlow::CombineAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReductionAttrs, ::FlexFlow::RepartitionAttrs, " + "::FlexFlow::ReplicateAttrs, ::FlexFlow::ReverseAttrs, " + "::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, " + "::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, " + "::FlexFlow::TransposeAttrs], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfPCGOperatorAttrs_v, + "PCGOperatorAttrs::get() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::CastAttrs, ::FlexFlow::CombineAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReductionAttrs, ::FlexFlow::RepartitionAttrs, " + "::FlexFlow::ReplicateAttrs, ::FlexFlow::ReverseAttrs, " + "::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, " + "::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, " + "::FlexFlow::TransposeAttrs], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfPCGOperatorAttrs_v, + "PCGOperatorAttrs::get() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::CastAttrs, ::FlexFlow::CombineAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReductionAttrs, ::FlexFlow::RepartitionAttrs, " + "::FlexFlow::ReplicateAttrs, ::FlexFlow::ReverseAttrs, " + "::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, " + "::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, " + "::FlexFlow::TransposeAttrs], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(PCGOperatorAttrs const &) const; + bool operator!=(PCGOperatorAttrs const &) const; + bool operator<(PCGOperatorAttrs const &) const; + bool operator>(PCGOperatorAttrs const &) const; + bool operator<=(PCGOperatorAttrs const &) const; + bool operator>=(PCGOperatorAttrs const &) const; + std::variant<::FlexFlow::BatchMatmulAttrs, + ::FlexFlow::BatchNormAttrs, + ::FlexFlow::CastAttrs, + ::FlexFlow::CombineAttrs, + ::FlexFlow::ConcatAttrs, + ::FlexFlow::Conv2DAttrs, + ::FlexFlow::DropoutAttrs, + ::FlexFlow::ElementBinaryAttrs, + ::FlexFlow::ElementUnaryAttrs, + ::FlexFlow::ElementScalarUnaryAttrs, + ::FlexFlow::EmbeddingAttrs, + ::FlexFlow::FlatAttrs, + ::FlexFlow::GatherAttrs, + ::FlexFlow::InputAttrs, + ::FlexFlow::LayerNormAttrs, + ::FlexFlow::LinearAttrs, + ::FlexFlow::MultiHeadAttentionAttrs, + ::FlexFlow::NoopAttrs, + ::FlexFlow::Pool2DAttrs, + ::FlexFlow::ReduceAttrs, + ::FlexFlow::ReductionAttrs, + ::FlexFlow::RepartitionAttrs, + ::FlexFlow::ReplicateAttrs, + ::FlexFlow::ReverseAttrs, + ::FlexFlow::ReshapeAttrs, + ::FlexFlow::SplitAttrs, + ::FlexFlow::SoftmaxAttrs, + ::FlexFlow::TopKAttrs, + ::FlexFlow::TransposeAttrs> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::PCGOperatorAttrs> { + size_t operator()(::FlexFlow::PCGOperatorAttrs const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::PCGOperatorAttrs> { + static ::FlexFlow::PCGOperatorAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::PCGOperatorAttrs const &); +}; +} // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::PCGOperatorAttrs> { + static Gen<::FlexFlow::PCGOperatorAttrs> arbitrary(); +}; +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::PCGOperatorAttrs const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::PCGOperatorAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PCG_OPERATOR_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h new file mode 100644 index 0000000000..0ad7a9f829 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PCG_OPERATOR_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PCG_OPERATOR_ATTRS_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" + +namespace FlexFlow { + +bool is_parallel_op(PCGOperatorAttrs const &); +OperatorType get_op_type(PCGOperatorAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml new file mode 100644 index 0000000000..ddb8a109d8 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml @@ -0,0 +1,158 @@ +namespace = "FlexFlow" +name = "PCGOperatorAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ops/attention_attrs.dtg.h", + "op-attrs/ops/batch_matmul.dtg.h", + "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/cast_attrs.dtg.h", + "op-attrs/ops/combine_attrs.dtg.h", + "op-attrs/ops/concat_attrs.dtg.h", + "op-attrs/ops/conv_2d_attrs.dtg.h", + "op-attrs/ops/dropout_attrs.dtg.h", + "op-attrs/ops/element_binary_attrs.dtg.h", + "op-attrs/ops/element_scalar_unary_attrs.dtg.h", + "op-attrs/ops/element_unary_attrs.dtg.h", + "op-attrs/ops/embedding_attrs.dtg.h", + "op-attrs/ops/flat_attrs.dtg.h", + "op-attrs/ops/gather_attrs.dtg.h", + "op-attrs/ops/input_attrs.dtg.h", + "op-attrs/ops/layer_norm_attrs.dtg.h", + "op-attrs/ops/linear_attrs.dtg.h", + "op-attrs/ops/noop_attrs.dtg.h", + "op-attrs/ops/pool_2d_attrs.dtg.h", + "op-attrs/ops/reduce_attrs.dtg.h", + "op-attrs/ops/reduction_attrs.dtg.h", + "op-attrs/ops/repartition_attrs.dtg.h", + "op-attrs/ops/replicate_attrs.dtg.h", + "op-attrs/ops/reshape_attrs.dtg.h", + "op-attrs/ops/reverse_attrs.dtg.h", + "op-attrs/ops/softmax_attrs.dtg.h", + "op-attrs/ops/split_attrs.dtg.h", + "op-attrs/ops/topk_attrs.dtg.h", + "op-attrs/ops/transpose_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::BatchMatmulAttrs" +key = "batch_matmul" + +[[values]] +type = "::FlexFlow::BatchNormAttrs" +key = "batch_norm" + +[[values]] +type = "::FlexFlow::CastAttrs" +key = "cast" + +[[values]] +type = "::FlexFlow::CombineAttrs" +key = "combine_distributed" + +[[values]] +type = "::FlexFlow::ConcatAttrs" +key = "concat" + +[[values]] +type = "::FlexFlow::Conv2DAttrs" +key = "conv2d" + +[[values]] +type = "::FlexFlow::DropoutAttrs" +key = "dropout" + +[[values]] +type = "::FlexFlow::ElementBinaryAttrs" +key = "element_binary" + +[[values]] +type = "::FlexFlow::ElementUnaryAttrs" +key = "element_unary" + +[[values]] +type = "::FlexFlow::ElementScalarUnaryAttrs" +key = "element_scalar_unary" + +[[values]] +type = "::FlexFlow::EmbeddingAttrs" +key = "embedding" + +[[values]] +type = "::FlexFlow::FlatAttrs" +key = "flat" + +[[values]] +type = "::FlexFlow::GatherAttrs" +key = "gather" + +[[values]] +type = "::FlexFlow::InputAttrs" +key = "input" + +[[values]] +type = "::FlexFlow::LayerNormAttrs" +key = "layer_norm" + +[[values]] +type = "::FlexFlow::LinearAttrs" +key = "linear" + +[[values]] +type = "::FlexFlow::MultiHeadAttentionAttrs" +key = "multi_head_attention" + +[[values]] +type = "::FlexFlow::NoopAttrs" +key = "noop" + +[[values]] +type = "::FlexFlow::Pool2DAttrs" +key = "pool2d" + +[[values]] +type = "::FlexFlow::ReduceAttrs" +key = "reduce" + +[[values]] +type = "::FlexFlow::ReductionAttrs" +key = "reduce_distributed" + +[[values]] +type = "::FlexFlow::RepartitionAttrs" +key = "partition_distributed" + +[[values]] +type = "::FlexFlow::ReplicateAttrs" +key = "replicate_distributed" + +[[values]] +type = "::FlexFlow::ReverseAttrs" +key = "reverse" + +[[values]] +type = "::FlexFlow::ReshapeAttrs" +key = "reshape" + +[[values]] +type = "::FlexFlow::SplitAttrs" +key = "split" + +[[values]] +type = "::FlexFlow::SoftmaxAttrs" +key = "softmax" + +[[values]] +type = "::FlexFlow::TopKAttrs" +key = "topk" + +[[values]] +type = "::FlexFlow::TransposeAttrs" +key = "transpose" diff --git a/lib/op-attrs/include/op-attrs/pool_op.dtg.h b/lib/op-attrs/include/op-attrs/pool_op.dtg.h new file mode 100644 index 0000000000..3511589b52 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pool_op.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/pool_op.enum.toml +/* proj-data +{ + "generated_from": "ed1d531c6227306c909eb28eb0a66538" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class PoolOp { MAX, AVG }; +std::string format_as(PoolOp); +std::ostream &operator<<(std::ostream &, PoolOp); +void to_json(::nlohmann::json &, PoolOp); +void from_json(::nlohmann::json const &, PoolOp &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::PoolOp) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_DTG_H diff --git a/lib/op-attrs/include/op-attrs/pool_op.enum.toml b/lib/op-attrs/include/op-attrs/pool_op.enum.toml new file mode 100644 index 0000000000..88f4dfea19 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pool_op.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "PoolOp" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "MAX" + +[[values]] +name = "AVG" diff --git a/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.h b/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.h new file mode 100644 index 0000000000..2621b4b12c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.h @@ -0,0 +1,128 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml +/* proj-data +{ + "generated_from": "ea060a8ab344c9772102f084903883ea" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REGULARIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REGULARIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/l1_regularizer_attrs.dtg.h" +#include "op-attrs/l2_regularizer_attrs.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct RegularizerAttrs { + RegularizerAttrs() = delete; + explicit RegularizerAttrs(::FlexFlow::L1RegularizerAttrs const &); + explicit RegularizerAttrs(::FlexFlow::L2RegularizerAttrs const &); + template + static constexpr bool IsPartOfRegularizerAttrs_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::L1RegularizerAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::L2RegularizerAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type RegularizerAttrs", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::L1RegularizerAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::L2RegularizerAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type RegularizerAttrs", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfRegularizerAttrs_v, + "RegularizerAttrs::has() expected one of " + "[::FlexFlow::L1RegularizerAttrs, " + "::FlexFlow::L2RegularizerAttrs], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfRegularizerAttrs_v, + "RegularizerAttrs::get() expected one of " + "[::FlexFlow::L1RegularizerAttrs, " + "::FlexFlow::L2RegularizerAttrs], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfRegularizerAttrs_v, + "RegularizerAttrs::get() expected one of " + "[::FlexFlow::L1RegularizerAttrs, " + "::FlexFlow::L2RegularizerAttrs], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(RegularizerAttrs const &) const; + bool operator!=(RegularizerAttrs const &) const; + bool operator<(RegularizerAttrs const &) const; + bool operator>(RegularizerAttrs const &) const; + bool operator<=(RegularizerAttrs const &) const; + bool operator>=(RegularizerAttrs const &) const; + std::variant<::FlexFlow::L1RegularizerAttrs, ::FlexFlow::L2RegularizerAttrs> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::RegularizerAttrs> { + size_t operator()(::FlexFlow::RegularizerAttrs const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::RegularizerAttrs> { + static ::FlexFlow::RegularizerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::RegularizerAttrs const &); +}; +} // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::RegularizerAttrs> { + static Gen<::FlexFlow::RegularizerAttrs> arbitrary(); +}; +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::RegularizerAttrs const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::RegularizerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REGULARIZER_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml b/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml new file mode 100644 index 0000000000..d650c7f6a9 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "RegularizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/l1_regularizer_attrs.dtg.h", + "op-attrs/l2_regularizer_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::L1RegularizerAttrs" +key = "l1" + +[[values]] +type = "::FlexFlow::L2RegularizerAttrs" +key = "l2" diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.h new file mode 100644 index 0000000000..250ba29947 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.h @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml +/* proj-data +{ + "generated_from": "f501393070c8d55a05c43dd73a81a8d7" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/replica_type.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ReplicaParallelDim { + ReplicaParallelDim() = delete; + ReplicaParallelDim(int const °ree, + ::FlexFlow::ReplicaType const &replica_type); + + bool operator==(ReplicaParallelDim const &) const; + bool operator!=(ReplicaParallelDim const &) const; + bool operator<(ReplicaParallelDim const &) const; + bool operator>(ReplicaParallelDim const &) const; + bool operator<=(ReplicaParallelDim const &) const; + bool operator>=(ReplicaParallelDim const &) const; + int degree; + ::FlexFlow::ReplicaType replica_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReplicaParallelDim const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReplicaParallelDim from_json(json const &); + static void to_json(json &, FlexFlow::ReplicaParallelDim const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicaParallelDim const &); +std::ostream &operator<<(std::ostream &, ReplicaParallelDim const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_DTG_H diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim.h new file mode 100644 index 0000000000..da3913b426 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_H + +#include "op-attrs/replica_parallel_dim.dtg.h" + +namespace FlexFlow { + +bool is_valid(ReplicaParallelDim const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml b/lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml new file mode 100644 index 0000000000..2ad442aa22 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "ReplicaParallelDim" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/replica_type.dtg.h", +] + +[[fields]] +name = "degree" +type = "int" + +[[fields]] +name = "replica_type" +type = "::FlexFlow::ReplicaType" diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h new file mode 100644 index 0000000000..321029347f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h @@ -0,0 +1,67 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml +/* proj-data +{ + "generated_from": "74230e2d18db5c059d3e7be0f25e746e" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_SET_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_SET_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" +#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ReplicaParallelDimSet { + ReplicaParallelDimSet() = delete; + ReplicaParallelDimSet( + ::FlexFlow::SumDegree const &sum_degree, + ::FlexFlow::DiscardCopyDegree const &discard_copy_degree); + + bool operator==(ReplicaParallelDimSet const &) const; + bool operator!=(ReplicaParallelDimSet const &) const; + bool operator<(ReplicaParallelDimSet const &) const; + bool operator>(ReplicaParallelDimSet const &) const; + bool operator<=(ReplicaParallelDimSet const &) const; + bool operator>=(ReplicaParallelDimSet const &) const; + ::FlexFlow::SumDegree sum_degree; + ::FlexFlow::DiscardCopyDegree discard_copy_degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReplicaParallelDimSet const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReplicaParallelDimSet from_json(json const &); + static void to_json(json &, FlexFlow::ReplicaParallelDimSet const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicaParallelDimSet const &); +std::ostream &operator<<(std::ostream &, ReplicaParallelDimSet const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_SET_DTG_H diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.h new file mode 100644 index 0000000000..74a8df339b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_SET_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_SET_H + +#include "op-attrs/replica_parallel_dim.dtg.h" +#include "op-attrs/replica_parallel_dim_set.dtg.h" +#include "op-attrs/replica_type.dtg.h" + +namespace FlexFlow { + +ReplicaParallelDimSet empty_replica_parallel_dim_set(); +int get_degree_of_replica_type(ReplicaParallelDimSet const &, ReplicaType); +std::unordered_set + get_replica_dims(ReplicaParallelDimSet const &); +bool is_valid(ReplicaParallelDimSet const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml new file mode 100644 index 0000000000..66f50bee9f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ReplicaParallelDimSet" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", + "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", +] + +[[fields]] +name = "sum_degree" +type = "::FlexFlow::SumDegree" + +[[fields]] +name = "discard_copy_degree" +type = "::FlexFlow::DiscardCopyDegree" diff --git a/lib/op-attrs/include/op-attrs/replica_type.dtg.h b/lib/op-attrs/include/op-attrs/replica_type.dtg.h new file mode 100644 index 0000000000..3b965d3e77 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_type.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_type.enum.toml +/* proj-data +{ + "generated_from": "6ecba7a6851b8bea93705bba24661149" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_TYPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_TYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class ReplicaType { SUM, DISCARD_COPY }; +std::string format_as(ReplicaType); +std::ostream &operator<<(std::ostream &, ReplicaType); +void to_json(::nlohmann::json &, ReplicaType); +void from_json(::nlohmann::json const &, ReplicaType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReplicaType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_TYPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/replica_type.enum.toml b/lib/op-attrs/include/op-attrs/replica_type.enum.toml new file mode 100644 index 0000000000..0c0eb5e3ab --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_type.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ReplicaType" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "SUM" + +[[values]] +name = "DISCARD_COPY" diff --git a/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.h b/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.h new file mode 100644 index 0000000000..631852c259 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml +/* proj-data +{ + "generated_from": "18e074f80556d90b9b27d6515bbf9071" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHARD_PARALLEL_DIM_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHARD_PARALLEL_DIM_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ShardParallelDim { + ShardParallelDim() = delete; + ShardParallelDim(size_t const &size, int const °ree); + + bool operator==(ShardParallelDim const &) const; + bool operator!=(ShardParallelDim const &) const; + bool operator<(ShardParallelDim const &) const; + bool operator>(ShardParallelDim const &) const; + bool operator<=(ShardParallelDim const &) const; + bool operator>=(ShardParallelDim const &) const; + size_t size; + int degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ShardParallelDim const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ShardParallelDim from_json(json const &); + static void to_json(json &, FlexFlow::ShardParallelDim const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ShardParallelDim const &); +std::ostream &operator<<(std::ostream &, ShardParallelDim const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHARD_PARALLEL_DIM_DTG_H diff --git a/lib/op-attrs/include/op-attrs/shard_parallel_dim.h b/lib/op-attrs/include/op-attrs/shard_parallel_dim.h new file mode 100644 index 0000000000..0a6323192d --- /dev/null +++ b/lib/op-attrs/include/op-attrs/shard_parallel_dim.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHARD_PARALLEL_DIM_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHARD_PARALLEL_DIM_H + +#include "op-attrs/shard_parallel_dim.dtg.h" + +namespace FlexFlow { + +bool is_valid(ShardParallelDim const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml b/lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml new file mode 100644 index 0000000000..21c81396d1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "ShardParallelDim" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "size" +type = "size_t" + +[[fields]] +name = "degree" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h b/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h new file mode 100644 index 0000000000..a8e46a4626 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/tensor_dims.struct.toml +/* proj-data +{ + "generated_from": "5beb89eeae9eba303f90e726c794375d" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/dim_ordered.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct TensorDims { + TensorDims() = delete; + TensorDims(::FlexFlow::FFOrdered const &ff_ordered); + + bool operator==(TensorDims const &) const; + bool operator!=(TensorDims const &) const; + bool operator<(TensorDims const &) const; + bool operator>(TensorDims const &) const; + bool operator<=(TensorDims const &) const; + bool operator>=(TensorDims const &) const; + ::FlexFlow::FFOrdered ff_ordered; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorDims const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorDims from_json(json const &); + static void to_json(json &, FlexFlow::TensorDims const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorDims const &); +std::ostream &operator<<(std::ostream &, TensorDims const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h new file mode 100644 index 0000000000..2391197471 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_H + +#include "op-attrs/parallel_tensor_dims.dtg.h" +#include "op-attrs/tensor_dims.dtg.h" + +namespace FlexFlow { + +FFOrdered const &ff_ordered(TensorDims const &); + +size_t num_dims(TensorDims const &); +size_t dim_at_idx(TensorDims const &, ff_dim_t); +size_t &dim_at_idx(TensorDims &, ff_dim_t); + +ParallelTensorDims lift_to_parallel(TensorDims const &); +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &, + SumDegree sum_degree, + DiscardCopyDegree discard_copy_degree, + FFOrdered const &shard_degrees); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml new file mode 100644 index 0000000000..cff8e08b0f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "TensorDims" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +includes = [ + "op-attrs/dim_ordered.h", +] + +[[fields]] +name = "ff_ordered" +type = "::FlexFlow::FFOrdered" diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h b/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h new file mode 100644 index 0000000000..f36d5d1306 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/tensor_shape.struct.toml +/* proj-data +{ + "generated_from": "ef6fa5088b89d6da4dc8bddf0a6d3294" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SHAPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SHAPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/tensor_dims.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct TensorShape { + TensorShape() = delete; + TensorShape(::FlexFlow::TensorDims const &dims, + ::FlexFlow::DataType const &data_type); + + bool operator==(TensorShape const &) const; + bool operator!=(TensorShape const &) const; + bool operator<(TensorShape const &) const; + bool operator>(TensorShape const &) const; + bool operator<=(TensorShape const &) const; + bool operator>=(TensorShape const &) const; + ::FlexFlow::TensorDims dims; + ::FlexFlow::DataType data_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorShape const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorShape from_json(json const &); + static void to_json(json &, FlexFlow::TensorShape const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorShape const &); +std::ostream &operator<<(std::ostream &, TensorShape const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SHAPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.h b/lib/op-attrs/include/op-attrs/tensor_shape.h index fa34860817..ad751461e8 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.h @@ -1,34 +1,14 @@ #ifndef _FLEXFLOW_OPATTRS_TENSOR_SHAPE_H #define _FLEXFLOW_OPATTRS_TENSOR_SHAPE_H -#include "datatype.h" -#include "op-attrs/dim_ordered.h" -#include "op-attrs/ff_dim.h" -#include "utils/stack_vector.h" -#include "utils/visitable.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -using TensorDims = FFOrdered; - -struct TensorShape : public use_visitable_cmp { - TensorShape() = delete; - - template - TensorShape(Dims const &dims, DataType data_type) - : dims(dims), data_type(data_type) {} - - size_t at(ff_dim_t) const; - size_t operator[](ff_dim_t) const; - -public: - TensorDims dims; - DataType data_type; -}; +size_t num_dims(TensorShape const &); +size_t dim_at_idx(TensorShape const &, ff_dim_t); +size_t &dim_at_idx(TensorShape &, ff_dim_t); } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::TensorShape, dims, data_type); -MAKE_VISIT_HASHABLE(::FlexFlow::TensorShape); - #endif diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml b/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml new file mode 100644 index 0000000000..901c3b9e60 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "TensorShape" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_dims.dtg.h", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "dims" +type = "::FlexFlow::TensorDims" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/src/batch_matmul.cc b/lib/op-attrs/src/batch_matmul.cc deleted file mode 100644 index 1cc8c5cfda..0000000000 --- a/lib/op-attrs/src/batch_matmul.cc +++ /dev/null @@ -1,26 +0,0 @@ -#include "op-attrs/ops/batch_matmul.h" - -namespace FlexFlow { - -/* bool BatchMatmulAttrs::is_valid( */ -/* ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const { - */ -/* if (!lhs.is_valid() || !rhs.is_valid()) { */ -/* return false; */ -/* } */ -/* if (lhs.num_dims() != rhs.num_dims()) { */ -/* return false; */ -/* } */ -/* for (int i = lhs.num_dims() - 1; i >= 2; i--) { */ -/* if (lhs.at(i) != rhs.at(i)) { */ -/* return false; */ -/* } */ -/* } */ -/* if (lhs.at(0) != rhs.at(1)) { */ -/* return false; */ -/* } */ - -/* return true; */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/batch_norm.cc b/lib/op-attrs/src/batch_norm.cc deleted file mode 100644 index 4e352d5f1c..0000000000 --- a/lib/op-attrs/src/batch_norm.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/batch_norm.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/broadcast.cc b/lib/op-attrs/src/broadcast.cc deleted file mode 100644 index c69f480b84..0000000000 --- a/lib/op-attrs/src/broadcast.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/broadcast.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/combine.cc b/lib/op-attrs/src/combine.cc deleted file mode 100644 index cdca524538..0000000000 --- a/lib/op-attrs/src/combine.cc +++ /dev/null @@ -1,18 +0,0 @@ -#include "op-attrs/ops/combine.h" -#include "utils/hash-utils.h" - -namespace FlexFlow { - -/* bool CombineAttrs::is_valid(ParallelTensorShape const &input) const { */ -/* return input.at(this->combine_legion_dim).degree % this->combine_degree == - * 0; */ -/* } */ - -/* ParallelTensorShape CombineAttrs::output_shape(ParallelTensorShape const - * &input_shape) const { */ -/* ParallelTensorShape output = input_shape; */ -/* output.at(this->combine_legion_dim).degree /= this->combine_degree; */ -/* return output; */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/conv_2d.cc b/lib/op-attrs/src/conv_2d.cc deleted file mode 100644 index 40ba3c8b41..0000000000 --- a/lib/op-attrs/src/conv_2d.cc +++ /dev/null @@ -1,115 +0,0 @@ -#include "op-attrs/ops/conv_2d.h" -#include "parallel_dim_mapping_record.h" -#include "parallel_dim_mapping_record_solver.h" -#include "utils/vector.h" - -namespace FlexFlow { - -namespace Input { -constexpr int WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, REPLICA = 4, - NUMDIM = 5; -} - -namespace Output { -constexpr int WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, REPLICA = 4, - NUMDIM = 5; -} - -namespace Kernel { -constexpr int WIDTH = 0, HEIGHT = 1, CHANNEL_IN = 2, CHANNEL_OUT = 3, - REPLICA = 4; -constexpr int WEIGHT_IDX = 0; -} // namespace Kernel - -namespace Bias { -constexpr int CHANNEL = 0, REPLICA_1 = 1, REPLICA_2 = 2, REPLICA_3 = 3, - REPLICA_4 = 4; -constexpr int WEIGHT_IDX = 1; -} // namespace Bias - -static std::vector - construct_output_mappings(ParallelTensorShape const &input_shape) { - return construct_output_parallel_dims( - {{Input::CHANNEL, MappingOperation::REPLICATE, Output::REPLICA}, - {Input::SAMPLE, MappingOperation::PARTITION, Output::SAMPLE}, - {Input::REPLICA, MappingOperation::PARTITION, Output::CHANNEL}, - {Input::HEIGHT, MappingOperation::PARTITION, Output::HEIGHT}, - {Input::WIDTH, MappingOperation::PARTITION, Output::WIDTH}}); -} - -static std::vector - construct_kernel_mappings(ParallelTensorShape const &input_shape) { - return construct_weight_parallel_dims( - { - {Input::REPLICA, MappingOperation::PARTITION, Kernel::CHANNEL_OUT}, - {Input::SAMPLE, MappingOperation::REPLICATE, Kernel::REPLICA}, - {Input::CHANNEL, MappingOperation::PARTITION, Kernel::CHANNEL_IN}, - {Input::HEIGHT, - MappingOperation::REPLICATE, - Kernel::HEIGHT}, // Kernel::{HEIGHT, WEIGHT} would both work - // here - {Input::WIDTH, - MappingOperation::REPLICATE, - Kernel::WIDTH}, // same as above - }, - 0, - Kernel::WEIGHT_IDX); -} - -static std::vector - construct_bias_mappings(ParallelTensorShape const &input_shape) { - return construct_weight_parallel_dims({{Input::REPLICA, Bias::REPLICA_1}, - {Input::SAMPLE, Bias::REPLICA_2}, - {Input::CHANNEL, Bias::CHANNEL}, - {Input::HEIGHT, Bias::REPLICA_3}, - {Input::WIDTH, Bias::REPLICA_4}}, - 0, - Bias::WEIGHT_IDX); -} - -std::vector - construct_mappings(ParallelTensorShape const &input_shape, bool use_bias) { - std::vector mappings = - concat(construct_output_mappings(input_shape), - construct_kernel_mappings(input_shape)); - if (use_bias) { - std::vector bias_mappings = - construct_bias_mappings(input_shape); - mappings.insert(mappings.end(), bias_mappings.begin(), bias_mappings.end()); - } - - return mappings; -} - -TensorShape get_kernel_shape(Conv2DAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); -} - -TensorShape get_bias_shape(Conv2DAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); -} - -/* bool Conv2DAttrs::is_valid(ParallelTensorShape const &input_shape) const { */ -/* bool is_valid = true; */ -/* is_valid &= input_shape.is_valid(); */ -/* is_valid &= this->calculate_output_shape(input_shape).is_valid(); */ -/* is_valid &= this->calculate_kernel_shape(input_shape).is_valid(); */ -/* if (use_bias) { */ -/* is_valid &= this->calculate_bias_shape(input_shape).is_valid(); */ -/* } */ - -/* // TODO FIXME: Currently disable parallelizing the height and width - * dimension */ -/* if (input_shape.at(0).degree > 1 || input_shape.at(1).degree > 1) { */ -/* return false; */ -/* } */ - -/* return is_valid; */ - -/* } */ - -/* OperatorType Conv2DAttrs::op_type() const { */ -/* return OP_CONV2D; */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/element_binary.cc b/lib/op-attrs/src/element_binary.cc deleted file mode 100644 index b713c6753f..0000000000 --- a/lib/op-attrs/src/element_binary.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/element_binary.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/element_unary.cc b/lib/op-attrs/src/element_unary.cc deleted file mode 100644 index 481151fafb..0000000000 --- a/lib/op-attrs/src/element_unary.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/element_unary.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/embedding.cc b/lib/op-attrs/src/embedding.cc deleted file mode 100644 index 56014fcc67..0000000000 --- a/lib/op-attrs/src/embedding.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "op-attrs/ops/embedding.h" - -namespace FlexFlow { - -TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/linear.cc b/lib/op-attrs/src/linear.cc deleted file mode 100644 index 16a94e7f6c..0000000000 --- a/lib/op-attrs/src/linear.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/linear.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/noop.cc b/lib/op-attrs/src/noop.cc deleted file mode 100644 index 387660164f..0000000000 --- a/lib/op-attrs/src/noop.cc +++ /dev/null @@ -1 +0,0 @@ -#include "op-attrs/ops/noop.h" diff --git a/lib/op-attrs/src/op-attrs/activation.dtg.cc b/lib/op-attrs/src/op-attrs/activation.dtg.cc new file mode 100644 index 0000000000..5671b1720f --- /dev/null +++ b/lib/op-attrs/src/op-attrs/activation.dtg.cc @@ -0,0 +1,86 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/activation.enum.toml +/* proj-data +{ + "generated_from": "2b0d2e3e825732838aa5be99f2f0e6df" +} +*/ + +#include "op-attrs/activation.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::Activation x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(Activation x) { + switch (x) { + case Activation::RELU: + return "RELU"; + case Activation::SIGMOID: + return "SIGMOID"; + case Activation::TANH: + return "TANH"; + case Activation::GELU: + return "GELU"; + default: + std::ostringstream oss; + oss << "Unknown Activation value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, Activation x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, Activation x) { + switch (x) { + case Activation::RELU: + j = "RELU"; + break; + case Activation::SIGMOID: + j = "SIGMOID"; + break; + case Activation::TANH: + j = "TANH"; + break; + case Activation::GELU: + j = "GELU"; + break; + default: + std::ostringstream oss; + oss << "Unknown Activation value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, Activation &x) { + std::string as_str = j.get(); + if (as_str == "RELU") { + x = Activation::RELU; + } else if (as_str == "SIGMOID") { + x = Activation::SIGMOID; + } else if (as_str == "TANH") { + x = Activation::TANH; + } else if (as_str == "GELU") { + x = Activation::GELU; + } else { + std::ostringstream oss; + oss << "Unknown Activation value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::Activation::RELU, + FlexFlow::Activation::SIGMOID, + FlexFlow::Activation::TANH, + FlexFlow::Activation::GELU); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/aggregate_op.dtg.cc b/lib/op-attrs/src/op-attrs/aggregate_op.dtg.cc new file mode 100644 index 0000000000..72beeb27c8 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/aggregate_op.dtg.cc @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/aggregate_op.enum.toml +/* proj-data +{ + "generated_from": "441fe9b0bb8f2dc2b31f74c58320ef30" +} +*/ + +#include "op-attrs/aggregate_op.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::AggregateOp x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(AggregateOp x) { + switch (x) { + case AggregateOp::SUM: + return "SUM"; + default: + std::ostringstream oss; + oss << "Unknown AggregateOp value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, AggregateOp x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, AggregateOp x) { + switch (x) { + case AggregateOp::SUM: + j = "SUM"; + break; + default: + std::ostringstream oss; + oss << "Unknown AggregateOp value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, AggregateOp &x) { + std::string as_str = j.get(); + if (as_str == "SUM") { + x = AggregateOp::SUM; + } else { + std::ostringstream oss; + oss << "Unknown AggregateOp value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::AggregateOp::SUM); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/as_dot.cc b/lib/op-attrs/src/op-attrs/as_dot.cc new file mode 100644 index 0000000000..f8d05de941 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/as_dot.cc @@ -0,0 +1,13 @@ +#include "op-attrs/as_dot.h" + +namespace FlexFlow { + +RecordFormatter as_dot(ComputationGraphOpAttrs const &attrs) { + NOT_IMPLEMENTED(); +} + +RecordFormatter as_dot(PCGOperatorAttrs const &attrs) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc new file mode 100644 index 0000000000..166416cbad --- /dev/null +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc @@ -0,0 +1,11 @@ +#include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/get_op_type.h" + +namespace FlexFlow { + +OperatorType get_op_type(ComputationGraphOpAttrs const &attrs) { + return attrs.visit( + [](auto const &x) { return get_op_type(x); }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc new file mode 100644 index 0000000000..9bcde22cd9 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc @@ -0,0 +1,597 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml +/* proj-data +{ + "generated_from": "cc0ab49405423594ffa1d8f541235a48" +} +*/ + +#include "op-attrs/computation_graph_op_attrs.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::BatchMatmulAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::BatchNormAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::BroadcastAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs(::FlexFlow::CastAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ConcatAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::Conv2DAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::DropoutAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ElementBinaryAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ElementUnaryAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ElementScalarUnaryAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::EmbeddingAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs(::FlexFlow::FlatAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::GatherAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::InputAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::LayerNormAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::LinearAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::MultiHeadAttentionAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs(::FlexFlow::NoopAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::Pool2DAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ReduceAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ReverseAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ReshapeAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::SplitAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::SoftmaxAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs(::FlexFlow::TopKAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::TransposeAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::WeightAttrs const &v) + : raw_variant(v) {} +bool ComputationGraphOpAttrs::operator==( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant == other.raw_variant; +} +bool ComputationGraphOpAttrs::operator!=( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant != other.raw_variant; +} +bool ComputationGraphOpAttrs::operator<( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant < other.raw_variant; +} +bool ComputationGraphOpAttrs::operator>( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant > other.raw_variant; +} +bool ComputationGraphOpAttrs::operator<=( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool ComputationGraphOpAttrs::operator>=( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::ComputationGraphOpAttrs>::operator()( + ::FlexFlow::ComputationGraphOpAttrs const &x) const { + return std::hash>{}(x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::ComputationGraphOpAttrs + adl_serializer<::FlexFlow::ComputationGraphOpAttrs>::from_json( + json const &j) { + std::string key = j.at("type").template get(); + if (key == "batch_matmul") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::BatchMatmulAttrs>()}; + } else if (key == "batch_norm") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::BatchNormAttrs>()}; + } else if (key == "broadcast") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::BroadcastAttrs>()}; + } else if (key == "cast") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::CastAttrs>()}; + } else if (key == "concat") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ConcatAttrs>()}; + } else if (key == "conv2d") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::Conv2DAttrs>()}; + } else if (key == "dropout") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::DropoutAttrs>()}; + } else if (key == "element_binary") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ElementBinaryAttrs>()}; + } else if (key == "element_unary") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ElementUnaryAttrs>()}; + } else if (key == "element_scalar_unary") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ElementScalarUnaryAttrs>()}; + } else if (key == "embedding") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::EmbeddingAttrs>()}; + } else if (key == "flat") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::FlatAttrs>()}; + } else if (key == "gather") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::GatherAttrs>()}; + } else if (key == "input") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::InputAttrs>()}; + } else if (key == "layer_norm") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::LayerNormAttrs>()}; + } else if (key == "linear") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::LinearAttrs>()}; + } else if (key == "multi_head_attention") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::MultiHeadAttentionAttrs>()}; + } else if (key == "noop") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::NoopAttrs>()}; + } else if (key == "pool2d") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::Pool2DAttrs>()}; + } else if (key == "reduce") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ReduceAttrs>()}; + } else if (key == "reverse") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ReverseAttrs>()}; + } else if (key == "reshape") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ReshapeAttrs>()}; + } else if (key == "split") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::SplitAttrs>()}; + } else if (key == "softmax") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::SoftmaxAttrs>()}; + } else if (key == "topk") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::TopKAttrs>()}; + } else if (key == "transpose") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::TransposeAttrs>()}; + } else if (key == "weight") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::WeightAttrs>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::ComputationGraphOpAttrs>::to_json( + json &j, ::FlexFlow::ComputationGraphOpAttrs const &x) { + j["__type"] = "ComputationGraphOpAttrs"; + switch (x.index()) { + case 0: { + j["type"] = "batch_matmul"; + j["value"] = x.get<::FlexFlow::BatchMatmulAttrs>(); + break; + } + case 1: { + j["type"] = "batch_norm"; + j["value"] = x.get<::FlexFlow::BatchNormAttrs>(); + break; + } + case 2: { + j["type"] = "broadcast"; + j["value"] = x.get<::FlexFlow::BroadcastAttrs>(); + break; + } + case 3: { + j["type"] = "cast"; + j["value"] = x.get<::FlexFlow::CastAttrs>(); + break; + } + case 4: { + j["type"] = "concat"; + j["value"] = x.get<::FlexFlow::ConcatAttrs>(); + break; + } + case 5: { + j["type"] = "conv2d"; + j["value"] = x.get<::FlexFlow::Conv2DAttrs>(); + break; + } + case 6: { + j["type"] = "dropout"; + j["value"] = x.get<::FlexFlow::DropoutAttrs>(); + break; + } + case 7: { + j["type"] = "element_binary"; + j["value"] = x.get<::FlexFlow::ElementBinaryAttrs>(); + break; + } + case 8: { + j["type"] = "element_unary"; + j["value"] = x.get<::FlexFlow::ElementUnaryAttrs>(); + break; + } + case 9: { + j["type"] = "element_scalar_unary"; + j["value"] = x.get<::FlexFlow::ElementScalarUnaryAttrs>(); + break; + } + case 10: { + j["type"] = "embedding"; + j["value"] = x.get<::FlexFlow::EmbeddingAttrs>(); + break; + } + case 11: { + j["type"] = "flat"; + j["value"] = x.get<::FlexFlow::FlatAttrs>(); + break; + } + case 12: { + j["type"] = "gather"; + j["value"] = x.get<::FlexFlow::GatherAttrs>(); + break; + } + case 13: { + j["type"] = "input"; + j["value"] = x.get<::FlexFlow::InputAttrs>(); + break; + } + case 14: { + j["type"] = "layer_norm"; + j["value"] = x.get<::FlexFlow::LayerNormAttrs>(); + break; + } + case 15: { + j["type"] = "linear"; + j["value"] = x.get<::FlexFlow::LinearAttrs>(); + break; + } + case 16: { + j["type"] = "multi_head_attention"; + j["value"] = x.get<::FlexFlow::MultiHeadAttentionAttrs>(); + break; + } + case 17: { + j["type"] = "noop"; + j["value"] = x.get<::FlexFlow::NoopAttrs>(); + break; + } + case 18: { + j["type"] = "pool2d"; + j["value"] = x.get<::FlexFlow::Pool2DAttrs>(); + break; + } + case 19: { + j["type"] = "reduce"; + j["value"] = x.get<::FlexFlow::ReduceAttrs>(); + break; + } + case 20: { + j["type"] = "reverse"; + j["value"] = x.get<::FlexFlow::ReverseAttrs>(); + break; + } + case 21: { + j["type"] = "reshape"; + j["value"] = x.get<::FlexFlow::ReshapeAttrs>(); + break; + } + case 22: { + j["type"] = "split"; + j["value"] = x.get<::FlexFlow::SplitAttrs>(); + break; + } + case 23: { + j["type"] = "softmax"; + j["value"] = x.get<::FlexFlow::SoftmaxAttrs>(); + break; + } + case 24: { + j["type"] = "topk"; + j["value"] = x.get<::FlexFlow::TopKAttrs>(); + break; + } + case 25: { + j["type"] = "transpose"; + j["value"] = x.get<::FlexFlow::TransposeAttrs>(); + break; + } + case 26: { + j["type"] = "weight"; + j["value"] = x.get<::FlexFlow::WeightAttrs>(); + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type ComputationGraphOpAttrs", x.index())); + } + } +} +} // namespace nlohmann +namespace rc { +Gen<::FlexFlow::ComputationGraphOpAttrs> + Arbitrary<::FlexFlow::ComputationGraphOpAttrs>::arbitrary() { + return gen::oneOf(gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::BatchMatmulAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::BatchNormAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::BroadcastAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::CastAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ConcatAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::Conv2DAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::DropoutAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ElementBinaryAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ElementUnaryAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ElementScalarUnaryAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::EmbeddingAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::FlatAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::GatherAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::InputAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::LayerNormAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::LinearAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::MultiHeadAttentionAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::NoopAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::Pool2DAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ReduceAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ReverseAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ReshapeAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::SplitAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::SoftmaxAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::TopKAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::TransposeAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::WeightAttrs>())); +} +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::ComputationGraphOpAttrs const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + case 3: { + oss << ""; + break; + } + case 4: { + oss << ""; + break; + } + case 5: { + oss << ""; + break; + } + case 6: { + oss << ""; + break; + } + case 7: { + oss << ""; + break; + } + case 8: { + oss << ""; + break; + } + case 9: { + oss << ""; + break; + } + case 10: { + oss << ""; + break; + } + case 11: { + oss << ""; + break; + } + case 12: { + oss << ""; + break; + } + case 13: { + oss << ""; + break; + } + case 14: { + oss << ""; + break; + } + case 15: { + oss << ""; + break; + } + case 16: { + oss << ""; + break; + } + case 17: { + oss << ""; + break; + } + case 18: { + oss << ""; + break; + } + case 19: { + oss << ""; + break; + } + case 20: { + oss << ""; + break; + } + case 21: { + oss << ""; + break; + } + case 22: { + oss << ""; + break; + } + case 23: { + oss << ""; + break; + } + case 24: { + oss << ""; + break; + } + case 25: { + oss << ""; + break; + } + case 26: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type ComputationGraphOpAttrs", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::ComputationGraphOpAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/datatype.dtg.cc b/lib/op-attrs/src/op-attrs/datatype.dtg.cc new file mode 100644 index 0000000000..a9c1d54f0e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/datatype.dtg.cc @@ -0,0 +1,102 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/datatype.enum.toml +/* proj-data +{ + "generated_from": "8315d0aa0a65b00c13aa580e923592ef" +} +*/ + +#include "op-attrs/datatype.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::DataType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(DataType x) { + switch (x) { + case DataType::BOOL: + return "BOOL"; + case DataType::INT32: + return "INT32"; + case DataType::INT64: + return "INT64"; + case DataType::HALF: + return "HALF"; + case DataType::FLOAT: + return "FLOAT"; + case DataType::DOUBLE: + return "DOUBLE"; + default: + std::ostringstream oss; + oss << "Unknown DataType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, DataType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, DataType x) { + switch (x) { + case DataType::BOOL: + j = "BOOL"; + break; + case DataType::INT32: + j = "INT32"; + break; + case DataType::INT64: + j = "INT64"; + break; + case DataType::HALF: + j = "HALF"; + break; + case DataType::FLOAT: + j = "FLOAT"; + break; + case DataType::DOUBLE: + j = "DOUBLE"; + break; + default: + std::ostringstream oss; + oss << "Unknown DataType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, DataType &x) { + std::string as_str = j.get(); + if (as_str == "BOOL") { + x = DataType::BOOL; + } else if (as_str == "INT32") { + x = DataType::INT32; + } else if (as_str == "INT64") { + x = DataType::INT64; + } else if (as_str == "HALF") { + x = DataType::HALF; + } else if (as_str == "FLOAT") { + x = DataType::FLOAT; + } else if (as_str == "DOUBLE") { + x = DataType::DOUBLE; + } else { + std::ostringstream oss; + oss << "Unknown DataType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::DataType::BOOL, + FlexFlow::DataType::INT32, + FlexFlow::DataType::INT64, + FlexFlow::DataType::HALF, + FlexFlow::DataType::FLOAT, + FlexFlow::DataType::DOUBLE); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc b/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc new file mode 100644 index 0000000000..8b22dfd18d --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc @@ -0,0 +1,68 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ff_dim.struct.toml +/* proj-data +{ + "generated_from": "a5fa89a024e95c4f2d52681a74cab30f" +} +*/ + +#include "op-attrs/ff_dim.dtg.h" + +#include + +namespace FlexFlow { +ff_dim_t::ff_dim_t(int const &value) : value(value) {} +bool ff_dim_t::operator==(ff_dim_t const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool ff_dim_t::operator!=(ff_dim_t const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool ff_dim_t::operator<(ff_dim_t const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool ff_dim_t::operator>(ff_dim_t const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool ff_dim_t::operator<=(ff_dim_t const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool ff_dim_t::operator>=(ff_dim_t const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()(FlexFlow::ff_dim_t const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ff_dim_t + adl_serializer::from_json(json const &j) { + return {j.at("value").template get()}; +} +void adl_serializer::to_json(json &j, + FlexFlow::ff_dim_t const &v) { + j["__type"] = "ff_dim_t"; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ff_dim_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ff_dim_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/get_op_type.cc b/lib/op-attrs/src/op-attrs/get_op_type.cc similarity index 60% rename from lib/op-attrs/src/get_op_type.cc rename to lib/op-attrs/src/op-attrs/get_op_type.cc index 3fa401b647..aced8d873c 100644 --- a/lib/op-attrs/src/get_op_type.cc +++ b/lib/op-attrs/src/op-attrs/get_op_type.cc @@ -3,25 +3,25 @@ namespace FlexFlow { OperatorType get_op_type(BatchMatmulAttrs const &) { - return Op::BATCHMATMUL; + return OperatorType::BATCHMATMUL; } OperatorType get_op_type(BatchNormAttrs const &) { - return Op::BATCHNORM; + return OperatorType::BATCHNORM; } OperatorType get_op_type(BroadcastAttrs const &) { - return Op::BROADCAST; + return OperatorType::BROADCAST; } OperatorType get_op_type(CastAttrs const &) { - return Op::CAST; + return OperatorType::CAST; } OperatorType get_op_type(ConcatAttrs const &) { - return Op::CONCAT; + return OperatorType::CONCAT; } OperatorType get_op_type(Conv2DAttrs const &) { - return Op::CONV2D; + return OperatorType::CONV2D; } OperatorType get_op_type(DropoutAttrs const &) { - return Op::DROPOUT; + return OperatorType::DROPOUT; } OperatorType get_op_type(ElementBinaryAttrs const &attrs) { return attrs.type; @@ -33,64 +33,67 @@ OperatorType get_op_type(ElementScalarUnaryAttrs const &attrs) { return attrs.op_type; } OperatorType get_op_type(EmbeddingAttrs const &) { - return Op::EMBEDDING; + return OperatorType::EMBEDDING; } OperatorType get_op_type(FlatAttrs const &) { - return Op::FLAT; + return OperatorType::FLAT; } OperatorType get_op_type(GatherAttrs const &) { - return Op::GATHER; + return OperatorType::GATHER; } OperatorType get_op_type(InputAttrs const &) { - return Op::INPUT; + return OperatorType::INPUT; } OperatorType get_op_type(LayerNormAttrs const &) { - return Op::LAYERNORM; + return OperatorType::LAYERNORM; } OperatorType get_op_type(LinearAttrs const &) { - return Op::LINEAR; + return OperatorType::LINEAR; } OperatorType get_op_type(MultiHeadAttentionAttrs const &) { - return Op::MULTIHEAD_ATTENTION; + return OperatorType::MULTIHEAD_ATTENTION; } OperatorType get_op_type(NoopAttrs const &) { - return Op::NOOP; + return OperatorType::NOOP; } OperatorType get_op_type(Pool2DAttrs const &) { - return Op::POOL2D; + return OperatorType::POOL2D; } -OperatorType get_op_type(ReduceAttrs const &) { - return Op::REDUCE_SUM; +OperatorType get_op_type(ReduceAttrs const &attrs) { + return attrs.op_type; } OperatorType get_op_type(ReshapeAttrs const &) { - return Op::RESHAPE; + return OperatorType::RESHAPE; +} +OperatorType get_op_type(ReverseAttrs const &) { + return OperatorType::REVERSE; } OperatorType get_op_type(SplitAttrs const &) { - return Op::SPLIT; + return OperatorType::SPLIT; } OperatorType get_op_type(SoftmaxAttrs const &) { - return Op::SOFTMAX; + return OperatorType::SOFTMAX; } OperatorType get_op_type(TopKAttrs const &) { - return Op::TOPK; + return OperatorType::TOPK; } OperatorType get_op_type(TransposeAttrs const &) { - return Op::TRANSPOSE; + return OperatorType::TRANSPOSE; } OperatorType get_op_type(CombineAttrs const &) { - return Op::COMBINE; + return OperatorType::COMBINE; } OperatorType get_op_type(ReductionAttrs const &) { - return Op::REDUCTION; + return OperatorType::REDUCTION; } OperatorType get_op_type(RepartitionAttrs const &) { - return Op::REPARTITION; + return OperatorType::REPARTITION; } OperatorType get_op_type(ReplicateAttrs const &) { - return Op::REPLICATE; + return OperatorType::REPLICATE; } -OperatorType get_op_type(ReverseAttrs const &attrs) { - return Op::REVERSE; +OperatorType get_op_type(WeightAttrs const &) { + return OperatorType::WEIGHT; } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.dtg.cc new file mode 100644 index 0000000000..ed06df2c78 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "50968fb8a3d43395d0eab7594f4935c0" +} +*/ + +#include "op-attrs/l1_regularizer_attrs.dtg.h" + +#include + +namespace FlexFlow { +L1RegularizerAttrs::L1RegularizerAttrs(float const &lambda) : lambda(lambda) {} +bool L1RegularizerAttrs::operator==(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) == std::tie(other.lambda); +} +bool L1RegularizerAttrs::operator!=(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) != std::tie(other.lambda); +} +bool L1RegularizerAttrs::operator<(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) < std::tie(other.lambda); +} +bool L1RegularizerAttrs::operator>(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) > std::tie(other.lambda); +} +bool L1RegularizerAttrs::operator<=(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) <= std::tie(other.lambda); +} +bool L1RegularizerAttrs::operator>=(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) >= std::tie(other.lambda); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::L1RegularizerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.lambda) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::L1RegularizerAttrs + adl_serializer::from_json(json const &j) { + return {j.at("lambda").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::L1RegularizerAttrs const &v) { + j["__type"] = "L1RegularizerAttrs"; + j["lambda"] = v.lambda; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(L1RegularizerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, L1RegularizerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.dtg.cc new file mode 100644 index 0000000000..f0f3f34ee5 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "c4f182e547ab6f0d5613e7eeb95d438e" +} +*/ + +#include "op-attrs/l2_regularizer_attrs.dtg.h" + +#include + +namespace FlexFlow { +L2RegularizerAttrs::L2RegularizerAttrs(float const &lambda) : lambda(lambda) {} +bool L2RegularizerAttrs::operator==(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) == std::tie(other.lambda); +} +bool L2RegularizerAttrs::operator!=(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) != std::tie(other.lambda); +} +bool L2RegularizerAttrs::operator<(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) < std::tie(other.lambda); +} +bool L2RegularizerAttrs::operator>(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) > std::tie(other.lambda); +} +bool L2RegularizerAttrs::operator<=(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) <= std::tie(other.lambda); +} +bool L2RegularizerAttrs::operator>=(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) >= std::tie(other.lambda); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::L2RegularizerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.lambda) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::L2RegularizerAttrs + adl_serializer::from_json(json const &j) { + return {j.at("lambda").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::L2RegularizerAttrs const &v) { + j["__type"] = "L2RegularizerAttrs"; + j["lambda"] = v.lambda; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(L2RegularizerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, L2RegularizerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/operator_type.cc b/lib/op-attrs/src/op-attrs/operator_type.cc new file mode 100644 index 0000000000..5a516ef122 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/operator_type.cc @@ -0,0 +1,24 @@ +#include "op-attrs/operator_type.h" + +namespace FlexFlow { + +std::string get_operator_type_name(OperatorType op) { + return fmt::to_string(op); +} + +bool is_parallel_op(OperatorType const &t) { + switch (t) { + case OperatorType::REPARTITION: + case OperatorType::COMBINE: + case OperatorType::REPLICATE: + case OperatorType::REDUCTION: + case OperatorType::BATCH: + case OperatorType::PIPELINE: + case OperatorType::FUSED_PARALLEL: + return true; + default: + return false; + } +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/operator_type.dtg.cc b/lib/op-attrs/src/op-attrs/operator_type.dtg.cc new file mode 100644 index 0000000000..07b6396a5a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/operator_type.dtg.cc @@ -0,0 +1,720 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/operator_type.enum.toml +/* proj-data +{ + "generated_from": "c1c4687ef2fbc7dad996e5c25d47124c" +} +*/ + +#include "op-attrs/operator_type.dtg.h" + +#include +#include + +namespace std { +size_t + hash::operator()(FlexFlow::OperatorType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(OperatorType x) { + switch (x) { + case OperatorType::NOOP: + return "NOOP"; + case OperatorType::INPUT: + return "INPUT"; + case OperatorType::WEIGHT: + return "WEIGHT"; + case OperatorType::CONV2D: + return "CONV2D"; + case OperatorType::DROPOUT: + return "DROPOUT"; + case OperatorType::LINEAR: + return "LINEAR"; + case OperatorType::BATCHMATMUL: + return "BATCHMATMUL"; + case OperatorType::POOL2D: + return "POOL2D"; + case OperatorType::SCALAR_MULTIPLY: + return "SCALAR_MULTIPLY"; + case OperatorType::SCALAR_ADD: + return "SCALAR_ADD"; + case OperatorType::SCALAR_FLOOR_DIV: + return "SCALAR_FLOOR_DIV"; + case OperatorType::SCALAR_TRUE_DIV: + return "SCALAR_TRUE_DIV"; + case OperatorType::SCALAR_SUB: + return "SCALAR_SUB"; + case OperatorType::RELU: + return "RELU"; + case OperatorType::IDENTITY: + return "IDENTITY"; + case OperatorType::SIGMOID: + return "SIGMOID"; + case OperatorType::TANH: + return "TANH"; + case OperatorType::ELU: + return "ELU"; + case OperatorType::FLAT: + return "FLAT"; + case OperatorType::SOFTMAX: + return "SOFTMAX"; + case OperatorType::BATCHNORM: + return "BATCHNORM"; + case OperatorType::CONCAT: + return "CONCAT"; + case OperatorType::SPLIT: + return "SPLIT"; + case OperatorType::EMBEDDING: + return "EMBEDDING"; + case OperatorType::CACHE: + return "CACHE"; + case OperatorType::RESHAPE: + return "RESHAPE"; + case OperatorType::REVERSE: + return "REVERSE"; + case OperatorType::TRANSPOSE: + return "TRANSPOSE"; + case OperatorType::EW_ADD: + return "EW_ADD"; + case OperatorType::EW_MUL: + return "EW_MUL"; + case OperatorType::MATMUL: + return "MATMUL"; + case OperatorType::MUL: + return "MUL"; + case OperatorType::ENLARGE: + return "ENLARGE"; + case OperatorType::SQUEEZE: + return "SQUEEZE"; + case OperatorType::UNSQUEEZE: + return "UNSQUEEZE"; + case OperatorType::EW_SUB: + return "EW_SUB"; + case OperatorType::EW_DIV: + return "EW_DIV"; + case OperatorType::EW_EQUAL: + return "EW_EQUAL"; + case OperatorType::EW_GREATER: + return "EW_GREATER"; + case OperatorType::EW_LESS: + return "EW_LESS"; + case OperatorType::EW_MAX: + return "EW_MAX"; + case OperatorType::EW_MIN: + return "EW_MIN"; + case OperatorType::REDUCE_ARGMAX: + return "REDUCE_ARGMAX"; + case OperatorType::REDUCE_ARGMIN: + return "REDUCE_ARGMIN"; + case OperatorType::REDUCE_MAX: + return "REDUCE_MAX"; + case OperatorType::REDUCE_MEAN: + return "REDUCE_MEAN"; + case OperatorType::REDUCE_MIN: + return "REDUCE_MIN"; + case OperatorType::REDUCE_PROD: + return "REDUCE_PROD"; + case OperatorType::REDUCE_SUM: + return "REDUCE_SUM"; + case OperatorType::PAD: + return "PAD"; + case OperatorType::SHAPE: + return "SHAPE"; + case OperatorType::SIZE: + return "SIZE"; + case OperatorType::TOPK: + return "TOPK"; + case OperatorType::WHERE: + return "WHERE"; + case OperatorType::CEIL: + return "CEIL"; + case OperatorType::CAST: + return "CAST"; + case OperatorType::EXP: + return "EXP"; + case OperatorType::ROUND: + return "ROUND"; + case OperatorType::LOG: + return "LOG"; + case OperatorType::LOGICAL_NOT: + return "LOGICAL_NOT"; + case OperatorType::SQRT: + return "SQRT"; + case OperatorType::SIN: + return "SIN"; + case OperatorType::COS: + return "COS"; + case OperatorType::LEAKYRELU: + return "LEAKYRELU"; + case OperatorType::SLICE: + return "SLICE"; + case OperatorType::RESIZE: + return "RESIZE"; + case OperatorType::PRELU: + return "PRELU"; + case OperatorType::GELU: + return "GELU"; + case OperatorType::MULTIHEAD_ATTENTION: + return "MULTIHEAD_ATTENTION"; + case OperatorType::FUSED: + return "FUSED"; + case OperatorType::RSQRT: + return "RSQRT"; + case OperatorType::POW: + return "POW"; + case OperatorType::MEAN: + return "MEAN"; + case OperatorType::LAYERNORM: + return "LAYERNORM"; + case OperatorType::GATHER: + return "GATHER"; + case OperatorType::BROADCAST: + return "BROADCAST"; + case OperatorType::REPARTITION: + return "REPARTITION"; + case OperatorType::COMBINE: + return "COMBINE"; + case OperatorType::REPLICATE: + return "REPLICATE"; + case OperatorType::REDUCTION: + return "REDUCTION"; + case OperatorType::BATCH: + return "BATCH"; + case OperatorType::PIPELINE: + return "PIPELINE"; + case OperatorType::FUSED_PARALLEL: + return "FUSED_PARALLEL"; + default: + std::ostringstream oss; + oss << "Unknown OperatorType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, OperatorType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, OperatorType x) { + switch (x) { + case OperatorType::NOOP: + j = "NOOP"; + break; + case OperatorType::INPUT: + j = "INPUT"; + break; + case OperatorType::WEIGHT: + j = "WEIGHT"; + break; + case OperatorType::CONV2D: + j = "CONV2D"; + break; + case OperatorType::DROPOUT: + j = "DROPOUT"; + break; + case OperatorType::LINEAR: + j = "LINEAR"; + break; + case OperatorType::BATCHMATMUL: + j = "BATCHMATMUL"; + break; + case OperatorType::POOL2D: + j = "POOL2D"; + break; + case OperatorType::SCALAR_MULTIPLY: + j = "SCALAR_MULTIPLY"; + break; + case OperatorType::SCALAR_ADD: + j = "SCALAR_ADD"; + break; + case OperatorType::SCALAR_FLOOR_DIV: + j = "SCALAR_FLOOR_DIV"; + break; + case OperatorType::SCALAR_TRUE_DIV: + j = "SCALAR_TRUE_DIV"; + break; + case OperatorType::SCALAR_SUB: + j = "SCALAR_SUB"; + break; + case OperatorType::RELU: + j = "RELU"; + break; + case OperatorType::IDENTITY: + j = "IDENTITY"; + break; + case OperatorType::SIGMOID: + j = "SIGMOID"; + break; + case OperatorType::TANH: + j = "TANH"; + break; + case OperatorType::ELU: + j = "ELU"; + break; + case OperatorType::FLAT: + j = "FLAT"; + break; + case OperatorType::SOFTMAX: + j = "SOFTMAX"; + break; + case OperatorType::BATCHNORM: + j = "BATCHNORM"; + break; + case OperatorType::CONCAT: + j = "CONCAT"; + break; + case OperatorType::SPLIT: + j = "SPLIT"; + break; + case OperatorType::EMBEDDING: + j = "EMBEDDING"; + break; + case OperatorType::CACHE: + j = "CACHE"; + break; + case OperatorType::RESHAPE: + j = "RESHAPE"; + break; + case OperatorType::REVERSE: + j = "REVERSE"; + break; + case OperatorType::TRANSPOSE: + j = "TRANSPOSE"; + break; + case OperatorType::EW_ADD: + j = "EW_ADD"; + break; + case OperatorType::EW_MUL: + j = "EW_MUL"; + break; + case OperatorType::MATMUL: + j = "MATMUL"; + break; + case OperatorType::MUL: + j = "MUL"; + break; + case OperatorType::ENLARGE: + j = "ENLARGE"; + break; + case OperatorType::SQUEEZE: + j = "SQUEEZE"; + break; + case OperatorType::UNSQUEEZE: + j = "UNSQUEEZE"; + break; + case OperatorType::EW_SUB: + j = "EW_SUB"; + break; + case OperatorType::EW_DIV: + j = "EW_DIV"; + break; + case OperatorType::EW_EQUAL: + j = "EW_EQUAL"; + break; + case OperatorType::EW_GREATER: + j = "EW_GREATER"; + break; + case OperatorType::EW_LESS: + j = "EW_LESS"; + break; + case OperatorType::EW_MAX: + j = "EW_MAX"; + break; + case OperatorType::EW_MIN: + j = "EW_MIN"; + break; + case OperatorType::REDUCE_ARGMAX: + j = "REDUCE_ARGMAX"; + break; + case OperatorType::REDUCE_ARGMIN: + j = "REDUCE_ARGMIN"; + break; + case OperatorType::REDUCE_MAX: + j = "REDUCE_MAX"; + break; + case OperatorType::REDUCE_MEAN: + j = "REDUCE_MEAN"; + break; + case OperatorType::REDUCE_MIN: + j = "REDUCE_MIN"; + break; + case OperatorType::REDUCE_PROD: + j = "REDUCE_PROD"; + break; + case OperatorType::REDUCE_SUM: + j = "REDUCE_SUM"; + break; + case OperatorType::PAD: + j = "PAD"; + break; + case OperatorType::SHAPE: + j = "SHAPE"; + break; + case OperatorType::SIZE: + j = "SIZE"; + break; + case OperatorType::TOPK: + j = "TOPK"; + break; + case OperatorType::WHERE: + j = "WHERE"; + break; + case OperatorType::CEIL: + j = "CEIL"; + break; + case OperatorType::CAST: + j = "CAST"; + break; + case OperatorType::EXP: + j = "EXP"; + break; + case OperatorType::ROUND: + j = "ROUND"; + break; + case OperatorType::LOG: + j = "LOG"; + break; + case OperatorType::LOGICAL_NOT: + j = "LOGICAL_NOT"; + break; + case OperatorType::SQRT: + j = "SQRT"; + break; + case OperatorType::SIN: + j = "SIN"; + break; + case OperatorType::COS: + j = "COS"; + break; + case OperatorType::LEAKYRELU: + j = "LEAKYRELU"; + break; + case OperatorType::SLICE: + j = "SLICE"; + break; + case OperatorType::RESIZE: + j = "RESIZE"; + break; + case OperatorType::PRELU: + j = "PRELU"; + break; + case OperatorType::GELU: + j = "GELU"; + break; + case OperatorType::MULTIHEAD_ATTENTION: + j = "MULTIHEAD_ATTENTION"; + break; + case OperatorType::FUSED: + j = "FUSED"; + break; + case OperatorType::RSQRT: + j = "RSQRT"; + break; + case OperatorType::POW: + j = "POW"; + break; + case OperatorType::MEAN: + j = "MEAN"; + break; + case OperatorType::LAYERNORM: + j = "LAYERNORM"; + break; + case OperatorType::GATHER: + j = "GATHER"; + break; + case OperatorType::BROADCAST: + j = "BROADCAST"; + break; + case OperatorType::REPARTITION: + j = "REPARTITION"; + break; + case OperatorType::COMBINE: + j = "COMBINE"; + break; + case OperatorType::REPLICATE: + j = "REPLICATE"; + break; + case OperatorType::REDUCTION: + j = "REDUCTION"; + break; + case OperatorType::BATCH: + j = "BATCH"; + break; + case OperatorType::PIPELINE: + j = "PIPELINE"; + break; + case OperatorType::FUSED_PARALLEL: + j = "FUSED_PARALLEL"; + break; + default: + std::ostringstream oss; + oss << "Unknown OperatorType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, OperatorType &x) { + std::string as_str = j.get(); + if (as_str == "NOOP") { + x = OperatorType::NOOP; + } else if (as_str == "INPUT") { + x = OperatorType::INPUT; + } else if (as_str == "WEIGHT") { + x = OperatorType::WEIGHT; + } else if (as_str == "CONV2D") { + x = OperatorType::CONV2D; + } else if (as_str == "DROPOUT") { + x = OperatorType::DROPOUT; + } else if (as_str == "LINEAR") { + x = OperatorType::LINEAR; + } else if (as_str == "BATCHMATMUL") { + x = OperatorType::BATCHMATMUL; + } else if (as_str == "POOL2D") { + x = OperatorType::POOL2D; + } else if (as_str == "SCALAR_MULTIPLY") { + x = OperatorType::SCALAR_MULTIPLY; + } else if (as_str == "SCALAR_ADD") { + x = OperatorType::SCALAR_ADD; + } else if (as_str == "SCALAR_FLOOR_DIV") { + x = OperatorType::SCALAR_FLOOR_DIV; + } else if (as_str == "SCALAR_TRUE_DIV") { + x = OperatorType::SCALAR_TRUE_DIV; + } else if (as_str == "SCALAR_SUB") { + x = OperatorType::SCALAR_SUB; + } else if (as_str == "RELU") { + x = OperatorType::RELU; + } else if (as_str == "IDENTITY") { + x = OperatorType::IDENTITY; + } else if (as_str == "SIGMOID") { + x = OperatorType::SIGMOID; + } else if (as_str == "TANH") { + x = OperatorType::TANH; + } else if (as_str == "ELU") { + x = OperatorType::ELU; + } else if (as_str == "FLAT") { + x = OperatorType::FLAT; + } else if (as_str == "SOFTMAX") { + x = OperatorType::SOFTMAX; + } else if (as_str == "BATCHNORM") { + x = OperatorType::BATCHNORM; + } else if (as_str == "CONCAT") { + x = OperatorType::CONCAT; + } else if (as_str == "SPLIT") { + x = OperatorType::SPLIT; + } else if (as_str == "EMBEDDING") { + x = OperatorType::EMBEDDING; + } else if (as_str == "CACHE") { + x = OperatorType::CACHE; + } else if (as_str == "RESHAPE") { + x = OperatorType::RESHAPE; + } else if (as_str == "REVERSE") { + x = OperatorType::REVERSE; + } else if (as_str == "TRANSPOSE") { + x = OperatorType::TRANSPOSE; + } else if (as_str == "EW_ADD") { + x = OperatorType::EW_ADD; + } else if (as_str == "EW_MUL") { + x = OperatorType::EW_MUL; + } else if (as_str == "MATMUL") { + x = OperatorType::MATMUL; + } else if (as_str == "MUL") { + x = OperatorType::MUL; + } else if (as_str == "ENLARGE") { + x = OperatorType::ENLARGE; + } else if (as_str == "SQUEEZE") { + x = OperatorType::SQUEEZE; + } else if (as_str == "UNSQUEEZE") { + x = OperatorType::UNSQUEEZE; + } else if (as_str == "EW_SUB") { + x = OperatorType::EW_SUB; + } else if (as_str == "EW_DIV") { + x = OperatorType::EW_DIV; + } else if (as_str == "EW_EQUAL") { + x = OperatorType::EW_EQUAL; + } else if (as_str == "EW_GREATER") { + x = OperatorType::EW_GREATER; + } else if (as_str == "EW_LESS") { + x = OperatorType::EW_LESS; + } else if (as_str == "EW_MAX") { + x = OperatorType::EW_MAX; + } else if (as_str == "EW_MIN") { + x = OperatorType::EW_MIN; + } else if (as_str == "REDUCE_ARGMAX") { + x = OperatorType::REDUCE_ARGMAX; + } else if (as_str == "REDUCE_ARGMIN") { + x = OperatorType::REDUCE_ARGMIN; + } else if (as_str == "REDUCE_MAX") { + x = OperatorType::REDUCE_MAX; + } else if (as_str == "REDUCE_MEAN") { + x = OperatorType::REDUCE_MEAN; + } else if (as_str == "REDUCE_MIN") { + x = OperatorType::REDUCE_MIN; + } else if (as_str == "REDUCE_PROD") { + x = OperatorType::REDUCE_PROD; + } else if (as_str == "REDUCE_SUM") { + x = OperatorType::REDUCE_SUM; + } else if (as_str == "PAD") { + x = OperatorType::PAD; + } else if (as_str == "SHAPE") { + x = OperatorType::SHAPE; + } else if (as_str == "SIZE") { + x = OperatorType::SIZE; + } else if (as_str == "TOPK") { + x = OperatorType::TOPK; + } else if (as_str == "WHERE") { + x = OperatorType::WHERE; + } else if (as_str == "CEIL") { + x = OperatorType::CEIL; + } else if (as_str == "CAST") { + x = OperatorType::CAST; + } else if (as_str == "EXP") { + x = OperatorType::EXP; + } else if (as_str == "ROUND") { + x = OperatorType::ROUND; + } else if (as_str == "LOG") { + x = OperatorType::LOG; + } else if (as_str == "LOGICAL_NOT") { + x = OperatorType::LOGICAL_NOT; + } else if (as_str == "SQRT") { + x = OperatorType::SQRT; + } else if (as_str == "SIN") { + x = OperatorType::SIN; + } else if (as_str == "COS") { + x = OperatorType::COS; + } else if (as_str == "LEAKYRELU") { + x = OperatorType::LEAKYRELU; + } else if (as_str == "SLICE") { + x = OperatorType::SLICE; + } else if (as_str == "RESIZE") { + x = OperatorType::RESIZE; + } else if (as_str == "PRELU") { + x = OperatorType::PRELU; + } else if (as_str == "GELU") { + x = OperatorType::GELU; + } else if (as_str == "MULTIHEAD_ATTENTION") { + x = OperatorType::MULTIHEAD_ATTENTION; + } else if (as_str == "FUSED") { + x = OperatorType::FUSED; + } else if (as_str == "RSQRT") { + x = OperatorType::RSQRT; + } else if (as_str == "POW") { + x = OperatorType::POW; + } else if (as_str == "MEAN") { + x = OperatorType::MEAN; + } else if (as_str == "LAYERNORM") { + x = OperatorType::LAYERNORM; + } else if (as_str == "GATHER") { + x = OperatorType::GATHER; + } else if (as_str == "BROADCAST") { + x = OperatorType::BROADCAST; + } else if (as_str == "REPARTITION") { + x = OperatorType::REPARTITION; + } else if (as_str == "COMBINE") { + x = OperatorType::COMBINE; + } else if (as_str == "REPLICATE") { + x = OperatorType::REPLICATE; + } else if (as_str == "REDUCTION") { + x = OperatorType::REDUCTION; + } else if (as_str == "BATCH") { + x = OperatorType::BATCH; + } else if (as_str == "PIPELINE") { + x = OperatorType::PIPELINE; + } else if (as_str == "FUSED_PARALLEL") { + x = OperatorType::FUSED_PARALLEL; + } else { + std::ostringstream oss; + oss << "Unknown OperatorType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element( + FlexFlow::OperatorType::NOOP, + FlexFlow::OperatorType::INPUT, + FlexFlow::OperatorType::WEIGHT, + FlexFlow::OperatorType::CONV2D, + FlexFlow::OperatorType::DROPOUT, + FlexFlow::OperatorType::LINEAR, + FlexFlow::OperatorType::BATCHMATMUL, + FlexFlow::OperatorType::POOL2D, + FlexFlow::OperatorType::SCALAR_MULTIPLY, + FlexFlow::OperatorType::SCALAR_ADD, + FlexFlow::OperatorType::SCALAR_FLOOR_DIV, + FlexFlow::OperatorType::SCALAR_TRUE_DIV, + FlexFlow::OperatorType::SCALAR_SUB, + FlexFlow::OperatorType::RELU, + FlexFlow::OperatorType::IDENTITY, + FlexFlow::OperatorType::SIGMOID, + FlexFlow::OperatorType::TANH, + FlexFlow::OperatorType::ELU, + FlexFlow::OperatorType::FLAT, + FlexFlow::OperatorType::SOFTMAX, + FlexFlow::OperatorType::BATCHNORM, + FlexFlow::OperatorType::CONCAT, + FlexFlow::OperatorType::SPLIT, + FlexFlow::OperatorType::EMBEDDING, + FlexFlow::OperatorType::CACHE, + FlexFlow::OperatorType::RESHAPE, + FlexFlow::OperatorType::REVERSE, + FlexFlow::OperatorType::TRANSPOSE, + FlexFlow::OperatorType::EW_ADD, + FlexFlow::OperatorType::EW_MUL, + FlexFlow::OperatorType::MATMUL, + FlexFlow::OperatorType::MUL, + FlexFlow::OperatorType::ENLARGE, + FlexFlow::OperatorType::SQUEEZE, + FlexFlow::OperatorType::UNSQUEEZE, + FlexFlow::OperatorType::EW_SUB, + FlexFlow::OperatorType::EW_DIV, + FlexFlow::OperatorType::EW_EQUAL, + FlexFlow::OperatorType::EW_GREATER, + FlexFlow::OperatorType::EW_LESS, + FlexFlow::OperatorType::EW_MAX, + FlexFlow::OperatorType::EW_MIN, + FlexFlow::OperatorType::REDUCE_ARGMAX, + FlexFlow::OperatorType::REDUCE_ARGMIN, + FlexFlow::OperatorType::REDUCE_MAX, + FlexFlow::OperatorType::REDUCE_MEAN, + FlexFlow::OperatorType::REDUCE_MIN, + FlexFlow::OperatorType::REDUCE_PROD, + FlexFlow::OperatorType::REDUCE_SUM, + FlexFlow::OperatorType::PAD, + FlexFlow::OperatorType::SHAPE, + FlexFlow::OperatorType::SIZE, + FlexFlow::OperatorType::TOPK, + FlexFlow::OperatorType::WHERE, + FlexFlow::OperatorType::CEIL, + FlexFlow::OperatorType::CAST, + FlexFlow::OperatorType::EXP, + FlexFlow::OperatorType::ROUND, + FlexFlow::OperatorType::LOG, + FlexFlow::OperatorType::LOGICAL_NOT, + FlexFlow::OperatorType::SQRT, + FlexFlow::OperatorType::SIN, + FlexFlow::OperatorType::COS, + FlexFlow::OperatorType::LEAKYRELU, + FlexFlow::OperatorType::SLICE, + FlexFlow::OperatorType::RESIZE, + FlexFlow::OperatorType::PRELU, + FlexFlow::OperatorType::GELU, + FlexFlow::OperatorType::MULTIHEAD_ATTENTION, + FlexFlow::OperatorType::FUSED, + FlexFlow::OperatorType::RSQRT, + FlexFlow::OperatorType::POW, + FlexFlow::OperatorType::MEAN, + FlexFlow::OperatorType::LAYERNORM, + FlexFlow::OperatorType::GATHER, + FlexFlow::OperatorType::BROADCAST, + FlexFlow::OperatorType::REPARTITION, + FlexFlow::OperatorType::COMBINE, + FlexFlow::OperatorType::REPLICATE, + FlexFlow::OperatorType::REDUCTION, + FlexFlow::OperatorType::BATCH, + FlexFlow::OperatorType::PIPELINE, + FlexFlow::OperatorType::FUSED_PARALLEL); +} +} // namespace rc diff --git a/lib/op-attrs/src/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc similarity index 77% rename from lib/op-attrs/src/attention.cc rename to lib/op-attrs/src/op-attrs/ops/attention.cc index 2c1500a477..14ab2b9b00 100644 --- a/lib/op-attrs/src/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -1,4 +1,9 @@ #include "op-attrs/ops/attention.h" +#include "op-attrs/ops/attention/multihead_attention_inputs.h" +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" +#include "utils/integer_conversions.h" namespace FlexFlow { @@ -27,78 +32,175 @@ int get_oProjSize(MultiHeadAttentionAttrs const &attrs) { } int get_qSize(TensorShape const &query_shape) { - return query_shape.at(ff_dim_t(0)); + return dim_at_idx(query_shape, ff_dim_t(0)); } int get_kSize(TensorShape const &key_shape) { - return key_shape.at(ff_dim_t(0)); + return dim_at_idx(key_shape, ff_dim_t(0)); } int get_vSize(TensorShape const &value_shape) { - return value_shape.at(ff_dim_t(0)); + return dim_at_idx(value_shape, ff_dim_t(0)); } -int get_qSize(MultiHeadAttentionInputs const &) { +int get_qSize(MultiHeadAttentionParallelInputs const &) { NOT_IMPLEMENTED(); } -int get_kSize(MultiHeadAttentionInputs const &) { +int get_qSize(MultiHeadAttentionInputs const &) { NOT_IMPLEMENTED(); } -int get_vSize(MultiHeadAttentionInputs const &) { +int get_kSize(MultiHeadAttentionParallelInputs const &) { NOT_IMPLEMENTED(); } -TensorShape +int get_kSize(MultiHeadAttentionInputs const &) { + NOT_IMPLEMENTED(); +} + +int get_vSize(MultiHeadAttentionParallelInputs const &) { + NOT_IMPLEMENTED(); +} + +int get_vSize(MultiHeadAttentionInputs const &) { + NOT_IMPLEMENTED(); +} + +tl::expected + get_output_shape(MultiHeadAttentionAttrs const &attrs, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { + tl::expected parse_result = + parse_attention_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); + } + + MultiHeadAttentionInputs parsed = parse_result.value(); + + return TensorShape{ + TensorDims{FFOrdered{ + parsed.batch_size, + parsed.sequence_length, + size_t_from_int(attrs.embed_dim), + }}, + parsed.datatype, + }; +} + +tl::expected get_weights_shape(MultiHeadAttentionAttrs const &attrs, - MultiHeadAttentionInputs const &inputs) { - size_t qParas = get_qProjSize(attrs) * get_qSize(inputs); - size_t kParas = get_kProjSize(attrs) * get_kSize(inputs); - size_t vParas = get_vProjSize(attrs) * get_vSize(inputs); - TensorShape output_shape = get_output_shape(attrs, inputs); - size_t oParas = get_oProjSize(attrs) * get_oSize(output_shape); + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { + tl::expected parse_result = + parse_attention_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); + } - TensorDims dims = {qParas + kParas + vParas + oParas, - static_cast(attrs.embed_dim)}; + MultiHeadAttentionInputs parsed = parse_result.value(); - return {dims, DataType::FLOAT}; + // W^Q_i in "Attention Is All You Need" top of page 5 + size_t qProjectWeightSize = parsed.query_size * attrs.kdim; + + // W^K_i in "Attention Is All You Need" top of page 5 (all i's put together) + size_t kProjectWeightSize = parsed.key_size * attrs.kdim; + + // W^V_i in "Attention Is All You Need" top of page 5 (all i's put together) + size_t vProjectWeightSize = parsed.value_size * attrs.vdim; + + // W^O in "Attention Is All You Need" top of page 5, with num_heads factored + // out + size_t outWeightSize = parsed.value_size * attrs.embed_dim; + + return TensorShape{ + TensorDims{FFOrdered{ + (qProjectWeightSize + kProjectWeightSize + vProjectWeightSize + + outWeightSize), + size_t_from_int(attrs.num_heads), + }}, + parsed.datatype, + }; } -ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, - ParallelTensorShape const &query_shape, - ParallelTensorShape const &key_shape, - ParallelTensorShape const &value_shape) { - /* ParallelDim replica_dim = query_shape.at(ff_dim_t(query_shape.num_dims() - - * 2)); */ - /* replica_dim.size = replica_dim.degree; */ +tl::expected + get_weights_shape(MultiHeadAttentionAttrs const &attrs, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v) { + tl::expected parse_result = + parse_attention_parallel_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); + } + MultiHeadAttentionParallelInputs parsed = parse_result.value(); + + tl::expected result_unpar_get_shape = + get_weights_shape(attrs, + get_reduced_shape(input_q), + get_reduced_shape(input_k), + get_reduced_shape(input_v)); + if (!result_unpar_get_shape.has_value()) { + return tl::unexpected(result_unpar_get_shape.error()); + } + TensorShape unpar_shape = result_unpar_get_shape.value(); - /* ParallelDim */ + int joined_dim_degree = 1; + int head_dim_degree = parsed.discard_copy_degree.value; - ParallelTensorShape output_shape = query_shape; - output_shape.at(ff_dim_t(output_shape.num_dims() - 1)).size = attrs.embed_dim; - return output_shape; + return lift_to_parallel_with_degrees( + unpar_shape, + SumDegree{1}, + DiscardCopyDegree{parsed.batch_dim.degree}, + FFOrdered{joined_dim_degree, head_dim_degree}); } -TensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, - TensorShape const &query_shape, - TensorShape const &key_shape, - TensorShape const &value_shape) { - ParallelTensorShape parallel_shape = +tl::expected + get_output_shape(MultiHeadAttentionAttrs const &attrs, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v) { + tl::expected parse_result = + parse_attention_parallel_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); + } + MultiHeadAttentionParallelInputs parsed = parse_result.value(); + + tl::expected result_unpar_get_shape = get_output_shape(attrs, - static_cast(query_shape), - static_cast(key_shape), - static_cast(value_shape)); - return get_tensor_shape_unsafe(parallel_shape); + get_reduced_shape(input_q), + get_reduced_shape(input_k), + get_reduced_shape(input_v)); + if (!result_unpar_get_shape.has_value()) { + return tl::unexpected(result_unpar_get_shape.error()); + } + TensorShape unpar_shape = result_unpar_get_shape.value(); + + int sum_degree = parsed.discard_copy_degree.value; + int discard_copy_degree = 1; + int batch_degree = parsed.batch_dim.degree; + int seq_len_degree = 1; + int out_dim_degree = 1; + + return lift_to_parallel_with_degrees( + unpar_shape, + SumDegree{sum_degree}, + DiscardCopyDegree{discard_copy_degree}, + FFOrdered{batch_degree, seq_len_degree, out_dim_degree}); } -TensorShape get_output_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &) { + +int get_oSize(ParallelTensorShape const &) { NOT_IMPLEMENTED(); } -int get_oSize(ParallelTensorShape const &) { +int get_oSize(TensorShape const &) { NOT_IMPLEMENTED(); } + } // namespace FlexFlow // Tensor FFModel::multihead_attention(const Tensor query, diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc new file mode 100644 index 0000000000..65feb642e1 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc @@ -0,0 +1,80 @@ +#include "op-attrs/ops/attention/multihead_attention_inputs.h" +#include "op-attrs/tensor_shape.h" + +namespace FlexFlow { + +template +static bool all_same(T const &x, T const &y, T const &z) { + return x == y && y == z; +} + +tl::expected + parse_attention_input_shape(TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { + if (num_dims(input_q) != 3) { + return tl::unexpected( + fmt::format("Query input has incorrect number of dims: {} != {}", + num_dims(input_q), + 3)); + } + if (num_dims(input_k) != 3) { + return tl::unexpected( + fmt::format("Key input has incorrect number of dims: {} != {}", + num_dims(input_k), + 3)); + } + if (num_dims(input_v) != 3) { + return tl::unexpected( + fmt::format("Value input has incorrect number of dims: {} != {}", + num_dims(input_v), + 3)); + } + + size_t seq_len_q = dim_at_idx(input_q, ff_dim_t{-2}); + size_t seq_len_k = dim_at_idx(input_k, ff_dim_t{-2}); + size_t seq_len_v = dim_at_idx(input_v, ff_dim_t{-2}); + + if (!all_same(seq_len_q, seq_len_k, seq_len_v)) { + return tl::unexpected(fmt::format( + "Q, K, V disagree on the sequence length: {} (Q) vs {} (K) vs {} (V)", + seq_len_q, + seq_len_k, + seq_len_v)); + } + + size_t batch_size_q = dim_at_idx(input_q, ff_dim_t{-3}); + size_t batch_size_k = dim_at_idx(input_k, ff_dim_t{-3}); + size_t batch_size_v = dim_at_idx(input_v, ff_dim_t{-3}); + + if (!all_same(batch_size_q, batch_size_k, batch_size_v)) { + return tl::unexpected(fmt::format( + "Q, K, V disagree on the batch size: {} (Q) vs {} (K) vs {} (V)", + batch_size_q, + batch_size_k, + batch_size_v)); + } + + if (!all_same(input_q.data_type, input_k.data_type, input_v.data_type)) { + return tl::unexpected(fmt::format( + "Q, K, V disagree on the datatype: {} (Q) vs {} (K) vs {} (V)", + input_q.data_type, + input_k.data_type, + input_v.data_type)); + } + + size_t q_size = dim_at_idx(input_q, ff_dim_t{-1}); + size_t k_size = dim_at_idx(input_k, ff_dim_t{-1}); + size_t v_size = dim_at_idx(input_v, ff_dim_t{-1}); + + return MultiHeadAttentionInputs{ + batch_size_q, + seq_len_q, + q_size, + k_size, + v_size, + input_q.data_type, + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc new file mode 100644 index 0000000000..26d3138eb4 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc @@ -0,0 +1,185 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml +/* proj-data +{ + "generated_from": "c57a9d1d2822a726ee9d9369d22e8e72" +} +*/ + +#include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" + +#include "op-attrs/datatype.dtg.h" +#include +#include + +namespace FlexFlow { +MultiHeadAttentionInputs::MultiHeadAttentionInputs( + size_t const &batch_size, + size_t const &sequence_length, + size_t const &query_size, + size_t const &key_size, + size_t const &value_size, + ::FlexFlow::DataType const &datatype) + : batch_size(batch_size), sequence_length(sequence_length), + query_size(query_size), key_size(key_size), value_size(value_size), + datatype(datatype) {} +bool MultiHeadAttentionInputs::operator==( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) == std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +bool MultiHeadAttentionInputs::operator!=( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) != std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +bool MultiHeadAttentionInputs::operator<( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) < std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +bool MultiHeadAttentionInputs::operator>( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) > std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +bool MultiHeadAttentionInputs::operator<=( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) <= std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +bool MultiHeadAttentionInputs::operator>=( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) >= std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MultiHeadAttentionInputs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.batch_size) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.sequence_length) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.query_size) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.key_size) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.value_size) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.datatype) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MultiHeadAttentionInputs + adl_serializer::from_json( + json const &j) { + return {j.at("batch_size").template get(), + j.at("sequence_length").template get(), + j.at("query_size").template get(), + j.at("key_size").template get(), + j.at("value_size").template get(), + j.at("datatype").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MultiHeadAttentionInputs const &v) { + j["__type"] = "MultiHeadAttentionInputs"; + j["batch_size"] = v.batch_size; + j["sequence_length"] = v.sequence_length; + j["query_size"] = v.query_size; + j["key_size"] = v.key_size; + j["value_size"] = v.value_size; + j["datatype"] = v.datatype; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionInputs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MultiHeadAttentionInputs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc new file mode 100644 index 0000000000..2cd5b7ec00 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc @@ -0,0 +1,132 @@ +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h" +#include "op-attrs/ops/attention/multihead_attention_inputs.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +template +static bool all_same(T const &x, T const &y, T const &z) { + return x == y && y == z; +} + +tl::expected + parse_attention_parallel_input_shape(ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v) { + tl::expected unpar_parse_result = + parse_attention_input_shape(get_reduced_shape(input_q), + get_reduced_shape(input_k), + get_reduced_shape(input_v)); + if (!unpar_parse_result.has_value()) { + return tl::unexpected( + fmt::format("MHA unparallel input parsing failed with message: \"{}\"", + unpar_parse_result.error())); + } + + if (num_shard_dims(input_q) != 3) { + return tl::unexpected( + fmt::format("Query input has incorrect number of dims: {} != {}", + num_shard_dims(input_q), + 3)); + } + if (num_shard_dims(input_k) != 3) { + return tl::unexpected( + fmt::format("Key input has incorrect number of dims: {} != {}", + num_shard_dims(input_k), + 3)); + } + if (num_shard_dims(input_v) != 3) { + return tl::unexpected( + fmt::format("Value input has incorrect number of dims: {} != {}", + num_shard_dims(input_v), + 3)); + } + + ShardParallelDim seq_len_q = shard_dim_at_idx(input_q, ff_dim_t{-2}); + if (seq_len_q.degree != 1) { + return tl::unexpected( + fmt::format("Query sequence length parallel degree expected to be 1, " + "but received degree {}", + seq_len_q.degree)); + } + + ShardParallelDim seq_len_k = shard_dim_at_idx(input_k, ff_dim_t{-2}); + if (seq_len_k.degree != 1) { + return tl::unexpected( + fmt::format("Key sequence length parallel degree expected to be 1, but " + "received degree {}", + seq_len_k.degree)); + } + + ShardParallelDim seq_len_v = shard_dim_at_idx(input_v, ff_dim_t{-2}); + if (seq_len_v.degree != 1) { + return tl::unexpected( + fmt::format("Value sequence length parallel degree expected to be 1, " + "but received degree {}", + seq_len_v.degree)); + } + + ShardParallelDim batch_size_q = shard_dim_at_idx(input_q, ff_dim_t{-3}); + ShardParallelDim batch_size_k = shard_dim_at_idx(input_k, ff_dim_t{-3}); + ShardParallelDim batch_size_v = shard_dim_at_idx(input_v, ff_dim_t{-3}); + + if (!all_same( + batch_size_q.degree, batch_size_k.degree, batch_size_v.degree)) { + return tl::unexpected( + fmt::format("Q, K, V disagree on the parallel degree of the batch " + "dimension: {} (Q) vs {} (K) vs {} (V)", + batch_size_q.degree, + batch_size_k.degree, + batch_size_v.degree)); + } + + ShardParallelDim query_dim = shard_dim_at_idx(input_q, ff_dim_t{-1}); + if (query_dim.degree > 1) { + return tl::unexpected( + fmt::format("Expected query tensor to have query dim parallel degree " + "1, but received degree {}", + query_dim.degree)); + } + + ShardParallelDim key_dim = shard_dim_at_idx(input_k, ff_dim_t{-1}); + if (key_dim.degree > 1) { + return tl::unexpected( + fmt::format("Expected key tensor to have key dim parallel degree 1, " + "but received degree {}", + key_dim.degree)); + } + + ShardParallelDim value_dim = shard_dim_at_idx(input_v, ff_dim_t{-1}); + if (value_dim.degree > 1) { + return tl::unexpected( + fmt::format("Expected value tensor to have value dim parallel degree " + "1, but received degree {}", + value_dim.degree)); + } + + int discard_copy_q = get_discard_copy_degree(input_q); + int discard_copy_k = get_discard_copy_degree(input_k); + int discard_copy_v = get_discard_copy_degree(input_v); + + if (!all_same(discard_copy_q, discard_copy_k, discard_copy_v)) { + return tl::unexpected(fmt::format("Q, K, V disagree on the discard-copy " + "degree: {} (Q) vs {} (K) vs {} (V)", + discard_copy_q, + discard_copy_k, + discard_copy_v)); + } + + return MultiHeadAttentionParallelInputs{ + batch_size_q, + seq_len_q, + query_dim, + key_dim, + value_dim, + discard_copy_q, + input_q.data_type, + }; + + // return; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc new file mode 100644 index 0000000000..94784d83cc --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc @@ -0,0 +1,209 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml +/* proj-data +{ + "generated_from": "7c434445707968123a361c038a337da2" +} +*/ + +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" + +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" +#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include +#include + +namespace FlexFlow { +MultiHeadAttentionParallelInputs::MultiHeadAttentionParallelInputs( + ::FlexFlow::ShardParallelDim const &batch_dim, + ::FlexFlow::ShardParallelDim const &sequence_dim, + ::FlexFlow::ShardParallelDim const &query_dim, + ::FlexFlow::ShardParallelDim const &key_dim, + ::FlexFlow::ShardParallelDim const &value_dim, + ::FlexFlow::DiscardCopyDegree const &discard_copy_degree, + ::FlexFlow::DataType const &datatype) + : batch_dim(batch_dim), sequence_dim(sequence_dim), query_dim(query_dim), + key_dim(key_dim), value_dim(value_dim), + discard_copy_degree(discard_copy_degree), datatype(datatype) {} +bool MultiHeadAttentionParallelInputs::operator==( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) == std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +bool MultiHeadAttentionParallelInputs::operator!=( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) != std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +bool MultiHeadAttentionParallelInputs::operator<( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) < std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +bool MultiHeadAttentionParallelInputs::operator>( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) > std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +bool MultiHeadAttentionParallelInputs::operator<=( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) <= std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +bool MultiHeadAttentionParallelInputs::operator>=( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) >= std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MultiHeadAttentionParallelInputs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.batch_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.sequence_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.query_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.key_dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.value_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DiscardCopyDegree>{}(x.discard_copy_degree) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.datatype) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MultiHeadAttentionParallelInputs + adl_serializer::from_json( + json const &j) { + return { + j.at("batch_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("sequence_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("query_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("key_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("value_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("discard_copy_degree").template get<::FlexFlow::DiscardCopyDegree>(), + j.at("datatype").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MultiHeadAttentionParallelInputs const &v) { + j["__type"] = "MultiHeadAttentionParallelInputs"; + j["batch_dim"] = v.batch_dim; + j["sequence_dim"] = v.sequence_dim; + j["query_dim"] = v.query_dim; + j["key_dim"] = v.key_dim; + j["value_dim"] = v.value_dim; + j["discard_copy_degree"] = v.discard_copy_degree; + j["datatype"] = v.datatype; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::DiscardCopyDegree>(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionParallelInputs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + MultiHeadAttentionParallelInputs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention_attrs.dtg.cc new file mode 100644 index 0000000000..ad0c094969 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention_attrs.dtg.cc @@ -0,0 +1,220 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml +/* proj-data +{ + "generated_from": "360324465947562229dc6632a9e9a2f3" +} +*/ + +#include "op-attrs/ops/attention_attrs.dtg.h" + +#include + +namespace FlexFlow { +MultiHeadAttentionAttrs::MultiHeadAttentionAttrs(int const &embed_dim, + int const &num_heads, + int const &kdim, + int const &vdim, + float const &dropout, + bool const &bias, + bool const &add_bias_kv, + bool const &add_zero_attn) + : embed_dim(embed_dim), num_heads(num_heads), kdim(kdim), vdim(vdim), + dropout(dropout), bias(bias), add_bias_kv(add_bias_kv), + add_zero_attn(add_zero_attn) {} +bool MultiHeadAttentionAttrs::operator==( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) == std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +bool MultiHeadAttentionAttrs::operator!=( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) != std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +bool MultiHeadAttentionAttrs::operator<( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) < std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +bool MultiHeadAttentionAttrs::operator>( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) > std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +bool MultiHeadAttentionAttrs::operator<=( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) <= std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +bool MultiHeadAttentionAttrs::operator>=( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) >= std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MultiHeadAttentionAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.embed_dim) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.num_heads) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.kdim) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.vdim) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.dropout) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.bias) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.add_bias_kv) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.add_zero_attn) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MultiHeadAttentionAttrs + adl_serializer::from_json( + json const &j) { + return {j.at("embed_dim").template get(), + j.at("num_heads").template get(), + j.at("kdim").template get(), + j.at("vdim").template get(), + j.at("dropout").template get(), + j.at("bias").template get(), + j.at("add_bias_kv").template get(), + j.at("add_zero_attn").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MultiHeadAttentionAttrs const &v) { + j["__type"] = "MultiHeadAttentionAttrs"; + j["embed_dim"] = v.embed_dim; + j["num_heads"] = v.num_heads; + j["kdim"] = v.kdim; + j["vdim"] = v.vdim; + j["dropout"] = v.dropout; + j["bias"] = v.bias; + j["add_bias_kv"] = v.add_bias_kv; + j["add_zero_attn"] = v.add_zero_attn; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MultiHeadAttentionAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc new file mode 100644 index 0000000000..cbda4ea533 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc @@ -0,0 +1,176 @@ +#include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +// bool BatchMatmulAttrs::is_valid( +// ParallelTensorShape const &lhs, +// ParallelTensorShape const &rhs) const { +// if (!lhs.is_valid() || !rhs.is_valid()) { +// return false; +// } +// if (lhs.num_dims() != rhs.num_dims()) { +// return false; +// } +// for (int i = lhs.num_dims() - 1; i >= 2; i--) { +// if (lhs.at(i) != rhs.at(i)) { +// return false; +// } +// } +// if (lhs.at(0) != rhs.at(1)) { +// return false; +// } +// +// return true; +// } + +bool is_valid(BatchMatmulAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +tl::expected + get_output_shape(BatchMatmulAttrs const &attrs, + TensorShape const &input_lhs, + TensorShape const &input_rhs) { + // If input_lhs is a (b×n×m) tensor, + // input_rhs is a (b×m×p) tensor, + // out will be a (b×n×p) tensor. + // https://pytorch.org/docs/stable/generated/torch.bmm.html + + if (num_dims(input_lhs) != 3) { + return tl::unexpected( + fmt::format("LHS input has incorrect number of shard dims: {} != {}", + num_dims(input_lhs), + 3)); + } + if (num_dims(input_rhs) != 3) { + return tl::unexpected( + fmt::format("RHS input has incorrect number of shard dims: {} != {}", + num_dims(input_rhs), + 3)); + } + if (input_lhs.data_type != input_rhs.data_type) { + return tl::unexpected(fmt::format("Input datatypes do not match: {} != {}", + input_lhs.data_type, + input_rhs.data_type)); + } + + size_t lhs_b = dim_at_idx(input_lhs, ff_dim_t{0}); + size_t n = dim_at_idx(input_lhs, ff_dim_t{1}); + size_t lhs_m = dim_at_idx(input_lhs, ff_dim_t{2}); + + size_t rhs_b = dim_at_idx(input_rhs, ff_dim_t{0}); + size_t rhs_m = dim_at_idx(input_rhs, ff_dim_t{1}); + size_t p = dim_at_idx(input_rhs, ff_dim_t{2}); + + if (lhs_b != rhs_b) { + return tl::unexpected( + fmt::format("LHS b dim ({}) != RHS b dim ({})", lhs_b, rhs_b)); + } + if (lhs_m != rhs_m) { + return tl::unexpected( + fmt::format("RHS m dim ({}) != RHS m dim ({})", lhs_m, rhs_m)); + } + + return TensorShape{ + TensorDims{ + FFOrdered{ + lhs_b, + n, + p, + }, + }, + input_lhs.data_type, + }; +} + +tl::expected + get_output_shape(BatchMatmulAttrs const &attrs, + ParallelTensorShape const &input_lhs, + ParallelTensorShape const &input_rhs) { + if (num_shard_dims(input_lhs) != 3) { + return tl::unexpected( + fmt::format("LHS input has incorrect number of shard dims: {} != {}", + num_shard_dims(input_lhs), + 3)); + } + if (num_shard_dims(input_rhs) != 3) { + return tl::unexpected( + fmt::format("RHS input has incorrect number of shard dims: {} != {}", + num_shard_dims(input_rhs), + 3)); + } + if (input_lhs.data_type != input_rhs.data_type) { + return tl::unexpected(fmt::format("Input datatypes do not match: {} != {}", + input_lhs.data_type, + input_rhs.data_type)); + } + + assert(get_total_parallel_degree(input_lhs) == + get_total_parallel_degree(input_rhs)); + + ShardParallelDim lhs_b = shard_dim_at_idx(input_lhs, ff_dim_t{0}); + ShardParallelDim n = shard_dim_at_idx(input_lhs, ff_dim_t{1}); + ShardParallelDim lhs_m = shard_dim_at_idx(input_lhs, ff_dim_t{2}); + + ShardParallelDim rhs_b = shard_dim_at_idx(input_rhs, ff_dim_t{0}); + ShardParallelDim rhs_m = shard_dim_at_idx(input_rhs, ff_dim_t{1}); + ShardParallelDim p = shard_dim_at_idx(input_rhs, ff_dim_t{2}); + + if (lhs_b != rhs_b) { + return tl::unexpected( + fmt::format("LHS b dim ({}) != RHS b dim ({})", lhs_b, rhs_b)); + } + + if (lhs_m != rhs_m) { + return tl::unexpected( + fmt::format("LHS m dim ({}) != RHS m dim ({})", lhs_m, rhs_m)); + } + + if (get_discard_copy_degree(input_lhs) != + get_sum_degree(input_rhs) * p.degree) { + return tl::unexpected(fmt::format("Unexpected number of replicas in LHS: " + "lhs.= ({}) != rhs.+ ({}) * rhs.p ({})", + get_discard_copy_degree(input_lhs), + get_sum_degree(input_rhs), + p.degree)); + } + + if (get_discard_copy_degree(input_rhs) != + get_sum_degree(input_lhs) * n.degree) { + return tl::unexpected(fmt::format("Unexpected number of replicas in RHS: " + "rhs.= ({}) != lhs.+ ({}) * lhs.n ({})", + get_discard_copy_degree(input_rhs), + get_sum_degree(input_lhs), + n.degree)); + } + + ShardParallelDim output_b = lhs_b; + ShardParallelDim output_n = n; + ShardParallelDim output_p = p; + + int output_discard_copy_degree = 1; + int output_sum_degree = get_total_parallel_degree(input_lhs) / + (output_b.degree * output_n.degree * output_p.degree); + + ParallelTensorShape result = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + output_b, + output_n, + output_p, + }, + ReplicaParallelDimSet{ + output_sum_degree, + output_discard_copy_degree, + }, + }, + input_lhs.data_type, + }; + + return result; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.dtg.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.dtg.cc new file mode 100644 index 0000000000..f178d40696 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/batch_matmul.dtg.cc @@ -0,0 +1,90 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml +/* proj-data +{ + "generated_from": "c3bbf4c76982ef27107b74e1e6e5d360" +} +*/ + +#include "op-attrs/ops/batch_matmul.dtg.h" + +#include + +namespace FlexFlow { +BatchMatmulAttrs::BatchMatmulAttrs(int const &a_seq_length_dim, + int const &b_seq_length_dim) + : a_seq_length_dim(a_seq_length_dim), b_seq_length_dim(b_seq_length_dim) {} +bool BatchMatmulAttrs::operator==(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) == + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator!=(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) != + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator<(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) < + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator>(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) > + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator<=(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) <= + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator>=(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) >= + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::BatchMatmulAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.a_seq_length_dim) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.b_seq_length_dim) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::BatchMatmulAttrs + adl_serializer::from_json(json const &j) { + return {j.at("a_seq_length_dim").template get(), + j.at("b_seq_length_dim").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::BatchMatmulAttrs const &v) { + j["__type"] = "BatchMatmulAttrs"; + j["a_seq_length_dim"] = v.a_seq_length_dim; + j["b_seq_length_dim"] = v.b_seq_length_dim; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(BatchMatmulAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, BatchMatmulAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc new file mode 100644 index 0000000000..7be51efa22 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/batch_norm.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(BatchNormAttrs const &, + ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.dtg.cc new file mode 100644 index 0000000000..cb8dcadae1 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.dtg.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml +/* proj-data +{ + "generated_from": "f8e0219d8a3e008a73c38cf84d25f66e" +} +*/ + +#include "op-attrs/ops/batch_norm_attrs.dtg.h" + +#include + +namespace FlexFlow { +BatchNormAttrs::BatchNormAttrs(bool const &relu) : relu(relu) {} +bool BatchNormAttrs::operator==(BatchNormAttrs const &other) const { + return std::tie(this->relu) == std::tie(other.relu); +} +bool BatchNormAttrs::operator!=(BatchNormAttrs const &other) const { + return std::tie(this->relu) != std::tie(other.relu); +} +bool BatchNormAttrs::operator<(BatchNormAttrs const &other) const { + return std::tie(this->relu) < std::tie(other.relu); +} +bool BatchNormAttrs::operator>(BatchNormAttrs const &other) const { + return std::tie(this->relu) > std::tie(other.relu); +} +bool BatchNormAttrs::operator<=(BatchNormAttrs const &other) const { + return std::tie(this->relu) <= std::tie(other.relu); +} +bool BatchNormAttrs::operator>=(BatchNormAttrs const &other) const { + return std::tie(this->relu) >= std::tie(other.relu); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::BatchNormAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.relu) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::BatchNormAttrs + adl_serializer::from_json(json const &j) { + return {j.at("relu").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::BatchNormAttrs const &v) { + j["__type"] = "BatchNormAttrs"; + j["relu"] = v.relu; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(BatchNormAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, BatchNormAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc new file mode 100644 index 0000000000..ec08bd6a1d --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc @@ -0,0 +1,81 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml +/* proj-data +{ + "generated_from": "12715c970e8416eacbd0750f338478e5" +} +*/ + +#include "op-attrs/ops/broadcast.dtg.h" + +#include "utils/stack_vector.h" +#include + +namespace FlexFlow { +BroadcastAttrs::BroadcastAttrs( + ::FlexFlow::stack_vector const &target_dims) + : target_dims(target_dims) {} +bool BroadcastAttrs::operator==(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) == std::tie(other.target_dims); +} +bool BroadcastAttrs::operator!=(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) != std::tie(other.target_dims); +} +bool BroadcastAttrs::operator<(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) < std::tie(other.target_dims); +} +bool BroadcastAttrs::operator>(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) > std::tie(other.target_dims); +} +bool BroadcastAttrs::operator<=(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) <= std::tie(other.target_dims); +} +bool BroadcastAttrs::operator>=(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) >= std::tie(other.target_dims); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::BroadcastAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::stack_vector>{}( + x.target_dims) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::BroadcastAttrs + adl_serializer::from_json(json const &j) { + return {j.at("target_dims") + .template get<::FlexFlow::stack_vector>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::BroadcastAttrs const &v) { + j["__type"] = "BroadcastAttrs"; + j["target_dims"] = v.target_dims; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::stack_vector>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(BroadcastAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, BroadcastAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/cast.cc b/lib/op-attrs/src/op-attrs/ops/cast.cc similarity index 100% rename from lib/op-attrs/src/cast.cc rename to lib/op-attrs/src/op-attrs/ops/cast.cc diff --git a/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc new file mode 100644 index 0000000000..28367f3449 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml +/* proj-data +{ + "generated_from": "c171c87db89b9ec9ea7d52a50c153054" +} +*/ + +#include "op-attrs/ops/cast_attrs.dtg.h" + +#include "op-attrs/datatype.h" +#include + +namespace FlexFlow { +CastAttrs::CastAttrs(DataType const &dtype) : dtype(dtype) {} +bool CastAttrs::operator==(CastAttrs const &other) const { + return std::tie(this->dtype) == std::tie(other.dtype); +} +bool CastAttrs::operator!=(CastAttrs const &other) const { + return std::tie(this->dtype) != std::tie(other.dtype); +} +bool CastAttrs::operator<(CastAttrs const &other) const { + return std::tie(this->dtype) < std::tie(other.dtype); +} +bool CastAttrs::operator>(CastAttrs const &other) const { + return std::tie(this->dtype) > std::tie(other.dtype); +} +bool CastAttrs::operator<=(CastAttrs const &other) const { + return std::tie(this->dtype) <= std::tie(other.dtype); +} +bool CastAttrs::operator>=(CastAttrs const &other) const { + return std::tie(this->dtype) >= std::tie(other.dtype); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(FlexFlow::CastAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.dtype) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::CastAttrs + adl_serializer::from_json(json const &j) { + return {j.at("dtype").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::CastAttrs const &v) { + j["__type"] = "CastAttrs"; + j["dtype"] = v.dtype; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(CastAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, CastAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/combine.cc b/lib/op-attrs/src/op-attrs/ops/combine.cc new file mode 100644 index 0000000000..e41b78c5af --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/combine.cc @@ -0,0 +1,37 @@ +#include "op-attrs/ops/combine.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +tl::expected + get_output_shape(CombineAttrs const &attrs, + ParallelTensorShape const &input) { + ShardParallelDim input_dim = ({ + std::optional result = + try_get_shard_dim_at_idx(input, attrs.combine_dim); + if (!result.has_value()) { + return tl::unexpected(fmt::format( + "Failed to get shard dim at index {} in parallel tensor shape {}", + attrs.combine_dim, + input)); + } + + result.value(); + }); + + if (input_dim.degree % attrs.combine_degree != 0) { + return tl::unexpected( + fmt::format("Combine received tensor containing parallel dim {} with " + "degree {}, which is not divisible by combine degree {}", + attrs.combine_dim, + input_dim.degree, + attrs.combine_degree)); + } + + ParallelTensorShape output = input; + shard_dim_at_idx(output, attrs.combine_dim).degree /= attrs.combine_degree; + + return output; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc new file mode 100644 index 0000000000..516d3b0318 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc @@ -0,0 +1,91 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml +/* proj-data +{ + "generated_from": "58fc5a388fd1a325ef4142094607e39a" +} +*/ + +#include "op-attrs/ops/combine_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include + +namespace FlexFlow { +CombineAttrs::CombineAttrs(::FlexFlow::ff_dim_t const &combine_dim, + int const &combine_degree) + : combine_dim(combine_dim), combine_degree(combine_degree) {} +bool CombineAttrs::operator==(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) == + std::tie(other.combine_dim, other.combine_degree); +} +bool CombineAttrs::operator!=(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) != + std::tie(other.combine_dim, other.combine_degree); +} +bool CombineAttrs::operator<(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) < + std::tie(other.combine_dim, other.combine_degree); +} +bool CombineAttrs::operator>(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) > + std::tie(other.combine_dim, other.combine_degree); +} +bool CombineAttrs::operator<=(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) <= + std::tie(other.combine_dim, other.combine_degree); +} +bool CombineAttrs::operator>=(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) >= + std::tie(other.combine_dim, other.combine_degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::CombineAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.combine_dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.combine_degree) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::CombineAttrs + adl_serializer::from_json(json const &j) { + return {j.at("combine_dim").template get<::FlexFlow::ff_dim_t>(), + j.at("combine_degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::CombineAttrs const &v) { + j["__type"] = "CombineAttrs"; + j["combine_dim"] = v.combine_dim; + j["combine_degree"] = v.combine_degree; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(CombineAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, CombineAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/concat.cc b/lib/op-attrs/src/op-attrs/ops/concat.cc similarity index 100% rename from lib/op-attrs/src/concat.cc rename to lib/op-attrs/src/op-attrs/ops/concat.cc diff --git a/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc new file mode 100644 index 0000000000..20db25d485 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc @@ -0,0 +1,91 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +/* proj-data +{ + "generated_from": "68e0520b143e0579140a2f2cdd390759" +} +*/ + +#include "op-attrs/ops/concat_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include + +namespace FlexFlow { +ConcatAttrs::ConcatAttrs(::FlexFlow::ff_dim_t const &axis, + int const &num_inputs) + : axis(axis), num_inputs(num_inputs) {} +bool ConcatAttrs::operator==(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) == + std::tie(other.axis, other.num_inputs); +} +bool ConcatAttrs::operator!=(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) != + std::tie(other.axis, other.num_inputs); +} +bool ConcatAttrs::operator<(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) < + std::tie(other.axis, other.num_inputs); +} +bool ConcatAttrs::operator>(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) > + std::tie(other.axis, other.num_inputs); +} +bool ConcatAttrs::operator<=(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) <= + std::tie(other.axis, other.num_inputs); +} +bool ConcatAttrs::operator>=(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) >= + std::tie(other.axis, other.num_inputs); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ConcatAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.axis) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.num_inputs) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ConcatAttrs + adl_serializer::from_json(json const &j) { + return {j.at("axis").template get<::FlexFlow::ff_dim_t>(), + j.at("num_inputs").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ConcatAttrs const &v) { + j["__type"] = "ConcatAttrs"; + j["axis"] = v.axis; + j["num_inputs"] = v.num_inputs; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ConcatAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ConcatAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc new file mode 100644 index 0000000000..c9ec467af4 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -0,0 +1,176 @@ +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/conv_2d/conv_2d_input_shape.h" +#include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +TensorShape get_kernel_shape(Conv2DAttrs const &attrs, + TensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DInputShape input = parse_input_shape(raw_input_shape); + + return TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(attrs.out_channels), + input.num_channels, + size_t_from_int(attrs.kernel_h), + size_t_from_int(attrs.kernel_w), + }}, + input.datatype, + }; +} + +TensorShape get_bias_shape(Conv2DAttrs const &attrs, + TensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DInputShape input = parse_input_shape(raw_input_shape); + + return TensorShape{ + TensorDims{ + FFOrdered{size_t_from_int(attrs.out_channels)}, + }, + input.datatype, + }; +} + +TensorShape get_output_shape(Conv2DAttrs const &attrs, + TensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DInputShape input = parse_input_shape(raw_input_shape); + + size_t out_height = + (input.height - (2 * attrs.padding_h) - (attrs.kernel_h - 1)) / + attrs.stride_h; + size_t out_width = + (input.width - (2 * attrs.padding_w) - (attrs.kernel_w - 1)) / + attrs.stride_w; + + assert(attrs.out_channels > 0); + + return TensorShape{TensorDims{FFOrdered{ + input.num_samples, + size_t_from_int(attrs.out_channels), + out_height, + out_width, + }}, + input.datatype}; +} + +ParallelTensorShape + get_kernel_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); + + ShardParallelDim output_channels_dim = {size_t_from_int(attrs.out_channels), + input.discard_copy_reduction_degree}; + ShardParallelDim input_channels_dim = { + size_t_from_int(input.channel_dim.size), input.channel_dim.degree}; + ShardParallelDim kernel_height_dim = {size_t_from_int(attrs.kernel_h), 1}; + ShardParallelDim kernel_width_dim = {size_t_from_int(attrs.kernel_w), 1}; + + int sum_degree = 1; + int discard_copy_degree = input.height_dim.degree * input.width_dim.degree * + input.sum_reduction_degree; + + ParallelTensorShape result = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + output_channels_dim, + input_channels_dim, + kernel_height_dim, + kernel_width_dim, + }, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }, + }, + input.datatype, + }; + + assert(total_parallel_degree(result.dims) == + total_parallel_degree(raw_input_shape.dims)); + + return result; +} + +ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); + + ShardParallelDim output_channels_dim = {size_t_from_int(attrs.out_channels), + input.discard_copy_reduction_degree}; + + int sum_degree = 1; + int discard_copy_degree = input.height_dim.degree * input.width_dim.degree * + input.sum_reduction_degree * + input.channel_dim.degree; + + ParallelTensorShape result = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + output_channels_dim, + }, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }, + }, + input.datatype, + }; + + assert(total_parallel_degree(result.dims) == + total_parallel_degree(raw_input_shape.dims)); + + return result; +} + +ParallelTensorShape + get_output_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); + + TensorShape unpar_output_shape = + get_output_shape(attrs, get_reduced_shape(raw_input_shape)); + + size_t num_samples = dim_at_idx(unpar_output_shape, ff_dim_t{0}); + size_t num_channels = dim_at_idx(unpar_output_shape, ff_dim_t{1}); + size_t height = dim_at_idx(unpar_output_shape, ff_dim_t{2}); + size_t width = dim_at_idx(unpar_output_shape, ff_dim_t{3}); + + ShardParallelDim sample_dim = {num_samples, input.sample_dim.degree}; + ShardParallelDim channel_dim = {num_channels, + input.discard_copy_reduction_degree}; + ShardParallelDim height_dim = {height, input.height_dim.degree}; + ShardParallelDim width_dim = {width, input.width_dim.degree}; + + int sum_degree = input.channel_dim.degree * input.sum_reduction_degree; + int discard_copy_degree = 1; + + ParallelTensorShape result = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + sample_dim, + channel_dim, + height_dim, + width_dim, + }, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }, + }, + input.datatype, + }; + + assert(total_parallel_degree(result.dims) == + total_parallel_degree(raw_input_shape.dims)); + + return result; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc new file mode 100644 index 0000000000..a8a3b10bdf --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc @@ -0,0 +1,23 @@ +#include "op-attrs/ops/conv_2d/conv_2d_input_shape.h" +#include "op-attrs/tensor_shape.h" + +namespace FlexFlow { + +Conv2DInputShape parse_input_shape(TensorShape const &input) { + assert(num_dims(input) == 4); + + size_t num_samples = dim_at_idx(input, ff_dim_t{0}); + size_t in_channels = dim_at_idx(input, ff_dim_t{1}); + size_t in_height = dim_at_idx(input, ff_dim_t{2}); + size_t in_width = dim_at_idx(input, ff_dim_t{3}); + + return Conv2DInputShape{ + num_samples, + in_channels, + in_height, + in_width, + input.data_type, + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc new file mode 100644 index 0000000000..74df30e2d7 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc @@ -0,0 +1,157 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml +/* proj-data +{ + "generated_from": "51911f58c134d55b2d0245444acbae53" +} +*/ + +#include "op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h" + +#include "op-attrs/datatype.dtg.h" +#include +#include + +namespace FlexFlow { +Conv2DInputShape::Conv2DInputShape(size_t const &num_samples, + size_t const &num_channels, + size_t const &height, + size_t const &width, + ::FlexFlow::DataType const &datatype) + : num_samples(num_samples), num_channels(num_channels), height(height), + width(width), datatype(datatype) {} +bool Conv2DInputShape::operator==(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) == std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +bool Conv2DInputShape::operator!=(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) != std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +bool Conv2DInputShape::operator<(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) < std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +bool Conv2DInputShape::operator>(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) > std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +bool Conv2DInputShape::operator<=(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) <= std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +bool Conv2DInputShape::operator>=(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) >= std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::Conv2DInputShape const &x) const { + size_t result = 0; + result ^= std::hash{}(x.num_samples) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.num_channels) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.height) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.width) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.datatype) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::Conv2DInputShape + adl_serializer::from_json(json const &j) { + return {j.at("num_samples").template get(), + j.at("num_channels").template get(), + j.at("height").template get(), + j.at("width").template get(), + j.at("datatype").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::Conv2DInputShape const &v) { + j["__type"] = "Conv2DInputShape"; + j["num_samples"] = v.num_samples; + j["num_channels"] = v.num_channels; + j["height"] = v.height; + j["width"] = v.width; + j["datatype"] = v.datatype; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DInputShape const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, Conv2DInputShape const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc new file mode 100644 index 0000000000..32ac4547f1 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc @@ -0,0 +1,26 @@ +#include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +Conv2DParallelInputShape + parse_parallel_input_shape(ParallelTensorShape const &input) { + assert(num_shard_dims(input) == 4); + + ShardParallelDim sample_dim = shard_dim_at_idx(input, ff_dim_t{0}); + ShardParallelDim channel_dim = shard_dim_at_idx(input, ff_dim_t{1}); + ShardParallelDim height_dim = shard_dim_at_idx(input, ff_dim_t{2}); + ShardParallelDim width_dim = shard_dim_at_idx(input, ff_dim_t{3}); + + return Conv2DParallelInputShape{ + sample_dim, + channel_dim, + height_dim, + width_dim, + get_sum_degree(input), + get_discard_copy_degree(input), + input.data_type, + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc new file mode 100644 index 0000000000..df854c2b8f --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc @@ -0,0 +1,211 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml +/* proj-data +{ + "generated_from": "d80394bdc90f843372760310b6d17a22" +} +*/ + +#include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h" + +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include + +namespace FlexFlow { +Conv2DParallelInputShape::Conv2DParallelInputShape( + ::FlexFlow::ShardParallelDim const &sample_dim, + ::FlexFlow::ShardParallelDim const &channel_dim, + ::FlexFlow::ShardParallelDim const &height_dim, + ::FlexFlow::ShardParallelDim const &width_dim, + int const &sum_reduction_degree, + int const &discard_copy_reduction_degree, + ::FlexFlow::DataType const &datatype) + : sample_dim(sample_dim), channel_dim(channel_dim), height_dim(height_dim), + width_dim(width_dim), sum_reduction_degree(sum_reduction_degree), + discard_copy_reduction_degree(discard_copy_reduction_degree), + datatype(datatype) {} +bool Conv2DParallelInputShape::operator==( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) == + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +bool Conv2DParallelInputShape::operator!=( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) != + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +bool Conv2DParallelInputShape::operator<( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) < + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +bool Conv2DParallelInputShape::operator>( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) > + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +bool Conv2DParallelInputShape::operator<=( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) <= + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +bool Conv2DParallelInputShape::operator>=( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) >= + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::Conv2DParallelInputShape const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.sample_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.channel_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.height_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.width_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.sum_reduction_degree) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.discard_copy_reduction_degree) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.datatype) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::Conv2DParallelInputShape + adl_serializer::from_json( + json const &j) { + return {j.at("sample_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("channel_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("height_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("width_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("sum_reduction_degree").template get(), + j.at("discard_copy_reduction_degree").template get(), + j.at("datatype").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::Conv2DParallelInputShape const &v) { + j["__type"] = "Conv2DParallelInputShape"; + j["sample_dim"] = v.sample_dim; + j["channel_dim"] = v.channel_dim; + j["height_dim"] = v.height_dim; + j["width_dim"] = v.width_dim; + j["sum_reduction_degree"] = v.sum_reduction_degree; + j["discard_copy_reduction_degree"] = v.discard_copy_reduction_degree; + j["datatype"] = v.datatype; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DParallelInputShape const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, Conv2DParallelInputShape const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc new file mode 100644 index 0000000000..238b349cbe --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc @@ -0,0 +1,256 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml +/* proj-data +{ + "generated_from": "74f98e1aacb57d847bb450e1d28d3e67" +} +*/ + +#include "op-attrs/ops/conv_2d_attrs.dtg.h" + +#include "op-attrs/activation.dtg.h" +#include "utils/json.h" +#include +#include + +namespace FlexFlow { +Conv2DAttrs::Conv2DAttrs( + int const &out_channels, + int const &kernel_h, + int const &kernel_w, + int const &stride_h, + int const &stride_w, + int const &padding_h, + int const &padding_w, + int const &groups, + std::optional<::FlexFlow::Activation> const &activation, + bool const &use_bias) + : out_channels(out_channels), kernel_h(kernel_h), kernel_w(kernel_w), + stride_h(stride_h), stride_w(stride_w), padding_h(padding_h), + padding_w(padding_w), groups(groups), activation(activation), + use_bias(use_bias) {} +bool Conv2DAttrs::operator==(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) == std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +bool Conv2DAttrs::operator!=(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) != std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +bool Conv2DAttrs::operator<(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) < std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +bool Conv2DAttrs::operator>(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) > std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +bool Conv2DAttrs::operator<=(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) <= std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +bool Conv2DAttrs::operator>=(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) >= std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::Conv2DAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.out_channels) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.kernel_h) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.kernel_w) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stride_h) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stride_w) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.padding_h) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.padding_w) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.groups) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash>{}(x.activation) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.use_bias) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::Conv2DAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("out_channels").template get(), + j.at("kernel_h").template get(), + j.at("kernel_w").template get(), + j.at("stride_h").template get(), + j.at("stride_w").template get(), + j.at("padding_h").template get(), + j.at("padding_w").template get(), + j.at("groups").template get(), + j.at("activation").template get>(), + j.at("use_bias").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::Conv2DAttrs const &v) { + j["__type"] = "Conv2DAttrs"; + j["out_channels"] = v.out_channels; + j["kernel_h"] = v.kernel_h; + j["kernel_w"] = v.kernel_w; + j["stride_h"] = v.stride_h; + j["stride_w"] = v.stride_w; + j["padding_h"] = v.padding_h; + j["padding_w"] = v.padding_w; + j["groups"] = v.groups; + j["activation"] = v.activation; + j["use_bias"] = v.use_bias; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary>(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, Conv2DAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/dropout.cc b/lib/op-attrs/src/op-attrs/ops/dropout.cc new file mode 100644 index 0000000000..adbd144f38 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/dropout.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/dropout.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(DropoutAttrs const &, + ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/dropout_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/dropout_attrs.dtg.cc new file mode 100644 index 0000000000..284443a0e4 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/dropout_attrs.dtg.cc @@ -0,0 +1,82 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml +/* proj-data +{ + "generated_from": "4fdbf129ea59b8a7306813cfa4c46021" +} +*/ + +#include "op-attrs/ops/dropout_attrs.dtg.h" + +#include + +namespace FlexFlow { +DropoutAttrs::DropoutAttrs(float const &rate, unsigned long long const &seed) + : rate(rate), seed(seed) {} +bool DropoutAttrs::operator==(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) == std::tie(other.rate, other.seed); +} +bool DropoutAttrs::operator!=(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) != std::tie(other.rate, other.seed); +} +bool DropoutAttrs::operator<(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) < std::tie(other.rate, other.seed); +} +bool DropoutAttrs::operator>(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) > std::tie(other.rate, other.seed); +} +bool DropoutAttrs::operator<=(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) <= std::tie(other.rate, other.seed); +} +bool DropoutAttrs::operator>=(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) >= std::tie(other.rate, other.seed); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::DropoutAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.rate) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.seed) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::DropoutAttrs + adl_serializer::from_json(json const &j) { + return {j.at("rate").template get(), + j.at("seed").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::DropoutAttrs const &v) { + j["__type"] = "DropoutAttrs"; + j["rate"] = v.rate; + j["seed"] = v.seed; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(DropoutAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DropoutAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary.cc b/lib/op-attrs/src/op-attrs/ops/element_binary.cc new file mode 100644 index 0000000000..16957a036c --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/element_binary.cc @@ -0,0 +1,75 @@ +#include "op-attrs/ops/element_binary.h" + +namespace FlexFlow { + +tl::expected + get_output_shape(ElementBinaryAttrs const &attrs, + TensorShape const &input_lhs, + TensorShape const &input_rhs) { + assert(!(attrs.should_broadcast_lhs && attrs.should_broadcast_rhs)); + + if (attrs.should_broadcast_lhs) { + NOT_IMPLEMENTED(); + } else if (attrs.should_broadcast_rhs) { + NOT_IMPLEMENTED(); + } else { + if (input_lhs != input_rhs) { + return tl::unexpected(fmt::format( + "Expected input shapes to match, but receieved LHS ({}) != RHS ({})", + input_lhs, + input_rhs)); + } + + return input_lhs; + } +} + +tl::expected + get_output_shape(ElementBinaryAttrs const &attrs, + ParallelTensorShape const &input_lhs, + ParallelTensorShape const &input_rhs) { + assert(!(attrs.should_broadcast_lhs && attrs.should_broadcast_rhs)); + + if (attrs.should_broadcast_lhs) { + NOT_IMPLEMENTED(); + } else if (attrs.should_broadcast_rhs) { + NOT_IMPLEMENTED(); + } else { + if (input_lhs != input_rhs) { + return tl::unexpected(fmt::format( + "Expected input shapes to match, but receieved LHS ({}) != RHS ({})", + input_lhs, + input_rhs)); + } + + switch (attrs.type) { + case OperatorType::EW_ADD: { + if (get_discard_copy_degree(input_lhs) != 1) { + return tl::unexpected( + fmt::format("Elementwise Add expected discard copy degree of " + "inputs to be 1, but receieved {}", + get_discard_copy_degree(input_lhs))); + } + + break; + } + case OperatorType::EW_SUB: + NOT_IMPLEMENTED(); + case OperatorType::EW_MUL: + NOT_IMPLEMENTED(); + case OperatorType::EW_DIV: + NOT_IMPLEMENTED(); + case OperatorType::EW_MAX: + NOT_IMPLEMENTED(); + case OperatorType::EW_MIN: + NOT_IMPLEMENTED(); + default: + return tl::unexpected(fmt::format( + "Unexpected element-wise binary operator {}", attrs.type)); + } + + return input_lhs; + } +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc new file mode 100644 index 0000000000..a0e555cb12 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc @@ -0,0 +1,145 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml +/* proj-data +{ + "generated_from": "2bb947c9cc92e3833ee88c908c539629" +} +*/ + +#include "op-attrs/ops/element_binary_attrs.dtg.h" + +#include "op-attrs/datatype.h" +#include "op-attrs/operator_type.h" +#include + +namespace FlexFlow { +ElementBinaryAttrs::ElementBinaryAttrs(::FlexFlow::OperatorType const &type, + ::FlexFlow::DataType const &compute_type, + bool const &should_broadcast_lhs, + bool const &should_broadcast_rhs) + : type(type), compute_type(compute_type), + should_broadcast_lhs(should_broadcast_lhs), + should_broadcast_rhs(should_broadcast_rhs) {} +bool ElementBinaryAttrs::operator==(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) == + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +bool ElementBinaryAttrs::operator!=(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) != + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +bool ElementBinaryAttrs::operator<(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) < + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +bool ElementBinaryAttrs::operator>(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) > + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +bool ElementBinaryAttrs::operator<=(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) <= + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +bool ElementBinaryAttrs::operator>=(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) >= + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ElementBinaryAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OperatorType>{}(x.type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.compute_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.should_broadcast_lhs) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.should_broadcast_rhs) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ElementBinaryAttrs + adl_serializer::from_json(json const &j) { + return {j.at("type").template get<::FlexFlow::OperatorType>(), + j.at("compute_type").template get<::FlexFlow::DataType>(), + j.at("should_broadcast_lhs").template get(), + j.at("should_broadcast_rhs").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ElementBinaryAttrs const &v) { + j["__type"] = "ElementBinaryAttrs"; + j["type"] = v.type; + j["compute_type"] = v.compute_type; + j["should_broadcast_lhs"] = v.should_broadcast_lhs; + j["should_broadcast_rhs"] = v.should_broadcast_rhs; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::OperatorType>(), + gen::arbitrary<::FlexFlow::DataType>(), + gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ElementBinaryAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ElementBinaryAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc new file mode 100644 index 0000000000..ee85474caf --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc @@ -0,0 +1,98 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml +/* proj-data +{ + "generated_from": "aa6f98b992d46bdf7ad59158bc143a3f" +} +*/ + +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" + +#include "op-attrs/operator_type.h" +#include + +namespace FlexFlow { +ElementScalarUnaryAttrs::ElementScalarUnaryAttrs( + ::FlexFlow::OperatorType const &op_type, float const &scalar) + : op_type(op_type), scalar(scalar) {} +bool ElementScalarUnaryAttrs::operator==( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) == + std::tie(other.op_type, other.scalar); +} +bool ElementScalarUnaryAttrs::operator!=( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) != + std::tie(other.op_type, other.scalar); +} +bool ElementScalarUnaryAttrs::operator<( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) < + std::tie(other.op_type, other.scalar); +} +bool ElementScalarUnaryAttrs::operator>( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) > + std::tie(other.op_type, other.scalar); +} +bool ElementScalarUnaryAttrs::operator<=( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) <= + std::tie(other.op_type, other.scalar); +} +bool ElementScalarUnaryAttrs::operator>=( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) >= + std::tie(other.op_type, other.scalar); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ElementScalarUnaryAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OperatorType>{}(x.op_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.scalar) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ElementScalarUnaryAttrs + adl_serializer::from_json( + json const &j) { + return {j.at("op_type").template get<::FlexFlow::OperatorType>(), + j.at("scalar").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ElementScalarUnaryAttrs const &v) { + j["__type"] = "ElementScalarUnaryAttrs"; + j["op_type"] = v.op_type; + j["scalar"] = v.scalar; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::OperatorType>(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ElementScalarUnaryAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ElementScalarUnaryAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/src/op-attrs/ops/element_unary.cc new file mode 100644 index 0000000000..f703799ef3 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/element_unary.cc @@ -0,0 +1,54 @@ +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +tl::expected + get_output_shape(ElementUnaryAttrs const &attrs, + TensorShape const &input_shape) { + return input_shape; +} + +tl::expected + get_output_shape(ElementUnaryAttrs const &attrs, + ParallelTensorShape const &input_shape) { + if (get_sum_degree(input_shape) != 1) { + return tl::unexpected( + fmt::format("Expected sum degree 1, but receieved sum degree {}", + get_sum_degree(input_shape))); + } + + if (get_discard_copy_degree(input_shape) != 1) { + return tl::unexpected(fmt::format( + "Expected discard copy degree 1, but received discartd copy degree {}", + get_discard_copy_degree(input_shape))); + } + + return input_shape; +} + +tl::expected + get_output_shape(ElementScalarUnaryAttrs const &attrs, + TensorShape const &input_shape) { + return input_shape; +} + +tl::expected + get_output_shape(ElementScalarUnaryAttrs const &attrs, + ParallelTensorShape const &input_shape) { + if (get_sum_degree(input_shape) != 1) { + return tl::unexpected( + fmt::format("Expected sum degree 1, but receieved sum degree {}", + get_sum_degree(input_shape))); + } + + if (get_discard_copy_degree(input_shape) != 1) { + return tl::unexpected(fmt::format( + "Expected discard copy degree 1, but received discartd copy degree {}", + get_discard_copy_degree(input_shape))); + } + + return input_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc new file mode 100644 index 0000000000..bf90a3db7d --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc @@ -0,0 +1,79 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml +/* proj-data +{ + "generated_from": "75272cff78d3db866122dbb1001aedbe" +} +*/ + +#include "op-attrs/ops/element_unary_attrs.dtg.h" + +#include "op-attrs/operator_type.h" +#include + +namespace FlexFlow { +ElementUnaryAttrs::ElementUnaryAttrs(::FlexFlow::OperatorType const &op_type) + : op_type(op_type) {} +bool ElementUnaryAttrs::operator==(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) == std::tie(other.op_type); +} +bool ElementUnaryAttrs::operator!=(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) != std::tie(other.op_type); +} +bool ElementUnaryAttrs::operator<(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) < std::tie(other.op_type); +} +bool ElementUnaryAttrs::operator>(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) > std::tie(other.op_type); +} +bool ElementUnaryAttrs::operator<=(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) <= std::tie(other.op_type); +} +bool ElementUnaryAttrs::operator>=(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) >= std::tie(other.op_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ElementUnaryAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OperatorType>{}(x.op_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ElementUnaryAttrs + adl_serializer::from_json(json const &j) { + return {j.at("op_type").template get<::FlexFlow::OperatorType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ElementUnaryAttrs const &v) { + j["__type"] = "ElementUnaryAttrs"; + j["op_type"] = v.op_type; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::OperatorType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ElementUnaryAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ElementUnaryAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/embedding.cc b/lib/op-attrs/src/op-attrs/ops/embedding.cc new file mode 100644 index 0000000000..9e9ad3a194 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/embedding.cc @@ -0,0 +1,112 @@ +#include "op-attrs/ops/embedding.h" +#include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/dim_ordered/transform.h" +#include "utils/containers.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +static std::optional basic_check(EmbeddingAttrs const &attrs, + TensorShape const &input) { + if (input.data_type != DataType::INT32 && + input.data_type != DataType::INT64) { + return fmt::format("Embedding expected input tensor to have integer " + "datatype, but receieved tensor of datatype {}", + input.data_type); + } + + if (attrs.aggr != AggregateOp::SUM) { + return fmt::format(fmt::format( + "Currently unsupported aggregation op for embedding: {}", attrs.aggr)); + } + + return std::nullopt; +} + +tl::expected + get_output_shape(EmbeddingAttrs const &attrs, TensorShape const &input) { + { + std::optional err_msg = basic_check(attrs, input); + if (err_msg.has_value()) { + return tl::unexpected(err_msg.value()); + } + } + + TensorShape output = input; + dim_at_idx(output, ff_dim_t{-1}) = attrs.out_channels; + output.data_type = attrs.data_type; + return output; +} + +tl::expected + get_weights_shape(EmbeddingAttrs const &attrs, TensorShape const &input) { + { + std::optional err_msg = basic_check(attrs, input); + if (err_msg.has_value()) { + return tl::unexpected(err_msg.value()); + } + } + + return TensorShape{ + TensorDims{ + FFOrdered{ + size_t_from_int(attrs.num_entries), + size_t_from_int(attrs.out_channels), + }, + }, + attrs.data_type, + }; +} + +tl::expected + get_output_shape(EmbeddingAttrs const &attrs, + ParallelTensorShape const &input) { + + TensorShape unpar = ({ + tl::expected result_unpar = + get_output_shape(attrs, get_reduced_shape(input)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + SumDegree sum_degree = shard_dim_at_idx(input, ff_dim_t{-1}).degree; + DiscardCopyDegree discard_copy_degree = 1; + FFOrdered shard_degrees = + transform(input.dims.shard_dims, + [](ShardParallelDim const &d) { return d.degree; }); + shard_degrees.at(ff_dim_t{-1}) = get_discard_copy_degree(input); + + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); +} + +tl::expected + get_weights_shape(EmbeddingAttrs const &attrs, + ParallelTensorShape const &input) { + TensorShape unpar = ({ + tl::expected result_unpar = + get_weights_shape(attrs, get_reduced_shape(input)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + SumDegree sum_degree = 1; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{product( + transform(ff_ordered_shard_dims(input.dims), + [](ShardParallelDim const &d) -> int { return d.degree; }))}; + int entry_dim_degree = 1; + int out_channel_degree = get_discard_copy_degree(input); + FFOrdered shard_degrees = { + entry_dim_degree, + out_channel_degree, + }; + + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc new file mode 100644 index 0000000000..b4d4657e08 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc @@ -0,0 +1,139 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml +/* proj-data +{ + "generated_from": "f2bdea52e23dee6f674f598f8691d994" +} +*/ + +#include "op-attrs/ops/embedding_attrs.dtg.h" + +#include "op-attrs/aggregate_op.dtg.h" +#include "op-attrs/datatype.dtg.h" +#include "utils/stack_vector.h" +#include + +namespace FlexFlow { +EmbeddingAttrs::EmbeddingAttrs( + int const &num_entries, + int const &out_channels, + std::optional<::FlexFlow::AggregateOp> const &aggr, + ::FlexFlow::DataType const &data_type) + : num_entries(num_entries), out_channels(out_channels), aggr(aggr), + data_type(data_type) {} +bool EmbeddingAttrs::operator==(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) == std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +bool EmbeddingAttrs::operator!=(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) != std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +bool EmbeddingAttrs::operator<(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) < std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +bool EmbeddingAttrs::operator>(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) > std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +bool EmbeddingAttrs::operator<=(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) <= std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +bool EmbeddingAttrs::operator>=(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) >= std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::EmbeddingAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.num_entries) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.out_channels) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash>{}(x.aggr) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.data_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::EmbeddingAttrs + adl_serializer::from_json(json const &j) { + return {j.at("num_entries").template get(), + j.at("out_channels").template get(), + j.at("aggr").template get>(), + j.at("data_type").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::EmbeddingAttrs const &v) { + j["__type"] = "EmbeddingAttrs"; + j["num_entries"] = v.num_entries; + j["out_channels"] = v.out_channels; + j["aggr"] = v.aggr; + j["data_type"] = v.data_type; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary>(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(EmbeddingAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, EmbeddingAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc similarity index 94% rename from lib/op-attrs/src/flat.cc rename to lib/op-attrs/src/op-attrs/ops/flat.cc index 75d31beae4..b0683c5f08 100644 --- a/lib/op-attrs/src/flat.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat.cc @@ -1,6 +1,4 @@ #include "op-attrs/ops/flat.h" -#include "parallel_dim_mapping_record.h" -#include "parallel_dim_mapping_record_solver.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/flat_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/flat_attrs.dtg.cc new file mode 100644 index 0000000000..ef34d97a89 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/flat_attrs.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml +/* proj-data +{ + "generated_from": "b63924cd671481df30fae314a199c606" +} +*/ + +#include "op-attrs/ops/flat_attrs.dtg.h" + +#include + +namespace FlexFlow { +bool FlatAttrs::operator==(FlatAttrs const &other) const { + return std::tie() == std::tie(); +} +bool FlatAttrs::operator!=(FlatAttrs const &other) const { + return std::tie() != std::tie(); +} +bool FlatAttrs::operator<(FlatAttrs const &other) const { + return std::tie() < std::tie(); +} +bool FlatAttrs::operator>(FlatAttrs const &other) const { + return std::tie() > std::tie(); +} +bool FlatAttrs::operator<=(FlatAttrs const &other) const { + return std::tie() <= std::tie(); +} +bool FlatAttrs::operator>=(FlatAttrs const &other) const { + return std::tie() >= std::tie(); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(FlexFlow::FlatAttrs const &x) const { + size_t result = 0; + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::FlatAttrs + adl_serializer::from_json(json const &j) { + return {}; +} +void adl_serializer::to_json( + json &j, FlexFlow::FlatAttrs const &v) { + j["__type"] = "FlatAttrs"; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(FlatAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, FlatAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/gather.cc b/lib/op-attrs/src/op-attrs/ops/gather.cc similarity index 100% rename from lib/op-attrs/src/gather.cc rename to lib/op-attrs/src/op-attrs/ops/gather.cc diff --git a/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc new file mode 100644 index 0000000000..713c0f391e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml +/* proj-data +{ + "generated_from": "4ba46b6b494a7a52edda437d2a05fcf1" +} +*/ + +#include "op-attrs/ops/gather_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include + +namespace FlexFlow { +GatherAttrs::GatherAttrs(::FlexFlow::ff_dim_t const &dim) : dim(dim) {} +bool GatherAttrs::operator==(GatherAttrs const &other) const { + return std::tie(this->dim) == std::tie(other.dim); +} +bool GatherAttrs::operator!=(GatherAttrs const &other) const { + return std::tie(this->dim) != std::tie(other.dim); +} +bool GatherAttrs::operator<(GatherAttrs const &other) const { + return std::tie(this->dim) < std::tie(other.dim); +} +bool GatherAttrs::operator>(GatherAttrs const &other) const { + return std::tie(this->dim) > std::tie(other.dim); +} +bool GatherAttrs::operator<=(GatherAttrs const &other) const { + return std::tie(this->dim) <= std::tie(other.dim); +} +bool GatherAttrs::operator>=(GatherAttrs const &other) const { + return std::tie(this->dim) >= std::tie(other.dim); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::GatherAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::GatherAttrs + adl_serializer::from_json(json const &j) { + return {j.at("dim").template get<::FlexFlow::ff_dim_t>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::GatherAttrs const &v) { + j["__type"] = "GatherAttrs"; + j["dim"] = v.dim; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(GatherAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, GatherAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/input.cc b/lib/op-attrs/src/op-attrs/ops/input.cc new file mode 100644 index 0000000000..93606b603a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/input.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/input.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(InputAttrs const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/input_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/input_attrs.dtg.cc new file mode 100644 index 0000000000..35544402f7 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/input_attrs.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml +/* proj-data +{ + "generated_from": "139ea46d57a3c8738b31b17a8c59a0aa" +} +*/ + +#include "op-attrs/ops/input_attrs.dtg.h" + +#include + +namespace FlexFlow { +bool InputAttrs::operator==(InputAttrs const &other) const { + return std::tie() == std::tie(); +} +bool InputAttrs::operator!=(InputAttrs const &other) const { + return std::tie() != std::tie(); +} +bool InputAttrs::operator<(InputAttrs const &other) const { + return std::tie() < std::tie(); +} +bool InputAttrs::operator>(InputAttrs const &other) const { + return std::tie() > std::tie(); +} +bool InputAttrs::operator<=(InputAttrs const &other) const { + return std::tie() <= std::tie(); +} +bool InputAttrs::operator>=(InputAttrs const &other) const { + return std::tie() >= std::tie(); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::InputAttrs const &x) const { + size_t result = 0; + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::InputAttrs + adl_serializer::from_json(json const &j) { + return {}; +} +void adl_serializer::to_json( + json &j, FlexFlow::InputAttrs const &v) { + j["__type"] = "InputAttrs"; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(InputAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, InputAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc new file mode 100644 index 0000000000..437ba3638a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/layer_norm.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(LayerNormAttrs const &, + ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc new file mode 100644 index 0000000000..163f2e2f91 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc @@ -0,0 +1,108 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml +/* proj-data +{ + "generated_from": "349deae8d9356d3eeacd7e7d069c3155" +} +*/ + +#include "op-attrs/ops/layer_norm_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "utils/stack_vector.h" +#include + +namespace FlexFlow { +LayerNormAttrs::LayerNormAttrs( + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> const &axes, + bool const &elementwise_affine, + float const &eps) + : axes(axes), elementwise_affine(elementwise_affine), eps(eps) {} +bool LayerNormAttrs::operator==(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) == + std::tie(other.axes, other.elementwise_affine, other.eps); +} +bool LayerNormAttrs::operator!=(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) != + std::tie(other.axes, other.elementwise_affine, other.eps); +} +bool LayerNormAttrs::operator<(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) < + std::tie(other.axes, other.elementwise_affine, other.eps); +} +bool LayerNormAttrs::operator>(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) > + std::tie(other.axes, other.elementwise_affine, other.eps); +} +bool LayerNormAttrs::operator<=(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) <= + std::tie(other.axes, other.elementwise_affine, other.eps); +} +bool LayerNormAttrs::operator>=(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) >= + std::tie(other.axes, other.elementwise_affine, other.eps); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::LayerNormAttrs const &x) const { + size_t result = 0; + result ^= + std::hash< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>{}( + x.axes) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.elementwise_affine) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.eps) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::LayerNormAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("axes") + .template get< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), + j.at("elementwise_affine").template get(), + j.at("eps").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::LayerNormAttrs const &v) { + j["__type"] = "LayerNormAttrs"; + j["axes"] = v.axes; + j["elementwise_affine"] = v.elementwise_affine; + j["eps"] = v.eps; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), + gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(LayerNormAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, LayerNormAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc new file mode 100644 index 0000000000..8283673378 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -0,0 +1,110 @@ +#include "op-attrs/ops/linear.h" +#include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/dim_ordered/transform.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +tl::expected + get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { + size_t in_channels = dim_at_idx(input_shape, ff_dim_t{-1}); + + return TensorShape{ + TensorDims{ + FFOrdered{in_channels, size_t_from_int(attrs.out_channels)}, + }, + input_shape.data_type, + }; +} + +tl::expected + get_bias_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { + return TensorShape{ + TensorDims{ + FFOrdered{size_t_from_int(attrs.out_channels)}, + }, + input_shape.data_type, + }; +} + +tl::expected + get_output_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { + TensorShape output_shape = input_shape; + output_shape.dims.ff_ordered.at(ff_dim_t{-1}) = + size_t_from_int(attrs.out_channels); + + return output_shape; +} + +tl::expected + get_kernel_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input) { + TensorShape unpar = ({ + tl::expected result_unpar = + get_kernel_shape(attrs, get_reduced_shape(input)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + SumDegree sum_degree = 1; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ + get_sum_degree(input) * + product( + slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1}))}; + FFOrdered shard_degrees = FFOrdered{ + shard_dim_at_idx(input, ff_dim_t{-1}).degree, + get_discard_copy_degree(input), + }; + + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); +} + +tl::expected + get_bias_shape(LinearAttrs const &attrs, ParallelTensorShape const &input) { + TensorShape unpar = ({ + tl::expected result_unpar = + get_bias_shape(attrs, get_reduced_shape(input)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + SumDegree sum_degree = + get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree; + DiscardCopyDegree discard_copy_degree = product( + slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1})); + FFOrdered shard_degrees = FFOrdered{get_discard_copy_degree(input)}; + + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); +} + +tl::expected + get_output_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input) { + TensorShape unpar = ({ + tl::expected result_unpar = + get_output_shape(attrs, get_reduced_shape(input)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + SumDegree sum_degree = + get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree; + DiscardCopyDegree discard_copy_degree = 1; + FFOrdered shard_degrees = ff_ordered_shard_degrees(input); + shard_degrees.at(ff_dim_t{-1}) = get_discard_copy_degree(input); + + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc new file mode 100644 index 0000000000..f3359da219 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc @@ -0,0 +1,162 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml +/* proj-data +{ + "generated_from": "7e82d282f90e08f1e0db7d5c4ce528b7" +} +*/ + +#include "op-attrs/ops/linear_attrs.dtg.h" + +#include "op-attrs/activation.dtg.h" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/regularizer_attrs.dtg.h" +#include "utils/json.h" +#include + +namespace FlexFlow { +LinearAttrs::LinearAttrs( + int const &out_channels, + bool const &use_bias, + ::FlexFlow::DataType const &data_type, + std::optional<::FlexFlow::Activation> const &activation, + std::optional<::FlexFlow::RegularizerAttrs> const ®ularizer) + : out_channels(out_channels), use_bias(use_bias), data_type(data_type), + activation(activation), regularizer(regularizer) {} +bool LinearAttrs::operator==(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) == std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +bool LinearAttrs::operator!=(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) != std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +bool LinearAttrs::operator<(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) < std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +bool LinearAttrs::operator>(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) > std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +bool LinearAttrs::operator<=(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) <= std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +bool LinearAttrs::operator>=(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) >= std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::LinearAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.out_channels) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.use_bias) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.data_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash>{}(x.activation) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash>{}(x.regularizer) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::LinearAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("out_channels").template get(), + j.at("use_bias").template get(), + j.at("data_type").template get<::FlexFlow::DataType>(), + j.at("activation").template get>(), + j.at("regularizer") + .template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::LinearAttrs const &v) { + j["__type"] = "LinearAttrs"; + j["out_channels"] = v.out_channels; + j["use_bias"] = v.use_bias; + j["data_type"] = v.data_type; + j["activation"] = v.activation; + j["regularizer"] = v.regularizer; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::DataType>(), + gen::arbitrary>(), + gen::arbitrary>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(LinearAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, LinearAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/noop.cc b/lib/op-attrs/src/op-attrs/ops/noop.cc new file mode 100644 index 0000000000..b2b15d820c --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/noop.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/noop.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(NoopAttrs const &, + ParallelTensorShape const &input_shape) { + return input_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/noop_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/noop_attrs.dtg.cc new file mode 100644 index 0000000000..3ef3a0119b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/noop_attrs.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml +/* proj-data +{ + "generated_from": "d440077aa598fdad0e5aa95288b63c40" +} +*/ + +#include "op-attrs/ops/noop_attrs.dtg.h" + +#include + +namespace FlexFlow { +bool NoopAttrs::operator==(NoopAttrs const &other) const { + return std::tie() == std::tie(); +} +bool NoopAttrs::operator!=(NoopAttrs const &other) const { + return std::tie() != std::tie(); +} +bool NoopAttrs::operator<(NoopAttrs const &other) const { + return std::tie() < std::tie(); +} +bool NoopAttrs::operator>(NoopAttrs const &other) const { + return std::tie() > std::tie(); +} +bool NoopAttrs::operator<=(NoopAttrs const &other) const { + return std::tie() <= std::tie(); +} +bool NoopAttrs::operator>=(NoopAttrs const &other) const { + return std::tie() >= std::tie(); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(FlexFlow::NoopAttrs const &x) const { + size_t result = 0; + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::NoopAttrs + adl_serializer::from_json(json const &j) { + return {}; +} +void adl_serializer::to_json( + json &j, FlexFlow::NoopAttrs const &v) { + j["__type"] = "NoopAttrs"; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(NoopAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, NoopAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc new file mode 100644 index 0000000000..ac8da6d2d7 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc @@ -0,0 +1,88 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml +/* proj-data +{ + "generated_from": "b76a39763275090d8376e1c27668d2cb" +} +*/ + +#include "op-attrs/ops/parallel_attention_inputs.dtg.h" + +#include "op-attrs/parallel_tensor_shape.h" +#include + +namespace FlexFlow { +ParallelMultiHeadAttentionInputs::ParallelMultiHeadAttentionInputs( + ::FlexFlow::ParallelTensorShape const &query, + ::FlexFlow::ParallelTensorShape const &key, + ::FlexFlow::ParallelTensorShape const &value) + : query(query), key(key), value(value) {} +bool ParallelMultiHeadAttentionInputs::operator==( + ParallelMultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) == + std::tie(other.query, other.key, other.value); +} +bool ParallelMultiHeadAttentionInputs::operator!=( + ParallelMultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) != + std::tie(other.query, other.key, other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelMultiHeadAttentionInputs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ParallelTensorShape>{}(x.query) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ParallelTensorShape>{}(x.key) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ParallelTensorShape>{}(x.value) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelMultiHeadAttentionInputs + adl_serializer::from_json( + json const &j) { + return {j.at("query").template get<::FlexFlow::ParallelTensorShape>(), + j.at("key").template get<::FlexFlow::ParallelTensorShape>(), + j.at("value").template get<::FlexFlow::ParallelTensorShape>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelMultiHeadAttentionInputs const &v) { + j["__type"] = "ParallelMultiHeadAttentionInputs"; + j["query"] = v.query; + j["key"] = v.key; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ParallelTensorShape>(), + gen::arbitrary<::FlexFlow::ParallelTensorShape>(), + gen::arbitrary<::FlexFlow::ParallelTensorShape>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ParallelMultiHeadAttentionInputs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ParallelMultiHeadAttentionInputs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc similarity index 65% rename from lib/op-attrs/src/pool_2d.cc rename to lib/op-attrs/src/op-attrs/ops/pool_2d.cc index 0867aeb344..cf6ed177d3 100644 --- a/lib/op-attrs/src/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -1,4 +1,16 @@ #include "op-attrs/ops/pool_2d.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(Pool2DAttrs const &, + ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow + +/* +#include "op-attrs/ops/pool_2d.h" #include "parallel_dim_mapping_record.h" #include "parallel_dim_mapping_record_solver.h" @@ -14,12 +26,11 @@ constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, REPLICA = 4; }; -/* bool Pool2DAttrs::is_valid(ParallelTensorShape const &input) const { */ -/* ParallelTensorShape output_shape = this->calculate_output_shape(input); */ +bool Pool2DAttrs::is_valid(ParallelTensorShape const &input) const { + ParallelTensorShape output_shape = this->calculate_output_shape(input); -/* return output_shape.is_valid() && (input.at(Input::REPLICA).degree == 1); - */ -/* } */ + return output_shape.is_valid() && (input.at(Input::REPLICA).degree == 1); +} static std::vector construct_mappings(ParallelTensorShape const &input_shape) { @@ -39,9 +50,9 @@ static ParallelDimMappingSolution return solve_parallel_dim_mappings(construct_mappings(input), {input}, 0, 1); } -/* ParallelTensorShape Pool2DAttrs::calculate_output_shape(ParallelTensorShape - * const &input) const { */ -/* return solve_mappings(input).output_shapes.at(0); */ -/* } */ +ParallelTensorShape Pool2DAttrs::calculate_output_shape(ParallelTensorShape +const &input) const { return solve_mappings(input).output_shapes.at(0); +} } // namespace FlexFlow +*/ diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc new file mode 100644 index 0000000000..8c445d8b84 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc @@ -0,0 +1,214 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml +/* proj-data +{ + "generated_from": "03aeafe335f68ff831e3e73a77f45caf" +} +*/ + +#include "op-attrs/ops/pool_2d_attrs.dtg.h" + +#include "op-attrs/activation.dtg.h" +#include "op-attrs/pool_op.dtg.h" +#include + +namespace FlexFlow { +Pool2DAttrs::Pool2DAttrs(int const &kernel_h, + int const &kernel_w, + int const &stride_h, + int const &stride_w, + int const &padding_h, + int const &padding_w, + ::FlexFlow::PoolOp const &pool_type, + ::FlexFlow::Activation const &activation) + : kernel_h(kernel_h), kernel_w(kernel_w), stride_h(stride_h), + stride_w(stride_w), padding_h(padding_h), padding_w(padding_w), + pool_type(pool_type), activation(activation) {} +bool Pool2DAttrs::operator==(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) == std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +bool Pool2DAttrs::operator!=(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) != std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +bool Pool2DAttrs::operator<(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) < std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +bool Pool2DAttrs::operator>(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) > std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +bool Pool2DAttrs::operator<=(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) <= std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +bool Pool2DAttrs::operator>=(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) >= std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::Pool2DAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.kernel_h) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.kernel_w) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stride_h) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stride_w) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.padding_h) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.padding_w) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::PoolOp>{}(x.pool_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::Activation>{}(x.activation) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::Pool2DAttrs + adl_serializer::from_json(json const &j) { + return {j.at("kernel_h").template get(), + j.at("kernel_w").template get(), + j.at("stride_h").template get(), + j.at("stride_w").template get(), + j.at("padding_h").template get(), + j.at("padding_w").template get(), + j.at("pool_type").template get<::FlexFlow::PoolOp>(), + j.at("activation").template get<::FlexFlow::Activation>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::Pool2DAttrs const &v) { + j["__type"] = "Pool2DAttrs"; + j["kernel_h"] = v.kernel_h; + j["kernel_w"] = v.kernel_w; + j["stride_h"] = v.stride_h; + j["stride_w"] = v.stride_w; + j["padding_h"] = v.padding_h; + j["padding_w"] = v.padding_w; + j["pool_type"] = v.pool_type; + j["activation"] = v.activation; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::PoolOp>(), + gen::arbitrary<::FlexFlow::Activation>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(Pool2DAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, Pool2DAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reduce.cc b/lib/op-attrs/src/op-attrs/ops/reduce.cc new file mode 100644 index 0000000000..2a8bf06ecf --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reduce.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/reduce.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ReduceAttrs const &, + ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc new file mode 100644 index 0000000000..2aa9546956 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc @@ -0,0 +1,109 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml +/* proj-data +{ + "generated_from": "097463446e254f662c7bdf5df4e12d17" +} +*/ + +#include "op-attrs/ops/reduce_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "op-attrs/operator_type.dtg.h" +#include "utils/stack_vector.h" +#include + +namespace FlexFlow { +ReduceAttrs::ReduceAttrs( + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> const &axes, + ::FlexFlow::OperatorType const &op_type, + bool const &keepdims) + : axes(axes), op_type(op_type), keepdims(keepdims) {} +bool ReduceAttrs::operator==(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) == + std::tie(other.axes, other.op_type, other.keepdims); +} +bool ReduceAttrs::operator!=(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) != + std::tie(other.axes, other.op_type, other.keepdims); +} +bool ReduceAttrs::operator<(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) < + std::tie(other.axes, other.op_type, other.keepdims); +} +bool ReduceAttrs::operator>(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) > + std::tie(other.axes, other.op_type, other.keepdims); +} +bool ReduceAttrs::operator<=(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) <= + std::tie(other.axes, other.op_type, other.keepdims); +} +bool ReduceAttrs::operator>=(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) >= + std::tie(other.axes, other.op_type, other.keepdims); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReduceAttrs const &x) const { + size_t result = 0; + result ^= + std::hash< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>{}( + x.axes) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::OperatorType>{}(x.op_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.keepdims) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReduceAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("axes") + .template get< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), + j.at("op_type").template get<::FlexFlow::OperatorType>(), + j.at("keepdims").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReduceAttrs const &v) { + j["__type"] = "ReduceAttrs"; + j["axes"] = v.axes; + j["op_type"] = v.op_type; + j["keepdims"] = v.keepdims; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), + gen::arbitrary<::FlexFlow::OperatorType>(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReduceAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReduceAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reduction.cc b/lib/op-attrs/src/op-attrs/ops/reduction.cc new file mode 100644 index 0000000000..0fef6f37d6 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reduction.cc @@ -0,0 +1,22 @@ +#include "op-attrs/ops/reduction.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +tl::expected + get_output_shape(ReductionAttrs const &attrs, + ParallelTensorShape const &input_shape) { + if (get_sum_degree(input_shape) % attrs.reduction_degree != 0) { + return tl::unexpected( + fmt::format("Reduction received tensor with sum degree {}, which is " + "not divisible by reduction degree {}", + get_sum_degree(input_shape), + attrs.reduction_degree)); + } + + ParallelTensorShape output_shape = input_shape; + output_shape.dims.replica_dims.sum_degree.value /= attrs.reduction_degree; + return output_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc new file mode 100644 index 0000000000..2f1550bb66 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml +/* proj-data +{ + "generated_from": "1d2b5b7cf11ed04a27a6fd8215e4e2a5" +} +*/ + +#include "op-attrs/ops/reduction_attrs.dtg.h" + +#include + +namespace FlexFlow { +ReductionAttrs::ReductionAttrs(int const &reduction_degree) + : reduction_degree(reduction_degree) {} +bool ReductionAttrs::operator==(ReductionAttrs const &other) const { + return std::tie(this->reduction_degree) == std::tie(other.reduction_degree); +} +bool ReductionAttrs::operator!=(ReductionAttrs const &other) const { + return std::tie(this->reduction_degree) != std::tie(other.reduction_degree); +} +bool ReductionAttrs::operator<(ReductionAttrs const &other) const { + return std::tie(this->reduction_degree) < std::tie(other.reduction_degree); +} +bool ReductionAttrs::operator>(ReductionAttrs const &other) const { + return std::tie(this->reduction_degree) > std::tie(other.reduction_degree); +} +bool ReductionAttrs::operator<=(ReductionAttrs const &other) const { + return std::tie(this->reduction_degree) <= std::tie(other.reduction_degree); +} +bool ReductionAttrs::operator>=(ReductionAttrs const &other) const { + return std::tie(this->reduction_degree) >= std::tie(other.reduction_degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReductionAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.reduction_degree) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReductionAttrs + adl_serializer::from_json(json const &j) { + return {j.at("reduction_degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReductionAttrs const &v) { + j["__type"] = "ReductionAttrs"; + j["reduction_degree"] = v.reduction_degree; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReductionAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReductionAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/repartition.cc b/lib/op-attrs/src/op-attrs/ops/repartition.cc new file mode 100644 index 0000000000..37a0b8a168 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/repartition.cc @@ -0,0 +1,14 @@ +#include "op-attrs/ops/repartition.h" + +namespace FlexFlow { + +tl::expected + get_output_shape(RepartitionAttrs const &attrs, + ParallelTensorShape const &input_shape) { + ParallelTensorShape output_shape = input_shape; + output_shape.dims.shard_dims.at(attrs.repartition_dim).degree *= + attrs.repartition_degree; + return output_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc new file mode 100644 index 0000000000..6270298c87 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc @@ -0,0 +1,93 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml +/* proj-data +{ + "generated_from": "0a4d8b435768ce3ee37013fc550c9ebb" +} +*/ + +#include "op-attrs/ops/repartition_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include + +namespace FlexFlow { +RepartitionAttrs::RepartitionAttrs(::FlexFlow::ff_dim_t const &repartition_dim, + int const &repartition_degree) + : repartition_dim(repartition_dim), repartition_degree(repartition_degree) { +} +bool RepartitionAttrs::operator==(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) == + std::tie(other.repartition_dim, other.repartition_degree); +} +bool RepartitionAttrs::operator!=(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) != + std::tie(other.repartition_dim, other.repartition_degree); +} +bool RepartitionAttrs::operator<(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) < + std::tie(other.repartition_dim, other.repartition_degree); +} +bool RepartitionAttrs::operator>(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) > + std::tie(other.repartition_dim, other.repartition_degree); +} +bool RepartitionAttrs::operator<=(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) <= + std::tie(other.repartition_dim, other.repartition_degree); +} +bool RepartitionAttrs::operator>=(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) >= + std::tie(other.repartition_dim, other.repartition_degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::RepartitionAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.repartition_dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.repartition_degree) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::RepartitionAttrs + adl_serializer::from_json(json const &j) { + return {j.at("repartition_dim").template get<::FlexFlow::ff_dim_t>(), + j.at("repartition_degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::RepartitionAttrs const &v) { + j["__type"] = "RepartitionAttrs"; + j["repartition_dim"] = v.repartition_dim; + j["repartition_degree"] = v.repartition_degree; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(RepartitionAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, RepartitionAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/replicate.cc b/lib/op-attrs/src/op-attrs/ops/replicate.cc new file mode 100644 index 0000000000..9e163cb55a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/replicate.cc @@ -0,0 +1,13 @@ +#include "op-attrs/ops/replicate.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, + ParallelTensorShape const &input_shape) { + ParallelTensorShape output_shape = input_shape; + output_shape.dims.replica_dims.discard_copy_degree.value *= + attrs.replicate_degree; + return output_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc new file mode 100644 index 0000000000..930c5beaf4 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml +/* proj-data +{ + "generated_from": "6d3ad4d10c24dae819ffee4592a72499" +} +*/ + +#include "op-attrs/ops/replicate_attrs.dtg.h" + +#include + +namespace FlexFlow { +ReplicateAttrs::ReplicateAttrs(int const &replicate_degree) + : replicate_degree(replicate_degree) {} +bool ReplicateAttrs::operator==(ReplicateAttrs const &other) const { + return std::tie(this->replicate_degree) == std::tie(other.replicate_degree); +} +bool ReplicateAttrs::operator!=(ReplicateAttrs const &other) const { + return std::tie(this->replicate_degree) != std::tie(other.replicate_degree); +} +bool ReplicateAttrs::operator<(ReplicateAttrs const &other) const { + return std::tie(this->replicate_degree) < std::tie(other.replicate_degree); +} +bool ReplicateAttrs::operator>(ReplicateAttrs const &other) const { + return std::tie(this->replicate_degree) > std::tie(other.replicate_degree); +} +bool ReplicateAttrs::operator<=(ReplicateAttrs const &other) const { + return std::tie(this->replicate_degree) <= std::tie(other.replicate_degree); +} +bool ReplicateAttrs::operator>=(ReplicateAttrs const &other) const { + return std::tie(this->replicate_degree) >= std::tie(other.replicate_degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReplicateAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.replicate_degree) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReplicateAttrs + adl_serializer::from_json(json const &j) { + return {j.at("replicate_degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReplicateAttrs const &v) { + j["__type"] = "ReplicateAttrs"; + j["replicate_degree"] = v.replicate_degree; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicateAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReplicateAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reshape.cc b/lib/op-attrs/src/op-attrs/ops/reshape.cc new file mode 100644 index 0000000000..7d0600550a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reshape.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/reshape.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, + ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc new file mode 100644 index 0000000000..b1fb350b88 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml +/* proj-data +{ + "generated_from": "015d04de0ccb982e7eaa013a842880ca" +} +*/ + +#include "op-attrs/ops/reshape_attrs.dtg.h" + +#include "op-attrs/tensor_shape.dtg.h" +#include + +namespace FlexFlow { +ReshapeAttrs::ReshapeAttrs(::FlexFlow::TensorShape const &shape) + : shape(shape) {} +bool ReshapeAttrs::operator==(ReshapeAttrs const &other) const { + return std::tie(this->shape) == std::tie(other.shape); +} +bool ReshapeAttrs::operator!=(ReshapeAttrs const &other) const { + return std::tie(this->shape) != std::tie(other.shape); +} +bool ReshapeAttrs::operator<(ReshapeAttrs const &other) const { + return std::tie(this->shape) < std::tie(other.shape); +} +bool ReshapeAttrs::operator>(ReshapeAttrs const &other) const { + return std::tie(this->shape) > std::tie(other.shape); +} +bool ReshapeAttrs::operator<=(ReshapeAttrs const &other) const { + return std::tie(this->shape) <= std::tie(other.shape); +} +bool ReshapeAttrs::operator>=(ReshapeAttrs const &other) const { + return std::tie(this->shape) >= std::tie(other.shape); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReshapeAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorShape>{}(x.shape) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReshapeAttrs + adl_serializer::from_json(json const &j) { + return {j.at("shape").template get<::FlexFlow::TensorShape>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReshapeAttrs const &v) { + j["__type"] = "ReshapeAttrs"; + j["shape"] = v.shape; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::TensorShape>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReshapeAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReshapeAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reverse.cc b/lib/op-attrs/src/op-attrs/ops/reverse.cc new file mode 100644 index 0000000000..79b5bd50fb --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reverse.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/reverse.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, + ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc new file mode 100644 index 0000000000..9ac9abeb82 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml +/* proj-data +{ + "generated_from": "c5a82c8a15ac3ce6f47dc054236ab69b" +} +*/ + +#include "op-attrs/ops/reverse_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include + +namespace FlexFlow { +ReverseAttrs::ReverseAttrs(::FlexFlow::ff_dim_t const &axis) : axis(axis) {} +bool ReverseAttrs::operator==(ReverseAttrs const &other) const { + return std::tie(this->axis) == std::tie(other.axis); +} +bool ReverseAttrs::operator!=(ReverseAttrs const &other) const { + return std::tie(this->axis) != std::tie(other.axis); +} +bool ReverseAttrs::operator<(ReverseAttrs const &other) const { + return std::tie(this->axis) < std::tie(other.axis); +} +bool ReverseAttrs::operator>(ReverseAttrs const &other) const { + return std::tie(this->axis) > std::tie(other.axis); +} +bool ReverseAttrs::operator<=(ReverseAttrs const &other) const { + return std::tie(this->axis) <= std::tie(other.axis); +} +bool ReverseAttrs::operator>=(ReverseAttrs const &other) const { + return std::tie(this->axis) >= std::tie(other.axis); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReverseAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.axis) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReverseAttrs + adl_serializer::from_json(json const &j) { + return {j.at("axis").template get<::FlexFlow::ff_dim_t>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReverseAttrs const &v) { + j["__type"] = "ReverseAttrs"; + j["axis"] = v.axis; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReverseAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReverseAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/softmax.cc b/lib/op-attrs/src/op-attrs/ops/softmax.cc new file mode 100644 index 0000000000..2d870af50e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/softmax.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/softmax.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(SoftmaxAttrs const &attrs, + ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc new file mode 100644 index 0000000000..4941b7438a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml +/* proj-data +{ + "generated_from": "2ddf5a8b7daa32a43387f5fd5866bb3b" +} +*/ + +#include "op-attrs/ops/softmax_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include + +namespace FlexFlow { +SoftmaxAttrs::SoftmaxAttrs(::FlexFlow::ff_dim_t const &dim) : dim(dim) {} +bool SoftmaxAttrs::operator==(SoftmaxAttrs const &other) const { + return std::tie(this->dim) == std::tie(other.dim); +} +bool SoftmaxAttrs::operator!=(SoftmaxAttrs const &other) const { + return std::tie(this->dim) != std::tie(other.dim); +} +bool SoftmaxAttrs::operator<(SoftmaxAttrs const &other) const { + return std::tie(this->dim) < std::tie(other.dim); +} +bool SoftmaxAttrs::operator>(SoftmaxAttrs const &other) const { + return std::tie(this->dim) > std::tie(other.dim); +} +bool SoftmaxAttrs::operator<=(SoftmaxAttrs const &other) const { + return std::tie(this->dim) <= std::tie(other.dim); +} +bool SoftmaxAttrs::operator>=(SoftmaxAttrs const &other) const { + return std::tie(this->dim) >= std::tie(other.dim); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::SoftmaxAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::SoftmaxAttrs + adl_serializer::from_json(json const &j) { + return {j.at("dim").template get<::FlexFlow::ff_dim_t>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::SoftmaxAttrs const &v) { + j["__type"] = "SoftmaxAttrs"; + j["dim"] = v.dim; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(SoftmaxAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, SoftmaxAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/split.cc b/lib/op-attrs/src/op-attrs/ops/split.cc new file mode 100644 index 0000000000..cfb4071833 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/split.cc @@ -0,0 +1,11 @@ +#include "op-attrs/ops/split.h" + +namespace FlexFlow { + +std::vector + get_output_shapes(SplitAttrs const &attrs, + ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc new file mode 100644 index 0000000000..c6f7e75dbf --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc @@ -0,0 +1,96 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml +/* proj-data +{ + "generated_from": "cde6b5caf6739d3b02fe8fce0d8ae8c5" +} +*/ + +#include "op-attrs/ops/split_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "utils/stack_vector.h" +#include + +namespace FlexFlow { +SplitAttrs::SplitAttrs( + ::FlexFlow::stack_vector const &splits, + ::FlexFlow::ff_dim_t const &axis) + : splits(splits), axis(axis) {} +bool SplitAttrs::operator==(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) == + std::tie(other.splits, other.axis); +} +bool SplitAttrs::operator!=(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) != + std::tie(other.splits, other.axis); +} +bool SplitAttrs::operator<(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) < + std::tie(other.splits, other.axis); +} +bool SplitAttrs::operator>(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) > + std::tie(other.splits, other.axis); +} +bool SplitAttrs::operator<=(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) <= + std::tie(other.splits, other.axis); +} +bool SplitAttrs::operator>=(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) >= + std::tie(other.splits, other.axis); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::SplitAttrs const &x) const { + size_t result = 0; + result ^= + std::hash<::FlexFlow::stack_vector>{}(x.splits) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.axis) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::SplitAttrs + adl_serializer::from_json(json const &j) { + return {j.at("splits") + .template get<::FlexFlow::stack_vector>(), + j.at("axis").template get<::FlexFlow::ff_dim_t>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::SplitAttrs const &v) { + j["__type"] = "SplitAttrs"; + j["splits"] = v.splits; + j["axis"] = v.axis; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::stack_vector>(), + gen::arbitrary<::FlexFlow::ff_dim_t>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(SplitAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, SplitAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/topk.cc b/lib/op-attrs/src/op-attrs/ops/topk.cc new file mode 100644 index 0000000000..9d2fd35a94 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/topk.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/topk.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(TopKAttrs const &attrs, + ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/topk_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/topk_attrs.dtg.cc new file mode 100644 index 0000000000..55ead7d858 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/topk_attrs.dtg.cc @@ -0,0 +1,79 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml +/* proj-data +{ + "generated_from": "c1be9dc2acafc58690713e650663cc93" +} +*/ + +#include "op-attrs/ops/topk_attrs.dtg.h" + +#include + +namespace FlexFlow { +TopKAttrs::TopKAttrs(int const &k, bool const &sorted) : k(k), sorted(sorted) {} +bool TopKAttrs::operator==(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) == std::tie(other.k, other.sorted); +} +bool TopKAttrs::operator!=(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) != std::tie(other.k, other.sorted); +} +bool TopKAttrs::operator<(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) < std::tie(other.k, other.sorted); +} +bool TopKAttrs::operator>(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) > std::tie(other.k, other.sorted); +} +bool TopKAttrs::operator<=(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) <= std::tie(other.k, other.sorted); +} +bool TopKAttrs::operator>=(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) >= std::tie(other.k, other.sorted); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(FlexFlow::TopKAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.k) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.sorted) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TopKAttrs + adl_serializer::from_json(json const &j) { + return {j.at("k").template get(), j.at("sorted").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TopKAttrs const &v) { + j["__type"] = "TopKAttrs"; + j["k"] = v.k; + j["sorted"] = v.sorted; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(TopKAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TopKAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/transpose.cc b/lib/op-attrs/src/op-attrs/ops/transpose.cc new file mode 100644 index 0000000000..75f7eb3c18 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/transpose.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/transpose.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs, + ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc new file mode 100644 index 0000000000..0a774b992e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc @@ -0,0 +1,82 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml +/* proj-data +{ + "generated_from": "de62a505821a59c4b77197c100e204f7" +} +*/ + +#include "op-attrs/ops/transpose_attrs.dtg.h" + +#include "op-attrs/dim_ordered.h" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include + +namespace FlexFlow { +TransposeAttrs::TransposeAttrs( + ::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t> const &perm) + : perm(perm) {} +bool TransposeAttrs::operator==(TransposeAttrs const &other) const { + return std::tie(this->perm) == std::tie(other.perm); +} +bool TransposeAttrs::operator!=(TransposeAttrs const &other) const { + return std::tie(this->perm) != std::tie(other.perm); +} +bool TransposeAttrs::operator<(TransposeAttrs const &other) const { + return std::tie(this->perm) < std::tie(other.perm); +} +bool TransposeAttrs::operator>(TransposeAttrs const &other) const { + return std::tie(this->perm) > std::tie(other.perm); +} +bool TransposeAttrs::operator<=(TransposeAttrs const &other) const { + return std::tie(this->perm) <= std::tie(other.perm); +} +bool TransposeAttrs::operator>=(TransposeAttrs const &other) const { + return std::tie(this->perm) >= std::tie(other.perm); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TransposeAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>>{}(x.perm) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TransposeAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("perm").template get<::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TransposeAttrs const &v) { + j["__type"] = "TransposeAttrs"; + j["perm"] = v.perm; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(TransposeAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TransposeAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/weight_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/weight_attrs.dtg.cc new file mode 100644 index 0000000000..a288161da2 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/weight_attrs.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml +/* proj-data +{ + "generated_from": "59f49374ffca95b2117b8940af1b6cac" +} +*/ + +#include "op-attrs/ops/weight_attrs.dtg.h" + +#include + +namespace FlexFlow { +bool WeightAttrs::operator==(WeightAttrs const &other) const { + return std::tie() == std::tie(); +} +bool WeightAttrs::operator!=(WeightAttrs const &other) const { + return std::tie() != std::tie(); +} +bool WeightAttrs::operator<(WeightAttrs const &other) const { + return std::tie() < std::tie(); +} +bool WeightAttrs::operator>(WeightAttrs const &other) const { + return std::tie() > std::tie(); +} +bool WeightAttrs::operator<=(WeightAttrs const &other) const { + return std::tie() <= std::tie(); +} +bool WeightAttrs::operator>=(WeightAttrs const &other) const { + return std::tie() >= std::tie(); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::WeightAttrs const &x) const { + size_t result = 0; + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::WeightAttrs + adl_serializer::from_json(json const &j) { + return {}; +} +void adl_serializer::to_json( + json &j, FlexFlow::WeightAttrs const &v) { + j["__type"] = "WeightAttrs"; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(WeightAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, WeightAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc new file mode 100644 index 0000000000..886893c90a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc @@ -0,0 +1,116 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_dim.variant.toml +/* proj-data +{ + "generated_from": "f382ff547aae62777e5091f00d034d84" +} +*/ + +#include "op-attrs/parallel_dim.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +ParallelDim::ParallelDim(::FlexFlow::ShardParallelDim const &v) + : raw_variant(v) {} +ParallelDim::ParallelDim(::FlexFlow::ReplicaParallelDim const &v) + : raw_variant(v) {} +bool ParallelDim::operator==(ParallelDim const &other) const { + return this->raw_variant == other.raw_variant; +} +bool ParallelDim::operator!=(ParallelDim const &other) const { + return this->raw_variant != other.raw_variant; +} +bool ParallelDim::operator<(ParallelDim const &other) const { + return this->raw_variant < other.raw_variant; +} +bool ParallelDim::operator>(ParallelDim const &other) const { + return this->raw_variant > other.raw_variant; +} +bool ParallelDim::operator<=(ParallelDim const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool ParallelDim::operator>=(ParallelDim const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::ParallelDim>::operator()( + ::FlexFlow::ParallelDim const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::ParallelDim + adl_serializer<::FlexFlow::ParallelDim>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "shard_dim") { + return ::FlexFlow::ParallelDim{ + j.at("value").template get<::FlexFlow::ShardParallelDim>()}; + } else if (key == "replica_dim") { + return ::FlexFlow::ParallelDim{ + j.at("value").template get<::FlexFlow::ReplicaParallelDim>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::ParallelDim>::to_json( + json &j, ::FlexFlow::ParallelDim const &x) { + j["__type"] = "ParallelDim"; + switch (x.index()) { + case 0: { + j["type"] = "shard_dim"; + j["value"] = x.get<::FlexFlow::ShardParallelDim>(); + break; + } + case 1: { + j["type"] = "replica_dim"; + j["value"] = x.get<::FlexFlow::ReplicaParallelDim>(); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type ParallelDim", x.index())); + } + } +} +} // namespace nlohmann +namespace rc { +Gen<::FlexFlow::ParallelDim> Arbitrary<::FlexFlow::ParallelDim>::arbitrary() { + return gen::oneOf(gen::construct<::FlexFlow::ParallelDim>( + gen::arbitrary<::FlexFlow::ShardParallelDim>()), + gen::construct<::FlexFlow::ParallelDim>( + gen::arbitrary<::FlexFlow::ReplicaParallelDim>())); +} +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::ParallelDim const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type ParallelDim", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ::FlexFlow::ParallelDim const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc new file mode 100644 index 0000000000..ff5a8224df --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -0,0 +1,72 @@ +#include "op-attrs/parallel_tensor_dims.h" +#include "op-attrs/dim_ordered/transform.h" +#include "op-attrs/replica_parallel_dim.h" +#include "op-attrs/replica_parallel_dim_set.h" +#include "op-attrs/shard_parallel_dim.h" +#include "utils/containers.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +FFOrdered ff_ordered_shard_dims(ParallelTensorDims const &d) { + return d.shard_dims; +} + +FFOrdered ff_ordered_shard_degrees(ParallelTensorDims const &d) { + return transform(d.shard_dims, + [](ShardParallelDim const &d) { return d.degree; }); +} + +std::unordered_set + replica_dims(ParallelTensorDims const &d) { + return get_replica_dims(d.replica_dims); +} + +size_t num_shard_dims(ParallelTensorDims const &dims) { + return dims.shard_dims.size(); +} + +int total_replica_degree(ParallelTensorDims const &dims) { + return dims.replica_dims.discard_copy_degree.value * + dims.replica_dims.sum_degree.value; +} + +int total_shard_degree(ParallelTensorDims const &dims) { + return product(transform(as_vector(dims.shard_dims), + [](ShardParallelDim const &d) { return d.degree; })); +} + +int total_parallel_degree(ParallelTensorDims const &dims) { + return total_replica_degree(dims) * total_shard_degree(dims); +} + +bool is_valid(ParallelTensorDims const &dims) { + return all_of(dims.shard_dims, + [](ShardParallelDim const &d) { return is_valid(d); }) && + all_of(replica_dims(dims), + [](ReplicaParallelDim const &d) { return is_valid(d); }); +} + +ShardParallelDim shard_dim_at_idx(ParallelTensorDims const &d, ff_dim_t idx) { + return d.shard_dims.at(idx); +} + +ShardParallelDim &shard_dim_at_idx(ParallelTensorDims &d, ff_dim_t idx) { + return d.shard_dims.at(idx); +} + +TensorDims get_piece_dims(ParallelTensorDims const &) { + NOT_IMPLEMENTED(); +} + +TensorDims get_tensor_dims_unsafe(ParallelTensorDims const &) { + NOT_IMPLEMENTED(); +} + +TensorDims get_reduced_dims(ParallelTensorDims const &dims) { + FFOrdered dim_sizes = transform( + dims.shard_dims, [](ShardParallelDim const &d) { return d.size; }); + return TensorDims{dim_sizes}; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc new file mode 100644 index 0000000000..40be73cb9f --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc @@ -0,0 +1,101 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml +/* proj-data +{ + "generated_from": "aec3b6b66e34be0d5ce3055822479430" +} +*/ + +#include "op-attrs/parallel_tensor_dims.dtg.h" + +#include "op-attrs/dim_ordered.h" +#include "op-attrs/replica_parallel_dim_set.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "utils/fmt/pair.h" +#include "utils/fmt/unordered_map.h" +#include +#include + +namespace FlexFlow { +ParallelTensorDims::ParallelTensorDims( + ::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim> const &shard_dims, + ::FlexFlow::ReplicaParallelDimSet const &replica_dims) + : shard_dims(shard_dims), replica_dims(replica_dims) {} +bool ParallelTensorDims::operator==(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) == + std::tie(other.shard_dims, other.replica_dims); +} +bool ParallelTensorDims::operator!=(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) != + std::tie(other.shard_dims, other.replica_dims); +} +bool ParallelTensorDims::operator<(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) < + std::tie(other.shard_dims, other.replica_dims); +} +bool ParallelTensorDims::operator>(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) > + std::tie(other.shard_dims, other.replica_dims); +} +bool ParallelTensorDims::operator<=(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) <= + std::tie(other.shard_dims, other.replica_dims); +} +bool ParallelTensorDims::operator>=(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) >= + std::tie(other.shard_dims, other.replica_dims); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelTensorDims const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>>{}( + x.shard_dims) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ReplicaParallelDimSet>{}(x.replica_dims) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelTensorDims + adl_serializer::from_json(json const &j) { + return { + j.at("shard_dims") + .template get<::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>>(), + j.at("replica_dims").template get<::FlexFlow::ReplicaParallelDimSet>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelTensorDims const &v) { + j["__type"] = "ParallelTensorDims"; + j["shard_dims"] = v.shard_dims; + j["replica_dims"] = v.replica_dims; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>>(), + gen::arbitrary<::FlexFlow::ReplicaParallelDimSet>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ParallelTensorDims const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ParallelTensorDims const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc new file mode 100644 index 0000000000..516cbe191f --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -0,0 +1,87 @@ +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "utils/containers.h" +#include "utils/hash-utils.h" + +namespace FlexFlow { + +int num_shard_dims(ParallelTensorShape const &s) { + return num_shard_dims(s.dims); +} + +std::unordered_set + replica_dims(ParallelTensorShape const &s) { + return replica_dims(s.dims); +} + +int get_num_replicas(ParallelTensorShape const &shape) { + return product( + transform(replica_dims(shape), + [](ReplicaParallelDim const &d) -> int { return d.degree; })); +} + +int get_sum_degree(ParallelTensorShape const &shape) { + return shape.dims.replica_dims.sum_degree.value; +} + +int get_discard_copy_degree(ParallelTensorShape const &shape) { + return shape.dims.replica_dims.discard_copy_degree.value; +} + +int get_total_parallel_degree(ParallelTensorShape const &s) { + return total_parallel_degree(s.dims); +} + +bool is_valid(ParallelTensorShape const &shape) { + return is_valid(shape.dims); +} + +ShardParallelDim shard_dim_at_idx(ParallelTensorShape const &s, ff_dim_t d) { + return shard_dim_at_idx(s.dims, d); +} + +ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &s, ff_dim_t d) { + return shard_dim_at_idx(s.dims, d); +} + +FFOrdered ff_ordered_shard_degrees(ParallelTensorShape const &s) { + return ff_ordered_shard_degrees(s.dims); +} + +std::optional + try_get_shard_dim_at_idx(ParallelTensorShape const &s, ff_dim_t d) { + if (s.dims.shard_dims.idx_is_valid(d)) { + return s.dims.shard_dims.at(d); + } else { + return std::nullopt; + } +} + +ParallelTensorShape lift_to_parallel(TensorShape const &s) { + return {lift_to_parallel(s.dims), s.data_type}; +} + +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &s, + SumDegree sum_degree, + DiscardCopyDegree discard_copy_degree, + FFOrdered const &shard_degrees) { + return ParallelTensorShape{ + lift_to_parallel_with_degrees( + s.dims, sum_degree, discard_copy_degree, shard_degrees), + s.data_type, + }; +} + +TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +TensorShape get_reduced_shape(ParallelTensorShape const &s) { + return TensorShape{ + get_reduced_dims(s.dims), + s.data_type, + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc new file mode 100644 index 0000000000..1fe82ce108 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc @@ -0,0 +1,94 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml +/* proj-data +{ + "generated_from": "06d657d1e95f34aebf4b721c768cbee8" +} +*/ + +#include "op-attrs/parallel_tensor_shape.dtg.h" + +#include "op-attrs/datatype.h" +#include "op-attrs/parallel_tensor_dims.h" +#include + +namespace FlexFlow { +ParallelTensorShape::ParallelTensorShape( + ::FlexFlow::ParallelTensorDims const &dims, + ::FlexFlow::DataType const &data_type) + : dims(dims), data_type(data_type) {} +bool ParallelTensorShape::operator==(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) == + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator!=(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) != + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator<(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) < + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator>(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) > + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator<=(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) <= + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator>=(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) >= + std::tie(other.dims, other.data_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelTensorShape const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ParallelTensorDims>{}(x.dims) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.data_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelTensorShape + adl_serializer::from_json(json const &j) { + return {j.at("dims").template get<::FlexFlow::ParallelTensorDims>(), + j.at("data_type").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelTensorShape const &v) { + j["__type"] = "ParallelTensorShape"; + j["dims"] = v.dims; + j["data_type"] = v.data_type; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ParallelTensorDims>(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ParallelTensorShape const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ParallelTensorShape const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.cc new file mode 100644 index 0000000000..4547a5df9b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml +/* proj-data +{ + "generated_from": "e4677d1fb25d3833570ee567f5659914" +} +*/ + +#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" + +#include + +namespace FlexFlow { +DiscardCopyDegree::DiscardCopyDegree(int const &value) : value(value) {} +bool DiscardCopyDegree::operator==(DiscardCopyDegree const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool DiscardCopyDegree::operator!=(DiscardCopyDegree const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool DiscardCopyDegree::operator<(DiscardCopyDegree const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool DiscardCopyDegree::operator>(DiscardCopyDegree const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool DiscardCopyDegree::operator<=(DiscardCopyDegree const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool DiscardCopyDegree::operator>=(DiscardCopyDegree const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::DiscardCopyDegree const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::DiscardCopyDegree + adl_serializer::from_json(json const &j) { + return {j.at("value").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::DiscardCopyDegree const &v) { + j["__type"] = "DiscardCopyDegree"; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(DiscardCopyDegree const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DiscardCopyDegree const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape/sum_degree.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/sum_degree.dtg.cc new file mode 100644 index 0000000000..cf159a1ea7 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/sum_degree.dtg.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml +/* proj-data +{ + "generated_from": "e94a05618f2ad92dd7b3328a1d9c6786" +} +*/ + +#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" + +#include + +namespace FlexFlow { +SumDegree::SumDegree(int const &value) : value(value) {} +bool SumDegree::operator==(SumDegree const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool SumDegree::operator!=(SumDegree const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool SumDegree::operator<(SumDegree const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool SumDegree::operator>(SumDegree const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool SumDegree::operator<=(SumDegree const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool SumDegree::operator>=(SumDegree const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(FlexFlow::SumDegree const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::SumDegree + adl_serializer::from_json(json const &j) { + return {j.at("value").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::SumDegree const &v) { + j["__type"] = "SumDegree"; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(SumDegree const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, SumDegree const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/param_sync.dtg.cc b/lib/op-attrs/src/op-attrs/param_sync.dtg.cc new file mode 100644 index 0000000000..e0d13fdd2e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/param_sync.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/param_sync.enum.toml +/* proj-data +{ + "generated_from": "288c6e9e256cf58ba5dbd0e3791c08df" +} +*/ + +#include "op-attrs/param_sync.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::ParamSync x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(ParamSync x) { + switch (x) { + case ParamSync::PS: + return "PS"; + case ParamSync::NCCL: + return "NCCL"; + default: + std::ostringstream oss; + oss << "Unknown ParamSync value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, ParamSync x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, ParamSync x) { + switch (x) { + case ParamSync::PS: + j = "PS"; + break; + case ParamSync::NCCL: + j = "NCCL"; + break; + default: + std::ostringstream oss; + oss << "Unknown ParamSync value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, ParamSync &x) { + std::string as_str = j.get(); + if (as_str == "PS") { + x = ParamSync::PS; + } else if (as_str == "NCCL") { + x = ParamSync::NCCL; + } else { + std::ostringstream oss; + oss << "Unknown ParamSync value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::ParamSync::PS, + FlexFlow::ParamSync::NCCL); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc new file mode 100644 index 0000000000..76ad48d471 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc @@ -0,0 +1,16 @@ +#include "op-attrs/pcg_operator_attrs.h" +#include "op-attrs/get_op_type.h" + +namespace FlexFlow { + +bool is_parallel_op(PCGOperatorAttrs const &attrs) { + return (attrs.has() || attrs.has() || + attrs.has() || attrs.has()); +} + +OperatorType get_op_type(PCGOperatorAttrs const &attrs) { + return attrs.visit( + [](auto const &x) { return get_op_type(x); }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc new file mode 100644 index 0000000000..5334c8a7ab --- /dev/null +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc @@ -0,0 +1,599 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml +/* proj-data +{ + "generated_from": "e1d10b0c7c98524c27886bdae0972321" +} +*/ + +#include "op-attrs/pcg_operator_attrs.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::BatchMatmulAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::BatchNormAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::CastAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::CombineAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ConcatAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::Conv2DAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::DropoutAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ElementBinaryAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ElementUnaryAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ElementScalarUnaryAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::EmbeddingAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::FlatAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::GatherAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::InputAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::LayerNormAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::LinearAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::MultiHeadAttentionAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::NoopAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::Pool2DAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReduceAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReductionAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::RepartitionAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReplicateAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReverseAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReshapeAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::SplitAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::SoftmaxAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::TopKAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::TransposeAttrs const &v) + : raw_variant(v) {} +bool PCGOperatorAttrs::operator==(PCGOperatorAttrs const &other) const { + return this->raw_variant == other.raw_variant; +} +bool PCGOperatorAttrs::operator!=(PCGOperatorAttrs const &other) const { + return this->raw_variant != other.raw_variant; +} +bool PCGOperatorAttrs::operator<(PCGOperatorAttrs const &other) const { + return this->raw_variant < other.raw_variant; +} +bool PCGOperatorAttrs::operator>(PCGOperatorAttrs const &other) const { + return this->raw_variant > other.raw_variant; +} +bool PCGOperatorAttrs::operator<=(PCGOperatorAttrs const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool PCGOperatorAttrs::operator>=(PCGOperatorAttrs const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::PCGOperatorAttrs>::operator()( + ::FlexFlow::PCGOperatorAttrs const &x) const { + return std::hash>{}(x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::PCGOperatorAttrs + adl_serializer<::FlexFlow::PCGOperatorAttrs>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "batch_matmul") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::BatchMatmulAttrs>()}; + } else if (key == "batch_norm") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::BatchNormAttrs>()}; + } else if (key == "cast") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::CastAttrs>()}; + } else if (key == "combine_distributed") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::CombineAttrs>()}; + } else if (key == "concat") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ConcatAttrs>()}; + } else if (key == "conv2d") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::Conv2DAttrs>()}; + } else if (key == "dropout") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::DropoutAttrs>()}; + } else if (key == "element_binary") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ElementBinaryAttrs>()}; + } else if (key == "element_unary") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ElementUnaryAttrs>()}; + } else if (key == "element_scalar_unary") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ElementScalarUnaryAttrs>()}; + } else if (key == "embedding") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::EmbeddingAttrs>()}; + } else if (key == "flat") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::FlatAttrs>()}; + } else if (key == "gather") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::GatherAttrs>()}; + } else if (key == "input") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::InputAttrs>()}; + } else if (key == "layer_norm") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::LayerNormAttrs>()}; + } else if (key == "linear") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::LinearAttrs>()}; + } else if (key == "multi_head_attention") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::MultiHeadAttentionAttrs>()}; + } else if (key == "noop") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::NoopAttrs>()}; + } else if (key == "pool2d") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::Pool2DAttrs>()}; + } else if (key == "reduce") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ReduceAttrs>()}; + } else if (key == "reduce_distributed") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ReductionAttrs>()}; + } else if (key == "partition_distributed") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::RepartitionAttrs>()}; + } else if (key == "replicate_distributed") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ReplicateAttrs>()}; + } else if (key == "reverse") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ReverseAttrs>()}; + } else if (key == "reshape") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ReshapeAttrs>()}; + } else if (key == "split") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::SplitAttrs>()}; + } else if (key == "softmax") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::SoftmaxAttrs>()}; + } else if (key == "topk") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::TopKAttrs>()}; + } else if (key == "transpose") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::TransposeAttrs>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::PCGOperatorAttrs>::to_json( + json &j, ::FlexFlow::PCGOperatorAttrs const &x) { + j["__type"] = "PCGOperatorAttrs"; + switch (x.index()) { + case 0: { + j["type"] = "batch_matmul"; + j["value"] = x.get<::FlexFlow::BatchMatmulAttrs>(); + break; + } + case 1: { + j["type"] = "batch_norm"; + j["value"] = x.get<::FlexFlow::BatchNormAttrs>(); + break; + } + case 2: { + j["type"] = "cast"; + j["value"] = x.get<::FlexFlow::CastAttrs>(); + break; + } + case 3: { + j["type"] = "combine_distributed"; + j["value"] = x.get<::FlexFlow::CombineAttrs>(); + break; + } + case 4: { + j["type"] = "concat"; + j["value"] = x.get<::FlexFlow::ConcatAttrs>(); + break; + } + case 5: { + j["type"] = "conv2d"; + j["value"] = x.get<::FlexFlow::Conv2DAttrs>(); + break; + } + case 6: { + j["type"] = "dropout"; + j["value"] = x.get<::FlexFlow::DropoutAttrs>(); + break; + } + case 7: { + j["type"] = "element_binary"; + j["value"] = x.get<::FlexFlow::ElementBinaryAttrs>(); + break; + } + case 8: { + j["type"] = "element_unary"; + j["value"] = x.get<::FlexFlow::ElementUnaryAttrs>(); + break; + } + case 9: { + j["type"] = "element_scalar_unary"; + j["value"] = x.get<::FlexFlow::ElementScalarUnaryAttrs>(); + break; + } + case 10: { + j["type"] = "embedding"; + j["value"] = x.get<::FlexFlow::EmbeddingAttrs>(); + break; + } + case 11: { + j["type"] = "flat"; + j["value"] = x.get<::FlexFlow::FlatAttrs>(); + break; + } + case 12: { + j["type"] = "gather"; + j["value"] = x.get<::FlexFlow::GatherAttrs>(); + break; + } + case 13: { + j["type"] = "input"; + j["value"] = x.get<::FlexFlow::InputAttrs>(); + break; + } + case 14: { + j["type"] = "layer_norm"; + j["value"] = x.get<::FlexFlow::LayerNormAttrs>(); + break; + } + case 15: { + j["type"] = "linear"; + j["value"] = x.get<::FlexFlow::LinearAttrs>(); + break; + } + case 16: { + j["type"] = "multi_head_attention"; + j["value"] = x.get<::FlexFlow::MultiHeadAttentionAttrs>(); + break; + } + case 17: { + j["type"] = "noop"; + j["value"] = x.get<::FlexFlow::NoopAttrs>(); + break; + } + case 18: { + j["type"] = "pool2d"; + j["value"] = x.get<::FlexFlow::Pool2DAttrs>(); + break; + } + case 19: { + j["type"] = "reduce"; + j["value"] = x.get<::FlexFlow::ReduceAttrs>(); + break; + } + case 20: { + j["type"] = "reduce_distributed"; + j["value"] = x.get<::FlexFlow::ReductionAttrs>(); + break; + } + case 21: { + j["type"] = "partition_distributed"; + j["value"] = x.get<::FlexFlow::RepartitionAttrs>(); + break; + } + case 22: { + j["type"] = "replicate_distributed"; + j["value"] = x.get<::FlexFlow::ReplicateAttrs>(); + break; + } + case 23: { + j["type"] = "reverse"; + j["value"] = x.get<::FlexFlow::ReverseAttrs>(); + break; + } + case 24: { + j["type"] = "reshape"; + j["value"] = x.get<::FlexFlow::ReshapeAttrs>(); + break; + } + case 25: { + j["type"] = "split"; + j["value"] = x.get<::FlexFlow::SplitAttrs>(); + break; + } + case 26: { + j["type"] = "softmax"; + j["value"] = x.get<::FlexFlow::SoftmaxAttrs>(); + break; + } + case 27: { + j["type"] = "topk"; + j["value"] = x.get<::FlexFlow::TopKAttrs>(); + break; + } + case 28: { + j["type"] = "transpose"; + j["value"] = x.get<::FlexFlow::TransposeAttrs>(); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type PCGOperatorAttrs", x.index())); + } + } +} +} // namespace nlohmann +namespace rc { +Gen<::FlexFlow::PCGOperatorAttrs> + Arbitrary<::FlexFlow::PCGOperatorAttrs>::arbitrary() { + return gen::oneOf(gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::BatchMatmulAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::BatchNormAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::CastAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::CombineAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ConcatAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::Conv2DAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::DropoutAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ElementBinaryAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ElementUnaryAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ElementScalarUnaryAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::EmbeddingAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::FlatAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::GatherAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::InputAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::LayerNormAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::LinearAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::MultiHeadAttentionAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::NoopAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::Pool2DAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ReduceAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ReductionAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::RepartitionAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ReplicateAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ReverseAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ReshapeAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::SplitAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::SoftmaxAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::TopKAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::TransposeAttrs>())); +} +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::PCGOperatorAttrs const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + case 3: { + oss << ""; + break; + } + case 4: { + oss << ""; + break; + } + case 5: { + oss << ""; + break; + } + case 6: { + oss << ""; + break; + } + case 7: { + oss << ""; + break; + } + case 8: { + oss << ""; + break; + } + case 9: { + oss << ""; + break; + } + case 10: { + oss << ""; + break; + } + case 11: { + oss << ""; + break; + } + case 12: { + oss << ""; + break; + } + case 13: { + oss << ""; + break; + } + case 14: { + oss << ""; + break; + } + case 15: { + oss << ""; + break; + } + case 16: { + oss << ""; + break; + } + case 17: { + oss << ""; + break; + } + case 18: { + oss << ""; + break; + } + case 19: { + oss << ""; + break; + } + case 20: { + oss << ""; + break; + } + case 21: { + oss << ""; + break; + } + case 22: { + oss << ""; + break; + } + case 23: { + oss << ""; + break; + } + case 24: { + oss << ""; + break; + } + case 25: { + oss << ""; + break; + } + case 26: { + oss << ""; + break; + } + case 27: { + oss << ""; + break; + } + case 28: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type PCGOperatorAttrs", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::PCGOperatorAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/pool_op.dtg.cc b/lib/op-attrs/src/op-attrs/pool_op.dtg.cc new file mode 100644 index 0000000000..08a6f43943 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/pool_op.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/pool_op.enum.toml +/* proj-data +{ + "generated_from": "ed1d531c6227306c909eb28eb0a66538" +} +*/ + +#include "op-attrs/pool_op.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::PoolOp x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(PoolOp x) { + switch (x) { + case PoolOp::MAX: + return "MAX"; + case PoolOp::AVG: + return "AVG"; + default: + std::ostringstream oss; + oss << "Unknown PoolOp value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, PoolOp x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, PoolOp x) { + switch (x) { + case PoolOp::MAX: + j = "MAX"; + break; + case PoolOp::AVG: + j = "AVG"; + break; + default: + std::ostringstream oss; + oss << "Unknown PoolOp value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, PoolOp &x) { + std::string as_str = j.get(); + if (as_str == "MAX") { + x = PoolOp::MAX; + } else if (as_str == "AVG") { + x = PoolOp::AVG; + } else { + std::ostringstream oss; + oss << "Unknown PoolOp value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::PoolOp::MAX, + FlexFlow::PoolOp::AVG); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/regularizer_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/regularizer_attrs.dtg.cc new file mode 100644 index 0000000000..d1f844ab10 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/regularizer_attrs.dtg.cc @@ -0,0 +1,118 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml +/* proj-data +{ + "generated_from": "ea060a8ab344c9772102f084903883ea" +} +*/ + +#include "op-attrs/regularizer_attrs.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +RegularizerAttrs::RegularizerAttrs(::FlexFlow::L1RegularizerAttrs const &v) + : raw_variant(v) {} +RegularizerAttrs::RegularizerAttrs(::FlexFlow::L2RegularizerAttrs const &v) + : raw_variant(v) {} +bool RegularizerAttrs::operator==(RegularizerAttrs const &other) const { + return this->raw_variant == other.raw_variant; +} +bool RegularizerAttrs::operator!=(RegularizerAttrs const &other) const { + return this->raw_variant != other.raw_variant; +} +bool RegularizerAttrs::operator<(RegularizerAttrs const &other) const { + return this->raw_variant < other.raw_variant; +} +bool RegularizerAttrs::operator>(RegularizerAttrs const &other) const { + return this->raw_variant > other.raw_variant; +} +bool RegularizerAttrs::operator<=(RegularizerAttrs const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool RegularizerAttrs::operator>=(RegularizerAttrs const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::RegularizerAttrs>::operator()( + ::FlexFlow::RegularizerAttrs const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::RegularizerAttrs + adl_serializer<::FlexFlow::RegularizerAttrs>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "l1") { + return ::FlexFlow::RegularizerAttrs{ + j.at("value").template get<::FlexFlow::L1RegularizerAttrs>()}; + } else if (key == "l2") { + return ::FlexFlow::RegularizerAttrs{ + j.at("value").template get<::FlexFlow::L2RegularizerAttrs>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::RegularizerAttrs>::to_json( + json &j, ::FlexFlow::RegularizerAttrs const &x) { + j["__type"] = "RegularizerAttrs"; + switch (x.index()) { + case 0: { + j["type"] = "l1"; + j["value"] = x.get<::FlexFlow::L1RegularizerAttrs>(); + break; + } + case 1: { + j["type"] = "l2"; + j["value"] = x.get<::FlexFlow::L2RegularizerAttrs>(); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type RegularizerAttrs", x.index())); + } + } +} +} // namespace nlohmann +namespace rc { +Gen<::FlexFlow::RegularizerAttrs> + Arbitrary<::FlexFlow::RegularizerAttrs>::arbitrary() { + return gen::oneOf(gen::construct<::FlexFlow::RegularizerAttrs>( + gen::arbitrary<::FlexFlow::L1RegularizerAttrs>()), + gen::construct<::FlexFlow::RegularizerAttrs>( + gen::arbitrary<::FlexFlow::L2RegularizerAttrs>())); +} +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::RegularizerAttrs const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type RegularizerAttrs", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::RegularizerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim.cc new file mode 100644 index 0000000000..44b17c8b44 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim.cc @@ -0,0 +1,9 @@ +#include "op-attrs/replica_parallel_dim.h" + +namespace FlexFlow { + +bool is_valid(ReplicaParallelDim const &d) { + return d.degree > 0; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc new file mode 100644 index 0000000000..a1256ad79a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc @@ -0,0 +1,91 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml +/* proj-data +{ + "generated_from": "f501393070c8d55a05c43dd73a81a8d7" +} +*/ + +#include "op-attrs/replica_parallel_dim.dtg.h" + +#include "op-attrs/replica_type.dtg.h" +#include + +namespace FlexFlow { +ReplicaParallelDim::ReplicaParallelDim( + int const °ree, ::FlexFlow::ReplicaType const &replica_type) + : degree(degree), replica_type(replica_type) {} +bool ReplicaParallelDim::operator==(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) == + std::tie(other.degree, other.replica_type); +} +bool ReplicaParallelDim::operator!=(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) != + std::tie(other.degree, other.replica_type); +} +bool ReplicaParallelDim::operator<(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) < + std::tie(other.degree, other.replica_type); +} +bool ReplicaParallelDim::operator>(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) > + std::tie(other.degree, other.replica_type); +} +bool ReplicaParallelDim::operator<=(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) <= + std::tie(other.degree, other.replica_type); +} +bool ReplicaParallelDim::operator>=(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) >= + std::tie(other.degree, other.replica_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReplicaParallelDim const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.degree) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ReplicaType>{}(x.replica_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReplicaParallelDim + adl_serializer::from_json(json const &j) { + return {j.at("degree").template get(), + j.at("replica_type").template get<::FlexFlow::ReplicaType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReplicaParallelDim const &v) { + j["__type"] = "ReplicaParallelDim"; + j["degree"] = v.degree; + j["replica_type"] = v.replica_type; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), gen::arbitrary<::FlexFlow::ReplicaType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicaParallelDim const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReplicaParallelDim const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc new file mode 100644 index 0000000000..7ef228e97e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc @@ -0,0 +1,36 @@ +#include "op-attrs/replica_parallel_dim_set.h" +#include "utils/exception.h" + +namespace FlexFlow { + +ReplicaParallelDimSet empty_replica_parallel_dim_set() { + return ReplicaParallelDimSet{1, 1}; +} + +int get_order_of_replica_type(ReplicaParallelDimSet const &s, + ReplicaType replica_type) { + switch (replica_type) { + case ReplicaType::SUM: + return s.sum_degree.value; + case ReplicaType::DISCARD_COPY: + return s.discard_copy_degree.value; + default: + throw mk_runtime_error(fmt::format("Unexpected ReplicaType value: {}", + static_cast(replica_type))); + } +} + +std::unordered_set + get_replica_dims(ReplicaParallelDimSet const &s) { + return std::unordered_set{ + ReplicaParallelDim{s.sum_degree.value, ReplicaType::SUM}, + ReplicaParallelDim{s.discard_copy_degree.value, + ReplicaType::DISCARD_COPY}, + }; +} + +bool is_valid(ReplicaParallelDimSet const &s) { + return s.sum_degree.value > 0 && s.discard_copy_degree.value > 0; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc new file mode 100644 index 0000000000..f8782be01b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc @@ -0,0 +1,101 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml +/* proj-data +{ + "generated_from": "74230e2d18db5c059d3e7be0f25e746e" +} +*/ + +#include "op-attrs/replica_parallel_dim_set.dtg.h" + +#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" +#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" +#include + +namespace FlexFlow { +ReplicaParallelDimSet::ReplicaParallelDimSet( + ::FlexFlow::SumDegree const &sum_degree, + ::FlexFlow::DiscardCopyDegree const &discard_copy_degree) + : sum_degree(sum_degree), discard_copy_degree(discard_copy_degree) {} +bool ReplicaParallelDimSet::operator==( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) == + std::tie(other.sum_degree, other.discard_copy_degree); +} +bool ReplicaParallelDimSet::operator!=( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) != + std::tie(other.sum_degree, other.discard_copy_degree); +} +bool ReplicaParallelDimSet::operator<( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) < + std::tie(other.sum_degree, other.discard_copy_degree); +} +bool ReplicaParallelDimSet::operator>( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) > + std::tie(other.sum_degree, other.discard_copy_degree); +} +bool ReplicaParallelDimSet::operator<=( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) <= + std::tie(other.sum_degree, other.discard_copy_degree); +} +bool ReplicaParallelDimSet::operator>=( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) >= + std::tie(other.sum_degree, other.discard_copy_degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReplicaParallelDimSet const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::SumDegree>{}(x.sum_degree) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DiscardCopyDegree>{}(x.discard_copy_degree) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReplicaParallelDimSet + adl_serializer::from_json(json const &j) { + return {j.at("sum_degree").template get<::FlexFlow::SumDegree>(), + j.at("discard_copy_degree") + .template get<::FlexFlow::DiscardCopyDegree>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReplicaParallelDimSet const &v) { + j["__type"] = "ReplicaParallelDimSet"; + j["sum_degree"] = v.sum_degree; + j["discard_copy_degree"] = v.discard_copy_degree; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::SumDegree>(), + gen::arbitrary<::FlexFlow::DiscardCopyDegree>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicaParallelDimSet const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReplicaParallelDimSet const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_type.dtg.cc b/lib/op-attrs/src/op-attrs/replica_type.dtg.cc new file mode 100644 index 0000000000..d0410c49e2 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/replica_type.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_type.enum.toml +/* proj-data +{ + "generated_from": "6ecba7a6851b8bea93705bba24661149" +} +*/ + +#include "op-attrs/replica_type.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::ReplicaType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(ReplicaType x) { + switch (x) { + case ReplicaType::SUM: + return "SUM"; + case ReplicaType::DISCARD_COPY: + return "DISCARD_COPY"; + default: + std::ostringstream oss; + oss << "Unknown ReplicaType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, ReplicaType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, ReplicaType x) { + switch (x) { + case ReplicaType::SUM: + j = "SUM"; + break; + case ReplicaType::DISCARD_COPY: + j = "DISCARD_COPY"; + break; + default: + std::ostringstream oss; + oss << "Unknown ReplicaType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, ReplicaType &x) { + std::string as_str = j.get(); + if (as_str == "SUM") { + x = ReplicaType::SUM; + } else if (as_str == "DISCARD_COPY") { + x = ReplicaType::DISCARD_COPY; + } else { + std::ostringstream oss; + oss << "Unknown ReplicaType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element( + FlexFlow::ReplicaType::SUM, FlexFlow::ReplicaType::DISCARD_COPY); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/shard_parallel_dim.cc b/lib/op-attrs/src/op-attrs/shard_parallel_dim.cc new file mode 100644 index 0000000000..d27a857723 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/shard_parallel_dim.cc @@ -0,0 +1,9 @@ +#include "op-attrs/shard_parallel_dim.h" + +namespace FlexFlow { + +bool is_valid(ShardParallelDim const &d) { + return d.degree > 0 && d.size > 0 && (d.size % d.degree) == 0; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/shard_parallel_dim.dtg.cc b/lib/op-attrs/src/op-attrs/shard_parallel_dim.dtg.cc new file mode 100644 index 0000000000..9566eb486b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/shard_parallel_dim.dtg.cc @@ -0,0 +1,89 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml +/* proj-data +{ + "generated_from": "18e074f80556d90b9b27d6515bbf9071" +} +*/ + +#include "op-attrs/shard_parallel_dim.dtg.h" + +#include + +namespace FlexFlow { +ShardParallelDim::ShardParallelDim(size_t const &size, int const °ree) + : size(size), degree(degree) {} +bool ShardParallelDim::operator==(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) == + std::tie(other.size, other.degree); +} +bool ShardParallelDim::operator!=(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) != + std::tie(other.size, other.degree); +} +bool ShardParallelDim::operator<(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) < + std::tie(other.size, other.degree); +} +bool ShardParallelDim::operator>(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) > + std::tie(other.size, other.degree); +} +bool ShardParallelDim::operator<=(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) <= + std::tie(other.size, other.degree); +} +bool ShardParallelDim::operator>=(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) >= + std::tie(other.size, other.degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ShardParallelDim const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.size) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.degree) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ShardParallelDim + adl_serializer::from_json(json const &j) { + return {j.at("size").template get(), + j.at("degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ShardParallelDim const &v) { + j["__type"] = "ShardParallelDim"; + j["size"] = v.size; + j["degree"] = v.degree; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ShardParallelDim const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ShardParallelDim const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc new file mode 100644 index 0000000000..ed40f509d9 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -0,0 +1,52 @@ +#include "op-attrs/tensor_dims.h" +#include "op-attrs/replica_parallel_dim_set.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "utils/containers.h" +#include "utils/containers/zip_vectors.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +FFOrdered const &ff_ordered(TensorDims const &dims) { + return dims.ff_ordered; +} + +size_t num_dims(TensorDims const &dims) { + return dims.ff_ordered.size(); +} + +size_t dim_at_idx(TensorDims const &dims, ff_dim_t idx) { + return dims.ff_ordered.at(idx); +} + +size_t &dim_at_idx(TensorDims &dims, ff_dim_t idx) { + return dims.ff_ordered.at(idx); +} + +ParallelTensorDims lift_to_parallel(TensorDims const &dims) { + std::vector shard_degrees(num_dims(dims), + 1); // 1 repeated num_dims(dims) times + return lift_to_parallel_with_degrees(dims, 1, 1, shard_degrees); +} + +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &dims, + SumDegree sum_degree, + DiscardCopyDegree discard_copy_degree, + FFOrdered const &shard_degrees) { + std::vector lifted = + transform(zip(as_vector(dims.ff_ordered), as_vector(shard_degrees)), + [](std::pair const &p) { + size_t size = p.first; + int degree = p.second; + return ShardParallelDim(size, degree); + }); + + return ParallelTensorDims{FFOrdered{lifted}, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }}; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc b/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc new file mode 100644 index 0000000000..909be323ac --- /dev/null +++ b/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/tensor_dims.struct.toml +/* proj-data +{ + "generated_from": "5beb89eeae9eba303f90e726c794375d" +} +*/ + +#include "op-attrs/tensor_dims.dtg.h" + +#include "op-attrs/dim_ordered.h" +#include + +namespace FlexFlow { +TensorDims::TensorDims(::FlexFlow::FFOrdered const &ff_ordered) + : ff_ordered(ff_ordered) {} +bool TensorDims::operator==(TensorDims const &other) const { + return std::tie(this->ff_ordered) == std::tie(other.ff_ordered); +} +bool TensorDims::operator!=(TensorDims const &other) const { + return std::tie(this->ff_ordered) != std::tie(other.ff_ordered); +} +bool TensorDims::operator<(TensorDims const &other) const { + return std::tie(this->ff_ordered) < std::tie(other.ff_ordered); +} +bool TensorDims::operator>(TensorDims const &other) const { + return std::tie(this->ff_ordered) > std::tie(other.ff_ordered); +} +bool TensorDims::operator<=(TensorDims const &other) const { + return std::tie(this->ff_ordered) <= std::tie(other.ff_ordered); +} +bool TensorDims::operator>=(TensorDims const &other) const { + return std::tie(this->ff_ordered) >= std::tie(other.ff_ordered); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorDims const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::FFOrdered>{}(x.ff_ordered) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorDims + adl_serializer::from_json(json const &j) { + return {j.at("ff_ordered").template get<::FlexFlow::FFOrdered>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorDims const &v) { + j["__type"] = "TensorDims"; + j["ff_ordered"] = v.ff_ordered; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::FFOrdered>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorDims const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorDims const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.cc b/lib/op-attrs/src/op-attrs/tensor_shape.cc new file mode 100644 index 0000000000..850bea6d00 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/tensor_shape.cc @@ -0,0 +1,18 @@ +#include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_dims.h" + +namespace FlexFlow { + +size_t num_dims(TensorShape const &s) { + return s.dims.ff_ordered.size(); +} + +size_t dim_at_idx(TensorShape const &s, ff_dim_t idx) { + return dim_at_idx(s.dims, idx); +} + +size_t &dim_at_idx(TensorShape &s, ff_dim_t idx) { + return dim_at_idx(s.dims, idx); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc b/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc new file mode 100644 index 0000000000..92b31930fa --- /dev/null +++ b/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc @@ -0,0 +1,92 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/tensor_shape.struct.toml +/* proj-data +{ + "generated_from": "ef6fa5088b89d6da4dc8bddf0a6d3294" +} +*/ + +#include "op-attrs/tensor_shape.dtg.h" + +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/tensor_dims.dtg.h" +#include + +namespace FlexFlow { +TensorShape::TensorShape(::FlexFlow::TensorDims const &dims, + ::FlexFlow::DataType const &data_type) + : dims(dims), data_type(data_type) {} +bool TensorShape::operator==(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) == + std::tie(other.dims, other.data_type); +} +bool TensorShape::operator!=(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) != + std::tie(other.dims, other.data_type); +} +bool TensorShape::operator<(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) < + std::tie(other.dims, other.data_type); +} +bool TensorShape::operator>(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) > + std::tie(other.dims, other.data_type); +} +bool TensorShape::operator<=(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) <= + std::tie(other.dims, other.data_type); +} +bool TensorShape::operator>=(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) >= + std::tie(other.dims, other.data_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorShape const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorDims>{}(x.dims) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.data_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorShape + adl_serializer::from_json(json const &j) { + return {j.at("dims").template get<::FlexFlow::TensorDims>(), + j.at("data_type").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorShape const &v) { + j["__type"] = "TensorShape"; + j["dims"] = v.dims; + j["data_type"] = v.data_type; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::TensorDims>(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorShape const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorShape const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op.cc b/lib/op-attrs/src/op.cc deleted file mode 100644 index 5bc5498d6e..0000000000 --- a/lib/op-attrs/src/op.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "op-attrs/op.h" - -namespace FlexFlow { - -std::string get_operator_type_name(Op op) { - return fmt::to_string(op); -} - -bool is_parallel_op(OperatorType const &t) { - switch (t) { - case Op::REPARTITION: - case Op::COMBINE: - case Op::REPLICATE: - case Op::REDUCTION: - case Op::BATCH: - case Op::PIPELINE: - case Op::FUSED_PARALLEL: - return true; - default: - return false; - } -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/operator_attrs.cc b/lib/op-attrs/src/operator_attrs.cc index a524ab3d14..7a0027fe61 100644 --- a/lib/op-attrs/src/operator_attrs.cc +++ b/lib/op-attrs/src/operator_attrs.cc @@ -193,7 +193,7 @@ struct IsValidFunctor { bool is_valid(PCGOperatorAttrs const &attrs, std::vector const &input_shapes) { - return visit(IsValidFunctor{input_shapes}, attrs); + NOT_IMPLEMENTED(); } /* int num_outputs(OperatorParameters const &o) { */ diff --git a/lib/op-attrs/src/parallel_dim.cc b/lib/op-attrs/src/parallel_dim.cc deleted file mode 100644 index e103625fab..0000000000 --- a/lib/op-attrs/src/parallel_dim.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "op-attrs/parallel_dim.h" - -namespace FlexFlow { - -bool is_valid(ParallelDim const &dim) { - return dim.size > 0 && dim.degree >= 1 && dim.size % dim.degree == 0; -} - -bool is_replica_dim(ParallelDim const &dim) { - return dim.is_replica_dim; -} -} // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc b/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc deleted file mode 100644 index c7e70bb906..0000000000 --- a/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc +++ /dev/null @@ -1,362 +0,0 @@ -#include "parallel_dim_mapping_record_solver.h" -#include "op-attrs/parallel_tensor_shape.h" -#include -#include - -namespace FlexFlow { - -std::vector construct_weight_parallel_dims( - std::vector &records, - std::vector> mappings, - int input_idx, - int weight_idx) { - - std::vector output; - std::transform(mappings.cbegin(), - mappings.cend(), - output.begin(), - [&](std::tuple const &mapping) { - return construct_weight_parallel_dims(std::get<0>(mapping), - std::get<2>(mapping), - input_idx, - weight_idx, - std::get<1>(mapping)); - }); - return output; -} - -std::vector construct_output_parallel_dims( - std::vector> mappings, - int input_idx, - int output_idx) { - NOT_IMPLEMENTED(); -} - -std::vector construct_weight_parallel_dims( - std::vector> mappings, - int input_idx, - int weight_idx) { - NOT_IMPLEMENTED(); -} - -ParallelDimMappingRecord - construct_output_parallel_dims(int input_dim, - int output_dim, - int input_idx, - int output_idx, - std::optional operation) { - NOT_IMPLEMENTED(); -} - -ParallelDimMappingRecord - construct_weight_parallel_dims(int input_dim, - int weight_dim, - int input_idx, - int weight_idx, - std::optional operation) { - NOT_IMPLEMENTED(); -} -/* int get_output_to_input_dim_mapping(ParallelTensorShape const &output, */ -/* int output_dim, */ -/* ParallelTensorShape const &input) { */ -/* int output_idx = -1, input_idx = -1; */ -/* for (int i = 0; i < numOutputs; i++) { */ -/* if (output == outputs[i]) { */ -/* output_idx = i; */ -/* } */ -/* } */ -/* for (int i = 0; i < numInputs; i++) { */ -/* if (input == inputs[i]) { */ -/* input_idx = i; */ -/* } */ -/* } */ -/* assert(output_idx != -1); */ -/* assert(input_idx != -1); */ -/* for (size_t i = 0; i < parallel_dims_mapping->size(); i++) { */ -/* if ((*parallel_dims_mapping)[i].output_idx != output_idx) { */ -/* continue; */ -/* } */ -/* if ((*parallel_dims_mapping)[i].output_dim != output_dim) { */ -/* continue; */ -/* } */ -/* if ((*parallel_dims_mapping)[i].input_idx != input_idx) { */ -/* continue; */ -/* } */ -/* // Check validness */ -/* assert((*parallel_dims_mapping)[i].weight_idx = -1); */ -/* assert((*parallel_dims_mapping)[i].weight_dim = -1); */ -/* return (*parallel_dims_mapping)[i].input_dim; */ -/* } */ -/* assert(false); */ -/* return -1; */ -/* } */ - -/* int get_output_to_weight_dim_mapping(const ParallelTensor output, */ -/* int output_dim, */ -/* const ParallelTensor weight) { */ -/* int output_idx = -1, weight_idx = -1; */ -/* for (int i = 0; i < numOutputs; i++) { */ -/* if (output == outputs[i]) { */ -/* output_idx = i; */ -/* } */ -/* } */ -/* for (int i = 0; i < numInputs; i++) { */ -/* if (weight == weights[i]) { */ -/* weight_idx = i; */ -/* } */ -/* } */ -/* assert(output_idx != -1); */ -/* assert(weight_idx != -1); */ -/* for (size_t i = 0; i < parallel_dims_mapping->size(); i++) { */ -/* if ((*parallel_dims_mapping)[i].output_idx != output_idx) { */ -/* continue; */ -/* } */ -/* if ((*parallel_dims_mapping)[i].output_dim != output_dim) { */ -/* continue; */ -/* } */ -/* if ((*parallel_dims_mapping)[i].weight_idx != weight_idx) { */ -/* continue; */ -/* } */ -/* // Check validness */ -/* assert((*parallel_dims_mapping)[i].input_idx = -1); */ -/* assert((*parallel_dims_mapping)[i].input_dim = -1); */ -/* return (*parallel_dims_mapping)[i].weight_dim; */ -/* } */ -/* assert(false); */ -/* return -1; */ -/* } */ - -/* bool check_output_input_weight_parallel_dims(bool allocate_weights) const { - */ -/* // if (!allocate_weights) { */ -/* // assert(this->numWeights == 0); */ -/* // } */ - -/* for (ParallelDimMappingRecord const &record : *parallel_dims_mapping) { */ -/* assert(record.input_idx < this->numInputs); */ -/* assert(record.input_dim < this->inputs[record.input_idx]->num_dims); */ -/* ParallelDim const &input_dim = */ -/* inputs[record.input_idx]->dims[record.input_dim]; */ -/* /1* assert (input_dim.degree != ParallelDim::UNKNOWN_DEGREE); *1/ */ -/* /1* assert (input_dim.parallel_idx != ParallelDim::UNKNOWN_INDEX); *1/ */ - -/* ParallelDim other_dim; */ -/* switch (record.get_type()) { */ -/* case MappingRecordType::INPUT_OUTPUT: */ -/* assert(record.output_idx < this->numOutputs); */ -/* assert(record.output_dim < - * this->outputs[record.output_idx]->num_dims); */ -/* other_dim = outputs[record.output_idx]->dims[record.output_dim]; */ -/* break; */ -/* case MappingRecordType::INPUT_WEIGHT: */ -/* if (!allocate_weights) { */ -/* continue; */ -/* } */ -/* if (record.weight_idx >= this->numWeights) { */ -/* // The case where some weights are not used (e.g., no bias for - * linear) */ -/* continue; */ -/* } */ -/* assert(record.weight_dim < - * this->weights[record.weight_idx]->num_dims); */ -/* other_dim = weights[record.weight_idx]->dims[record.weight_dim]; */ -/* break; */ -/* } */ - -/* assert(other_dim.degree == input_dim.degree); */ -/* assert(other_dim.parallel_idx == input_dim.parallel_idx); */ -/* } */ -/* return true; */ -/* } */ - -/* bool check_output_input_weight_same_machine_view() const { */ -/* assert(numOutputs > 0); */ -/* MachineView machine_view = outputs[0]->machine_view; */ -/* for (int i = 0; i < numOutputs; i++) { */ -/* if (outputs[i]->machine_view != machine_view) { */ -/* return false; */ -/* } */ -/* } */ -/* for (int i = 0; i < numInputs; i++) { */ -/* if (inputs[i]->machine_view != machine_view) { */ -/* return false; */ -/* } */ -/* } */ -/* for (int i = 0; i < numWeights; i++) { */ -/* if (weights[i]->machine_view != machine_view) { */ -/* return false; */ -/* } */ -/* } */ -/* return true; */ -/* } */ - -std::vector construct_weight_parallel_dims( - std::vector> mappings, int input_idx, int weight_idx) { - std::vector output; - std::transform(mappings.cbegin(), - mappings.cend(), - output.begin(), - [&](std::pair const &mapping) { - return construct_weight_parallel_dims( - mapping.first, mapping.second, input_idx, weight_idx); - }); - return output; -} - -void construct_weight_parallel_dims( - std::vector &records, - int input_dim, - int weight_dim, - int input_idx, - int weight_idx, - std::optional operation) { - records.push_back(ParallelDimMappingRecord::input_weight_record( - input_idx, input_dim, weight_idx, weight_dim, operation)); -} - -/* void ParallelDimMappingRecordSolver::register_weight_parallel_dims( */ -/* std::vector> mappings, int input_idx, int weight_idx) - * { */ -/* construct_weight_parallel_dims( */ -/* *this->parallel_dims_mapping, mappings, input_idx, weight_idx); */ -/* } */ - -/* void register_weight_parallel_dims( */ -/* std::vector> mappings, */ -/* int input_idx, */ -/* int weight_idx) { */ -/* construct_weight_parallel_dims( */ -/* *this->parallel_dims_mapping, mappings, input_idx, weight_idx); */ -/* } */ - -/* void register_weight_parallel_dims( */ -/* int input_dim, */ -/* int weight_dim, */ -/* int input_idx, */ -/* int weight_idx, */ -/* tl::optional operation) { */ -/* construct_weight_parallel_dims(*this->parallel_dims_mapping, */ -/* input_dim, */ -/* weight_dim, */ -/* input_idx, */ -/* weight_idx, */ -/* operation); */ -/* } */ - -void construct_output_parallel_dims( - std::vector &records, - std::vector> mappings, - int input_idx, - int output_idx) { - for (std::tuple const &mapping : mappings) { - construct_output_parallel_dims(std::get<0>(mapping), - std::get<2>(mapping), - input_idx, - output_idx, - std::get<1>(mapping)); - } -} - -void construct_output_parallel_dims( - std::vector &records, - std::vector> mappings, - int input_idx, - int output_idx) { - for (std::pair const &mapping : mappings) { - construct_output_parallel_dims( - mapping.first, mapping.second, input_idx, output_idx); - } -} - -void construct_output_parallel_dims( - std::vector &records, - int input_dim, - int output_dim, - int input_idx, - int output_idx, - std::optional operation) { - records.push_back(ParallelDimMappingRecord::input_output_record( - input_idx, input_dim, output_idx, output_dim, operation)); -} - -/* void register_output_parallel_dims( */ -/* std::vector> mappings, int input_idx, int output_idx) - * { */ -/* construct_output_parallel_dims( */ -/* *this->parallel_dims_mapping, mappings, input_idx, output_idx); */ -/* } */ - -/* void register_output_parallel_dims( */ -/* std::vector> mappings, */ -/* int input_idx, */ -/* int output_idx) { */ -/* construct_output_parallel_dims( */ -/* *this->parallel_dims_mapping, mappings, input_idx, output_idx); */ -/* } */ - -/* void register_output_parallel_dims( */ -/* int input_dim, */ -/* int output_dim, */ -/* int input_idx, */ -/* int output_idx, */ -/* tl::optional operation) { */ -/* construct_output_parallel_dims(*this->parallel_dims_mapping, */ -/* input_dim, */ -/* output_dim, */ -/* input_idx, */ -/* output_idx, */ -/* operation); */ -/* } */ - -/* ParallelDimMappingSolution solve_parallel_dim_mappings( */ -/* std::vector const &mappings, */ -/* std::vector const &inputs, */ -/* int numWeights, int numOutputs) { */ - -/* ParallelDimMappingSolution solution = [&]() -> ParallelDimMappingSolution { - */ -/* std::vector weight_shapes(numWeights); */ -/* std::vector output_shapes(numOutputs); */ -/* return { weight_shapes, output_shapes }; */ -/* }(); */ - -/* for (ParallelDimMappingRecord const &record : mappings) { */ -/* ParallelDim const &input_dim = - * inputs.at(record.input_idx).at(record.input_dim); */ - -/* switch (record.get_type()) { */ -/* case MappingRecordType::INPUT_OUTPUT: { */ -/* ParallelDim &output_dim = - * solution.output_shapes.at(record.output_idx).at(record.output_dim); */ -/* output_dim.degree = input_dim.degree; */ -/* output_dim.parallel_idx = input_dim.parallel_idx; */ - -/* if (output_dim.is_replica_dim) { */ -/* output_dim.size = input_dim.degree; */ -/* } */ -/* } break; */ -/* case MappingRecordType::INPUT_WEIGHT: { */ -/* ParallelDim &weight_dim = - * solution.weight_shapes.at(record.weight_idx).at(record.weight_dim); */ -/* weight_dim.degree = input_dim.degree; */ -/* weight_dim.parallel_idx = input_dim.parallel_idx; */ - -/* if (weight_dim.is_replica_dim) { */ -/* weight_dim.size = input_dim.degree; */ -/* } */ -/* } break; */ -/* } */ -/* } */ - -/* return solution; */ -/* } */ - -ParallelDimMappingSolution solve_parallel_dim_mappings( - std::vector const &mappings, - std::vector const &input, - int numWeights, - int numOutputs) { - NOT_IMPLEMENTED(); -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_dim_mapping_record_solver.h b/lib/op-attrs/src/parallel_dim_mapping_record_solver.h deleted file mode 100644 index a46192edeb..0000000000 --- a/lib/op-attrs/src/parallel_dim_mapping_record_solver.h +++ /dev/null @@ -1,106 +0,0 @@ -/** - * @file - * @warning This is legacy code the should be removed - * (partially tracked in - * https://github.com/flexflow/FlexFlow/issues/519). - * @brief Helper functions for computing data dependencies of parallel - * operators. Functions based on an incorrect abstraction that should eventually - * be removed in favor of something like https://doi.org/10.1145/3302424.3303953 - */ - -#ifndef _FLEXFLOW_OP_META_SRC_PARELLEL_DIM_MAPPING_RECORD_SOLVER_H -#define _FLEXFLOW_OP_META_SRC_PARELLEL_DIM_MAPPING_RECORD_SOLVER_H - -#include "op-attrs/parallel_tensor_shape.h" -#include "parallel_dim_mapping_record.h" - -namespace FlexFlow { - -std::vector - construct_weight_parallel_dims(std::vector> mappings, - int input_idx = 0, - int weight_idx = 0); -std::vector construct_weight_parallel_dims( - std::vector> mappings, - int input_idx = 0, - int weight_idx = 0); -ParallelDimMappingRecord construct_weight_parallel_dims( - int input_dim, - int weight_dim, - int input_idx = 0, - int weight_idx = 0, - std::optional operation = std::nullopt); - -std::vector - construct_output_parallel_dims(std::vector> mappings, - int input_idx = 0, - int output_idx = 0); -std::vector construct_output_parallel_dims( - std::vector> mappings, - int input_idx = 0, - int output_idx = 0); -ParallelDimMappingRecord construct_output_parallel_dims( - int input_dim, - int output_dim, - int input_idx = 0, - int output_idx = 0, - std::optional operation = std::nullopt); - -struct ParallelDimMappingSolution { - std::vector weight_shapes; - std::vector output_shapes; -}; - -ParallelDimMappingSolution solve_parallel_dim_mappings( - std::vector const &mappings, - std::vector const &input, - int numWeights, - int numOutputs); - -/* class ParallelDimMappingRecordSolver { */ -/* /1* void register_weight_parallel_dims(std::vector> - * mappings, *1/ */ -/* /1* int input_idx = 0, *1/ */ -/* /1* int weight_idx = 0); *1/ */ - -/* /1* void register_output_parallel_dims(std::vector> - * mappings, *1/ */ -/* /1* int input_idx = 0, *1/ */ -/* /1* int output_idx = 0); *1/ */ - -/* /1* int get_output_to_input_dim_mapping(const ParallelTensor output, *1/ */ -/* /1* int output_dim, *1/ */ -/* /1* const ParallelTensor input); *1/ */ -/* /1* int get_output_to_weight_dim_mapping(const ParallelTensor output, *1/ - */ -/* /1* int output_dim, *1/ */ -/* /1* const ParallelTensor weight); *1/ - */ -/* void register_weight_parallel_dims( */ -/* std::vector> mappings, */ -/* int input_idx = 0, */ -/* int weight_idx = 0); */ -/* void register_weight_parallel_dims( */ -/* int input_dim, */ -/* int weight_dim, */ -/* int input_idx = 0, */ -/* int weight_idx = 0, */ -/* std::optional operation = std::nullopt); */ -/* void register_output_parallel_dims( */ -/* std::vector> mappings, */ -/* int input_idx = 0, */ -/* int output_idx = 0); */ -/* void register_output_parallel_dims( */ -/* int input_dim, */ -/* int output_dim, */ -/* int input_idx = 0, */ -/* int output_idx = 0, */ -/* std::optional operation = std::nullopt); */ - -/* private: */ -/* std::vector *parallel_dims_mapping; */ -/* }; */ - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/src/parallel_tensor_shape.cc b/lib/op-attrs/src/parallel_tensor_shape.cc deleted file mode 100644 index e226c38eac..0000000000 --- a/lib/op-attrs/src/parallel_tensor_shape.cc +++ /dev/null @@ -1,97 +0,0 @@ -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/containers.h" -#include "utils/hash-utils.h" - -namespace FlexFlow { - -int ParallelTensorShape::num_dims() const { - return dims.num_dims(); -} - -static std::vector lift_dims(TensorDims const &dims) { - std::vector lifted_dims; - for (size_t dim_size : dims) { - lifted_dims.push_back({dim_size, 1, false}); - } - lifted_dims.push_back({1, 1, true}); - return lifted_dims; -} - -ParallelTensorDims::ParallelTensorDims(TensorDims const &dims) - : data(lift_dims(dims)) {} - -ParallelTensorShape::ParallelTensorShape(TensorShape const &tensor_shape) - : dims(tensor_shape.dims), data_type(tensor_shape.data_type) {} - -int get_num_replica_dims(ParallelTensorShape const &shape) { - return count(shape.dims, is_replica_dim); -} - -int get_num_replicas(ParallelTensorShape const &shape) { - return product( - transform(filter(as_vector(shape.dims), is_replica_dim), - [](ParallelDim const &d) -> int { return d.degree; })); -} - -bool is_valid(ParallelTensorDims const &dims) { - return all_of(dims, [](ParallelDim const &d) { return is_valid(d); }); -} - -bool is_valid(ParallelTensorShape const &shape) { - return is_valid(shape.dims); -} - -ParallelTensorDims::iterator ParallelTensorDims::begin() { - return data.begin(); -} - -ParallelTensorDims::const_iterator ParallelTensorDims::begin() const { - return data.begin(); -} - -ParallelTensorDims::const_iterator ParallelTensorDims::cbegin() const { - return data.cbegin(); -} - -ParallelTensorDims::iterator ParallelTensorDims::end() { - return data.end(); -} - -ParallelTensorDims::const_iterator ParallelTensorDims::end() const { - return data.end(); -} - -ParallelTensorDims::const_iterator ParallelTensorDims::cend() const { - return data.cend(); -} - -ParallelDim const &ParallelTensorDims::at(ff_dim_t const &d) const { - return data.at(d); -} - -ParallelDim &ParallelTensorDims::at(ff_dim_t const &d) { - return data.at(d); -} - -size_t ParallelTensorDims::num_dims() const { - return data.size(); -} - -ParallelDim const &ParallelTensorShape::at(ff_dim_t const &d) const { - return dims.at(d); -} - -ParallelDim &ParallelTensorShape::at(ff_dim_t const &d) { - return dims.at(d); -} -ParallelDim const &ParallelTensorShape::operator[](ff_dim_t const &d) const { - return dims.at(d); -} -ParallelDim &ParallelTensorShape::operator[](ff_dim_t const &d) { - return dims.at(d); -} - -TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) { - NOT_IMPLEMENTED(); -} -} // namespace FlexFlow diff --git a/lib/op-attrs/src/reduce.cc b/lib/op-attrs/src/reduce.cc deleted file mode 100644 index 9d1770d5be..0000000000 --- a/lib/op-attrs/src/reduce.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/reduce.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/reduction.cc b/lib/op-attrs/src/reduction.cc deleted file mode 100644 index 22fc9bab6a..0000000000 --- a/lib/op-attrs/src/reduction.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "op-attrs/ops/reduction.h" - -namespace FlexFlow { - -/* ParallelTensorShape ReductionAttrs::output_shape(ParallelTensorShape const - * &input_shape) const { */ -/* ParallelTensorShape output = input_shape; */ -/* output.at(this->reduction_legion_dim).degree /= this->reduction_degree; */ -/* output.at(this->reduction_legion_dim).size /= this->reduction_degree; */ -/* return output; */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/repartition.cc b/lib/op-attrs/src/repartition.cc deleted file mode 100644 index 672e68b4f6..0000000000 --- a/lib/op-attrs/src/repartition.cc +++ /dev/null @@ -1,11 +0,0 @@ -#include "op-attrs/ops/repartition.h" - -namespace FlexFlow { - -/* bool RepartitionAttrs::is_valid(ParallelTensorShape const &input_shape) const - * { */ -/* ParallelDim dim = input_shape.at(this->repartition_legion_dim); */ -/* return (dim.size % this->repartition_degree * dim.degree == 0); */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/replicate.cc b/lib/op-attrs/src/replicate.cc deleted file mode 100644 index 73ad288d8c..0000000000 --- a/lib/op-attrs/src/replicate.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/replicate.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/reshape.cc b/lib/op-attrs/src/reshape.cc deleted file mode 100644 index e8349e1f26..0000000000 --- a/lib/op-attrs/src/reshape.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/reshape.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/softmax.cc b/lib/op-attrs/src/softmax.cc deleted file mode 100644 index 9f95da4fb7..0000000000 --- a/lib/op-attrs/src/softmax.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/softmax.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/split.cc b/lib/op-attrs/src/split.cc deleted file mode 100644 index acda8f3262..0000000000 --- a/lib/op-attrs/src/split.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/split.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/tensor_shape.cc b/lib/op-attrs/src/tensor_shape.cc deleted file mode 100644 index e456b31e3c..0000000000 --- a/lib/op-attrs/src/tensor_shape.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "op-attrs/tensor_shape.h" - -namespace FlexFlow { - -size_t TensorShape::at(ff_dim_t d) const { - return dims.at(d); -} - -size_t TensorShape::operator[](ff_dim_t d) const { - return dims[d]; -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/topk.cc b/lib/op-attrs/src/topk.cc deleted file mode 100644 index 9d701e4868..0000000000 --- a/lib/op-attrs/src/topk.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/topk.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/transpose.cc b/lib/op-attrs/src/transpose.cc deleted file mode 100644 index ad4a84a3d5..0000000000 --- a/lib/op-attrs/src/transpose.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/transpose.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/test/CMakeLists.txt b/lib/op-attrs/test/CMakeLists.txt new file mode 100644 index 0000000000..b6ff72fc00 --- /dev/null +++ b/lib/op-attrs/test/CMakeLists.txt @@ -0,0 +1,13 @@ +ff_add_test_executable( + NAME + op-attrs-tests + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + src/ + DEPS + utils + op-attrs + doctest + utils-test-common +) diff --git a/lib/op-attrs/test/src/dim_ordered/slice.cc b/lib/op-attrs/test/src/dim_ordered/slice.cc new file mode 100644 index 0000000000..8640b077dc --- /dev/null +++ b/lib/op-attrs/test/src/dim_ordered/slice.cc @@ -0,0 +1,23 @@ +#include "op-attrs/dim_ordered/slice.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE( + "slice(DimOrdered, std::optional, std::optional)") { + FFOrdered d = FFOrdered{ + 1, + 2, + 3, + 4, + }; + + FFOrdered result = slice(d, std::nullopt, ff_dim_t{-1}); + FFOrdered correct = FFOrdered{ + 1, + 2, + 3, + }; + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/ops/combine.cc b/lib/op-attrs/test/src/ops/combine.cc new file mode 100644 index 0000000000..a50b3b01de --- /dev/null +++ b/lib/op-attrs/test/src/ops/combine.cc @@ -0,0 +1,59 @@ +#include "op-attrs/ops/combine.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Combine shape inference") { + + ParallelTensorShape input = { + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{14, 1}, + ShardParallelDim{16, 3}, + ShardParallelDim{18, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("valid") { + ff_dim_t dim = 2; + int degree = 3; + CombineAttrs attrs = CombineAttrs{ + /*repartition_dim=*/dim, + /*repartition_degree=*/degree, + }; + + tl::expected result = + get_output_shape(attrs, input); + + tl::expected correct = [&] { + ParallelTensorShape output = input; + output.dims.shard_dims.at(dim).degree /= degree; + return output; + }(); + + CHECK(result == correct); + } + + SUBCASE("invalid") { + ff_dim_t dim = 2; + int degree = 4; + CombineAttrs attrs = CombineAttrs{ + /*repartition_dim=*/dim, + /*repartition_degree=*/degree, + }; + + tl::expected result = + get_output_shape(attrs, input); + + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); + } + } +} diff --git a/lib/op-attrs/test/src/ops/linear.cc b/lib/op-attrs/test/src/ops/linear.cc new file mode 100644 index 0000000000..0d23dc35df --- /dev/null +++ b/lib/op-attrs/test/src/ops/linear.cc @@ -0,0 +1,231 @@ +#include "op-attrs/ops/linear.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest.h" +#include "utils/integer_conversions.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Linear shape inference") { + int out_channels = 16; + LinearAttrs attrs = LinearAttrs{ + /*out_channels=*/out_channels, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/Activation::RELU, + /*regularizer=*/std::nullopt, + }; + + size_t batch_size = 12; + size_t extra_dim = 16; + size_t in_channels = 8; + + TensorShape input = TensorShape{ + TensorDims{ + FFOrdered{ + batch_size, + extra_dim, + in_channels, + }, + }, + DataType::FLOAT, + }; + + TensorShape output = TensorShape{ + TensorDims{ + FFOrdered{ + batch_size, + extra_dim, + size_t_from_int(out_channels), + }, + }, + DataType::FLOAT, + }; + + TensorShape kernel = TensorShape{ + TensorDims{ + FFOrdered{ + in_channels, + size_t_from_int(out_channels), + }, + }, + DataType::FLOAT, + }; + + TensorShape bias = TensorShape{ + TensorDims{ + FFOrdered{ + size_t_from_int(out_channels), + }, + }, + DataType::FLOAT, + }; + + // get_output_shape + { + tl::expected output_result = + get_output_shape(attrs, input); + tl::expected output_correct = output; + CHECK(output_result == output_correct); + } + + // get_weight_shape + { + tl::expected kernel_result = + get_kernel_shape(attrs, input); + tl::expected kernel_correct = kernel; + CHECK(kernel_result == kernel_correct); + } + + // get_bias_shape + { + tl::expected bias_result = + get_bias_shape(attrs, input); + tl::expected bias_correct = bias; + CHECK(bias_result == bias_correct); + } + + auto make_input = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_extra_dim, + int o_channel) { + return lift_to_parallel_with_degrees( + input, o_sum, o_eq, FFOrdered{o_batch, o_extra_dim, o_channel}); + }; + + auto make_output = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_extra_dim, + int o_channel) { + return lift_to_parallel_with_degrees( + output, o_sum, o_eq, FFOrdered{o_batch, o_extra_dim, o_channel}); + }; + + auto make_kernel = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_inchannel, + int o_outchannel) { + return lift_to_parallel_with_degrees( + kernel, o_sum, o_eq, FFOrdered{o_inchannel, o_outchannel}); + }; + + auto make_bias = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_outchannel) { + return lift_to_parallel_with_degrees( + bias, o_sum, o_eq, FFOrdered{o_outchannel}); + }; + + SUBCASE("data parallelism") { + int input_sum_degree = 2; + int extra_dim_degree = 8; + int degree = 4; + + ParallelTensorShape par_input = make_input(SumDegree{input_sum_degree}, + DiscardCopyDegree{1}, + degree, + extra_dim_degree, + 1); + + { + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = + make_output(SumDegree{input_sum_degree}, + DiscardCopyDegree{1}, + degree, + extra_dim_degree, + 1); + CHECK(result == correct); + } + + { + tl::expected result = + get_kernel_shape(attrs, par_input); + tl::expected correct = make_kernel( + SumDegree{1}, + DiscardCopyDegree{input_sum_degree * degree * extra_dim_degree}, + 1, + 1); + CHECK(result == correct); + } + + { + tl::expected result = + get_bias_shape(attrs, par_input); + tl::expected correct = + make_bias(SumDegree{input_sum_degree}, + DiscardCopyDegree{degree * extra_dim_degree}, + 1); + CHECK(result == correct); + } + } + + SUBCASE("reduction parallelism") { + int input_sum_degree = 2; + int degree = 4; + + ParallelTensorShape par_input = make_input( + SumDegree{input_sum_degree}, DiscardCopyDegree{1}, 1, 1, degree); + + { + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = + make_output(SumDegree{input_sum_degree * degree}, + DiscardCopyDegree{1}, + 1, + 1, + 1); + CHECK(result == correct); + } + + { + tl::expected result = + get_kernel_shape(attrs, par_input); + tl::expected correct = make_kernel( + SumDegree{1}, DiscardCopyDegree{input_sum_degree}, degree, 1); + CHECK(result == correct); + } + + { + tl::expected result = + get_bias_shape(attrs, par_input); + tl::expected correct = make_bias( + SumDegree{input_sum_degree * degree}, DiscardCopyDegree{1}, 1); + CHECK(result == correct); + } + } + + SUBCASE("output channel parallelism") { + int input_sum_degree = 2; + int degree = 4; + + ParallelTensorShape par_input = make_input( + SumDegree{input_sum_degree}, DiscardCopyDegree{degree}, 1, 1, 1); + + { + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = make_output( + SumDegree{input_sum_degree}, DiscardCopyDegree{1}, 1, 1, degree); + CHECK(result == correct); + } + + { + tl::expected result = + get_kernel_shape(attrs, par_input); + tl::expected correct = make_kernel( + SumDegree{1}, DiscardCopyDegree{input_sum_degree}, 1, degree); + CHECK(result == correct); + } + + { + tl::expected result = + get_bias_shape(attrs, par_input); + tl::expected correct = make_bias( + SumDegree{input_sum_degree}, DiscardCopyDegree{1}, degree); + CHECK(result == correct); + } + } + } +} diff --git a/lib/op-attrs/test/src/ops/reduction.cc b/lib/op-attrs/test/src/ops/reduction.cc new file mode 100644 index 0000000000..6f73951e00 --- /dev/null +++ b/lib/op-attrs/test/src/ops/reduction.cc @@ -0,0 +1,55 @@ +#include "op-attrs/ops/reduction.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Reduction shape inference") { + + ParallelTensorShape input = { + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{14, 1}, + ShardParallelDim{16, 3}, + ShardParallelDim{18, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("valid") { + int degree = 3; + ReductionAttrs attrs = ReductionAttrs{ + /*repartition_degree=*/degree, + }; + + tl::expected result = + get_output_shape(attrs, input); + + tl::expected correct = [&] { + ParallelTensorShape output = input; + output.dims.replica_dims.sum_degree.value /= degree; + return output; + }(); + + CHECK(result == correct); + } + + SUBCASE("invalid") { + int degree = 4; + ReductionAttrs attrs = ReductionAttrs{ + /*repartition_degree=*/degree, + }; + + tl::expected result = + get_output_shape(attrs, input); + + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); + } + } +} diff --git a/lib/op-attrs/test/src/ops/repartition.cc b/lib/op-attrs/test/src/ops/repartition.cc new file mode 100644 index 0000000000..3b3ae92b4c --- /dev/null +++ b/lib/op-attrs/test/src/ops/repartition.cc @@ -0,0 +1,40 @@ +#include "op-attrs/ops/repartition.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Repartition shape inference") { + ff_dim_t dim = 2; + int degree = 4; + RepartitionAttrs attrs = RepartitionAttrs{ + /*repartition_dim=*/dim, + /*repartition_degree=*/degree, + }; + + ParallelTensorShape input = { + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{14, 1}, + ShardParallelDim{16, 3}, + ShardParallelDim{18, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input); + + tl::expected correct = [&] { + ParallelTensorShape output = input; + output.dims.shard_dims.at(dim).degree *= degree; + return output; + }(); + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/ops/replicate.cc b/lib/op-attrs/test/src/ops/replicate.cc new file mode 100644 index 0000000000..b326038388 --- /dev/null +++ b/lib/op-attrs/test/src/ops/replicate.cc @@ -0,0 +1,33 @@ +#include "op-attrs/ops/replicate.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Replicate shape inference") { + ReplicateAttrs attrs = ReplicateAttrs{ + /*replicate_degree=*/4, + }; + + ParallelTensorShape input = { + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + ShardParallelDim{14, 2}, + ShardParallelDim{16, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + ParallelTensorShape result = get_output_shape(attrs, input); + + ParallelTensorShape correct_output = input; + correct_output.dims.replica_dims.discard_copy_degree = 8; + + CHECK(result == correct_output); + } +} diff --git a/lib/op-attrs/test/src/test_attention.cc b/lib/op-attrs/test/src/test_attention.cc new file mode 100644 index 0000000000..74ae4565ca --- /dev/null +++ b/lib/op-attrs/test/src/test_attention.cc @@ -0,0 +1,272 @@ +#include "op-attrs/ops/attention.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest.h" +#include "utils/integer_conversions.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(MultiHeadAttentionAttrs, TensorShape, " + "TensorShape, TensorShape)") { + int embed_dim = 32; + + /* Parameter meanings match those at + * https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html + */ + MultiHeadAttentionAttrs attrs = { + /*embed_dim=*/embed_dim, + /*num_heads=*/10, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + + size_t batch_size = 40; + size_t seq_len = 48; + + TensorShape input_q = { + TensorDims{FFOrdered{ + batch_size, + seq_len, + size_t_from_int(attrs.embed_dim), + }}, + DataType::FLOAT, + }; + + TensorShape input_k = { + TensorDims{ + FFOrdered{ + batch_size, + seq_len, + size_t_from_int(attrs.kdim), + }, + }, + DataType::FLOAT, + }; + + TensorShape input_v = { + TensorDims{ + FFOrdered{ + batch_size, + seq_len, + size_t_from_int(attrs.vdim), + }, + }, + DataType::FLOAT, + }; + + SUBCASE("get_output_shape") { + tl::expected result = + get_output_shape(attrs, input_q, input_k, input_v); + + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{ + batch_size, + seq_len, + size_t_from_int(attrs.embed_dim), + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("get_weights_shape") { + tl::expected result = + get_weights_shape(attrs, input_q, input_k, input_v); + + int qProjPerHeadWeightSize = + attrs.kdim * dim_at_idx(input_q, ff_dim_t{-1}); + int kProjPerHeadWeightSize = + attrs.kdim * dim_at_idx(input_k, ff_dim_t{-1}); + int vProjPerHeadWeightSize = + attrs.vdim * dim_at_idx(input_v, ff_dim_t{-1}); + int oProjPerHeadWeightSize = attrs.embed_dim * attrs.vdim; + int perHeadWeightSize = qProjPerHeadWeightSize + kProjPerHeadWeightSize + + vProjPerHeadWeightSize + oProjPerHeadWeightSize; + + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(perHeadWeightSize), + size_t_from_int(attrs.num_heads), + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("parallel shape inference for MultiHeadAttentionAttrs") { + int embed_dim = 32; + + /* Parameter meanings can be found at + * https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html + */ + MultiHeadAttentionAttrs attrs = { + /*embed_dim=*/embed_dim, + /*num_heads=*/10, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + + size_t batchsize = 40; + size_t seq_len = 48; + size_t q_size = 56; + size_t k_size = 64; + size_t v_size = 72; + + TensorShape unpar_q_shape = TensorShape{ + TensorDims{ + FFOrdered{ + batchsize, + seq_len, + q_size, + }, + }, + DataType::FLOAT, + }; + + TensorShape unpar_k_shape = TensorShape{ + TensorDims{ + FFOrdered{ + batchsize, + seq_len, + k_size, + }, + }, + DataType::FLOAT, + }; + + TensorShape unpar_v_shape = TensorShape{ + TensorDims{ + FFOrdered{ + batchsize, + seq_len, + v_size, + }, + }, + DataType::FLOAT, + }; + + tl::expected result_unpar_o_shape = + get_output_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); + REQUIRE(result_unpar_o_shape.has_value()); + TensorShape unpar_o_shape = result_unpar_o_shape.value(); + + tl::expected result_unpar_w_shape = + get_weights_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); + REQUIRE(result_unpar_o_shape.has_value()); + TensorShape unpar_w_shape = result_unpar_w_shape.value(); + + auto make_q = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_seq_len, + int o_q) { + return lift_to_parallel_with_degrees( + unpar_q_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_q}); + }; + + auto make_k = [&](int o_sum, + int o_eq, + int o_batch, + int o_seq_len, + int o_k) { + return lift_to_parallel_with_degrees( + unpar_k_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_k}); + }; + + auto make_v = [&](int o_sum, + int o_eq, + int o_batch, + int o_seq_len, + int o_v) { + return lift_to_parallel_with_degrees( + unpar_v_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_v}); + }; + + auto make_o = [&](int o_sum, + int o_eq, + int o_batch, + int o_seq_len, + int o_o) { + return lift_to_parallel_with_degrees( + unpar_o_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_o}); + }; + + auto make_w = [&](int o_sum, int o_eq, int o_e, int o_h) { + return lift_to_parallel_with_degrees( + unpar_w_shape, o_sum, o_eq, FFOrdered{o_e, o_h}); + }; + + SUBCASE("data parallelism") { + int o_b = 4; + ParallelTensorShape q = make_q(1, 1, o_b, 1, 1); + ParallelTensorShape k = make_k(1, 1, o_b, 1, 1); + ParallelTensorShape v = make_v(1, 1, o_b, 1, 1); + + tl::expected result_o = + get_output_shape(attrs, q, k, v); + tl::expected correct_o = + make_o(1, 1, o_b, 1, 1); + + CHECK(result_o == correct_o); + + tl::expected result_w = + get_weights_shape(attrs, q, k, v); + tl::expected correct_w = + make_w(1, o_b, 1, 1); + + CHECK(result_w == correct_w); + } + + SUBCASE("attention head parallelism") { + int o_h = 2; + ParallelTensorShape q = make_q(1, o_h, 1, 1, 1); + ParallelTensorShape k = make_k(1, o_h, 1, 1, 1); + ParallelTensorShape v = make_v(1, o_h, 1, 1, 1); + + tl::expected result_o = + get_output_shape(attrs, q, k, v); + tl::expected correct_o = + make_o(o_h, 1, 1, 1, 1); + + CHECK(result_o == correct_o); + + tl::expected result_w = + get_weights_shape(attrs, q, k, v); + tl::expected correct_w = + make_w(1, 1, 1, o_h); + + CHECK(result_w == correct_w); + } + + SUBCASE("combined data & attention head parallelism") { + int o_b = 4; + int o_h = 2; + ParallelTensorShape q = make_q(1, o_h, o_b, 1, 1); + ParallelTensorShape k = make_k(1, o_h, o_b, 1, 1); + ParallelTensorShape v = make_v(1, o_h, o_b, 1, 1); + + tl::expected result_o = + get_output_shape(attrs, q, k, v); + tl::expected correct_o = + make_o(o_h, 1, o_b, 1, 1); + + CHECK(result_o == correct_o); + + tl::expected result_w = + get_weights_shape(attrs, q, k, v); + tl::expected correct_w = + make_w(1, o_b, 1, o_h); + + CHECK(result_w == correct_w); + } + } +} diff --git a/lib/op-attrs/test/src/test_batch_matmul.cc b/lib/op-attrs/test/src/test_batch_matmul.cc new file mode 100644 index 0000000000..f48478be10 --- /dev/null +++ b/lib/op-attrs/test/src/test_batch_matmul.cc @@ -0,0 +1,268 @@ +#include "op-attrs/ops/batch_matmul.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(BatchMatmulAttrs, TensorShape)") { + size_t b = 4; + size_t m = 6; + size_t n = 8; + size_t p = 10; + + BatchMatmulAttrs attrs = { + /*a_seq_length_dim=*/0, // TODO figure out if these arguments are still + // relevant + /*b_seq_length_dim=*/0, + }; + + TensorShape input_lhs_shape = { + TensorDims{ + FFOrdered{ + b, + n, + m, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("valid") { + TensorShape input_rhs_shape = { + TensorDims{ + FFOrdered{ + b, + m, + p, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input_lhs_shape, input_rhs_shape); + + tl::expected correct_output_shape = TensorShape{ + TensorDims{ + FFOrdered{ + b, + n, + p, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct_output_shape); + } + + SUBCASE("mismatched b") { + TensorShape input_rhs_shape = { + TensorDims{ + FFOrdered{ + b + 1, + m, + p, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input_lhs_shape, input_rhs_shape); + + CHECK(!result.has_value()); + } + + SUBCASE("mismatched m") { + TensorShape input_rhs_shape = { + TensorDims{ + FFOrdered{ + b, + m + 1, + p, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input_lhs_shape, input_rhs_shape); + + CHECK(!result.has_value()); + } + } + + TEST_CASE("get_output_shape(BatchMatmulAttrs, ParallelTensorShape)") { + size_t b = 2 * 2; + int o_b = 2; + size_t m = 3 * 3; + int o_m = 3; + size_t n = 5 * 5; + int o_n = 5; + size_t p = 7 * 7; + int o_p = 7; + int o_sum = 11; + + BatchMatmulAttrs attrs = { + /*a_seq_length_dim=*/0, // TODO figure out if these arguments are still + // relevant + /*b_seq_length_dim=*/0, + }; + + auto make_lhs = [&](int o_sum, int o_eq, int o_b, int o_n, int o_m) { + return ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{b, o_b}, + ShardParallelDim{n, o_n}, + ShardParallelDim{m, o_m}, + }, + ReplicaParallelDimSet{ + o_sum, + o_eq, + }, + }, + DataType::FLOAT, + }; + }; + + auto make_rhs = [&](int o_sum, int o_eq, int o_b, int o_m, int o_p) { + return ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{b, o_b}, + ShardParallelDim{m, o_m}, + ShardParallelDim{p, o_p}, + }, + ReplicaParallelDimSet{ + o_sum, + o_eq, + }, + }, + DataType::FLOAT, + }; + }; + + auto make_output = [&](int o_sum, int o_eq, int o_b, int o_n, int o_p) { + return ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{b, o_b}, + ShardParallelDim{n, o_n}, + ShardParallelDim{p, o_p}, + }, + ReplicaParallelDimSet{ + o_sum, + o_eq, + }, + }, + DataType::FLOAT, + }; + }; + + SUBCASE("data parallel") { + tl::expected result = get_output_shape( + attrs, make_lhs(1, 1, o_b, 1, 1), make_rhs(1, 1, o_b, 1, 1)); + tl::expected correct = + make_output(1, 1, o_b, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("n parallel") { + tl::expected result = get_output_shape( + attrs, make_lhs(1, 1, 1, o_n, 1), make_rhs(1, o_n, 1, 1, 1)); + tl::expected correct = + make_output(1, 1, 1, o_n, 1); + + CHECK(result == correct); + } + + SUBCASE("p parallel") { + tl::expected result = get_output_shape( + attrs, make_lhs(1, o_p, 1, 1, 1), make_rhs(1, 1, 1, 1, o_p)); + tl::expected correct = + make_output(1, 1, 1, 1, o_p); + + CHECK(result == correct); + } + + SUBCASE("reduction parallel") { + tl::expected result = get_output_shape( + attrs, make_lhs(1, 1, 1, 1, o_m), make_rhs(1, 1, 1, o_m, 1)); + tl::expected correct = + make_output(o_m, 1, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("propagate reduction lhs") { + tl::expected result = get_output_shape( + attrs, make_lhs(o_sum, 1, 1, 1, 1), make_rhs(1, o_sum, 1, 1, 1)); + tl::expected correct = + make_output(o_sum, 1, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("propagate reduction rhs") { + tl::expected result = get_output_shape( + attrs, make_lhs(1, o_sum, 1, 1, 1), make_rhs(o_sum, 1, 1, 1, 1)); + tl::expected correct = + make_output(o_sum, 1, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("reduction lhs & reduction rhs") { + tl::expected result = + get_output_shape(attrs, + make_lhs(o_sum, o_sum, 1, 1, 1), + make_rhs(o_sum, o_sum, 1, 1, 1)); + tl::expected correct = + make_output(o_sum * o_sum, 1, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("reduction lhs & rhs (invalid)") { + tl::expected result = get_output_shape( + attrs, make_lhs(o_sum, 1, 1, 1, 1), make_rhs(o_sum, 1, 1, 1, 1)); + + CHECK_MESSAGE( + !result.has_value(), "Unexpected successful value: ", result); + } + + SUBCASE("reduction lhs & n") { + tl::expected result = + get_output_shape(attrs, + make_lhs(o_sum, 1, 1, o_n, 1), + make_rhs(1, o_sum * o_n, 1, 1, 1)); + tl::expected correct = + make_output(o_sum, 1, 1, o_n, 1); + + CHECK(result == correct); + } + + SUBCASE("reduction lhs & reduction rhs & n") { + tl::expected result = + get_output_shape(attrs, + make_lhs(o_sum, o_sum, 1, o_n, 1), + make_rhs(o_sum, o_sum * o_n, 1, 1, 1)); + tl::expected correct = + make_output(o_sum * o_sum, 1, 1, o_n, 1); + + CHECK(result == correct); + } + + SUBCASE("reduction lhs & reduction rhs & n & m") { + tl::expected result = + get_output_shape(attrs, + make_lhs(o_sum, o_sum, 1, o_n, o_m), + make_rhs(o_sum, o_sum * o_n, 1, o_m, 1)); + tl::expected correct = + make_output(o_sum * o_sum * o_m, 1, 1, o_n, 1); + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/test_conv_2d.cc b/lib/op-attrs/test/src/test_conv_2d.cc new file mode 100644 index 0000000000..b16a26a7b1 --- /dev/null +++ b/lib/op-attrs/test/src/test_conv_2d.cc @@ -0,0 +1,62 @@ +#include "doctest/doctest.h" +#include "op-attrs/ops/conv_2d.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(Conv2DAttrs, TensorShape)") { + int out_channels = 4; + int kernel_h = 3; + int kernel_w = 2; + int stride_h = 2; + int stride_w = 2; + int padding_h = 1; + int padding_w = 1; + int groups = 1; + std::optional activation = std::nullopt; + bool use_bias = true; + + Conv2DAttrs attrs = { + /*out_channels=*/out_channels, + /*kernel_h=*/kernel_h, + /*kernel_w=*/kernel_w, + /*stride_h=*/stride_h, + /*stride_w=*/stride_w, + /*padding_h=*/padding_h, + /*padding_w=*/padding_w, + /*groups=*/groups, + /*activation=*/activation, + /*use_bias=*/true, + }; + + size_t num_samples = 7; + size_t input_channels = 6; + size_t input_height = 10; + size_t input_width = 15; + + TensorShape input_shape = { + TensorDims{FFOrdered{ + num_samples, + input_channels, + input_height, + input_width, + }}, + DataType::FLOAT, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + + size_t correct_output_height = 3; + size_t correct_output_width = 6; + + TensorShape correct_output_shape = { + TensorDims{FFOrdered{ + num_samples, + static_cast(out_channels), + correct_output_height, + correct_output_width, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct_output_shape); + } +} diff --git a/lib/op-attrs/test/src/test_dim_ordered.cc b/lib/op-attrs/test/src/test_dim_ordered.cc new file mode 100644 index 0000000000..17f4bae05f --- /dev/null +++ b/lib/op-attrs/test/src/test_dim_ordered.cc @@ -0,0 +1,13 @@ +#include "doctest/doctest.h" +#include "op-attrs/dim_ordered.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE_TEMPLATE("RC", T, int, double, char) { + CHECK(rc::check("generate", + [](FFOrdered ff_dim, DimOrdered dim) {})); + } +} diff --git a/lib/op-attrs/test/src/test_element_binary.cc b/lib/op-attrs/test/src/test_element_binary.cc new file mode 100644 index 0000000000..b1aedbf6b5 --- /dev/null +++ b/lib/op-attrs/test/src/test_element_binary.cc @@ -0,0 +1,162 @@ +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("EWAdd shape inference") { + size_t d1 = 16; + size_t d2 = 32; + size_t d3 = 24; + + ElementBinaryAttrs attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }; + + TensorShape input_lhs = TensorShape{ + TensorDims{ + FFOrdered{ + d1, + d2, + d3, + }, + }, + DataType::FLOAT, + }; + + TensorShape input_rhs = input_lhs; + + SUBCASE("correct") { + tl::expected result = + get_output_shape(attrs, input_lhs, input_rhs); + tl::expected correct = input_lhs; + + CHECK(result == correct); + } + + SUBCASE("mismatched dim size") { + TensorShape incorrect_rhs = input_lhs; + dim_at_idx(incorrect_rhs, ff_dim_t{0}) += 1; + + tl::expected result = + get_output_shape(attrs, input_lhs, incorrect_rhs); + + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); + } + } + + TEST_CASE("EWAdd parallel shape inference") { + size_t d1 = 16; + size_t d2 = 32; + size_t d3 = 24; + + ElementBinaryAttrs attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }; + + TensorShape unpar_lhs = TensorShape{ + TensorDims{ + FFOrdered{ + d1, + d2, + d3, + }, + }, + DataType::FLOAT, + }; + + TensorShape unpar_rhs = unpar_lhs; + tl::expected result_unpar_output = + get_output_shape(attrs, unpar_lhs, unpar_rhs); + REQUIRE(result_unpar_output.has_value()); + TensorShape unpar_output = result_unpar_output.value(); + + auto make_lhs = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_1, + int o_2, + int o_3) { + return lift_to_parallel_with_degrees( + unpar_lhs, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + }; + + auto make_rhs = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_1, + int o_2, + int o_3) { + return lift_to_parallel_with_degrees( + unpar_rhs, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + }; + + auto make_output = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_1, + int o_2, + int o_3) { + return lift_to_parallel_with_degrees( + unpar_output, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + }; + + SUBCASE("data parallelism") { + int degree = 4; + + ParallelTensorShape input_lhs = make_lhs(1, 1, degree, 1, 1); + ParallelTensorShape input_rhs = make_rhs(1, 1, degree, 1, 1); + tl::expected result = + get_output_shape(attrs, input_lhs, input_rhs); + tl::expected correct = + make_output(1, 1, degree, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("reduction parallelism") { + int degree = 4; + + ParallelTensorShape input_lhs = make_lhs(SumDegree{degree}, 1, 1, 1, 1); + ParallelTensorShape input_rhs = make_rhs(SumDegree{degree}, 1, 1, 1, 1); + tl::expected result = + get_output_shape(attrs, input_lhs, input_rhs); + tl::expected correct = + make_output(SumDegree{degree}, 1, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("invalid discard copy parallelism") { + int degree = 4; + + ParallelTensorShape input_lhs = + make_lhs(1, DiscardCopyDegree{degree}, 1, 1, 1); + ParallelTensorShape input_rhs = + make_rhs(1, DiscardCopyDegree{degree}, 1, 1, 1); + tl::expected result = + get_output_shape(attrs, input_lhs, input_rhs); + + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); + } + + SUBCASE("invalid mismatched parallelism degrees") { + int degree = 4; + + ParallelTensorShape input_lhs = make_lhs(1, 1, 1, degree, 1); + ParallelTensorShape input_rhs = make_rhs(1, 1, 1, 1, degree); + tl::expected result = + get_output_shape(attrs, input_lhs, input_rhs); + + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); + } + } +} diff --git a/lib/op-attrs/test/src/test_element_unary.cc b/lib/op-attrs/test/src/test_element_unary.cc new file mode 100644 index 0000000000..384dbc1a53 --- /dev/null +++ b/lib/op-attrs/test/src/test_element_unary.cc @@ -0,0 +1,73 @@ +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("ReLU shape inference") { + size_t d1 = 16; + size_t d2 = 32; + size_t d3 = 24; + + ElementUnaryAttrs attrs = ElementUnaryAttrs{OperatorType::RELU}; + + TensorShape input = TensorShape{ + TensorDims{ + FFOrdered{ + d1, + d2, + d3, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + + auto make_i = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_1, + int o_2, + int o_3) { + return lift_to_parallel_with_degrees( + input, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + }; + + SUBCASE("partition i.e., sharding parallelism") { + int degree1 = 4; + int degree2 = 8; + ParallelTensorShape par_input = make_i(1, 1, degree1, 1, degree2); + + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = par_input; + + CHECK(result == correct); + } + + SUBCASE("sum degree > 1") { + int degree = 2; + + tl::expected result = + get_output_shape(attrs, make_i(SumDegree{degree}, 1, 1, 1, 1)); + + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); + } + + SUBCASE("discard copy degree > 1") { + int degree = 2; + + tl::expected result = get_output_shape( + attrs, make_i(1, DiscardCopyDegree{degree}, 1, 1, 1)); + + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); + } + } +} diff --git a/lib/op-attrs/test/src/test_embedding.cc b/lib/op-attrs/test/src/test_embedding.cc new file mode 100644 index 0000000000..7bce6bd4d9 --- /dev/null +++ b/lib/op-attrs/test/src/test_embedding.cc @@ -0,0 +1,160 @@ +#include "op-attrs/ops/embedding.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest.h" +#include "utils/integer_conversions.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Sum embedding shape inference") { + int out_channels = 128; + int num_entries = 1024; + EmbeddingAttrs attrs = EmbeddingAttrs{ + /*num_entries=*/num_entries, + /*out_channels=*/out_channels, + /*aggr=*/AggregateOp::SUM, + /*data_type=*/DataType::FLOAT, + }; + + size_t batch_size = 48; + size_t features_dim = 56; + + TensorShape input = { + TensorDims{FFOrdered{ + batch_size, + features_dim, + }}, + DataType::INT32, + }; + + TensorShape output = TensorShape{ + TensorDims{ + FFOrdered{ + batch_size, + size_t_from_int(out_channels), + }, + }, + DataType::FLOAT, + }; + + TensorShape weights = TensorShape{ + TensorDims{ + FFOrdered{ + size_t_from_int(num_entries), + size_t_from_int(out_channels), + }, + }, + DataType::FLOAT, + }; + + // get_output_shape + { + tl::expected output_result = + get_output_shape(attrs, input); + tl::expected output_correct = output; + CHECK(output_result == output_correct); + } + + // get_weights_shape + { + tl::expected weight_result = + get_weights_shape(attrs, input); + tl::expected weight_correct = weights; + CHECK(weight_result == weight_correct); + } + + auto make_input = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_features) { + return lift_to_parallel_with_degrees( + input, o_sum, o_eq, FFOrdered{o_batch, o_features}); + }; + + auto make_output = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_outchannels) { + return lift_to_parallel_with_degrees( + output, o_sum, o_eq, FFOrdered{o_batch, o_outchannels}); + }; + + auto make_weights = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_entries, + int o_outchannels) { + return lift_to_parallel_with_degrees( + weights, o_sum, o_eq, FFOrdered{o_entries, o_outchannels}); + }; + + SUBCASE("data parallelism") { + int degree = 4; + ParallelTensorShape par_input = + make_input(SumDegree{1}, DiscardCopyDegree{1}, degree, 1); + + { + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = + make_output(SumDegree{1}, DiscardCopyDegree{1}, degree, 1); + CHECK(result == correct); + } + + { + tl::expected result = + get_weights_shape(attrs, par_input); + tl::expected correct = + make_weights(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1); + CHECK(result == correct); + } + } + + SUBCASE("input features parallelism") { + int degree = 4; + ParallelTensorShape input = + make_input(SumDegree{1}, DiscardCopyDegree{1}, 1, degree); + + { + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = + make_output(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1); + CHECK(result == correct); + } + + { + tl::expected result = + get_weights_shape(attrs, input); + tl::expected correct = + make_weights(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1); + CHECK(result == correct); + } + } + + SUBCASE("output channel shard parallelism") { + // NOTE (@lockshaw): in the current (parallel shape inference from just + // input tensor) representation we have to choose between either + // parallelism in the weight channel dimension or in the weight entry + // dimension. For now we choose to represent parallelism in the channel + // dimension, but partitioning in the entry dimension is also potentially + // useful as it produces sum parallelism in the output + int degree = 4; + ParallelTensorShape input = + make_input(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1); + + { + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = + make_output(SumDegree{1}, DiscardCopyDegree{1}, 1, degree); + CHECK(result == correct); + } + + { + tl::expected result = + get_weights_shape(attrs, input); + tl::expected correct = + make_weights(SumDegree{1}, DiscardCopyDegree{1}, 1, degree); + CHECK(result == correct); + } + } + } +} diff --git a/lib/op-attrs/test/src/test_operator_attrs.cc b/lib/op-attrs/test/src/test_operator_attrs.cc new file mode 100644 index 0000000000..a7724dba69 --- /dev/null +++ b/lib/op-attrs/test/src/test_operator_attrs.cc @@ -0,0 +1,35 @@ +#include "doctest/doctest.h" +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "utils/json.h" +#include +#include + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("BatchNormAttrs to/from json") { + BatchNormAttrs correct = BatchNormAttrs{true}; + json j = correct; + auto result = j.get(); + CHECK(result == correct); + } + + TEST_CASE("ComputationGraphAttrs to/from json") { + ComputationGraphOpAttrs correct = + ComputationGraphOpAttrs{BatchNormAttrs{true}}; + json j = correct; + auto result = j.get(); + + CHECK(result == correct); + } + + TEST_CASE("PCGOperatorAttrs to/from json") { + PCGOperatorAttrs correct = PCGOperatorAttrs{RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{1}, + /*repartition_degree=*/4, + }}; + json j = correct; + auto result = j.get(); + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/test_regularizer_attrs.cc b/lib/op-attrs/test/src/test_regularizer_attrs.cc new file mode 100644 index 0000000000..198c3add38 --- /dev/null +++ b/lib/op-attrs/test/src/test_regularizer_attrs.cc @@ -0,0 +1,14 @@ +#include "doctest/doctest.h" +#include "op-attrs/regularizer_attrs.dtg.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("RC") { + CHECK(rc::check("valid variant", [](RegularizerAttrs reg) { + return reg.has() || reg.has(); + })); + } +} diff --git a/lib/pcg/CMakeLists.txt b/lib/pcg/CMakeLists.txt index 81009b0f1f..e1875ca694 100644 --- a/lib/pcg/CMakeLists.txt +++ b/lib/pcg/CMakeLists.txt @@ -13,3 +13,4 @@ ff_add_library( ) add_subdirectory(ffi) +add_subdirectory(test) diff --git a/lib/pcg/include/pcg/computation_graph.dtg.h b/lib/pcg/include/pcg/computation_graph.dtg.h new file mode 100644 index 0000000000..217b940ce6 --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph.dtg.h @@ -0,0 +1,29 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/computation_graph.struct.toml +/* proj-data +{ + "generated_from": "8f1f0e13d75065944f7fe307e12fe280" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H + +#include "pcg/dataflow_graph.h" +#include "pcg/layer_attrs.dtg.h" +#include "pcg/tensor_attrs.dtg.h" + +namespace FlexFlow { +struct ComputationGraph { + ComputationGraph() = delete; + ComputationGraph( + ::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, + ::FlexFlow::TensorAttrs> const &raw_graph); + + ::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs> + raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index c051fcc8c3..23003641cf 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -1,59 +1,23 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H -#include "layer.h" -#include "operator_guid_t.h" -#include "tensor.h" -#include "tensor_guid_t.h" -#include "utils/containers.h" -#include "utils/graph.h" -#include "utils/strong_typedef.h" -#include "visit_struct/visit_struct.hpp" +#include "pcg/computation_graph.dtg.h" +#include "pcg/computation_graph/layer_added_result.dtg.h" +#include "pcg/layer_guid_t.dtg.h" +#include "pcg/tensor_attrs.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" namespace FlexFlow { -struct ComputationGraph - : public strong_typedef> { - using strong_typedef::strong_typedef; +ComputationGraph make_empty_computation_graph(); - Layer &at(operator_guid_t const &n) { - return this->value().at(n.value()); - } +std::unordered_set get_layers(ComputationGraph const &); - Layer const &at(operator_guid_t const &n) const { - return this->value().at(n.value()); - } - - Tensor &at(tensor_guid_t const &e) { - return this->value().at(e.value()); - } - - Tensor const &at(tensor_guid_t const &e) const { - return this->value().at(e.value()); - } -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(ComputationGraph); - -std::vector - traverse_comp_graph_forward(ComputationGraph const &comp_graph); -std::vector - traverse_comp_graph_backward(ComputationGraph const &comp_graph); -std::vector - get_outgoing_tensors(ComputationGraph const &comp_graph, operator_guid_t n); -std::vector - get_incoming_tensors(ComputationGraph const &comp_graph, operator_guid_t n); -operator_guid_t create_node(ComputationGraph &comp_graph, Layer const &layer); -tensor_guid_t create_outgoing_edge(ComputationGraph &comp_graph, - operator_guid_t node, - int idx, - Tensor tensor); - -void connect_incoming_edges(ComputationGraph &comp_graph, - std::vector const &incoming_edges, - operator_guid_t node); -CompGraphOperatorAttrs get_layer_attrs(ComputationGraph const &comp_graph, - operator_guid_t const &n); +LayerAddedResult add_layer(ComputationGraph &computation_graph, + LayerAttrs const &attrs, + std::vector const &inputs, + std::vector const &outputs); +TensorAttrs get_tensor_attrs(ComputationGraph const &, tensor_guid_t const &); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/computation_graph.struct.toml b/lib/pcg/include/pcg/computation_graph.struct.toml new file mode 100644 index 0000000000..a270cb8fbe --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "ComputationGraph" +features = [ ] + +includes = [ + "pcg/layer_attrs.dtg.h", + "pcg/tensor_attrs.dtg.h", + "pcg/dataflow_graph.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs>" diff --git a/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.h b/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.h new file mode 100644 index 0000000000..4fd78f2d44 --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.h @@ -0,0 +1,37 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml +/* proj-data +{ + "generated_from": "15bf9d73ef934599c9b11807d86ae5d4" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_LAYER_ADDED_RESULT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_LAYER_ADDED_RESULT_DTG_H + +#include "fmt/format.h" +#include "pcg/layer_guid_t.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" +#include +#include + +namespace FlexFlow { +struct LayerAddedResult { + LayerAddedResult() = delete; + LayerAddedResult(::FlexFlow::layer_guid_t const &layer, + std::vector<::FlexFlow::tensor_guid_t> const &outputs); + + bool operator==(LayerAddedResult const &) const; + bool operator!=(LayerAddedResult const &) const; + ::FlexFlow::layer_guid_t layer; + std::vector<::FlexFlow::tensor_guid_t> outputs; +}; +} // namespace FlexFlow + +namespace FlexFlow { +std::string format_as(LayerAddedResult const &); +std::ostream &operator<<(std::ostream &, LayerAddedResult const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_LAYER_ADDED_RESULT_DTG_H diff --git a/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml b/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml new file mode 100644 index 0000000000..b02e992ba1 --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "LayerAddedResult" +features = [ + "eq", + "fmt", +] + +includes = [ + "pcg/layer_guid_t.dtg.h", + "pcg/tensor_guid_t.dtg.h", +] + +[[fields]] +name = "layer" +type = "::FlexFlow::layer_guid_t" + +[[fields]] +name = "outputs" +type = "std::vector<::FlexFlow::tensor_guid_t>" diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 7ba95d701b..3a1526e9c8 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -1,13 +1,13 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_BUILDER_H #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_BUILDER_H -#include "computation_graph.h" -#include "optimizer.h" +#include "pcg/computation_graph.dtg.h" +#include "pcg/initializer_attrs.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" namespace FlexFlow { -struct ComputationGraphBuilder - : public use_visitable_cmp { +struct ComputationGraphBuilder { public: ComputationGraphBuilder(); @@ -95,8 +95,8 @@ struct ComputationGraphBuilder std::optional const &activation = std::nullopt, int groups = 1, bool use_bias = true, - std::optional const &kernel_initializer = std::nullopt, - std::optional const &bias_initializer = std::nullopt, + std::optional const &kernel_initializer = std::nullopt, + std::optional const &bias_initializer = std::nullopt, std::optional const &kernel_regularizer = std::nullopt, std::optional const &name = std::nullopt); // Add a dropout layer @@ -111,7 +111,7 @@ struct ComputationGraphBuilder int outDim, AggregateOp aggr, DataType dtype = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, + std::optional const &kernel_initializer = std::nullopt, std::optional const &name = std::nullopt); // Add a gather layer std::vector @@ -154,15 +154,15 @@ struct ComputationGraphBuilder int a_seq_length_dim = -1, int b_seq_length_dim = -1, std::optional const &name = std::nullopt); - tensor_guid_t - dense(tensor_guid_t const &input, - int outDim, - std::optional activation = std::nullopt, - bool use_bias = true, - DataType data_type = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, - std::optional const &bias_initializer = std::nullopt, - std::optional const &name = std::nullopt); + tensor_guid_t dense( + tensor_guid_t const &input, + int outDim, + std::optional activation = std::nullopt, + bool use_bias = true, + DataType data_type = DataType::FLOAT, + std::optional const &kernel_initializer = std::nullopt, + std::optional const &bias_initializer = std::nullopt, + std::optional const &name = std::nullopt); // Add a cast layer tensor_guid_t cast(tensor_guid_t const &input, DataType dtype, @@ -178,11 +178,11 @@ struct ComputationGraphBuilder bool keepdims, char const *name); // Add a split layer - void split(tensor_guid_t const &input, - tensor_guid_t *outputs, - std::vector const &split, - int axis, - std::optional const &name = std::nullopt); + std::vector + split(tensor_guid_t const &input, + std::vector const &split, + int axis, + std::optional const &name = std::nullopt); // Add a flat layer tensor_guid_t flat(tensor_guid_t const &input, std::optional const &name = std::nullopt); @@ -191,8 +191,6 @@ struct ComputationGraphBuilder int dim = -1, std::optional const &name = std::nullopt); // Create input tensors and constants - tensor_guid_t input(Tensor const &input_tensor, - std::optional const &name = std::nullopt); tensor_guid_t transpose(tensor_guid_t const &input, std::vector const &perm, @@ -208,11 +206,11 @@ struct ComputationGraphBuilder tensor_guid_t reverse(tensor_guid_t const &input, int axis, std::optional const &name = std::nullopt); - void top_k(tensor_guid_t const &input, - tensor_guid_t *outputs, - int k, - bool sorted, - std::optional const &name = std::nullopt); + std::vector + top_k(tensor_guid_t const &input, + int k, + bool sorted, + std::optional const &name = std::nullopt); tensor_guid_t multihead_attention( tensor_guid_t const &query, tensor_guid_t const &key, @@ -225,44 +223,48 @@ struct ComputationGraphBuilder bool bias = true, bool add_bias_kv = false, bool add_zero_attn = false, - std::optional initializer = std::nullopt, + std::optional initializer = std::nullopt, std::optional const &name = std::nullopt); tensor_guid_t create_tensor(TensorShape const &, bool create_grad = true); - Parameter create_weight( + tensor_guid_t create_weight( TensorShape const &, bool create_grad = true, - std::optional const &initializer = std::nullopt, + std::optional const &initializer = std::nullopt, std::optional sync_type = std::nullopt); - std::vector get_outputs(operator_guid_t const &) const; - tensor_guid_t get_output(operator_guid_t const &, int idx) const; - Tensor get_tensor(tensor_guid_t const &) const; + std::vector get_outputs(LayerAttrs const &) const; + tensor_guid_t get_output(LayerAttrs const &, int idx) const; private: - tensor_guid_t broadcast(tensor_guid_t const &, TensorShape const &); + TensorShape get_shape(tensor_guid_t const &) const; - void add_layer(Layer const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs); - tensor_guid_t add_layer( - Layer const &layer, - std::vector const &inputs, - std::vector>> const - &weight_shapes, - TensorShape const &output_shape); - std::vector add_layer( - Layer const &layer, - std::vector const &inputs, - std::vector>> const - &weight_shapes, - std::vector const &output_shapes); + tensor_guid_t broadcast(tensor_guid_t const &, TensorShape const &); tensor_guid_t as_type(tensor_guid_t const &, DataType, std::string const &); + std::vector add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs); + + tensor_guid_t add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorAttrs const &output); + + std::vector add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs); + + tensor_guid_t add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorShape const &output); + + TensorShape get_broadcast_target_shape(std::vector const &); TensorShape get_broadcast_target_shape(std::vector const &); - TensorShape get_shape(tensor_guid_t const &t); - std::vector get_shapes(std::vector const &t); + tensor_guid_t element_binary(OperatorType, tensor_guid_t const &lhs, @@ -293,11 +295,4 @@ struct ComputationGraphBuilder } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::ComputationGraphBuilder, computation_graph); - -namespace FlexFlow { -static_assert( - is_well_behaved_value_type_no_hash::value, ""); -} - #endif diff --git a/lib/pcg/include/pcg/cpu_id_t.dtg.h b/lib/pcg/include/pcg/cpu_id_t.dtg.h new file mode 100644 index 0000000000..a6c81e80b0 --- /dev/null +++ b/lib/pcg/include/pcg/cpu_id_t.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/cpu_id_t.struct.toml +/* proj-data +{ + "generated_from": "a0faf78831febfa3a02929169943d9f5" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CPU_ID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CPU_ID_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct cpu_id_t { + cpu_id_t() = delete; + cpu_id_t(int const &cpu_index); + + bool operator==(cpu_id_t const &) const; + bool operator!=(cpu_id_t const &) const; + bool operator<(cpu_id_t const &) const; + bool operator>(cpu_id_t const &) const; + bool operator<=(cpu_id_t const &) const; + bool operator>=(cpu_id_t const &) const; + int cpu_index; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::cpu_id_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::cpu_id_t from_json(json const &); + static void to_json(json &, FlexFlow::cpu_id_t const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(cpu_id_t const &); +std::ostream &operator<<(std::ostream &, cpu_id_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CPU_ID_T_DTG_H diff --git a/lib/pcg/include/pcg/cpu_id_t.struct.toml b/lib/pcg/include/pcg/cpu_id_t.struct.toml new file mode 100644 index 0000000000..0492a937be --- /dev/null +++ b/lib/pcg/include/pcg/cpu_id_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "cpu_id_t" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "cpu_index" +type = "int" diff --git a/lib/pcg/include/pcg/create_grad.dtg.h b/lib/pcg/include/pcg/create_grad.dtg.h new file mode 100644 index 0000000000..494ff06b75 --- /dev/null +++ b/lib/pcg/include/pcg/create_grad.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/create_grad.enum.toml +/* proj-data +{ + "generated_from": "9fd617027e850b6d6db476a49b3e0334" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CREATE_GRAD_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CREATE_GRAD_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class CreateGrad { YES, NO }; +std::string format_as(CreateGrad); +std::ostream &operator<<(std::ostream &, CreateGrad); +void to_json(::nlohmann::json &, CreateGrad); +void from_json(::nlohmann::json const &, CreateGrad &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::CreateGrad) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CREATE_GRAD_DTG_H diff --git a/lib/pcg/include/pcg/create_grad.enum.toml b/lib/pcg/include/pcg/create_grad.enum.toml new file mode 100644 index 0000000000..20febe49fb --- /dev/null +++ b/lib/pcg/include/pcg/create_grad.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "CreateGrad" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "YES" + +[[values]] +name = "NO" diff --git a/lib/pcg/include/pcg/create_grad.h b/lib/pcg/include/pcg/create_grad.h index 7dd843b76d..5a12d310c2 100644 --- a/lib/pcg/include/pcg/create_grad.h +++ b/lib/pcg/include/pcg/create_grad.h @@ -1,36 +1,8 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_CREATE_GRAD_H #define _FLEXFLOW_PCG_INCLUDE_PCG_CREATE_GRAD_H -#include "utils/fmt.h" +#include "pcg/create_grad_t.h" -namespace FlexFlow { - -enum class CreateGrad { YES, NO }; - -} - -namespace fmt { - -template <> -struct formatter<::FlexFlow::CreateGrad> : formatter { - template - auto format(::FlexFlow::CreateGrad ps, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (ps) { - case CreateGrad::YES: - name = "yes"; - break; - case CreateGrad::NO: - name = "no"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt +namespace FlexFlow {} #endif diff --git a/lib/pcg/include/pcg/dataflow_graph.h b/lib/pcg/include/pcg/dataflow_graph.h new file mode 100644 index 0000000000..f649c0444c --- /dev/null +++ b/lib/pcg/include/pcg/dataflow_graph.h @@ -0,0 +1,77 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_DATAFLOW_GRAPH_H + +#include "utils/containers/enumerate_vector.h" +#include "utils/graph.h" + +namespace FlexFlow { + +template +struct DataflowGraph { +public: + DataflowGraph() + : g(OutputLabelledMultiDiGraph::template create< + UnorderedOutputLabelledMultiDiGraph>()) {} + + std::vector + add_operator(NodeLabel const &func, + std::vector const &inputs, + std::vector const &outputs) { + Node n = this->g.add_node(func); + for (auto const &[idx, input] : enumerate_vector(inputs)) { + this->g.add_edge(MultiDiEdge{ + input.src, input.src_idx, n, this->make_port_for_idx(idx)}); + } + + std::vector result; + for (auto const &[idx, label] : enumerate_vector(outputs)) { + MultiDiOutput output = MultiDiOutput{n, this->make_port_for_idx(idx)}; + this->g.add_output(output, label); + result.push_back(output); + } + + return result; + } + + NodePort make_port_for_idx(int idx) { + if (!this->port_mapping.contains_l(idx)) { + this->port_mapping.equate(idx, this->g.add_node_port()); + } + return this->port_mapping.at_l(idx); + } + + NodePort port_for_idx(int idx) const { + return this->port_mapping.at_l(idx); + } + + int idx_for_port(NodePort const &p) const { + return this->port_mapping.at_r(p); + } + + OutputLabelledMultiDiGraphView const & + get_raw_graph() const { + return this->g; + } + + NodeLabel const &at(Node const &n) const { + return this->g.at(n); + } + + OutputLabel const &at(MultiDiOutput const &o) const { + return this->g.at(o); + } + +private: + OutputLabelledMultiDiGraph g; + bidict port_mapping; +}; + +template +std::unordered_set + get_nodes(DataflowGraph const &g) { + return get_nodes(g.get_raw_graph()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/dataflow_input.dtg.h b/lib/pcg/include/pcg/dataflow_input.dtg.h new file mode 100644 index 0000000000..c698c75c25 --- /dev/null +++ b/lib/pcg/include/pcg/dataflow_input.dtg.h @@ -0,0 +1,101 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/dataflow_input.variant.toml +/* proj-data +{ + "generated_from": "d6a7f4570e36e257383529e9bf9390ec" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_INPUT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_INPUT_DTG_H + +#include "utils/graph/multidiedge.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct DataflowInput { + DataflowInput() = delete; + explicit DataflowInput(::FlexFlow::MultiDiOutput const &); + explicit DataflowInput(int const &); + template + static constexpr bool IsPartOfDataflowInput_v = + std::is_same_v || std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::MultiDiOutput>()); + return result; + } + case 1: { + ReturnType result = v(this->get()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type DataflowInput", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::MultiDiOutput>()); + return result; + } + case 1: { + ReturnType result = v(this->get()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type DataflowInput", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfDataflowInput_v, + "DataflowInput::has() expected one of " + "[::FlexFlow::MultiDiOutput, int], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfDataflowInput_v, + "DataflowInput::get() expected one of " + "[::FlexFlow::MultiDiOutput, int], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfDataflowInput_v, + "DataflowInput::get() expected one of " + "[::FlexFlow::MultiDiOutput, int], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(DataflowInput const &) const; + bool operator!=(DataflowInput const &) const; + bool operator<(DataflowInput const &) const; + bool operator>(DataflowInput const &) const; + bool operator<=(DataflowInput const &) const; + bool operator>=(DataflowInput const &) const; + std::variant<::FlexFlow::MultiDiOutput, int> raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::DataflowInput> { + size_t operator()(::FlexFlow::DataflowInput const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_INPUT_DTG_H diff --git a/lib/pcg/include/pcg/dataflow_input.variant.toml b/lib/pcg/include/pcg/dataflow_input.variant.toml new file mode 100644 index 0000000000..ac7c3ae5d7 --- /dev/null +++ b/lib/pcg/include/pcg/dataflow_input.variant.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "DataflowInput" +features = [ + "eq", + "ord", + "hash", + # "json", + # "fmt", +] + +includes = [ + "utils/graph/multidiedge.h" , +] + +[[values]] +type = "::FlexFlow::MultiDiOutput" +key = "internal" + +[[values]] +type = "int" +key = "external" diff --git a/lib/pcg/include/pcg/device_id.h b/lib/pcg/include/pcg/device_id.h index b118d69259..be92be7081 100644 --- a/lib/pcg/include/pcg/device_id.h +++ b/lib/pcg/include/pcg/device_id.h @@ -1,35 +1,21 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_DEVICE_ID_H #define _FLEXFLOW_PCG_INCLUDE_PCG_DEVICE_ID_H -#include "device_type.h" -#include "utils/strong_typedef.h" -#include +#include "pcg/cpu_id_t.dtg.h" +#include "pcg/device_id_t.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/gpu_id_t.dtg.h" namespace FlexFlow { -struct gpu_id_t : strong_typedef { - using strong_typedef::strong_typedef; -}; - -struct cpu_id_t : strong_typedef { - using strong_typedef::strong_typedef; -}; - -using device_id_t = std::variant; device_id_t operator+(device_id_t, size_t); DeviceType get_device_type(device_id_t); gpu_id_t unwrap_gpu(device_id_t); cpu_id_t unwrap_cpu(device_id_t); -device_id_t from_index(int, DeviceType); +device_id_t device_id_from_index(int, DeviceType); } // namespace FlexFlow -MAKE_TYPEDEF_HASHABLE(::FlexFlow::gpu_id_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::gpu_id_t, "gpu_id"); - -MAKE_TYPEDEF_HASHABLE(::FlexFlow::cpu_id_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::cpu_id_t, "cpu_id"); - #endif diff --git a/lib/pcg/include/pcg/device_id_t.dtg.h b/lib/pcg/include/pcg/device_id_t.dtg.h new file mode 100644 index 0000000000..d46f3dd079 --- /dev/null +++ b/lib/pcg/include/pcg/device_id_t.dtg.h @@ -0,0 +1,117 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/device_id_t.variant.toml +/* proj-data +{ + "generated_from": "85870050c742b0159775399ec2be67e3" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_ID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_ID_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/cpu_id_t.dtg.h" +#include "pcg/gpu_id_t.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct device_id_t { + device_id_t() = delete; + explicit device_id_t(::FlexFlow::gpu_id_t const &); + explicit device_id_t(::FlexFlow::cpu_id_t const &); + template + static constexpr bool IsPartOfdevice_id_t_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::gpu_id_t>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::cpu_id_t>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type device_id_t", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::gpu_id_t>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::cpu_id_t>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type device_id_t", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfdevice_id_t_v, + "device_id_t::has() expected one of [::FlexFlow::gpu_id_t, " + "::FlexFlow::cpu_id_t], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfdevice_id_t_v, + "device_id_t::get() expected one of [::FlexFlow::gpu_id_t, " + "::FlexFlow::cpu_id_t], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfdevice_id_t_v, + "device_id_t::get() expected one of [::FlexFlow::gpu_id_t, " + "::FlexFlow::cpu_id_t], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(device_id_t const &) const; + bool operator!=(device_id_t const &) const; + bool operator<(device_id_t const &) const; + bool operator>(device_id_t const &) const; + bool operator<=(device_id_t const &) const; + bool operator>=(device_id_t const &) const; + std::variant<::FlexFlow::gpu_id_t, ::FlexFlow::cpu_id_t> raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::device_id_t> { + size_t operator()(::FlexFlow::device_id_t const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::device_id_t> { + static ::FlexFlow::device_id_t from_json(json const &); + static void to_json(json &, ::FlexFlow::device_id_t const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::device_id_t const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::device_id_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_ID_T_DTG_H diff --git a/lib/pcg/include/pcg/device_id_t.variant.toml b/lib/pcg/include/pcg/device_id_t.variant.toml new file mode 100644 index 0000000000..71af18919f --- /dev/null +++ b/lib/pcg/include/pcg/device_id_t.variant.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "device_id_t" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/cpu_id_t.dtg.h", + "pcg/gpu_id_t.dtg.h", +] + +[[values]] +type = "::FlexFlow::gpu_id_t" +key = "gpu" + +[[values]] +type = "::FlexFlow::cpu_id_t" +key = "cpu" diff --git a/lib/pcg/include/pcg/device_type.dtg.h b/lib/pcg/include/pcg/device_type.dtg.h new file mode 100644 index 0000000000..f5e90dc193 --- /dev/null +++ b/lib/pcg/include/pcg/device_type.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/device_type.enum.toml +/* proj-data +{ + "generated_from": "cfe4bc5e9f7c5796b9b90b420c33935f" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_TYPE_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_TYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class DeviceType { GPU, CPU }; +std::string format_as(DeviceType); +std::ostream &operator<<(std::ostream &, DeviceType); +void to_json(::nlohmann::json &, DeviceType); +void from_json(::nlohmann::json const &, DeviceType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::DeviceType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_TYPE_DTG_H diff --git a/lib/pcg/include/pcg/device_type.enum.toml b/lib/pcg/include/pcg/device_type.enum.toml new file mode 100644 index 0000000000..67f89fbc6f --- /dev/null +++ b/lib/pcg/include/pcg/device_type.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "DeviceType" +features = [ + "hash", + "json", + "fmt", + "rapidcheck", +] + +[[values]] +name = "GPU" + +[[values]] +name = "CPU" diff --git a/lib/pcg/include/pcg/device_type.h b/lib/pcg/include/pcg/device_type.h deleted file mode 100644 index 3ae374c5ea..0000000000 --- a/lib/pcg/include/pcg/device_type.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_DEVICE_TYPE_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_DEVICE_TYPE_H - -#include "utils/fmt.h" - -namespace FlexFlow { - -enum class DeviceType { GPU, CPU }; - -} - -namespace fmt { - -template <> -struct formatter<::FlexFlow::DeviceType> : formatter { - template - auto format(::FlexFlow::DeviceType d, FormatContext &ctx) const - -> decltype(ctx.out()) { - using ::FlexFlow::DeviceType; - - string_view name = "unknown"; - switch (d) { - case DeviceType::GPU: - name = "GPU"; - break; - case DeviceType::CPU: - name = "CPU"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/data_type.h b/lib/pcg/include/pcg/file_format/v1/data_type_value.h similarity index 64% rename from lib/pcg/include/pcg/file_format/v1/data_type.h rename to lib/pcg/include/pcg/file_format/v1/data_type_value.h index eab188155f..6e4e5abc54 100644 --- a/lib/pcg/include/pcg/file_format/v1/data_type.h +++ b/lib/pcg/include/pcg/file_format/v1/data_type_value.h @@ -9,23 +9,6 @@ namespace FlexFlow { using V1DataTypeValue = std::variant; -enum class V1DataType { - BOOL, - INT32, - INT64, - HALF, - FLOAT, - DOUBLE, -}; - -NLOHMANN_JSON_SERIALIZE_ENUM(V1DataType, - {{V1DataType::BOOL, "BOOL"}, - {V1DataType::INT32, "INT32"}, - {V1DataType::INT64, "INT64"}, - {V1DataType::HALF, "HALF"}, - {V1DataType::FLOAT, "FLOAT"}, - {V1DataType::DOUBLE, "DOUBLE"}}); - } // namespace FlexFlow namespace nlohmann { diff --git a/lib/pcg/include/pcg/file_format/v1/graphs.h b/lib/pcg/include/pcg/file_format/v1/graphs.h index 6bc852b0f1..dad73ce142 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs.h @@ -1,73 +1,23 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H -#include "operator_attrs.h" -#include "parallel_tensor.h" -#include "pcg/computation_graph.h" -#include "pcg/parallel_computation_graph.h" -#include "tensor.h" +#include "pcg/computation_graph.dtg.h" +#include "pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h" +#include "pcg/layer_attrs.dtg.h" +#include "pcg/parallel_computation_graph.dtg.h" +#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "pcg/tensor_attrs.dtg.h" #include "utils/json.h" -#include "utils/required.h" -#include "utils/visitable.h" namespace FlexFlow { -struct V1GraphOutput { - req srcNode; - req srcIdx; -}; -FF_VISITABLE_STRUCT(V1GraphOutput, srcNode, srcIdx); -CHECK_IS_JSONABLE(V1GraphOutput); - -struct V1GraphEdge { - req srcNode; - req srcIdx; - req dstNode; - req dstIdx; -}; -FF_VISITABLE_STRUCT(V1GraphEdge, srcNode, srcIdx, dstNode, dstIdx); -CHECK_IS_JSONABLE(V1GraphEdge); - -struct V1MultiDiGraph { - req> nodes; - req> ports; - req> edges; -}; -FF_VISITABLE_STRUCT(V1MultiDiGraph, nodes, ports, edges); -CHECK_IS_JSONABLE(V1MultiDiGraph); -V1MultiDiGraph to_v1(MultiDiGraphView const &); -V1MultiDiGraph to_v1(MultiDiGraphView const &, - std::unordered_map const &, - std::unordered_map const &); - -template -struct V1JsonableGraph { - using node_id = size_t; - using tensor_id = size_t; - - req> node_labels; - req> outputs; - req> output_labels; - V1MultiDiGraph graph; -}; - -struct V1Layer { - V1CompGraphOperatorAttrs attrs; - req> name; -}; -FF_VISITABLE_STRUCT(V1Layer, attrs, name); -V1Layer to_v1(Layer const &); - -using V1ComputationGraph = V1JsonableGraph; -FF_VISITABLE_STRUCT( - V1ComputationGraph, node_labels, outputs, output_labels, graph); +using V1ComputationGraph = V1JsonableGraph; CHECK_IS_JSONABLE(V1ComputationGraph); V1ComputationGraph to_v1(ComputationGraph const &); using V1ParallelComputationGraph = - V1JsonableGraph; -FF_VISITABLE_STRUCT( - V1ParallelComputationGraph, node_labels, outputs, output_labels, graph); + V1JsonableGraph; CHECK_IS_JSONABLE(V1ParallelComputationGraph); V1ParallelComputationGraph to_v1(ParallelComputationGraph const &); diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.h new file mode 100644 index 0000000000..e9238301d0 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.h @@ -0,0 +1,60 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml +/* proj-data +{ + "generated_from": "865097b569b831af049343e933834329" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_EDGE_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_EDGE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct V1GraphEdge { + V1GraphEdge() = delete; + V1GraphEdge(size_t const &srcNode, + size_t const &srcIdx, + size_t const &dstNode, + size_t const &dstIdx); + + bool operator==(V1GraphEdge const &) const; + bool operator!=(V1GraphEdge const &) const; + bool operator<(V1GraphEdge const &) const; + bool operator>(V1GraphEdge const &) const; + bool operator<=(V1GraphEdge const &) const; + bool operator>=(V1GraphEdge const &) const; + size_t srcNode; + size_t srcIdx; + size_t dstNode; + size_t dstIdx; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::V1GraphEdge const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::V1GraphEdge from_json(json const &); + static void to_json(json &, FlexFlow::V1GraphEdge const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1GraphEdge const &); +std::ostream &operator<<(std::ostream &, V1GraphEdge const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_EDGE_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml new file mode 100644 index 0000000000..b0d2546977 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "V1GraphEdge" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +[[fields]] +name = "srcNode" +type = "size_t" + +[[fields]] +name = "srcIdx" +type = "size_t" + +[[fields]] +name = "dstNode" +type = "size_t" + +[[fields]] +name = "dstIdx" +type = "size_t" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.h new file mode 100644 index 0000000000..730282bdb9 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.h @@ -0,0 +1,55 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml +/* proj-data +{ + "generated_from": "05ff8401c3d976ea2220899edb8dfe3a" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_OUTPUT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_OUTPUT_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct V1GraphOutput { + V1GraphOutput() = delete; + V1GraphOutput(size_t const &srcNode, size_t const &srcIdx); + + bool operator==(V1GraphOutput const &) const; + bool operator!=(V1GraphOutput const &) const; + bool operator<(V1GraphOutput const &) const; + bool operator>(V1GraphOutput const &) const; + bool operator<=(V1GraphOutput const &) const; + bool operator>=(V1GraphOutput const &) const; + size_t srcNode; + size_t srcIdx; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::V1GraphOutput const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::V1GraphOutput from_json(json const &); + static void to_json(json &, FlexFlow::V1GraphOutput const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1GraphOutput const &); +std::ostream &operator<<(std::ostream &, V1GraphOutput const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_OUTPUT_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml new file mode 100644 index 0000000000..ba41f7e43f --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "V1GraphOutput" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +[[fields]] +name = "srcNode" +type = "size_t" + +[[fields]] +name = "srcIdx" +type = "size_t" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h new file mode 100644 index 0000000000..f183a14a9e --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h @@ -0,0 +1,109 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml +/* proj-data +{ + "generated_from": "0595a9f5a6bc19f9a170cb0e42c4202d" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_JSONABLE_GRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_JSONABLE_GRAPH_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/file_format/v1/graphs/v1_graph_output.dtg.h" +#include "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h" +#include +#include +#include + +namespace FlexFlow { +template +struct V1JsonableGraph { + V1JsonableGraph() = delete; + V1JsonableGraph( + std::unordered_map const &node_labels, + std::unordered_map const &outputs, + std::unordered_map const &output_labels, + ::FlexFlow::V1MultiDiGraph const &graph); + + std::unordered_map node_labels; + std::unordered_map outputs; + std::unordered_map output_labels; + ::FlexFlow::V1MultiDiGraph graph; +}; +} // namespace FlexFlow + +namespace nlohmann { +template +struct adl_serializer> { + static FlexFlow::V1JsonableGraph from_json(json const &); + static void to_json(json &, + FlexFlow::V1JsonableGraph const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +template +std::string format_as(V1JsonableGraph const &); +template +std::ostream &operator<<(std::ostream &, + V1JsonableGraph const &); +} // namespace FlexFlow + +namespace FlexFlow { +template +V1JsonableGraph::V1JsonableGraph( + std::unordered_map const &node_labels, + std::unordered_map const &outputs, + std::unordered_map const &output_labels, + ::FlexFlow::V1MultiDiGraph const &graph) + : node_labels(node_labels), outputs(outputs), output_labels(output_labels), + graph(graph) {} +} // namespace FlexFlow + +namespace nlohmann { +template +FlexFlow::V1JsonableGraph + adl_serializer>::from_json( + json const &j) { + return { + j.at("node_labels").template get>(), + j.at("outputs") + .template get< + std::unordered_map>(), + j.at("output_labels").template get>(), + j.at("graph").template get<::FlexFlow::V1MultiDiGraph>()}; +} +template +void adl_serializer>::to_json( + json &j, FlexFlow::V1JsonableGraph const &v) { + j["__type"] = "V1JsonableGraph"; + j["node_labels"] = v.node_labels; + j["outputs"] = v.outputs; + j["output_labels"] = v.output_labels; + j["graph"] = v.graph; +} +} // namespace nlohmann + +namespace FlexFlow { +template +std::string format_as(V1JsonableGraph const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +template +std::ostream &operator<<(std::ostream &s, + V1JsonableGraph const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_JSONABLE_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml new file mode 100644 index 0000000000..ad9ba21c60 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml @@ -0,0 +1,38 @@ +namespace = "FlexFlow" +name = "V1JsonableGraph" +features = [ + # "eq", + # "ord", + # "hash", + "json", + # "rapidcheck", + "fmt", +] + +template_params = [ + "NodeT", + "TensorT", +] + +includes = [ + "", + "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h", + "pcg/file_format/v1/graphs/v1_graph_output.dtg.h", +] + +[[fields]] +name = "node_labels" +type = "std::unordered_map" + +[[fields]] +name = "outputs" +type = "std::unordered_map" + +[[fields]] +name = "output_labels" +type = "std::unordered_map" + +[[fields]] +name = "graph" +type = "::FlexFlow::V1MultiDiGraph" + diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h new file mode 100644 index 0000000000..5d7edcf1d8 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml +/* proj-data +{ + "generated_from": "fb1033385645e54a19c9b44cef0be04b" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" +#include "utils/fmt.h" +#include +#include +#include + +namespace FlexFlow { +struct V1MultiDiGraph { + V1MultiDiGraph() = delete; + V1MultiDiGraph(std::vector const &nodes, + std::vector const &ports, + std::unordered_set<::FlexFlow::V1GraphEdge> const &edges); + + std::vector nodes; + std::vector ports; + std::unordered_set<::FlexFlow::V1GraphEdge> edges; +}; +} // namespace FlexFlow + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::V1MultiDiGraph from_json(json const &); + static void to_json(json &, FlexFlow::V1MultiDiGraph const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1MultiDiGraph const &); +std::ostream &operator<<(std::ostream &, V1MultiDiGraph const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.h new file mode 100644 index 0000000000..49ff850a29 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_H + +#include "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { + +V1MultiDiGraph to_v1(MultiDiGraphView const &); +V1MultiDiGraph to_v1(MultiDiGraphView const &, + std::unordered_map const &, + std::unordered_map const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml new file mode 100644 index 0000000000..9650f3bd43 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "V1MultiDiGraph" +features = [ + # "eq", + # "ord", + # "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "", + "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h", + "utils/fmt.h", +] + +[[fields]] +name = "nodes" +type = "std::vector" + +[[fields]] +name = "ports" +type = "std::vector" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::V1GraphEdge>" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h new file mode 100644 index 0000000000..7e5554d44a --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml +/* proj-data +{ + "generated_from": "5bfd7d8755cfd8cd9dbf57d5c367038e" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_OPERATOR_GRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_OPERATOR_GRAPH_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" +#include "utils/fmt.h" +#include +#include +#include + +namespace FlexFlow { +struct V1OperatorGraph { + V1OperatorGraph() = delete; + V1OperatorGraph(std::vector const &nodes, + std::unordered_set<::FlexFlow::V1GraphEdge> const &edges); + + std::vector nodes; + std::unordered_set<::FlexFlow::V1GraphEdge> edges; +}; +} // namespace FlexFlow + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::V1OperatorGraph from_json(json const &); + static void to_json(json &, FlexFlow::V1OperatorGraph const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1OperatorGraph const &); +std::ostream &operator<<(std::ostream &, V1OperatorGraph const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_OPERATOR_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml new file mode 100644 index 0000000000..61dc45ae2e --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "V1OperatorGraph" +features = [ + # "eq", + # "ord", + # "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "", + "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h", + "utils/fmt.h", +] + +[[fields]] +name = "nodes" +type = "std::vector" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::V1GraphEdge>" diff --git a/lib/pcg/include/pcg/file_format/v1/initializer.h b/lib/pcg/include/pcg/file_format/v1/initializer.h deleted file mode 100644 index 21af7d55e0..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/initializer.h +++ /dev/null @@ -1,57 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_INITIALIZER_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_INITIALIZER_H - -#include "data_type.h" -#include "utils/json.h" -#include "utils/required.h" -#include "utils/variant.h" -#include "utils/visitable.h" -#include "visit_struct/visit_struct_intrusive.hpp" - -namespace FlexFlow { - -struct V1GlorotInitializer { - req seed; -}; -FF_VISITABLE_STRUCT(V1GlorotInitializer, seed); - -struct V1ZeroInitializer {}; -FF_VISITABLE_STRUCT(V1ZeroInitializer); - -struct V1UniformInitializer { - int seed; - float min_val; - req max_val; -}; -FF_VISITABLE_STRUCT(V1UniformInitializer, seed, min_val, max_val); - -struct V1NormInitializer { - int seed; - float mean; - req stddev; -}; -FF_VISITABLE_STRUCT(V1NormInitializer, seed, mean, stddev); - -struct V1ConstantInitializer { - req value; -}; -FF_VISITABLE_STRUCT(V1ConstantInitializer, value); - -using V1Initializer = std::variant; - -} // namespace FlexFlow - -namespace FlexFlow { -CHECK_IS_JSONABLE(V1GlorotInitializer); -CHECK_IS_JSONABLE(V1ZeroInitializer); -CHECK_IS_JSONABLE(V1UniformInitializer); -CHECK_IS_JSONABLE(V1NormInitializer); -CHECK_IS_JSONABLE(V1ConstantInitializer); -CHECK_IS_JSONABLE(V1Initializer); -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/operator_attrs.h b/lib/pcg/include/pcg/file_format/v1/operator_attrs.h deleted file mode 100644 index 2830fbd301..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/operator_attrs.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPERATOR_ATTRS_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPERATOR_ATTRS_H - -#include "utils/json.h" -#include - -namespace FlexFlow { - -struct V1Conv2DAttrs {}; -FF_VISITABLE_STRUCT(V1Conv2DAttrs); - -static_assert( - std::is_same, std::tuple<>>::value, ""); - -using V1CompGraphOperatorAttrs = std::variant; -using V1PCGOperatorAttrs = std::variant; - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/parallel_tensor.h b/lib/pcg/include/pcg/file_format/v1/parallel_tensor.h deleted file mode 100644 index c215569b21..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/parallel_tensor.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_PARALLEL_TENSOR_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_PARALLEL_TENSOR_H - -#include "data_type.h" -#include "initializer.h" -#include "param_sync.h" -#include "utils/json.h" -#include "utils/variant.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct V1ParallelDim { - size_t size; - int degree; - req is_replica_dim; -}; -FF_VISITABLE_STRUCT(V1ParallelDim, size, degree, is_replica_dim); - -struct V1ParallelTensorShape { - std::vector dims; - req data_type; -}; -FF_VISITABLE_STRUCT(V1ParallelTensorShape, dims, data_type); - -struct V1ParallelTensor { - V1ParallelTensorShape shape; - std::optional sync_type; - std::optional initializer; - req create_grad; -}; -FF_VISITABLE_STRUCT( - V1ParallelTensor, shape, sync_type, initializer, create_grad); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/param_sync.h b/lib/pcg/include/pcg/file_format/v1/param_sync.h deleted file mode 100644 index 32769a8d20..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/param_sync.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_PCG_FILE_FORMAT_V1_PARAM_SYNC_H -#define _FLEXFLOW_PCG_FILE_FORMAT_V1_PARAM_SYNC_H - -#include "utils/json.h" - -namespace FlexFlow { - -enum class V1ParamSync { PARAM_SERVER, NCCL }; - -NLOHMANN_JSON_SERIALIZE_ENUM(V1ParamSync, - {{V1ParamSync::PARAM_SERVER, "PARAM_SERVER"}, - {V1ParamSync::NCCL, "NCCL"}}); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/tensor.h b/lib/pcg/include/pcg/file_format/v1/tensor.h deleted file mode 100644 index c304a41401..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/tensor.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_TENSOR_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_TENSOR_H - -#include "data_type.h" -#include "initializer.h" -#include "op-attrs/tensor_shape.h" -#include "param_sync.h" -#include "pcg/tensor.h" -#include "utils/visitable.h" -#include - -namespace FlexFlow { - -struct V1TensorShape { - std::vector dims; - req data_type; -}; -FF_VISITABLE_STRUCT(V1TensorShape, dims, data_type); -CHECK_IS_JSONABLE(V1TensorShape); -V1TensorShape to_v1(TensorShape const &); - -struct V1Tensor { - V1TensorShape shape; - std::optional initializer; - bool create_gradients; - std::optional sync_type; - req> name; -}; -FF_VISITABLE_STRUCT( - V1Tensor, shape, initializer, create_gradients, sync_type, name); -CHECK_IS_JSONABLE(V1Tensor); -V1Tensor to_v1(Tensor const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/gpu_id_t.dtg.h b/lib/pcg/include/pcg/gpu_id_t.dtg.h new file mode 100644 index 0000000000..f0847848ca --- /dev/null +++ b/lib/pcg/include/pcg/gpu_id_t.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/gpu_id_t.struct.toml +/* proj-data +{ + "generated_from": "022355e43f43141d332be50ea3080ee2" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_GPU_ID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_GPU_ID_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct gpu_id_t { + gpu_id_t() = delete; + gpu_id_t(int const &gpu_index); + + bool operator==(gpu_id_t const &) const; + bool operator!=(gpu_id_t const &) const; + bool operator<(gpu_id_t const &) const; + bool operator>(gpu_id_t const &) const; + bool operator<=(gpu_id_t const &) const; + bool operator>=(gpu_id_t const &) const; + int gpu_index; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::gpu_id_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::gpu_id_t from_json(json const &); + static void to_json(json &, FlexFlow::gpu_id_t const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(gpu_id_t const &); +std::ostream &operator<<(std::ostream &, gpu_id_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_GPU_ID_T_DTG_H diff --git a/lib/pcg/include/pcg/gpu_id_t.struct.toml b/lib/pcg/include/pcg/gpu_id_t.struct.toml new file mode 100644 index 0000000000..170dbb96fa --- /dev/null +++ b/lib/pcg/include/pcg/gpu_id_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "gpu_id_t" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "gpu_index" +type = "int" diff --git a/lib/pcg/include/pcg/initializer.h b/lib/pcg/include/pcg/initializer.h deleted file mode 100644 index 6913289653..0000000000 --- a/lib/pcg/include/pcg/initializer.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_INITIALIZER_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_INITIALIZER_H - -#include "op-attrs/datatype.h" -#include "utils/required.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct GlorotUniform { - req seed; - /* float scale; */ - /* DataType data_type; */ -}; -FF_VISITABLE_STRUCT(GlorotUniform, seed); - -struct ZeroInitializer { - ZeroInitializer() = default; -}; -FF_VISITABLE_STRUCT(ZeroInitializer); - -struct UniformInitializer { - int seed; - float min_val; - req max_val; -}; -FF_VISITABLE_STRUCT(UniformInitializer, seed, min_val, max_val); - -struct NormInitializer { - int seed; - float mean; - req stddev; -}; -FF_VISITABLE_STRUCT(NormInitializer, seed, mean, stddev); - -struct ConstantInitializer { - req value; -}; -FF_VISITABLE_STRUCT(ConstantInitializer, value); - -using Initializer = std::variant; -CHECK_WELL_BEHAVED_VALUE_TYPE(Initializer); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializer_attrs.dtg.h new file mode 100644 index 0000000000..7f5a470a90 --- /dev/null +++ b/lib/pcg/include/pcg/initializer_attrs.dtg.h @@ -0,0 +1,169 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializer_attrs.variant.toml +/* proj-data +{ + "generated_from": "f66f3a89ea937e96a058d83ab52e2826" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/initializers/constant_initializer_attrs.dtg.h" +#include "pcg/initializers/glorot_uniform_attrs.dtg.h" +#include "pcg/initializers/norm_initializer_attrs.dtg.h" +#include "pcg/initializers/uniform_initializer_attrs.dtg.h" +#include "pcg/initializers/zero_initializer_attrs.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct InitializerAttrs { + InitializerAttrs() = delete; + explicit InitializerAttrs(::FlexFlow::GlorotUniformAttrs const &); + explicit InitializerAttrs(::FlexFlow::ZeroInitializerAttrs const &); + explicit InitializerAttrs(::FlexFlow::UniformInitializerAttrs const &); + explicit InitializerAttrs(::FlexFlow::NormInitializerAttrs const &); + explicit InitializerAttrs(::FlexFlow::ConstantInitializerAttrs const &); + template + static constexpr bool IsPartOfInitializerAttrs_v = + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::GlorotUniformAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::ZeroInitializerAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::UniformInitializerAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::NormInitializerAttrs>()); + return result; + } + case 4: { + ReturnType result = + v(this->get<::FlexFlow::ConstantInitializerAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type InitializerAttrs", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::GlorotUniformAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::ZeroInitializerAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::UniformInitializerAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::NormInitializerAttrs>()); + return result; + } + case 4: { + ReturnType result = + v(this->get<::FlexFlow::ConstantInitializerAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type InitializerAttrs", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfInitializerAttrs_v, + "InitializerAttrs::has() expected one of " + "[::FlexFlow::GlorotUniformAttrs, ::FlexFlow::ZeroInitializerAttrs, " + "::FlexFlow::UniformInitializerAttrs, " + "::FlexFlow::NormInitializerAttrs, " + "::FlexFlow::ConstantInitializerAttrs], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfInitializerAttrs_v, + "InitializerAttrs::get() expected one of " + "[::FlexFlow::GlorotUniformAttrs, ::FlexFlow::ZeroInitializerAttrs, " + "::FlexFlow::UniformInitializerAttrs, " + "::FlexFlow::NormInitializerAttrs, " + "::FlexFlow::ConstantInitializerAttrs], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfInitializerAttrs_v, + "InitializerAttrs::get() expected one of " + "[::FlexFlow::GlorotUniformAttrs, ::FlexFlow::ZeroInitializerAttrs, " + "::FlexFlow::UniformInitializerAttrs, " + "::FlexFlow::NormInitializerAttrs, " + "::FlexFlow::ConstantInitializerAttrs], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(InitializerAttrs const &) const; + bool operator!=(InitializerAttrs const &) const; + bool operator<(InitializerAttrs const &) const; + bool operator>(InitializerAttrs const &) const; + bool operator<=(InitializerAttrs const &) const; + bool operator>=(InitializerAttrs const &) const; + std::variant<::FlexFlow::GlorotUniformAttrs, + ::FlexFlow::ZeroInitializerAttrs, + ::FlexFlow::UniformInitializerAttrs, + ::FlexFlow::NormInitializerAttrs, + ::FlexFlow::ConstantInitializerAttrs> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::InitializerAttrs> { + size_t operator()(::FlexFlow::InitializerAttrs const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::InitializerAttrs> { + static ::FlexFlow::InitializerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::InitializerAttrs const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::InitializerAttrs const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::InitializerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializer_attrs.variant.toml b/lib/pcg/include/pcg/initializer_attrs.variant.toml new file mode 100644 index 0000000000..14a5cfdcac --- /dev/null +++ b/lib/pcg/include/pcg/initializer_attrs.variant.toml @@ -0,0 +1,37 @@ +namespace = "FlexFlow" +name = "InitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/initializers/glorot_uniform_attrs.dtg.h", + "pcg/initializers/zero_initializer_attrs.dtg.h", + "pcg/initializers/uniform_initializer_attrs.dtg.h", + "pcg/initializers/norm_initializer_attrs.dtg.h", + "pcg/initializers/constant_initializer_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::GlorotUniformAttrs" +key = "glorot_uniform" + +[[values]] +type = "::FlexFlow::ZeroInitializerAttrs" +key = "zero" + +[[values]] +type = "::FlexFlow::UniformInitializerAttrs" +key = "uniform" + +[[values]] +type = "::FlexFlow::NormInitializerAttrs" +key = "normal" + +[[values]] +type = "::FlexFlow::ConstantInitializerAttrs" +key = "constant" diff --git a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h new file mode 100644 index 0000000000..1eb9eb8834 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h @@ -0,0 +1,56 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "0162b9c49fe6cbfc65410c6fa8dec427" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_CONSTANT_INITIALIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_CONSTANT_INITIALIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.h" +#include "utils/json.h" +#include +#include +#include + +namespace FlexFlow { +struct ConstantInitializerAttrs { + ConstantInitializerAttrs() = delete; + ConstantInitializerAttrs(::FlexFlow::DataTypeValue const &value); + + bool operator==(ConstantInitializerAttrs const &) const; + bool operator!=(ConstantInitializerAttrs const &) const; + bool operator<(ConstantInitializerAttrs const &) const; + bool operator>(ConstantInitializerAttrs const &) const; + bool operator<=(ConstantInitializerAttrs const &) const; + bool operator>=(ConstantInitializerAttrs const &) const; + ::FlexFlow::DataTypeValue value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ConstantInitializerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ConstantInitializerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ConstantInitializerAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ConstantInitializerAttrs const &); +std::ostream &operator<<(std::ostream &, ConstantInitializerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_CONSTANT_INITIALIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml new file mode 100644 index 0000000000..3a80559d7b --- /dev/null +++ b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "ConstantInitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/datatype.h", + "utils/json.h", +] + +[[fields]] +name = "value" +type = "::FlexFlow::DataTypeValue" diff --git a/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.dtg.h b/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.dtg.h new file mode 100644 index 0000000000..04851fb333 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml +/* proj-data +{ + "generated_from": "a268b411b6d378faa11e60c8517d7be5" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_GLOROT_UNIFORM_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_GLOROT_UNIFORM_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct GlorotUniformAttrs { + GlorotUniformAttrs() = delete; + GlorotUniformAttrs(int const &seed); + + bool operator==(GlorotUniformAttrs const &) const; + bool operator!=(GlorotUniformAttrs const &) const; + bool operator<(GlorotUniformAttrs const &) const; + bool operator>(GlorotUniformAttrs const &) const; + bool operator<=(GlorotUniformAttrs const &) const; + bool operator>=(GlorotUniformAttrs const &) const; + int seed; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::GlorotUniformAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::GlorotUniformAttrs from_json(json const &); + static void to_json(json &, FlexFlow::GlorotUniformAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(GlorotUniformAttrs const &); +std::ostream &operator<<(std::ostream &, GlorotUniformAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_GLOROT_UNIFORM_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml b/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml new file mode 100644 index 0000000000..de7f9141b0 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "GlorotUniformAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" diff --git a/lib/pcg/include/pcg/initializers/norm_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/norm_initializer_attrs.dtg.h new file mode 100644 index 0000000000..e1d3e59ed7 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/norm_initializer_attrs.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "6843fc9ca02aea2b40e57dbc497f99ac" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_NORM_INITIALIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_NORM_INITIALIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct NormInitializerAttrs { + NormInitializerAttrs() = delete; + NormInitializerAttrs(int const &seed, float const &mean, float const &stddev); + + bool operator==(NormInitializerAttrs const &) const; + bool operator!=(NormInitializerAttrs const &) const; + bool operator<(NormInitializerAttrs const &) const; + bool operator>(NormInitializerAttrs const &) const; + bool operator<=(NormInitializerAttrs const &) const; + bool operator>=(NormInitializerAttrs const &) const; + int seed; + float mean; + float stddev; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::NormInitializerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::NormInitializerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::NormInitializerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(NormInitializerAttrs const &); +std::ostream &operator<<(std::ostream &, NormInitializerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_NORM_INITIALIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml new file mode 100644 index 0000000000..ec138de63e --- /dev/null +++ b/lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "NormInitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" + +[[fields]] +name = "mean" +type = "float" + +[[fields]] +name = "stddev" +type = "float" diff --git a/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h new file mode 100644 index 0000000000..1f4deada06 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "f887e1db5d5dc710793ec5fa99bb7cd4" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_UNIFORM_INITIALIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_UNIFORM_INITIALIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct UniformInitializerAttrs { + UniformInitializerAttrs() = delete; + UniformInitializerAttrs(int const &seed, + float const &min_val, + float const &max_val); + + bool operator==(UniformInitializerAttrs const &) const; + bool operator!=(UniformInitializerAttrs const &) const; + bool operator<(UniformInitializerAttrs const &) const; + bool operator>(UniformInitializerAttrs const &) const; + bool operator<=(UniformInitializerAttrs const &) const; + bool operator>=(UniformInitializerAttrs const &) const; + int seed; + float min_val; + float max_val; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::UniformInitializerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::UniformInitializerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::UniformInitializerAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(UniformInitializerAttrs const &); +std::ostream &operator<<(std::ostream &, UniformInitializerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_UNIFORM_INITIALIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml new file mode 100644 index 0000000000..11a6597c0a --- /dev/null +++ b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "UniformInitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" + +[[fields]] +name = "min_val" +type = "float" + +[[fields]] +name = "max_val" +type = "float" diff --git a/lib/pcg/include/pcg/initializers/zero_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/zero_initializer_attrs.dtg.h new file mode 100644 index 0000000000..f3086ea087 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/zero_initializer_attrs.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "a19d5a2cdc67a2840d6ba55250a10411" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_ZERO_INITIALIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_ZERO_INITIALIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ZeroInitializerAttrs { + bool operator==(ZeroInitializerAttrs const &) const; + bool operator!=(ZeroInitializerAttrs const &) const; + bool operator<(ZeroInitializerAttrs const &) const; + bool operator>(ZeroInitializerAttrs const &) const; + bool operator<=(ZeroInitializerAttrs const &) const; + bool operator>=(ZeroInitializerAttrs const &) const; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ZeroInitializerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ZeroInitializerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ZeroInitializerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ZeroInitializerAttrs const &); +std::ostream &operator<<(std::ostream &, ZeroInitializerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_ZERO_INITIALIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml new file mode 100644 index 0000000000..db1b6238d5 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "ZeroInitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/pcg/include/pcg/layer.h b/lib/pcg/include/pcg/layer.h deleted file mode 100644 index 9749cb9d06..0000000000 --- a/lib/pcg/include/pcg/layer.h +++ /dev/null @@ -1,33 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_LAYER_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_LAYER_H - -#include "op-attrs/operator_attrs.h" -#include "utils/stack_string.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct Layer { -public: - Layer() = delete; - Layer(CompGraphOperatorAttrs const &attrs, - std::optional const &name); - -public: - std::optional> name; - CompGraphOperatorAttrs attrs; -}; - -} // namespace FlexFlow - -VISITABLE_STRUCT(::FlexFlow::Layer, attrs, name); -MAKE_VISIT_HASHABLE(::FlexFlow::Layer); - -namespace FlexFlow { - -FF_VISIT_FMTABLE(Layer); -// CHECK_FMTABLE(Layer); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/layer_attrs.dtg.h b/lib/pcg/include/pcg/layer_attrs.dtg.h new file mode 100644 index 0000000000..6afa1757dc --- /dev/null +++ b/lib/pcg/include/pcg/layer_attrs.dtg.h @@ -0,0 +1,60 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/layer_attrs.struct.toml +/* proj-data +{ + "generated_from": "b3e4f0c07a906139b599bd4696cb5e65" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "utils/json.h" +#include "utils/stack_string.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct LayerAttrs { + LayerAttrs() = delete; + LayerAttrs(::FlexFlow::ComputationGraphOpAttrs const &attrs, + std::optional<::FlexFlow::stack_string> const &name); + + bool operator==(LayerAttrs const &) const; + bool operator!=(LayerAttrs const &) const; + bool operator<(LayerAttrs const &) const; + bool operator>(LayerAttrs const &) const; + bool operator<=(LayerAttrs const &) const; + bool operator>=(LayerAttrs const &) const; + ::FlexFlow::ComputationGraphOpAttrs attrs; + std::optional<::FlexFlow::stack_string> name; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LayerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::LayerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::LayerAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(LayerAttrs const &); +std::ostream &operator<<(std::ostream &, LayerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/layer_attrs.struct.toml b/lib/pcg/include/pcg/layer_attrs.struct.toml new file mode 100644 index 0000000000..9f8aaa5ba3 --- /dev/null +++ b/lib/pcg/include/pcg/layer_attrs.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "LayerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/computation_graph_op_attrs.dtg.h", + "utils/stack_string.h", + "", + "utils/json.h" +] + +[[fields]] +name = "attrs" +type = "::FlexFlow::ComputationGraphOpAttrs" + +[[fields]] +name = "name" +type = "std::optional<::FlexFlow::stack_string>" + diff --git a/lib/pcg/include/pcg/layer_guid_t.dtg.h b/lib/pcg/include/pcg/layer_guid_t.dtg.h new file mode 100644 index 0000000000..4bbdd36fed --- /dev/null +++ b/lib/pcg/include/pcg/layer_guid_t.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/layer_guid_t.struct.toml +/* proj-data +{ + "generated_from": "a672ffe470fd1dde8299f91f3038ca7a" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_GUID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_GUID_T_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct layer_guid_t { + layer_guid_t() = delete; + layer_guid_t(::FlexFlow::Node const &raw_node); + + bool operator==(layer_guid_t const &) const; + bool operator!=(layer_guid_t const &) const; + bool operator<(layer_guid_t const &) const; + bool operator>(layer_guid_t const &) const; + bool operator<=(layer_guid_t const &) const; + bool operator>=(layer_guid_t const &) const; + ::FlexFlow::Node raw_node; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::layer_guid_t const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(layer_guid_t const &); +std::ostream &operator<<(std::ostream &, layer_guid_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_GUID_T_DTG_H diff --git a/lib/pcg/include/pcg/layer_guid_t.struct.toml b/lib/pcg/include/pcg/layer_guid_t.struct.toml new file mode 100644 index 0000000000..c6d4073f58 --- /dev/null +++ b/lib/pcg/include/pcg/layer_guid_t.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "layer_guid_t" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_node" +type = "::FlexFlow::Node" diff --git a/lib/pcg/include/pcg/machine_specification.dtg.h b/lib/pcg/include/pcg/machine_specification.dtg.h new file mode 100644 index 0000000000..cd6ffe6c0f --- /dev/null +++ b/lib/pcg/include/pcg/machine_specification.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/machine_specification.struct.toml +/* proj-data +{ + "generated_from": "72c3ae372af189d0c8bae74c2dbbc531" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct MachineSpecification { + MachineSpecification() = delete; + MachineSpecification(int const &num_nodes, + int const &num_cpus_per_node, + int const &num_gpus_per_node, + float const &inter_node_bandwidth, + float const &intra_node_bandwidth); + + bool operator==(MachineSpecification const &) const; + bool operator!=(MachineSpecification const &) const; + bool operator<(MachineSpecification const &) const; + bool operator>(MachineSpecification const &) const; + bool operator<=(MachineSpecification const &) const; + bool operator>=(MachineSpecification const &) const; + int num_nodes; + int num_cpus_per_node; + int num_gpus_per_node; + float inter_node_bandwidth; + float intra_node_bandwidth; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MachineSpecification const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MachineSpecification from_json(json const &); + static void to_json(json &, FlexFlow::MachineSpecification const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(MachineSpecification const &); +std::ostream &operator<<(std::ostream &, MachineSpecification const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_DTG_H diff --git a/lib/pcg/include/pcg/machine_specification.h b/lib/pcg/include/pcg/machine_specification.h index 1b2a02b070..cf84bf5048 100644 --- a/lib/pcg/include/pcg/machine_specification.h +++ b/lib/pcg/include/pcg/machine_specification.h @@ -1,31 +1,8 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_H #define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_H -#include "machine_view.h" -#include "utils/visitable.h" +#include "machine_specification_t.h" -namespace FlexFlow { - -struct BandwidthNetworkModelConfig - : public use_visitable_cmp { - int bandwidth; -}; - -struct MachineSpecification { - int num_nodes; - int num_cpus_per_node; - int num_gpus_per_node; - float inter_node_bandwidth; - req intra_node_bandwidth; -}; - -FF_VISITABLE_STRUCT(MachineSpecification, - num_nodes, - num_cpus_per_node, - num_gpus_per_node, - inter_node_bandwidth, - intra_node_bandwidth); - -} // namespace FlexFlow +namespace FlexFlow {} // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/machine_specification.struct.toml b/lib/pcg/include/pcg/machine_specification.struct.toml new file mode 100644 index 0000000000..e75b5018cb --- /dev/null +++ b/lib/pcg/include/pcg/machine_specification.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "MachineSpecification" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +[[fields]] +name = "num_nodes" +type = "int" + +[[fields]] +name = "num_cpus_per_node" +type = "int" + +[[fields]] +name = "num_gpus_per_node" +type = "int" + +[[fields]] +name = "inter_node_bandwidth" +type = "float" + +[[fields]] +name = "intra_node_bandwidth" +type = "float" diff --git a/lib/pcg/include/pcg/machine_view.dtg.h b/lib/pcg/include/pcg/machine_view.dtg.h new file mode 100644 index 0000000000..2eae6e2c8b --- /dev/null +++ b/lib/pcg/include/pcg/machine_view.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/machine_view.struct.toml +/* proj-data +{ + "generated_from": "16c571e6bb82d7ef88e5d2a9146638f4" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_VIEW_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_VIEW_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/device_id_t.dtg.h" +#include "pcg/strided_rectangle.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct MachineView { + MachineView() = delete; + MachineView(::FlexFlow::device_id_t const &start, + ::FlexFlow::StridedRectangle const &rect); + + bool operator==(MachineView const &) const; + bool operator!=(MachineView const &) const; + bool operator<(MachineView const &) const; + bool operator>(MachineView const &) const; + bool operator<=(MachineView const &) const; + bool operator>=(MachineView const &) const; + ::FlexFlow::device_id_t start; + ::FlexFlow::StridedRectangle rect; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MachineView const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MachineView from_json(json const &); + static void to_json(json &, FlexFlow::MachineView const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(MachineView const &); +std::ostream &operator<<(std::ostream &, MachineView const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_VIEW_DTG_H diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index 7521cd209a..625b128d35 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -1,30 +1,19 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H #define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H -#include "device_id.h" -#include "device_type.h" -#include "strided_rectangle.h" -#include "utils/graph.h" -#include "utils/visitable.h" +#include "pcg/cpu_id_t.dtg.h" +#include "pcg/device_id_t.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/gpu_id_t.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/num_points_t.dtg.h" +#include "pcg/side_size_t.dtg.h" #include -#include #include namespace FlexFlow { -struct MachineView { - std::vector device_ids() const; - - device_id_t at(FFOrdered const &coord) const; - StridedRectangleSide at(size_t) const; - -public: - device_id_t start; - StridedRectangle rect; -}; - -FF_VISITABLE_STRUCT(MachineView, start, rect); - +std::vector device_ids(MachineView const &); std::size_t num_dims(MachineView const &); std::size_t num_devices(MachineView const &); DeviceType get_device_type(MachineView const &); diff --git a/lib/pcg/include/pcg/machine_view.struct.toml b/lib/pcg/include/pcg/machine_view.struct.toml new file mode 100644 index 0000000000..c97731991f --- /dev/null +++ b/lib/pcg/include/pcg/machine_view.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "MachineView" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "pcg/device_id_t.dtg.h", + "pcg/strided_rectangle.dtg.h", +] + +[[fields]] +name = "start" +type = "::FlexFlow::device_id_t" + +[[fields]] +name = "rect" +type = "::FlexFlow::StridedRectangle" diff --git a/lib/pcg/include/pcg/num_points_t.dtg.h b/lib/pcg/include/pcg/num_points_t.dtg.h new file mode 100644 index 0000000000..3b8e0e0c6c --- /dev/null +++ b/lib/pcg/include/pcg/num_points_t.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/num_points_t.struct.toml +/* proj-data +{ + "generated_from": "2a862b92055eda0508447d2f4df52f71" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_NUM_POINTS_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_NUM_POINTS_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct num_points_t { + num_points_t() = delete; + num_points_t(int const &unwrapped); + + bool operator==(num_points_t const &) const; + bool operator!=(num_points_t const &) const; + bool operator<(num_points_t const &) const; + bool operator>(num_points_t const &) const; + bool operator<=(num_points_t const &) const; + bool operator>=(num_points_t const &) const; + int unwrapped; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::num_points_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::num_points_t from_json(json const &); + static void to_json(json &, FlexFlow::num_points_t const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(num_points_t const &); +std::ostream &operator<<(std::ostream &, num_points_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_NUM_POINTS_T_DTG_H diff --git a/lib/pcg/include/pcg/num_points_t.struct.toml b/lib/pcg/include/pcg/num_points_t.struct.toml new file mode 100644 index 0000000000..b389245c63 --- /dev/null +++ b/lib/pcg/include/pcg/num_points_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "num_points_t" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "unwrapped" +type = "int" diff --git a/lib/pcg/include/pcg/open_dataflow_graph.h b/lib/pcg/include/pcg/open_dataflow_graph.h new file mode 100644 index 0000000000..b3367686b3 --- /dev/null +++ b/lib/pcg/include/pcg/open_dataflow_graph.h @@ -0,0 +1,81 @@ +// #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPEN_DATAFLOW_GRAPH_H +// #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPEN_DATAFLOW_GRAPH_H +// +// #include "utils/containers/enumerate_vector.h" +// #include "utils/graph.h" +// #include "pcg/dataflow_input.dtg.h" +// +// namespace FlexFlow { +// +// template +// struct OpenDataflowGraph { +// public: +// OpenDataflowGraph() +// : g(OutputLabelledOpenMultiDiGraph::template +// create< +// UnorderedOutputLabelledOpenMultiDiGraph>()) +// { } +// +// DataflowInput add_external_input(OutputLabel const &label) { +// /* size_t src_node_idx = edge_uid_ctr; */ +// /* edge_uid_ctr++; */ +// /* size_t src_port_idx = 0; */ +// /* edge_uid_t edge_uid = { src_node_idx, src_port_idx }; */ +// /* return MultiDiOutput{edge_uid}; */ +// } +// +// std::vector add_operator(NodeLabel const &func, +// std::vector const &inputs, std::vector const +// &outputs) { +// Node n = this->g.add_node(func); +// for (auto const &[idx, input] : enumerate_vector(inputs)) { +// this->g.add_edge(MultiDiEdge{input.src, input.src_idx, n, +// this->make_port_for_idx(idx)}); +// } +// +// std::vector result; +// for (auto const &[idx, label] : enumerate_vector(outputs)) { +// MultiDiOutput output = MultiDiOutput{n, this->make_port_for_idx(idx)}; +// this->g.add_output(output, label); +// result.push_back(output); +// } +// +// return result; +// } +// +// NodePort make_port_for_idx(int idx) { +// if (!this->port_mapping.contains_l(idx)) { +// this->port_mapping.equate(idx, this->g.add_node_port()); +// } +// return this->port_mapping.at_l(idx); +// } +// +// NodePort port_for_idx(int idx) const { +// return this->port_mapping.at_l(idx); +// } +// +// int idx_for_port(NodePort const &p) const { +// return this->port_mapping.at_r(p); +// } +// +// OutputLabelledMultiDiGraphView const +// &get_raw_graph() const { +// return this->g; +// } +// +// NodeLabel const &at(Node const &n) const { +// return this->g.at(n); +// } +// +// OutputLabel const &at(MultiDiOutput const &o) const { +// return this->g.at(o); +// } +// private: +// OutputLabelledOpenMultiDiGraph g; +// bidict port_mapping; +// size_t edge_uid_ctr = 0; +// }; +// +// } // namespace FlexFlow +// +// #endif diff --git a/lib/pcg/include/pcg/operator.h b/lib/pcg/include/pcg/operator.h deleted file mode 100644 index bb9a4cf5e4..0000000000 --- a/lib/pcg/include/pcg/operator.h +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_OPERATOR_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_OPERATOR_H - -#include "op-attrs/operator_attrs.h" -#include "utils/stack_string.h" -#include "utils/visitable.h" - -#include - -namespace FlexFlow { - -struct Operator { -public: - operator PCGOperatorAttrs() const; - -public: - PCGOperatorAttrs attrs; - req> name; -}; - -FF_VISITABLE_STRUCT(Operator, attrs, name); - -static_assert(is_well_behaved_value_type::value); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph.h b/lib/pcg/include/pcg/operator_graph/operator_graph.h new file mode 100644 index 0000000000..5fca50d4c7 --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph.h @@ -0,0 +1,80 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_H + +#include "pcg/operator_graph/operator_graph_input.dtg.h" +#include "pcg/operator_graph/operator_graph_output.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { + +struct OperatorGraphOutputQuery {}; +struct OperatorGraphEdge {}; + +Node get_src_node(OperatorGraphEdge const &); +Node get_dst_node(OperatorGraphEdge const &); +int get_src_idx(OperatorGraphEdge const &); +int get_dst_idx(OperatorGraphEdge const &); + +struct OperatorGraphEdgeQuery; + +struct OperatorGraphView { +public: + using Edge = OperatorGraphEdge; + using EdgeQuery = OperatorGraphEdgeQuery; + + OperatorGraphView(OperatorGraphView const &); + OperatorGraphView &operator=(OperatorGraphView const &); + + OperatorGraphView(OperatorGraphView &&); + OperatorGraphView &&operator=(OperatorGraphView &&); + + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set + query_outputs(OperatorGraphOutputQuery const &) const; + std::unordered_set + query_edges(OperatorGraphEdgeQuery const &) const; + + struct Impl; + std::unique_ptr impl; +}; +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(OperatorGraphView); + +std::unordered_set get_outputs(OperatorGraphView const &); +std::vector get_outputs(OperatorGraphView const &, + Node const &); +std::unordered_set get_uses(OperatorGraphView const &, + OperatorGraphOutput const &); + +struct OperatorGraph { +public: + OperatorGraph(); + OperatorGraph(OperatorGraph const &) = default; + OperatorGraph &operator=(OperatorGraph const &) = default; + + Node add_node(std::vector const &inputs, + int num_outputs); + +private: + struct Impl; + std::unique_ptr impl; +}; + +struct value_t; + +template +struct LabelledOperatorGraphView : virtual OperatorGraphView { + NodeLabel const &at(Node const &) const; + OutputLabel const &at(OperatorGraphOutput const &) const; +}; + +template +struct LabelledOperatorGraph + : virtual LabelledOperatorGraphView { + Node add_node(NodeLabel const &, + std::vector const &inputs, + std::vector const &output_labels); +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_input.dtg.h b/lib/pcg/include/pcg/operator_graph/operator_graph_input.dtg.h new file mode 100644 index 0000000000..13904f220d --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_input.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml +/* proj-data +{ + "generated_from": "57d9c9afc86f43049c6f035c74477afd" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_INPUT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_INPUT_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorGraphInput { + OperatorGraphInput() = delete; + OperatorGraphInput(::FlexFlow::Node const &node, int const &idx); + + bool operator==(OperatorGraphInput const &) const; + bool operator!=(OperatorGraphInput const &) const; + bool operator<(OperatorGraphInput const &) const; + bool operator>(OperatorGraphInput const &) const; + bool operator<=(OperatorGraphInput const &) const; + bool operator>=(OperatorGraphInput const &) const; + ::FlexFlow::Node node; + int idx; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorGraphInput const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OperatorGraphInput const &); +std::ostream &operator<<(std::ostream &, OperatorGraphInput const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_INPUT_DTG_H diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_input.h b/lib/pcg/include/pcg/operator_graph/operator_graph_input.h new file mode 100644 index 0000000000..18e7710186 --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_input.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_INPUT_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_INPUT_H + +#include "pcg/operator_graph/operator_graph_input.dtg.h" + +namespace FlexFlow { + +Node get_node(OperatorGraphInput const &); +int get_idx(OperatorGraphInput const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml b/lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml new file mode 100644 index 0000000000..a729f75bae --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OperatorGraphInput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "idx" +type = "int" diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_output.dtg.h b/lib/pcg/include/pcg/operator_graph/operator_graph_output.dtg.h new file mode 100644 index 0000000000..40bdc245b8 --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_output.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml +/* proj-data +{ + "generated_from": "3931cb388b00e0634495cdb89cb2af54" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_OUTPUT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_OUTPUT_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorGraphOutput { + OperatorGraphOutput() = delete; + OperatorGraphOutput(::FlexFlow::Node const &node, int const &idx); + + bool operator==(OperatorGraphOutput const &) const; + bool operator!=(OperatorGraphOutput const &) const; + bool operator<(OperatorGraphOutput const &) const; + bool operator>(OperatorGraphOutput const &) const; + bool operator<=(OperatorGraphOutput const &) const; + bool operator>=(OperatorGraphOutput const &) const; + ::FlexFlow::Node node; + int idx; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorGraphOutput const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OperatorGraphOutput const &); +std::ostream &operator<<(std::ostream &, OperatorGraphOutput const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_OUTPUT_DTG_H diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_output.h b/lib/pcg/include/pcg/operator_graph/operator_graph_output.h new file mode 100644 index 0000000000..d50b74f496 --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_output.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_OUTPUT_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_OUTPUT_H + +#include "pcg/operator_graph/operator_graph_output.dtg.h" + +namespace FlexFlow { + +Node get_node(OperatorGraphOutput const &); +int get_idx(OperatorGraphOutput const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml b/lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml new file mode 100644 index 0000000000..044d4c8df3 --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OperatorGraphOutput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "idx" +type = "int" diff --git a/lib/pcg/include/pcg/operator_guid_t.dtg.h b/lib/pcg/include/pcg/operator_guid_t.dtg.h new file mode 100644 index 0000000000..bf08150e5e --- /dev/null +++ b/lib/pcg/include/pcg/operator_guid_t.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_guid_t.struct.toml +/* proj-data +{ + "generated_from": "348b5a610f4ff6f545884564ee9a1e6a" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GUID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GUID_T_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct operator_guid_t { + operator_guid_t() = delete; + operator_guid_t(::FlexFlow::Node const &raw_graph_node); + + bool operator==(operator_guid_t const &) const; + bool operator!=(operator_guid_t const &) const; + bool operator<(operator_guid_t const &) const; + bool operator>(operator_guid_t const &) const; + bool operator<=(operator_guid_t const &) const; + bool operator>=(operator_guid_t const &) const; + ::FlexFlow::Node raw_graph_node; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::operator_guid_t const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(operator_guid_t const &); +std::ostream &operator<<(std::ostream &, operator_guid_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GUID_T_DTG_H diff --git a/lib/pcg/include/pcg/operator_guid_t.h b/lib/pcg/include/pcg/operator_guid_t.h deleted file mode 100644 index 46b640774a..0000000000 --- a/lib/pcg/include/pcg/operator_guid_t.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_OPERATOR_GUID_T_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_OPERATOR_GUID_T_H - -#include "utils/graph.h" -#include "utils/strong_typedef.h" - -namespace FlexFlow { - -struct operator_guid_t : strong_typedef { - using strong_typedef::strong_typedef; -}; - -} // namespace FlexFlow - -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::operator_guid_t, "operator_guid"); -MAKE_TYPEDEF_HASHABLE(::FlexFlow::operator_guid_t); - -#endif diff --git a/lib/pcg/include/pcg/operator_guid_t.struct.toml b/lib/pcg/include/pcg/operator_guid_t.struct.toml new file mode 100644 index 0000000000..f89d30137e --- /dev/null +++ b/lib/pcg/include/pcg/operator_guid_t.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "operator_guid_t" +features = [ + "eq", + "ord", + "hash", + # "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "raw_graph_node" +type = "::FlexFlow::Node" diff --git a/lib/pcg/include/pcg/optimizer.h b/lib/pcg/include/pcg/optimizer.h deleted file mode 100644 index 0bb3fab974..0000000000 --- a/lib/pcg/include/pcg/optimizer.h +++ /dev/null @@ -1,41 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H - -#include "utils/variant.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct SGDOptimizer { - double lr; - double momentum; - bool nesterov; - req weight_decay; -}; -FF_VISITABLE_STRUCT(SGDOptimizer, lr, momentum, nesterov, weight_decay); - -struct AdamOptimizer { - double alpha; - double beta1; - double beta2; - double weight_decay; - double epsilon; - double alpha_t; - double beta_t; - req beta2_t; -}; -FF_VISITABLE_STRUCT(AdamOptimizer, - alpha, - beta1, - beta2, - weight_decay, - epsilon, - alpha_t, - beta_t, - beta2_t); - -using Optimizer = std::variant; - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/optimizer_attrs.h b/lib/pcg/include/pcg/optimizer_attrs.h new file mode 100644 index 0000000000..4bac74b999 --- /dev/null +++ b/lib/pcg/include/pcg/optimizer_attrs.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H + +#include "pcg/optimizers/adam_optimizer_attrs.h" +#include "pcg/optimizers/sgd_optimizer_attrs.h" +#include "utils/variant.h" + +namespace FlexFlow { + +using OptimizerAttrs = std::variant; + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.h b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.h new file mode 100644 index 0000000000..a5a6a5ed0a --- /dev/null +++ b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.h @@ -0,0 +1,74 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "f49e1bebcb0ef2bc3c210073e3183d4d" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_ADAM_OPTIMIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_ADAM_OPTIMIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct AdamOptimizerAttrs { + AdamOptimizerAttrs() = delete; + AdamOptimizerAttrs(double const &alpha, + double const &beta1, + double const &beta2, + double const &weight_decay, + double const &alpha_t, + double const &beta_t, + double const &beta2_t); + + bool operator==(AdamOptimizerAttrs const &) const; + bool operator!=(AdamOptimizerAttrs const &) const; + bool operator<(AdamOptimizerAttrs const &) const; + bool operator>(AdamOptimizerAttrs const &) const; + bool operator<=(AdamOptimizerAttrs const &) const; + bool operator>=(AdamOptimizerAttrs const &) const; + double alpha; + double beta1; + double beta2; + double weight_decay; + double alpha_t; + double beta_t; + double beta2_t; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::AdamOptimizerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::AdamOptimizerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::AdamOptimizerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(AdamOptimizerAttrs const &); +std::ostream &operator<<(std::ostream &, AdamOptimizerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_ADAM_OPTIMIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml new file mode 100644 index 0000000000..fd3e83cc4a --- /dev/null +++ b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml @@ -0,0 +1,38 @@ +namespace = "FlexFlow" +name = "AdamOptimizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "alpha" +type = "double" + +[[fields]] +name = "beta1" +type = "double" + +[[fields]] +name = "beta2" +type = "double" + +[[fields]] +name = "weight_decay" +type = "double" + +[[fields]] +name = "alpha_t" +type = "double" + +[[fields]] +name = "beta_t" +type = "double" + +[[fields]] +name = "beta2_t" +type = "double" diff --git a/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.h b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.h new file mode 100644 index 0000000000..f6a17f2354 --- /dev/null +++ b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.h @@ -0,0 +1,68 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "d18c91cdddc760f1fb3990d2c817ee87" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_SGD_OPTIMIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_SGD_OPTIMIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct SGDOptimizerAttrs { + SGDOptimizerAttrs() = delete; + SGDOptimizerAttrs(double const &lr, + double const &momentum, + bool const &nesterov, + double const &weight_decay); + + bool operator==(SGDOptimizerAttrs const &) const; + bool operator!=(SGDOptimizerAttrs const &) const; + bool operator<(SGDOptimizerAttrs const &) const; + bool operator>(SGDOptimizerAttrs const &) const; + bool operator<=(SGDOptimizerAttrs const &) const; + bool operator>=(SGDOptimizerAttrs const &) const; + double lr; + double momentum; + bool nesterov; + double weight_decay; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SGDOptimizerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::SGDOptimizerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::SGDOptimizerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(SGDOptimizerAttrs const &); +std::ostream &operator<<(std::ostream &, SGDOptimizerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_SGD_OPTIMIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml new file mode 100644 index 0000000000..37affb0e1f --- /dev/null +++ b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "SGDOptimizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "lr" +type = "double" + +[[fields]] +name = "momentum" +type = "double" + +[[fields]] +name = "nesterov" +type = "bool" + +[[fields]] +name = "weight_decay" +type = "double" diff --git a/lib/pcg/include/pcg/parallel_computation_graph.dtg.h b/lib/pcg/include/pcg/parallel_computation_graph.dtg.h new file mode 100644 index 0000000000..01fbb7d30c --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph.dtg.h @@ -0,0 +1,31 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_computation_graph.struct.toml +/* proj-data +{ + "generated_from": "e4db0f603f7b8947dda13e01f96c40fb" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_DTG_H + +#include "pcg/dataflow_graph.h" +#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" + +namespace FlexFlow { +struct ParallelComputationGraph { + ParallelComputationGraph() = delete; + ParallelComputationGraph( + ::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const + &raw_graph); + + ::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> + raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph.h index 39a69a80ab..9d7103f4fd 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph.h @@ -1,31 +1,8 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H #define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H -#include "operator.h" -#include "parallel_tensor.h" -#include "utils/graph.h" +#include "pcg/parallel_computation_graph_t.h" -namespace FlexFlow { - -struct ParallelComputationGraph - : public strong_typedef< - ParallelComputationGraph, - OutputLabelledMultiDiGraph> { - using strong_typedef::strong_typedef; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(ParallelComputationGraph); - -bool operator==(ParallelComputationGraph const &, - ParallelComputationGraph const &); - -} // namespace FlexFlow - -namespace std { - -template <> -struct hash { - size_t operator()(FlexFlow::ParallelComputationGraph const &g) const; -}; -} // namespace std +namespace FlexFlow {} #endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph.struct.toml new file mode 100644 index 0000000000..d4e305abe5 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "ParallelComputationGraph" +features = [ ] + +includes = [ + "pcg/dataflow_graph.h", + "pcg/parallel_tensor_attrs.dtg.h", + "pcg/parallel_layer_attrs.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/pcg/include/pcg/parallel_layer_attrs.dtg.h b/lib/pcg/include/pcg/parallel_layer_attrs.dtg.h new file mode 100644 index 0000000000..4c7fce4038 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_layer_attrs.dtg.h @@ -0,0 +1,60 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_layer_attrs.struct.toml +/* proj-data +{ + "generated_from": "97fa0b11c59ae892a8a530ffd67e33ad" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/operator_attrs.h" +#include "utils/stack_string.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ParallelLayerAttrs { + ParallelLayerAttrs() = delete; + ParallelLayerAttrs( + ::FlexFlow::PCGOperatorAttrs const &attrs, + std::optional<::FlexFlow::stack_string> const &name); + + bool operator==(ParallelLayerAttrs const &) const; + bool operator!=(ParallelLayerAttrs const &) const; + bool operator<(ParallelLayerAttrs const &) const; + bool operator>(ParallelLayerAttrs const &) const; + bool operator<=(ParallelLayerAttrs const &) const; + bool operator>=(ParallelLayerAttrs const &) const; + ::FlexFlow::PCGOperatorAttrs attrs; + std::optional<::FlexFlow::stack_string> name; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelLayerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelLayerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ParallelLayerAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelLayerAttrs const &); +std::ostream &operator<<(std::ostream &, ParallelLayerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/parallel_layer_attrs.struct.toml b/lib/pcg/include/pcg/parallel_layer_attrs.struct.toml new file mode 100644 index 0000000000..9b1f8f47aa --- /dev/null +++ b/lib/pcg/include/pcg/parallel_layer_attrs.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "ParallelLayerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/operator_attrs.h", + "utils/stack_string.h", + "", +] + +[[fields]] +name = "attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "name" +type = "std::optional<::FlexFlow::stack_string>" diff --git a/lib/pcg/include/pcg/parallel_tensor.h b/lib/pcg/include/pcg/parallel_tensor.h index 652b408c15..de41e0fb21 100644 --- a/lib/pcg/include/pcg/parallel_tensor.h +++ b/lib/pcg/include/pcg/parallel_tensor.h @@ -21,56 +21,12 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_TENSOR_H #define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_TENSOR_H -#include "create_grad.h" -#include "initializer.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "op-attrs/param_sync.h" +#include "pcg/parallel_tensor_attrs.h" -namespace FlexFlow { - -/** - * @brief Base structure of the parallel tensor representation. - * - * @details Parallel tensor is the fundamental component to support the - * representation and exploration of parallelization strategies. - */ -struct ParallelTensor : public use_visitable_cmp { - ParallelTensor() = delete; - - ParallelTensor(ParallelTensorShape const &, - CreateGrad create_gradients, - std::optional sync_type = std::nullopt, - std::optional initializer = std::nullopt); - ParallelTensor(ParallelTensorDims const &, - DataType, - CreateGrad create_gradients, - std::optional sync_type = std::nullopt, - std::optional initializer = std::nullopt); - - ParallelTensorShape get_shape() const; - -public: - ParallelTensorDims dims; - DataType data_type; - std::optional sync_type = std::nullopt; - std::optional initializer = std::nullopt; - CreateGrad create_gradients; -}; - -using ParallelParameter = ParallelTensor; - -} // namespace FlexFlow - -VISITABLE_STRUCT(::FlexFlow::ParallelTensor, - dims, - data_type, - sync_type, - initializer, - create_gradients); -MAKE_VISIT_HASHABLE(::FlexFlow::ParallelTensor); +namespace FlexFlow {} // namespace FlexFlow namespace FlexFlow { -static_assert(is_well_behaved_value_type::value, ""); +static_assert(is_well_behaved_value_type::value, ""); } #endif diff --git a/lib/pcg/include/pcg/parallel_tensor_attrs.dtg.h b/lib/pcg/include/pcg/parallel_tensor_attrs.dtg.h new file mode 100644 index 0000000000..fa6b153b0a --- /dev/null +++ b/lib/pcg/include/pcg/parallel_tensor_attrs.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml +/* proj-data +{ + "generated_from": "b3e086b380bbc41d99332e1463a34b28" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/param_sync.dtg.h" +#include "pcg/create_grad.dtg.h" +#include "pcg/initializer_attrs.dtg.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ParallelTensorAttrs { + ParallelTensorAttrs() = delete; + ParallelTensorAttrs( + ::FlexFlow::ParallelTensorShape const &shape, + std::optional<::FlexFlow::ParamSync> const &sync_type, + std::optional<::FlexFlow::InitializerAttrs> const &initializer, + ::FlexFlow::CreateGrad const &create_gradients); + + bool operator==(ParallelTensorAttrs const &) const; + bool operator!=(ParallelTensorAttrs const &) const; + bool operator<(ParallelTensorAttrs const &) const; + bool operator>(ParallelTensorAttrs const &) const; + bool operator<=(ParallelTensorAttrs const &) const; + bool operator>=(ParallelTensorAttrs const &) const; + ::FlexFlow::ParallelTensorShape shape; + std::optional<::FlexFlow::ParamSync> sync_type; + std::optional<::FlexFlow::InitializerAttrs> initializer; + ::FlexFlow::CreateGrad create_gradients; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelTensorAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelTensorAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ParallelTensorAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelTensorAttrs const &); +std::ostream &operator<<(std::ostream &, ParallelTensorAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml b/lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml new file mode 100644 index 0000000000..1f81b56ec8 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "ParallelTensorAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.dtg.h", + "op-attrs/param_sync.dtg.h", + "pcg/initializer_attrs.dtg.h", + "pcg/create_grad.dtg.h", + "", +] + +[[fields]] +name = "shape" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "sync_type" +type = "std::optional<::FlexFlow::ParamSync>" + +[[fields]] +name = "initializer" +type = "std::optional<::FlexFlow::InitializerAttrs>" + +[[fields]] +name = "create_gradients" +type = "::FlexFlow::CreateGrad" diff --git a/lib/pcg/include/pcg/serialization.h b/lib/pcg/include/pcg/serialization.h deleted file mode 100644 index 28e16aeb1e..0000000000 --- a/lib/pcg/include/pcg/serialization.h +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_SERIALIZATION_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_SERIALIZATION_H - -#include "computation_graph.h" -#include "layer.h" -#include "machine_specification.h" -#include "parallel_computation_graph.h" -#include "parallel_tensor.h" -#include "tensor_mapping.h" -#include "utils/json.h" - -namespace FlexFlow { - -void from_json(json const &, ComputationGraph &); -void to_json(json &, ComputationGraph const &); - -void from_json(json const &, ParallelComputationGraph &); -void to_json(json &, ParallelComputationGraph const &); - -void from_json(json const &, Layer &); -void to_json(json &, Layer const &); - -void from_json(json const &, ParallelTensor &); -void to_json(json &, ParallelTensor const &); - -void from_json(json const &, Tensor &); -void to_json(json &, Tensor const &); - -void from_json(json const &, Initializer &); -void to_json(json &, Initializer const &); - -void from_json(json const &, MachineSpecification &); -void to_json(json &, MachineSpecification const &); - -void from_json(json const &, Operator &); -void to_json(json &, Operator const &); - -void from_json(json const &, MachineView &); -void to_json(json &, MachineView const &); - -void from_json(json const &, StridedRectangle &); -void to_json(json &, StridedRectangle const &); - -void from_json(json const &, StridedRectangleSide &); -void to_json(json &, StridedRectangleSide const &); - -void from_json(json const &, ParallelTensorDims &); -void to_json(json &, ParallelTensorDims const &); - -void from_json(json const &, TensorDims &); -void to_json(json &, TensorDims const &); - -void from_json(json const &, TensorMapping &); -void to_json(json &, TensorMapping const &); - -void from_json(json const &, ParallelTensorShape &); -void to_json(json &, ParallelTensorShape const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/side_size_t.dtg.h b/lib/pcg/include/pcg/side_size_t.dtg.h new file mode 100644 index 0000000000..fce31b1c9d --- /dev/null +++ b/lib/pcg/include/pcg/side_size_t.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/side_size_t.struct.toml +/* proj-data +{ + "generated_from": "6a1669890e547dcc7a4ddb90be05be15" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_SIDE_SIZE_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_SIDE_SIZE_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct side_size_t { + side_size_t() = delete; + side_size_t(int const &unwrapped); + + bool operator==(side_size_t const &) const; + bool operator!=(side_size_t const &) const; + bool operator<(side_size_t const &) const; + bool operator>(side_size_t const &) const; + bool operator<=(side_size_t const &) const; + bool operator>=(side_size_t const &) const; + int unwrapped; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::side_size_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::side_size_t from_json(json const &); + static void to_json(json &, FlexFlow::side_size_t const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(side_size_t const &); +std::ostream &operator<<(std::ostream &, side_size_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_SIDE_SIZE_T_DTG_H diff --git a/lib/pcg/include/pcg/side_size_t.struct.toml b/lib/pcg/include/pcg/side_size_t.struct.toml new file mode 100644 index 0000000000..dbaad4fedb --- /dev/null +++ b/lib/pcg/include/pcg/side_size_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "side_size_t" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "unwrapped" +type = "int" diff --git a/lib/pcg/include/pcg/strided_rectangle.dtg.h b/lib/pcg/include/pcg/strided_rectangle.dtg.h new file mode 100644 index 0000000000..df6a16a0ad --- /dev/null +++ b/lib/pcg/include/pcg/strided_rectangle.dtg.h @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/strided_rectangle.struct.toml +/* proj-data +{ + "generated_from": "817bbe017d179aa469822a4032d08836" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/dim_ordered.h" +#include "pcg/strided_rectangle_side.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct StridedRectangle { + StridedRectangle() = delete; + StridedRectangle( + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide> const &sides); + + bool operator==(StridedRectangle const &) const; + bool operator!=(StridedRectangle const &) const; + bool operator<(StridedRectangle const &) const; + bool operator>(StridedRectangle const &) const; + bool operator<=(StridedRectangle const &) const; + bool operator>=(StridedRectangle const &) const; + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide> sides; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::StridedRectangle const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::StridedRectangle from_json(json const &); + static void to_json(json &, FlexFlow::StridedRectangle const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(StridedRectangle const &); +std::ostream &operator<<(std::ostream &, StridedRectangle const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_DTG_H diff --git a/lib/pcg/include/pcg/strided_rectangle.h b/lib/pcg/include/pcg/strided_rectangle.h index d123d7c6ac..24ae51ac41 100644 --- a/lib/pcg/include/pcg/strided_rectangle.h +++ b/lib/pcg/include/pcg/strided_rectangle.h @@ -1,62 +1,16 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_H #define _FLEXFLOW_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_H -#include "op-attrs/dim_ordered.h" -#include "op-attrs/ff_dim.h" -#include "utils/stack_vector.h" -#include "utils/strong_typedef.h" -#include "utils/visitable.h" +#include "op-attrs/ff_dim.dtg.h" +#include "pcg/side_size_t.dtg.h" +#include "pcg/strided_rectangle.dtg.h" namespace FlexFlow { -struct num_points_t : public strong_typedef { - using strong_typedef::strong_typedef; -}; - -struct side_size_t : public strong_typedef { - using strong_typedef::strong_typedef; -}; - -struct StridedRectangleSide { -public: - StridedRectangleSide() = delete; - StridedRectangleSide(num_points_t const &, int stride); - StridedRectangleSide(side_size_t const &, int stride); - - num_points_t get_num_points() const; - side_size_t get_size() const; - int get_stride() const; - - side_size_t at(num_points_t) const; - num_points_t at(side_size_t) const; - -public: - num_points_t num_points; - req stride; -}; - -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(StridedRectangleSide, - num_points, - stride); - -struct StridedRectangle { -public: - size_t at(FFOrdered const &) const; - StridedRectangleSide at(ff_dim_t const &) const; - size_t num_dims() const; - -public: - FFOrdered sides; -}; - -FF_VISITABLE_STRUCT(StridedRectangle, sides); +size_t get_num_dims(StridedRectangle const &); +StridedRectangleSide get_side_at_idx(StridedRectangle const &, + ff_dim_t const &); } // namespace FlexFlow -MAKE_TYPEDEF_HASHABLE(::FlexFlow::num_points_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::num_points_t, "num_points"); - -MAKE_TYPEDEF_HASHABLE(::FlexFlow::side_size_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::side_size_t, "side_size"); - #endif diff --git a/lib/pcg/include/pcg/strided_rectangle.struct.toml b/lib/pcg/include/pcg/strided_rectangle.struct.toml new file mode 100644 index 0000000000..3dfd90e296 --- /dev/null +++ b/lib/pcg/include/pcg/strided_rectangle.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "StridedRectangle" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "pcg/strided_rectangle_side.dtg.h", + "op-attrs/dim_ordered.h", +] + +[[fields]] +name = "sides" +type = "::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>" diff --git a/lib/pcg/include/pcg/strided_rectangle_side.dtg.h b/lib/pcg/include/pcg/strided_rectangle_side.dtg.h new file mode 100644 index 0000000000..3e4365c24d --- /dev/null +++ b/lib/pcg/include/pcg/strided_rectangle_side.dtg.h @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/strided_rectangle_side.struct.toml +/* proj-data +{ + "generated_from": "b14fcf1e28c262d22b92fac691ede3d4" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/num_points_t.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct StridedRectangleSide { + StridedRectangleSide() = delete; + StridedRectangleSide(::FlexFlow::num_points_t const &num_points, + int const &stride); + + bool operator==(StridedRectangleSide const &) const; + bool operator!=(StridedRectangleSide const &) const; + bool operator<(StridedRectangleSide const &) const; + bool operator>(StridedRectangleSide const &) const; + bool operator<=(StridedRectangleSide const &) const; + bool operator>=(StridedRectangleSide const &) const; + ::FlexFlow::num_points_t num_points; + int stride; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::StridedRectangleSide const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::StridedRectangleSide from_json(json const &); + static void to_json(json &, FlexFlow::StridedRectangleSide const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(StridedRectangleSide const &); +std::ostream &operator<<(std::ostream &, StridedRectangleSide const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_DTG_H diff --git a/lib/pcg/include/pcg/strided_rectangle_side.h b/lib/pcg/include/pcg/strided_rectangle_side.h new file mode 100644 index 0000000000..1486b73143 --- /dev/null +++ b/lib/pcg/include/pcg/strided_rectangle_side.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_H + +#include "pcg/side_size_t.dtg.h" +#include "pcg/strided_rectangle_side.dtg.h" + +namespace FlexFlow { + +StridedRectangleSide strided_side_from_size_and_stride(side_size_t, int stride); + +side_size_t get_side_size(StridedRectangleSide const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/strided_rectangle_side.struct.toml b/lib/pcg/include/pcg/strided_rectangle_side.struct.toml new file mode 100644 index 0000000000..f26adfafd5 --- /dev/null +++ b/lib/pcg/include/pcg/strided_rectangle_side.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "StridedRectangleSide" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "pcg/num_points_t.dtg.h", +] + +[[fields]] +name = "num_points" +type = "::FlexFlow::num_points_t" + +[[fields]] +name = "stride" +type = "int" diff --git a/lib/pcg/include/pcg/tensor.h b/lib/pcg/include/pcg/tensor.h deleted file mode 100644 index b5ff857a6c..0000000000 --- a/lib/pcg/include/pcg/tensor.h +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_TENSOR_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_TENSOR_H - -#include "create_grad.h" -#include "initializer.h" -#include "op-attrs/param_sync.h" -#include "op-attrs/tensor_shape.h" - -namespace FlexFlow { - -struct Tensor { - /* Tensor() = delete; */ - /* Tensor(TensorShape const &, */ - /* CreateGrad create_gradients, */ - /* optional initializer = nullopt, */ - /* optional sync_type = nullopt); */ - - size_t get_volume() const; - TensorShape get_shape() const; - int num_dims() const; - - operator TensorShape() const; - -public: - TensorDims dims; - DataType data_type; - std::optional initializer; - bool create_gradients; - req> sync_type; -}; -FF_VISITABLE_STRUCT( - Tensor, dims, data_type, initializer, create_gradients, sync_type); - -using Parameter = Tensor; - -Tensor construct_tensor_from_output_shape(TensorShape const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/tensor_attrs.dtg.h b/lib/pcg/include/pcg/tensor_attrs.dtg.h new file mode 100644 index 0000000000..8bc9d3ce9d --- /dev/null +++ b/lib/pcg/include/pcg/tensor_attrs.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/tensor_attrs.struct.toml +/* proj-data +{ + "generated_from": "68447a4357476647ef25dd39dfd12578" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/param_sync.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "pcg/initializer_attrs.dtg.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct TensorAttrs { + TensorAttrs() = delete; + TensorAttrs(::FlexFlow::TensorShape const &shape, + std::optional<::FlexFlow::InitializerAttrs> const &initializer, + bool const &create_gradients, + std::optional<::FlexFlow::ParamSync> const &sync_type); + + bool operator==(TensorAttrs const &) const; + bool operator!=(TensorAttrs const &) const; + bool operator<(TensorAttrs const &) const; + bool operator>(TensorAttrs const &) const; + bool operator<=(TensorAttrs const &) const; + bool operator>=(TensorAttrs const &) const; + ::FlexFlow::TensorShape shape; + std::optional<::FlexFlow::InitializerAttrs> initializer; + bool create_gradients; + std::optional<::FlexFlow::ParamSync> sync_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorAttrs from_json(json const &); + static void to_json(json &, FlexFlow::TensorAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttrs const &); +std::ostream &operator<<(std::ostream &, TensorAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/tensor_attrs.struct.toml b/lib/pcg/include/pcg/tensor_attrs.struct.toml new file mode 100644 index 0000000000..eefb6da702 --- /dev/null +++ b/lib/pcg/include/pcg/tensor_attrs.struct.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "TensorAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_shape.dtg.h", + "pcg/initializer_attrs.dtg.h", + "op-attrs/param_sync.dtg.h", + "", +] + +[[fields]] +name = "shape" +type = "::FlexFlow::TensorShape" + +[[fields]] +name = "initializer" +type = "std::optional<::FlexFlow::InitializerAttrs>" + +[[fields]] +name = "create_gradients" +type = "bool" + +[[fields]] +name = "sync_type" +type = "std::optional<::FlexFlow::ParamSync>" diff --git a/lib/pcg/include/pcg/tensor_guid_t.dtg.h b/lib/pcg/include/pcg/tensor_guid_t.dtg.h new file mode 100644 index 0000000000..c6109c6103 --- /dev/null +++ b/lib/pcg/include/pcg/tensor_guid_t.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/tensor_guid_t.struct.toml +/* proj-data +{ + "generated_from": "dc15fcbb876ec70509dfa8b662963bc3" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_GUID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_GUID_T_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct tensor_guid_t { + tensor_guid_t() = delete; + tensor_guid_t(::FlexFlow::MultiDiOutput const &raw_graph_output); + + bool operator==(tensor_guid_t const &) const; + bool operator!=(tensor_guid_t const &) const; + bool operator<(tensor_guid_t const &) const; + bool operator>(tensor_guid_t const &) const; + bool operator<=(tensor_guid_t const &) const; + bool operator>=(tensor_guid_t const &) const; + ::FlexFlow::MultiDiOutput raw_graph_output; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::tensor_guid_t const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(tensor_guid_t const &); +std::ostream &operator<<(std::ostream &, tensor_guid_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_GUID_T_DTG_H diff --git a/lib/pcg/include/pcg/tensor_guid_t.h b/lib/pcg/include/pcg/tensor_guid_t.h deleted file mode 100644 index 3e4e840a5f..0000000000 --- a/lib/pcg/include/pcg/tensor_guid_t.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_TENSOR_GUID_T_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_TENSOR_GUID_T_H - -#include "utils/graph.h" - -namespace FlexFlow { - -struct tensor_guid_t : strong_typedef { - using strong_typedef::strong_typedef; -}; - -} // namespace FlexFlow - -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::tensor_guid_t, "tensor_guid"); -MAKE_TYPEDEF_HASHABLE(::FlexFlow::tensor_guid_t); - -#endif diff --git a/lib/pcg/include/pcg/tensor_guid_t.struct.toml b/lib/pcg/include/pcg/tensor_guid_t.struct.toml new file mode 100644 index 0000000000..aea4fad108 --- /dev/null +++ b/lib/pcg/include/pcg/tensor_guid_t.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "tensor_guid_t" +features = [ + "eq", + "ord", + "hash", + # "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "raw_graph_output" +type = "::FlexFlow::MultiDiOutput" diff --git a/lib/pcg/src/computation_graph.cc b/lib/pcg/src/computation_graph.cc deleted file mode 100644 index 18fded6d3e..0000000000 --- a/lib/pcg/src/computation_graph.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include "pcg/computation_graph.h" - -namespace FlexFlow { - -std::vector - traverse_comp_graph_forward(ComputationGraph const &comp_graph) { - std::vector layers = get_topological_ordering(comp_graph.value()); - return transform(layers, [&](Node const &e) -> operator_guid_t { - return operator_guid_t{e}; - }); -} - -std::vector - traverse_comp_graph_backward(ComputationGraph const &comp_graph) { - std::vector layers = - reversed>(get_topological_ordering(comp_graph.value())); - return transform(layers, [&](Node const &e) -> operator_guid_t { - return operator_guid_t{e}; - }); -} - -std::vector - sort_edge_set(std::unordered_set edges) { - return transform( - sorted_by(edges, compare_by([](MultiDiEdge const &e) { - return e.src_idx; - })), - [&](MultiDiEdge const &e) -> tensor_guid_t { return tensor_guid_t{e}; }); -} - -std::vector - get_outgoing_tensors(ComputationGraph const &comp_graph, - operator_guid_t n) { - return sort_edge_set(get_outgoing_edges(comp_graph.value(), n.value())); -} - -std::vector - get_incoming_tensors(ComputationGraph const &comp_graph, - operator_guid_t n) { - return sort_edge_set(get_incoming_edges(comp_graph.value(), n.value())); -} - -operator_guid_t create_node(ComputationGraph &comp_graph, Layer const &layer) { - Node added_node = comp_graph.value().add_node(layer); - return operator_guid_t{added_node}; -} - -tensor_guid_t create_outgoing_edge(ComputationGraph &comp_graph, - operator_guid_t node, - int idx, - Tensor tensor) { - MultiDiOutput edge = {node.value(), NodePort{idx}}; - comp_graph.value().add_output(edge, tensor); - return tensor_guid_t{edge}; -} - -void connect_incoming_edges(ComputationGraph &comp_graph, - std::vector const &incoming_edges, - operator_guid_t node) { - size_t incoming_edge_dst_port = 0; - for (tensor_guid_t input : incoming_edges) { - MultiDiOutput input_view = input.value(); - MultiDiEdge edge = {node.value(), - NodePort{incoming_edge_dst_port++}, - input_view.src, - input_view.src_idx}; - comp_graph.value().add_edge(edge); - } -} - -CompGraphOperatorAttrs get_layer_attrs(ComputationGraph const &comp_graph, - operator_guid_t const &n) { - return comp_graph.at(n).attrs; -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/device_id.cc b/lib/pcg/src/device_id.cc deleted file mode 100644 index 2849df7c3c..0000000000 --- a/lib/pcg/src/device_id.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "pcg/device_id.h" -#include "utils/exception.h" -#include - -namespace FlexFlow { - -DeviceType get_device_type(device_id_t const &id) { - if (std::holds_alternative(id)) { - return DeviceType::GPU; - } else { - assert(std::holds_alternative(id)); - return DeviceType::CPU; - } -} - -device_id_t operator+(device_id_t, size_t) { - NOT_IMPLEMENTED(); -} -} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/graphs.cc b/lib/pcg/src/file_format/v1/graphs.cc index d00de7b0c1..eabd266e25 100644 --- a/lib/pcg/src/file_format/v1/graphs.cc +++ b/lib/pcg/src/file_format/v1/graphs.cc @@ -1,23 +1,39 @@ #include "pcg/file_format/v1/graphs.h" +#include "pcg/dataflow_graph.h" +#include "pcg/file_format/v1/graphs/v1_multidigraph.h" +#include "pcg/file_format/v1/graphs/v1_operator_graph.dtg.h" #include "utils/graph/algorithms.h" +#include "utils/integer_conversions.h" namespace FlexFlow { -V1MultiDiGraph to_v1(MultiDiGraphView const &g) { - return to_v1(g, - enumerate(get_nodes(g)).reversed(), - enumerate(get_present_node_ports(g)).reversed()); -} +/* static V1OperatorGraph to_v1(OperatorGraphView const &g, bidict + * const &nodes) { */ +/* std::unordered_set edges; */ +/* for (MultiDiEdge const &e : get_edges(g)) { */ +/* size_t src_node = nodes.at_l(get_src_node(e)); */ +/* size_t dst_node = nodes.at_l(get_dst_node(e)); */ +/* size_t src_idx = size_t_from_int(get_src_idx(e)); */ +/* size_t dst_idx = size_t_from_int(get_dst_idx(e)); */ +/* V1GraphEdge v1_e = {src_node, src_idx, dst_node, dst_idx}; */ +/* edges.insert(v1_e); */ +/* } */ + +/* return V1OperatorGraph{ */ +/* count(nodes.size()), */ +/* edges, */ +/* }; */ +/* } */ -V1MultiDiGraph to_v1(MultiDiGraphView const &g, - std::unordered_map const &nodes, - std::unordered_map const &node_ports) { +static V1MultiDiGraph to_v1(MultiDiGraphView const &g, + bidict const &nodes, + bidict const &node_ports) { std::unordered_set edges; for (MultiDiEdge const &e : get_edges(g)) { - edges.insert({nodes.at(e.src), - node_ports.at(e.src_idx), - nodes.at(e.dst), - node_ports.at(e.dst_idx)}); + edges.insert({nodes.at_l(e.src), + node_ports.at_l(e.src_idx), + nodes.at_l(e.dst), + node_ports.at_l(e.dst_idx)}); } return V1MultiDiGraph{ @@ -27,32 +43,101 @@ V1MultiDiGraph to_v1(MultiDiGraphView const &g, }; } +/* static V1MultiDiGraph to_v1(MultiDiGraphView const &g) { */ +/* return to_v1(g, */ +/* enumerate(get_nodes(g)).reversed(), */ +/* enumerate(get_present_node_ports(g)).reversed()); */ +/* } */ + +/* template */ +/* static V1JsonableGraph */ +/* to_v1(LabelledOperatorGraphView const &g) { */ + +/* bidict nodes = enumerate(get_nodes(g)); */ + +/* V1OperatorGraph unlabelled = to_v1(g, nodes.reversed()); */ +/* std::unordered_map node_labels = */ +/* map_values(nodes, [&](Node const &n) { return g.at(n); }); */ + +/* bidict outputs_bidict = + * enumerate(get_outputs(g)); */ +/* std::unordered_map outputs = */ +/* map_values(outputs_bidict, [&](OperatorGraphOutput const &o) { */ +/* return V1GraphOutput{nodes.at_r(get_node(o)), + * size_t_from_int(get_idx(o))}; */ +/* }); */ + +/* std::unordered_map output_labels = map_values( */ +/* outputs_bidict, [&](OperatorGraphOutput const &o) { return g.at(o); }); + */ + +/* return {node_labels, outputs, output_labels, unlabelled}; */ +/* } */ + template -V1JsonableGraph())), - decltype(to_v1(std::declval()))> - to_v1(OutputLabelledMultiDiGraph const &g) { - using V1NodeLabel = decltype(to_v1(std::declval())); - using V1OutputLabel = decltype(to_v1(std::declval())); +static bidict + get_ports_by_idx(DataflowGraph const &g) { + bidict result; + for (NodePort const &p : get_present_node_ports(g.get_raw_graph())) { + size_t idx = size_t_from_int(g.idx_for_port(p)); + result.equate(idx, p); + } + return result; +} + +template +static V1JsonableGraph + to_v1(DataflowGraph const &g) { + + bidict nodes = enumerate(get_nodes(g.get_raw_graph())); + bidict node_ports = get_ports_by_idx(g); + + V1MultiDiGraph unlabelled = + to_v1(g.get_raw_graph(), nodes.reversed(), node_ports.reversed()); + std::unordered_map node_labels = + map_values(nodes, [&](Node const &n) { return g.at(n); }); + bidict outputs_bidict = + enumerate(get_outputs(g.get_raw_graph())); + std::unordered_map outputs = + map_values(outputs_bidict, [&](MultiDiOutput const &o) { + return V1GraphOutput{nodes.at_r(o.src), node_ports.at_r(o.src_idx)}; + }); + + std::unordered_map output_labels = map_values( + outputs_bidict, [&](MultiDiOutput const &o) { return g.at(o); }); + + return {node_labels, outputs, output_labels, unlabelled}; +} + +template +static V1JsonableGraph + to_v1(OutputLabelledMultiDiGraphView const &g) { bidict nodes = enumerate(get_nodes(g)); bidict node_ports = enumerate(get_present_node_ports(g)); V1MultiDiGraph unlabelled = to_v1(g, nodes.reversed(), node_ports.reversed()); - std::unordered_map node_labels = - map_values(nodes, [&](Node const &n) { return to_v1(g.at(n)); }); + std::unordered_map node_labels = + map_values(nodes, [&](Node const &n) { return g.at(n); }); + bidict outputs_bidict = enumerate(get_outputs(g)); std::unordered_map outputs = map_values(outputs_bidict, [&](MultiDiOutput const &o) { return V1GraphOutput{nodes.at_r(o.src), node_ports.at_r(o.src_idx)}; }); - std::unordered_map output_labels = map_values( - outputs_bidict, [&](MultiDiOutput const &o) { return to_v1(g.at(o)); }); + + std::unordered_map output_labels = map_values( + outputs_bidict, [&](MultiDiOutput const &o) { return g.at(o); }); return {node_labels, outputs, output_labels, unlabelled}; } V1ComputationGraph to_v1(ComputationGraph const &g) { - return to_v1(g.value()); + return to_v1(g.raw_graph); +} + +V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) { + return to_v1(g.raw_graph); } } // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/v1.cc b/lib/pcg/src/file_format/v1/v1.cc deleted file mode 100644 index 7715985eed..0000000000 --- a/lib/pcg/src/file_format/v1/v1.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "pcg/file_format/v1/v1.h" - -namespace FlexFlow { - -V1Tensor to_v1(Tensor const &) { - NOT_IMPLEMENTED(); -} - -V1Layer to_v1(Layer const &) { - NOT_IMPLEMENTED(); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/layer.cc b/lib/pcg/src/layer.cc deleted file mode 100644 index 00fb07a8c5..0000000000 --- a/lib/pcg/src/layer.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "pcg/layer.h" - -namespace FlexFlow { - -Layer::Layer(CompGraphOperatorAttrs const &_attrs, - std::optional const &_name) - : attrs(_attrs), name(_name) {} - -} // namespace FlexFlow diff --git a/lib/pcg/src/machine_view.cc b/lib/pcg/src/machine_view.cc deleted file mode 100644 index 46f87833f0..0000000000 --- a/lib/pcg/src/machine_view.cc +++ /dev/null @@ -1,29 +0,0 @@ -#include "pcg/machine_view.h" -#include "utils/utils.h" - -namespace FlexFlow { - -static StridedRectangle make_1d_rect(int start, int stop, int stride) { - assert(stop > start); - assert(stride > 0); - StridedRectangleSide side = {side_size_t(stop - start), stride}; - StridedRectangle rect = {{side}}; - return rect; -} - -MachineView make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride) { - StridedRectangle rect = make_1d_rect(start.value(), stop.value(), stride); - return {start, rect}; -} - -MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride) { - StridedRectangle rect = make_1d_rect(start.value(), stop.value(), stride); - return {start, rect}; -} - -device_id_t MachineView::at(FFOrdered const &coord) const { - size_t offset = this->rect.at(coord); - return this->start + offset; -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/operator.cc b/lib/pcg/src/operator.cc deleted file mode 100644 index 9d36ae1b25..0000000000 --- a/lib/pcg/src/operator.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "pcg/operator.h" - -namespace FlexFlow { - -Operator::operator PCGOperatorAttrs() const { - return attrs; -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/parallel_computation_graph.cc b/lib/pcg/src/parallel_computation_graph.cc deleted file mode 100644 index 011c40eb4c..0000000000 --- a/lib/pcg/src/parallel_computation_graph.cc +++ /dev/null @@ -1,40 +0,0 @@ -#include "pcg/parallel_computation_graph.h" -#include "utils/graph/algorithms.h" - -namespace FlexFlow { - -bool operator==(ParallelComputationGraph const &lhs, - ParallelComputationGraph const &rhs) { - return std::hash{}(lhs) == - std::hash{}(rhs); -} - -} // namespace FlexFlow - -namespace std { - -size_t hash::operator()( - FlexFlow::ParallelComputationGraph const &g) const { - using namespace FlexFlow; - - size_t h = 0; - - std::vector ordered_nodes = get_topological_ordering(g.value()); - hash_combine(h, ordered_nodes.size()); - - std::unordered_map node_index; - for (int i = 0; i < ordered_nodes.size(); ++i) { - node_index[ordered_nodes[i]] = i; - hash_combine(h, g.value().at(ordered_nodes[i])); - } - - for (MultiDiEdge const &edge : get_edges(g.value())) { - hash_combine(h, node_index.at(edge.src)); - hash_combine(h, node_index.at(edge.dst)); - hash_combine(h, g.value().at(edge)); - } - - return h; -} - -} // namespace std diff --git a/lib/pcg/src/parallel_tensor.cc b/lib/pcg/src/parallel_tensor.cc deleted file mode 100644 index ff53e456ec..0000000000 --- a/lib/pcg/src/parallel_tensor.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "pcg/parallel_tensor.h" - -namespace FlexFlow { - -ParallelTensor::ParallelTensor(ParallelTensorDims const &dims, - DataType data_type, - CreateGrad create_gradients, - std::optional sync_type, - std::optional initializer) - : dims(dims), data_type(data_type), sync_type(sync_type), - initializer(initializer), create_gradients(create_gradients) {} - -ParallelTensorShape ParallelTensor::get_shape() const { - return ParallelTensorShape(dims, data_type); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc new file mode 100644 index 0000000000..12a72ca837 --- /dev/null +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -0,0 +1,20 @@ +#include "pcg/computation_graph.h" +#include "utils/containers.h" + +namespace FlexFlow { + +ComputationGraph make_empty_computation_graph() { + return ComputationGraph{DataflowGraph{}}; +} + +std::unordered_set get_layers(ComputationGraph const &cg) { + return transform(get_nodes(cg.raw_graph), + [&](Node const &n) { return layer_guid_t{n}; }); +} + +TensorAttrs get_tensor_attrs(ComputationGraph const &cg, + tensor_guid_t const &t) { + return cg.raw_graph.at(t.raw_graph_output); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph.dtg.cc b/lib/pcg/src/pcg/computation_graph.dtg.cc new file mode 100644 index 0000000000..bb6233a910 --- /dev/null +++ b/lib/pcg/src/pcg/computation_graph.dtg.cc @@ -0,0 +1,21 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/computation_graph.struct.toml +/* proj-data +{ + "generated_from": "8f1f0e13d75065944f7fe307e12fe280" +} +*/ + +#include "pcg/computation_graph.dtg.h" + +#include "pcg/dataflow_graph.h" +#include "pcg/layer_attrs.dtg.h" +#include "pcg/tensor_attrs.dtg.h" + +namespace FlexFlow { +ComputationGraph::ComputationGraph( + ::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, + ::FlexFlow::TensorAttrs> const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc b/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc new file mode 100644 index 0000000000..18b394f6d0 --- /dev/null +++ b/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc @@ -0,0 +1,43 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml +/* proj-data +{ + "generated_from": "15bf9d73ef934599c9b11807d86ae5d4" +} +*/ + +#include "pcg/computation_graph/layer_added_result.dtg.h" + +#include "pcg/layer_guid_t.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" +#include + +namespace FlexFlow { +LayerAddedResult::LayerAddedResult( + ::FlexFlow::layer_guid_t const &layer, + std::vector<::FlexFlow::tensor_guid_t> const &outputs) + : layer(layer), outputs(outputs) {} +bool LayerAddedResult::operator==(LayerAddedResult const &other) const { + return std::tie(this->layer, this->outputs) == + std::tie(other.layer, other.outputs); +} +bool LayerAddedResult::operator!=(LayerAddedResult const &other) const { + return std::tie(this->layer, this->outputs) != + std::tie(other.layer, other.outputs); +} +} // namespace FlexFlow + +namespace FlexFlow { +std::string format_as(LayerAddedResult const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, LayerAddedResult const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc similarity index 52% rename from lib/pcg/src/computation_graph_builder.cc rename to lib/pcg/src/pcg/computation_graph_builder.cc index f237232a76..8c69b3a724 100644 --- a/lib/pcg/src/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -1,94 +1,151 @@ #include "pcg/computation_graph_builder.h" +#include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/get_op_type.h" #include "op-attrs/get_output_shapes.h" +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/weight_attrs.dtg.h" +#include "pcg/computation_graph.h" +#include "utils/containers.h" +#include "utils/containers/concat_vectors.h" +#include "utils/containers/enumerate_vector.h" #include "utils/expected.h" #include "utils/fmt.h" namespace FlexFlow { -tensor_guid_t ComputationGraphBuilder::add_layer( - Layer const &layer, - std::vector const &inputs, - std::vector>> const - &weight_shapes, - TensorShape const &output_shape) { - operator_guid_t node = create_node(computation_graph, layer); - connect_incoming_edges(computation_graph, inputs, node); - return create_outgoing_edge(computation_graph, - node, - 0, - construct_tensor_from_output_shape(output_shape)); +ComputationGraphBuilder::ComputationGraphBuilder() + : computation_graph(make_empty_computation_graph()) {} + +TensorShape ComputationGraphBuilder::get_shape(tensor_guid_t const &t) const { + return get_tensor_attrs(this->computation_graph, t).shape; +} + +tensor_guid_t ComputationGraphBuilder::create_tensor(TensorShape const &shape, + bool create_grad) { + TensorAttrs tensor_attrs = {shape, std::nullopt, create_grad, std::nullopt}; + LayerAttrs layer_attrs = LayerAttrs{ + ComputationGraphOpAttrs{InputAttrs{}}, + std::nullopt, + }; + + return this->add_layer(layer_attrs, {}, {}, tensor_attrs); } std::vector ComputationGraphBuilder::add_layer( - Layer const &layer, + LayerAttrs const &layer, std::vector const &inputs, - std::vector>> const - &weight_shapes, - std::vector const &output_shapes) { - operator_guid_t node = create_node(computation_graph, layer); - connect_incoming_edges(computation_graph, inputs, node); - std::vector output_tensor_guids; - for (int i = 0; i < output_shapes.size(); ++i) { - output_tensor_guids.push_back(create_outgoing_edge( - computation_graph, - node, - i, - construct_tensor_from_output_shape(output_shapes[i]))); + std::vector const &weights, + std::vector const &outputs) { + std::vector raw_weight_tensors; + for (auto const &kv : enumerate_vector(weights)) { + int weight_idx = kv.first; + TensorAttrs weight_tensor_attrs = kv.second; + + std::optional weight_name = + transform(layer.name, [&](std::string const &layer_name) { + return fmt::format("{}.weights[{}]", layer_name, weight_idx); + }); + LayerAttrs weight_layer_attrs = LayerAttrs{ + ComputationGraphOpAttrs{WeightAttrs{}}, + weight_name, + }; + std::vector weight_layer_inputs = {}; + std::vector weight_output_attrs = {weight_tensor_attrs}; + raw_weight_tensors.push_back( + get_only(this->computation_graph.raw_graph.add_operator( + weight_layer_attrs, weight_layer_inputs, weight_output_attrs))); } - return output_tensor_guids; + + std::vector raw_inputs = transform( + inputs, [](tensor_guid_t const &t) { return t.raw_graph_output; }); + std::vector raw_outputs = + this->computation_graph.raw_graph.add_operator( + layer, concat_vectors(raw_inputs, raw_weight_tensors), outputs); + return transform(raw_outputs, + [](MultiDiOutput const &o) { return tensor_guid_t{o}; }); } -tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &, - TensorShape const &) { - NOT_IMPLEMENTED(); +tensor_guid_t + ComputationGraphBuilder::add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorAttrs const &output) { + std::vector outputs = {output}; + return get_only(this->add_layer(layer, inputs, weights, outputs)); } + +std::vector ComputationGraphBuilder::add_layer( + LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs) { + return this->add_layer( + layer, inputs, weights, transform(outputs, [](TensorShape const &s) { + return TensorAttrs{s, std::nullopt, true, std::nullopt}; + })); +} + tensor_guid_t - ComputationGraphBuilder::cast(tensor_guid_t const &input, - DataType dtype, - std::optional const &name){ - NOT_IMPLEMENTED()} + ComputationGraphBuilder::add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorShape const &output) { + return get_only(this->add_layer( + layer, inputs, weights, std::vector{output})); +} tensor_guid_t ComputationGraphBuilder::as_type(tensor_guid_t const &x, DataType data_type, std::string const &name) { - Tensor tensor = computation_graph.at(x); - if (tensor.data_type < data_type) { + DataType x_datatype = this->get_shape(x).data_type; + if (x_datatype < data_type) { return this->cast(x, data_type, name); - } else if (tensor.data_type > data_type) { - throw mk_runtime_error("Could not convert provided tensor data type {} to " - "desired data type {}", - tensor.data_type, - data_type); + } else if (x_datatype > data_type) { + throw mk_runtime_error( + fmt::format("Could not convert provided tensor data type {} to " + "desired data type {}", + x_datatype, + data_type)); + } else { + return x; } - return x; +} + +tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &, + TensorShape const &) { + NOT_IMPLEMENTED(); +} + +tensor_guid_t + ComputationGraphBuilder::cast(tensor_guid_t const &input, + DataType dtype, + std::optional const &name) { + NOT_IMPLEMENTED() } static std::string get_default_name(OperatorType op_type) { return get_operator_type_name(op_type); } -static std::string get_default_name(ComputationGraphAttrs const &attrs) { +static std::string get_default_name(ComputationGraphOpAttrs const &attrs) { return get_default_name(get_op_type(attrs)); } -template -static std::string get_default_name(std::variant const &attrs) { - return get_default_name(widen(attrs)); -} - tensor_guid_t ComputationGraphBuilder::element_unary( ElementUnaryAttrs const &attrs, tensor_guid_t const &x, std::optional const &maybe_name) { - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - Layer layer = {attrs, name}; + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + TensorShape output_shape = - get_output_shape(attrs, computation_graph.at(input)); + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {}, output_shape); } @@ -97,14 +154,16 @@ tensor_guid_t ComputationGraphBuilder::element_scalar_unary( ElementScalarUnaryAttrs const &attrs, tensor_guid_t const &x, std::optional const &maybe_name) { - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - Layer layer = {attrs, name}; + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; + TensorShape output_shape = - get_output_shape(attrs, computation_graph.at(input)); + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {}, output_shape); } @@ -133,146 +192,109 @@ tensor_guid_t ComputationGraphBuilder::element_binary( std::optional const &maybe_name) { std::string name = maybe_name.value_or(get_default_name(op_type)); - Tensor lhs_tensor = computation_graph.at(lhs); - Tensor rhs_tensor = computation_graph.at(rhs); - - TensorShape compute_shape = - this->get_broadcast_target_shape({lhs_tensor, rhs_tensor}); - DataType compute_type = std::max(lhs_tensor.data_type, rhs_tensor.data_type); + TensorShape compute_shape = this->get_broadcast_target_shape({lhs, rhs}); + DataType compute_type = + std::max(this->get_shape(lhs).data_type, this->get_shape(rhs).data_type); - tensor_guid_t const lhs_input = - this->as_type(this->broadcast(lhs, compute_shape), - compute_type, - name + "_inputl_pre_cast"); - tensor_guid_t const rhs_input = - this->as_type(this->broadcast(rhs, compute_shape), - compute_type, - name + "_inputr_pre_cast"); + tensor_guid_t lhs_input = this->as_type(this->broadcast(lhs, compute_shape), + compute_type, + name + "_inputl_pre_cast"); + tensor_guid_t rhs_input = this->as_type(this->broadcast(rhs, compute_shape), + compute_type, + name + "_inputr_pre_cast"); ElementBinaryAttrs attrs = {op_type, compute_type, false, false}; - Layer layer = {attrs, name}; - TensorShape output_shape = get_output_shape( - attrs, computation_graph.at(lhs_input), computation_graph.at(rhs_input)); + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; - return this->add_layer(layer, {lhs_input, rhs_input}, {}, output_shape); -} + TensorShape output_shape = throw_if_unexpected(get_output_shape( + attrs, this->get_shape(lhs_input), this->get_shape(rhs_input))); -tensor_guid_t ComputationGraphBuilder::dense( - tensor_guid_t const &input, - int outDim, - std::optional activation, - bool use_bias, - DataType data_type, - std::optional const &kernel_initializer, - std::optional const &bias_initializer, - std::optional const &name) { - LinearAttrs attrs = { - outDim, use_bias, data_type, activation.value(), std::nullopt}; - std::string unwrapped_name = name.value_or(get_default_name(attrs)); - - tensor_guid_t input_recast = - this->as_type(input, data_type, unwrapped_name + "input_recast"); - - Tensor input_recast_tensor = computation_graph.at(input_recast); - Layer layer = {attrs, name}; - TensorShape output_shape = get_output_shape(attrs, input_recast_tensor); - Tensor output = { - output_shape.dims, data_type, std::nullopt, false, std::nullopt}; - - std::vector>> weights; - - weights.push_back( - {get_weights_shape(attrs, input_recast_tensor), kernel_initializer}); - - if (use_bias) { - weights.push_back( - {get_bias_shape(attrs, input_recast_tensor), bias_initializer}); - } - - return this->add_layer(layer, {input_recast}, weights, output); + return this->add_layer(layer, {lhs_input, rhs_input}, {}, output_shape); } tensor_guid_t ComputationGraphBuilder::exp(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::EXP, input, name); + return this->element_unary(OperatorType::EXP, input, name); } tensor_guid_t ComputationGraphBuilder::add(tensor_guid_t const &lhs, tensor_guid_t const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_ADD, lhs, rhs, name); + return this->element_binary(OperatorType::EW_ADD, lhs, rhs, name); } tensor_guid_t ComputationGraphBuilder::subtract(tensor_guid_t const &lhs, tensor_guid_t const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_SUB, lhs, rhs, name); + return this->element_binary(OperatorType::EW_SUB, lhs, rhs, name); } tensor_guid_t ComputationGraphBuilder::multiply(tensor_guid_t const &lhs, tensor_guid_t const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_MUL, lhs, rhs, name); + return this->element_binary(OperatorType::EW_MUL, lhs, rhs, name); } tensor_guid_t ComputationGraphBuilder::divide(tensor_guid_t const &lhs, tensor_guid_t const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_DIV, lhs, rhs, name); + return this->element_binary(OperatorType::EW_DIV, lhs, rhs, name); } tensor_guid_t ComputationGraphBuilder::max(tensor_guid_t const &lhs, tensor_guid_t const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_MAX, lhs, rhs, name); + return this->element_binary(OperatorType::EW_MAX, lhs, rhs, name); } tensor_guid_t ComputationGraphBuilder::min(tensor_guid_t const &lhs, tensor_guid_t const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_MIN, lhs, rhs, name); + return this->element_binary(OperatorType::EW_MIN, lhs, rhs, name); } tensor_guid_t ComputationGraphBuilder::rsqrt(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::RSQRT, input, name); + return this->element_unary(OperatorType::RSQRT, input, name); } tensor_guid_t ComputationGraphBuilder::pow(tensor_guid_t const &input, float exponent, std::optional const &name) { - return this->element_scalar_unary(Op::POW, input, exponent, name); + return this->element_scalar_unary(OperatorType::POW, input, exponent, name); } tensor_guid_t ComputationGraphBuilder::scalar_multiply( tensor_guid_t const &input, float scalar, std::optional const &name) { - return this->element_scalar_unary(Op::SCALAR_MULTIPLY, input, scalar, name); + return this->element_scalar_unary( + OperatorType::SCALAR_MULTIPLY, input, scalar, name); } tensor_guid_t ComputationGraphBuilder::scalar_add( tensor_guid_t const &input, float scalar, std::optional const &name) { - return this->element_scalar_unary(Op::SCALAR_ADD, input, scalar, name); + return this->element_scalar_unary( + OperatorType::SCALAR_ADD, input, scalar, name); } tensor_guid_t ComputationGraphBuilder::scalar_sub( tensor_guid_t const &lhs, float rhs, std::optional const &name) { - return this->element_scalar_unary(Op::SCALAR_SUB, lhs, rhs, name); + return this->element_scalar_unary(OperatorType::SCALAR_SUB, lhs, rhs, name); } tensor_guid_t ComputationGraphBuilder::scalar_truediv( @@ -280,55 +302,61 @@ tensor_guid_t ComputationGraphBuilder::scalar_truediv( float denominator, std::optional const &name) { return this->element_scalar_unary( - Op::SCALAR_TRUE_DIV, numerator, denominator, name); + OperatorType::SCALAR_TRUE_DIV, numerator, denominator, name); } tensor_guid_t ComputationGraphBuilder::sin(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::SIN, input, name); + return this->element_unary(OperatorType::SIN, input, name); } tensor_guid_t ComputationGraphBuilder::cos(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::COS, input, name); + return this->element_unary(OperatorType::COS, input, name); } tensor_guid_t ComputationGraphBuilder::relu(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::RELU, input, name); + return this->element_unary(OperatorType::RELU, input, name); } tensor_guid_t ComputationGraphBuilder::identity(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::IDENTITY, input, name); + return this->element_unary(OperatorType::IDENTITY, input, name); } tensor_guid_t ComputationGraphBuilder::gelu(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::GELU, input, name); + return this->element_unary(OperatorType::GELU, input, name); } tensor_guid_t ComputationGraphBuilder::sigmoid(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::SIGMOID, input, name); + return this->element_unary(OperatorType::SIGMOID, input, name); } tensor_guid_t ComputationGraphBuilder::tanh(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::TANH, input, name); + return this->element_unary(OperatorType::TANH, input, name); } tensor_guid_t ComputationGraphBuilder::elu(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::ELU, input, name); + return this->element_unary(OperatorType::ELU, input, name); +} + +static TensorAttrs make_weight_attrs( + TensorShape const &shape, + std::optional const &initializer_attrs) { + return TensorAttrs{shape, initializer_attrs, true, std::nullopt}; } tensor_guid_t ComputationGraphBuilder::conv2d( @@ -343,8 +371,8 @@ tensor_guid_t ComputationGraphBuilder::conv2d( std::optional const &activation, int groups, bool use_bias, - std::optional const &kernel_initializer, - std::optional const &bias_initializer, + std::optional const &kernel_initializer, + std::optional const &bias_initializer, std::optional const &kernel_regularizer, std::optional const &maybe_name) { Conv2DAttrs attrs = {outChannels, @@ -357,23 +385,26 @@ tensor_guid_t ComputationGraphBuilder::conv2d( groups, activation, use_bias}; - std::string name = maybe_name.value_or(get_default_name(attrs)); + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - Tensor input_tensor = computation_graph.at(input); + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; - Layer layer = {attrs, name}; - TensorShape output_shape = get_output_shape(attrs, input_tensor); + TensorShape input_shape = this->get_shape(input); + TensorShape output_shape = get_output_shape(attrs, input_shape); - std::vector>> weights; + std::vector weights; - weights.push_back( - {get_kernel_shape(attrs, input_tensor), kernel_initializer}); + weights.push_back(make_weight_attrs(get_kernel_shape(attrs, input_shape), + kernel_initializer)); if (use_bias) { - weights.push_back({get_bias_shape(attrs, input_tensor), bias_initializer}); + weights.push_back(make_weight_attrs(get_bias_shape(attrs, input_shape), + bias_initializer)); } return this->add_layer(layer, {input}, weights, output_shape); @@ -385,14 +416,14 @@ tensor_guid_t ComputationGraphBuilder::dropout( unsigned long long seed, std::optional const &maybe_name) { DropoutAttrs attrs = {rate, seed}; - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - Layer layer = {attrs, name}; + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - TensorShape output_shape = - get_output_shape(attrs, computation_graph.at(input)); + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); return this->add_layer(layer, {input}, {}, output_shape); } @@ -403,21 +434,26 @@ tensor_guid_t ComputationGraphBuilder::embedding( int outDim, AggregateOp aggr, DataType dtype, - std::optional const &kernel_initializer, + std::optional const &kernel_initializer, std::optional const &maybe_name) { EmbeddingAttrs attrs = {num_entries, outDim, aggr, dtype}; - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - Layer layer = {attrs, name}; + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - Tensor input_tensor = computation_graph.at(input); - TensorShape output_shape = get_output_shape(attrs, input_tensor); - TensorShape weights_shape = get_weights_shape(attrs, input_tensor); + TensorShape input_shape = this->get_shape(input); - return this->add_layer( - layer, {input}, {{weights_shape, kernel_initializer}}, output_shape); + TensorAttrs weight_attrs = make_weight_attrs( + throw_if_unexpected(get_weights_shape(attrs, input_shape)), + kernel_initializer); + + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + + return this->add_layer(layer, {input}, {weight_attrs}, output_shape); } std::vector ComputationGraphBuilder::gather( @@ -426,42 +462,30 @@ std::vector ComputationGraphBuilder::gather( ff_dim_t dim, std::optional const &maybe_name) { GatherAttrs attrs = {dim}; - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - Layer layer = {attrs, name}; - Tensor index_tensor = computation_graph.at(index); - if (index_tensor.data_type != DataType::INT32 && - index_tensor.data_type != DataType::INT64) { + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; + if (this->get_shape(index).data_type != DataType::INT32 && + this->get_shape(index).data_type != DataType::INT64) { throw mk_runtime_error("Invalid data type for input tensor 2 for Gather: " "{} (should be {} or {})", - index_tensor.data_type, + this->get_shape(input).data_type, DataType::INT32, DataType::INT64); } std::vector output_shapes = - get_output_shapes(attrs, computation_graph.at(input), index_tensor); + get_output_shapes(attrs, this->get_shape(input), this->get_shape(index)); return this->add_layer(layer, {input}, {}, output_shapes); } -tensor_guid_t - ComputationGraphBuilder::input(Tensor const &input_tensor, - std::optional const &name) { - InputAttrs input_attrs = {}; - std::string str_name = name.value_or(get_default_name(input_attrs)); - - Layer layer = {input_attrs, str_name}; - - return this->add_layer(layer, {}, {}, input_tensor); -} - -TensorShape ComputationGraphBuilder::get_shape(tensor_guid_t const &t) { - return computation_graph.at(t).get_shape(); -} -std::vector - ComputationGraphBuilder::get_shapes(std::vector const &) { - NOT_IMPLEMENTED(); -} +/* std::vector + * ComputationGraphBuilder::get_shapes(std::vector const &ts) + * const { */ +/* return transform(ts, [&](tensor_guid_t const &t) { return + * this->get_shape(t); }); */ +/* } */ // tensor_guid_t ComputationGraphBuilder::aggregate( // tensor_guid_t const &gate_preds, @@ -475,7 +499,7 @@ std::vector // AggregateAttrs attrs = {n, lambda_bal}; // std::string name = maybe_name.value_or(get_default_name(attrs)); -// Layer layer = {attrs, name}; +// LayerAttrs layer = {attrs, name}; // TensorShape output_shape = get_output_shape(attrs, // get_shape(gate_preds), // get_shape(gate_assign), @@ -494,15 +518,21 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( bool relu, std::optional const &maybe_name) { BatchNormAttrs attrs = BatchNormAttrs{relu}; - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - Layer layer = {attrs, name}; + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = get_output_shape(attrs, get_shape(input)); + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); return this->add_layer(layer, {input}, {}, output_shape); } +TensorShape ComputationGraphBuilder::get_broadcast_target_shape( + std::vector const &) { + NOT_IMPLEMENTED(); +} + TensorShape ComputationGraphBuilder::get_broadcast_target_shape( std::vector const &) { NOT_IMPLEMENTED(); diff --git a/lib/pcg/src/pcg/cpu_id_t.dtg.cc b/lib/pcg/src/pcg/cpu_id_t.dtg.cc new file mode 100644 index 0000000000..f865442eb0 --- /dev/null +++ b/lib/pcg/src/pcg/cpu_id_t.dtg.cc @@ -0,0 +1,74 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/cpu_id_t.struct.toml +/* proj-data +{ + "generated_from": "a0faf78831febfa3a02929169943d9f5" +} +*/ + +#include "pcg/cpu_id_t.dtg.h" + +#include + +namespace FlexFlow { +cpu_id_t::cpu_id_t(int const &cpu_index) : cpu_index(cpu_index) {} +bool cpu_id_t::operator==(cpu_id_t const &other) const { + return std::tie(this->cpu_index) == std::tie(other.cpu_index); +} +bool cpu_id_t::operator!=(cpu_id_t const &other) const { + return std::tie(this->cpu_index) != std::tie(other.cpu_index); +} +bool cpu_id_t::operator<(cpu_id_t const &other) const { + return std::tie(this->cpu_index) < std::tie(other.cpu_index); +} +bool cpu_id_t::operator>(cpu_id_t const &other) const { + return std::tie(this->cpu_index) > std::tie(other.cpu_index); +} +bool cpu_id_t::operator<=(cpu_id_t const &other) const { + return std::tie(this->cpu_index) <= std::tie(other.cpu_index); +} +bool cpu_id_t::operator>=(cpu_id_t const &other) const { + return std::tie(this->cpu_index) >= std::tie(other.cpu_index); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()(FlexFlow::cpu_id_t const &x) const { + size_t result = 0; + result ^= std::hash{}(x.cpu_index) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::cpu_id_t + adl_serializer::from_json(json const &j) { + return {j.at("cpu_index").template get()}; +} +void adl_serializer::to_json(json &j, + FlexFlow::cpu_id_t const &v) { + j["__type"] = "cpu_id_t"; + j["cpu_index"] = v.cpu_index; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(cpu_id_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, cpu_id_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/create_grad.dtg.cc b/lib/pcg/src/pcg/create_grad.dtg.cc new file mode 100644 index 0000000000..b2b7e3233b --- /dev/null +++ b/lib/pcg/src/pcg/create_grad.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/create_grad.enum.toml +/* proj-data +{ + "generated_from": "9fd617027e850b6d6db476a49b3e0334" +} +*/ + +#include "pcg/create_grad.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::CreateGrad x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(CreateGrad x) { + switch (x) { + case CreateGrad::YES: + return "YES"; + case CreateGrad::NO: + return "NO"; + default: + std::ostringstream oss; + oss << "Unknown CreateGrad value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, CreateGrad x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, CreateGrad x) { + switch (x) { + case CreateGrad::YES: + j = "YES"; + break; + case CreateGrad::NO: + j = "NO"; + break; + default: + std::ostringstream oss; + oss << "Unknown CreateGrad value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, CreateGrad &x) { + std::string as_str = j.get(); + if (as_str == "YES") { + x = CreateGrad::YES; + } else if (as_str == "NO") { + x = CreateGrad::NO; + } else { + std::ostringstream oss; + oss << "Unknown CreateGrad value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::CreateGrad::YES, + FlexFlow::CreateGrad::NO); +} +} // namespace rc diff --git a/lib/pcg/src/pcg/dataflow_input.dtg.cc b/lib/pcg/src/pcg/dataflow_input.dtg.cc new file mode 100644 index 0000000000..bd5a43dfa9 --- /dev/null +++ b/lib/pcg/src/pcg/dataflow_input.dtg.cc @@ -0,0 +1,41 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/dataflow_input.variant.toml +/* proj-data +{ + "generated_from": "d6a7f4570e36e257383529e9bf9390ec" +} +*/ + +#include "pcg/dataflow_input.dtg.h" + +namespace FlexFlow { +DataflowInput::DataflowInput(::FlexFlow::MultiDiOutput const &v) + : raw_variant(v) {} +DataflowInput::DataflowInput(int const &v) : raw_variant(v) {} +bool DataflowInput::operator==(DataflowInput const &other) const { + return this->raw_variant == other.raw_variant; +} +bool DataflowInput::operator!=(DataflowInput const &other) const { + return this->raw_variant != other.raw_variant; +} +bool DataflowInput::operator<(DataflowInput const &other) const { + return this->raw_variant < other.raw_variant; +} +bool DataflowInput::operator>(DataflowInput const &other) const { + return this->raw_variant > other.raw_variant; +} +bool DataflowInput::operator<=(DataflowInput const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool DataflowInput::operator>=(DataflowInput const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::DataflowInput>::operator()( + ::FlexFlow::DataflowInput const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std diff --git a/lib/pcg/src/pcg/device_id.cc b/lib/pcg/src/pcg/device_id.cc new file mode 100644 index 0000000000..35b0c9aeda --- /dev/null +++ b/lib/pcg/src/pcg/device_id.cc @@ -0,0 +1,32 @@ +#include "pcg/device_id.h" +#include "utils/exception.h" +#include + +namespace FlexFlow { + +device_id_t operator+(device_id_t, size_t) { + NOT_IMPLEMENTED(); +} + +DeviceType get_device_type(device_id_t const &device_id) { + if (device_id.has()) { + return DeviceType::GPU; + } else { + assert(device_id.has()); + return DeviceType::CPU; + } +} + +gpu_id_t unwrap_gpu(device_id_t device_id) { + return device_id.get(); +} + +cpu_id_t unwrap_cpu(device_id_t device_id) { + return device_id.get(); +} + +device_id_t device_id_from_index(int, DeviceType) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/device_id_t.dtg.cc b/lib/pcg/src/pcg/device_id_t.dtg.cc new file mode 100644 index 0000000000..517c6c198c --- /dev/null +++ b/lib/pcg/src/pcg/device_id_t.dtg.cc @@ -0,0 +1,103 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/device_id_t.variant.toml +/* proj-data +{ + "generated_from": "85870050c742b0159775399ec2be67e3" +} +*/ + +#include "pcg/device_id_t.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +device_id_t::device_id_t(::FlexFlow::gpu_id_t const &v) : raw_variant(v) {} +device_id_t::device_id_t(::FlexFlow::cpu_id_t const &v) : raw_variant(v) {} +bool device_id_t::operator==(device_id_t const &other) const { + return this->raw_variant == other.raw_variant; +} +bool device_id_t::operator!=(device_id_t const &other) const { + return this->raw_variant != other.raw_variant; +} +bool device_id_t::operator<(device_id_t const &other) const { + return this->raw_variant < other.raw_variant; +} +bool device_id_t::operator>(device_id_t const &other) const { + return this->raw_variant > other.raw_variant; +} +bool device_id_t::operator<=(device_id_t const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool device_id_t::operator>=(device_id_t const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::device_id_t>::operator()( + ::FlexFlow::device_id_t const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::device_id_t + adl_serializer<::FlexFlow::device_id_t>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "gpu") { + return ::FlexFlow::device_id_t{ + j.at("value").template get<::FlexFlow::gpu_id_t>()}; + } else if (key == "cpu") { + return ::FlexFlow::device_id_t{ + j.at("value").template get<::FlexFlow::cpu_id_t>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::device_id_t>::to_json( + json &j, ::FlexFlow::device_id_t const &x) { + j["__type"] = "device_id_t"; + switch (x.index()) { + case 0: { + j["type"] = "gpu"; + j["value"] = x.get<::FlexFlow::gpu_id_t>(); + break; + } + case 1: { + j["type"] = "cpu"; + j["value"] = x.get<::FlexFlow::cpu_id_t>(); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type device_id_t", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::device_id_t const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type device_id_t", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ::FlexFlow::device_id_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/device_type.dtg.cc b/lib/pcg/src/pcg/device_type.dtg.cc new file mode 100644 index 0000000000..8279cc4c16 --- /dev/null +++ b/lib/pcg/src/pcg/device_type.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/device_type.enum.toml +/* proj-data +{ + "generated_from": "cfe4bc5e9f7c5796b9b90b420c33935f" +} +*/ + +#include "pcg/device_type.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::DeviceType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(DeviceType x) { + switch (x) { + case DeviceType::GPU: + return "GPU"; + case DeviceType::CPU: + return "CPU"; + default: + std::ostringstream oss; + oss << "Unknown DeviceType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, DeviceType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, DeviceType x) { + switch (x) { + case DeviceType::GPU: + j = "GPU"; + break; + case DeviceType::CPU: + j = "CPU"; + break; + default: + std::ostringstream oss; + oss << "Unknown DeviceType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, DeviceType &x) { + std::string as_str = j.get(); + if (as_str == "GPU") { + x = DeviceType::GPU; + } else if (as_str == "CPU") { + x = DeviceType::CPU; + } else { + std::ostringstream oss; + oss << "Unknown DeviceType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::DeviceType::GPU, + FlexFlow::DeviceType::CPU); +} +} // namespace rc diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_edge.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_edge.dtg.cc new file mode 100644 index 0000000000..713aa941d2 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_edge.dtg.cc @@ -0,0 +1,94 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml +/* proj-data +{ + "generated_from": "865097b569b831af049343e933834329" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" + +#include + +namespace FlexFlow { +V1GraphEdge::V1GraphEdge(size_t const &srcNode, + size_t const &srcIdx, + size_t const &dstNode, + size_t const &dstIdx) + : srcNode(srcNode), srcIdx(srcIdx), dstNode(dstNode), dstIdx(dstIdx) {} +bool V1GraphEdge::operator==(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) == + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +bool V1GraphEdge::operator!=(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) != + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +bool V1GraphEdge::operator<(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) < + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +bool V1GraphEdge::operator>(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) > + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +bool V1GraphEdge::operator<=(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) <= + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +bool V1GraphEdge::operator>=(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) >= + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::V1GraphEdge const &x) const { + size_t result = 0; + result ^= std::hash{}(x.srcNode) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.srcIdx) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.dstNode) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.dstIdx) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::V1GraphEdge + adl_serializer::from_json(json const &j) { + return {j.at("srcNode").template get(), + j.at("srcIdx").template get(), + j.at("dstNode").template get(), + j.at("dstIdx").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::V1GraphEdge const &v) { + j["__type"] = "V1GraphEdge"; + j["srcNode"] = v.srcNode; + j["srcIdx"] = v.srcIdx; + j["dstNode"] = v.dstNode; + j["dstIdx"] = v.dstIdx; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1GraphEdge const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, V1GraphEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_output.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_output.dtg.cc new file mode 100644 index 0000000000..fa0b792a37 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_output.dtg.cc @@ -0,0 +1,81 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml +/* proj-data +{ + "generated_from": "05ff8401c3d976ea2220899edb8dfe3a" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_graph_output.dtg.h" + +#include + +namespace FlexFlow { +V1GraphOutput::V1GraphOutput(size_t const &srcNode, size_t const &srcIdx) + : srcNode(srcNode), srcIdx(srcIdx) {} +bool V1GraphOutput::operator==(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) == + std::tie(other.srcNode, other.srcIdx); +} +bool V1GraphOutput::operator!=(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) != + std::tie(other.srcNode, other.srcIdx); +} +bool V1GraphOutput::operator<(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) < + std::tie(other.srcNode, other.srcIdx); +} +bool V1GraphOutput::operator>(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) > + std::tie(other.srcNode, other.srcIdx); +} +bool V1GraphOutput::operator<=(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) <= + std::tie(other.srcNode, other.srcIdx); +} +bool V1GraphOutput::operator>=(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) >= + std::tie(other.srcNode, other.srcIdx); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::V1GraphOutput const &x) const { + size_t result = 0; + result ^= std::hash{}(x.srcNode) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.srcIdx) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::V1GraphOutput + adl_serializer::from_json(json const &j) { + return {j.at("srcNode").template get(), + j.at("srcIdx").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::V1GraphOutput const &v) { + j["__type"] = "V1GraphOutput"; + j["srcNode"] = v.srcNode; + j["srcIdx"] = v.srcIdx; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1GraphOutput const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, V1GraphOutput const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc new file mode 100644 index 0000000000..7f7e670782 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc @@ -0,0 +1,10 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml +/* proj-data +{ + "generated_from": "0595a9f5a6bc19f9a170cb0e42c4202d" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h" diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc new file mode 100644 index 0000000000..0f5a83b02f --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc @@ -0,0 +1,56 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml +/* proj-data +{ + "generated_from": "fb1033385645e54a19c9b44cef0be04b" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h" + +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" +#include "utils/fmt.h" +#include +#include +#include + +namespace FlexFlow { +V1MultiDiGraph::V1MultiDiGraph( + std::vector const &nodes, + std::vector const &ports, + std::unordered_set<::FlexFlow::V1GraphEdge> const &edges) + : nodes(nodes), ports(ports), edges(edges) {} +} // namespace FlexFlow + +namespace nlohmann { +FlexFlow::V1MultiDiGraph + adl_serializer::from_json(json const &j) { + return {j.at("nodes").template get>(), + j.at("ports").template get>(), + j.at("edges") + .template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::V1MultiDiGraph const &v) { + j["__type"] = "V1MultiDiGraph"; + j["nodes"] = v.nodes; + j["ports"] = v.ports; + j["edges"] = v.edges; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1MultiDiGraph const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, V1MultiDiGraph const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc new file mode 100644 index 0000000000..19f1e09d07 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc @@ -0,0 +1,52 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml +/* proj-data +{ + "generated_from": "5bfd7d8755cfd8cd9dbf57d5c367038e" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_operator_graph.dtg.h" + +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" +#include "utils/fmt.h" +#include +#include +#include + +namespace FlexFlow { +V1OperatorGraph::V1OperatorGraph( + std::vector const &nodes, + std::unordered_set<::FlexFlow::V1GraphEdge> const &edges) + : nodes(nodes), edges(edges) {} +} // namespace FlexFlow + +namespace nlohmann { +FlexFlow::V1OperatorGraph + adl_serializer::from_json(json const &j) { + return {j.at("nodes").template get>(), + j.at("edges") + .template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::V1OperatorGraph const &v) { + j["__type"] = "V1OperatorGraph"; + j["nodes"] = v.nodes; + j["edges"] = v.edges; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1OperatorGraph const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, V1OperatorGraph const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/gpu_id_t.dtg.cc b/lib/pcg/src/pcg/gpu_id_t.dtg.cc new file mode 100644 index 0000000000..e2385a83ce --- /dev/null +++ b/lib/pcg/src/pcg/gpu_id_t.dtg.cc @@ -0,0 +1,74 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/gpu_id_t.struct.toml +/* proj-data +{ + "generated_from": "022355e43f43141d332be50ea3080ee2" +} +*/ + +#include "pcg/gpu_id_t.dtg.h" + +#include + +namespace FlexFlow { +gpu_id_t::gpu_id_t(int const &gpu_index) : gpu_index(gpu_index) {} +bool gpu_id_t::operator==(gpu_id_t const &other) const { + return std::tie(this->gpu_index) == std::tie(other.gpu_index); +} +bool gpu_id_t::operator!=(gpu_id_t const &other) const { + return std::tie(this->gpu_index) != std::tie(other.gpu_index); +} +bool gpu_id_t::operator<(gpu_id_t const &other) const { + return std::tie(this->gpu_index) < std::tie(other.gpu_index); +} +bool gpu_id_t::operator>(gpu_id_t const &other) const { + return std::tie(this->gpu_index) > std::tie(other.gpu_index); +} +bool gpu_id_t::operator<=(gpu_id_t const &other) const { + return std::tie(this->gpu_index) <= std::tie(other.gpu_index); +} +bool gpu_id_t::operator>=(gpu_id_t const &other) const { + return std::tie(this->gpu_index) >= std::tie(other.gpu_index); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()(FlexFlow::gpu_id_t const &x) const { + size_t result = 0; + result ^= std::hash{}(x.gpu_index) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::gpu_id_t + adl_serializer::from_json(json const &j) { + return {j.at("gpu_index").template get()}; +} +void adl_serializer::to_json(json &j, + FlexFlow::gpu_id_t const &v) { + j["__type"] = "gpu_id_t"; + j["gpu_index"] = v.gpu_index; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(gpu_id_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, gpu_id_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializer_attrs.dtg.cc new file mode 100644 index 0000000000..2a4e97db1e --- /dev/null +++ b/lib/pcg/src/pcg/initializer_attrs.dtg.cc @@ -0,0 +1,158 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializer_attrs.variant.toml +/* proj-data +{ + "generated_from": "f66f3a89ea937e96a058d83ab52e2826" +} +*/ + +#include "pcg/initializer_attrs.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +InitializerAttrs::InitializerAttrs(::FlexFlow::GlorotUniformAttrs const &v) + : raw_variant(v) {} +InitializerAttrs::InitializerAttrs(::FlexFlow::ZeroInitializerAttrs const &v) + : raw_variant(v) {} +InitializerAttrs::InitializerAttrs(::FlexFlow::UniformInitializerAttrs const &v) + : raw_variant(v) {} +InitializerAttrs::InitializerAttrs(::FlexFlow::NormInitializerAttrs const &v) + : raw_variant(v) {} +InitializerAttrs::InitializerAttrs( + ::FlexFlow::ConstantInitializerAttrs const &v) + : raw_variant(v) {} +bool InitializerAttrs::operator==(InitializerAttrs const &other) const { + return this->raw_variant == other.raw_variant; +} +bool InitializerAttrs::operator!=(InitializerAttrs const &other) const { + return this->raw_variant != other.raw_variant; +} +bool InitializerAttrs::operator<(InitializerAttrs const &other) const { + return this->raw_variant < other.raw_variant; +} +bool InitializerAttrs::operator>(InitializerAttrs const &other) const { + return this->raw_variant > other.raw_variant; +} +bool InitializerAttrs::operator<=(InitializerAttrs const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool InitializerAttrs::operator>=(InitializerAttrs const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::InitializerAttrs>::operator()( + ::FlexFlow::InitializerAttrs const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::InitializerAttrs + adl_serializer<::FlexFlow::InitializerAttrs>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "glorot_uniform") { + return ::FlexFlow::InitializerAttrs{ + j.at("value").template get<::FlexFlow::GlorotUniformAttrs>()}; + } else if (key == "zero") { + return ::FlexFlow::InitializerAttrs{ + j.at("value").template get<::FlexFlow::ZeroInitializerAttrs>()}; + } else if (key == "uniform") { + return ::FlexFlow::InitializerAttrs{ + j.at("value").template get<::FlexFlow::UniformInitializerAttrs>()}; + } else if (key == "normal") { + return ::FlexFlow::InitializerAttrs{ + j.at("value").template get<::FlexFlow::NormInitializerAttrs>()}; + } else if (key == "constant") { + return ::FlexFlow::InitializerAttrs{ + j.at("value").template get<::FlexFlow::ConstantInitializerAttrs>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::InitializerAttrs>::to_json( + json &j, ::FlexFlow::InitializerAttrs const &x) { + j["__type"] = "InitializerAttrs"; + switch (x.index()) { + case 0: { + j["type"] = "glorot_uniform"; + j["value"] = x.get<::FlexFlow::GlorotUniformAttrs>(); + break; + } + case 1: { + j["type"] = "zero"; + j["value"] = x.get<::FlexFlow::ZeroInitializerAttrs>(); + break; + } + case 2: { + j["type"] = "uniform"; + j["value"] = x.get<::FlexFlow::UniformInitializerAttrs>(); + break; + } + case 3: { + j["type"] = "normal"; + j["value"] = x.get<::FlexFlow::NormInitializerAttrs>(); + break; + } + case 4: { + j["type"] = "constant"; + j["value"] = x.get<::FlexFlow::ConstantInitializerAttrs>(); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type InitializerAttrs", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::InitializerAttrs const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + case 3: { + oss << ""; + break; + } + case 4: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type InitializerAttrs", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::InitializerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc new file mode 100644 index 0000000000..9770c35248 --- /dev/null +++ b/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc @@ -0,0 +1,80 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "0162b9c49fe6cbfc65410c6fa8dec427" +} +*/ + +#include "pcg/initializers/constant_initializer_attrs.dtg.h" + +#include "op-attrs/datatype.h" +#include "utils/json.h" +#include + +namespace FlexFlow { +ConstantInitializerAttrs::ConstantInitializerAttrs( + ::FlexFlow::DataTypeValue const &value) + : value(value) {} +bool ConstantInitializerAttrs::operator==( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool ConstantInitializerAttrs::operator!=( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool ConstantInitializerAttrs::operator<( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool ConstantInitializerAttrs::operator>( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool ConstantInitializerAttrs::operator<=( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool ConstantInitializerAttrs::operator>=( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ConstantInitializerAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::DataTypeValue>{}(x.value) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ConstantInitializerAttrs + adl_serializer::from_json( + json const &j) { + return {j.at("value").template get<::FlexFlow::DataTypeValue>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ConstantInitializerAttrs const &v) { + j["__type"] = "ConstantInitializerAttrs"; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ConstantInitializerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ConstantInitializerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializers/glorot_uniform_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/glorot_uniform_attrs.dtg.cc new file mode 100644 index 0000000000..0c8ae6e60c --- /dev/null +++ b/lib/pcg/src/pcg/initializers/glorot_uniform_attrs.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml +/* proj-data +{ + "generated_from": "a268b411b6d378faa11e60c8517d7be5" +} +*/ + +#include "pcg/initializers/glorot_uniform_attrs.dtg.h" + +#include + +namespace FlexFlow { +GlorotUniformAttrs::GlorotUniformAttrs(int const &seed) : seed(seed) {} +bool GlorotUniformAttrs::operator==(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) == std::tie(other.seed); +} +bool GlorotUniformAttrs::operator!=(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) != std::tie(other.seed); +} +bool GlorotUniformAttrs::operator<(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) < std::tie(other.seed); +} +bool GlorotUniformAttrs::operator>(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) > std::tie(other.seed); +} +bool GlorotUniformAttrs::operator<=(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) <= std::tie(other.seed); +} +bool GlorotUniformAttrs::operator>=(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) >= std::tie(other.seed); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::GlorotUniformAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.seed) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::GlorotUniformAttrs + adl_serializer::from_json(json const &j) { + return {j.at("seed").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::GlorotUniformAttrs const &v) { + j["__type"] = "GlorotUniformAttrs"; + j["seed"] = v.seed; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(GlorotUniformAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, GlorotUniformAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializers/norm_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/norm_initializer_attrs.dtg.cc new file mode 100644 index 0000000000..aceac12212 --- /dev/null +++ b/lib/pcg/src/pcg/initializers/norm_initializer_attrs.dtg.cc @@ -0,0 +1,96 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "6843fc9ca02aea2b40e57dbc497f99ac" +} +*/ + +#include "pcg/initializers/norm_initializer_attrs.dtg.h" + +#include + +namespace FlexFlow { +NormInitializerAttrs::NormInitializerAttrs(int const &seed, + float const &mean, + float const &stddev) + : seed(seed), mean(mean), stddev(stddev) {} +bool NormInitializerAttrs::operator==(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) == + std::tie(other.seed, other.mean, other.stddev); +} +bool NormInitializerAttrs::operator!=(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) != + std::tie(other.seed, other.mean, other.stddev); +} +bool NormInitializerAttrs::operator<(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) < + std::tie(other.seed, other.mean, other.stddev); +} +bool NormInitializerAttrs::operator>(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) > + std::tie(other.seed, other.mean, other.stddev); +} +bool NormInitializerAttrs::operator<=(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) <= + std::tie(other.seed, other.mean, other.stddev); +} +bool NormInitializerAttrs::operator>=(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) >= + std::tie(other.seed, other.mean, other.stddev); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::NormInitializerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.seed) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.mean) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stddev) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::NormInitializerAttrs + adl_serializer::from_json(json const &j) { + return {j.at("seed").template get(), + j.at("mean").template get(), + j.at("stddev").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::NormInitializerAttrs const &v) { + j["__type"] = "NormInitializerAttrs"; + j["seed"] = v.seed; + j["mean"] = v.mean; + j["stddev"] = v.stddev; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), gen::arbitrary(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(NormInitializerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, NormInitializerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc new file mode 100644 index 0000000000..a9c62675d0 --- /dev/null +++ b/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc @@ -0,0 +1,95 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "f887e1db5d5dc710793ec5fa99bb7cd4" +} +*/ + +#include "pcg/initializers/uniform_initializer_attrs.dtg.h" + +#include + +namespace FlexFlow { +UniformInitializerAttrs::UniformInitializerAttrs(int const &seed, + float const &min_val, + float const &max_val) + : seed(seed), min_val(min_val), max_val(max_val) {} +bool UniformInitializerAttrs::operator==( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) == + std::tie(other.seed, other.min_val, other.max_val); +} +bool UniformInitializerAttrs::operator!=( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) != + std::tie(other.seed, other.min_val, other.max_val); +} +bool UniformInitializerAttrs::operator<( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) < + std::tie(other.seed, other.min_val, other.max_val); +} +bool UniformInitializerAttrs::operator>( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) > + std::tie(other.seed, other.min_val, other.max_val); +} +bool UniformInitializerAttrs::operator<=( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) <= + std::tie(other.seed, other.min_val, other.max_val); +} +bool UniformInitializerAttrs::operator>=( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) >= + std::tie(other.seed, other.min_val, other.max_val); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::UniformInitializerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.seed) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.min_val) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.max_val) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::UniformInitializerAttrs + adl_serializer::from_json( + json const &j) { + return {j.at("seed").template get(), + j.at("min_val").template get(), + j.at("max_val").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::UniformInitializerAttrs const &v) { + j["__type"] = "UniformInitializerAttrs"; + j["seed"] = v.seed; + j["min_val"] = v.min_val; + j["max_val"] = v.max_val; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(UniformInitializerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, UniformInitializerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializers/zero_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/zero_initializer_attrs.dtg.cc new file mode 100644 index 0000000000..933501a734 --- /dev/null +++ b/lib/pcg/src/pcg/initializers/zero_initializer_attrs.dtg.cc @@ -0,0 +1,71 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "a19d5a2cdc67a2840d6ba55250a10411" +} +*/ + +#include "pcg/initializers/zero_initializer_attrs.dtg.h" + +#include + +namespace FlexFlow { +bool ZeroInitializerAttrs::operator==(ZeroInitializerAttrs const &other) const { + return std::tie() == std::tie(); +} +bool ZeroInitializerAttrs::operator!=(ZeroInitializerAttrs const &other) const { + return std::tie() != std::tie(); +} +bool ZeroInitializerAttrs::operator<(ZeroInitializerAttrs const &other) const { + return std::tie() < std::tie(); +} +bool ZeroInitializerAttrs::operator>(ZeroInitializerAttrs const &other) const { + return std::tie() > std::tie(); +} +bool ZeroInitializerAttrs::operator<=(ZeroInitializerAttrs const &other) const { + return std::tie() <= std::tie(); +} +bool ZeroInitializerAttrs::operator>=(ZeroInitializerAttrs const &other) const { + return std::tie() >= std::tie(); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ZeroInitializerAttrs const &x) const { + size_t result = 0; + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ZeroInitializerAttrs + adl_serializer::from_json(json const &j) { + return {}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ZeroInitializerAttrs const &v) { + j["__type"] = "ZeroInitializerAttrs"; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ZeroInitializerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ZeroInitializerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/layer_attrs.dtg.cc b/lib/pcg/src/pcg/layer_attrs.dtg.cc new file mode 100644 index 0000000000..21c53ad4e8 --- /dev/null +++ b/lib/pcg/src/pcg/layer_attrs.dtg.cc @@ -0,0 +1,84 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/layer_attrs.struct.toml +/* proj-data +{ + "generated_from": "b3e4f0c07a906139b599bd4696cb5e65" +} +*/ + +#include "pcg/layer_attrs.dtg.h" + +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "utils/json.h" +#include "utils/stack_string.h" +#include +#include + +namespace FlexFlow { +LayerAttrs::LayerAttrs( + ::FlexFlow::ComputationGraphOpAttrs const &attrs, + std::optional<::FlexFlow::stack_string> const &name) + : attrs(attrs), name(name) {} +bool LayerAttrs::operator==(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) == std::tie(other.attrs, other.name); +} +bool LayerAttrs::operator!=(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) != std::tie(other.attrs, other.name); +} +bool LayerAttrs::operator<(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) < std::tie(other.attrs, other.name); +} +bool LayerAttrs::operator>(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) > std::tie(other.attrs, other.name); +} +bool LayerAttrs::operator<=(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) <= std::tie(other.attrs, other.name); +} +bool LayerAttrs::operator>=(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) >= std::tie(other.attrs, other.name); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::LayerAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ComputationGraphOpAttrs>{}(x.attrs) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash>>{}(x.name) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::LayerAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("attrs").template get<::FlexFlow::ComputationGraphOpAttrs>(), + j.at("name") + .template get>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::LayerAttrs const &v) { + j["__type"] = "LayerAttrs"; + j["attrs"] = v.attrs; + j["name"] = v.name; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(LayerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, LayerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/layer_guid_t.dtg.cc b/lib/pcg/src/pcg/layer_guid_t.dtg.cc new file mode 100644 index 0000000000..9d92608569 --- /dev/null +++ b/lib/pcg/src/pcg/layer_guid_t.dtg.cc @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/layer_guid_t.struct.toml +/* proj-data +{ + "generated_from": "a672ffe470fd1dde8299f91f3038ca7a" +} +*/ + +#include "pcg/layer_guid_t.dtg.h" + +#include "utils/graph.h" +#include + +namespace FlexFlow { +layer_guid_t::layer_guid_t(::FlexFlow::Node const &raw_node) + : raw_node(raw_node) {} +bool layer_guid_t::operator==(layer_guid_t const &other) const { + return std::tie(this->raw_node) == std::tie(other.raw_node); +} +bool layer_guid_t::operator!=(layer_guid_t const &other) const { + return std::tie(this->raw_node) != std::tie(other.raw_node); +} +bool layer_guid_t::operator<(layer_guid_t const &other) const { + return std::tie(this->raw_node) < std::tie(other.raw_node); +} +bool layer_guid_t::operator>(layer_guid_t const &other) const { + return std::tie(this->raw_node) > std::tie(other.raw_node); +} +bool layer_guid_t::operator<=(layer_guid_t const &other) const { + return std::tie(this->raw_node) <= std::tie(other.raw_node); +} +bool layer_guid_t::operator>=(layer_guid_t const &other) const { + return std::tie(this->raw_node) >= std::tie(other.raw_node); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::layer_guid_t const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.raw_node) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(layer_guid_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, layer_guid_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_specification.dtg.cc b/lib/pcg/src/pcg/machine_specification.dtg.cc new file mode 100644 index 0000000000..238c61a014 --- /dev/null +++ b/lib/pcg/src/pcg/machine_specification.dtg.cc @@ -0,0 +1,151 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/machine_specification.struct.toml +/* proj-data +{ + "generated_from": "72c3ae372af189d0c8bae74c2dbbc531" +} +*/ + +#include "pcg/machine_specification.dtg.h" + +#include + +namespace FlexFlow { +MachineSpecification::MachineSpecification(int const &num_nodes, + int const &num_cpus_per_node, + int const &num_gpus_per_node, + float const &inter_node_bandwidth, + float const &intra_node_bandwidth) + : num_nodes(num_nodes), num_cpus_per_node(num_cpus_per_node), + num_gpus_per_node(num_gpus_per_node), + inter_node_bandwidth(inter_node_bandwidth), + intra_node_bandwidth(intra_node_bandwidth) {} +bool MachineSpecification::operator==(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) == + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +bool MachineSpecification::operator!=(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) != + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +bool MachineSpecification::operator<(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) < + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +bool MachineSpecification::operator>(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) > + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +bool MachineSpecification::operator<=(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) <= + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +bool MachineSpecification::operator>=(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) >= + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MachineSpecification const &x) const { + size_t result = 0; + result ^= std::hash{}(x.num_nodes) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.num_cpus_per_node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.num_gpus_per_node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.inter_node_bandwidth) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.intra_node_bandwidth) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MachineSpecification + adl_serializer::from_json(json const &j) { + return {j.at("num_nodes").template get(), + j.at("num_cpus_per_node").template get(), + j.at("num_gpus_per_node").template get(), + j.at("inter_node_bandwidth").template get(), + j.at("intra_node_bandwidth").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MachineSpecification const &v) { + j["__type"] = "MachineSpecification"; + j["num_nodes"] = v.num_nodes; + j["num_cpus_per_node"] = v.num_cpus_per_node; + j["num_gpus_per_node"] = v.num_gpus_per_node; + j["inter_node_bandwidth"] = v.inter_node_bandwidth; + j["intra_node_bandwidth"] = v.intra_node_bandwidth; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(MachineSpecification const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MachineSpecification const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc new file mode 100644 index 0000000000..ff1d34852b --- /dev/null +++ b/lib/pcg/src/pcg/machine_view.cc @@ -0,0 +1,63 @@ +#include "pcg/machine_view.h" +#include "pcg/strided_rectangle.dtg.h" +#include "pcg/strided_rectangle_side.h" + +namespace FlexFlow { + +std::vector device_ids(MachineView const &) { + NOT_IMPLEMENTED(); +} + +std::size_t num_dims(MachineView const &) { + NOT_IMPLEMENTED(); +} + +std::size_t num_devices(MachineView const &) { + NOT_IMPLEMENTED(); +} + +DeviceType get_device_type(MachineView const &) { + NOT_IMPLEMENTED(); +} + +static StridedRectangle make_1d_rect(int start, int stop, int stride) { + assert(stop > start); + assert(stride > 0); + StridedRectangleSide side = + strided_side_from_size_and_stride(side_size_t{stop - start}, stride); + StridedRectangle rect = {{side}}; + return rect; +} + +MachineView make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride) { + StridedRectangle rect = make_1d_rect(start.gpu_index, stop.gpu_index, stride); + return {device_id_t{start}, rect}; +} + +MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride) { + StridedRectangle rect = make_1d_rect(start.cpu_index, stop.cpu_index, stride); + return {device_id_t{start}, rect}; +} + +MachineView make_1d_machine_view(device_id_t start, + num_points_t num_points, + int stride) { + NOT_IMPLEMENTED(); +} + +MachineView make_1d_machine_view(device_id_t start, + side_size_t interval_size, + int stride) { + NOT_IMPLEMENTED(); +} + +MachineView make_1d_machine_view(device_id_t start, size_t interval_size) { + NOT_IMPLEMENTED(); +} + +/* device_id_t MachineView::at(FFOrdered const &coord) const { */ +/* size_t offset = this->rect.at(coord); */ +/* return this->start + offset; */ +/* } */ + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_view.dtg.cc b/lib/pcg/src/pcg/machine_view.dtg.cc new file mode 100644 index 0000000000..edab125e3d --- /dev/null +++ b/lib/pcg/src/pcg/machine_view.dtg.cc @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/machine_view.struct.toml +/* proj-data +{ + "generated_from": "16c571e6bb82d7ef88e5d2a9146638f4" +} +*/ + +#include "pcg/machine_view.dtg.h" + +#include "pcg/device_id_t.dtg.h" +#include "pcg/strided_rectangle.dtg.h" +#include + +namespace FlexFlow { +MachineView::MachineView(::FlexFlow::device_id_t const &start, + ::FlexFlow::StridedRectangle const &rect) + : start(start), rect(rect) {} +bool MachineView::operator==(MachineView const &other) const { + return std::tie(this->start, this->rect) == std::tie(other.start, other.rect); +} +bool MachineView::operator!=(MachineView const &other) const { + return std::tie(this->start, this->rect) != std::tie(other.start, other.rect); +} +bool MachineView::operator<(MachineView const &other) const { + return std::tie(this->start, this->rect) < std::tie(other.start, other.rect); +} +bool MachineView::operator>(MachineView const &other) const { + return std::tie(this->start, this->rect) > std::tie(other.start, other.rect); +} +bool MachineView::operator<=(MachineView const &other) const { + return std::tie(this->start, this->rect) <= std::tie(other.start, other.rect); +} +bool MachineView::operator>=(MachineView const &other) const { + return std::tie(this->start, this->rect) >= std::tie(other.start, other.rect); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MachineView const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::device_id_t>{}(x.start) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::StridedRectangle>{}(x.rect) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MachineView + adl_serializer::from_json(json const &j) { + return {j.at("start").template get<::FlexFlow::device_id_t>(), + j.at("rect").template get<::FlexFlow::StridedRectangle>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MachineView const &v) { + j["__type"] = "MachineView"; + j["start"] = v.start; + j["rect"] = v.rect; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(MachineView const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MachineView const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/num_points_t.dtg.cc b/lib/pcg/src/pcg/num_points_t.dtg.cc new file mode 100644 index 0000000000..7a0a849814 --- /dev/null +++ b/lib/pcg/src/pcg/num_points_t.dtg.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/num_points_t.struct.toml +/* proj-data +{ + "generated_from": "2a862b92055eda0508447d2f4df52f71" +} +*/ + +#include "pcg/num_points_t.dtg.h" + +#include + +namespace FlexFlow { +num_points_t::num_points_t(int const &unwrapped) : unwrapped(unwrapped) {} +bool num_points_t::operator==(num_points_t const &other) const { + return std::tie(this->unwrapped) == std::tie(other.unwrapped); +} +bool num_points_t::operator!=(num_points_t const &other) const { + return std::tie(this->unwrapped) != std::tie(other.unwrapped); +} +bool num_points_t::operator<(num_points_t const &other) const { + return std::tie(this->unwrapped) < std::tie(other.unwrapped); +} +bool num_points_t::operator>(num_points_t const &other) const { + return std::tie(this->unwrapped) > std::tie(other.unwrapped); +} +bool num_points_t::operator<=(num_points_t const &other) const { + return std::tie(this->unwrapped) <= std::tie(other.unwrapped); +} +bool num_points_t::operator>=(num_points_t const &other) const { + return std::tie(this->unwrapped) >= std::tie(other.unwrapped); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::num_points_t const &x) const { + size_t result = 0; + result ^= std::hash{}(x.unwrapped) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::num_points_t + adl_serializer::from_json(json const &j) { + return {j.at("unwrapped").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::num_points_t const &v) { + j["__type"] = "num_points_t"; + j["unwrapped"] = v.unwrapped; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(num_points_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, num_points_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph.cc b/lib/pcg/src/pcg/operator_graph/operator_graph.cc new file mode 100644 index 0000000000..461fc8027c --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph/operator_graph.cc @@ -0,0 +1,48 @@ +#include "pcg/operator_graph/operator_graph.h" +#include "utils/graph.h" + +namespace FlexFlow { + +/* struct OperatorGraphView::Impl { */ +/* MultiDiGraphView raw_graph; */ +/* }; */ + +/* struct OperatorGraph::Impl { */ +/* MultiDiGraph raw_graph; */ +/* }; */ + +/* std::unordered_set get_outputs(OperatorGraphView const + * &g) { */ +/* return transform(get_outputs(g.impl->raw_graph), [](MultiDiOutput const &o) + * {}); */ +/* } */ + +/* std::vector get_outputs(OperatorGraphView const &, Node + * const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* std::unordered_set get_uses(OperatorGraphView const &, + * OperatorGraphOutput const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* Node get_src_node(OperatorGraphEdge const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* Node get_dst_node(OperatorGraphEdge const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* int get_src_idx(OperatorGraphEdge const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* int get_dst_idx(OperatorGraphEdge const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* OperatorGraphView::query_nodes */ + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_input.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_input.cc new file mode 100644 index 0000000000..945034dd73 --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_input.cc @@ -0,0 +1,13 @@ +#include "pcg/operator_graph/operator_graph_input.h" + +namespace FlexFlow { + +Node get_node(OperatorGraphInput const &i) { + return i.node; +} + +int get_idx(OperatorGraphInput const &i) { + return i.idx; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc new file mode 100644 index 0000000000..381c948ad0 --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml +/* proj-data +{ + "generated_from": "57d9c9afc86f43049c6f035c74477afd" +} +*/ + +#include "pcg/operator_graph/operator_graph_input.dtg.h" + +#include "utils/graph.h" +#include + +namespace FlexFlow { +OperatorGraphInput::OperatorGraphInput(::FlexFlow::Node const &node, + int const &idx) + : node(node), idx(idx) {} +bool OperatorGraphInput::operator==(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) == std::tie(other.node, other.idx); +} +bool OperatorGraphInput::operator!=(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) != std::tie(other.node, other.idx); +} +bool OperatorGraphInput::operator<(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) < std::tie(other.node, other.idx); +} +bool OperatorGraphInput::operator>(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) > std::tie(other.node, other.idx); +} +bool OperatorGraphInput::operator<=(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) <= std::tie(other.node, other.idx); +} +bool OperatorGraphInput::operator>=(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) >= std::tie(other.node, other.idx); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorGraphInput const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.idx) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OperatorGraphInput const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OperatorGraphInput const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_output.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_output.cc new file mode 100644 index 0000000000..bdfe1a9795 --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_output.cc @@ -0,0 +1,13 @@ +#include "pcg/operator_graph/operator_graph_output.h" + +namespace FlexFlow { + +Node get_node(OperatorGraphOutput const &o) { + return o.node; +} + +int get_idx(OperatorGraphOutput const &o) { + return o.idx; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc new file mode 100644 index 0000000000..88c23c0c67 --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml +/* proj-data +{ + "generated_from": "3931cb388b00e0634495cdb89cb2af54" +} +*/ + +#include "pcg/operator_graph/operator_graph_output.dtg.h" + +#include "utils/graph.h" +#include + +namespace FlexFlow { +OperatorGraphOutput::OperatorGraphOutput(::FlexFlow::Node const &node, + int const &idx) + : node(node), idx(idx) {} +bool OperatorGraphOutput::operator==(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) == std::tie(other.node, other.idx); +} +bool OperatorGraphOutput::operator!=(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) != std::tie(other.node, other.idx); +} +bool OperatorGraphOutput::operator<(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) < std::tie(other.node, other.idx); +} +bool OperatorGraphOutput::operator>(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) > std::tie(other.node, other.idx); +} +bool OperatorGraphOutput::operator<=(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) <= std::tie(other.node, other.idx); +} +bool OperatorGraphOutput::operator>=(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) >= std::tie(other.node, other.idx); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorGraphOutput const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.idx) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OperatorGraphOutput const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OperatorGraphOutput const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_guid_t.dtg.cc b/lib/pcg/src/pcg/operator_guid_t.dtg.cc new file mode 100644 index 0000000000..46b031f7e1 --- /dev/null +++ b/lib/pcg/src/pcg/operator_guid_t.dtg.cc @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_guid_t.struct.toml +/* proj-data +{ + "generated_from": "348b5a610f4ff6f545884564ee9a1e6a" +} +*/ + +#include "pcg/operator_guid_t.dtg.h" + +#include "utils/graph.h" +#include + +namespace FlexFlow { +operator_guid_t::operator_guid_t(::FlexFlow::Node const &raw_graph_node) + : raw_graph_node(raw_graph_node) {} +bool operator_guid_t::operator==(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) == std::tie(other.raw_graph_node); +} +bool operator_guid_t::operator!=(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) != std::tie(other.raw_graph_node); +} +bool operator_guid_t::operator<(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) < std::tie(other.raw_graph_node); +} +bool operator_guid_t::operator>(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) > std::tie(other.raw_graph_node); +} +bool operator_guid_t::operator<=(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) <= std::tie(other.raw_graph_node); +} +bool operator_guid_t::operator>=(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) >= std::tie(other.raw_graph_node); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::operator_guid_t const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.raw_graph_node) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(operator_guid_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, operator_guid_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/optimizers/adam_optimizer_attrs.dtg.cc b/lib/pcg/src/pcg/optimizers/adam_optimizer_attrs.dtg.cc new file mode 100644 index 0000000000..d362459cc3 --- /dev/null +++ b/lib/pcg/src/pcg/optimizers/adam_optimizer_attrs.dtg.cc @@ -0,0 +1,192 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "f49e1bebcb0ef2bc3c210073e3183d4d" +} +*/ + +#include "pcg/optimizers/adam_optimizer_attrs.dtg.h" + +#include + +namespace FlexFlow { +AdamOptimizerAttrs::AdamOptimizerAttrs(double const &alpha, + double const &beta1, + double const &beta2, + double const &weight_decay, + double const &alpha_t, + double const &beta_t, + double const &beta2_t) + : alpha(alpha), beta1(beta1), beta2(beta2), weight_decay(weight_decay), + alpha_t(alpha_t), beta_t(beta_t), beta2_t(beta2_t) {} +bool AdamOptimizerAttrs::operator==(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) == std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +bool AdamOptimizerAttrs::operator!=(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) != std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +bool AdamOptimizerAttrs::operator<(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) < std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +bool AdamOptimizerAttrs::operator>(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) > std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +bool AdamOptimizerAttrs::operator<=(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) <= std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +bool AdamOptimizerAttrs::operator>=(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) >= std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::AdamOptimizerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.alpha) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.beta1) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.beta2) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.weight_decay) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.alpha_t) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.beta_t) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.beta2_t) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::AdamOptimizerAttrs + adl_serializer::from_json(json const &j) { + return {j.at("alpha").template get(), + j.at("beta1").template get(), + j.at("beta2").template get(), + j.at("weight_decay").template get(), + j.at("alpha_t").template get(), + j.at("beta_t").template get(), + j.at("beta2_t").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::AdamOptimizerAttrs const &v) { + j["__type"] = "AdamOptimizerAttrs"; + j["alpha"] = v.alpha; + j["beta1"] = v.beta1; + j["beta2"] = v.beta2; + j["weight_decay"] = v.weight_decay; + j["alpha_t"] = v.alpha_t; + j["beta_t"] = v.beta_t; + j["beta2_t"] = v.beta2_t; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(AdamOptimizerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, AdamOptimizerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/optimizers/sgd_optimizer_attrs.dtg.cc b/lib/pcg/src/pcg/optimizers/sgd_optimizer_attrs.dtg.cc new file mode 100644 index 0000000000..d5e668917b --- /dev/null +++ b/lib/pcg/src/pcg/optimizers/sgd_optimizer_attrs.dtg.cc @@ -0,0 +1,111 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "d18c91cdddc760f1fb3990d2c817ee87" +} +*/ + +#include "pcg/optimizers/sgd_optimizer_attrs.dtg.h" + +#include + +namespace FlexFlow { +SGDOptimizerAttrs::SGDOptimizerAttrs(double const &lr, + double const &momentum, + bool const &nesterov, + double const &weight_decay) + : lr(lr), momentum(momentum), nesterov(nesterov), + weight_decay(weight_decay) {} +bool SGDOptimizerAttrs::operator==(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) == + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +bool SGDOptimizerAttrs::operator!=(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) != + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +bool SGDOptimizerAttrs::operator<(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) < + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +bool SGDOptimizerAttrs::operator>(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) > + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +bool SGDOptimizerAttrs::operator<=(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) <= + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +bool SGDOptimizerAttrs::operator>=(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) >= + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::SGDOptimizerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.lr) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.momentum) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.nesterov) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.weight_decay) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::SGDOptimizerAttrs + adl_serializer::from_json(json const &j) { + return {j.at("lr").template get(), + j.at("momentum").template get(), + j.at("nesterov").template get(), + j.at("weight_decay").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::SGDOptimizerAttrs const &v) { + j["__type"] = "SGDOptimizerAttrs"; + j["lr"] = v.lr; + j["momentum"] = v.momentum; + j["nesterov"] = v.nesterov; + j["weight_decay"] = v.weight_decay; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(SGDOptimizerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, SGDOptimizerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph.dtg.cc new file mode 100644 index 0000000000..e4e1555b4a --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph.dtg.cc @@ -0,0 +1,21 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_computation_graph.struct.toml +/* proj-data +{ + "generated_from": "e4db0f603f7b8947dda13e01f96c40fb" +} +*/ + +#include "pcg/parallel_computation_graph.dtg.h" + +#include "pcg/dataflow_graph.h" +#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" + +namespace FlexFlow { +ParallelComputationGraph::ParallelComputationGraph( + ::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc b/lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc new file mode 100644 index 0000000000..455fb22baf --- /dev/null +++ b/lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc @@ -0,0 +1,83 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_layer_attrs.struct.toml +/* proj-data +{ + "generated_from": "97fa0b11c59ae892a8a530ffd67e33ad" +} +*/ + +#include "pcg/parallel_layer_attrs.dtg.h" + +#include "op-attrs/operator_attrs.h" +#include "utils/stack_string.h" +#include +#include + +namespace FlexFlow { +ParallelLayerAttrs::ParallelLayerAttrs( + ::FlexFlow::PCGOperatorAttrs const &attrs, + std::optional<::FlexFlow::stack_string> const &name) + : attrs(attrs), name(name) {} +bool ParallelLayerAttrs::operator==(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) == std::tie(other.attrs, other.name); +} +bool ParallelLayerAttrs::operator!=(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) != std::tie(other.attrs, other.name); +} +bool ParallelLayerAttrs::operator<(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) < std::tie(other.attrs, other.name); +} +bool ParallelLayerAttrs::operator>(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) > std::tie(other.attrs, other.name); +} +bool ParallelLayerAttrs::operator<=(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) <= std::tie(other.attrs, other.name); +} +bool ParallelLayerAttrs::operator>=(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) >= std::tie(other.attrs, other.name); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelLayerAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::PCGOperatorAttrs>{}(x.attrs) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash>>{}(x.name) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelLayerAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("attrs").template get<::FlexFlow::PCGOperatorAttrs>(), + j.at("name") + .template get>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelLayerAttrs const &v) { + j["__type"] = "ParallelLayerAttrs"; + j["attrs"] = v.attrs; + j["name"] = v.name; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelLayerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ParallelLayerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc b/lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc new file mode 100644 index 0000000000..ae5d618172 --- /dev/null +++ b/lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc @@ -0,0 +1,134 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml +/* proj-data +{ + "generated_from": "b3e086b380bbc41d99332e1463a34b28" +} +*/ + +#include "pcg/parallel_tensor_attrs.dtg.h" + +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/param_sync.dtg.h" +#include "pcg/create_grad.dtg.h" +#include "pcg/initializer_attrs.dtg.h" +#include +#include + +namespace FlexFlow { +ParallelTensorAttrs::ParallelTensorAttrs( + ::FlexFlow::ParallelTensorShape const &shape, + std::optional<::FlexFlow::ParamSync> const &sync_type, + std::optional<::FlexFlow::InitializerAttrs> const &initializer, + ::FlexFlow::CreateGrad const &create_gradients) + : shape(shape), sync_type(sync_type), initializer(initializer), + create_gradients(create_gradients) {} +bool ParallelTensorAttrs::operator==(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) == std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +bool ParallelTensorAttrs::operator!=(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) != std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +bool ParallelTensorAttrs::operator<(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) < std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +bool ParallelTensorAttrs::operator>(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) > std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +bool ParallelTensorAttrs::operator<=(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) <= std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +bool ParallelTensorAttrs::operator>=(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) >= std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelTensorAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ParallelTensorShape>{}(x.shape) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash>{}(x.sync_type) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash>{}(x.initializer) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::CreateGrad>{}(x.create_gradients) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelTensorAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("shape").template get<::FlexFlow::ParallelTensorShape>(), + j.at("sync_type").template get>(), + j.at("initializer") + .template get>(), + j.at("create_gradients").template get<::FlexFlow::CreateGrad>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelTensorAttrs const &v) { + j["__type"] = "ParallelTensorAttrs"; + j["shape"] = v.shape; + j["sync_type"] = v.sync_type; + j["initializer"] = v.initializer; + j["create_gradients"] = v.create_gradients; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelTensorAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ParallelTensorAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/side_size_t.dtg.cc b/lib/pcg/src/pcg/side_size_t.dtg.cc new file mode 100644 index 0000000000..54db2974fe --- /dev/null +++ b/lib/pcg/src/pcg/side_size_t.dtg.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/side_size_t.struct.toml +/* proj-data +{ + "generated_from": "6a1669890e547dcc7a4ddb90be05be15" +} +*/ + +#include "pcg/side_size_t.dtg.h" + +#include + +namespace FlexFlow { +side_size_t::side_size_t(int const &unwrapped) : unwrapped(unwrapped) {} +bool side_size_t::operator==(side_size_t const &other) const { + return std::tie(this->unwrapped) == std::tie(other.unwrapped); +} +bool side_size_t::operator!=(side_size_t const &other) const { + return std::tie(this->unwrapped) != std::tie(other.unwrapped); +} +bool side_size_t::operator<(side_size_t const &other) const { + return std::tie(this->unwrapped) < std::tie(other.unwrapped); +} +bool side_size_t::operator>(side_size_t const &other) const { + return std::tie(this->unwrapped) > std::tie(other.unwrapped); +} +bool side_size_t::operator<=(side_size_t const &other) const { + return std::tie(this->unwrapped) <= std::tie(other.unwrapped); +} +bool side_size_t::operator>=(side_size_t const &other) const { + return std::tie(this->unwrapped) >= std::tie(other.unwrapped); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::side_size_t const &x) const { + size_t result = 0; + result ^= std::hash{}(x.unwrapped) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::side_size_t + adl_serializer::from_json(json const &j) { + return {j.at("unwrapped").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::side_size_t const &v) { + j["__type"] = "side_size_t"; + j["unwrapped"] = v.unwrapped; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(side_size_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, side_size_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/strided_rectangle.dtg.cc b/lib/pcg/src/pcg/strided_rectangle.dtg.cc new file mode 100644 index 0000000000..e743a2722a --- /dev/null +++ b/lib/pcg/src/pcg/strided_rectangle.dtg.cc @@ -0,0 +1,86 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/strided_rectangle.struct.toml +/* proj-data +{ + "generated_from": "817bbe017d179aa469822a4032d08836" +} +*/ + +#include "pcg/strided_rectangle.dtg.h" + +#include "op-attrs/dim_ordered.h" +#include "pcg/strided_rectangle_side.dtg.h" +#include + +namespace FlexFlow { +StridedRectangle::StridedRectangle( + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide> const &sides) + : sides(sides) {} +bool StridedRectangle::operator==(StridedRectangle const &other) const { + return std::tie(this->sides) == std::tie(other.sides); +} +bool StridedRectangle::operator!=(StridedRectangle const &other) const { + return std::tie(this->sides) != std::tie(other.sides); +} +bool StridedRectangle::operator<(StridedRectangle const &other) const { + return std::tie(this->sides) < std::tie(other.sides); +} +bool StridedRectangle::operator>(StridedRectangle const &other) const { + return std::tie(this->sides) > std::tie(other.sides); +} +bool StridedRectangle::operator<=(StridedRectangle const &other) const { + return std::tie(this->sides) <= std::tie(other.sides); +} +bool StridedRectangle::operator>=(StridedRectangle const &other) const { + return std::tie(this->sides) >= std::tie(other.sides); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::StridedRectangle const &x) const { + size_t result = 0; + result ^= + std::hash<::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>>{}( + x.sides) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::StridedRectangle + adl_serializer::from_json(json const &j) { + return {j.at("sides") + .template get< + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::StridedRectangle const &v) { + j["__type"] = "StridedRectangle"; + j["sides"] = v.sides; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary< + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(StridedRectangle const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, StridedRectangle const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/strided_rectangle_side.cc b/lib/pcg/src/pcg/strided_rectangle_side.cc new file mode 100644 index 0000000000..80258886d7 --- /dev/null +++ b/lib/pcg/src/pcg/strided_rectangle_side.cc @@ -0,0 +1,15 @@ +#include "pcg/strided_rectangle_side.h" +#include "utils/exception.h" + +namespace FlexFlow { + +StridedRectangleSide strided_side_from_size_and_stride(side_size_t, + int stride) { + NOT_IMPLEMENTED(); +} + +side_size_t get_side_size(StridedRectangleSide const &s) { + return s.num_points.unwrapped * s.stride; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc b/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc new file mode 100644 index 0000000000..0bb31b0496 --- /dev/null +++ b/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc @@ -0,0 +1,91 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/strided_rectangle_side.struct.toml +/* proj-data +{ + "generated_from": "b14fcf1e28c262d22b92fac691ede3d4" +} +*/ + +#include "pcg/strided_rectangle_side.dtg.h" + +#include "pcg/num_points_t.dtg.h" +#include + +namespace FlexFlow { +StridedRectangleSide::StridedRectangleSide( + ::FlexFlow::num_points_t const &num_points, int const &stride) + : num_points(num_points), stride(stride) {} +bool StridedRectangleSide::operator==(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) == + std::tie(other.num_points, other.stride); +} +bool StridedRectangleSide::operator!=(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) != + std::tie(other.num_points, other.stride); +} +bool StridedRectangleSide::operator<(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) < + std::tie(other.num_points, other.stride); +} +bool StridedRectangleSide::operator>(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) > + std::tie(other.num_points, other.stride); +} +bool StridedRectangleSide::operator<=(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) <= + std::tie(other.num_points, other.stride); +} +bool StridedRectangleSide::operator>=(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) >= + std::tie(other.num_points, other.stride); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::StridedRectangleSide const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::num_points_t>{}(x.num_points) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stride) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::StridedRectangleSide + adl_serializer::from_json(json const &j) { + return {j.at("num_points").template get<::FlexFlow::num_points_t>(), + j.at("stride").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::StridedRectangleSide const &v) { + j["__type"] = "StridedRectangleSide"; + j["num_points"] = v.num_points; + j["stride"] = v.stride; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::num_points_t>(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(StridedRectangleSide const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, StridedRectangleSide const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/tensor_attrs.dtg.cc b/lib/pcg/src/pcg/tensor_attrs.dtg.cc new file mode 100644 index 0000000000..46a6fb8d50 --- /dev/null +++ b/lib/pcg/src/pcg/tensor_attrs.dtg.cc @@ -0,0 +1,133 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/tensor_attrs.struct.toml +/* proj-data +{ + "generated_from": "68447a4357476647ef25dd39dfd12578" +} +*/ + +#include "pcg/tensor_attrs.dtg.h" + +#include "op-attrs/param_sync.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "pcg/initializer_attrs.dtg.h" +#include +#include + +namespace FlexFlow { +TensorAttrs::TensorAttrs( + ::FlexFlow::TensorShape const &shape, + std::optional<::FlexFlow::InitializerAttrs> const &initializer, + bool const &create_gradients, + std::optional<::FlexFlow::ParamSync> const &sync_type) + : shape(shape), initializer(initializer), + create_gradients(create_gradients), sync_type(sync_type) {} +bool TensorAttrs::operator==(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) == std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +bool TensorAttrs::operator!=(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) != std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +bool TensorAttrs::operator<(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) < std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +bool TensorAttrs::operator>(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) > std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +bool TensorAttrs::operator<=(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) <= std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +bool TensorAttrs::operator>=(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) >= std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorShape>{}(x.shape) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash>{}(x.initializer) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.create_gradients) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash>{}(x.sync_type) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("shape").template get<::FlexFlow::TensorShape>(), + j.at("initializer") + .template get>(), + j.at("create_gradients").template get(), + j.at("sync_type").template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorAttrs const &v) { + j["__type"] = "TensorAttrs"; + j["shape"] = v.shape; + j["initializer"] = v.initializer; + j["create_gradients"] = v.create_gradients; + j["sync_type"] = v.sync_type; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/tensor_guid_t.dtg.cc b/lib/pcg/src/pcg/tensor_guid_t.dtg.cc new file mode 100644 index 0000000000..9d57291112 --- /dev/null +++ b/lib/pcg/src/pcg/tensor_guid_t.dtg.cc @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/tensor_guid_t.struct.toml +/* proj-data +{ + "generated_from": "dc15fcbb876ec70509dfa8b662963bc3" +} +*/ + +#include "pcg/tensor_guid_t.dtg.h" + +#include "utils/graph.h" +#include + +namespace FlexFlow { +tensor_guid_t::tensor_guid_t(::FlexFlow::MultiDiOutput const &raw_graph_output) + : raw_graph_output(raw_graph_output) {} +bool tensor_guid_t::operator==(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) == std::tie(other.raw_graph_output); +} +bool tensor_guid_t::operator!=(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) != std::tie(other.raw_graph_output); +} +bool tensor_guid_t::operator<(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) < std::tie(other.raw_graph_output); +} +bool tensor_guid_t::operator>(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) > std::tie(other.raw_graph_output); +} +bool tensor_guid_t::operator<=(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) <= std::tie(other.raw_graph_output); +} +bool tensor_guid_t::operator>=(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) >= std::tie(other.raw_graph_output); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::tensor_guid_t const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::MultiDiOutput>{}(x.raw_graph_output) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(tensor_guid_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, tensor_guid_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/serialization.cc b/lib/pcg/src/serialization.cc deleted file mode 100644 index 439c03916d..0000000000 --- a/lib/pcg/src/serialization.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "pcg/serialization.h" - -namespace FlexFlow {} diff --git a/lib/pcg/src/strided_rectangle.cc b/lib/pcg/src/strided_rectangle.cc index 27ef9a7f5b..9c8ff69b42 100644 --- a/lib/pcg/src/strided_rectangle.cc +++ b/lib/pcg/src/strided_rectangle.cc @@ -3,34 +3,23 @@ namespace FlexFlow { -size_t StridedRectangle::at(FFOrdered const &coord) const { - assert(coord.size() == this->num_dims()); - - size_t _1d_stride = 1; - size_t idx = 0; - for (auto dim : inner_to_outer_idxs(this->sides)) { - idx += this->sides.at(dim).at(coord.at(dim)).value() * _1d_stride; - _1d_stride *= this->sides.at(dim).get_size().value(); - } - return idx; -} - -StridedRectangleSide::StridedRectangleSide(side_size_t const &num, int stride) - : num_points(num.value()), stride(stride) {} - -side_size_t StridedRectangleSide::at(num_points_t) const { - NOT_IMPLEMENTED(); -} - -num_points_t StridedRectangleSide::at(side_size_t) const { - NOT_IMPLEMENTED(); -} - -side_size_t StridedRectangleSide::get_size() const { +/* size_t StridedRectangle::at(FFOrdered const &coord) const { */ +/* assert(coord.size() == this->num_dims()); */ + +/* size_t _1d_stride = 1; */ +/* size_t idx = 0; */ +/* for (auto dim : inner_to_outer_idxs(this->sides)) { */ +/* idx += this->sides.at(dim).at(coord.at(dim)).value() * _1d_stride; */ +/* _1d_stride *= this->sides.at(dim).get_size().value(); */ +/* } */ +/* return idx; */ +/* } */ + +size_t get_num_dims(StridedRectangle const &) { NOT_IMPLEMENTED(); } -size_t StridedRectangle::num_dims() const { +size_t get_side_at_idx(StridedRectangle const &) { NOT_IMPLEMENTED(); } diff --git a/lib/pcg/src/tensor.cc b/lib/pcg/src/tensor.cc deleted file mode 100644 index df29ee0065..0000000000 --- a/lib/pcg/src/tensor.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "pcg/tensor.h" - -namespace FlexFlow { - -Tensor::operator TensorShape() const { - return TensorShape{dims, data_type}; -} - -TensorShape Tensor::get_shape() const { - return TensorShape(*this); -} - -Tensor construct_tensor_from_output_shape(TensorShape const &shape) { - return Tensor{shape.dims, shape.data_type, std::nullopt, false, std::nullopt}; -} - -} // namespace FlexFlow diff --git a/lib/pcg/test/CMakeLists.txt b/lib/pcg/test/CMakeLists.txt new file mode 100644 index 0000000000..685d1d8b88 --- /dev/null +++ b/lib/pcg/test/CMakeLists.txt @@ -0,0 +1,13 @@ +ff_add_test_executable( + NAME + pcg-tests + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + src/ + DEPS + utils + pcg + doctest + utils-test-common +) diff --git a/lib/pcg/test/src/test_computation_graph_builder.cc b/lib/pcg/test/src/test_computation_graph_builder.cc new file mode 100644 index 0000000000..e88e231bd0 --- /dev/null +++ b/lib/pcg/test/src/test_computation_graph_builder.cc @@ -0,0 +1,28 @@ +#include "doctest/doctest.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("ComputationGraphBuilder") { + ComputationGraphBuilder b; + + size_t batch_size = 2; + + TensorShape input_shape = { + TensorDims{FFOrdered{batch_size, 3, 10, 10}}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_tensor(input_shape, /*create_grad=*/true); + tensor_guid_t output = b.conv2d(input, + /*outChannels=*/5, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/0, + /*paddingW=*/0); + // ComputationGraph cg = b.computation_graph; + // CHECK(get_layers(cg).size() == 1); + } +} diff --git a/lib/runtime/src/parallel_op_info.h b/lib/runtime/src/parallel_op_info.h index 11e2f03477..ebd44f012b 100644 --- a/lib/runtime/src/parallel_op_info.h +++ b/lib/runtime/src/parallel_op_info.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_PARALLEL_OPS_PARALLEL_OP_INFO_H #include "op-attrs/ff_dim.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" #include "utils/visitable.h" #include #include diff --git a/lib/substitution-generator/include/substitution-generator/json.h b/lib/substitution-generator/include/substitution-generator/json.h index dbde110f8d..5563d8a835 100644 --- a/lib/substitution-generator/include/substitution-generator/json.h +++ b/lib/substitution-generator/include/substitution-generator/json.h @@ -1,166 +1,16 @@ #ifndef _FLEXFLOW_SUBSTITUTION_LOADER_H #define _FLEXFLOW_SUBSTITUTION_LOADER_H -#include "op-attrs/op.h" +#include "substitution-generator/legacy_operator_type.dtg.h" +#include "substitution-generator/legacy_pm_parameter.dtg.h" #include #include +#include namespace FlexFlow { -enum PMParameter { - PM_OP_TYPE, // AnyOp - PM_NUM_INPUTS, // AnyOp - PM_NUM_OUTPUTS, // AnyOp - PM_GROUP, // Conv2D - PM_KERNEL_H, // Conv2D, Pool2D - PM_KERNEL_W, // Conv2D, Pool2D - PM_STRIDE_H, // Conv2D, Pool2D - PM_STRIDE_W, // Conv2D, Pool2D - PM_PADDING_H, // Conv2D, Pool2D - PM_PADDING_W, // Conv2D, Pool2D - PM_ACTI, // Conv2D, Pool2D - PM_NUMDIM, // Concat, Transpose - PM_AXIS, // Concat, Split - PM_PERM, // Transpose - PM_OUTSHUFFLE, // Transpose - PM_MERGE_GCONV_COUNT, // MergeGConv - PM_AXES, // Squeeze, Unsqueeze, Reduce* - PM_KEEP_DIMS, // Reduce* - PM_EPSILON, // BatchNorm - PM_REPARTITION_DIM, // Repartition - PM_REPARTITION_DEGREE, // Repartition - PM_REPLICATE_DIM, // Replicate - PM_REPLICATE_DEGREE, // Replicate - PM_COMBINE_DIM, // Combine - PM_COMBINE_DEGREE, // Combine - PM_REDUCTION_DIM, // Reduction - PM_REDUCTION_DEGREE, // Reduction - PM_SOFTMAX_DIM, // Softmax - PM_NUM_HEADS, // MultiHeadAttention - PM_INVALID, - PM_PARALLEL_DIM, - PM_PARALLEL_DEGREE, - PM_PAD, -}; - -NLOHMANN_JSON_SERIALIZE_ENUM(PMParameter, - {{PM_INVALID, nullptr}, - {PM_OP_TYPE, "PM_OP_TYPE"}, - {PM_NUM_INPUTS, "PM_NUM_INPUTS"}, - {PM_NUM_OUTPUTS, "PM_NUM_OUTPUTS"}, - {PM_GROUP, "PM_GROUP"}, - {PM_KERNEL_H, "PM_KERNEL_H"}, - {PM_KERNEL_W, "PM_KERNEL_W"}, - {PM_STRIDE_H, "PM_STRIDE_H"}, - {PM_STRIDE_W, "PM_STRIDE_W"}, - {PM_PADDING_H, "PM_PADDING_H"}, - {PM_PADDING_W, "PM_PADDING_W"}, - {PM_ACTI, "PM_ACTI"}, - {PM_NUMDIM, "PM_NUMDIM"}, - {PM_AXIS, "PM_AXIS"}, - {PM_PERM, "PM_PERM"}, - {PM_OUTSHUFFLE, "PM_OUTSHUFFLE"}, - {PM_MERGE_GCONV_COUNT, "PM_MERGE_GCONV_COUNT"}, - {PM_AXES, "PM_AXES"}, - {PM_KEEP_DIMS, "PM_KEEP_DIMS"}, - {PM_EPSILON, "PM_EPSILON"}, - {PM_REPARTITION_DIM, "PM_REPARTITION_DIM"}, - {PM_REPARTITION_DEGREE, "PM_REPARTITION_DEGREE"}, - {PM_REPLICATE_DIM, "PM_REPLICATE_DIM"}, - {PM_REPLICATE_DEGREE, "PM_REPLICATE_DEGREE"}, - {PM_COMBINE_DIM, "PM_COMBINE_DIM"}, - {PM_COMBINE_DEGREE, "PM_COMBINE_DEGREE"}, - {PM_REDUCTION_DIM, "PM_REDUCTION_DIM"}, - {PM_REDUCTION_DEGREE, "PM_REDUCTION_DEGREE"}, - {PM_SOFTMAX_DIM, "PM_SOFTMAX_DIM"}, - {PM_NUM_HEADS, "PM_NUM_HEADS"}, - {PM_PARALLEL_DIM, "PM_PARALLEL_DIM"}, - {PM_PARALLEL_DEGREE, "PM_PARALLEL_DEGREE"}, - {PM_PAD, "PM_PAD"}}) - -NLOHMANN_JSON_SERIALIZE_ENUM(Op, - {{Op::NOOP, "OP_NOOP"}, - {Op::CONV2D, "OP_CONV2D"}, - {Op::DROPOUT, "OP_DROPOUT"}, - {Op::LINEAR, "OP_LINEAR"}, - {Op::BATCHMATMUL, "OP_BATCHMATMUL"}, - {Op::POOL2D, "OP_POOL2D_MAX"}, - {Op::SCALAR_MULTIPLY, "OP_SCALAR_MULTIPLY"}, - {Op::SCALAR_ADD, "OP_SCALAR_ADD"}, - {Op::SCALAR_FLOOR_DIV, "OP_SCALAR_FLOOR_DIV"}, - {Op::SCALAR_TRUE_DIV, "OP_SCALAR_TRUE_DIV"}, - {Op::SCALAR_SUB, "OP_SCALAR_SUB"}, - {Op::RELU, "OP_RELU"}, - {Op::IDENTITY, "OP_IDENTITY"}, - {Op::SIGMOID, "OP_SIGMOID"}, - {Op::TANH, "OP_TANH"}, - {Op::ELU, "OP_ELU"}, - {Op::FLAT, "OP_FLAT"}, - {Op::SOFTMAX, "OP_SOFTMAX"}, - {Op::BATCHNORM, "OP_BATCHNORM"}, - {Op::CONCAT, "OP_CONCAT"}, - {Op::SPLIT, "OP_SPLIT"}, - {Op::EMBEDDING, "OP_EMBEDDING"}, - {Op::CACHE, "OP_CACHE"}, - {Op::RESHAPE, "OP_RESHAPE"}, - {Op::REVERSE, "OP_REVERSE"}, - {Op::TRANSPOSE, "OP_TRANSPOSE"}, - {Op::EW_ADD, "OP_EW_ADD"}, - {Op::EW_MUL, "OP_EW_MUL"}, - {Op::MATMUL, "OP_MATMUL"}, - {Op::MUL, "OP_MUL"}, - {Op::ENLARGE, "OP_ENLARGE"}, - {Op::SQUEEZE, "OP_SQUEEZE"}, - {Op::UNSQUEEZE, "OP_UNSQUEEZE"}, - {Op::EW_SUB, "OP_EW_SUB"}, - {Op::EW_DIV, "OP_EW_DIV"}, - {Op::EW_EQUAL, "OP_EW_EQUAL"}, - {Op::EW_GREATER, "OP_EW_GREATER"}, - {Op::EW_LESS, "OP_EW_LESS"}, - {Op::EW_MAX, "OP_EW_MAX"}, - {Op::EW_MIN, "OP_EW_MIN"}, - {Op::REDUCE_ARGMAX, "OP_REDUCE_ARGMAX"}, - {Op::REDUCE_ARGMIN, "OP_REDUCE_ARGMIN"}, - {Op::REDUCE_MAX, "OP_REDUCE_MAX"}, - {Op::REDUCE_MEAN, "OP_REDUCE_MEAN"}, - {Op::REDUCE_MIN, "OP_REDUCE_MIN"}, - {Op::REDUCE_PROD, "OP_REDUCE_PROD"}, - {Op::REDUCE_SUM, "OP_REDUCE_SUM"}, - {Op::PAD, "OP_PAD"}, - {Op::SHAPE, "OP_SHAPE"}, - {Op::SIZE, "OP_SIZE"}, - {Op::TOPK, "OP_TOPK"}, - {Op::WHERE, "OP_WHERE"}, - {Op::CEIL, "OP_CEIL"}, - {Op::CAST, "OP_CAST"}, - {Op::EXP, "OP_EXP"}, - {Op::ROUND, "OP_ROUND"}, - {Op::LOG, "OP_LOG"}, - {Op::LOGICAL_NOT, "OP_LOGICAL_NOT"}, - {Op::SQRT, "OP_SQRT"}, - {Op::SIN, "OP_SIN"}, - {Op::COS, "OP_COS"}, - {Op::LEAKYRELU, "OP_LEAKYRELU"}, - {Op::SLICE, "OP_SLICE"}, - {Op::RESIZE, "OP_RESIZE"}, - {Op::PRELU, "OP_PRELU"}, - {Op::GELU, "OP_GELU"}, - {Op::MULTIHEAD_ATTENTION, - "OP_MULTIHEAD_ATTENTION"}, - {Op::FUSED, "OP_FUSED"}, - {Op::RSQRT, "OP_RSQRT"}, - {Op::POW, "OP_POW"}, - {Op::MEAN, "OP_MEAN"}, - {Op::LAYERNORM, "OP_LAYERNORM"}, - {Op::REPARTITION, "OP_PARTITION"}, - {Op::COMBINE, "OP_COMBINE"}, - {Op::REPLICATE, "OP_REPLICATE"}, - {Op::REDUCTION, "OP_REDUCE"}, - {Op::PIPELINE, "OP_PIPELINE"}, - {Op::FUSED_PARALLEL, "OP_FUSED_PARALLEL"}}) - struct Parameter { - PMParameter key; + LegacyPMParameter key; int value; }; void from_json(nlohmann::json const &j, Parameter &p); @@ -172,11 +22,11 @@ struct Tensor { void from_json(nlohmann::json const &j, Tensor &t); struct Operator { - OperatorType op_type; + LegacyOperatorType op_type; std::vector input; std::vector para; - std::optional at(PMParameter key) const; + std::optional at(LegacyPMParameter key) const; }; void from_json(nlohmann::json const &j, Operator &t); diff --git a/lib/substitution-generator/include/substitution-generator/legacy_operator_type.dtg.h b/lib/substitution-generator/include/substitution-generator/legacy_operator_type.dtg.h new file mode 100644 index 0000000000..a74f6ef4f3 --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_operator_type.dtg.h @@ -0,0 +1,124 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml +/* proj-data +{ + "generated_from": "d6ba52e2b0d58b7cb533dae3894b0486" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_OPERATOR_TYPE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_OPERATOR_TYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class LegacyOperatorType { + NOOP, + INPUT, + WEIGHT, + CONV2D, + DROPOUT, + LINEAR, + BATCHMATMUL, + POOL2D, + SCALAR_MULTIPLY, + SCALAR_ADD, + SCALAR_FLOOR_DIV, + SCALAR_TRUE_DIV, + SCALAR_SUB, + RELU, + IDENTITY, + SIGMOID, + TANH, + ELU, + FLAT, + SOFTMAX, + BATCHNORM, + CONCAT, + SPLIT, + EMBEDDING, + CACHE, + RESHAPE, + REVERSE, + TRANSPOSE, + EW_ADD, + EW_MUL, + MATMUL, + MUL, + ENLARGE, + SQUEEZE, + UNSQUEEZE, + EW_SUB, + EW_DIV, + EW_EQUAL, + EW_GREATER, + EW_LESS, + EW_MAX, + EW_MIN, + REDUCE_ARGMAX, + REDUCE_ARGMIN, + REDUCE_MAX, + REDUCE_MEAN, + REDUCE_MIN, + REDUCE_PROD, + REDUCE_SUM, + PAD, + SHAPE, + SIZE, + TOPK, + WHERE, + CEIL, + CAST, + EXP, + ROUND, + LOG, + LOGICAL_NOT, + SQRT, + SIN, + COS, + LEAKYRELU, + SLICE, + RESIZE, + PRELU, + GELU, + MULTIHEAD_ATTENTION, + FUSED, + RSQRT, + POW, + MEAN, + LAYERNORM, + GATHER, + BROADCAST, + REPARTITION, + COMBINE, + REPLICATE, + REDUCTION, + BATCH, + PIPELINE, + FUSED_PARALLEL +}; +std::string format_as(LegacyOperatorType); +std::ostream &operator<<(std::ostream &, LegacyOperatorType); +void to_json(::nlohmann::json &, LegacyOperatorType); +void from_json(::nlohmann::json const &, LegacyOperatorType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LegacyOperatorType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_OPERATOR_TYPE_DTG_H diff --git a/lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml b/lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml new file mode 100644 index 0000000000..3f0bcccf6f --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml @@ -0,0 +1,95 @@ +namespace = "FlexFlow" +name = "LegacyOperatorType" + +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +values = [ + { name = "NOOP", json_key = "OP_NOOP" }, + { name = "INPUT", json_key = "OP_INPUT" }, + { name = "WEIGHT", json_key = "OP_WEIGHT" }, + { name = "CONV2D", json_key = "OP_CONV2D" }, + { name = "DROPOUT", json_key = "OP_DROPOUT" }, + { name = "LINEAR", json_key = "OP_LINEAR" }, + { name = "BATCHMATMUL", json_key = "OP_BATCHMATMUL" }, + { name = "POOL2D", json_key = "OP_POOL2D" }, + { name = "SCALAR_MULTIPLY", json_key = "OP_SCALAR_MULTIPLY" }, + { name = "SCALAR_ADD", json_key = "OP_SCALAR_ADD" }, + { name = "SCALAR_FLOOR_DIV", json_key = "OP_SCALAR_FLOOR_DIV" }, + { name = "SCALAR_TRUE_DIV", json_key = "OP_SCALAR_TRUE_DIV" }, + { name = "SCALAR_SUB", json_key = "OP_SCALAR_SUB" }, + { name = "RELU", json_key = "OP_RELU" }, + { name = "IDENTITY", json_key = "OP_IDENTITY" }, + { name = "SIGMOID", json_key = "OP_SIGMOID" }, + { name = "TANH", json_key = "OP_TANH" }, + { name = "ELU", json_key = "OP_ELU" }, + { name = "FLAT", json_key = "OP_FLAT" }, + { name = "SOFTMAX", json_key = "OP_SOFTMAX" }, + { name = "BATCHNORM", json_key = "OP_BATCHNORM" }, + { name = "CONCAT", json_key = "OP_CONCAT" }, + { name = "SPLIT", json_key = "OP_SPLIT" }, + { name = "EMBEDDING", json_key = "OP_EMBEDDING" }, + { name = "CACHE", json_key = "OP_CACHE" }, + { name = "RESHAPE", json_key = "OP_RESHAPE" }, + { name = "REVERSE", json_key = "OP_REVERSE" }, + { name = "TRANSPOSE", json_key = "OP_TRANSPOSE" }, + { name = "EW_ADD", json_key = "OP_EW_ADD" }, + { name = "EW_MUL", json_key = "OP_EW_MUL" }, + { name = "MATMUL", json_key = "OP_MATMUL" }, + { name = "MUL", json_key = "OP_MUL" }, + { name = "ENLARGE", json_key = "OP_ENLARGE" }, + { name = "SQUEEZE", json_key = "OP_SQUEEZE" }, + { name = "UNSQUEEZE", json_key = "OP_UNSQUEEZE" }, + { name = "EW_SUB", json_key = "OP_EW_SUB" }, + { name = "EW_DIV", json_key = "OP_EW_DIV" }, + { name = "EW_EQUAL", json_key = "OP_EW_EQUAL" }, + { name = "EW_GREATER", json_key = "OP_EW_GREATER" }, + { name = "EW_LESS", json_key = "OP_EW_LESS" }, + { name = "EW_MAX", json_key = "OP_EW_MAX" }, + { name = "EW_MIN", json_key = "OP_EW_MIN" }, + { name = "REDUCE_ARGMAX", json_key = "OP_REDUCE_ARGMAX" }, + { name = "REDUCE_ARGMIN", json_key = "OP_REDUCE_ARGMIN" }, + { name = "REDUCE_MAX", json_key = "OP_REDUCE_MAX" }, + { name = "REDUCE_MEAN", json_key = "OP_REDUCE_MEAN" }, + { name = "REDUCE_MIN", json_key = "OP_REDUCE_MIN" }, + { name = "REDUCE_PROD", json_key = "OP_REDUCE_PROD" }, + { name = "REDUCE_SUM", json_key = "OP_REDUCE_SUM" }, + { name = "PAD", json_key = "OP_PAD" }, + { name = "SHAPE", json_key = "OP_SHAPE" }, + { name = "SIZE", json_key = "OP_SIZE" }, + { name = "TOPK", json_key = "OP_TOPK" }, + { name = "WHERE", json_key = "OP_WHERE" }, + { name = "CEIL", json_key = "OP_CEIL" }, + { name = "CAST", json_key = "OP_CAST" }, + { name = "EXP", json_key = "OP_EXP" }, + { name = "ROUND", json_key = "OP_ROUND" }, + { name = "LOG", json_key = "OP_LOG" }, + { name = "LOGICAL_NOT", json_key = "OP_LOGICAL_NOT" }, + { name = "SQRT", json_key = "OP_SQRT" }, + { name = "SIN", json_key = "OP_SIN" }, + { name = "COS", json_key = "OP_COS" }, + { name = "LEAKYRELU", json_key = "OP_LEAKYRELU" }, + { name = "SLICE", json_key = "OP_SLICE" }, + { name = "RESIZE", json_key = "OP_RESIZE" }, + { name = "PRELU", json_key = "OP_PRELU" }, + { name = "GELU", json_key = "OP_GELU" }, + { name = "MULTIHEAD_ATTENTION", json_key = "OP_MULTIHEAD_ATTENTION" }, + { name = "FUSED", json_key = "OP_FUSED" }, + { name = "RSQRT", json_key = "OP_RSQRT" }, + { name = "POW", json_key = "OP_POW" }, + { name = "MEAN", json_key = "OP_MEAN" }, + { name = "LAYERNORM", json_key = "OP_LAYERNORM" }, + { name = "GATHER", json_key = "OP_GATHER" }, + { name = "BROADCAST", json_key = "OP_BROADCAST" }, + { name = "REPARTITION", json_key = "OP_PARTITION" }, + { name = "COMBINE", json_key = "OP_COMBINE" }, + { name = "REPLICATE", json_key = "OP_REPLICATE" }, + { name = "REDUCTION", json_key = "OP_REDUCE" }, + { name = "BATCH", json_key = "OP_BATCH" }, + { name = "PIPELINE", json_key = "OP_PIPELINE" }, + { name = "FUSED_PARALLEL", json_key = "OP_FUSED_PARALLEL" }, +] diff --git a/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.dtg.h b/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.dtg.h new file mode 100644 index 0000000000..2435024b9a --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.dtg.h @@ -0,0 +1,73 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml +/* proj-data +{ + "generated_from": "e8dda0047c91e576878b86df2fec0b6b" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_PM_PARAMETER_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_PM_PARAMETER_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class LegacyPMParameter { + OP_TYPE, + NUM_INPUTS, + NUM_OUTPUTS, + GROUP, + KERNEL_H, + KERNEL_W, + STRIDE_H, + STRIDE_W, + PADDING_H, + PADDING_W, + ACTI, + NUMDIM, + AXIS, + PERM, + OUTSHUFFLE, + MERGE_GCONV_COUNT, + AXES, + KEEP_DIMS, + EPSILON, + REPARTITION_DIM, + REPARTITION_DEGREE, + REPLICATE_DIM, + REPLICATE_DEGREE, + COMBINE_DIM, + COMBINE_DEGREE, + REDUCTION_DIM, + REDUCTION_DEGREE, + SOFTMAX_DIM, + NUM_HEADS, + PARALLEL_DIM, + PARALLEL_DEGREE, + PAD +}; +std::string format_as(LegacyPMParameter); +std::ostream &operator<<(std::ostream &, LegacyPMParameter); +void to_json(::nlohmann::json &, LegacyPMParameter); +void from_json(::nlohmann::json const &, LegacyPMParameter &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LegacyPMParameter) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_PM_PARAMETER_DTG_H diff --git a/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml b/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml new file mode 100644 index 0000000000..e71a71a5a8 --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml @@ -0,0 +1,44 @@ +namespace = "FlexFlow" +name = "LegacyPMParameter" + +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +values = [ + { name = "OP_TYPE", json_key = "PM_OP_TYPE" }, + { name = "NUM_INPUTS", json_key = "PM_NUM_INPUTS" }, + { name = "NUM_OUTPUTS", json_key = "PM_NUM_OUTPUTS" }, + { name = "GROUP", json_key = "PM_GROUP" }, + { name = "KERNEL_H", json_key = "PM_KERNEL_H" }, + { name = "KERNEL_W", json_key = "PM_KERNEL_W" }, + { name = "STRIDE_H", json_key = "PM_STRIDE_H" }, + { name = "STRIDE_W", json_key = "PM_STRIDE_W" }, + { name = "PADDING_H", json_key = "PM_PADDING_H" }, + { name = "PADDING_W", json_key = "PM_PADDING_W" }, + { name = "ACTI", json_key = "PM_ACTI" }, + { name = "NUMDIM", json_key = "PM_NUMDIM" }, + { name = "AXIS", json_key = "PM_AXIS" }, + { name = "PERM", json_key = "PM_PERM" }, + { name = "OUTSHUFFLE", json_key = "PM_OUTSHUFFLE" }, + { name = "MERGE_GCONV_COUNT", json_key = "PM_MERGE_GCONV_COUNT" }, + { name = "AXES", json_key = "PM_AXES" }, + { name = "KEEP_DIMS", json_key = "PM_KEEP_DIMS" }, + { name = "EPSILON", json_key = "PM_EPSILON" }, + { name = "REPARTITION_DIM", json_key = "PM_REPARTITION_DIM" }, + { name = "REPARTITION_DEGREE", json_key = "PM_REPARTITION_DEGREE" }, + { name = "REPLICATE_DIM", json_key = "PM_REPLICATE_DIM" }, + { name = "REPLICATE_DEGREE", json_key = "PM_REPLICATE_DEGREE" }, + { name = "COMBINE_DIM", json_key = "PM_COMBINE_DIM" }, + { name = "COMBINE_DEGREE", json_key = "PM_COMBINE_DEGREE" }, + { name = "REDUCTION_DIM", json_key = "PM_REDUCTION_DIM" }, + { name = "REDUCTION_DEGREE", json_key = "PM_REDUCTION_DEGREE" }, + { name = "SOFTMAX_DIM", json_key = "PM_SOFTMAX_DIM" }, + { name = "NUM_HEADS", json_key = "PM_NUM_HEADS" }, + { name = "PARALLEL_DIM", json_key = "PM_PARALLEL_DIM" }, + { name = "PARALLEL_DEGREE", json_key = "PM_PARALLEL_DEGREE" }, + { name = "PAD", json_key = "PM_PAD" }, +] diff --git a/lib/substitution-generator/src/substitution-generator/json.cc b/lib/substitution-generator/src/substitution-generator/json.cc index 7e6a93b863..940ecb3e36 100644 --- a/lib/substitution-generator/src/substitution-generator/json.cc +++ b/lib/substitution-generator/src/substitution-generator/json.cc @@ -10,11 +10,6 @@ namespace FlexFlow { void from_json(json const &j, Parameter &p) { j.at("key").get_to(p.key); j.at("value").get_to(p.value); - if (p.key == PM_INVALID) { - std::ostringstream oss; - oss << "Attempted to load invalid PMParameter: " << j.at("key"); - throw std::runtime_error(oss.str()); - } } void from_json(json const &j, Tensor &t) { @@ -22,17 +17,17 @@ void from_json(json const &j, Tensor &t) { j.at("tsId").get_to(t.tsId); } -std::optional Operator::at(PMParameter key) const { - std::optional value = std::nullopt; - for (Parameter const &p : this->para) { - if (p.key == key) { - assert(!value.has_value()); - value = p.key; - } - } +/* std::optional Operator::at(LegacyPMParameter key) const { */ +/* std::optional value = std::nullopt; */ +/* for (Parameter const &p : this->para) { */ +/* if (p.key == key) { */ +/* assert(!value.has_value()); */ +/* value = p.key; */ +/* } */ +/* } */ - return value; -} +/* return value; */ +/* } */ void from_json(json const &j, Operator &o) { j.at("type").get_to(o.op_type); diff --git a/lib/substitution-generator/src/substitution-generator/legacy_operator_type.dtg.cc b/lib/substitution-generator/src/substitution-generator/legacy_operator_type.dtg.cc new file mode 100644 index 0000000000..94c65e33fd --- /dev/null +++ b/lib/substitution-generator/src/substitution-generator/legacy_operator_type.dtg.cc @@ -0,0 +1,721 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml +/* proj-data +{ + "generated_from": "d6ba52e2b0d58b7cb533dae3894b0486" +} +*/ + +#include "substitution-generator/legacy_operator_type.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::LegacyOperatorType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(LegacyOperatorType x) { + switch (x) { + case LegacyOperatorType::NOOP: + return "NOOP"; + case LegacyOperatorType::INPUT: + return "INPUT"; + case LegacyOperatorType::WEIGHT: + return "WEIGHT"; + case LegacyOperatorType::CONV2D: + return "CONV2D"; + case LegacyOperatorType::DROPOUT: + return "DROPOUT"; + case LegacyOperatorType::LINEAR: + return "LINEAR"; + case LegacyOperatorType::BATCHMATMUL: + return "BATCHMATMUL"; + case LegacyOperatorType::POOL2D: + return "POOL2D"; + case LegacyOperatorType::SCALAR_MULTIPLY: + return "SCALAR_MULTIPLY"; + case LegacyOperatorType::SCALAR_ADD: + return "SCALAR_ADD"; + case LegacyOperatorType::SCALAR_FLOOR_DIV: + return "SCALAR_FLOOR_DIV"; + case LegacyOperatorType::SCALAR_TRUE_DIV: + return "SCALAR_TRUE_DIV"; + case LegacyOperatorType::SCALAR_SUB: + return "SCALAR_SUB"; + case LegacyOperatorType::RELU: + return "RELU"; + case LegacyOperatorType::IDENTITY: + return "IDENTITY"; + case LegacyOperatorType::SIGMOID: + return "SIGMOID"; + case LegacyOperatorType::TANH: + return "TANH"; + case LegacyOperatorType::ELU: + return "ELU"; + case LegacyOperatorType::FLAT: + return "FLAT"; + case LegacyOperatorType::SOFTMAX: + return "SOFTMAX"; + case LegacyOperatorType::BATCHNORM: + return "BATCHNORM"; + case LegacyOperatorType::CONCAT: + return "CONCAT"; + case LegacyOperatorType::SPLIT: + return "SPLIT"; + case LegacyOperatorType::EMBEDDING: + return "EMBEDDING"; + case LegacyOperatorType::CACHE: + return "CACHE"; + case LegacyOperatorType::RESHAPE: + return "RESHAPE"; + case LegacyOperatorType::REVERSE: + return "REVERSE"; + case LegacyOperatorType::TRANSPOSE: + return "TRANSPOSE"; + case LegacyOperatorType::EW_ADD: + return "EW_ADD"; + case LegacyOperatorType::EW_MUL: + return "EW_MUL"; + case LegacyOperatorType::MATMUL: + return "MATMUL"; + case LegacyOperatorType::MUL: + return "MUL"; + case LegacyOperatorType::ENLARGE: + return "ENLARGE"; + case LegacyOperatorType::SQUEEZE: + return "SQUEEZE"; + case LegacyOperatorType::UNSQUEEZE: + return "UNSQUEEZE"; + case LegacyOperatorType::EW_SUB: + return "EW_SUB"; + case LegacyOperatorType::EW_DIV: + return "EW_DIV"; + case LegacyOperatorType::EW_EQUAL: + return "EW_EQUAL"; + case LegacyOperatorType::EW_GREATER: + return "EW_GREATER"; + case LegacyOperatorType::EW_LESS: + return "EW_LESS"; + case LegacyOperatorType::EW_MAX: + return "EW_MAX"; + case LegacyOperatorType::EW_MIN: + return "EW_MIN"; + case LegacyOperatorType::REDUCE_ARGMAX: + return "REDUCE_ARGMAX"; + case LegacyOperatorType::REDUCE_ARGMIN: + return "REDUCE_ARGMIN"; + case LegacyOperatorType::REDUCE_MAX: + return "REDUCE_MAX"; + case LegacyOperatorType::REDUCE_MEAN: + return "REDUCE_MEAN"; + case LegacyOperatorType::REDUCE_MIN: + return "REDUCE_MIN"; + case LegacyOperatorType::REDUCE_PROD: + return "REDUCE_PROD"; + case LegacyOperatorType::REDUCE_SUM: + return "REDUCE_SUM"; + case LegacyOperatorType::PAD: + return "PAD"; + case LegacyOperatorType::SHAPE: + return "SHAPE"; + case LegacyOperatorType::SIZE: + return "SIZE"; + case LegacyOperatorType::TOPK: + return "TOPK"; + case LegacyOperatorType::WHERE: + return "WHERE"; + case LegacyOperatorType::CEIL: + return "CEIL"; + case LegacyOperatorType::CAST: + return "CAST"; + case LegacyOperatorType::EXP: + return "EXP"; + case LegacyOperatorType::ROUND: + return "ROUND"; + case LegacyOperatorType::LOG: + return "LOG"; + case LegacyOperatorType::LOGICAL_NOT: + return "LOGICAL_NOT"; + case LegacyOperatorType::SQRT: + return "SQRT"; + case LegacyOperatorType::SIN: + return "SIN"; + case LegacyOperatorType::COS: + return "COS"; + case LegacyOperatorType::LEAKYRELU: + return "LEAKYRELU"; + case LegacyOperatorType::SLICE: + return "SLICE"; + case LegacyOperatorType::RESIZE: + return "RESIZE"; + case LegacyOperatorType::PRELU: + return "PRELU"; + case LegacyOperatorType::GELU: + return "GELU"; + case LegacyOperatorType::MULTIHEAD_ATTENTION: + return "MULTIHEAD_ATTENTION"; + case LegacyOperatorType::FUSED: + return "FUSED"; + case LegacyOperatorType::RSQRT: + return "RSQRT"; + case LegacyOperatorType::POW: + return "POW"; + case LegacyOperatorType::MEAN: + return "MEAN"; + case LegacyOperatorType::LAYERNORM: + return "LAYERNORM"; + case LegacyOperatorType::GATHER: + return "GATHER"; + case LegacyOperatorType::BROADCAST: + return "BROADCAST"; + case LegacyOperatorType::REPARTITION: + return "REPARTITION"; + case LegacyOperatorType::COMBINE: + return "COMBINE"; + case LegacyOperatorType::REPLICATE: + return "REPLICATE"; + case LegacyOperatorType::REDUCTION: + return "REDUCTION"; + case LegacyOperatorType::BATCH: + return "BATCH"; + case LegacyOperatorType::PIPELINE: + return "PIPELINE"; + case LegacyOperatorType::FUSED_PARALLEL: + return "FUSED_PARALLEL"; + default: + std::ostringstream oss; + oss << "Unknown LegacyOperatorType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, LegacyOperatorType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, LegacyOperatorType x) { + switch (x) { + case LegacyOperatorType::NOOP: + j = "OP_NOOP"; + break; + case LegacyOperatorType::INPUT: + j = "OP_INPUT"; + break; + case LegacyOperatorType::WEIGHT: + j = "OP_WEIGHT"; + break; + case LegacyOperatorType::CONV2D: + j = "OP_CONV2D"; + break; + case LegacyOperatorType::DROPOUT: + j = "OP_DROPOUT"; + break; + case LegacyOperatorType::LINEAR: + j = "OP_LINEAR"; + break; + case LegacyOperatorType::BATCHMATMUL: + j = "OP_BATCHMATMUL"; + break; + case LegacyOperatorType::POOL2D: + j = "OP_POOL2D"; + break; + case LegacyOperatorType::SCALAR_MULTIPLY: + j = "OP_SCALAR_MULTIPLY"; + break; + case LegacyOperatorType::SCALAR_ADD: + j = "OP_SCALAR_ADD"; + break; + case LegacyOperatorType::SCALAR_FLOOR_DIV: + j = "OP_SCALAR_FLOOR_DIV"; + break; + case LegacyOperatorType::SCALAR_TRUE_DIV: + j = "OP_SCALAR_TRUE_DIV"; + break; + case LegacyOperatorType::SCALAR_SUB: + j = "OP_SCALAR_SUB"; + break; + case LegacyOperatorType::RELU: + j = "OP_RELU"; + break; + case LegacyOperatorType::IDENTITY: + j = "OP_IDENTITY"; + break; + case LegacyOperatorType::SIGMOID: + j = "OP_SIGMOID"; + break; + case LegacyOperatorType::TANH: + j = "OP_TANH"; + break; + case LegacyOperatorType::ELU: + j = "OP_ELU"; + break; + case LegacyOperatorType::FLAT: + j = "OP_FLAT"; + break; + case LegacyOperatorType::SOFTMAX: + j = "OP_SOFTMAX"; + break; + case LegacyOperatorType::BATCHNORM: + j = "OP_BATCHNORM"; + break; + case LegacyOperatorType::CONCAT: + j = "OP_CONCAT"; + break; + case LegacyOperatorType::SPLIT: + j = "OP_SPLIT"; + break; + case LegacyOperatorType::EMBEDDING: + j = "OP_EMBEDDING"; + break; + case LegacyOperatorType::CACHE: + j = "OP_CACHE"; + break; + case LegacyOperatorType::RESHAPE: + j = "OP_RESHAPE"; + break; + case LegacyOperatorType::REVERSE: + j = "OP_REVERSE"; + break; + case LegacyOperatorType::TRANSPOSE: + j = "OP_TRANSPOSE"; + break; + case LegacyOperatorType::EW_ADD: + j = "OP_EW_ADD"; + break; + case LegacyOperatorType::EW_MUL: + j = "OP_EW_MUL"; + break; + case LegacyOperatorType::MATMUL: + j = "OP_MATMUL"; + break; + case LegacyOperatorType::MUL: + j = "OP_MUL"; + break; + case LegacyOperatorType::ENLARGE: + j = "OP_ENLARGE"; + break; + case LegacyOperatorType::SQUEEZE: + j = "OP_SQUEEZE"; + break; + case LegacyOperatorType::UNSQUEEZE: + j = "OP_UNSQUEEZE"; + break; + case LegacyOperatorType::EW_SUB: + j = "OP_EW_SUB"; + break; + case LegacyOperatorType::EW_DIV: + j = "OP_EW_DIV"; + break; + case LegacyOperatorType::EW_EQUAL: + j = "OP_EW_EQUAL"; + break; + case LegacyOperatorType::EW_GREATER: + j = "OP_EW_GREATER"; + break; + case LegacyOperatorType::EW_LESS: + j = "OP_EW_LESS"; + break; + case LegacyOperatorType::EW_MAX: + j = "OP_EW_MAX"; + break; + case LegacyOperatorType::EW_MIN: + j = "OP_EW_MIN"; + break; + case LegacyOperatorType::REDUCE_ARGMAX: + j = "OP_REDUCE_ARGMAX"; + break; + case LegacyOperatorType::REDUCE_ARGMIN: + j = "OP_REDUCE_ARGMIN"; + break; + case LegacyOperatorType::REDUCE_MAX: + j = "OP_REDUCE_MAX"; + break; + case LegacyOperatorType::REDUCE_MEAN: + j = "OP_REDUCE_MEAN"; + break; + case LegacyOperatorType::REDUCE_MIN: + j = "OP_REDUCE_MIN"; + break; + case LegacyOperatorType::REDUCE_PROD: + j = "OP_REDUCE_PROD"; + break; + case LegacyOperatorType::REDUCE_SUM: + j = "OP_REDUCE_SUM"; + break; + case LegacyOperatorType::PAD: + j = "OP_PAD"; + break; + case LegacyOperatorType::SHAPE: + j = "OP_SHAPE"; + break; + case LegacyOperatorType::SIZE: + j = "OP_SIZE"; + break; + case LegacyOperatorType::TOPK: + j = "OP_TOPK"; + break; + case LegacyOperatorType::WHERE: + j = "OP_WHERE"; + break; + case LegacyOperatorType::CEIL: + j = "OP_CEIL"; + break; + case LegacyOperatorType::CAST: + j = "OP_CAST"; + break; + case LegacyOperatorType::EXP: + j = "OP_EXP"; + break; + case LegacyOperatorType::ROUND: + j = "OP_ROUND"; + break; + case LegacyOperatorType::LOG: + j = "OP_LOG"; + break; + case LegacyOperatorType::LOGICAL_NOT: + j = "OP_LOGICAL_NOT"; + break; + case LegacyOperatorType::SQRT: + j = "OP_SQRT"; + break; + case LegacyOperatorType::SIN: + j = "OP_SIN"; + break; + case LegacyOperatorType::COS: + j = "OP_COS"; + break; + case LegacyOperatorType::LEAKYRELU: + j = "OP_LEAKYRELU"; + break; + case LegacyOperatorType::SLICE: + j = "OP_SLICE"; + break; + case LegacyOperatorType::RESIZE: + j = "OP_RESIZE"; + break; + case LegacyOperatorType::PRELU: + j = "OP_PRELU"; + break; + case LegacyOperatorType::GELU: + j = "OP_GELU"; + break; + case LegacyOperatorType::MULTIHEAD_ATTENTION: + j = "OP_MULTIHEAD_ATTENTION"; + break; + case LegacyOperatorType::FUSED: + j = "OP_FUSED"; + break; + case LegacyOperatorType::RSQRT: + j = "OP_RSQRT"; + break; + case LegacyOperatorType::POW: + j = "OP_POW"; + break; + case LegacyOperatorType::MEAN: + j = "OP_MEAN"; + break; + case LegacyOperatorType::LAYERNORM: + j = "OP_LAYERNORM"; + break; + case LegacyOperatorType::GATHER: + j = "OP_GATHER"; + break; + case LegacyOperatorType::BROADCAST: + j = "OP_BROADCAST"; + break; + case LegacyOperatorType::REPARTITION: + j = "OP_PARTITION"; + break; + case LegacyOperatorType::COMBINE: + j = "OP_COMBINE"; + break; + case LegacyOperatorType::REPLICATE: + j = "OP_REPLICATE"; + break; + case LegacyOperatorType::REDUCTION: + j = "OP_REDUCE"; + break; + case LegacyOperatorType::BATCH: + j = "OP_BATCH"; + break; + case LegacyOperatorType::PIPELINE: + j = "OP_PIPELINE"; + break; + case LegacyOperatorType::FUSED_PARALLEL: + j = "OP_FUSED_PARALLEL"; + break; + default: + std::ostringstream oss; + oss << "Unknown LegacyOperatorType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, LegacyOperatorType &x) { + std::string as_str = j.get(); + if (as_str == "OP_NOOP") { + x = LegacyOperatorType::NOOP; + } else if (as_str == "OP_INPUT") { + x = LegacyOperatorType::INPUT; + } else if (as_str == "OP_WEIGHT") { + x = LegacyOperatorType::WEIGHT; + } else if (as_str == "OP_CONV2D") { + x = LegacyOperatorType::CONV2D; + } else if (as_str == "OP_DROPOUT") { + x = LegacyOperatorType::DROPOUT; + } else if (as_str == "OP_LINEAR") { + x = LegacyOperatorType::LINEAR; + } else if (as_str == "OP_BATCHMATMUL") { + x = LegacyOperatorType::BATCHMATMUL; + } else if (as_str == "OP_POOL2D") { + x = LegacyOperatorType::POOL2D; + } else if (as_str == "OP_SCALAR_MULTIPLY") { + x = LegacyOperatorType::SCALAR_MULTIPLY; + } else if (as_str == "OP_SCALAR_ADD") { + x = LegacyOperatorType::SCALAR_ADD; + } else if (as_str == "OP_SCALAR_FLOOR_DIV") { + x = LegacyOperatorType::SCALAR_FLOOR_DIV; + } else if (as_str == "OP_SCALAR_TRUE_DIV") { + x = LegacyOperatorType::SCALAR_TRUE_DIV; + } else if (as_str == "OP_SCALAR_SUB") { + x = LegacyOperatorType::SCALAR_SUB; + } else if (as_str == "OP_RELU") { + x = LegacyOperatorType::RELU; + } else if (as_str == "OP_IDENTITY") { + x = LegacyOperatorType::IDENTITY; + } else if (as_str == "OP_SIGMOID") { + x = LegacyOperatorType::SIGMOID; + } else if (as_str == "OP_TANH") { + x = LegacyOperatorType::TANH; + } else if (as_str == "OP_ELU") { + x = LegacyOperatorType::ELU; + } else if (as_str == "OP_FLAT") { + x = LegacyOperatorType::FLAT; + } else if (as_str == "OP_SOFTMAX") { + x = LegacyOperatorType::SOFTMAX; + } else if (as_str == "OP_BATCHNORM") { + x = LegacyOperatorType::BATCHNORM; + } else if (as_str == "OP_CONCAT") { + x = LegacyOperatorType::CONCAT; + } else if (as_str == "OP_SPLIT") { + x = LegacyOperatorType::SPLIT; + } else if (as_str == "OP_EMBEDDING") { + x = LegacyOperatorType::EMBEDDING; + } else if (as_str == "OP_CACHE") { + x = LegacyOperatorType::CACHE; + } else if (as_str == "OP_RESHAPE") { + x = LegacyOperatorType::RESHAPE; + } else if (as_str == "OP_REVERSE") { + x = LegacyOperatorType::REVERSE; + } else if (as_str == "OP_TRANSPOSE") { + x = LegacyOperatorType::TRANSPOSE; + } else if (as_str == "OP_EW_ADD") { + x = LegacyOperatorType::EW_ADD; + } else if (as_str == "OP_EW_MUL") { + x = LegacyOperatorType::EW_MUL; + } else if (as_str == "OP_MATMUL") { + x = LegacyOperatorType::MATMUL; + } else if (as_str == "OP_MUL") { + x = LegacyOperatorType::MUL; + } else if (as_str == "OP_ENLARGE") { + x = LegacyOperatorType::ENLARGE; + } else if (as_str == "OP_SQUEEZE") { + x = LegacyOperatorType::SQUEEZE; + } else if (as_str == "OP_UNSQUEEZE") { + x = LegacyOperatorType::UNSQUEEZE; + } else if (as_str == "OP_EW_SUB") { + x = LegacyOperatorType::EW_SUB; + } else if (as_str == "OP_EW_DIV") { + x = LegacyOperatorType::EW_DIV; + } else if (as_str == "OP_EW_EQUAL") { + x = LegacyOperatorType::EW_EQUAL; + } else if (as_str == "OP_EW_GREATER") { + x = LegacyOperatorType::EW_GREATER; + } else if (as_str == "OP_EW_LESS") { + x = LegacyOperatorType::EW_LESS; + } else if (as_str == "OP_EW_MAX") { + x = LegacyOperatorType::EW_MAX; + } else if (as_str == "OP_EW_MIN") { + x = LegacyOperatorType::EW_MIN; + } else if (as_str == "OP_REDUCE_ARGMAX") { + x = LegacyOperatorType::REDUCE_ARGMAX; + } else if (as_str == "OP_REDUCE_ARGMIN") { + x = LegacyOperatorType::REDUCE_ARGMIN; + } else if (as_str == "OP_REDUCE_MAX") { + x = LegacyOperatorType::REDUCE_MAX; + } else if (as_str == "OP_REDUCE_MEAN") { + x = LegacyOperatorType::REDUCE_MEAN; + } else if (as_str == "OP_REDUCE_MIN") { + x = LegacyOperatorType::REDUCE_MIN; + } else if (as_str == "OP_REDUCE_PROD") { + x = LegacyOperatorType::REDUCE_PROD; + } else if (as_str == "OP_REDUCE_SUM") { + x = LegacyOperatorType::REDUCE_SUM; + } else if (as_str == "OP_PAD") { + x = LegacyOperatorType::PAD; + } else if (as_str == "OP_SHAPE") { + x = LegacyOperatorType::SHAPE; + } else if (as_str == "OP_SIZE") { + x = LegacyOperatorType::SIZE; + } else if (as_str == "OP_TOPK") { + x = LegacyOperatorType::TOPK; + } else if (as_str == "OP_WHERE") { + x = LegacyOperatorType::WHERE; + } else if (as_str == "OP_CEIL") { + x = LegacyOperatorType::CEIL; + } else if (as_str == "OP_CAST") { + x = LegacyOperatorType::CAST; + } else if (as_str == "OP_EXP") { + x = LegacyOperatorType::EXP; + } else if (as_str == "OP_ROUND") { + x = LegacyOperatorType::ROUND; + } else if (as_str == "OP_LOG") { + x = LegacyOperatorType::LOG; + } else if (as_str == "OP_LOGICAL_NOT") { + x = LegacyOperatorType::LOGICAL_NOT; + } else if (as_str == "OP_SQRT") { + x = LegacyOperatorType::SQRT; + } else if (as_str == "OP_SIN") { + x = LegacyOperatorType::SIN; + } else if (as_str == "OP_COS") { + x = LegacyOperatorType::COS; + } else if (as_str == "OP_LEAKYRELU") { + x = LegacyOperatorType::LEAKYRELU; + } else if (as_str == "OP_SLICE") { + x = LegacyOperatorType::SLICE; + } else if (as_str == "OP_RESIZE") { + x = LegacyOperatorType::RESIZE; + } else if (as_str == "OP_PRELU") { + x = LegacyOperatorType::PRELU; + } else if (as_str == "OP_GELU") { + x = LegacyOperatorType::GELU; + } else if (as_str == "OP_MULTIHEAD_ATTENTION") { + x = LegacyOperatorType::MULTIHEAD_ATTENTION; + } else if (as_str == "OP_FUSED") { + x = LegacyOperatorType::FUSED; + } else if (as_str == "OP_RSQRT") { + x = LegacyOperatorType::RSQRT; + } else if (as_str == "OP_POW") { + x = LegacyOperatorType::POW; + } else if (as_str == "OP_MEAN") { + x = LegacyOperatorType::MEAN; + } else if (as_str == "OP_LAYERNORM") { + x = LegacyOperatorType::LAYERNORM; + } else if (as_str == "OP_GATHER") { + x = LegacyOperatorType::GATHER; + } else if (as_str == "OP_BROADCAST") { + x = LegacyOperatorType::BROADCAST; + } else if (as_str == "OP_PARTITION") { + x = LegacyOperatorType::REPARTITION; + } else if (as_str == "OP_COMBINE") { + x = LegacyOperatorType::COMBINE; + } else if (as_str == "OP_REPLICATE") { + x = LegacyOperatorType::REPLICATE; + } else if (as_str == "OP_REDUCE") { + x = LegacyOperatorType::REDUCTION; + } else if (as_str == "OP_BATCH") { + x = LegacyOperatorType::BATCH; + } else if (as_str == "OP_PIPELINE") { + x = LegacyOperatorType::PIPELINE; + } else if (as_str == "OP_FUSED_PARALLEL") { + x = LegacyOperatorType::FUSED_PARALLEL; + } else { + std::ostringstream oss; + oss << "Unknown LegacyOperatorType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::element( + FlexFlow::LegacyOperatorType::NOOP, + FlexFlow::LegacyOperatorType::INPUT, + FlexFlow::LegacyOperatorType::WEIGHT, + FlexFlow::LegacyOperatorType::CONV2D, + FlexFlow::LegacyOperatorType::DROPOUT, + FlexFlow::LegacyOperatorType::LINEAR, + FlexFlow::LegacyOperatorType::BATCHMATMUL, + FlexFlow::LegacyOperatorType::POOL2D, + FlexFlow::LegacyOperatorType::SCALAR_MULTIPLY, + FlexFlow::LegacyOperatorType::SCALAR_ADD, + FlexFlow::LegacyOperatorType::SCALAR_FLOOR_DIV, + FlexFlow::LegacyOperatorType::SCALAR_TRUE_DIV, + FlexFlow::LegacyOperatorType::SCALAR_SUB, + FlexFlow::LegacyOperatorType::RELU, + FlexFlow::LegacyOperatorType::IDENTITY, + FlexFlow::LegacyOperatorType::SIGMOID, + FlexFlow::LegacyOperatorType::TANH, + FlexFlow::LegacyOperatorType::ELU, + FlexFlow::LegacyOperatorType::FLAT, + FlexFlow::LegacyOperatorType::SOFTMAX, + FlexFlow::LegacyOperatorType::BATCHNORM, + FlexFlow::LegacyOperatorType::CONCAT, + FlexFlow::LegacyOperatorType::SPLIT, + FlexFlow::LegacyOperatorType::EMBEDDING, + FlexFlow::LegacyOperatorType::CACHE, + FlexFlow::LegacyOperatorType::RESHAPE, + FlexFlow::LegacyOperatorType::REVERSE, + FlexFlow::LegacyOperatorType::TRANSPOSE, + FlexFlow::LegacyOperatorType::EW_ADD, + FlexFlow::LegacyOperatorType::EW_MUL, + FlexFlow::LegacyOperatorType::MATMUL, + FlexFlow::LegacyOperatorType::MUL, + FlexFlow::LegacyOperatorType::ENLARGE, + FlexFlow::LegacyOperatorType::SQUEEZE, + FlexFlow::LegacyOperatorType::UNSQUEEZE, + FlexFlow::LegacyOperatorType::EW_SUB, + FlexFlow::LegacyOperatorType::EW_DIV, + FlexFlow::LegacyOperatorType::EW_EQUAL, + FlexFlow::LegacyOperatorType::EW_GREATER, + FlexFlow::LegacyOperatorType::EW_LESS, + FlexFlow::LegacyOperatorType::EW_MAX, + FlexFlow::LegacyOperatorType::EW_MIN, + FlexFlow::LegacyOperatorType::REDUCE_ARGMAX, + FlexFlow::LegacyOperatorType::REDUCE_ARGMIN, + FlexFlow::LegacyOperatorType::REDUCE_MAX, + FlexFlow::LegacyOperatorType::REDUCE_MEAN, + FlexFlow::LegacyOperatorType::REDUCE_MIN, + FlexFlow::LegacyOperatorType::REDUCE_PROD, + FlexFlow::LegacyOperatorType::REDUCE_SUM, + FlexFlow::LegacyOperatorType::PAD, + FlexFlow::LegacyOperatorType::SHAPE, + FlexFlow::LegacyOperatorType::SIZE, + FlexFlow::LegacyOperatorType::TOPK, + FlexFlow::LegacyOperatorType::WHERE, + FlexFlow::LegacyOperatorType::CEIL, + FlexFlow::LegacyOperatorType::CAST, + FlexFlow::LegacyOperatorType::EXP, + FlexFlow::LegacyOperatorType::ROUND, + FlexFlow::LegacyOperatorType::LOG, + FlexFlow::LegacyOperatorType::LOGICAL_NOT, + FlexFlow::LegacyOperatorType::SQRT, + FlexFlow::LegacyOperatorType::SIN, + FlexFlow::LegacyOperatorType::COS, + FlexFlow::LegacyOperatorType::LEAKYRELU, + FlexFlow::LegacyOperatorType::SLICE, + FlexFlow::LegacyOperatorType::RESIZE, + FlexFlow::LegacyOperatorType::PRELU, + FlexFlow::LegacyOperatorType::GELU, + FlexFlow::LegacyOperatorType::MULTIHEAD_ATTENTION, + FlexFlow::LegacyOperatorType::FUSED, + FlexFlow::LegacyOperatorType::RSQRT, + FlexFlow::LegacyOperatorType::POW, + FlexFlow::LegacyOperatorType::MEAN, + FlexFlow::LegacyOperatorType::LAYERNORM, + FlexFlow::LegacyOperatorType::GATHER, + FlexFlow::LegacyOperatorType::BROADCAST, + FlexFlow::LegacyOperatorType::REPARTITION, + FlexFlow::LegacyOperatorType::COMBINE, + FlexFlow::LegacyOperatorType::REPLICATE, + FlexFlow::LegacyOperatorType::REDUCTION, + FlexFlow::LegacyOperatorType::BATCH, + FlexFlow::LegacyOperatorType::PIPELINE, + FlexFlow::LegacyOperatorType::FUSED_PARALLEL); +} +} // namespace rc diff --git a/lib/substitution-generator/src/substitution-generator/legacy_pm_parameter.dtg.cc b/lib/substitution-generator/src/substitution-generator/legacy_pm_parameter.dtg.cc new file mode 100644 index 0000000000..c8df4ccd7d --- /dev/null +++ b/lib/substitution-generator/src/substitution-generator/legacy_pm_parameter.dtg.cc @@ -0,0 +1,313 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml +/* proj-data +{ + "generated_from": "e8dda0047c91e576878b86df2fec0b6b" +} +*/ + +#include "substitution-generator/legacy_pm_parameter.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::LegacyPMParameter x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(LegacyPMParameter x) { + switch (x) { + case LegacyPMParameter::OP_TYPE: + return "OP_TYPE"; + case LegacyPMParameter::NUM_INPUTS: + return "NUM_INPUTS"; + case LegacyPMParameter::NUM_OUTPUTS: + return "NUM_OUTPUTS"; + case LegacyPMParameter::GROUP: + return "GROUP"; + case LegacyPMParameter::KERNEL_H: + return "KERNEL_H"; + case LegacyPMParameter::KERNEL_W: + return "KERNEL_W"; + case LegacyPMParameter::STRIDE_H: + return "STRIDE_H"; + case LegacyPMParameter::STRIDE_W: + return "STRIDE_W"; + case LegacyPMParameter::PADDING_H: + return "PADDING_H"; + case LegacyPMParameter::PADDING_W: + return "PADDING_W"; + case LegacyPMParameter::ACTI: + return "ACTI"; + case LegacyPMParameter::NUMDIM: + return "NUMDIM"; + case LegacyPMParameter::AXIS: + return "AXIS"; + case LegacyPMParameter::PERM: + return "PERM"; + case LegacyPMParameter::OUTSHUFFLE: + return "OUTSHUFFLE"; + case LegacyPMParameter::MERGE_GCONV_COUNT: + return "MERGE_GCONV_COUNT"; + case LegacyPMParameter::AXES: + return "AXES"; + case LegacyPMParameter::KEEP_DIMS: + return "KEEP_DIMS"; + case LegacyPMParameter::EPSILON: + return "EPSILON"; + case LegacyPMParameter::REPARTITION_DIM: + return "REPARTITION_DIM"; + case LegacyPMParameter::REPARTITION_DEGREE: + return "REPARTITION_DEGREE"; + case LegacyPMParameter::REPLICATE_DIM: + return "REPLICATE_DIM"; + case LegacyPMParameter::REPLICATE_DEGREE: + return "REPLICATE_DEGREE"; + case LegacyPMParameter::COMBINE_DIM: + return "COMBINE_DIM"; + case LegacyPMParameter::COMBINE_DEGREE: + return "COMBINE_DEGREE"; + case LegacyPMParameter::REDUCTION_DIM: + return "REDUCTION_DIM"; + case LegacyPMParameter::REDUCTION_DEGREE: + return "REDUCTION_DEGREE"; + case LegacyPMParameter::SOFTMAX_DIM: + return "SOFTMAX_DIM"; + case LegacyPMParameter::NUM_HEADS: + return "NUM_HEADS"; + case LegacyPMParameter::PARALLEL_DIM: + return "PARALLEL_DIM"; + case LegacyPMParameter::PARALLEL_DEGREE: + return "PARALLEL_DEGREE"; + case LegacyPMParameter::PAD: + return "PAD"; + default: + std::ostringstream oss; + oss << "Unknown LegacyPMParameter value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, LegacyPMParameter x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, LegacyPMParameter x) { + switch (x) { + case LegacyPMParameter::OP_TYPE: + j = "PM_OP_TYPE"; + break; + case LegacyPMParameter::NUM_INPUTS: + j = "PM_NUM_INPUTS"; + break; + case LegacyPMParameter::NUM_OUTPUTS: + j = "PM_NUM_OUTPUTS"; + break; + case LegacyPMParameter::GROUP: + j = "PM_GROUP"; + break; + case LegacyPMParameter::KERNEL_H: + j = "PM_KERNEL_H"; + break; + case LegacyPMParameter::KERNEL_W: + j = "PM_KERNEL_W"; + break; + case LegacyPMParameter::STRIDE_H: + j = "PM_STRIDE_H"; + break; + case LegacyPMParameter::STRIDE_W: + j = "PM_STRIDE_W"; + break; + case LegacyPMParameter::PADDING_H: + j = "PM_PADDING_H"; + break; + case LegacyPMParameter::PADDING_W: + j = "PM_PADDING_W"; + break; + case LegacyPMParameter::ACTI: + j = "PM_ACTI"; + break; + case LegacyPMParameter::NUMDIM: + j = "PM_NUMDIM"; + break; + case LegacyPMParameter::AXIS: + j = "PM_AXIS"; + break; + case LegacyPMParameter::PERM: + j = "PM_PERM"; + break; + case LegacyPMParameter::OUTSHUFFLE: + j = "PM_OUTSHUFFLE"; + break; + case LegacyPMParameter::MERGE_GCONV_COUNT: + j = "PM_MERGE_GCONV_COUNT"; + break; + case LegacyPMParameter::AXES: + j = "PM_AXES"; + break; + case LegacyPMParameter::KEEP_DIMS: + j = "PM_KEEP_DIMS"; + break; + case LegacyPMParameter::EPSILON: + j = "PM_EPSILON"; + break; + case LegacyPMParameter::REPARTITION_DIM: + j = "PM_REPARTITION_DIM"; + break; + case LegacyPMParameter::REPARTITION_DEGREE: + j = "PM_REPARTITION_DEGREE"; + break; + case LegacyPMParameter::REPLICATE_DIM: + j = "PM_REPLICATE_DIM"; + break; + case LegacyPMParameter::REPLICATE_DEGREE: + j = "PM_REPLICATE_DEGREE"; + break; + case LegacyPMParameter::COMBINE_DIM: + j = "PM_COMBINE_DIM"; + break; + case LegacyPMParameter::COMBINE_DEGREE: + j = "PM_COMBINE_DEGREE"; + break; + case LegacyPMParameter::REDUCTION_DIM: + j = "PM_REDUCTION_DIM"; + break; + case LegacyPMParameter::REDUCTION_DEGREE: + j = "PM_REDUCTION_DEGREE"; + break; + case LegacyPMParameter::SOFTMAX_DIM: + j = "PM_SOFTMAX_DIM"; + break; + case LegacyPMParameter::NUM_HEADS: + j = "PM_NUM_HEADS"; + break; + case LegacyPMParameter::PARALLEL_DIM: + j = "PM_PARALLEL_DIM"; + break; + case LegacyPMParameter::PARALLEL_DEGREE: + j = "PM_PARALLEL_DEGREE"; + break; + case LegacyPMParameter::PAD: + j = "PM_PAD"; + break; + default: + std::ostringstream oss; + oss << "Unknown LegacyPMParameter value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, LegacyPMParameter &x) { + std::string as_str = j.get(); + if (as_str == "PM_OP_TYPE") { + x = LegacyPMParameter::OP_TYPE; + } else if (as_str == "PM_NUM_INPUTS") { + x = LegacyPMParameter::NUM_INPUTS; + } else if (as_str == "PM_NUM_OUTPUTS") { + x = LegacyPMParameter::NUM_OUTPUTS; + } else if (as_str == "PM_GROUP") { + x = LegacyPMParameter::GROUP; + } else if (as_str == "PM_KERNEL_H") { + x = LegacyPMParameter::KERNEL_H; + } else if (as_str == "PM_KERNEL_W") { + x = LegacyPMParameter::KERNEL_W; + } else if (as_str == "PM_STRIDE_H") { + x = LegacyPMParameter::STRIDE_H; + } else if (as_str == "PM_STRIDE_W") { + x = LegacyPMParameter::STRIDE_W; + } else if (as_str == "PM_PADDING_H") { + x = LegacyPMParameter::PADDING_H; + } else if (as_str == "PM_PADDING_W") { + x = LegacyPMParameter::PADDING_W; + } else if (as_str == "PM_ACTI") { + x = LegacyPMParameter::ACTI; + } else if (as_str == "PM_NUMDIM") { + x = LegacyPMParameter::NUMDIM; + } else if (as_str == "PM_AXIS") { + x = LegacyPMParameter::AXIS; + } else if (as_str == "PM_PERM") { + x = LegacyPMParameter::PERM; + } else if (as_str == "PM_OUTSHUFFLE") { + x = LegacyPMParameter::OUTSHUFFLE; + } else if (as_str == "PM_MERGE_GCONV_COUNT") { + x = LegacyPMParameter::MERGE_GCONV_COUNT; + } else if (as_str == "PM_AXES") { + x = LegacyPMParameter::AXES; + } else if (as_str == "PM_KEEP_DIMS") { + x = LegacyPMParameter::KEEP_DIMS; + } else if (as_str == "PM_EPSILON") { + x = LegacyPMParameter::EPSILON; + } else if (as_str == "PM_REPARTITION_DIM") { + x = LegacyPMParameter::REPARTITION_DIM; + } else if (as_str == "PM_REPARTITION_DEGREE") { + x = LegacyPMParameter::REPARTITION_DEGREE; + } else if (as_str == "PM_REPLICATE_DIM") { + x = LegacyPMParameter::REPLICATE_DIM; + } else if (as_str == "PM_REPLICATE_DEGREE") { + x = LegacyPMParameter::REPLICATE_DEGREE; + } else if (as_str == "PM_COMBINE_DIM") { + x = LegacyPMParameter::COMBINE_DIM; + } else if (as_str == "PM_COMBINE_DEGREE") { + x = LegacyPMParameter::COMBINE_DEGREE; + } else if (as_str == "PM_REDUCTION_DIM") { + x = LegacyPMParameter::REDUCTION_DIM; + } else if (as_str == "PM_REDUCTION_DEGREE") { + x = LegacyPMParameter::REDUCTION_DEGREE; + } else if (as_str == "PM_SOFTMAX_DIM") { + x = LegacyPMParameter::SOFTMAX_DIM; + } else if (as_str == "PM_NUM_HEADS") { + x = LegacyPMParameter::NUM_HEADS; + } else if (as_str == "PM_PARALLEL_DIM") { + x = LegacyPMParameter::PARALLEL_DIM; + } else if (as_str == "PM_PARALLEL_DEGREE") { + x = LegacyPMParameter::PARALLEL_DEGREE; + } else if (as_str == "PM_PAD") { + x = LegacyPMParameter::PAD; + } else { + std::ostringstream oss; + oss << "Unknown LegacyPMParameter value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::element( + FlexFlow::LegacyPMParameter::OP_TYPE, + FlexFlow::LegacyPMParameter::NUM_INPUTS, + FlexFlow::LegacyPMParameter::NUM_OUTPUTS, + FlexFlow::LegacyPMParameter::GROUP, + FlexFlow::LegacyPMParameter::KERNEL_H, + FlexFlow::LegacyPMParameter::KERNEL_W, + FlexFlow::LegacyPMParameter::STRIDE_H, + FlexFlow::LegacyPMParameter::STRIDE_W, + FlexFlow::LegacyPMParameter::PADDING_H, + FlexFlow::LegacyPMParameter::PADDING_W, + FlexFlow::LegacyPMParameter::ACTI, + FlexFlow::LegacyPMParameter::NUMDIM, + FlexFlow::LegacyPMParameter::AXIS, + FlexFlow::LegacyPMParameter::PERM, + FlexFlow::LegacyPMParameter::OUTSHUFFLE, + FlexFlow::LegacyPMParameter::MERGE_GCONV_COUNT, + FlexFlow::LegacyPMParameter::AXES, + FlexFlow::LegacyPMParameter::KEEP_DIMS, + FlexFlow::LegacyPMParameter::EPSILON, + FlexFlow::LegacyPMParameter::REPARTITION_DIM, + FlexFlow::LegacyPMParameter::REPARTITION_DEGREE, + FlexFlow::LegacyPMParameter::REPLICATE_DIM, + FlexFlow::LegacyPMParameter::REPLICATE_DEGREE, + FlexFlow::LegacyPMParameter::COMBINE_DIM, + FlexFlow::LegacyPMParameter::COMBINE_DEGREE, + FlexFlow::LegacyPMParameter::REDUCTION_DIM, + FlexFlow::LegacyPMParameter::REDUCTION_DEGREE, + FlexFlow::LegacyPMParameter::SOFTMAX_DIM, + FlexFlow::LegacyPMParameter::NUM_HEADS, + FlexFlow::LegacyPMParameter::PARALLEL_DIM, + FlexFlow::LegacyPMParameter::PARALLEL_DEGREE, + FlexFlow::LegacyPMParameter::PAD); +} +} // namespace rc diff --git a/lib/substitution-generator/test/substitution-generator/json.cc b/lib/substitution-generator/test/substitution-generator/json.cc index d12b294a2e..befdaf1308 100644 --- a/lib/substitution-generator/test/substitution-generator/json.cc +++ b/lib/substitution-generator/test/substitution-generator/json.cc @@ -18,7 +18,7 @@ TEST_SUITE(FF_TEST_SUITE) { Operator o; from_json(j, o); - CHECK(o.op_type == Op::EW_ADD); + CHECK(o.op_type == LegacyOperatorType::EW_ADD); CHECK(o.input.size() == 2); CHECK(o.input[0].opId == -2); CHECK(o.input[0].tsId == 0); diff --git a/lib/substitutions/include/substitutions/attribute_expr.h b/lib/substitutions/include/substitutions/attribute_expr.h deleted file mode 100644 index 0afd48b431..0000000000 --- a/lib/substitutions/include/substitutions/attribute_expr.h +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_CONSTRAINT_H -#define _FLEXFLOW_SUBSTITUTIONS_CONSTRAINT_H - -#include "utils/variant.h" - -namespace FlexFlow { - -enum class ConstraintType { EQUAL }; - -template -struct ListIndexAccess { - T attribute_key; - req index; -}; - -template -struct ListSize { - req attribute_key; -}; - -template -using AttributeExpr = std::variant, ListSize>; - -template -struct AttributeConstraint { - ConstraintType constraint_type; - AttributeExpr attribute_expr; - V attribute_value; -}; - -template -struct AttributePattern { - std::vector> attribute_constraints; - // TODO: Revert to unordered_set once we have visitable for templates - // std::unordered_set> attribute_constraints; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/constraint_type.dtg.h b/lib/substitutions/include/substitutions/constraint_type.dtg.h new file mode 100644 index 0000000000..f99b794e46 --- /dev/null +++ b/lib/substitutions/include/substitutions/constraint_type.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/constraint_type.enum.toml +/* proj-data +{ + "generated_from": "06b029d76658cb434abf08b1fdb86137" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_CONSTRAINT_TYPE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_CONSTRAINT_TYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class ConstraintType { EQUAL }; +std::string format_as(ConstraintType); +std::ostream &operator<<(std::ostream &, ConstraintType); +void to_json(::nlohmann::json &, ConstraintType); +void from_json(::nlohmann::json const &, ConstraintType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ConstraintType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_CONSTRAINT_TYPE_DTG_H diff --git a/lib/substitutions/include/substitutions/constraint_type.enum.toml b/lib/substitutions/include/substitutions/constraint_type.enum.toml new file mode 100644 index 0000000000..8646ba1c83 --- /dev/null +++ b/lib/substitutions/include/substitutions/constraint_type.enum.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "ConstraintType" +features = [ + "json", + "hash", + "rapidcheck", + "fmt", +] + +[[values]] +name = "EQUAL" diff --git a/lib/substitutions/include/substitutions/graph_pattern.h b/lib/substitutions/include/substitutions/graph_pattern.h index 4f4021203b..5f03a6e92e 100644 --- a/lib/substitutions/include/substitutions/graph_pattern.h +++ b/lib/substitutions/include/substitutions/graph_pattern.h @@ -1,32 +1,24 @@ #ifndef _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_H #define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_H -#include "graph_pattern_match.h" -#include "operator_pattern.h" -#include "parallel_tensor_pattern.h" -#include "sub_parallel_computation_graph.h" +#include "substitutions/pcg_pattern.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_matching.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" namespace FlexFlow { -struct GraphPattern - : public strong_typedef< - GraphPattern, - OutputLabelledOpenMultiDiGraph> { - using strong_typedef::strong_typedef; -}; +UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &); -GraphSplit split_pattern(OpenMultiDiGraphView const &pattern); - -bool is_singleton_pattern(OpenMultiDiGraphView const &); - -bool operator_satisfies(Operator const ¶ms, OperatorPattern const &pattern); - -bool parallel_tensor_satisfies(ParallelTensor const ¶ms, - ParallelTensorPattern const &pattern); +TensorAttributePattern get_tensor_pattern(PCGPattern const &, + PatternEdge const &); +OperatorAttributePattern get_operator_pattern(PCGPattern const &, + PatternNode const &); bool assignment_satisfies(SubParallelComputationGraph const &, - GraphPattern const &, + PCGPattern const &, MultiDiGraphPatternMatch const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/graph_pattern_match.h b/lib/substitutions/include/substitutions/graph_pattern_match.h deleted file mode 100644 index bf6d6b6921..0000000000 --- a/lib/substitutions/include/substitutions/graph_pattern_match.h +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_GRAPH_PATTERN_MATCH_H -#define _FLEXFLOW_SUBSTITUTIONS_GRAPH_PATTERN_MATCH_H - -#include "utils/graph.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct MultiDiGraphPatternMatch { - using PatternNode = Node; - using PCGNode = Node; - using PatternEdge = OpenMultiDiEdge; - using PCGEdge = OpenMultiDiEdge; - - bidict node_assignment; - bidict edge_assignment; -}; - -struct MatchSplit { - MultiDiGraphPatternMatch prefix_submatch; - MultiDiGraphPatternMatch postfix_submatch; -}; - -struct MatchAdditionalCriterion { - std::function node_criterion; - std::function - edge_criterion; -}; - -bool pattern_matches(OpenMultiDiGraphView const &pattern, - OpenMultiDiGraphView const &graph, - MultiDiGraphPatternMatch const &match, - MatchAdditionalCriterion const &additional_criterion); - -std::vector - find_pattern_matches(OpenMultiDiGraphView const &pattern, - OpenMultiDiGraphView const &graph, - MatchAdditionalCriterion const &additional_criterion); - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h deleted file mode 100644 index 8fc4ebefc2..0000000000 --- a/lib/substitutions/include/substitutions/operator_pattern.h +++ /dev/null @@ -1,107 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_OPERATOR_PATTERN_H -#define _FLEXFLOW_SUBSTITUTIONS_OPERATOR_PATTERN_H - -#include "attribute_expr.h" -#include "op-attrs/activation.h" -#include "op-attrs/datatype.h" -#include "op-attrs/op.h" -#include "pcg/operator.h" -#include -#include - -namespace FlexFlow { - -enum class OperatorAttributeKey { - OP_TYPE, // AnyOp - USE_BIAS, - GROUPS, - POOL_TYPE, - KERNEL_H, - KERNEL_W, - DATA_TYPE, - SCALAR, - STRIDE_H, - STRIDE_W, - PADDING_H, - PADDING_W, - AGGR, - NUM_ENTRIES, - OUT_CHANNELS, - ACTIVATION, - NUMDIM, - AXIS, - PERMUTATION, - OUTSHUFFLE, - MERGE_GCONV_COUNT, - AXES, - KEEP_DIMS, - EPSILON, - PARALLEL_OP_DIM, - PARALLEL_OP_DEGREE, - SOFTMAX_DIM, - NUM_HEADS, - PARALLEL_DIM, - PARALLEL_DEGREE, - PAD, - EMBED_DIM, - KDIM, - VDIM, - DROPOUT, - BIAS, - ADD_BIAS_KV, - ADD_ZERO_ATTN, - A_SEQ_LENGTH_DIM, - B_SEQ_LENGTH_DIM, - RELU, - TARGET_DIMS, - RATE, - SEED, - SHOULD_BROADCAST_LHS, - SHOULD_BROADCAST_RHS, - DIM, - ELEMENTWISE_AFFINE, - REGULARIZER, - SHAPE, - SPLITS, - K, - SORTED, - COMBINE_DIM, - COMBINE_DEGREE, - NUM_INPUTS -}; - -using OperatorAttributeValue = - std::variant, - stack_vector, - OperatorType, - Activation, - ff_dim_t, - unsigned long long, - AggregateOp, - stack_vector, - std::optional, - PoolOp, - TensorShape, - DataType>; - -FF_VISITABLE_STRUCT(ListIndexAccess, - attribute_key, - index); -FF_VISITABLE_STRUCT(ListSize, attribute_key); - -using OperatorAttributeConstraint = - AttributeConstraint; - -using OperatorPattern = - AttributePattern; - -std::optional - evaluate_attribute_expr(Operator const &attrs, - AttributeExpr const &expr); - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/eval_list_access.h b/lib/substitutions/include/substitutions/operator_pattern/eval_list_access.h new file mode 100644 index 0000000000..93d2d56384 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/eval_list_access.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_list_access.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include + +namespace FlexFlow { + +std::optional + eval_list_access(PCGOperatorAttrs const &attrs, + OperatorAttributeListIndexAccess const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/eval_list_size.h b/lib/substitutions/include/substitutions/operator_pattern/eval_list_size.h new file mode 100644 index 0000000000..236a248945 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/eval_list_size.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_list_size.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" + +namespace FlexFlow { + +std::optional + eval_list_size(PCGOperatorAttrs const &attrs, + OperatorAttributeListSize const &acc); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/get_attribute.h b/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h similarity index 83% rename from lib/substitutions/include/substitutions/get_attribute.h rename to lib/substitutions/include/substitutions/operator_pattern/get_attribute.h index 0e6dd4c69b..93f4a2bc0f 100644 --- a/lib/substitutions/include/substitutions/get_attribute.h +++ b/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h @@ -1,9 +1,10 @@ #ifndef _FLEXFLOW_SUBSTITUTIONS_GET_ATTRIBUTES_H #define _FLEXFLOW_SUBSTITUTIONS_GET_ATTRIBUTES_H -#include "op-attrs/operator_attrs.h" -#include "operator_pattern.h" -#include "utils/optional.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include namespace FlexFlow { @@ -11,6 +12,8 @@ std::optional get_attribute(PCGOperatorAttrs const &, OperatorAttributeKey); std::optional get_attribute(BatchMatmulAttrs const &p, OperatorAttributeKey); +std::optional get_attribute(BatchNormAttrs const &p, + OperatorAttributeKey); std::optional get_attribute(CastAttrs const &p, OperatorAttributeKey); std::optional get_attribute(CombineAttrs const &p, @@ -33,12 +36,17 @@ std::optional get_attribute(FlatAttrs const &p, OperatorAttributeKey); std::optional get_attribute(GatherAttrs const &p, OperatorAttributeKey); +std::optional get_attribute(InputAttrs const &p, + OperatorAttributeKey); std::optional get_attribute(LayerNormAttrs const &p, OperatorAttributeKey); std::optional get_attribute(LinearAttrs const &p, OperatorAttributeKey); std::optional get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey); + +std::optional get_attribute(NoopAttrs const &p, + OperatorAttributeKey); std::optional get_attribute(Pool2DAttrs const &p, OperatorAttributeKey); std::optional get_attribute(ReduceAttrs const &p, @@ -51,6 +59,8 @@ std::optional get_attribute(ReplicateAttrs const &p, OperatorAttributeKey); std::optional get_attribute(ReshapeAttrs const &p, OperatorAttributeKey); +std::optional get_attribute(ReverseAttrs const &p, + OperatorAttributeKey); std::optional get_attribute(SplitAttrs const &p, OperatorAttributeKey); std::optional get_attribute(SoftmaxAttrs const &p, diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.h new file mode 100644 index 0000000000..35ec9e499f --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml +/* proj-data +{ + "generated_from": "7867bd0f403866c13417171bb5ec364c" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_CONSTRAINT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_CONSTRAINT_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/constraint_type.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributeConstraint { + OperatorAttributeConstraint() = delete; + OperatorAttributeConstraint( + ::FlexFlow::ConstraintType const &constraint_type, + ::FlexFlow::OperatorAttributeExpr const &attribute_expr, + ::FlexFlow::OperatorAttributeValue const &attribute_value); + + bool operator==(OperatorAttributeConstraint const &) const; + bool operator!=(OperatorAttributeConstraint const &) const; + bool operator<(OperatorAttributeConstraint const &) const; + bool operator>(OperatorAttributeConstraint const &) const; + bool operator<=(OperatorAttributeConstraint const &) const; + bool operator>=(OperatorAttributeConstraint const &) const; + ::FlexFlow::ConstraintType constraint_type; + ::FlexFlow::OperatorAttributeExpr attribute_expr; + ::FlexFlow::OperatorAttributeValue attribute_value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorAttributeConstraint const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::OperatorAttributeConstraint from_json(json const &); + static void to_json(json &, FlexFlow::OperatorAttributeConstraint const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(OperatorAttributeConstraint const &); +std::ostream &operator<<(std::ostream &, OperatorAttributeConstraint const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_CONSTRAINT_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml new file mode 100644 index 0000000000..646faf878e --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "OperatorAttributeConstraint" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "substitutions/constraint_type.dtg.h", + "substitutions/operator_pattern/operator_attribute_expr.dtg.h", + "substitutions/operator_pattern/operator_attribute_value.dtg.h", +] + +[[fields]] +name = "constraint_type" +type = "::FlexFlow::ConstraintType" + +[[fields]] +name = "attribute_expr" +type = "::FlexFlow::OperatorAttributeExpr" + +[[fields]] +name = "attribute_value" +type = "::FlexFlow::OperatorAttributeValue" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.dtg.h new file mode 100644 index 0000000000..a66a035ba8 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.dtg.h @@ -0,0 +1,143 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "15d26dd1f08092ecc82b725aa9411597" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_list_access.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_list_size.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributeExpr { + OperatorAttributeExpr() = delete; + explicit OperatorAttributeExpr(::FlexFlow::OperatorAttributeKey const &); + explicit OperatorAttributeExpr(::FlexFlow::OperatorAttributeListSize const &); + explicit OperatorAttributeExpr( + ::FlexFlow::OperatorAttributeListIndexAccess const &); + template + static constexpr bool IsPartOfOperatorAttributeExpr_v = + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::OperatorAttributeKey>()); + return result; + } + case 1: { + ReturnType result = + v(this->get<::FlexFlow::OperatorAttributeListSize>()); + return result; + } + case 2: { + ReturnType result = + v(this->get<::FlexFlow::OperatorAttributeListIndexAccess>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeExpr", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::OperatorAttributeKey>()); + return result; + } + case 1: { + ReturnType result = + v(this->get<::FlexFlow::OperatorAttributeListSize>()); + return result; + } + case 2: { + ReturnType result = + v(this->get<::FlexFlow::OperatorAttributeListIndexAccess>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeExpr", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfOperatorAttributeExpr_v, + "OperatorAttributeExpr::has() expected one of " + "[::FlexFlow::OperatorAttributeKey, " + "::FlexFlow::OperatorAttributeListSize, " + "::FlexFlow::OperatorAttributeListIndexAccess], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfOperatorAttributeExpr_v, + "OperatorAttributeExpr::get() expected one of " + "[::FlexFlow::OperatorAttributeKey, " + "::FlexFlow::OperatorAttributeListSize, " + "::FlexFlow::OperatorAttributeListIndexAccess], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfOperatorAttributeExpr_v, + "OperatorAttributeExpr::get() expected one of " + "[::FlexFlow::OperatorAttributeKey, " + "::FlexFlow::OperatorAttributeListSize, " + "::FlexFlow::OperatorAttributeListIndexAccess], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(OperatorAttributeExpr const &) const; + bool operator!=(OperatorAttributeExpr const &) const; + bool operator<(OperatorAttributeExpr const &) const; + bool operator>(OperatorAttributeExpr const &) const; + bool operator<=(OperatorAttributeExpr const &) const; + bool operator>=(OperatorAttributeExpr const &) const; + std::variant<::FlexFlow::OperatorAttributeKey, + ::FlexFlow::OperatorAttributeListSize, + ::FlexFlow::OperatorAttributeListIndexAccess> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::OperatorAttributeExpr> { + size_t operator()(::FlexFlow::OperatorAttributeExpr const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::OperatorAttributeExpr> { + static ::FlexFlow::OperatorAttributeExpr from_json(json const &); + static void to_json(json &, ::FlexFlow::OperatorAttributeExpr const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::OperatorAttributeExpr const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::OperatorAttributeExpr const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h new file mode 100644 index 0000000000..4528847771 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_H + +#include "pcg/parallel_layer_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include + +namespace FlexFlow { + +std::optional + evaluate_attribute_expr(PCGOperatorAttrs const &attrs, + OperatorAttributeExpr const &expr); +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml new file mode 100644 index 0000000000..ff79ecaaa5 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "OperatorAttributeExpr" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h", + "substitutions/operator_pattern/operator_attribute_list_access.dtg.h", + "substitutions/operator_pattern/operator_attribute_list_size.dtg.h", +] + +[[values]] +type = "::FlexFlow::OperatorAttributeKey" +key = "key" + +[[values]] +type = "::FlexFlow::OperatorAttributeListSize" +key = "list_size" + +[[values]] +type = "::FlexFlow::OperatorAttributeListIndexAccess" +key = "list_idx" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.dtg.h new file mode 100644 index 0000000000..49a5ccbbe6 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.dtg.h @@ -0,0 +1,97 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml +/* proj-data +{ + "generated_from": "e637388397720b328b1f4b9ba6b14611" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_KEY_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_KEY_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class OperatorAttributeKey { + OP_TYPE, + USE_BIAS, + GROUPS, + POOL_TYPE, + KERNEL_H, + KERNEL_W, + DATA_TYPE, + SCALAR, + STRIDE_H, + STRIDE_W, + PADDING_H, + PADDING_W, + AGGR, + NUM_ENTRIES, + OUT_CHANNELS, + ACTIVATION, + NUMDIM, + AXIS, + PERMUTATION, + OUTSHUFFLE, + MERGE_GCONV_COUNT, + AXES, + KEEP_DIMS, + EPSILON, + PARALLEL_OP_DIM, + PARALLEL_OP_DEGREE, + SOFTMAX_DIM, + NUM_HEADS, + PARALLEL_DIM, + PARALLEL_DEGREE, + PAD, + EMBED_DIM, + KDIM, + VDIM, + DROPOUT, + BIAS, + ADD_BIAS_KV, + ADD_ZERO_ATTN, + A_SEQ_LENGTH_DIM, + B_SEQ_LENGTH_DIM, + RELU, + TARGET_DIMS, + RATE, + SEED, + SHOULD_BROADCAST_LHS, + SHOULD_BROADCAST_RHS, + DIM, + ELEMENTWISE_AFFINE, + REGULARIZER, + SHAPE, + SPLITS, + K, + SORTED, + COMBINE_DIM, + COMBINE_DEGREE, + NUM_INPUTS +}; +std::string format_as(OperatorAttributeKey); +std::ostream &operator<<(std::ostream &, OperatorAttributeKey); +void to_json(::nlohmann::json &, OperatorAttributeKey); +void from_json(::nlohmann::json const &, OperatorAttributeKey &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorAttributeKey) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_KEY_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml new file mode 100644 index 0000000000..59e913750e --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml @@ -0,0 +1,67 @@ +namespace = "FlexFlow" +name = "OperatorAttributeKey" +features = [ + "json", + "hash", + "fmt", + "rapidcheck", +] + +values = [ + { name = "OP_TYPE" }, + { name = "USE_BIAS" }, + { name = "GROUPS" }, + { name = "POOL_TYPE" }, + { name = "KERNEL_H" }, + { name = "KERNEL_W" }, + { name = "DATA_TYPE" }, + { name = "SCALAR" }, + { name = "STRIDE_H" }, + { name = "STRIDE_W" }, + { name = "PADDING_H" }, + { name = "PADDING_W" }, + { name = "AGGR" }, + { name = "NUM_ENTRIES" }, + { name = "OUT_CHANNELS" }, + { name = "ACTIVATION" }, + { name = "NUMDIM" }, + { name = "AXIS" }, + { name = "PERMUTATION" }, + { name = "OUTSHUFFLE" }, + { name = "MERGE_GCONV_COUNT" }, + { name = "AXES" }, + { name = "KEEP_DIMS" }, + { name = "EPSILON" }, + { name = "PARALLEL_OP_DIM" }, + { name = "PARALLEL_OP_DEGREE" }, + { name = "SOFTMAX_DIM" }, + { name = "NUM_HEADS" }, + { name = "PARALLEL_DIM" }, + { name = "PARALLEL_DEGREE" }, + { name = "PAD" }, + { name = "EMBED_DIM" }, + { name = "KDIM" }, + { name = "VDIM" }, + { name = "DROPOUT" }, + { name = "BIAS" }, + { name = "ADD_BIAS_KV" }, + { name = "ADD_ZERO_ATTN" }, + { name = "A_SEQ_LENGTH_DIM" }, + { name = "B_SEQ_LENGTH_DIM" }, + { name = "RELU" }, + { name = "TARGET_DIMS" }, + { name = "RATE" }, + { name = "SEED" }, + { name = "SHOULD_BROADCAST_LHS" }, + { name = "SHOULD_BROADCAST_RHS" }, + { name = "DIM" }, + { name = "ELEMENTWISE_AFFINE" }, + { name = "REGULARIZER" }, + { name = "SHAPE" }, + { name = "SPLITS" }, + { name = "K" }, + { name = "SORTED" }, + { name = "COMBINE_DIM" }, + { name = "COMBINE_DEGREE" }, + { name = "NUM_INPUTS" }, +] diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.h new file mode 100644 index 0000000000..5a30c40f8d --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.h @@ -0,0 +1,67 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml +/* proj-data +{ + "generated_from": "1dc90d1e823f05b82c1a5ff433fbf000" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributeListIndexAccess { + OperatorAttributeListIndexAccess() = delete; + OperatorAttributeListIndexAccess( + ::FlexFlow::OperatorAttributeKey const &attribute_key, int const &index); + + bool operator==(OperatorAttributeListIndexAccess const &) const; + bool operator!=(OperatorAttributeListIndexAccess const &) const; + bool operator<(OperatorAttributeListIndexAccess const &) const; + bool operator>(OperatorAttributeListIndexAccess const &) const; + bool operator<=(OperatorAttributeListIndexAccess const &) const; + bool operator>=(OperatorAttributeListIndexAccess const &) const; + ::FlexFlow::OperatorAttributeKey attribute_key; + int index; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorAttributeListIndexAccess const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::OperatorAttributeListIndexAccess from_json(json const &); + static void to_json(json &, + FlexFlow::OperatorAttributeListIndexAccess const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(OperatorAttributeListIndexAccess const &); +std::ostream &operator<<(std::ostream &, + OperatorAttributeListIndexAccess const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml new file mode 100644 index 0000000000..bceff393d2 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "OperatorAttributeListIndexAccess" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h" +] + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::OperatorAttributeKey" + +[[fields]] +name = "index" +type = "int" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.h new file mode 100644 index 0000000000..17d76a08f1 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml +/* proj-data +{ + "generated_from": "30999ad6b0603e380bc33d32fa088e45" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributeListSize { + OperatorAttributeListSize() = delete; + OperatorAttributeListSize( + ::FlexFlow::OperatorAttributeKey const &attribute_key); + + bool operator==(OperatorAttributeListSize const &) const; + bool operator!=(OperatorAttributeListSize const &) const; + bool operator<(OperatorAttributeListSize const &) const; + bool operator>(OperatorAttributeListSize const &) const; + bool operator<=(OperatorAttributeListSize const &) const; + bool operator>=(OperatorAttributeListSize const &) const; + ::FlexFlow::OperatorAttributeKey attribute_key; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorAttributeListSize const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::OperatorAttributeListSize from_json(json const &); + static void to_json(json &, FlexFlow::OperatorAttributeListSize const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(OperatorAttributeListSize const &); +std::ostream &operator<<(std::ostream &, OperatorAttributeListSize const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml new file mode 100644 index 0000000000..271b545fda --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "OperatorAttributeListSize" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h", +] + + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::OperatorAttributeKey" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h new file mode 100644 index 0000000000..7bce198f3d --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h @@ -0,0 +1,56 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml +/* proj-data +{ + "generated_from": "968d7a3e93303a7fa7482bbcd50246b6" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_PATTERN_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_PATTERN_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" +#include "utils/fmt.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributePattern { + OperatorAttributePattern() = delete; + OperatorAttributePattern( + std::unordered_set<::FlexFlow::OperatorAttributeConstraint> const + &attribute_constraints); + + bool operator==(OperatorAttributePattern const &) const; + bool operator!=(OperatorAttributePattern const &) const; + std::unordered_set<::FlexFlow::OperatorAttributeConstraint> + attribute_constraints; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorAttributePattern const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::OperatorAttributePattern from_json(json const &); + static void to_json(json &, FlexFlow::OperatorAttributePattern const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(OperatorAttributePattern const &); +std::ostream &operator<<(std::ostream &, OperatorAttributePattern const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_PATTERN_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml new file mode 100644 index 0000000000..6facf7d3bc --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OperatorAttributePattern" +features = [ + "eq", + # "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "utils/fmt.h", + "substitutions/operator_pattern/operator_attribute_constraint.dtg.h", +] + +[[fields]] +name = "attribute_constraints" +type = "std::unordered_set<::FlexFlow::OperatorAttributeConstraint>" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.h new file mode 100644 index 0000000000..080909d147 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.h @@ -0,0 +1,264 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +/* proj-data +{ + "generated_from": "de14592f1f4bcfb52689bc95e9d3b55f" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_VALUE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_VALUE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/activation.dtg.h" +#include "op-attrs/aggregate_op.dtg.h" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/operator_type.dtg.h" +#include "op-attrs/pool_op.dtg.h" +#include "op-attrs/regularizer_attrs.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include +#include +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributeValue { + OperatorAttributeValue() = delete; + OperatorAttributeValue(int const &); + OperatorAttributeValue(bool const &); + OperatorAttributeValue(std::vector const &); + OperatorAttributeValue(std::vector<::FlexFlow::ff_dim_t> const &); + OperatorAttributeValue(::FlexFlow::OperatorType const &); + OperatorAttributeValue(::FlexFlow::Activation const &); + OperatorAttributeValue(::FlexFlow::ff_dim_t const &); + OperatorAttributeValue(size_t const &); + OperatorAttributeValue(::FlexFlow::AggregateOp const &); + OperatorAttributeValue(std::optional<::FlexFlow::RegularizerAttrs> const &); + OperatorAttributeValue(::FlexFlow::PoolOp const &); + OperatorAttributeValue(::FlexFlow::TensorShape const &); + OperatorAttributeValue(::FlexFlow::DataType const &); + template + static constexpr bool IsPartOfOperatorAttributeValue_v = + std::is_same_v || std::is_same_v || + std::is_same_v> || + std::is_same_v> || + std::is_same_v || + std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || + std::is_same_v> || + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get()); + return result; + } + case 1: { + ReturnType result = v(this->get()); + return result; + } + case 2: { + ReturnType result = v(this->get>()); + return result; + } + case 3: { + ReturnType result = v(this->get>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::OperatorType>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Activation>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::ff_dim_t>()); + return result; + } + case 7: { + ReturnType result = v(this->get()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::AggregateOp>()); + return result; + } + case 9: { + ReturnType result = + v(this->get>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::PoolOp>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::TensorShape>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::DataType>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeValue", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get()); + return result; + } + case 1: { + ReturnType result = v(this->get()); + return result; + } + case 2: { + ReturnType result = v(this->get>()); + return result; + } + case 3: { + ReturnType result = v(this->get>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::OperatorType>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Activation>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::ff_dim_t>()); + return result; + } + case 7: { + ReturnType result = v(this->get()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::AggregateOp>()); + return result; + } + case 9: { + ReturnType result = + v(this->get>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::PoolOp>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::TensorShape>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::DataType>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeValue", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfOperatorAttributeValue_v, + "OperatorAttributeValue::has() expected one of [int, bool, " + "std::vector, std::vector<::FlexFlow::ff_dim_t>, " + "::FlexFlow::OperatorType, ::FlexFlow::Activation, " + "::FlexFlow::ff_dim_t, size_t, ::FlexFlow::AggregateOp, " + "std::optional<::FlexFlow::RegularizerAttrs>, ::FlexFlow::PoolOp, " + "::FlexFlow::TensorShape, ::FlexFlow::DataType], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfOperatorAttributeValue_v, + "OperatorAttributeValue::get() expected one of [int, bool, " + "std::vector, std::vector<::FlexFlow::ff_dim_t>, " + "::FlexFlow::OperatorType, ::FlexFlow::Activation, " + "::FlexFlow::ff_dim_t, size_t, ::FlexFlow::AggregateOp, " + "std::optional<::FlexFlow::RegularizerAttrs>, ::FlexFlow::PoolOp, " + "::FlexFlow::TensorShape, ::FlexFlow::DataType], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfOperatorAttributeValue_v, + "OperatorAttributeValue::get() expected one of [int, bool, " + "std::vector, std::vector<::FlexFlow::ff_dim_t>, " + "::FlexFlow::OperatorType, ::FlexFlow::Activation, " + "::FlexFlow::ff_dim_t, size_t, ::FlexFlow::AggregateOp, " + "std::optional<::FlexFlow::RegularizerAttrs>, ::FlexFlow::PoolOp, " + "::FlexFlow::TensorShape, ::FlexFlow::DataType], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(OperatorAttributeValue const &) const; + bool operator!=(OperatorAttributeValue const &) const; + bool operator<(OperatorAttributeValue const &) const; + bool operator>(OperatorAttributeValue const &) const; + bool operator<=(OperatorAttributeValue const &) const; + bool operator>=(OperatorAttributeValue const &) const; + std::variant, + std::vector<::FlexFlow::ff_dim_t>, + ::FlexFlow::OperatorType, + ::FlexFlow::Activation, + ::FlexFlow::ff_dim_t, + size_t, + ::FlexFlow::AggregateOp, + std::optional<::FlexFlow::RegularizerAttrs>, + ::FlexFlow::PoolOp, + ::FlexFlow::TensorShape, + ::FlexFlow::DataType> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::OperatorAttributeValue> { + size_t operator()(::FlexFlow::OperatorAttributeValue const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::OperatorAttributeValue> { + static ::FlexFlow::OperatorAttributeValue from_json(json const &); + static void to_json(json &, ::FlexFlow::OperatorAttributeValue const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::OperatorAttributeValue const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::OperatorAttributeValue const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_VALUE_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml new file mode 100644 index 0000000000..9ab88e63c2 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml @@ -0,0 +1,63 @@ +namespace = "FlexFlow" +name = "OperatorAttributeValue" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", +] +explicit_constructors = false + +includes = [ + "", + "", + "op-attrs/operator_type.dtg.h", + "op-attrs/ff_dim.dtg.h", + "op-attrs/activation.dtg.h", + "op-attrs/aggregate_op.dtg.h", + "op-attrs/regularizer_attrs.dtg.h", + "op-attrs/pool_op.dtg.h", + "op-attrs/tensor_shape.dtg.h", + "op-attrs/datatype.dtg.h", + "", +] + +[[values]] +type = "int" + +[[values]] +type = "bool" + +[[values]] +type = "std::vector" + +[[values]] +type = "std::vector<::FlexFlow::ff_dim_t>" + +[[values]] +type = "::FlexFlow::OperatorType" + +[[values]] +type = "::FlexFlow::Activation" + +[[values]] +type = "::FlexFlow::ff_dim_t" + +[[values]] +type = "size_t" + +[[values]] +type = "::FlexFlow::AggregateOp" + +[[values]] +type = "std::optional<::FlexFlow::RegularizerAttrs>" + +[[values]] +type = "::FlexFlow::PoolOp" + +[[values]] +type = "::FlexFlow::TensorShape" + +[[values]] +type = "::FlexFlow::DataType" diff --git a/lib/substitutions/include/substitutions/operator_pattern/satisfies_constraint.h b/lib/substitutions/include/substitutions/operator_pattern/satisfies_constraint.h new file mode 100644 index 0000000000..7ddda2219c --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/satisfies_constraint.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_SATISFIES_CONSTRAINT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_SATISFIES_CONSTRAINT_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" + +namespace FlexFlow { + +bool operator_satisfies_constraint( + PCGOperatorAttrs const ¶ms, + OperatorAttributeConstraint const &constraint); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/satisfies_pattern.h b/lib/substitutions/include/substitutions/operator_pattern/satisfies_pattern.h new file mode 100644 index 0000000000..ca4d5c13fa --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/satisfies_pattern.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_SATISFIES_PATTERN_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_SATISFIES_PATTERN_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" + +namespace FlexFlow { + +bool operator_satisfies_pattern(PCGOperatorAttrs const &attrs, + OperatorAttributePattern const &pattern); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h deleted file mode 100644 index 4ed90aed06..0000000000 --- a/lib/substitutions/include/substitutions/output_graph.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_OUTPUT_GRAPH_H -#define _FLEXFLOW_SUBSTITUTIONS_OUTPUT_GRAPH_H - -#include "utils/graph.h" - -namespace FlexFlow { - -// NOTE(@wmdi) I am not sure whether these should be part of attribute expr. -struct OperatorAttrAccess { - Node node; - AttributeExpr attr_expr; -}; - -struct AttrConstant { - OperatorAttributeValue value; -}; - -using OperatorAttributeExpr = std::variant; - -// NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can -// define the assignment for each operator type. -struct OperatorAttrAssignment { - std::unordered_map assignments; -}; - -struct OutputGraphExpr - : public strong_typedef< - OutputGraphExpr, - NodeLabelledOpenMultiDiGraph> { - using strong_typedef::strong_typedef; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.h b/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.h new file mode 100644 index 0000000000..9dd20bb10e --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml +/* proj-data +{ + "generated_from": "1e5beabcb8e3657d8fe9c9c8b1310cb1" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_ATTR_CONSTANT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_ATTR_CONSTANT_DTG_H + +#include "fmt/format.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct AttrConstant { + AttrConstant() = delete; + AttrConstant(::FlexFlow::OperatorAttributeValue const &value); + + bool operator==(AttrConstant const &) const; + bool operator!=(AttrConstant const &) const; + bool operator<(AttrConstant const &) const; + bool operator>(AttrConstant const &) const; + bool operator<=(AttrConstant const &) const; + bool operator>=(AttrConstant const &) const; + ::FlexFlow::OperatorAttributeValue value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::AttrConstant const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(AttrConstant const &); +std::ostream &operator<<(std::ostream &, AttrConstant const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_ATTR_CONSTANT_DTG_H diff --git a/lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml b/lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml new file mode 100644 index 0000000000..68973f9c0c --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "AttrConstant" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_value.dtg.h", +] + +[[fields]] +name = "value" +type = "::FlexFlow::OperatorAttributeValue" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h new file mode 100644 index 0000000000..3d6fb21574 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h @@ -0,0 +1,28 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml +/* proj-data +{ + "generated_from": "9084c9afb2724504a6f4db4288a83a0d" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_DTG_H + +#include "substitutions/output_graph/output_operator_attrs_assignment.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +struct OutputGraphExpr { + OutputGraphExpr() = delete; + OutputGraphExpr(::FlexFlow::NodeLabelledOpenMultiDiGraph< + ::FlexFlow::OutputOperatorAttrsAssignment> const &raw_graph); + + ::FlexFlow::NodeLabelledOpenMultiDiGraph< + ::FlexFlow::OutputOperatorAttrsAssignment> + raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_DTG_H diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml new file mode 100644 index 0000000000..37d87f7820 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "OutputGraphExpr" +features = [] + +includes = [ + "utils/graph.h", + "substitutions/output_graph/output_operator_attrs_assignment.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::NodeLabelledOpenMultiDiGraph<::FlexFlow::OutputOperatorAttrsAssignment>" diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.h new file mode 100644 index 0000000000..0d585f0aa0 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.h @@ -0,0 +1,49 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml +/* proj-data +{ + "generated_from": "e3b3a741183fcb38cfa68aacb82e12d1" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTR_ACCESS_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTR_ACCESS_DTG_H + +#include "fmt/format.h" +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct OutputOperatorAttrAccess { + OutputOperatorAttrAccess() = delete; + OutputOperatorAttrAccess(::FlexFlow::Node const &node, + ::FlexFlow::OperatorAttributeExpr const &attr_expr); + + bool operator==(OutputOperatorAttrAccess const &) const; + bool operator!=(OutputOperatorAttrAccess const &) const; + bool operator<(OutputOperatorAttrAccess const &) const; + bool operator>(OutputOperatorAttrAccess const &) const; + bool operator<=(OutputOperatorAttrAccess const &) const; + bool operator>=(OutputOperatorAttrAccess const &) const; + ::FlexFlow::Node node; + ::FlexFlow::OperatorAttributeExpr attr_expr; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OutputOperatorAttrAccess const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputOperatorAttrAccess const &); +std::ostream &operator<<(std::ostream &, OutputOperatorAttrAccess const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTR_ACCESS_DTG_H diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml new file mode 100644 index 0000000000..51aae54730 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "OutputOperatorAttrAccess" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph.h", + "substitutions/operator_pattern/operator_attribute_expr.dtg.h", +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +# NOTE(@wmdi) I am not sure whether these should be part of attribute expr. +[[fields]] +name = "attr_expr" +type = "::FlexFlow::OperatorAttributeExpr" + diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.dtg.h new file mode 100644 index 0000000000..327c230b61 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.dtg.h @@ -0,0 +1,119 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "89ebf777a5b909eef78ab5a5a177e041" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRIBUTE_EXPR_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRIBUTE_EXPR_DTG_H + +#include "fmt/format.h" +#include "substitutions/output_graph/attr_constant.dtg.h" +#include "substitutions/output_graph/output_operator_attr_access.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct OutputOperatorAttributeExpr { + OutputOperatorAttributeExpr() = delete; + explicit OutputOperatorAttributeExpr( + ::FlexFlow::OutputOperatorAttrAccess const &); + explicit OutputOperatorAttributeExpr(::FlexFlow::AttrConstant const &); + template + static constexpr bool IsPartOfOutputOperatorAttributeExpr_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = + v(this->get<::FlexFlow::OutputOperatorAttrAccess>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::AttrConstant>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type OutputOperatorAttributeExpr", + this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = + v(this->get<::FlexFlow::OutputOperatorAttrAccess>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::AttrConstant>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type OutputOperatorAttributeExpr", + this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfOutputOperatorAttributeExpr_v, + "OutputOperatorAttributeExpr::has() expected one of " + "[::FlexFlow::OutputOperatorAttrAccess, " + "::FlexFlow::AttrConstant], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfOutputOperatorAttributeExpr_v, + "OutputOperatorAttributeExpr::get() expected one of " + "[::FlexFlow::OutputOperatorAttrAccess, " + "::FlexFlow::AttrConstant], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfOutputOperatorAttributeExpr_v, + "OutputOperatorAttributeExpr::get() expected one of " + "[::FlexFlow::OutputOperatorAttrAccess, " + "::FlexFlow::AttrConstant], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(OutputOperatorAttributeExpr const &) const; + bool operator!=(OutputOperatorAttributeExpr const &) const; + bool operator<(OutputOperatorAttributeExpr const &) const; + bool operator>(OutputOperatorAttributeExpr const &) const; + bool operator<=(OutputOperatorAttributeExpr const &) const; + bool operator>=(OutputOperatorAttributeExpr const &) const; + std::variant<::FlexFlow::OutputOperatorAttrAccess, ::FlexFlow::AttrConstant> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::OutputOperatorAttributeExpr> { + size_t operator()(::FlexFlow::OutputOperatorAttributeExpr const &) const; +}; +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::OutputOperatorAttributeExpr const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::OutputOperatorAttributeExpr const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRIBUTE_EXPR_DTG_H diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml new file mode 100644 index 0000000000..19810a0151 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "OutputOperatorAttributeExpr" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "substitutions/output_graph/attr_constant.dtg.h", + "substitutions/output_graph/output_operator_attr_access.dtg.h", +] + +[[values]] +type = "::FlexFlow::OutputOperatorAttrAccess" +key = "attr_ref" + +[[values]] +type = "::FlexFlow::AttrConstant" +key = "constant" diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.h new file mode 100644 index 0000000000..5586a90a08 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.h @@ -0,0 +1,49 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml +/* proj-data +{ + "generated_from": "bbfb309c5a39a729da23dace4df4a9de" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRS_ASSIGNMENT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRS_ASSIGNMENT_DTG_H + +#include "fmt/format.h" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include "substitutions/output_graph/output_operator_attribute_expr.dtg.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct OutputOperatorAttrsAssignment { + OutputOperatorAttrsAssignment() = delete; + OutputOperatorAttrsAssignment( + std::unordered_map<::FlexFlow::OperatorAttributeKey, + ::FlexFlow::OutputOperatorAttributeExpr> const + &assignments); + + bool operator==(OutputOperatorAttrsAssignment const &) const; + bool operator!=(OutputOperatorAttrsAssignment const &) const; + std::unordered_map<::FlexFlow::OperatorAttributeKey, + ::FlexFlow::OutputOperatorAttributeExpr> + assignments; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OutputOperatorAttrsAssignment const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputOperatorAttrsAssignment const &); +std::ostream &operator<<(std::ostream &, OutputOperatorAttrsAssignment const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRS_ASSIGNMENT_DTG_H diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml new file mode 100644 index 0000000000..9781515803 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "OutputOperatorAttrsAssignment" +features = [ + "eq", + # "ord", + "hash", + # "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h", + "substitutions/output_graph/output_operator_attribute_expr.dtg.h", + "", +] + +# NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can +# define the assignment for each operator type. +[[fields]] +name = "assignments" +type = "std::unordered_map<::FlexFlow::OperatorAttributeKey, ::FlexFlow::OutputOperatorAttributeExpr>" diff --git a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h deleted file mode 100644 index 741554142f..0000000000 --- a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_TENSOR_PATTERN_H -#define _FLEXFLOW_SUBSTITUTIONS_TENSOR_PATTERN_H - -#include "attribute_expr.h" -#include "pcg/parallel_tensor.h" - -namespace FlexFlow { - -enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; - -using TensorAttributeValue = std::variant>; - -using TensorAttributeConstraint = - AttributeConstraint; - -using ParallelTensorPattern = - AttributePattern; - -std::optional - evaluate_attribute_expr(ParallelTensor const &tensor_shape, - AttributeExpr const &expr); - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/pcg_pattern.dtg.h b/lib/substitutions/include/substitutions/pcg_pattern.dtg.h new file mode 100644 index 0000000000..0c0cc41891 --- /dev/null +++ b/lib/substitutions/include/substitutions/pcg_pattern.dtg.h @@ -0,0 +1,31 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/pcg_pattern.struct.toml +/* proj-data +{ + "generated_from": "f536f846828ba39266dd4a1fbaeec0e6" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_DTG_H + +#include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +struct PCGPattern { + PCGPattern() = delete; + PCGPattern(::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::OperatorAttributePattern, + ::FlexFlow::TensorAttributePattern> const &raw_graph); + + ::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::OperatorAttributePattern, + ::FlexFlow::TensorAttributePattern> + raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_DTG_H diff --git a/lib/substitutions/include/substitutions/pcg_pattern.struct.toml b/lib/substitutions/include/substitutions/pcg_pattern.struct.toml new file mode 100644 index 0000000000..191d66a38c --- /dev/null +++ b/lib/substitutions/include/substitutions/pcg_pattern.struct.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "PCGPattern" +features = [] +includes = [ + "utils/graph.h", + "substitutions/operator_pattern/operator_attribute_pattern.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::OutputLabelledOpenMultiDiGraph<::FlexFlow::OperatorAttributePattern, ::FlexFlow::TensorAttributePattern>" diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h new file mode 100644 index 0000000000..d31d65d83b --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h @@ -0,0 +1,31 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml +/* proj-data +{ + "generated_from": "0022d1b2c1447667695a120c154a0168" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_DTG_H + +#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +struct SubParallelComputationGraph { + SubParallelComputationGraph() = delete; + SubParallelComputationGraph( + ::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const &raw_graph); + + ::FlexFlow::OutputLabelledOpenMultiDiGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> + raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_DTG_H diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 0d6bfe7628..5d40f3f975 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -1,18 +1,17 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H -#define _FLEXFLOW_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H -#include "op-attrs/operator_attrs.h" -#include "pcg/machine_view.h" -#include "pcg/operator.h" -#include "pcg/parallel_tensor.h" -#include "utils/graph.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" namespace FlexFlow { -using SubParallelComputationGraph = - OutputLabelledOpenMultiDiGraph; - -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(SubParallelComputationGraph); +ParallelLayerAttrs get_parallel_layer_attrs(SubParallelComputationGraph const &, + Node const &); +PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &, + Node const &); +ParallelTensorAttrs + get_parallel_tensor_attrs(SubParallelComputationGraph const &, + OpenMultiDiEdge const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml new file mode 100644 index 0000000000..1ba04b544c --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "SubParallelComputationGraph" +features = [ ] + +includes = [ + "pcg/parallel_layer_attrs.dtg.h", + "pcg/parallel_tensor_attrs.dtg.h", + "utils/graph.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::OutputLabelledOpenMultiDiGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/substitutions/include/substitutions/substitution.dtg.h b/lib/substitutions/include/substitutions/substitution.dtg.h new file mode 100644 index 0000000000..5f50d9bafc --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution.dtg.h @@ -0,0 +1,38 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/substitution.struct.toml +/* proj-data +{ + "generated_from": "c101f1d63e2d8d80a0ec9c5f5db4fa12" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_DTG_H + +#include "substitutions/output_graph/output_graph_expr.dtg.h" +#include "substitutions/pcg_pattern.dtg.h" + +namespace FlexFlow { +struct Substitution { + Substitution() = delete; + Substitution(::FlexFlow::PCGPattern const &pcg_pattern, + ::FlexFlow::OutputGraphExpr const &output_graph_expr, + ::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, + ::FlexFlow::InputMultiDiEdge> const + &input_edge_match_to_output, + ::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::OutputMultiDiEdge> const + &output_edge_match_to_output); + + ::FlexFlow::PCGPattern pcg_pattern; + ::FlexFlow::OutputGraphExpr output_graph_expr; + ::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, ::FlexFlow::InputMultiDiEdge> + input_edge_match_to_output; + ::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::OutputMultiDiEdge> + output_edge_match_to_output; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_DTG_H diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index 8dbe4e66cf..1aa2b2946b 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -1,24 +1,12 @@ #ifndef _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTION_H #define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTION_H -#include "graph_pattern.h" -#include "output_graph.h" -#include "sub_parallel_computation_graph.h" +#include "sub_parallel_computation_graph.dtg.h" +#include "substitutions/substitution.dtg.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" namespace FlexFlow { -struct Substitution { - using InputPatternInput = InputMultiDiEdge; - using InputPatternOutput = OutputMultiDiEdge; - using OutputPatternInput = InputMultiDiEdge; - using OutputPatternOutput = OutputMultiDiEdge; - - GraphPattern input_graph; - OutputGraphExpr output_graph_expr; - bidict input_mapping; - bidict output_mapping; -}; - bool is_valid_substitution(Substitution const &); SubParallelComputationGraph @@ -28,12 +16,4 @@ SubParallelComputationGraph } // namespace FlexFlow -namespace std { -template <> -struct hash { - size_t operator()(FlexFlow::Substitution const &) const; -}; - -}; // namespace std - #endif diff --git a/lib/substitutions/include/substitutions/substitution.struct.toml b/lib/substitutions/include/substitutions/substitution.struct.toml new file mode 100644 index 0000000000..eb630e9308 --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "Substitution" +features = [] + +includes = [ + "substitutions/pcg_pattern.dtg.h", + "substitutions/output_graph/output_graph_expr.dtg.h", +] + +[[fields]] +name = "pcg_pattern" +type = "::FlexFlow::PCGPattern" + +[[fields]] +name = "output_graph_expr" +type = "::FlexFlow::OutputGraphExpr" + +[[fields]] +name = "input_edge_match_to_output" +type = "::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, ::FlexFlow::InputMultiDiEdge>" + +[[fields]] +name = "output_edge_match_to_output" +type = "::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::OutputMultiDiEdge>" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h new file mode 100644 index 0000000000..e245e800b2 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_ACCESS_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_ACCESS_H + +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" + +namespace FlexFlow { + +TensorAttributeValue eval_list_access(ParallelTensorAttrs const &attrs, + TensorAttributeListIndexAccess const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h new file mode 100644 index 0000000000..de0d58e14f --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_SIZE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_SIZE_H + +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" + +namespace FlexFlow { + +TensorAttributeValue eval_list_size(ParallelTensorAttrs const &attrs, + TensorAttributeListSize const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h b/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h new file mode 100644 index 0000000000..eedca2da82 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_GET_ATTRIBUTE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_GET_ATTRIBUTE_H + +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" + +namespace FlexFlow { + +TensorAttributeValue get_attribute(ParallelTensorAttrs const &, + TensorAttributeKey); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h new file mode 100644 index 0000000000..6c11b421a8 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_CONSTRAINT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_CONSTRAINT_H + +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" + +namespace FlexFlow { + +bool parallel_tensor_satisfies_constraint( + ParallelTensorAttrs const ¶ms, + TensorAttributeConstraint const &constraint); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h new file mode 100644 index 0000000000..b8b46669c6 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_PATTERN_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_PATTERN_H + +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" + +namespace FlexFlow { + +bool parallel_tensor_satisfies_pattern(ParallelTensorAttrs const &attrs, + TensorAttributePattern const &pattern); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h new file mode 100644 index 0000000000..ba705a5d35 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml +/* proj-data +{ + "generated_from": "29dbf81668bc864b06af52261060335e" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_CONSTRAINT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_CONSTRAINT_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/constraint_type.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributeConstraint { + TensorAttributeConstraint() = delete; + TensorAttributeConstraint( + ::FlexFlow::ConstraintType const &constraint_type, + ::FlexFlow::TensorAttributeExpr const &attribute_expr, + ::FlexFlow::TensorAttributeValue const &attribute_value); + + bool operator==(TensorAttributeConstraint const &) const; + bool operator!=(TensorAttributeConstraint const &) const; + bool operator<(TensorAttributeConstraint const &) const; + bool operator>(TensorAttributeConstraint const &) const; + bool operator<=(TensorAttributeConstraint const &) const; + bool operator>=(TensorAttributeConstraint const &) const; + ::FlexFlow::ConstraintType constraint_type; + ::FlexFlow::TensorAttributeExpr attribute_expr; + ::FlexFlow::TensorAttributeValue attribute_value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttributeConstraint const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorAttributeConstraint from_json(json const &); + static void to_json(json &, FlexFlow::TensorAttributeConstraint const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttributeConstraint const &); +std::ostream &operator<<(std::ostream &, TensorAttributeConstraint const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_CONSTRAINT_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml new file mode 100644 index 0000000000..6aba719e08 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "TensorAttributeConstraint" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "substitutions/constraint_type.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_value.dtg.h", +] + +[[fields]] +name = "constraint_type" +type = "::FlexFlow::ConstraintType" + +[[fields]] +name = "attribute_expr" +type = "::FlexFlow::TensorAttributeExpr" + +[[fields]] +name = "attribute_value" +type = "::FlexFlow::TensorAttributeValue" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.dtg.h new file mode 100644 index 0000000000..d34be357c5 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.dtg.h @@ -0,0 +1,141 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "b91285329f12f1b409805cbf9be575b2" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributeExpr { + TensorAttributeExpr() = delete; + explicit TensorAttributeExpr(::FlexFlow::TensorAttributeKey const &); + explicit TensorAttributeExpr(::FlexFlow::TensorAttributeListSize const &); + explicit TensorAttributeExpr( + ::FlexFlow::TensorAttributeListIndexAccess const &); + template + static constexpr bool IsPartOfTensorAttributeExpr_v = + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::TensorAttributeKey>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::TensorAttributeListSize>()); + return result; + } + case 2: { + ReturnType result = + v(this->get<::FlexFlow::TensorAttributeListIndexAccess>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeExpr", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::TensorAttributeKey>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::TensorAttributeListSize>()); + return result; + } + case 2: { + ReturnType result = + v(this->get<::FlexFlow::TensorAttributeListIndexAccess>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeExpr", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfTensorAttributeExpr_v, + "TensorAttributeExpr::has() expected one of " + "[::FlexFlow::TensorAttributeKey, ::FlexFlow::TensorAttributeListSize, " + "::FlexFlow::TensorAttributeListIndexAccess], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfTensorAttributeExpr_v, + "TensorAttributeExpr::get() expected one of " + "[::FlexFlow::TensorAttributeKey, ::FlexFlow::TensorAttributeListSize, " + "::FlexFlow::TensorAttributeListIndexAccess], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfTensorAttributeExpr_v, + "TensorAttributeExpr::get() expected one of " + "[::FlexFlow::TensorAttributeKey, ::FlexFlow::TensorAttributeListSize, " + "::FlexFlow::TensorAttributeListIndexAccess], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(TensorAttributeExpr const &) const; + bool operator!=(TensorAttributeExpr const &) const; + bool operator<(TensorAttributeExpr const &) const; + bool operator>(TensorAttributeExpr const &) const; + bool operator<=(TensorAttributeExpr const &) const; + bool operator>=(TensorAttributeExpr const &) const; + std::variant<::FlexFlow::TensorAttributeKey, + ::FlexFlow::TensorAttributeListSize, + ::FlexFlow::TensorAttributeListIndexAccess> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::TensorAttributeExpr> { + size_t operator()(::FlexFlow::TensorAttributeExpr const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::TensorAttributeExpr> { + static ::FlexFlow::TensorAttributeExpr from_json(json const &); + static void to_json(json &, ::FlexFlow::TensorAttributeExpr const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::TensorAttributeExpr const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::TensorAttributeExpr const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h new file mode 100644 index 0000000000..98d4394530 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_H + +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" + +namespace FlexFlow { + +TensorAttributeValue evaluate_attribute_expr(ParallelTensorAttrs const &attrs, + TensorAttributeExpr const &expr); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml new file mode 100644 index 0000000000..03ec0eb624 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "TensorAttributeExpr" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "substitutions/tensor_pattern/tensor_attribute_key.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h", +] + +[[values]] +type = "::FlexFlow::TensorAttributeKey" +key = "key" + +[[values]] +type = "::FlexFlow::TensorAttributeListSize" +key = "list_size" + +[[values]] +type = "::FlexFlow::TensorAttributeListIndexAccess" +key = "list_idx" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.dtg.h new file mode 100644 index 0000000000..50a0aa49e8 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml +/* proj-data +{ + "generated_from": "63a7c40c1e5b582f98b59750a35f0a08" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_KEY_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_KEY_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; +std::string format_as(TensorAttributeKey); +std::ostream &operator<<(std::ostream &, TensorAttributeKey); +void to_json(::nlohmann::json &, TensorAttributeKey); +void from_json(::nlohmann::json const &, TensorAttributeKey &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttributeKey) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_KEY_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml new file mode 100644 index 0000000000..3df36d13ac --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "TensorAttributeKey" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "DIM_SIZES" + +[[values]] +name = "DIM_DEGREES" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h new file mode 100644 index 0000000000..473f4e1698 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml +/* proj-data +{ + "generated_from": "41f5449cd700b6d7ab017f3efa39dc1d" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_ACCESS_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_ACCESS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributeListIndexAccess { + TensorAttributeListIndexAccess() = delete; + TensorAttributeListIndexAccess( + ::FlexFlow::TensorAttributeKey const &attribute_key, int const &index); + + bool operator==(TensorAttributeListIndexAccess const &) const; + bool operator!=(TensorAttributeListIndexAccess const &) const; + bool operator<(TensorAttributeListIndexAccess const &) const; + bool operator>(TensorAttributeListIndexAccess const &) const; + bool operator<=(TensorAttributeListIndexAccess const &) const; + bool operator>=(TensorAttributeListIndexAccess const &) const; + ::FlexFlow::TensorAttributeKey attribute_key; + int index; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttributeListIndexAccess const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorAttributeListIndexAccess from_json(json const &); + static void to_json(json &, FlexFlow::TensorAttributeListIndexAccess const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorAttributeListIndexAccess const &); +std::ostream &operator<<(std::ostream &, + TensorAttributeListIndexAccess const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_ACCESS_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml new file mode 100644 index 0000000000..a57dd25845 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "TensorAttributeListIndexAccess" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +] + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::TensorAttributeKey" + +[[fields]] +name = "index" +type = "int" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h new file mode 100644 index 0000000000..1630014bdf --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml +/* proj-data +{ + "generated_from": "ec72cd39de5d1c0f0478696d7b83e4e9" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_SIZE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_SIZE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributeListSize { + TensorAttributeListSize() = delete; + TensorAttributeListSize(::FlexFlow::TensorAttributeKey const &attribute_key); + + bool operator==(TensorAttributeListSize const &) const; + bool operator!=(TensorAttributeListSize const &) const; + bool operator<(TensorAttributeListSize const &) const; + bool operator>(TensorAttributeListSize const &) const; + bool operator<=(TensorAttributeListSize const &) const; + bool operator>=(TensorAttributeListSize const &) const; + ::FlexFlow::TensorAttributeKey attribute_key; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttributeListSize const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorAttributeListSize from_json(json const &); + static void to_json(json &, FlexFlow::TensorAttributeListSize const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorAttributeListSize const &); +std::ostream &operator<<(std::ostream &, TensorAttributeListSize const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_SIZE_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml new file mode 100644 index 0000000000..c876696343 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "TensorAttributeListSize" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/tensor_pattern/tensor_attribute_key.dtg.h", +] + + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::TensorAttributeKey" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h new file mode 100644 index 0000000000..ecc4bc7da0 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h @@ -0,0 +1,56 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml +/* proj-data +{ + "generated_from": "42a51afce383f1ddc3d70827aa94a68f" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" +#include "utils/hash-utils.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributePattern { + TensorAttributePattern() = delete; + TensorAttributePattern( + std::unordered_set<::FlexFlow::TensorAttributeConstraint> const + &attribute_constraints); + + bool operator==(TensorAttributePattern const &) const; + bool operator!=(TensorAttributePattern const &) const; + std::unordered_set<::FlexFlow::TensorAttributeConstraint> + attribute_constraints; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttributePattern const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorAttributePattern from_json(json const &); + static void to_json(json &, FlexFlow::TensorAttributePattern const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttributePattern const &); +std::ostream &operator<<(std::ostream &, TensorAttributePattern const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml new file mode 100644 index 0000000000..43f45e95b9 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "TensorAttributePattern" +features = [ + "eq", + # "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h", + "utils/hash-utils.h", +] + +[[fields]] +name = "attribute_constraints" +type = "std::unordered_set<::FlexFlow::TensorAttributeConstraint>" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.h new file mode 100644 index 0000000000..948a7abae6 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.h @@ -0,0 +1,118 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml +/* proj-data +{ + "generated_from": "d80cf2e618d64df284c2647430a12a86" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_VALUE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_VALUE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "utils/fmt.h" +#include "utils/hash-utils-core.h" +#include +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributeValue { + TensorAttributeValue() = delete; + explicit TensorAttributeValue(size_t const &); + explicit TensorAttributeValue(std::vector const &); + template + static constexpr bool IsPartOfTensorAttributeValue_v = + std::is_same_v || std::is_same_v>; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get()); + return result; + } + case 1: { + ReturnType result = v(this->get>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeValue", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get()); + return result; + } + case 1: { + ReturnType result = v(this->get>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeValue", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfTensorAttributeValue_v, + "TensorAttributeValue::has() expected one of [size_t, " + "std::vector], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfTensorAttributeValue_v, + "TensorAttributeValue::get() expected one of [size_t, " + "std::vector], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfTensorAttributeValue_v, + "TensorAttributeValue::get() expected one of [size_t, " + "std::vector], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(TensorAttributeValue const &) const; + bool operator!=(TensorAttributeValue const &) const; + bool operator<(TensorAttributeValue const &) const; + bool operator>(TensorAttributeValue const &) const; + bool operator<=(TensorAttributeValue const &) const; + bool operator>=(TensorAttributeValue const &) const; + std::variant> raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::TensorAttributeValue> { + size_t operator()(::FlexFlow::TensorAttributeValue const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::TensorAttributeValue> { + static ::FlexFlow::TensorAttributeValue from_json(json const &); + static void to_json(json &, ::FlexFlow::TensorAttributeValue const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::TensorAttributeValue const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::TensorAttributeValue const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_VALUE_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml new file mode 100644 index 0000000000..91313f159b --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "TensorAttributeValue" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "", + "utils/hash-utils-core.h", + "utils/fmt.h", +] + +[[values]] +type = "size_t" + +[[values]] +type = "std::vector" diff --git a/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.dtg.h new file mode 100644 index 0000000000..6bf815791d --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "b4086fd78ca7ec0475ed7abfd034304c" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_CLOSED_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_CLOSED_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct ClosedPatternEdge { + ClosedPatternEdge() = delete; + ClosedPatternEdge(::FlexFlow::MultiDiEdge const &raw_edge); + + bool operator==(ClosedPatternEdge const &) const; + bool operator!=(ClosedPatternEdge const &) const; + bool operator<(ClosedPatternEdge const &) const; + bool operator>(ClosedPatternEdge const &) const; + bool operator<=(ClosedPatternEdge const &) const; + bool operator>=(ClosedPatternEdge const &) const; + ::FlexFlow::MultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ClosedPatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_CLOSED_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml new file mode 100644 index 0000000000..d609ca1c27 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "ClosedPatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::MultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.dtg.h new file mode 100644 index 0000000000..5ce0e63073 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "c67ec363a91ce090dc538dcf76fa1f12" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct DownwardOpenPatternEdge { + DownwardOpenPatternEdge() = delete; + DownwardOpenPatternEdge(::FlexFlow::DownwardOpenMultiDiEdge const &raw_edge); + + bool operator==(DownwardOpenPatternEdge const &) const; + bool operator!=(DownwardOpenPatternEdge const &) const; + bool operator<(DownwardOpenPatternEdge const &) const; + bool operator>(DownwardOpenPatternEdge const &) const; + bool operator<=(DownwardOpenPatternEdge const &) const; + bool operator>=(DownwardOpenPatternEdge const &) const; + ::FlexFlow::DownwardOpenMultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::DownwardOpenPatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.h new file mode 100644 index 0000000000..9855d96e46 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_H + +#include "substitutions/unlabelled/downward_open_pattern_edge.dtg.h" + +namespace FlexFlow { + +int get_src_idx(DownwardOpenPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml new file mode 100644 index 0000000000..2dda7498f0 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "DownwardOpenPatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::DownwardOpenMultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h b/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h new file mode 100644 index 0000000000..e92fe547b1 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h @@ -0,0 +1,36 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml +/* proj-data +{ + "generated_from": "f172b041a99f4de1d396e5d451a5e64d" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_DTG_H + +#include "utils/bidict.h" +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct UnlabelledPatternEdgeSplits { + UnlabelledPatternEdgeSplits() = delete; + UnlabelledPatternEdgeSplits( + ::FlexFlow::bidict<::FlexFlow::MultiDiEdge, + std::pair<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::InputMultiDiEdge>> const + &unwrapped); + + bool operator==(UnlabelledPatternEdgeSplits const &) const; + bool operator!=(UnlabelledPatternEdgeSplits const &) const; + ::FlexFlow::bidict< + ::FlexFlow::MultiDiEdge, + std::pair<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::InputMultiDiEdge>> + unwrapped; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/edge_splits.h b/lib/substitutions/include/substitutions/unlabelled/edge_splits.h new file mode 100644 index 0000000000..58704500ac --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/edge_splits.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_H + +#include "substitutions/unlabelled/closed_pattern_edge.dtg.h" +#include "substitutions/unlabelled/edge_splits.dtg.h" +#include "substitutions/unlabelled/input_pattern_edge.dtg.h" +#include "substitutions/unlabelled/output_pattern_edge.dtg.h" +#include + +namespace FlexFlow { + +std::pair + get_split_edges(UnlabelledPatternEdgeSplits const &, + ClosedPatternEdge const &); + +std::vector> + as_closed_output_input_tuples(UnlabelledPatternEdgeSplits const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml b/lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml new file mode 100644 index 0000000000..fa714296c8 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "UnlabelledPatternEdgeSplits" +features = [ + "eq", +] + +includes = [ + "utils/bidict.h", + "utils/graph.h", + "", +] + +[[fields]] +name = "unwrapped" +type = "::FlexFlow::bidict<::FlexFlow::MultiDiEdge, std::pair<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::InputMultiDiEdge>>" diff --git a/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h b/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h new file mode 100644 index 0000000000..29c5740c0e --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_FIND_PATTERN_MATCHES_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_FIND_PATTERN_MATCHES_H + +#include "substitutions/unlabelled/match_additional_criterion.dtg.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { + +std::vector + find_pattern_matches(UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MatchAdditionalCriterion const &additional_criterion); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.h new file mode 100644 index 0000000000..f292acba14 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "d0cc0e65c4e3feb2e9b8435947c99e5f" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct InputPatternEdge { + InputPatternEdge() = delete; + InputPatternEdge(::FlexFlow::InputMultiDiEdge const &raw_edge); + + bool operator==(InputPatternEdge const &) const; + bool operator!=(InputPatternEdge const &) const; + bool operator<(InputPatternEdge const &) const; + bool operator>(InputPatternEdge const &) const; + bool operator<=(InputPatternEdge const &) const; + bool operator>=(InputPatternEdge const &) const; + ::FlexFlow::InputMultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::InputPatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h new file mode 100644 index 0000000000..b05fa479db --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_H + +#include "substitutions/unlabelled/input_pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" + +namespace FlexFlow { + +PatternNode get_dst_node(InputPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml new file mode 100644 index 0000000000..6da52b58aa --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "InputPatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::InputMultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h new file mode 100644 index 0000000000..e910be21ba --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h @@ -0,0 +1,36 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml +/* proj-data +{ + "generated_from": "2dff356c85dccda1fce8f714d41c6202" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_ADDITIONAL_CRITERION_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_ADDITIONAL_CRITERION_DTG_H + +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/graph.h" +#include + +namespace FlexFlow { +struct MatchAdditionalCriterion { + MatchAdditionalCriterion() = delete; + MatchAdditionalCriterion( + std::function const &node_criterion, + std::function const + &edge_criterion); + + std::function + node_criterion; + std::function + edge_criterion; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_ADDITIONAL_CRITERION_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml new file mode 100644 index 0000000000..c0107d84e9 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MatchAdditionalCriterion" +features = [] + +includes = [ + "", + "utils/graph.h", + "substitutions/unlabelled/pattern_node.dtg.h", + "substitutions/unlabelled/pattern_edge.dtg.h", +] + +[[fields]] +name = "node_criterion" +type = "std::function" + +[[fields]] +name = "edge_criterion" +type = "std::function" diff --git a/lib/substitutions/include/substitutions/unlabelled/match_split.dtg.h b/lib/substitutions/include/substitutions/unlabelled/match_split.dtg.h new file mode 100644 index 0000000000..aa17814c52 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_split.dtg.h @@ -0,0 +1,29 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml +/* proj-data +{ + "generated_from": "e44c4347e07263a493cbbd5caccedd22" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_DTG_H + +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include + +namespace FlexFlow { +struct MatchSplit { + MatchSplit() = delete; + MatchSplit(MultiDiGraphPatternMatch const &prefix_submatch, + MultiDiGraphPatternMatch const &postfix_submatch); + + bool operator==(MatchSplit const &) const; + bool operator!=(MatchSplit const &) const; + MultiDiGraphPatternMatch prefix_submatch; + MultiDiGraphPatternMatch postfix_submatch; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/match_split.h b/lib/substitutions/include/substitutions/unlabelled/match_split.h new file mode 100644 index 0000000000..a23bc3f89a --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_split.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_H + +#include "substitutions/unlabelled/match_split.dtg.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include "substitutions/unlabelled/pattern_split.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" + +namespace FlexFlow { + +MatchSplit empty_match_split(); +MatchSplit apply_split(UnlabelledGraphPattern const &pattern, + MultiDiGraphPatternMatch const &match, + PatternSplit const &split); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml b/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml new file mode 100644 index 0000000000..3fd77e7b4a --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MatchSplit" +features = [ + "eq", + # "ord", +] + +includes = [ + "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +] + +[[fields]] +name = "prefix_submatch" +type = "MultiDiGraphPatternMatch" + +[[fields]] +name = "postfix_submatch" +type = "MultiDiGraphPatternMatch" diff --git a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.dtg.h b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.dtg.h new file mode 100644 index 0000000000..30f81504fe --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.dtg.h @@ -0,0 +1,36 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml +/* proj-data +{ + "generated_from": "9842661a5d4e7d717f12d2c27da7df0d" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_DTG_H + +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/bidict.h" +#include "utils/graph.h" +#include + +namespace FlexFlow { +struct MultiDiGraphPatternMatch { + MultiDiGraphPatternMatch() = delete; + MultiDiGraphPatternMatch( + ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> const + &node_assignment, + ::FlexFlow::bidict<::FlexFlow::PatternEdge, + ::FlexFlow::OpenMultiDiEdge> const &edge_assignment); + + bool operator==(MultiDiGraphPatternMatch const &) const; + bool operator!=(MultiDiGraphPatternMatch const &) const; + ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> node_assignment; + ::FlexFlow::bidict<::FlexFlow::PatternEdge, ::FlexFlow::OpenMultiDiEdge> + edge_assignment; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h new file mode 100644 index 0000000000..aacae6d42a --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_H + +#include "substitutions/unlabelled/edge_splits.dtg.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" + +namespace FlexFlow { + +MultiDiGraphPatternMatch empty_multidigraph_pattern_match(); +std::optional + unsplit_matches(MultiDiGraphPatternMatch const &prefix, + MultiDiGraphPatternMatch const &postfix, + UnlabelledPatternEdgeSplits const &edge_splits); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml new file mode 100644 index 0000000000..778767ab62 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +# TODO(@lockshaw): rename to UnlabelledGraphPatternMatch +name = "MultiDiGraphPatternMatch" +features = [ + "eq", + # "ord", + # "hash", + # "fmt", +] + +includes = [ + "utils/bidict.h", + "utils/graph.h", + "substitutions/unlabelled/pattern_edge.dtg.h", + "substitutions/unlabelled/pattern_node.dtg.h", +] + +[[fields]] +name = "node_assignment" +type = "::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node>" + +[[fields]] +name = "edge_assignment" +type = "::FlexFlow::bidict<::FlexFlow::PatternEdge, ::FlexFlow::OpenMultiDiEdge>" diff --git a/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.dtg.h new file mode 100644 index 0000000000..04ec8c656d --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "3222696e351c3e203e008714245c737f" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct OutputPatternEdge { + OutputPatternEdge() = delete; + OutputPatternEdge(::FlexFlow::OutputMultiDiEdge const &raw_edge); + + bool operator==(OutputPatternEdge const &) const; + bool operator!=(OutputPatternEdge const &) const; + bool operator<(OutputPatternEdge const &) const; + bool operator>(OutputPatternEdge const &) const; + bool operator<=(OutputPatternEdge const &) const; + bool operator>=(OutputPatternEdge const &) const; + ::FlexFlow::OutputMultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OutputPatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.h new file mode 100644 index 0000000000..72e8ff02cf --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_H + +#include "substitutions/unlabelled/output_pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" + +namespace FlexFlow { + +PatternNode get_src_node(OutputPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml new file mode 100644 index 0000000000..362cbc3265 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "OutputPatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::OutputMultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h new file mode 100644 index 0000000000..4883590130 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "a3eff166b0c8be2ddf3f7305eec094fd" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct PatternEdge { + PatternEdge() = delete; + PatternEdge(::FlexFlow::OpenMultiDiEdge const &raw_edge); + + bool operator==(PatternEdge const &) const; + bool operator!=(PatternEdge const &) const; + bool operator<(PatternEdge const &) const; + bool operator>(PatternEdge const &) const; + bool operator<=(PatternEdge const &) const; + bool operator>=(PatternEdge const &) const; + ::FlexFlow::OpenMultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::PatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h new file mode 100644 index 0000000000..79db533d4e --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PATTERN_EDGE_H + +#include "substitutions/unlabelled/closed_pattern_edge.dtg.h" +#include "substitutions/unlabelled/input_pattern_edge.dtg.h" +#include "substitutions/unlabelled/output_pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" + +namespace FlexFlow { + +std::unordered_set get_nodes(PatternEdge const &); +bool is_closed_edge(PatternEdge const &); +bool is_input_edge(PatternEdge const &); +bool is_output_edge(PatternEdge const &); + +ClosedPatternEdge require_closed_edge(PatternEdge const &); +InputPatternEdge require_input_edge(PatternEdge const &); +OutputPatternEdge require_output_edge(PatternEdge const &); + +PatternEdge pattern_edge_from_input_edge(InputPatternEdge const &); +PatternEdge pattern_edge_from_output_edge(OutputPatternEdge const &); +PatternEdge pattern_edge_from_closed_edge(ClosedPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml new file mode 100644 index 0000000000..4abfa1c0db --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "PatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::OpenMultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h new file mode 100644 index 0000000000..223886b411 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_MATCHING_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_MATCHING_H + +#include "substitutions/unlabelled/match_additional_criterion.dtg.h" +#include "substitutions/unlabelled/match_split.dtg.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { + +bool unlabelled_pattern_does_match( + UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MultiDiGraphPatternMatch const &match, + MatchAdditionalCriterion const &additional_criterion); + +std::vector + find_pattern_matches(UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MatchAdditionalCriterion const &additional_criterion); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.h b/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.h new file mode 100644 index 0000000000..56471c2e08 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml +/* proj-data +{ + "generated_from": "a0e58ade010a9b250d2c1c378fde2639" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_NODE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_NODE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct PatternNode { + PatternNode() = delete; + PatternNode(::FlexFlow::Node const &raw_node); + + bool operator==(PatternNode const &) const; + bool operator!=(PatternNode const &) const; + bool operator<(PatternNode const &) const; + bool operator>(PatternNode const &) const; + bool operator<=(PatternNode const &) const; + bool operator>=(PatternNode const &) const; + ::FlexFlow::Node raw_node; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::PatternNode const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_NODE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml new file mode 100644 index 0000000000..ecd0253516 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "PatternNode" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_node" +type = "::FlexFlow::Node" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.h b/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.h new file mode 100644 index 0000000000..453c4020a8 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml +/* proj-data +{ + "generated_from": "8604edb5bd1a546ffa94ef496888e46d" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct PatternSplit { + PatternSplit() = delete; + PatternSplit(std::unordered_set<::FlexFlow::PatternNode> const &first, + std::unordered_set<::FlexFlow::PatternNode> const &second); + + bool operator==(PatternSplit const &) const; + bool operator!=(PatternSplit const &) const; + std::unordered_set<::FlexFlow::PatternNode> first; + std::unordered_set<::FlexFlow::PatternNode> second; +}; +} // namespace FlexFlow + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::PatternSplit from_json(json const &); + static void to_json(json &, FlexFlow::PatternSplit const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(PatternSplit const &); +std::ostream &operator<<(std::ostream &, PatternSplit const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.h b/lib/substitutions/include/substitutions/unlabelled/pattern_split.h new file mode 100644 index 0000000000..3fcc5cb12f --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_H + +#include "substitutions/unlabelled/edge_splits.dtg.h" +#include "substitutions/unlabelled/pattern_split.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" + +namespace FlexFlow { + +PatternSplit find_even_split(UnlabelledGraphPattern const &); + +GraphSplit get_raw_split(PatternSplit const &); + +UnlabelledPatternEdgeSplits + get_edge_splits(UnlabelledGraphPattern const &pattern, + PatternSplit const &split); + +std::pair + apply_split(UnlabelledGraphPattern const &, PatternSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml new file mode 100644 index 0000000000..04d1080ff7 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "PatternSplit" +features = [ + "eq", + # "ord", + "json", + "fmt", +] + +includes = [ + "utils/graph.h", + "", + "substitutions/unlabelled/pattern_node.dtg.h", +] + +[[fields]] +name = "first" +type = "std::unordered_set<::FlexFlow::PatternNode>" + +[[fields]] +name = "second" +type = "std::unordered_set<::FlexFlow::PatternNode>" diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h new file mode 100644 index 0000000000..a2ba6c26d2 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h @@ -0,0 +1,24 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml +/* proj-data +{ + "generated_from": "f494ed79eb1ba4010155e456b452157f" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_DTG_H + +#include "utils/graph.h" + +namespace FlexFlow { +struct UnlabelledGraphPattern { + UnlabelledGraphPattern() = delete; + UnlabelledGraphPattern(::FlexFlow::OpenMultiDiGraphView const &raw_graph); + + ::FlexFlow::OpenMultiDiGraphView raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h new file mode 100644 index 0000000000..9bb63037be --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_H + +#include "substitutions/unlabelled/downward_open_pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" +#include "substitutions/unlabelled/upward_open_pattern_edge.dtg.h" + +namespace FlexFlow { + +size_t num_nodes(UnlabelledGraphPattern const &); +bool is_singleton_pattern(UnlabelledGraphPattern const &); +std::unordered_set get_nodes(UnlabelledGraphPattern const &); +std::unordered_set get_edges(UnlabelledGraphPattern const &); +std::vector + get_topological_ordering(UnlabelledGraphPattern const &); + +std::unordered_set + get_incoming_edges(UnlabelledGraphPattern const &, PatternNode const &); +std::unordered_set + get_outgoing_edges(UnlabelledGraphPattern const &, PatternNode const &); + +UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml new file mode 100644 index 0000000000..03f4bd5523 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml @@ -0,0 +1,10 @@ +namespace = "FlexFlow" +name = "UnlabelledGraphPattern" +features = [] +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::OpenMultiDiGraphView" diff --git a/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.dtg.h new file mode 100644 index 0000000000..82440b5820 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "a1d4c9d1dd94eb456c5e29d80ad579da" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct UpwardOpenPatternEdge { + UpwardOpenPatternEdge() = delete; + UpwardOpenPatternEdge(::FlexFlow::UpwardOpenMultiDiEdge const &raw_edge); + + bool operator==(UpwardOpenPatternEdge const &) const; + bool operator!=(UpwardOpenPatternEdge const &) const; + bool operator<(UpwardOpenPatternEdge const &) const; + bool operator>(UpwardOpenPatternEdge const &) const; + bool operator<=(UpwardOpenPatternEdge const &) const; + bool operator>=(UpwardOpenPatternEdge const &) const; + ::FlexFlow::UpwardOpenMultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::UpwardOpenPatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.h new file mode 100644 index 0000000000..998cf1a519 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_H + +#include "substitutions/unlabelled/upward_open_pattern_edge.dtg.h" + +namespace FlexFlow { + +int get_dst_idx(UpwardOpenPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml new file mode 100644 index 0000000000..a4c3bad809 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "UpwardOpenPatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::UpwardOpenMultiDiEdge" diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc deleted file mode 100644 index 296a975626..0000000000 --- a/lib/substitutions/src/graph_pattern.cc +++ /dev/null @@ -1,257 +0,0 @@ -#include "substitutions/graph_pattern.h" -#include "op-attrs/operator_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "pcg/parallel_computation_graph.h" -#include "substitutions/get_attribute.h" -#include "substitutions/graph_pattern_match.h" -#include "substitutions/operator_pattern.h" -#include "substitutions/parallel_tensor_pattern.h" - -namespace FlexFlow { - -std::optional - evaluate_list_index_access(int index, - std::optional const &v) { - if (!v.has_value() || - !std::holds_alternative>(v.value()) || - !std::holds_alternative>( - v.value())) { - return std::nullopt; - } - - if (index >= MAX_TENSOR_DIM) { - return std::nullopt; - } - - if (std::holds_alternative>(v.value())) { - return get>(v.value()).at(index); - } else { - return get>(v.value()).at(index); - } -} - -std::optional - evaluate_list_index_access(int const &index, - std::optional const &v) { - if (!v.has_value() || !std::holds_alternative>(v.value())) { - return std::nullopt; - } - - auto vec = get>(v.value()); - - if (index >= vec.size()) { - return std::nullopt; - } - - return vec.at(index); -} - -std::optional - evaluate_list_size(std::optional const &v) { - return MAX_TENSOR_DIM; -} - -std::optional - evaluate_list_size(std::optional const &v) { - if (!v.has_value() || !std::holds_alternative>(v.value())) { - return std::nullopt; - } - - return (int)get>(v.value()).size(); -} - -struct EvaluateOperatorAttributeExpr { - EvaluateOperatorAttributeExpr(Operator const &attrs) : attrs(attrs) {} - - std::optional - operator()(OperatorAttributeKey const &key) { - return get_attribute(this->attrs.attrs, key); - } - - std::optional - operator()(ListIndexAccess const &index_access) { - std::optional v = - get_attribute(this->attrs.attrs, index_access.attribute_key); - return evaluate_list_index_access(index_access.index, v); - } - - std::optional - operator()(ListSize const &list_size) { - std::optional v = - get_attribute(this->attrs.attrs, list_size.attribute_key); - return evaluate_list_size(v); - } - -private: - Operator attrs; -}; - -std::optional - evaluate_tensor_attribute_expr(ParallelTensor const &, - AttributeExpr const &); - -struct EvaluateTensorAttributeExpr { - EvaluateTensorAttributeExpr(ParallelTensor const &tensor_shape) - : tensor_shape(tensor_shape) {} - - template - std::optional evaluate(T const &t) { - return this->operator()(t); - } - - std::optional operator()(TensorAttributeKey key) { - switch (key) { - case TensorAttributeKey::DIM_SIZES: { - std::vector result; - for (ParallelDim const &dim : this->tensor_shape.dims) { - result.push_back(dim.size); - } - return result; - } - case TensorAttributeKey::DIM_DEGREES: { - std::vector result; - for (ParallelDim const &dim : this->tensor_shape.dims) { - result.push_back(dim.degree); - } - return result; - } - default: - throw std::runtime_error("Unknown TensorAttributeKey"); - } - } - - std::optional - operator()(ListIndexAccess const &index_access) { - std::optional v = - this->evaluate(index_access.attribute_key); - return evaluate_list_index_access(index_access.index, v); - } - - std::optional - operator()(ListSize const &list_size) { - return evaluate_list_size(this->evaluate(list_size.attribute_key)); - } - -private: - ParallelTensor tensor_shape; -}; - -std::optional - evaluate_attribute_expr(ParallelTensor const &tensor_shape, - AttributeExpr const &expr) { - return visit(EvaluateTensorAttributeExpr(tensor_shape), expr); -} - -std::optional - evaluate_attribute_expr(Operator const &attrs, - AttributeExpr const &expr) { - return visit(EvaluateOperatorAttributeExpr(attrs), expr); -} - -template -std::optional satisfies(ConstraintType constraint_type, - V const &constraint_value, - std::optional const &maybe_attribute_value) { - if (!maybe_attribute_value.has_value()) { - return std::nullopt; - } - V attr_val = maybe_attribute_value.value(); - - if (attr_val.index() != constraint_value.index()) { - return std::nullopt; - } - - if (constraint_type == ConstraintType::EQUAL) { - return attr_val == constraint_value; - } else { - throw std::runtime_error("Unknown constraint_type"); - } -} - -std::optional satisfies(ParallelTensor const &tensor_shape, - TensorAttributeConstraint const &constraint) { - auto value = evaluate_attribute_expr(tensor_shape, constraint.attribute_expr); - return satisfies( - constraint.constraint_type, constraint.attribute_value, value); -} - -std::optional satisfies(Operator const ¶ms, - OperatorAttributeConstraint const &constraint) { - auto value = evaluate_attribute_expr(params, constraint.attribute_expr); - OperatorAttributeValue v = value.value(); - return satisfies( - constraint.constraint_type, constraint.attribute_value, value); -} - -template -std::optional optional_all_of(Container const &container, - Function const &func) { - for (auto const &element : container) { - std::optional condition = func(element); - if (!condition.has_value()) { - return std::nullopt; - } - - if (!condition.value()) { - return false; - } - } - return true; -} - -std::optional satisfies(Operator const ¶ms, - OperatorPattern const &pattern) { - return optional_all_of(pattern.attribute_constraints, - [&](OperatorAttributeConstraint const &c) { - return satisfies(params, c); - }); -} - -std::optional satisfies(ParallelTensor const ¶ms, - ParallelTensorPattern const &pattern) { - return optional_all_of( - pattern.attribute_constraints, - [&](TensorAttributeConstraint const &c) { return satisfies(params, c); }); -} - -bool operator_satisfies(Operator const ¶ms, - OperatorPattern const &pattern) { - return satisfies(params, pattern).value_or(false); -} - -bool parallel_tensor_satisfies(ParallelTensor const ¶ms, - ParallelTensorPattern const &pattern) { - return satisfies(params, pattern).value_or(false); -} - -bool assignment_satisfies(SubParallelComputationGraph const &pcg, - GraphPattern const &pattern, - MultiDiGraphPatternMatch const &patternMatch) { - bool result = true; - for (auto const &kv : patternMatch.node_assignment) { - Node patternNode = kv.first; - Node pcgNode = kv.second; - std::optional constraintResult = - satisfies(pcg.at(pcgNode), pattern.value().at(patternNode)); - result &= constraintResult.value_or(false); - } - - for (auto const &kv : patternMatch.edge_assignment) { - OpenMultiDiEdge patternEdge = kv.first; - OpenMultiDiEdge pcgEdge = kv.second; - std::optional constraintResult = - satisfies(pcg.at(pcgEdge), pattern.value().at(patternEdge)); - result &= constraintResult.value_or(false); - } - - result &= pattern_matches( - pattern, - pcg, - patternMatch, - MatchAdditionalCriterion{[](Node const &, Node const &) { return true; }, - [](OpenMultiDiEdge const &, - OpenMultiDiEdge const &) { return true; }}); - - return result; -} -} // namespace FlexFlow diff --git a/lib/substitutions/src/graph_pattern_match.cc b/lib/substitutions/src/graph_pattern_match.cc deleted file mode 100644 index f9c6b9a773..0000000000 --- a/lib/substitutions/src/graph_pattern_match.cc +++ /dev/null @@ -1,305 +0,0 @@ -#include "substitutions/graph_pattern.h" -#include "utils/hash-utils.h" -#include - -namespace FlexFlow { - -GraphSplit split_pattern(OpenMultiDiGraphView const &pattern) { - std::vector topological_ordering = get_topological_ordering(pattern); - assert(topological_ordering.size() >= 2); - - int split_point = topological_ordering.size() / 2; - auto split = vector_split(topological_ordering, split_point); - std::unordered_set prefix(split.first.begin(), split.first.end()); - std::unordered_set postfix(split.second.begin(), split.second.end()); - return {prefix, postfix}; -} - -std::pair - apply_split(OpenMultiDiGraphView const &pattern, GraphSplit const &split) { - return {get_subgraph(pattern, split.first), - get_subgraph(pattern, split.second)}; -} - -/* -Given a match and a pattern split, gets the submatches in subpatterns. -*/ -MatchSplit apply_split(OpenMultiDiGraphView const &pattern, - MultiDiGraphPatternMatch const &match, - GraphSplit const &split) { - auto prefix = split.first; - auto postfix = split.second; - - MatchSplit result; - - for (auto const &kv : match.node_assignment) { - Node pattern_node = kv.first; - Node graph_node = kv.second; - if (contains(split.first, pattern_node)) { - result.prefix_submatch.node_assignment.equate(pattern_node, graph_node); - } else { - assert(contains(split.second, pattern_node)); - result.postfix_submatch.node_assignment.equate(pattern_node, graph_node); - } - } - - auto edge_splits = get_edge_splits(pattern, split); - - std::function - handle_edge = [&](OpenMultiDiEdge const &pattern_edge, - OpenMultiDiEdge const &graph_edge) -> void { - auto edge_nodes = get_nodes(pattern_edge); - if (is_subseteq_of(edge_nodes, prefix)) { - result.prefix_submatch.edge_assignment.equate(pattern_edge, graph_edge); - } else if (is_subseteq_of(edge_nodes, postfix)) { - result.postfix_submatch.edge_assignment.equate(pattern_edge, graph_edge); - } else { - assert(is_standard_edge(pattern_edge)); - assert(is_standard_edge(graph_edge)); - auto standard_edge = std::get(pattern_edge); - auto divided = edge_splits.at_l(standard_edge); - auto divided_graph_edge = split_edge(get(graph_edge)); - handle_edge(divided.first, divided_graph_edge.first); - handle_edge(divided.second, divided_graph_edge.second); - } - }; - - for (auto const &kv : match.edge_assignment) { - OpenMultiDiEdge pattern_edge = kv.first; - OpenMultiDiEdge graph_edge = match.edge_assignment.at_l(pattern_edge); - handle_edge(pattern_edge, graph_edge); - } - - return result; -} - -bool is_singleton_pattern(OpenMultiDiGraphView const &pattern) { - return num_nodes(pattern) == 1; -} - -bool pattern_matches(OpenMultiDiGraphView const &pattern, - OpenMultiDiGraphView const &graph, - MultiDiGraphPatternMatch const &match, - MatchAdditionalCriterion const &additional_criterion) { - if (is_singleton_pattern(pattern)) { - Node pattern_node = get_only(get_nodes(pattern)); - Node graph_matched_node = match.node_assignment.at_l(pattern_node); - if (!additional_criterion.node_criterion(pattern_node, - graph_matched_node)) { - return false; - } - for (OpenMultiDiEdge const &e : get_edges(pattern)) { - OpenMultiDiEdge graph_matched_edge = match.edge_assignment.at_l(e); - - assert(is_input_edge(e) || is_output_edge(e)); - if (is_input_edge(e)) { - if (is_output_edge(graph_matched_edge)) { - return false; - } - UpwardOpenMultiDiEdge matched_edge = - narrow(graph_matched_edge).value(); - InputMultiDiEdge input_edge = std::get(e); - if (match.node_assignment.at_l(input_edge.dst) != - get_dst_node(matched_edge)) { - return false; - } - } else { - if (is_input_edge(graph_matched_edge)) { - return false; - } - DownwardOpenMultiDiEdge matched_edge = - narrow(graph_matched_edge).value(); - OutputMultiDiEdge output_edge = std::get(e); - if (match.node_assignment.at_l(output_edge.src) != - get_src_node(matched_edge)) { - return false; - } - } - - if (!additional_criterion.edge_criterion(e, graph_matched_edge)) { - return false; - } - } - - return true; - } - - auto split = split_pattern(pattern); - auto subpatterns = apply_split(pattern, split); - auto submatches = apply_split(pattern, match, split); - - return pattern_matches(subpatterns.first, - graph, - submatches.prefix_submatch, - additional_criterion) && - pattern_matches(subpatterns.second, - graph, - submatches.postfix_submatch, - additional_criterion); -} - -template -bool dst_compare(T const &lhs, T const &rhs) { - return get_dst_idx(lhs) < get_dst_idx(rhs); -} - -template -bool src_compare(T const &lhs, T const &rhs) { - return get_src_idx(lhs) < get_src_idx(rhs); -} - -std::optional - get_candidate_singleton_match(OpenMultiDiGraphView const &pattern, - OpenMultiDiGraphView const &graph, - Node const &graph_node) { - assert(is_singleton_pattern(pattern)); - - Node pattern_node = get_only(get_nodes(pattern)); - - MultiDiGraphPatternMatch match; - match.node_assignment.equate(pattern_node, graph_node); - - std::unordered_set incoming = - get_incoming_edges(graph, graph_node); - std::unordered_set outgoing = - get_outgoing_edges(graph, graph_node); - - std::unordered_set pattern_incoming = - get_incoming_edges(pattern, pattern_node); - std::unordered_set pattern_outgoing = - get_outgoing_edges(pattern, pattern_node); - - if (!pattern_incoming.empty() && pattern_incoming.size() != incoming.size()) { - return std::nullopt; - } - - if (!pattern_outgoing.empty() && pattern_outgoing.size() != outgoing.size()) { - return std::nullopt; - } - - std::vector incoming_ordered = - sorted_by(incoming, dst_compare); - std::vector outgoing_ordered = - sorted_by(outgoing, src_compare); - - std::vector pattern_incoming_ordered = - sorted_by(pattern_incoming, dst_compare); - std::vector pattern_outgoing_ordered = - sorted_by(pattern_outgoing, src_compare); - - if (pattern_incoming.size()) { - std::unordered_map node_port_mapping; - for (int i = 0; i < incoming_ordered.size(); ++i) { - UpwardOpenMultiDiEdge graph_edge = incoming_ordered[i], - pattern_edge = pattern_incoming_ordered[i]; - NodePort graph_port = get_dst_idx(graph_edge), - pattern_port = get_dst_idx(pattern_edge); - if (!contains_key(node_port_mapping, graph_port)) { - node_port_mapping.emplace(graph_port, pattern_port); - } else { - if (pattern_port != node_port_mapping.at(graph_port)) { - return std::nullopt; - } - } - match.edge_assignment.equate(widen(pattern_edge), - widen(graph_edge)); - } - } - - if (pattern_outgoing.size()) { - std::unordered_map node_port_mapping; - for (int i = 0; i < outgoing_ordered.size(); ++i) { - DownwardOpenMultiDiEdge graph_edge = outgoing_ordered[i], - pattern_edge = pattern_outgoing_ordered[i]; - NodePort graph_port = get_src_idx(graph_edge), - pattern_port = get_src_idx(pattern_edge); - if (!contains_key(node_port_mapping, graph_port)) { - node_port_mapping.insert({graph_port, pattern_port}); - } else { - if (pattern_port != node_port_mapping.at(graph_port)) { - return std::nullopt; - } - } - match.edge_assignment.equate(widen(pattern_edge), - widen(graph_edge)); - } - } - - return match; -} - -std::optional unsplit_matches( - MultiDiGraphPatternMatch const &prefix, - MultiDiGraphPatternMatch const &postfix, - bidict> const - &edge_splits) { - MultiDiGraphPatternMatch result; - std::unordered_set handled; - for (auto const &kv : edge_splits) { - MultiDiEdge standard_edge = kv.first; - OutputMultiDiEdge output_edge = kv.second.first; - InputMultiDiEdge input_edge = kv.second.second; - handled.insert(output_edge); - handled.insert(input_edge); - - OpenMultiDiEdge output_graph_edge = - prefix.edge_assignment.at_l(output_edge); - OpenMultiDiEdge input_graph_edge = postfix.edge_assignment.at_l(input_edge); - if (output_graph_edge == input_graph_edge) { - result.edge_assignment.equate(standard_edge, output_graph_edge); - } else { - return std::nullopt; - } - } - - for (auto const &kv : - merge_maps(prefix.edge_assignment, postfix.edge_assignment)) { - if (!contains(handled, kv.first)) { - result.edge_assignment.equate(kv.first, kv.second); - } - } - - result.node_assignment = - merge_maps(prefix.node_assignment, postfix.node_assignment); - - return result; -} - -std::vector - find_pattern_matches(OpenMultiDiGraphView const &pattern, - OpenMultiDiGraphView const &graph, - MatchAdditionalCriterion const &additional_criterion) { - std::vector matches; - if (is_singleton_pattern(pattern)) { - for (Node const &graph_node : get_nodes(graph)) { - std::optional candidate = - get_candidate_singleton_match(pattern, graph, graph_node); - if (candidate.has_value() && - pattern_matches( - pattern, graph, candidate.value(), additional_criterion)) { - matches.push_back(candidate.value()); - } - } - } else { - GraphSplit split = split_pattern(pattern); - auto subpatterns = apply_split(pattern, split); - auto prefix_matches = - find_pattern_matches(subpatterns.first, graph, additional_criterion); - auto postfix_matches = - find_pattern_matches(subpatterns.second, graph, additional_criterion); - auto edge_splits = get_edge_splits(pattern, split); - for (MultiDiGraphPatternMatch const &prefix_match : prefix_matches) { - for (MultiDiGraphPatternMatch const &postfix_match : postfix_matches) { - std::optional unsplit = - unsplit_matches(prefix_match, postfix_match, edge_splits); - if (unsplit.has_value()) { - matches.push_back(unsplit.value()); - } - } - } - } - - return matches; -} - -} // namespace FlexFlow diff --git a/lib/substitutions/src/sub_parallel_computation_graph.cc b/lib/substitutions/src/sub_parallel_computation_graph.cc deleted file mode 100644 index e8cb093222..0000000000 --- a/lib/substitutions/src/sub_parallel_computation_graph.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "substitutions/sub_parallel_computation_graph.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 15816185ee..94993f3c90 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -2,475 +2,386 @@ namespace FlexFlow { -struct DeriveValidOperatorAttributeExpr { - template - std::unordered_set> - operator()(T const &t) { - return derive_valid_operator_attribute_expr(t); - } - - std::unordered_set> - derive_valid_operator_attribute_expr(OperatorAttributeKey const &key) { - return {key}; - } - - std::unordered_set> - derive_valid_operator_attribute_expr( - ListIndexAccess const &access) { - return {access, access.attribute_key}; - } - - std::unordered_set> - derive_valid_operator_attribute_expr( - ListSize const &ls) { - return {ls, ls.attribute_key}; - } -}; - -std::unordered_set> - get_valid_operator_attribute_exprs(OperatorPattern const &pattern) { - return set_union(transform( - pattern.attribute_constraints, [](OperatorAttributeConstraint const &t) { - return visit(DeriveValidOperatorAttributeExpr{}, t.attribute_expr); - })); -} - -bool is_valid_operator_attribute_expr( - OperatorPattern const &pattern, - AttributeExpr const &expr) { - return contains(get_valid_operator_attribute_exprs(pattern), expr); -} - -struct IsValidOperatorAttributeExprFunctor { - GraphPattern const &graph_pattern; - - template - bool operator()(T const &t) const { - return is_valid(t); - } - - bool is_valid(OperatorAttrAccess const &t) const { - return is_valid_operator_attribute_expr(graph_pattern.value().at(t.node), - t.attr_expr); - } - - bool is_valid(AttrConstant const &t) const { - return true; - } -}; - -bool is_valid_operator_attribute_expr(GraphPattern const &pattern, - OperatorAttributeExpr const &expr) { - return visit(IsValidOperatorAttributeExprFunctor{pattern}, expr); -} - -bool is_valid_substitution(Substitution const &s) { - for (Node const &node : get_nodes(s.output_graph_expr.value())) { - for (OperatorAttributeExpr expr : - values(s.output_graph_expr.value().at(node).assignments)) { - if (!is_valid_operator_attribute_expr(s.input_graph, expr)) { - return false; - } - } - } - return true; -} - -struct EvaluateOperatorAttributeExpr { - SubParallelComputationGraph const &graph; - MultiDiGraphPatternMatch const &match; - - template - OperatorAttributeValue operator()(T const &t) { - return evaluate(t); - } - - OperatorAttributeValue evaluate(OperatorAttrAccess const &t) { - Node node_in_pattern = t.node; - Node node_in_pcg = match.node_assignment.at_l(node_in_pattern); - return evaluate_attribute_expr(graph.at(node_in_pcg), t.attr_expr).value(); - } - - OperatorAttributeValue evaluate(AttrConstant const &t) { - return t.value; - } -}; - -OperatorAttributeValue - evaluate_graph_attribute_expr(SubParallelComputationGraph const &g, - MultiDiGraphPatternMatch const &match, - OperatorAttributeExpr const &expr) { - return visit(EvaluateOperatorAttributeExpr{g, match}, expr); -} - -Operator get_operator_attrs(SubParallelComputationGraph const &graph, - MultiDiGraphPatternMatch const &match, - OperatorAttrAssignment const &assignment) { - std::unordered_map assignments; - for (auto const &[key, expr] : assignment.assignments) { - OperatorAttributeValue value = - evaluate_graph_attribute_expr(graph, match, expr); - assignments.emplace(key, value); - } - assert(contains_key(assignments, OperatorAttributeKey::OP_TYPE)); - assert(std::holds_alternative( - assignments.at(OperatorAttributeKey::OP_TYPE))); - OperatorType op_type = - std::get(assignments.at(OperatorAttributeKey::OP_TYPE)); - switch (op_type) { - case Op::BATCHMATMUL: - return Operator{ - BatchMatmulAttrs{std::get(assignments.at( - OperatorAttributeKey::A_SEQ_LENGTH_DIM)), - std::get(assignments.at( - OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, - std::nullopt}; - case Op::BATCHNORM: - return Operator{BatchNormAttrs{std::get( - assignments.at(OperatorAttributeKey::RELU))}, - std::nullopt}; - case Op::CAST: - return Operator{CastAttrs{std::get( - assignments.at(OperatorAttributeKey::DATA_TYPE))}, - std::nullopt}; - case Op::CONCAT: - return Operator{ - ConcatAttrs{ - std::get(assignments.at(OperatorAttributeKey::AXIS)), - std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, - std::nullopt}; - case Op::CONV2D: - return Operator{ - Conv2DAttrs{ - std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), - std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), - std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), - std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), - std::get(assignments.at(OperatorAttributeKey::PADDING_H)), - std::get(assignments.at(OperatorAttributeKey::PADDING_W)), - std::get(assignments.at(OperatorAttributeKey::GROUPS)), - std::get( - assignments.at(OperatorAttributeKey::ACTIVATION)), - std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, - std::nullopt}; - case Op::DROPOUT: - return Operator{DropoutAttrs{std::get(assignments.at( - OperatorAttributeKey::RATE)), - std::get(assignments.at( - OperatorAttributeKey::SEED))}, - std::nullopt}; - case Op::EW_ADD: - case Op::EW_DIV: - case Op::EW_EQUAL: - case Op::EW_GREATER: - case Op::EW_LESS: - case Op::EW_MAX: - case Op::EW_MIN: - case Op::EW_MUL: - case Op::EW_SUB: - return Operator{ - ElementBinaryAttrs{op_type, - std::get(assignments.at( - OperatorAttributeKey::DATA_TYPE)), - std::get(assignments.at( - OperatorAttributeKey::SHOULD_BROADCAST_LHS)), - std::get(assignments.at( - OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, - std::nullopt}; - case Op::SCALAR_ADD: - case Op::SCALAR_FLOOR_DIV: - case Op::SCALAR_MULTIPLY: - case Op::SCALAR_SUB: - case Op::SCALAR_TRUE_DIV: - return Operator{ - ElementScalarUnaryAttrs{ - op_type, - std::get(assignments.at(OperatorAttributeKey::SCALAR))}, - std::nullopt}; - case Op::EXP: - case Op::IDENTITY: - case Op::GELU: - case Op::RSQRT: - case Op::POW: - case Op::SIN: - case Op::COS: - return Operator{ElementUnaryAttrs{op_type}, std::nullopt}; - case Op::EMBEDDING: - return Operator{ - EmbeddingAttrs{ - std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), - std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - std::get(assignments.at(OperatorAttributeKey::AGGR)), - std::get( - assignments.at(OperatorAttributeKey::OP_TYPE))}, - std::nullopt}; - case Op::FLAT: - return Operator{FlatAttrs{}, std::nullopt}; - case Op::GATHER: - return Operator{GatherAttrs{std::get( - assignments.at(OperatorAttributeKey::DIM))}, - std::nullopt}; - case Op::INPUT: - return Operator{InputAttrs{}, std::nullopt}; - case Op::LAYERNORM: - return Operator{ - LayerNormAttrs{ - std::get>( - assignments.at(OperatorAttributeKey::AXES)), - std::get( - assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), - std::get(assignments.at(OperatorAttributeKey::EPSILON))}, - std::nullopt}; - case Op::LINEAR: - return Operator{ - LinearAttrs{ - std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - std::get(assignments.at(OperatorAttributeKey::USE_BIAS)), - std::get( - assignments.at(OperatorAttributeKey::DATA_TYPE)), - std::get( - assignments.at(OperatorAttributeKey::ACTIVATION)), - std::get>( - assignments.at(OperatorAttributeKey::REGULARIZER))}, - std::nullopt}; - case Op::MULTIHEAD_ATTENTION: - return Operator{ - MultiHeadAttentionAttrs{ - std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), - std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - std::get(assignments.at(OperatorAttributeKey::VDIM)), - std::get(assignments.at(OperatorAttributeKey::DROPOUT)), - std::get(assignments.at(OperatorAttributeKey::BIAS)), - std::get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), - std::get( - assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, - std::nullopt}; - case Op::NOOP: - return Operator{NoopAttrs{}, std::nullopt}; - case Op::POOL2D: - return Operator{ - Pool2DAttrs{ - std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), - std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), - std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), - std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), - std::get(assignments.at(OperatorAttributeKey::PADDING_H)), - std::get(assignments.at(OperatorAttributeKey::PADDING_W)), - std::get(assignments.at(OperatorAttributeKey::POOL_TYPE)), - std::get( - assignments.at(OperatorAttributeKey::ACTIVATION))}, - std::nullopt}; - case Op::REDUCE_ARGMAX: - case Op::REDUCE_ARGMIN: - case Op::REDUCE_MAX: - case Op::REDUCE_MEAN: - case Op::REDUCE_MIN: - case Op::REDUCE_PROD: - case Op::REDUCE_SUM: - return Operator{ - ReduceAttrs{ - std::get>( - assignments.at(OperatorAttributeKey::AXES)), - op_type, - std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, - std::nullopt}; - case Op::REVERSE: - return Operator{ReverseAttrs{std::get( - assignments.at(OperatorAttributeKey::AXIS))}, - std::nullopt}; - case Op::RESHAPE: - return Operator{ReshapeAttrs{std::get( - assignments.at(OperatorAttributeKey::SHAPE))}, - std::nullopt}; - case Op::SPLIT: - return Operator{ - SplitAttrs{ - std::get>( - assignments.at(OperatorAttributeKey::SPLITS)), - std::get(assignments.at(OperatorAttributeKey::AXIS))}, - std::nullopt}; - case Op::SOFTMAX: - return Operator{SoftmaxAttrs{std::get( - assignments.at(OperatorAttributeKey::DIM))}, - std::nullopt}; - case Op::TOPK: - return Operator{ - TopKAttrs{ - std::get(assignments.at(OperatorAttributeKey::K)), - std::get(assignments.at(OperatorAttributeKey::SORTED))}, - std::nullopt}; - case Op::TRANSPOSE: - return Operator{ - TransposeAttrs{std::get>( - assignments.at(OperatorAttributeKey::PERMUTATION))}, - std::nullopt}; - case Op::COMBINE: - return Operator{CombineAttrs{std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DEGREE))}, - std::nullopt}; - case Op::REDUCTION: - return Operator{ - ReductionAttrs{std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DEGREE))}, - std::nullopt}; - case Op::REPARTITION: - return Operator{ - RepartitionAttrs{std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DEGREE))}, - std::nullopt}; - case Op::REPLICATE: - return Operator{ - ReplicateAttrs{std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DEGREE))}, - std::nullopt}; - default: - throw mk_runtime_error("Unknown Operator"); - } -} - -struct AddMappedEdgeFunctor { - bidict const &node_mapping; - SubParallelComputationGraph &new_pcg; - - template - void operator()(T const &t) { - return add_mapped_edge(t); - } - - void add_mapped_edge(InputMultiDiEdge const &e) { - new_pcg.add_edge(InputMultiDiEdge{ - node_mapping.at_l(e.dst), new_pcg.add_node_port(), e.uid}); - } - - void add_mapped_edge(OutputMultiDiEdge const &e) { - new_pcg.add_edge(OutputMultiDiEdge{ - node_mapping.at_l(e.src), new_pcg.add_node_port(), e.uid}); - } - - void add_mapped_edge(MultiDiEdge const &e) { - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), - new_pcg.add_node_port(), - node_mapping.at_l(e.src), - new_pcg.add_node_port()}); - } -}; - -struct AddNewEdgeFunctor { - SubParallelComputationGraph const &old_pcg; - SubParallelComputationGraph &new_pcg; - MultiDiGraphPatternMatch const &match; - bidict node_mapping; - - template - void operator()(TO const &old_edge, TN const &new_edge) { - return add_new_edge(old_edge, new_edge); - } - - void add_new_edge(InputMultiDiEdge const &old_edge, - InputMultiDiEdge const &new_edge) { - new_pcg.add_edge(InputMultiDiEdge{node_mapping.at_l(new_edge.dst), - new_pcg.add_node_port(), - old_edge.uid}); - } - - void add_new_edge(MultiDiEdge const &old_edge, - InputMultiDiEdge const &new_edge) { - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(new_edge.dst), - new_pcg.add_node_port(), - node_mapping.at_l(old_edge.src), - new_pcg.add_node_port()}); - } - - void add_new_edge(OutputMultiDiEdge const &old_edge, - OutputMultiDiEdge const &new_edge) { - new_pcg.add_edge(OutputMultiDiEdge{node_mapping.at_l(new_edge.src), - new_pcg.add_node_port(), - old_edge.uid}); - } - - void add_new_edge(MultiDiEdge const &old_edge, - OutputMultiDiEdge const &new_edge) { - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(old_edge.dst), - new_pcg.add_node_port(), - node_mapping.at_l(new_edge.src), - new_pcg.add_node_port()}); - } - - void add_new_edge(InputMultiDiEdge const &, OutputMultiDiEdge const &) { - assert(false); - } - - void add_new_edge(OpenMultiDiEdge const &, MultiDiEdge const &) { - assert(false); - } - - void add_new_edge(OutputMultiDiEdge const &, InputMultiDiEdge const &) { - assert(false); - } -}; - -SubParallelComputationGraph - apply_substitution(SubParallelComputationGraph const &pcg, - Substitution const &substitution, - MultiDiGraphPatternMatch const &match) { - SubParallelComputationGraph new_pcg = - OutputLabelledOpenMultiDiGraph::template create< - UnorderedOutputLabelledOpenMultiDiGraph>(); - bidict node_mapping; // Refactor it with global nodes - for (Node const &node : get_nodes(pcg)) { - if (!contains_r(match.node_assignment, node)) { - node_mapping.equate(node, new_pcg.add_node(pcg.at(node))); - } - } - for (OpenMultiDiEdge const &edge : get_edges(pcg)) { - if (!contains_r(match.edge_assignment, edge)) { - visit(AddMappedEdgeFunctor{node_mapping, new_pcg}, edge); - } - } - for (Node const &output_node : - get_nodes(substitution.output_graph_expr.value())) { - Operator new_op = get_operator_attrs( - pcg, match, substitution.output_graph_expr.value().at(output_node)); - Node new_node = new_pcg.add_node(new_op); - node_mapping.equate(output_node, new_node); - } - for (OpenMultiDiEdge const &output_edge : - get_edges(substitution.output_graph_expr.value())) { - if (std::holds_alternative(output_edge)) { - InputMultiDiEdge e = std::get(output_edge); - OpenMultiDiEdge original_edge = - match.edge_assignment.at_l(substitution.input_mapping.at_r(e)); - visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, - original_edge, - output_edge); - } else if (std::holds_alternative(output_edge)) { - OutputMultiDiEdge e = std::get(output_edge); - OpenMultiDiEdge original_edge = - match.edge_assignment.at_l(substitution.output_mapping.at_r(e)); - visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, - original_edge, - output_edge); - } else { - assert(std::holds_alternative(output_edge)); - MultiDiEdge e = std::get(output_edge); - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), - new_pcg.add_node_port(), - node_mapping.at_l(e.src), - new_pcg.add_node_port()}); - } - } - - return new_pcg; -} +/* struct DeriveValidOperatorAttributeExpr { */ +/* template */ +/* std::unordered_set> */ +/* operator()(T const &t) { */ +/* return derive_valid_operator_attribute_expr(t); */ +/* } */ + +/* std::unordered_set> */ +/* derive_valid_operator_attribute_expr(OperatorAttributeKey const &key) { + */ +/* return {key}; */ +/* } */ + +/* std::unordered_set> */ +/* derive_valid_operator_attribute_expr( */ +/* ListIndexAccess const &access) { */ +/* return {access, access.attribute_key}; */ +/* } */ + +/* std::unordered_set> */ +/* derive_valid_operator_attribute_expr( */ +/* ListSize const &ls) { */ +/* return {ls, ls.attribute_key}; */ +/* } */ +/* }; */ + +/* std::unordered_set> */ +/* get_valid_operator_attribute_exprs(OperatorPattern const &pattern) { */ +/* return set_union(transform( */ +/* pattern.attribute_constraints, [](OperatorAttributeConstraint const &t) + * { */ +/* return visit(DeriveValidOperatorAttributeExpr{}, t.attribute_expr); + */ +/* })); */ +/* } */ + +/* bool is_valid_operator_attribute_expr( */ +/* OperatorPattern const &pattern, */ +/* AttributeExpr const &expr) { */ +/* return contains(get_valid_operator_attribute_exprs(pattern), expr); */ +/* } */ + +/* struct IsValidOperatorAttributeExprFunctor { */ +/* GraphPattern const &graph_pattern; */ + +/* template */ +/* bool operator()(T const &t) const { */ +/* return is_valid(t); */ +/* } */ + +/* bool is_valid(OperatorAttrAccess const &t) const { */ +/* return is_valid_operator_attribute_expr(graph_pattern.value().at(t.node), + */ +/* t.attr_expr); */ +/* } */ + +/* bool is_valid(AttrConstant const &t) const { */ +/* return true; */ +/* } */ +/* }; */ + +/* bool is_valid_operator_attribute_expr(GraphPattern const &pattern, */ +/* OperatorAttributeExpr const &expr) { */ +/* return visit(IsValidOperatorAttributeExprFunctor{pattern}, expr); */ +/* } */ + +/* bool is_valid_substitution(Substitution const &s) { */ +/* for (Node const &node : get_nodes(s.output_graph_expr.value())) { */ +/* for (OperatorAttributeExpr expr : */ +/* values(s.output_graph_expr.value().at(node).assignments)) { */ +/* if (!is_valid_operator_attribute_expr(s.input_graph, expr)) { */ +/* return false; */ +/* } */ +/* } */ +/* } */ +/* return true; */ +/* } */ + +/* struct EvaluateOperatorAttributeExpr { */ +/* SubParallelComputationGraph const &graph; */ +/* MultiDiGraphPatternMatch const &match; */ + +/* template */ +/* OperatorAttributeValue operator()(T const &t) { */ +/* return evaluate(t); */ +/* } */ + +/* OperatorAttributeValue evaluate(OperatorAttrAccess const &t) { */ +/* Node node_in_pattern = t.node; */ +/* Node node_in_pcg = match.node_assignment.at_l(node_in_pattern); */ +/* return evaluate_attribute_expr(graph.at(node_in_pcg), + * t.attr_expr).value(); */ +/* } */ + +/* OperatorAttributeValue evaluate(AttrConstant const &t) { */ +/* return t.value; */ +/* } */ +/* }; */ + +/* OperatorAttributeValue */ +/* evaluate_graph_attribute_expr(SubParallelComputationGraph const &g, */ +/* MultiDiGraphPatternMatch const &match, */ +/* OperatorAttributeExpr const &expr) { */ +/* return visit(EvaluateOperatorAttributeExpr{g, match}, expr); */ +/* } */ + +/* Operator get_operator_attrs(SubParallelComputationGraph const &graph, */ +/* MultiDiGraphPatternMatch const &match, */ +/* OperatorAttrAssignment const &assignment) { */ +/* std::unordered_map + * assignments; */ +/* for (auto const &[key, expr] : assignment.assignments) { */ +/* OperatorAttributeValue value = */ +/* evaluate_graph_attribute_expr(graph, match, expr); */ +/* assignments.emplace(key, value); */ +/* } */ +/* assert(contains_key(assignments, OperatorAttributeKey::OP_TYPE)); */ +/* assert(std::holds_alternative( */ +/* assignments.at(OperatorAttributeKey::OP_TYPE))); */ +/* OperatorType op_type = */ +/* std::get(assignments.at(OperatorAttributeKey::OP_TYPE)); + */ +/* switch (op_type) { */ +/* case OperatorType::BATCHMATMUL: */ +/* return Operator{ */ +/* BatchMatmulAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::A_SEQ_LENGTH_DIM)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, */ +/* std::nullopt}; */ +/* case OperatorType::BATCHNORM: */ +/* return Operator{BatchNormAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::RELU))}, */ +/* std::nullopt}; */ +/* case OperatorType::CAST: */ +/* return Operator{CastAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::DATA_TYPE))}, + */ +/* std::nullopt}; */ +/* case OperatorType::CONCAT: */ +/* return Operator{ */ +/* ConcatAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::AXIS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, + */ +/* std::nullopt}; */ +/* case OperatorType::CONV2D: */ +/* return Operator{ */ +/* Conv2DAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::GROUPS)), */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::ACTIVATION)), */ +/* std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, + */ +/* std::nullopt}; */ +/* case OperatorType::DROPOUT: */ +/* return Operator{DropoutAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::RATE)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::SEED))}, */ +/* std::nullopt}; */ +/* case OperatorType::EW_ADD: */ +/* case OperatorType::EW_DIV: */ +/* case OperatorType::EW_EQUAL: */ +/* case OperatorType::EW_GREATER: */ +/* case OperatorType::EW_LESS: */ +/* case OperatorType::EW_MAX: */ +/* case OperatorType::EW_MIN: */ +/* case OperatorType::EW_MUL: */ +/* case OperatorType::EW_SUB: */ +/* return Operator{ */ +/* ElementBinaryAttrs{op_type, */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::DATA_TYPE)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::SHOULD_BROADCAST_LHS)), + */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, + */ +/* std::nullopt}; */ +/* case OperatorType::SCALAR_ADD: */ +/* case OperatorType::SCALAR_FLOOR_DIV: */ +/* case OperatorType::SCALAR_MULTIPLY: */ +/* case OperatorType::SCALAR_SUB: */ +/* case OperatorType::SCALAR_TRUE_DIV: */ +/* return Operator{ */ +/* ElementScalarUnaryAttrs{ */ +/* op_type, */ +/* std::get(assignments.at(OperatorAttributeKey::SCALAR))}, + */ +/* std::nullopt}; */ +/* case OperatorType::EXP: */ +/* case OperatorType::IDENTITY: */ +/* case OperatorType::GELU: */ +/* case OperatorType::RSQRT: */ +/* case OperatorType::POW: */ +/* case OperatorType::SIN: */ +/* case OperatorType::COS: */ +/* return Operator{ElementUnaryAttrs{op_type}, std::nullopt}; */ +/* case OperatorType::EMBEDDING: */ +/* return Operator{ */ +/* EmbeddingAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), + */ +/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::AGGR)), + */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::OP_TYPE))}, */ +/* std::nullopt}; */ +/* case OperatorType::FLAT: */ +/* return Operator{FlatAttrs{}, std::nullopt}; */ +/* case OperatorType::GATHER: */ +/* return Operator{GatherAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::DIM))}, */ +/* std::nullopt}; */ +/* case OperatorType::INPUT: */ +/* return Operator{InputAttrs{}, std::nullopt}; */ +/* case OperatorType::LAYERNORM: */ +/* return Operator{ */ +/* LayerNormAttrs{ */ +/* std::get>( */ +/* assignments.at(OperatorAttributeKey::AXES)), */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), + */ +/* std::get(assignments.at(OperatorAttributeKey::EPSILON))}, + */ +/* std::nullopt}; */ +/* case OperatorType::LINEAR: */ +/* return Operator{ */ +/* LinearAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::USE_BIAS)), + */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::DATA_TYPE)), */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::ACTIVATION)), */ +/* std::get>( */ +/* assignments.at(OperatorAttributeKey::REGULARIZER))}, */ +/* std::nullopt}; */ +/* case OperatorType::MULTIHEAD_ATTENTION: */ +/* return Operator{ */ +/* MultiHeadAttentionAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), + */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::VDIM)), */ +/* std::get(assignments.at(OperatorAttributeKey::DROPOUT)), + */ +/* std::get(assignments.at(OperatorAttributeKey::BIAS)), */ +/* std::get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), + */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, */ +/* std::nullopt}; */ +/* case OperatorType::NOOP: */ +/* return Operator{NoopAttrs{}, std::nullopt}; */ +/* case OperatorType::POOL2D: */ +/* return Operator{ */ +/* Pool2DAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::POOL_TYPE)), + */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::ACTIVATION))}, */ +/* std::nullopt}; */ +/* case OperatorType::REDUCE_ARGMAX: */ +/* case OperatorType::REDUCE_ARGMIN: */ +/* case OperatorType::REDUCE_MAX: */ +/* case OperatorType::REDUCE_MEAN: */ +/* case OperatorType::REDUCE_MIN: */ +/* case OperatorType::REDUCE_PROD: */ +/* case OperatorType::REDUCE_SUM: */ +/* return Operator{ */ +/* ReduceAttrs{ */ +/* std::get>( */ +/* assignments.at(OperatorAttributeKey::AXES)), */ +/* op_type, */ +/* std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, + */ +/* std::nullopt}; */ +/* case OperatorType::REVERSE: */ +/* return Operator{ReverseAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::AXIS))}, */ +/* std::nullopt}; */ +/* case OperatorType::RESHAPE: */ +/* return Operator{ReshapeAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::SHAPE))}, */ +/* std::nullopt}; */ +/* case OperatorType::SPLIT: */ +/* return Operator{ */ +/* SplitAttrs{ */ +/* std::get>( */ +/* assignments.at(OperatorAttributeKey::SPLITS)), */ +/* std::get(assignments.at(OperatorAttributeKey::AXIS))}, + */ +/* std::nullopt}; */ +/* case OperatorType::SOFTMAX: */ +/* return Operator{SoftmaxAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::DIM))}, */ +/* std::nullopt}; */ +/* case OperatorType::TOPK: */ +/* return Operator{ */ +/* TopKAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::K)), */ +/* std::get(assignments.at(OperatorAttributeKey::SORTED))}, + */ +/* std::nullopt}; */ +/* case OperatorType::TRANSPOSE: */ +/* return Operator{ */ +/* TransposeAttrs{std::get>( */ +/* assignments.at(OperatorAttributeKey::PERMUTATION))}, */ +/* std::nullopt}; */ +/* case OperatorType::COMBINE: */ +/* return Operator{CombineAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DIM)), + */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DEGREE))}, + */ +/* std::nullopt}; */ +/* case OperatorType::REDUCTION: */ +/* return Operator{ */ +/* ReductionAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DIM)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ +/* std::nullopt}; */ +/* case OperatorType::REPARTITION: */ +/* return Operator{ */ +/* RepartitionAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DIM)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ +/* std::nullopt}; */ +/* case OperatorType::REPLICATE: */ +/* return Operator{ */ +/* ReplicateAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DIM)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ +/* std::nullopt}; */ +/* default: */ +/* throw mk_runtime_error("Unknown Operator"); */ +/* } */ +/* } */ } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/constraint_type.dtg.cc b/lib/substitutions/src/substitutions/constraint_type.dtg.cc new file mode 100644 index 0000000000..aa5c30dbe9 --- /dev/null +++ b/lib/substitutions/src/substitutions/constraint_type.dtg.cc @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/constraint_type.enum.toml +/* proj-data +{ + "generated_from": "06b029d76658cb434abf08b1fdb86137" +} +*/ + +#include "substitutions/constraint_type.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::ConstraintType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(ConstraintType x) { + switch (x) { + case ConstraintType::EQUAL: + return "EQUAL"; + default: + std::ostringstream oss; + oss << "Unknown ConstraintType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, ConstraintType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, ConstraintType x) { + switch (x) { + case ConstraintType::EQUAL: + j = "EQUAL"; + break; + default: + std::ostringstream oss; + oss << "Unknown ConstraintType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, ConstraintType &x) { + std::string as_str = j.get(); + if (as_str == "EQUAL") { + x = ConstraintType::EQUAL; + } else { + std::ostringstream oss; + oss << "Unknown ConstraintType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element( + FlexFlow::ConstraintType::EQUAL); +} +} // namespace rc diff --git a/lib/substitutions/src/substitutions/graph_pattern.cc b/lib/substitutions/src/substitutions/graph_pattern.cc new file mode 100644 index 0000000000..22cf12b4cf --- /dev/null +++ b/lib/substitutions/src/substitutions/graph_pattern.cc @@ -0,0 +1,42 @@ +#include "substitutions/graph_pattern.h" +#include "substitutions/operator_pattern/satisfies_pattern.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/tensor_pattern/satisfies_pattern.h" + +namespace FlexFlow { + +UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &p) { + return UnlabelledGraphPattern{p.raw_graph}; +} + +TensorAttributePattern get_tensor_pattern(PCGPattern const &p, + PatternEdge const &e) { + return p.raw_graph.at(e.raw_edge); +} + +OperatorAttributePattern get_operator_pattern(PCGPattern const &p, + PatternNode const &n) { + return p.raw_graph.at(n.raw_node); +} + +bool assignment_satisfies(SubParallelComputationGraph const &pcg, + PCGPattern const &pattern, + MultiDiGraphPatternMatch const &patternMatch) { + return unlabelled_pattern_does_match( + get_unlabelled_pattern(pattern), + pcg.raw_graph, + patternMatch, + MatchAdditionalCriterion{ + [&](PatternNode const &patternNode, Node const &pcgNode) { + return operator_satisfies_pattern( + get_operator_attrs(pcg, pcgNode), + get_operator_pattern(pattern, patternNode)); + }, + [&](PatternEdge const &patternEdge, OpenMultiDiEdge const &pcgEdge) { + return parallel_tensor_satisfies_pattern( + get_parallel_tensor_attrs(pcg, pcgEdge), + get_tensor_pattern(pattern, patternEdge)); + }}); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc b/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc new file mode 100644 index 0000000000..53973dc1cb --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc @@ -0,0 +1,41 @@ +#include "substitutions/operator_pattern/eval_list_access.h" +#include "substitutions/operator_pattern/get_attribute.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::optional + eval_list_access(PCGOperatorAttrs const &attrs, + OperatorAttributeListIndexAccess const &acc) { + std::optional from_attr = + get_attribute(attrs, acc.attribute_key); + + if (!from_attr.has_value()) { + return std::nullopt; + } + + return from_attr.value().visit>( + [&](auto const &v) -> std::optional { + using T = std::decay_t; + + if constexpr (std::is_same_v>) { + if (acc.index >= v.size()) { + return std::nullopt; + } else { + int value = v.at(acc.index); + return OperatorAttributeValue{value}; + } + } else if constexpr (std::is_same_v>) { + if (acc.index >= v.size()) { + return std::nullopt; + } else { + ff_dim_t value = v.at(acc.index); + return OperatorAttributeValue{value}; + } + } else { + throw mk_runtime_error("Invalid operand"); + } + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/eval_list_size.cc b/lib/substitutions/src/substitutions/operator_pattern/eval_list_size.cc new file mode 100644 index 0000000000..a3ae9c84d1 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/eval_list_size.cc @@ -0,0 +1,31 @@ +#include "substitutions/operator_pattern/eval_list_size.h" +#include "substitutions/operator_pattern/get_attribute.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::optional + eval_list_size(PCGOperatorAttrs const &attrs, + OperatorAttributeListSize const &acc) { + std::optional from_attr = + get_attribute(attrs, acc.attribute_key); + + if (!from_attr.has_value()) { + return std::nullopt; + } + + return from_attr.value().visit>( + [&](auto const &v) -> std::optional { + using T = std::decay_t; + + if constexpr (std::is_same_v> || + std::is_same_v>) { + size_t size = v.size(); + return OperatorAttributeValue{size}; + } else { + throw mk_runtime_error("Invalid operand"); + } + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/operator_attributes.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc similarity index 72% rename from lib/substitutions/src/operator_attributes.cc rename to lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index 8bd8688194..e168760c3b 100644 --- a/lib/substitutions/src/operator_attributes.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -1,11 +1,26 @@ +#include "substitutions/operator_pattern/get_attribute.h" #include "op-attrs/get_op_type.h" -#include "substitutions/get_attribute.h" +#include "utils/containers.h" namespace FlexFlow { std::optional get_attribute(BatchMatmulAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + default: + return std::nullopt; + } +} + +std::optional get_attribute(BatchNormAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + case OperatorAttributeKey::RELU: + return p.relu; default: return std::nullopt; } @@ -14,6 +29,8 @@ std::optional get_attribute(BatchMatmulAttrs const &p, std::optional get_attribute(CastAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::DATA_TYPE: return p.dtype; default: @@ -24,6 +41,8 @@ std::optional get_attribute(CastAttrs const &p, std::optional get_attribute(CombineAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::PARALLEL_OP_DIM: return p.combine_dim; case OperatorAttributeKey::PARALLEL_DIM: @@ -36,6 +55,8 @@ std::optional get_attribute(CombineAttrs const &p, std::optional get_attribute(ConcatAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::AXIS: return p.axis; default: @@ -46,6 +67,8 @@ std::optional get_attribute(ConcatAttrs const &p, std::optional get_attribute(Conv2DAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::KERNEL_H: return p.kernel_h; case OperatorAttributeKey::KERNEL_W: @@ -72,6 +95,8 @@ std::optional get_attribute(Conv2DAttrs const &p, std::optional get_attribute(ElementBinaryAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -80,6 +105,8 @@ std::optional get_attribute(ElementBinaryAttrs const &p, std::optional get_attribute(ElementUnaryAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -88,6 +115,8 @@ std::optional get_attribute(ElementUnaryAttrs const &p, std::optional get_attribute(ElementScalarUnaryAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -96,6 +125,8 @@ std::optional std::optional get_attribute(DropoutAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -104,6 +135,8 @@ std::optional get_attribute(DropoutAttrs const &p, std::optional get_attribute(EmbeddingAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::DATA_TYPE: return p.data_type; case OperatorAttributeKey::AGGR: @@ -120,6 +153,8 @@ std::optional get_attribute(EmbeddingAttrs const &p, std::optional get_attribute(FlatAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -128,6 +163,8 @@ std::optional get_attribute(FlatAttrs const &p, std::optional get_attribute(GatherAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::AXIS: return p.dim; default: @@ -135,9 +172,21 @@ std::optional get_attribute(GatherAttrs const &p, } } +std::optional get_attribute(InputAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + default: + return std::nullopt; + } +} + std::optional get_attribute(LayerNormAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -146,6 +195,8 @@ std::optional get_attribute(LayerNormAttrs const &p, std::optional get_attribute(LinearAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::OUT_CHANNELS: return p.out_channels; case OperatorAttributeKey::USE_BIAS: @@ -166,6 +217,8 @@ std::optional get_attribute(LinearAttrs const &p, std::optional get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::NUM_HEADS: return p.num_heads; case OperatorAttributeKey::USE_BIAS: @@ -175,9 +228,21 @@ std::optional } } +std::optional get_attribute(NoopAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + default: + return std::nullopt; + } +} + std::optional get_attribute(Pool2DAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::KERNEL_H: return p.kernel_h; case OperatorAttributeKey::KERNEL_W: @@ -202,6 +267,8 @@ std::optional get_attribute(Pool2DAttrs const &p, std::optional get_attribute(ReduceAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -210,8 +277,8 @@ std::optional get_attribute(ReduceAttrs const &p, std::optional get_attribute(ReductionAttrs const &p, OperatorAttributeKey key) { switch (key) { - case OperatorAttributeKey::PARALLEL_OP_DIM: - return p.reduction_dim; + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.reduction_degree; default: @@ -222,6 +289,8 @@ std::optional get_attribute(ReductionAttrs const &p, std::optional get_attribute(RepartitionAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::PARALLEL_OP_DIM: return p.repartition_dim; case OperatorAttributeKey::PARALLEL_OP_DEGREE: @@ -234,8 +303,8 @@ std::optional get_attribute(RepartitionAttrs const &p, std::optional get_attribute(ReplicateAttrs const &p, OperatorAttributeKey key) { switch (key) { - case OperatorAttributeKey::PARALLEL_OP_DIM: - return p.replicate_dim; + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.replicate_degree; default: @@ -246,6 +315,20 @@ std::optional get_attribute(ReplicateAttrs const &p, std::optional get_attribute(ReshapeAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + default: + return std::nullopt; + } +} + +std::optional get_attribute(ReverseAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + case OperatorAttributeKey::AXIS: + return p.axis; default: return std::nullopt; } @@ -254,6 +337,8 @@ std::optional get_attribute(ReshapeAttrs const &p, std::optional get_attribute(SplitAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::AXIS: return p.axis; default: @@ -264,6 +349,8 @@ std::optional get_attribute(SplitAttrs const &p, std::optional get_attribute(SoftmaxAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::AXIS: return p.dim; default: @@ -274,6 +361,8 @@ std::optional get_attribute(SoftmaxAttrs const &p, std::optional get_attribute(TopKAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -282,38 +371,19 @@ std::optional get_attribute(TopKAttrs const &p, std::optional get_attribute(TransposeAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::PERMUTATION: - return p.perm; + return as_vector(p.perm); default: return std::nullopt; } } -struct GetAttribute { - GetAttribute(OperatorAttributeKey key) : key(key) {} - - template - std::optional operator()(T const &t) { - return get_attribute(t, this->key); - } - -private: - OperatorAttributeKey key; -}; - -struct GetOpType { - template - std::optional operator()(T const &t) { - return get_op_type(t); - } -}; - std::optional get_attribute(PCGOperatorAttrs const &p, OperatorAttributeKey key) { - if (key == OperatorAttributeKey::OP_TYPE) { - return std::visit(GetOpType{}, p); - } - return std::visit(GetAttribute(key), p); + return p.visit>( + [&](auto const &attrs) { return get_attribute(attrs, key); }); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc new file mode 100644 index 0000000000..bc913b7c1a --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc @@ -0,0 +1,121 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml +/* proj-data +{ + "generated_from": "7867bd0f403866c13417171bb5ec364c" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" + +#include "substitutions/constraint_type.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include + +namespace FlexFlow { +OperatorAttributeConstraint::OperatorAttributeConstraint( + ::FlexFlow::ConstraintType const &constraint_type, + ::FlexFlow::OperatorAttributeExpr const &attribute_expr, + ::FlexFlow::OperatorAttributeValue const &attribute_value) + : constraint_type(constraint_type), attribute_expr(attribute_expr), + attribute_value(attribute_value) {} +bool OperatorAttributeConstraint::operator==( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) == std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool OperatorAttributeConstraint::operator!=( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) != std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool OperatorAttributeConstraint::operator<( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) < std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool OperatorAttributeConstraint::operator>( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) > std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool OperatorAttributeConstraint::operator<=( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) <= std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool OperatorAttributeConstraint::operator>=( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) >= std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorAttributeConstraint const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ConstraintType>{}(x.constraint_type) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::OperatorAttributeExpr>{}(x.attribute_expr) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::OperatorAttributeValue>{}(x.attribute_value) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::OperatorAttributeConstraint + adl_serializer::from_json( + json const &j) { + return { + j.at("constraint_type").template get<::FlexFlow::ConstraintType>(), + j.at("attribute_expr").template get<::FlexFlow::OperatorAttributeExpr>(), + j.at("attribute_value") + .template get<::FlexFlow::OperatorAttributeValue>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::OperatorAttributeConstraint const &v) { + j["__type"] = "OperatorAttributeConstraint"; + j["constraint_type"] = v.constraint_type; + j["attribute_expr"] = v.attribute_expr; + j["attribute_value"] = v.attribute_value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(OperatorAttributeConstraint const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + OperatorAttributeConstraint const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc new file mode 100644 index 0000000000..4a55fa3de3 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc @@ -0,0 +1,23 @@ +#include "substitutions/operator_pattern/operator_attribute_expr.h" +#include "substitutions/operator_pattern/eval_list_access.h" +#include "substitutions/operator_pattern/eval_list_size.h" +#include "substitutions/operator_pattern/get_attribute.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::optional + evaluate_attribute_expr(PCGOperatorAttrs const &attrs, + OperatorAttributeExpr const &expr) { + return expr.visit>(overload{ + [&](OperatorAttributeKey const &k) { return get_attribute(attrs, k); }, + [&](OperatorAttributeListSize const &k) { + return eval_list_size(attrs, k); + }, + [&](OperatorAttributeListIndexAccess const &k) { + return eval_list_access(attrs, k); + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.dtg.cc new file mode 100644 index 0000000000..60c77d8d0f --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.dtg.cc @@ -0,0 +1,137 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "15d26dd1f08092ecc82b725aa9411597" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +OperatorAttributeExpr::OperatorAttributeExpr( + ::FlexFlow::OperatorAttributeKey const &v) + : raw_variant(v) {} +OperatorAttributeExpr::OperatorAttributeExpr( + ::FlexFlow::OperatorAttributeListSize const &v) + : raw_variant(v) {} +OperatorAttributeExpr::OperatorAttributeExpr( + ::FlexFlow::OperatorAttributeListIndexAccess const &v) + : raw_variant(v) {} +bool OperatorAttributeExpr::operator==( + OperatorAttributeExpr const &other) const { + return this->raw_variant == other.raw_variant; +} +bool OperatorAttributeExpr::operator!=( + OperatorAttributeExpr const &other) const { + return this->raw_variant != other.raw_variant; +} +bool OperatorAttributeExpr::operator<( + OperatorAttributeExpr const &other) const { + return this->raw_variant < other.raw_variant; +} +bool OperatorAttributeExpr::operator>( + OperatorAttributeExpr const &other) const { + return this->raw_variant > other.raw_variant; +} +bool OperatorAttributeExpr::operator<=( + OperatorAttributeExpr const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool OperatorAttributeExpr::operator>=( + OperatorAttributeExpr const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::OperatorAttributeExpr>::operator()( + ::FlexFlow::OperatorAttributeExpr const &x) const { + return std::hash< + std::variant<::FlexFlow::OperatorAttributeKey, + ::FlexFlow::OperatorAttributeListSize, + ::FlexFlow::OperatorAttributeListIndexAccess>>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::OperatorAttributeExpr + adl_serializer<::FlexFlow::OperatorAttributeExpr>::from_json( + json const &j) { + std::string key = j.at("type").template get(); + if (key == "key") { + return ::FlexFlow::OperatorAttributeExpr{ + j.at("value").template get<::FlexFlow::OperatorAttributeKey>()}; + } else if (key == "list_size") { + return ::FlexFlow::OperatorAttributeExpr{ + j.at("value").template get<::FlexFlow::OperatorAttributeListSize>()}; + } else if (key == "list_idx") { + return ::FlexFlow::OperatorAttributeExpr{ + j.at("value") + .template get<::FlexFlow::OperatorAttributeListIndexAccess>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::OperatorAttributeExpr>::to_json( + json &j, ::FlexFlow::OperatorAttributeExpr const &x) { + j["__type"] = "OperatorAttributeExpr"; + switch (x.index()) { + case 0: { + j["type"] = "key"; + j["value"] = x.get<::FlexFlow::OperatorAttributeKey>(); + break; + } + case 1: { + j["type"] = "list_size"; + j["value"] = x.get<::FlexFlow::OperatorAttributeListSize>(); + break; + } + case 2: { + j["type"] = "list_idx"; + j["value"] = x.get<::FlexFlow::OperatorAttributeListIndexAccess>(); + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeExpr", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::OperatorAttributeExpr const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeExpr", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::OperatorAttributeExpr const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_key.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_key.dtg.cc new file mode 100644 index 0000000000..a24e1c12e4 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_key.dtg.cc @@ -0,0 +1,505 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml +/* proj-data +{ + "generated_from": "e637388397720b328b1f4b9ba6b14611" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorAttributeKey x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(OperatorAttributeKey x) { + switch (x) { + case OperatorAttributeKey::OP_TYPE: + return "OP_TYPE"; + case OperatorAttributeKey::USE_BIAS: + return "USE_BIAS"; + case OperatorAttributeKey::GROUPS: + return "GROUPS"; + case OperatorAttributeKey::POOL_TYPE: + return "POOL_TYPE"; + case OperatorAttributeKey::KERNEL_H: + return "KERNEL_H"; + case OperatorAttributeKey::KERNEL_W: + return "KERNEL_W"; + case OperatorAttributeKey::DATA_TYPE: + return "DATA_TYPE"; + case OperatorAttributeKey::SCALAR: + return "SCALAR"; + case OperatorAttributeKey::STRIDE_H: + return "STRIDE_H"; + case OperatorAttributeKey::STRIDE_W: + return "STRIDE_W"; + case OperatorAttributeKey::PADDING_H: + return "PADDING_H"; + case OperatorAttributeKey::PADDING_W: + return "PADDING_W"; + case OperatorAttributeKey::AGGR: + return "AGGR"; + case OperatorAttributeKey::NUM_ENTRIES: + return "NUM_ENTRIES"; + case OperatorAttributeKey::OUT_CHANNELS: + return "OUT_CHANNELS"; + case OperatorAttributeKey::ACTIVATION: + return "ACTIVATION"; + case OperatorAttributeKey::NUMDIM: + return "NUMDIM"; + case OperatorAttributeKey::AXIS: + return "AXIS"; + case OperatorAttributeKey::PERMUTATION: + return "PERMUTATION"; + case OperatorAttributeKey::OUTSHUFFLE: + return "OUTSHUFFLE"; + case OperatorAttributeKey::MERGE_GCONV_COUNT: + return "MERGE_GCONV_COUNT"; + case OperatorAttributeKey::AXES: + return "AXES"; + case OperatorAttributeKey::KEEP_DIMS: + return "KEEP_DIMS"; + case OperatorAttributeKey::EPSILON: + return "EPSILON"; + case OperatorAttributeKey::PARALLEL_OP_DIM: + return "PARALLEL_OP_DIM"; + case OperatorAttributeKey::PARALLEL_OP_DEGREE: + return "PARALLEL_OP_DEGREE"; + case OperatorAttributeKey::SOFTMAX_DIM: + return "SOFTMAX_DIM"; + case OperatorAttributeKey::NUM_HEADS: + return "NUM_HEADS"; + case OperatorAttributeKey::PARALLEL_DIM: + return "PARALLEL_DIM"; + case OperatorAttributeKey::PARALLEL_DEGREE: + return "PARALLEL_DEGREE"; + case OperatorAttributeKey::PAD: + return "PAD"; + case OperatorAttributeKey::EMBED_DIM: + return "EMBED_DIM"; + case OperatorAttributeKey::KDIM: + return "KDIM"; + case OperatorAttributeKey::VDIM: + return "VDIM"; + case OperatorAttributeKey::DROPOUT: + return "DROPOUT"; + case OperatorAttributeKey::BIAS: + return "BIAS"; + case OperatorAttributeKey::ADD_BIAS_KV: + return "ADD_BIAS_KV"; + case OperatorAttributeKey::ADD_ZERO_ATTN: + return "ADD_ZERO_ATTN"; + case OperatorAttributeKey::A_SEQ_LENGTH_DIM: + return "A_SEQ_LENGTH_DIM"; + case OperatorAttributeKey::B_SEQ_LENGTH_DIM: + return "B_SEQ_LENGTH_DIM"; + case OperatorAttributeKey::RELU: + return "RELU"; + case OperatorAttributeKey::TARGET_DIMS: + return "TARGET_DIMS"; + case OperatorAttributeKey::RATE: + return "RATE"; + case OperatorAttributeKey::SEED: + return "SEED"; + case OperatorAttributeKey::SHOULD_BROADCAST_LHS: + return "SHOULD_BROADCAST_LHS"; + case OperatorAttributeKey::SHOULD_BROADCAST_RHS: + return "SHOULD_BROADCAST_RHS"; + case OperatorAttributeKey::DIM: + return "DIM"; + case OperatorAttributeKey::ELEMENTWISE_AFFINE: + return "ELEMENTWISE_AFFINE"; + case OperatorAttributeKey::REGULARIZER: + return "REGULARIZER"; + case OperatorAttributeKey::SHAPE: + return "SHAPE"; + case OperatorAttributeKey::SPLITS: + return "SPLITS"; + case OperatorAttributeKey::K: + return "K"; + case OperatorAttributeKey::SORTED: + return "SORTED"; + case OperatorAttributeKey::COMBINE_DIM: + return "COMBINE_DIM"; + case OperatorAttributeKey::COMBINE_DEGREE: + return "COMBINE_DEGREE"; + case OperatorAttributeKey::NUM_INPUTS: + return "NUM_INPUTS"; + default: + std::ostringstream oss; + oss << "Unknown OperatorAttributeKey value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, OperatorAttributeKey x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, OperatorAttributeKey x) { + switch (x) { + case OperatorAttributeKey::OP_TYPE: + j = "OP_TYPE"; + break; + case OperatorAttributeKey::USE_BIAS: + j = "USE_BIAS"; + break; + case OperatorAttributeKey::GROUPS: + j = "GROUPS"; + break; + case OperatorAttributeKey::POOL_TYPE: + j = "POOL_TYPE"; + break; + case OperatorAttributeKey::KERNEL_H: + j = "KERNEL_H"; + break; + case OperatorAttributeKey::KERNEL_W: + j = "KERNEL_W"; + break; + case OperatorAttributeKey::DATA_TYPE: + j = "DATA_TYPE"; + break; + case OperatorAttributeKey::SCALAR: + j = "SCALAR"; + break; + case OperatorAttributeKey::STRIDE_H: + j = "STRIDE_H"; + break; + case OperatorAttributeKey::STRIDE_W: + j = "STRIDE_W"; + break; + case OperatorAttributeKey::PADDING_H: + j = "PADDING_H"; + break; + case OperatorAttributeKey::PADDING_W: + j = "PADDING_W"; + break; + case OperatorAttributeKey::AGGR: + j = "AGGR"; + break; + case OperatorAttributeKey::NUM_ENTRIES: + j = "NUM_ENTRIES"; + break; + case OperatorAttributeKey::OUT_CHANNELS: + j = "OUT_CHANNELS"; + break; + case OperatorAttributeKey::ACTIVATION: + j = "ACTIVATION"; + break; + case OperatorAttributeKey::NUMDIM: + j = "NUMDIM"; + break; + case OperatorAttributeKey::AXIS: + j = "AXIS"; + break; + case OperatorAttributeKey::PERMUTATION: + j = "PERMUTATION"; + break; + case OperatorAttributeKey::OUTSHUFFLE: + j = "OUTSHUFFLE"; + break; + case OperatorAttributeKey::MERGE_GCONV_COUNT: + j = "MERGE_GCONV_COUNT"; + break; + case OperatorAttributeKey::AXES: + j = "AXES"; + break; + case OperatorAttributeKey::KEEP_DIMS: + j = "KEEP_DIMS"; + break; + case OperatorAttributeKey::EPSILON: + j = "EPSILON"; + break; + case OperatorAttributeKey::PARALLEL_OP_DIM: + j = "PARALLEL_OP_DIM"; + break; + case OperatorAttributeKey::PARALLEL_OP_DEGREE: + j = "PARALLEL_OP_DEGREE"; + break; + case OperatorAttributeKey::SOFTMAX_DIM: + j = "SOFTMAX_DIM"; + break; + case OperatorAttributeKey::NUM_HEADS: + j = "NUM_HEADS"; + break; + case OperatorAttributeKey::PARALLEL_DIM: + j = "PARALLEL_DIM"; + break; + case OperatorAttributeKey::PARALLEL_DEGREE: + j = "PARALLEL_DEGREE"; + break; + case OperatorAttributeKey::PAD: + j = "PAD"; + break; + case OperatorAttributeKey::EMBED_DIM: + j = "EMBED_DIM"; + break; + case OperatorAttributeKey::KDIM: + j = "KDIM"; + break; + case OperatorAttributeKey::VDIM: + j = "VDIM"; + break; + case OperatorAttributeKey::DROPOUT: + j = "DROPOUT"; + break; + case OperatorAttributeKey::BIAS: + j = "BIAS"; + break; + case OperatorAttributeKey::ADD_BIAS_KV: + j = "ADD_BIAS_KV"; + break; + case OperatorAttributeKey::ADD_ZERO_ATTN: + j = "ADD_ZERO_ATTN"; + break; + case OperatorAttributeKey::A_SEQ_LENGTH_DIM: + j = "A_SEQ_LENGTH_DIM"; + break; + case OperatorAttributeKey::B_SEQ_LENGTH_DIM: + j = "B_SEQ_LENGTH_DIM"; + break; + case OperatorAttributeKey::RELU: + j = "RELU"; + break; + case OperatorAttributeKey::TARGET_DIMS: + j = "TARGET_DIMS"; + break; + case OperatorAttributeKey::RATE: + j = "RATE"; + break; + case OperatorAttributeKey::SEED: + j = "SEED"; + break; + case OperatorAttributeKey::SHOULD_BROADCAST_LHS: + j = "SHOULD_BROADCAST_LHS"; + break; + case OperatorAttributeKey::SHOULD_BROADCAST_RHS: + j = "SHOULD_BROADCAST_RHS"; + break; + case OperatorAttributeKey::DIM: + j = "DIM"; + break; + case OperatorAttributeKey::ELEMENTWISE_AFFINE: + j = "ELEMENTWISE_AFFINE"; + break; + case OperatorAttributeKey::REGULARIZER: + j = "REGULARIZER"; + break; + case OperatorAttributeKey::SHAPE: + j = "SHAPE"; + break; + case OperatorAttributeKey::SPLITS: + j = "SPLITS"; + break; + case OperatorAttributeKey::K: + j = "K"; + break; + case OperatorAttributeKey::SORTED: + j = "SORTED"; + break; + case OperatorAttributeKey::COMBINE_DIM: + j = "COMBINE_DIM"; + break; + case OperatorAttributeKey::COMBINE_DEGREE: + j = "COMBINE_DEGREE"; + break; + case OperatorAttributeKey::NUM_INPUTS: + j = "NUM_INPUTS"; + break; + default: + std::ostringstream oss; + oss << "Unknown OperatorAttributeKey value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, OperatorAttributeKey &x) { + std::string as_str = j.get(); + if (as_str == "OP_TYPE") { + x = OperatorAttributeKey::OP_TYPE; + } else if (as_str == "USE_BIAS") { + x = OperatorAttributeKey::USE_BIAS; + } else if (as_str == "GROUPS") { + x = OperatorAttributeKey::GROUPS; + } else if (as_str == "POOL_TYPE") { + x = OperatorAttributeKey::POOL_TYPE; + } else if (as_str == "KERNEL_H") { + x = OperatorAttributeKey::KERNEL_H; + } else if (as_str == "KERNEL_W") { + x = OperatorAttributeKey::KERNEL_W; + } else if (as_str == "DATA_TYPE") { + x = OperatorAttributeKey::DATA_TYPE; + } else if (as_str == "SCALAR") { + x = OperatorAttributeKey::SCALAR; + } else if (as_str == "STRIDE_H") { + x = OperatorAttributeKey::STRIDE_H; + } else if (as_str == "STRIDE_W") { + x = OperatorAttributeKey::STRIDE_W; + } else if (as_str == "PADDING_H") { + x = OperatorAttributeKey::PADDING_H; + } else if (as_str == "PADDING_W") { + x = OperatorAttributeKey::PADDING_W; + } else if (as_str == "AGGR") { + x = OperatorAttributeKey::AGGR; + } else if (as_str == "NUM_ENTRIES") { + x = OperatorAttributeKey::NUM_ENTRIES; + } else if (as_str == "OUT_CHANNELS") { + x = OperatorAttributeKey::OUT_CHANNELS; + } else if (as_str == "ACTIVATION") { + x = OperatorAttributeKey::ACTIVATION; + } else if (as_str == "NUMDIM") { + x = OperatorAttributeKey::NUMDIM; + } else if (as_str == "AXIS") { + x = OperatorAttributeKey::AXIS; + } else if (as_str == "PERMUTATION") { + x = OperatorAttributeKey::PERMUTATION; + } else if (as_str == "OUTSHUFFLE") { + x = OperatorAttributeKey::OUTSHUFFLE; + } else if (as_str == "MERGE_GCONV_COUNT") { + x = OperatorAttributeKey::MERGE_GCONV_COUNT; + } else if (as_str == "AXES") { + x = OperatorAttributeKey::AXES; + } else if (as_str == "KEEP_DIMS") { + x = OperatorAttributeKey::KEEP_DIMS; + } else if (as_str == "EPSILON") { + x = OperatorAttributeKey::EPSILON; + } else if (as_str == "PARALLEL_OP_DIM") { + x = OperatorAttributeKey::PARALLEL_OP_DIM; + } else if (as_str == "PARALLEL_OP_DEGREE") { + x = OperatorAttributeKey::PARALLEL_OP_DEGREE; + } else if (as_str == "SOFTMAX_DIM") { + x = OperatorAttributeKey::SOFTMAX_DIM; + } else if (as_str == "NUM_HEADS") { + x = OperatorAttributeKey::NUM_HEADS; + } else if (as_str == "PARALLEL_DIM") { + x = OperatorAttributeKey::PARALLEL_DIM; + } else if (as_str == "PARALLEL_DEGREE") { + x = OperatorAttributeKey::PARALLEL_DEGREE; + } else if (as_str == "PAD") { + x = OperatorAttributeKey::PAD; + } else if (as_str == "EMBED_DIM") { + x = OperatorAttributeKey::EMBED_DIM; + } else if (as_str == "KDIM") { + x = OperatorAttributeKey::KDIM; + } else if (as_str == "VDIM") { + x = OperatorAttributeKey::VDIM; + } else if (as_str == "DROPOUT") { + x = OperatorAttributeKey::DROPOUT; + } else if (as_str == "BIAS") { + x = OperatorAttributeKey::BIAS; + } else if (as_str == "ADD_BIAS_KV") { + x = OperatorAttributeKey::ADD_BIAS_KV; + } else if (as_str == "ADD_ZERO_ATTN") { + x = OperatorAttributeKey::ADD_ZERO_ATTN; + } else if (as_str == "A_SEQ_LENGTH_DIM") { + x = OperatorAttributeKey::A_SEQ_LENGTH_DIM; + } else if (as_str == "B_SEQ_LENGTH_DIM") { + x = OperatorAttributeKey::B_SEQ_LENGTH_DIM; + } else if (as_str == "RELU") { + x = OperatorAttributeKey::RELU; + } else if (as_str == "TARGET_DIMS") { + x = OperatorAttributeKey::TARGET_DIMS; + } else if (as_str == "RATE") { + x = OperatorAttributeKey::RATE; + } else if (as_str == "SEED") { + x = OperatorAttributeKey::SEED; + } else if (as_str == "SHOULD_BROADCAST_LHS") { + x = OperatorAttributeKey::SHOULD_BROADCAST_LHS; + } else if (as_str == "SHOULD_BROADCAST_RHS") { + x = OperatorAttributeKey::SHOULD_BROADCAST_RHS; + } else if (as_str == "DIM") { + x = OperatorAttributeKey::DIM; + } else if (as_str == "ELEMENTWISE_AFFINE") { + x = OperatorAttributeKey::ELEMENTWISE_AFFINE; + } else if (as_str == "REGULARIZER") { + x = OperatorAttributeKey::REGULARIZER; + } else if (as_str == "SHAPE") { + x = OperatorAttributeKey::SHAPE; + } else if (as_str == "SPLITS") { + x = OperatorAttributeKey::SPLITS; + } else if (as_str == "K") { + x = OperatorAttributeKey::K; + } else if (as_str == "SORTED") { + x = OperatorAttributeKey::SORTED; + } else if (as_str == "COMBINE_DIM") { + x = OperatorAttributeKey::COMBINE_DIM; + } else if (as_str == "COMBINE_DEGREE") { + x = OperatorAttributeKey::COMBINE_DEGREE; + } else if (as_str == "NUM_INPUTS") { + x = OperatorAttributeKey::NUM_INPUTS; + } else { + std::ostringstream oss; + oss << "Unknown OperatorAttributeKey value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::element( + FlexFlow::OperatorAttributeKey::OP_TYPE, + FlexFlow::OperatorAttributeKey::USE_BIAS, + FlexFlow::OperatorAttributeKey::GROUPS, + FlexFlow::OperatorAttributeKey::POOL_TYPE, + FlexFlow::OperatorAttributeKey::KERNEL_H, + FlexFlow::OperatorAttributeKey::KERNEL_W, + FlexFlow::OperatorAttributeKey::DATA_TYPE, + FlexFlow::OperatorAttributeKey::SCALAR, + FlexFlow::OperatorAttributeKey::STRIDE_H, + FlexFlow::OperatorAttributeKey::STRIDE_W, + FlexFlow::OperatorAttributeKey::PADDING_H, + FlexFlow::OperatorAttributeKey::PADDING_W, + FlexFlow::OperatorAttributeKey::AGGR, + FlexFlow::OperatorAttributeKey::NUM_ENTRIES, + FlexFlow::OperatorAttributeKey::OUT_CHANNELS, + FlexFlow::OperatorAttributeKey::ACTIVATION, + FlexFlow::OperatorAttributeKey::NUMDIM, + FlexFlow::OperatorAttributeKey::AXIS, + FlexFlow::OperatorAttributeKey::PERMUTATION, + FlexFlow::OperatorAttributeKey::OUTSHUFFLE, + FlexFlow::OperatorAttributeKey::MERGE_GCONV_COUNT, + FlexFlow::OperatorAttributeKey::AXES, + FlexFlow::OperatorAttributeKey::KEEP_DIMS, + FlexFlow::OperatorAttributeKey::EPSILON, + FlexFlow::OperatorAttributeKey::PARALLEL_OP_DIM, + FlexFlow::OperatorAttributeKey::PARALLEL_OP_DEGREE, + FlexFlow::OperatorAttributeKey::SOFTMAX_DIM, + FlexFlow::OperatorAttributeKey::NUM_HEADS, + FlexFlow::OperatorAttributeKey::PARALLEL_DIM, + FlexFlow::OperatorAttributeKey::PARALLEL_DEGREE, + FlexFlow::OperatorAttributeKey::PAD, + FlexFlow::OperatorAttributeKey::EMBED_DIM, + FlexFlow::OperatorAttributeKey::KDIM, + FlexFlow::OperatorAttributeKey::VDIM, + FlexFlow::OperatorAttributeKey::DROPOUT, + FlexFlow::OperatorAttributeKey::BIAS, + FlexFlow::OperatorAttributeKey::ADD_BIAS_KV, + FlexFlow::OperatorAttributeKey::ADD_ZERO_ATTN, + FlexFlow::OperatorAttributeKey::A_SEQ_LENGTH_DIM, + FlexFlow::OperatorAttributeKey::B_SEQ_LENGTH_DIM, + FlexFlow::OperatorAttributeKey::RELU, + FlexFlow::OperatorAttributeKey::TARGET_DIMS, + FlexFlow::OperatorAttributeKey::RATE, + FlexFlow::OperatorAttributeKey::SEED, + FlexFlow::OperatorAttributeKey::SHOULD_BROADCAST_LHS, + FlexFlow::OperatorAttributeKey::SHOULD_BROADCAST_RHS, + FlexFlow::OperatorAttributeKey::DIM, + FlexFlow::OperatorAttributeKey::ELEMENTWISE_AFFINE, + FlexFlow::OperatorAttributeKey::REGULARIZER, + FlexFlow::OperatorAttributeKey::SHAPE, + FlexFlow::OperatorAttributeKey::SPLITS, + FlexFlow::OperatorAttributeKey::K, + FlexFlow::OperatorAttributeKey::SORTED, + FlexFlow::OperatorAttributeKey::COMBINE_DIM, + FlexFlow::OperatorAttributeKey::COMBINE_DEGREE, + FlexFlow::OperatorAttributeKey::NUM_INPUTS); +} +} // namespace rc diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc new file mode 100644 index 0000000000..71b71d4a51 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc @@ -0,0 +1,101 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml +/* proj-data +{ + "generated_from": "1dc90d1e823f05b82c1a5ff433fbf000" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_list_access.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include + +namespace FlexFlow { +OperatorAttributeListIndexAccess::OperatorAttributeListIndexAccess( + ::FlexFlow::OperatorAttributeKey const &attribute_key, int const &index) + : attribute_key(attribute_key), index(index) {} +bool OperatorAttributeListIndexAccess::operator==( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) == + std::tie(other.attribute_key, other.index); +} +bool OperatorAttributeListIndexAccess::operator!=( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) != + std::tie(other.attribute_key, other.index); +} +bool OperatorAttributeListIndexAccess::operator<( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) < + std::tie(other.attribute_key, other.index); +} +bool OperatorAttributeListIndexAccess::operator>( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) > + std::tie(other.attribute_key, other.index); +} +bool OperatorAttributeListIndexAccess::operator<=( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) <= + std::tie(other.attribute_key, other.index); +} +bool OperatorAttributeListIndexAccess::operator>=( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) >= + std::tie(other.attribute_key, other.index); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorAttributeListIndexAccess const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OperatorAttributeKey>{}(x.attribute_key) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.index) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::OperatorAttributeListIndexAccess + adl_serializer::from_json( + json const &j) { + return { + j.at("attribute_key").template get<::FlexFlow::OperatorAttributeKey>(), + j.at("index").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::OperatorAttributeListIndexAccess const &v) { + j["__type"] = "OperatorAttributeListIndexAccess"; + j["attribute_key"] = v.attribute_key; + j["index"] = v.index; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::OperatorAttributeKey>(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(OperatorAttributeListIndexAccess const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + OperatorAttributeListIndexAccess const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc new file mode 100644 index 0000000000..eb7ae28131 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc @@ -0,0 +1,88 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml +/* proj-data +{ + "generated_from": "30999ad6b0603e380bc33d32fa088e45" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_list_size.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include + +namespace FlexFlow { +OperatorAttributeListSize::OperatorAttributeListSize( + ::FlexFlow::OperatorAttributeKey const &attribute_key) + : attribute_key(attribute_key) {} +bool OperatorAttributeListSize::operator==( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) == std::tie(other.attribute_key); +} +bool OperatorAttributeListSize::operator!=( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) != std::tie(other.attribute_key); +} +bool OperatorAttributeListSize::operator<( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) < std::tie(other.attribute_key); +} +bool OperatorAttributeListSize::operator>( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) > std::tie(other.attribute_key); +} +bool OperatorAttributeListSize::operator<=( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) <= std::tie(other.attribute_key); +} +bool OperatorAttributeListSize::operator>=( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) >= std::tie(other.attribute_key); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorAttributeListSize const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OperatorAttributeKey>{}(x.attribute_key) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::OperatorAttributeListSize + adl_serializer::from_json( + json const &j) { + return { + j.at("attribute_key").template get<::FlexFlow::OperatorAttributeKey>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::OperatorAttributeListSize const &v) { + j["__type"] = "OperatorAttributeListSize"; + j["attribute_key"] = v.attribute_key; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::OperatorAttributeKey>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(OperatorAttributeListSize const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OperatorAttributeListSize const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc new file mode 100644 index 0000000000..5eaf54bb5f --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc @@ -0,0 +1,73 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml +/* proj-data +{ + "generated_from": "968d7a3e93303a7fa7482bbcd50246b6" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" +#include "utils/fmt.h" +#include +#include + +namespace FlexFlow { +OperatorAttributePattern::OperatorAttributePattern( + std::unordered_set<::FlexFlow::OperatorAttributeConstraint> const + &attribute_constraints) + : attribute_constraints(attribute_constraints) {} +bool OperatorAttributePattern::operator==( + OperatorAttributePattern const &other) const { + return std::tie(this->attribute_constraints) == + std::tie(other.attribute_constraints); +} +bool OperatorAttributePattern::operator!=( + OperatorAttributePattern const &other) const { + return std::tie(this->attribute_constraints) != + std::tie(other.attribute_constraints); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorAttributePattern const &x) const { + size_t result = 0; + result ^= + std::hash>{}( + x.attribute_constraints) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::OperatorAttributePattern + adl_serializer::from_json( + json const &j) { + return { + j.at("attribute_constraints") + .template get< + std::unordered_set<::FlexFlow::OperatorAttributeConstraint>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::OperatorAttributePattern const &v) { + j["__type"] = "OperatorAttributePattern"; + j["attribute_constraints"] = v.attribute_constraints; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(OperatorAttributePattern const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OperatorAttributePattern const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_value.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_value.dtg.cc new file mode 100644 index 0000000000..376a9c2ce8 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_value.dtg.cc @@ -0,0 +1,292 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +/* proj-data +{ + "generated_from": "de14592f1f4bcfb52689bc95e9d3b55f" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +OperatorAttributeValue::OperatorAttributeValue(int const &v) : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(bool const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(std::vector const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue( + std::vector<::FlexFlow::ff_dim_t> const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue( + ::FlexFlow::OperatorType const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::Activation const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::ff_dim_t const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(size_t const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::AggregateOp const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue( + std::optional<::FlexFlow::RegularizerAttrs> const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::PoolOp const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::TensorShape const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::DataType const &v) + : raw_variant(v) {} +bool OperatorAttributeValue::operator==( + OperatorAttributeValue const &other) const { + return this->raw_variant == other.raw_variant; +} +bool OperatorAttributeValue::operator!=( + OperatorAttributeValue const &other) const { + return this->raw_variant != other.raw_variant; +} +bool OperatorAttributeValue::operator<( + OperatorAttributeValue const &other) const { + return this->raw_variant < other.raw_variant; +} +bool OperatorAttributeValue::operator>( + OperatorAttributeValue const &other) const { + return this->raw_variant > other.raw_variant; +} +bool OperatorAttributeValue::operator<=( + OperatorAttributeValue const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool OperatorAttributeValue::operator>=( + OperatorAttributeValue const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::OperatorAttributeValue>::operator()( + ::FlexFlow::OperatorAttributeValue const &x) const { + return std::hash, + std::vector<::FlexFlow::ff_dim_t>, + ::FlexFlow::OperatorType, + ::FlexFlow::Activation, + ::FlexFlow::ff_dim_t, + size_t, + ::FlexFlow::AggregateOp, + std::optional<::FlexFlow::RegularizerAttrs>, + ::FlexFlow::PoolOp, + ::FlexFlow::TensorShape, + ::FlexFlow::DataType>>{}(x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::OperatorAttributeValue + adl_serializer<::FlexFlow::OperatorAttributeValue>::from_json( + json const &j) { + std::string key = j.at("type").template get(); + if (key == "int") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get()}; + } else if (key == "bool") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get()}; + } else if (key == "std::vector") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get>()}; + } else if (key == "std::vector<::FlexFlow::ff_dim_t>") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get>()}; + } else if (key == "::FlexFlow::OperatorType") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::OperatorType>()}; + } else if (key == "::FlexFlow::Activation") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::Activation>()}; + } else if (key == "::FlexFlow::ff_dim_t") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::ff_dim_t>()}; + } else if (key == "size_t") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get()}; + } else if (key == "::FlexFlow::AggregateOp") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::AggregateOp>()}; + } else if (key == "std::optional<::FlexFlow::RegularizerAttrs>") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value") + .template get>()}; + } else if (key == "::FlexFlow::PoolOp") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::PoolOp>()}; + } else if (key == "::FlexFlow::TensorShape") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::TensorShape>()}; + } else if (key == "::FlexFlow::DataType") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::DataType>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::OperatorAttributeValue>::to_json( + json &j, ::FlexFlow::OperatorAttributeValue const &x) { + j["__type"] = "OperatorAttributeValue"; + switch (x.index()) { + case 0: { + j["type"] = "int"; + j["value"] = x.get(); + break; + } + case 1: { + j["type"] = "bool"; + j["value"] = x.get(); + break; + } + case 2: { + j["type"] = "std::vector"; + j["value"] = x.get>(); + break; + } + case 3: { + j["type"] = "std::vector<::FlexFlow::ff_dim_t>"; + j["value"] = x.get>(); + break; + } + case 4: { + j["type"] = "::FlexFlow::OperatorType"; + j["value"] = x.get<::FlexFlow::OperatorType>(); + break; + } + case 5: { + j["type"] = "::FlexFlow::Activation"; + j["value"] = x.get<::FlexFlow::Activation>(); + break; + } + case 6: { + j["type"] = "::FlexFlow::ff_dim_t"; + j["value"] = x.get<::FlexFlow::ff_dim_t>(); + break; + } + case 7: { + j["type"] = "size_t"; + j["value"] = x.get(); + break; + } + case 8: { + j["type"] = "::FlexFlow::AggregateOp"; + j["value"] = x.get<::FlexFlow::AggregateOp>(); + break; + } + case 9: { + j["type"] = "std::optional<::FlexFlow::RegularizerAttrs>"; + j["value"] = x.get>(); + break; + } + case 10: { + j["type"] = "::FlexFlow::PoolOp"; + j["value"] = x.get<::FlexFlow::PoolOp>(); + break; + } + case 11: { + j["type"] = "::FlexFlow::TensorShape"; + j["value"] = x.get<::FlexFlow::TensorShape>(); + break; + } + case 12: { + j["type"] = "::FlexFlow::DataType"; + j["value"] = x.get<::FlexFlow::DataType>(); + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeValue", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::OperatorAttributeValue const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << "=" + << x.get>() << ">"; + break; + } + case 3: { + oss << "=" + << x.get>() << ">"; + break; + } + case 4: { + oss << ""; + break; + } + case 5: { + oss << ""; + break; + } + case 6: { + oss << ""; + break; + } + case 7: { + oss << ""; + break; + } + case 8: { + oss << ""; + break; + } + case 9: { + oss << "=" + << x.get>() << ">"; + break; + } + case 10: { + oss << ""; + break; + } + case 11: { + oss << ""; + break; + } + case 12: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeValue", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::OperatorAttributeValue const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc new file mode 100644 index 0000000000..ae42515cc8 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc @@ -0,0 +1,26 @@ +#include "substitutions/operator_pattern/satisfies_constraint.h" +#include "substitutions/operator_pattern/operator_attribute_expr.h" + +namespace FlexFlow { + +bool operator_satisfies_constraint( + PCGOperatorAttrs const &attrs, + OperatorAttributeConstraint const &constraint) { + std::optional expr_val = + evaluate_attribute_expr(attrs, constraint.attribute_expr); + + if (!expr_val.has_value()) { + return false; + } + + switch (constraint.constraint_type) { + case ConstraintType::EQUAL: + return expr_val.value() == constraint.attribute_value; + default: + throw mk_runtime_error( + fmt::format("Unknown constraint type {}", + static_cast(constraint.constraint_type))); + } +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_pattern.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_pattern.cc new file mode 100644 index 0000000000..60ab363cc6 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_pattern.cc @@ -0,0 +1,14 @@ +#include "substitutions/operator_pattern/satisfies_pattern.h" +#include "substitutions/operator_pattern/satisfies_constraint.h" + +namespace FlexFlow { + +bool operator_satisfies_pattern(PCGOperatorAttrs const &attrs, + OperatorAttributePattern const &pattern) { + return all_of(pattern.attribute_constraints, + [&](OperatorAttributeConstraint const &c) { + return operator_satisfies_constraint(attrs, c); + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc b/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc new file mode 100644 index 0000000000..f20afc1164 --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml +/* proj-data +{ + "generated_from": "1e5beabcb8e3657d8fe9c9c8b1310cb1" +} +*/ + +#include "substitutions/output_graph/attr_constant.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include + +namespace FlexFlow { +AttrConstant::AttrConstant(::FlexFlow::OperatorAttributeValue const &value) + : value(value) {} +bool AttrConstant::operator==(AttrConstant const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool AttrConstant::operator!=(AttrConstant const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool AttrConstant::operator<(AttrConstant const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool AttrConstant::operator>(AttrConstant const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool AttrConstant::operator<=(AttrConstant const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool AttrConstant::operator>=(AttrConstant const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::AttrConstant const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OperatorAttributeValue>{}(x.value) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(AttrConstant const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, AttrConstant const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc new file mode 100644 index 0000000000..7d07bf9218 --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc @@ -0,0 +1,20 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml +/* proj-data +{ + "generated_from": "9084c9afb2724504a6f4db4288a83a0d" +} +*/ + +#include "substitutions/output_graph/output_graph_expr.dtg.h" + +#include "substitutions/output_graph/output_operator_attrs_assignment.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +OutputGraphExpr::OutputGraphExpr( + ::FlexFlow::NodeLabelledOpenMultiDiGraph< + ::FlexFlow::OutputOperatorAttrsAssignment> const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc new file mode 100644 index 0000000000..0c6abc925d --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc @@ -0,0 +1,77 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml +/* proj-data +{ + "generated_from": "e3b3a741183fcb38cfa68aacb82e12d1" +} +*/ + +#include "substitutions/output_graph/output_operator_attr_access.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" +#include "utils/graph.h" +#include + +namespace FlexFlow { +OutputOperatorAttrAccess::OutputOperatorAttrAccess( + ::FlexFlow::Node const &node, + ::FlexFlow::OperatorAttributeExpr const &attr_expr) + : node(node), attr_expr(attr_expr) {} +bool OutputOperatorAttrAccess::operator==( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) == + std::tie(other.node, other.attr_expr); +} +bool OutputOperatorAttrAccess::operator!=( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) != + std::tie(other.node, other.attr_expr); +} +bool OutputOperatorAttrAccess::operator<( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) < + std::tie(other.node, other.attr_expr); +} +bool OutputOperatorAttrAccess::operator>( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) > + std::tie(other.node, other.attr_expr); +} +bool OutputOperatorAttrAccess::operator<=( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) <= + std::tie(other.node, other.attr_expr); +} +bool OutputOperatorAttrAccess::operator>=( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) >= + std::tie(other.node, other.attr_expr); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OutputOperatorAttrAccess const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::OperatorAttributeExpr>{}(x.attr_expr) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputOperatorAttrAccess const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OutputOperatorAttrAccess const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.dtg.cc new file mode 100644 index 0000000000..bf1b07c825 --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.dtg.cc @@ -0,0 +1,79 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "89ebf777a5b909eef78ab5a5a177e041" +} +*/ + +#include "substitutions/output_graph/output_operator_attribute_expr.dtg.h" + +#include + +namespace FlexFlow { +OutputOperatorAttributeExpr::OutputOperatorAttributeExpr( + ::FlexFlow::OutputOperatorAttrAccess const &v) + : raw_variant(v) {} +OutputOperatorAttributeExpr::OutputOperatorAttributeExpr( + ::FlexFlow::AttrConstant const &v) + : raw_variant(v) {} +bool OutputOperatorAttributeExpr::operator==( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant == other.raw_variant; +} +bool OutputOperatorAttributeExpr::operator!=( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant != other.raw_variant; +} +bool OutputOperatorAttributeExpr::operator<( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant < other.raw_variant; +} +bool OutputOperatorAttributeExpr::operator>( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant > other.raw_variant; +} +bool OutputOperatorAttributeExpr::operator<=( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool OutputOperatorAttributeExpr::operator>=( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::OutputOperatorAttributeExpr>::operator()( + ::FlexFlow::OutputOperatorAttributeExpr const &x) const { + return std::hash>{}(x.raw_variant); +} +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::OutputOperatorAttributeExpr const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OutputOperatorAttributeExpr", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::OutputOperatorAttributeExpr const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc new file mode 100644 index 0000000000..7a1950482a --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml +/* proj-data +{ + "generated_from": "bbfb309c5a39a729da23dace4df4a9de" +} +*/ + +#include "substitutions/output_graph/output_operator_attrs_assignment.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include "substitutions/output_graph/output_operator_attribute_expr.dtg.h" +#include +#include + +namespace FlexFlow { +OutputOperatorAttrsAssignment::OutputOperatorAttrsAssignment( + std::unordered_map<::FlexFlow::OperatorAttributeKey, + ::FlexFlow::OutputOperatorAttributeExpr> const + &assignments) + : assignments(assignments) {} +bool OutputOperatorAttrsAssignment::operator==( + OutputOperatorAttrsAssignment const &other) const { + return std::tie(this->assignments) == std::tie(other.assignments); +} +bool OutputOperatorAttrsAssignment::operator!=( + OutputOperatorAttrsAssignment const &other) const { + return std::tie(this->assignments) != std::tie(other.assignments); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OutputOperatorAttrsAssignment const &x) const { + size_t result = 0; + result ^= + std::hash>{}( + x.assignments) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputOperatorAttrsAssignment const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + OutputOperatorAttrsAssignment const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc b/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc new file mode 100644 index 0000000000..7133ab42a7 --- /dev/null +++ b/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc @@ -0,0 +1,21 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/pcg_pattern.struct.toml +/* proj-data +{ + "generated_from": "f536f846828ba39266dd4a1fbaeec0e6" +} +*/ + +#include "substitutions/pcg_pattern.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +PCGPattern::PCGPattern(::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::OperatorAttributePattern, + ::FlexFlow::TensorAttributePattern> const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc new file mode 100644 index 0000000000..7736113819 --- /dev/null +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -0,0 +1,22 @@ +#include "substitutions/sub_parallel_computation_graph.h" + +namespace FlexFlow { + +ParallelLayerAttrs + get_parallel_layer_attrs(SubParallelComputationGraph const &spcg, + Node const &n) { + return spcg.raw_graph.at(n); +} + +PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &spcg, + Node const &n) { + return get_parallel_layer_attrs(spcg, n).attrs; +} + +ParallelTensorAttrs + get_parallel_tensor_attrs(SubParallelComputationGraph const &spcg, + OpenMultiDiEdge const &e) { + return spcg.raw_graph.at(e); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc new file mode 100644 index 0000000000..83baef2cfc --- /dev/null +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc @@ -0,0 +1,22 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml +/* proj-data +{ + "generated_from": "0022d1b2c1447667695a120c154a0168" +} +*/ + +#include "substitutions/sub_parallel_computation_graph.dtg.h" + +#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +SubParallelComputationGraph::SubParallelComputationGraph( + ::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/substitution.cc b/lib/substitutions/src/substitutions/substitution.cc new file mode 100644 index 0000000000..e900175bc6 --- /dev/null +++ b/lib/substitutions/src/substitutions/substitution.cc @@ -0,0 +1,154 @@ +#include "substitutions/substitution.h" + +namespace FlexFlow { + +/* struct AddMappedEdgeFunctor { */ +/* bidict const &node_mapping; */ +/* SubParallelComputationGraph &new_pcg; */ + +/* template */ +/* void operator()(T const &t) { */ +/* return add_mapped_edge(t); */ +/* } */ + +/* void add_mapped_edge(InputMultiDiEdge const &e) { */ +/* new_pcg.add_edge(InputMultiDiEdge{ */ +/* node_mapping.at_l(e.dst), new_pcg.add_node_port(), e.uid}); */ +/* } */ + +/* void add_mapped_edge(OutputMultiDiEdge const &e) { */ +/* new_pcg.add_edge(OutputMultiDiEdge{ */ +/* node_mapping.at_l(e.src), new_pcg.add_node_port(), e.uid}); */ +/* } */ + +/* void add_mapped_edge(MultiDiEdge const &e) { */ +/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), */ +/* new_pcg.add_node_port(), */ +/* node_mapping.at_l(e.src), */ +/* new_pcg.add_node_port()}); */ +/* } */ +/* }; */ + +/* struct AddNewEdgeFunctor { */ +/* SubParallelComputationGraph const &old_pcg; */ +/* SubParallelComputationGraph &new_pcg; */ +/* MultiDiGraphPatternMatch const &match; */ +/* bidict node_mapping; */ + +/* template */ +/* void operator()(TO const &old_edge, TN const &new_edge) { */ +/* return add_new_edge(old_edge, new_edge); */ +/* } */ + +/* void add_new_edge(InputMultiDiEdge const &old_edge, */ +/* InputMultiDiEdge const &new_edge) { */ +/* new_pcg.add_edge(InputMultiDiEdge{node_mapping.at_l(new_edge.dst), */ +/* new_pcg.add_node_port(), */ +/* old_edge.uid}); */ +/* } */ + +/* void add_new_edge(MultiDiEdge const &old_edge, */ +/* InputMultiDiEdge const &new_edge) { */ +/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(new_edge.dst), */ +/* new_pcg.add_node_port(), */ +/* node_mapping.at_l(old_edge.src), */ +/* new_pcg.add_node_port()}); */ +/* } */ + +/* void add_new_edge(OutputMultiDiEdge const &old_edge, */ +/* OutputMultiDiEdge const &new_edge) { */ +/* new_pcg.add_edge(OutputMultiDiEdge{node_mapping.at_l(new_edge.src), */ +/* new_pcg.add_node_port(), */ +/* old_edge.uid}); */ +/* } */ + +/* void add_new_edge(MultiDiEdge const &old_edge, */ +/* OutputMultiDiEdge const &new_edge) { */ +/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(old_edge.dst), */ +/* new_pcg.add_node_port(), */ +/* node_mapping.at_l(new_edge.src), */ +/* new_pcg.add_node_port()}); */ +/* } */ + +/* void add_new_edge(InputMultiDiEdge const &, OutputMultiDiEdge const &) { */ +/* assert(false); */ +/* } */ + +/* void add_new_edge(OpenMultiDiEdge const &, MultiDiEdge const &) { */ +/* assert(false); */ +/* } */ + +/* void add_new_edge(OutputMultiDiEdge const &, InputMultiDiEdge const &) { */ +/* assert(false); */ +/* } */ +/* }; */ + +/* SubParallelComputationGraph */ +/* apply_substitution(SubParallelComputationGraph const &pcg, */ +/* Substitution const &substitution, */ +/* MultiDiGraphPatternMatch const &match) { */ +/* SubParallelComputationGraph new_pcg = */ +/* OutputLabelledOpenMultiDiGraph::template + * create< */ +/* UnorderedOutputLabelledOpenMultiDiGraph>(); */ +/* bidict node_mapping; // Refactor it with global nodes */ +/* for (Node const &node : get_nodes(pcg)) { */ +/* if (!contains_r(match.node_assignment, node)) { */ +/* node_mapping.equate(node, new_pcg.add_node(pcg.at(node))); */ +/* } */ +/* } */ +/* for (OpenMultiDiEdge const &edge : get_edges(pcg)) { */ +/* if (!contains_r(match.edge_assignment, edge)) { */ +/* visit(AddMappedEdgeFunctor{node_mapping, new_pcg}, edge); */ +/* } */ +/* } */ +/* for (Node const &output_node : */ +/* get_nodes(substitution.output_graph_expr.value())) { */ +/* Operator new_op = get_operator_attrs( */ +/* pcg, match, substitution.output_graph_expr.value().at(output_node)); + */ +/* Node new_node = new_pcg.add_node(new_op); */ +/* node_mapping.equate(output_node, new_node); */ +/* } */ +/* for (OpenMultiDiEdge const &output_edge : */ +/* get_edges(substitution.output_graph_expr.value())) { */ +/* if (std::holds_alternative(output_edge)) { */ +/* InputMultiDiEdge e = std::get(output_edge); */ +/* OpenMultiDiEdge original_edge = */ +/* match.edge_assignment.at_l(substitution.input_mapping.at_r(e)); */ +/* visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, */ +/* original_edge, */ +/* output_edge); */ +/* } else if (std::holds_alternative(output_edge)) { */ +/* OutputMultiDiEdge e = std::get(output_edge); */ +/* OpenMultiDiEdge original_edge = */ +/* match.edge_assignment.at_l(substitution.output_mapping.at_r(e)); */ +/* visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, */ +/* original_edge, */ +/* output_edge); */ +/* } else { */ +/* assert(std::holds_alternative(output_edge)); */ +/* MultiDiEdge e = std::get(output_edge); */ +/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), */ +/* new_pcg.add_node_port(), */ +/* node_mapping.at_l(e.src), */ +/* new_pcg.add_node_port()}); */ +/* } */ +/* } */ + +/* return new_pcg; */ +/* } */ + +bool is_valid_substitution(Substitution const &) { + NOT_IMPLEMENTED(); +} + +SubParallelComputationGraph + apply_substitution(SubParallelComputationGraph const &, + Substitution const &, + MultiDiGraphPatternMatch const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/substitution.dtg.cc b/lib/substitutions/src/substitutions/substitution.dtg.cc new file mode 100644 index 0000000000..67d39d6ff7 --- /dev/null +++ b/lib/substitutions/src/substitutions/substitution.dtg.cc @@ -0,0 +1,28 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/substitution.struct.toml +/* proj-data +{ + "generated_from": "c101f1d63e2d8d80a0ec9c5f5db4fa12" +} +*/ + +#include "substitutions/substitution.dtg.h" + +#include "substitutions/output_graph/output_graph_expr.dtg.h" +#include "substitutions/pcg_pattern.dtg.h" + +namespace FlexFlow { +Substitution::Substitution( + ::FlexFlow::PCGPattern const &pcg_pattern, + ::FlexFlow::OutputGraphExpr const &output_graph_expr, + ::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, + ::FlexFlow::InputMultiDiEdge> const + &input_edge_match_to_output, + ::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::OutputMultiDiEdge> const + &output_edge_match_to_output) + : pcg_pattern(pcg_pattern), output_graph_expr(output_graph_expr), + input_edge_match_to_output(input_edge_match_to_output), + output_edge_match_to_output(output_edge_match_to_output) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc new file mode 100644 index 0000000000..ea4833d36a --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc @@ -0,0 +1,24 @@ +#include "substitutions/tensor_pattern/eval_list_access.h" +#include "substitutions/tensor_pattern/get_attribute.h" +#include "utils/containers.h" +#include "utils/overload.h" + +namespace FlexFlow { + +TensorAttributeValue + eval_list_access(ParallelTensorAttrs const &attrs, + TensorAttributeListIndexAccess const &acc) { + TensorAttributeValue from_attr = get_attribute(attrs, acc.attribute_key); + + return from_attr.visit(overload{ + [&](std::vector const &v) -> TensorAttributeValue { + return TensorAttributeValue{ + static_cast(at_idx(v, acc.index).value())}; + }, + [](auto &&) -> TensorAttributeValue { + throw mk_runtime_error("Invalid operand"); + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/eval_list_size.cc b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_size.cc new file mode 100644 index 0000000000..d1e97adc37 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_size.cc @@ -0,0 +1,21 @@ +#include "substitutions/tensor_pattern/eval_list_size.h" +#include "substitutions/tensor_pattern/get_attribute.h" +#include "utils/overload.h" + +namespace FlexFlow { + +TensorAttributeValue eval_list_size(ParallelTensorAttrs const &attrs, + TensorAttributeListSize const &acc) { + TensorAttributeValue from_attr = get_attribute(attrs, acc.attribute_key); + + return from_attr.visit(overload{ + [](std::vector const &v) -> TensorAttributeValue { + return TensorAttributeValue{v.size()}; + }, + [](auto &&) -> TensorAttributeValue { + throw mk_runtime_error("Invalid operand"); + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc new file mode 100644 index 0000000000..7c42bdd904 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc @@ -0,0 +1,28 @@ +#include "substitutions/tensor_pattern/get_attribute.h" +#include "utils/containers.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +TensorAttributeValue get_attribute(ParallelTensorAttrs const &attrs, + TensorAttributeKey key) { + switch (key) { + case TensorAttributeKey::DIM_SIZES: { + std::vector sizes = + transform(as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + [](ShardParallelDim const &d) { return d.size; }); + return TensorAttributeValue{sizes}; + } + case TensorAttributeKey::DIM_DEGREES: { + std::vector degrees = transform( + as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + [](ShardParallelDim const &d) { return size_t_from_int(d.degree); }); + return TensorAttributeValue{degrees}; + } + default: + throw std::runtime_error( + fmt::format("Unknown TensorAttributeKey {}", static_cast(key))); + } +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc new file mode 100644 index 0000000000..974bfcabc0 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc @@ -0,0 +1,22 @@ +#include "substitutions/tensor_pattern/satisfies_constraint.h" +#include "substitutions/tensor_pattern/tensor_attribute_expr.h" + +namespace FlexFlow { + +bool parallel_tensor_satisfies_constraint( + ParallelTensorAttrs const &attrs, + TensorAttributeConstraint const &constraint) { + TensorAttributeValue expr_val = + evaluate_attribute_expr(attrs, constraint.attribute_expr); + + switch (constraint.constraint_type) { + case ConstraintType::EQUAL: + return expr_val == constraint.attribute_value; + default: + throw mk_runtime_error( + fmt::format("Unknown constraint type {}", + static_cast(constraint.constraint_type))); + } +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_pattern.cc b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_pattern.cc new file mode 100644 index 0000000000..35fec2dfea --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_pattern.cc @@ -0,0 +1,13 @@ +#include "substitutions/tensor_pattern/satisfies_pattern.h" +#include "substitutions/tensor_pattern/satisfies_constraint.h" + +namespace FlexFlow { + +bool parallel_tensor_satisfies_pattern(ParallelTensorAttrs const &attrs, + TensorAttributePattern const &pattern) { + return all_of(pattern.attribute_constraints, + [&](TensorAttributeConstraint const &c) { + return parallel_tensor_satisfies_constraint(attrs, c); + }); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.cc new file mode 100644 index 0000000000..6f9df90fb2 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.cc @@ -0,0 +1,119 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml +/* proj-data +{ + "generated_from": "29dbf81668bc864b06af52261060335e" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" + +#include "substitutions/constraint_type.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" +#include + +namespace FlexFlow { +TensorAttributeConstraint::TensorAttributeConstraint( + ::FlexFlow::ConstraintType const &constraint_type, + ::FlexFlow::TensorAttributeExpr const &attribute_expr, + ::FlexFlow::TensorAttributeValue const &attribute_value) + : constraint_type(constraint_type), attribute_expr(attribute_expr), + attribute_value(attribute_value) {} +bool TensorAttributeConstraint::operator==( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) == std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool TensorAttributeConstraint::operator!=( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) != std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool TensorAttributeConstraint::operator<( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) < std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool TensorAttributeConstraint::operator>( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) > std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool TensorAttributeConstraint::operator<=( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) <= std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool TensorAttributeConstraint::operator>=( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) >= std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttributeConstraint const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ConstraintType>{}(x.constraint_type) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::TensorAttributeExpr>{}(x.attribute_expr) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::TensorAttributeValue>{}(x.attribute_value) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorAttributeConstraint + adl_serializer::from_json( + json const &j) { + return { + j.at("constraint_type").template get<::FlexFlow::ConstraintType>(), + j.at("attribute_expr").template get<::FlexFlow::TensorAttributeExpr>(), + j.at("attribute_value").template get<::FlexFlow::TensorAttributeValue>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorAttributeConstraint const &v) { + j["__type"] = "TensorAttributeConstraint"; + j["constraint_type"] = v.constraint_type; + j["attribute_expr"] = v.attribute_expr; + j["attribute_value"] = v.attribute_value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttributeConstraint const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorAttributeConstraint const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.cc new file mode 100644 index 0000000000..33bcc1a082 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.cc @@ -0,0 +1,22 @@ +#include "substitutions/tensor_pattern/tensor_attribute_expr.h" +#include "substitutions/tensor_pattern/eval_list_access.h" +#include "substitutions/tensor_pattern/eval_list_size.h" +#include "substitutions/tensor_pattern/get_attribute.h" +#include "utils/overload.h" + +namespace FlexFlow { + +TensorAttributeValue evaluate_attribute_expr(ParallelTensorAttrs const &attrs, + TensorAttributeExpr const &expr) { + + return expr.visit(overload{ + [&](TensorAttributeKey const &key) { return get_attribute(attrs, key); }, + [&](TensorAttributeListSize const &s) { + return eval_list_size(attrs, s); + }, + [&](TensorAttributeListIndexAccess const &s) { + return eval_list_access(attrs, s); + }}); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.dtg.cc new file mode 100644 index 0000000000..a42f18bf26 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.dtg.cc @@ -0,0 +1,129 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "b91285329f12f1b409805cbf9be575b2" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +TensorAttributeExpr::TensorAttributeExpr( + ::FlexFlow::TensorAttributeKey const &v) + : raw_variant(v) {} +TensorAttributeExpr::TensorAttributeExpr( + ::FlexFlow::TensorAttributeListSize const &v) + : raw_variant(v) {} +TensorAttributeExpr::TensorAttributeExpr( + ::FlexFlow::TensorAttributeListIndexAccess const &v) + : raw_variant(v) {} +bool TensorAttributeExpr::operator==(TensorAttributeExpr const &other) const { + return this->raw_variant == other.raw_variant; +} +bool TensorAttributeExpr::operator!=(TensorAttributeExpr const &other) const { + return this->raw_variant != other.raw_variant; +} +bool TensorAttributeExpr::operator<(TensorAttributeExpr const &other) const { + return this->raw_variant < other.raw_variant; +} +bool TensorAttributeExpr::operator>(TensorAttributeExpr const &other) const { + return this->raw_variant > other.raw_variant; +} +bool TensorAttributeExpr::operator<=(TensorAttributeExpr const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool TensorAttributeExpr::operator>=(TensorAttributeExpr const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::TensorAttributeExpr>::operator()( + ::FlexFlow::TensorAttributeExpr const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::TensorAttributeExpr + adl_serializer<::FlexFlow::TensorAttributeExpr>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "key") { + return ::FlexFlow::TensorAttributeExpr{ + j.at("value").template get<::FlexFlow::TensorAttributeKey>()}; + } else if (key == "list_size") { + return ::FlexFlow::TensorAttributeExpr{ + j.at("value").template get<::FlexFlow::TensorAttributeListSize>()}; + } else if (key == "list_idx") { + return ::FlexFlow::TensorAttributeExpr{ + j.at("value") + .template get<::FlexFlow::TensorAttributeListIndexAccess>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::TensorAttributeExpr>::to_json( + json &j, ::FlexFlow::TensorAttributeExpr const &x) { + j["__type"] = "TensorAttributeExpr"; + switch (x.index()) { + case 0: { + j["type"] = "key"; + j["value"] = x.get<::FlexFlow::TensorAttributeKey>(); + break; + } + case 1: { + j["type"] = "list_size"; + j["value"] = x.get<::FlexFlow::TensorAttributeListSize>(); + break; + } + case 2: { + j["type"] = "list_idx"; + j["value"] = x.get<::FlexFlow::TensorAttributeListIndexAccess>(); + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeExpr", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::TensorAttributeExpr const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeExpr", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::TensorAttributeExpr const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_key.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_key.dtg.cc new file mode 100644 index 0000000000..fe87c63777 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_key.dtg.cc @@ -0,0 +1,73 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml +/* proj-data +{ + "generated_from": "63a7c40c1e5b582f98b59750a35f0a08" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttributeKey x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(TensorAttributeKey x) { + switch (x) { + case TensorAttributeKey::DIM_SIZES: + return "DIM_SIZES"; + case TensorAttributeKey::DIM_DEGREES: + return "DIM_DEGREES"; + default: + std::ostringstream oss; + oss << "Unknown TensorAttributeKey value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, TensorAttributeKey x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, TensorAttributeKey x) { + switch (x) { + case TensorAttributeKey::DIM_SIZES: + j = "DIM_SIZES"; + break; + case TensorAttributeKey::DIM_DEGREES: + j = "DIM_DEGREES"; + break; + default: + std::ostringstream oss; + oss << "Unknown TensorAttributeKey value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, TensorAttributeKey &x) { + std::string as_str = j.get(); + if (as_str == "DIM_SIZES") { + x = TensorAttributeKey::DIM_SIZES; + } else if (as_str == "DIM_DEGREES") { + x = TensorAttributeKey::DIM_DEGREES; + } else { + std::ostringstream oss; + oss << "Unknown TensorAttributeKey value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::element( + FlexFlow::TensorAttributeKey::DIM_SIZES, + FlexFlow::TensorAttributeKey::DIM_DEGREES); +} +} // namespace rc diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc new file mode 100644 index 0000000000..4e28de2c28 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc @@ -0,0 +1,99 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml +/* proj-data +{ + "generated_from": "41f5449cd700b6d7ab017f3efa39dc1d" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h" + +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include + +namespace FlexFlow { +TensorAttributeListIndexAccess::TensorAttributeListIndexAccess( + ::FlexFlow::TensorAttributeKey const &attribute_key, int const &index) + : attribute_key(attribute_key), index(index) {} +bool TensorAttributeListIndexAccess::operator==( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) == + std::tie(other.attribute_key, other.index); +} +bool TensorAttributeListIndexAccess::operator!=( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) != + std::tie(other.attribute_key, other.index); +} +bool TensorAttributeListIndexAccess::operator<( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) < + std::tie(other.attribute_key, other.index); +} +bool TensorAttributeListIndexAccess::operator>( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) > + std::tie(other.attribute_key, other.index); +} +bool TensorAttributeListIndexAccess::operator<=( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) <= + std::tie(other.attribute_key, other.index); +} +bool TensorAttributeListIndexAccess::operator>=( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) >= + std::tie(other.attribute_key, other.index); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttributeListIndexAccess const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorAttributeKey>{}(x.attribute_key) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.index) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorAttributeListIndexAccess + adl_serializer::from_json( + json const &j) { + return {j.at("attribute_key").template get<::FlexFlow::TensorAttributeKey>(), + j.at("index").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorAttributeListIndexAccess const &v) { + j["__type"] = "TensorAttributeListIndexAccess"; + j["attribute_key"] = v.attribute_key; + j["index"] = v.index; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::TensorAttributeKey>(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorAttributeListIndexAccess const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + TensorAttributeListIndexAccess const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc new file mode 100644 index 0000000000..24d8b6c025 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc @@ -0,0 +1,87 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml +/* proj-data +{ + "generated_from": "ec72cd39de5d1c0f0478696d7b83e4e9" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h" + +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include + +namespace FlexFlow { +TensorAttributeListSize::TensorAttributeListSize( + ::FlexFlow::TensorAttributeKey const &attribute_key) + : attribute_key(attribute_key) {} +bool TensorAttributeListSize::operator==( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) == std::tie(other.attribute_key); +} +bool TensorAttributeListSize::operator!=( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) != std::tie(other.attribute_key); +} +bool TensorAttributeListSize::operator<( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) < std::tie(other.attribute_key); +} +bool TensorAttributeListSize::operator>( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) > std::tie(other.attribute_key); +} +bool TensorAttributeListSize::operator<=( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) <= std::tie(other.attribute_key); +} +bool TensorAttributeListSize::operator>=( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) >= std::tie(other.attribute_key); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttributeListSize const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorAttributeKey>{}(x.attribute_key) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorAttributeListSize + adl_serializer::from_json( + json const &j) { + return {j.at("attribute_key").template get<::FlexFlow::TensorAttributeKey>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorAttributeListSize const &v) { + j["__type"] = "TensorAttributeListSize"; + j["attribute_key"] = v.attribute_key; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::TensorAttributeKey>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorAttributeListSize const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorAttributeListSize const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc new file mode 100644 index 0000000000..121549d4dc --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc @@ -0,0 +1,71 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml +/* proj-data +{ + "generated_from": "42a51afce383f1ddc3d70827aa94a68f" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" + +#include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" +#include "utils/hash-utils.h" +#include +#include + +namespace FlexFlow { +TensorAttributePattern::TensorAttributePattern( + std::unordered_set<::FlexFlow::TensorAttributeConstraint> const + &attribute_constraints) + : attribute_constraints(attribute_constraints) {} +bool TensorAttributePattern::operator==( + TensorAttributePattern const &other) const { + return std::tie(this->attribute_constraints) == + std::tie(other.attribute_constraints); +} +bool TensorAttributePattern::operator!=( + TensorAttributePattern const &other) const { + return std::tie(this->attribute_constraints) != + std::tie(other.attribute_constraints); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttributePattern const &x) const { + size_t result = 0; + result ^= + std::hash>{}( + x.attribute_constraints) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorAttributePattern + adl_serializer::from_json(json const &j) { + return {j.at("attribute_constraints") + .template get< + std::unordered_set<::FlexFlow::TensorAttributeConstraint>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorAttributePattern const &v) { + j["__type"] = "TensorAttributePattern"; + j["attribute_constraints"] = v.attribute_constraints; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttributePattern const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorAttributePattern const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_value.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_value.dtg.cc new file mode 100644 index 0000000000..27a82c4ffe --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_value.dtg.cc @@ -0,0 +1,105 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml +/* proj-data +{ + "generated_from": "d80cf2e618d64df284c2647430a12a86" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +TensorAttributeValue::TensorAttributeValue(size_t const &v) : raw_variant(v) {} +TensorAttributeValue::TensorAttributeValue(std::vector const &v) + : raw_variant(v) {} +bool TensorAttributeValue::operator==(TensorAttributeValue const &other) const { + return this->raw_variant == other.raw_variant; +} +bool TensorAttributeValue::operator!=(TensorAttributeValue const &other) const { + return this->raw_variant != other.raw_variant; +} +bool TensorAttributeValue::operator<(TensorAttributeValue const &other) const { + return this->raw_variant < other.raw_variant; +} +bool TensorAttributeValue::operator>(TensorAttributeValue const &other) const { + return this->raw_variant > other.raw_variant; +} +bool TensorAttributeValue::operator<=(TensorAttributeValue const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool TensorAttributeValue::operator>=(TensorAttributeValue const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::TensorAttributeValue>::operator()( + ::FlexFlow::TensorAttributeValue const &x) const { + return std::hash>>{}(x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::TensorAttributeValue + adl_serializer<::FlexFlow::TensorAttributeValue>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "size_t") { + return ::FlexFlow::TensorAttributeValue{ + j.at("value").template get()}; + } else if (key == "std::vector") { + return ::FlexFlow::TensorAttributeValue{ + j.at("value").template get>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::TensorAttributeValue>::to_json( + json &j, ::FlexFlow::TensorAttributeValue const &x) { + j["__type"] = "TensorAttributeValue"; + switch (x.index()) { + case 0: { + j["type"] = "size_t"; + j["value"] = x.get(); + break; + } + case 1: { + j["type"] = "std::vector"; + j["value"] = x.get>(); + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeValue", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::TensorAttributeValue const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << "=" + << x.get>() << ">"; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeValue", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::TensorAttributeValue const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc new file mode 100644 index 0000000000..fbefc6f01a --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "b4086fd78ca7ec0475ed7abfd034304c" +} +*/ + +#include "substitutions/unlabelled/closed_pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +ClosedPatternEdge::ClosedPatternEdge(::FlexFlow::MultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool ClosedPatternEdge::operator==(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool ClosedPatternEdge::operator!=(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool ClosedPatternEdge::operator<(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool ClosedPatternEdge::operator>(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool ClosedPatternEdge::operator<=(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool ClosedPatternEdge::operator>=(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ClosedPatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::MultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.cc new file mode 100644 index 0000000000..704e0aea1a --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.cc @@ -0,0 +1,9 @@ +#include "substitutions/unlabelled/downward_open_pattern_edge.h" + +namespace FlexFlow { + +int get_src_idx(DownwardOpenPatternEdge const &e) { + return get_src_idx(e.raw_edge); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc new file mode 100644 index 0000000000..30c52fbbb2 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc @@ -0,0 +1,52 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "c67ec363a91ce090dc538dcf76fa1f12" +} +*/ + +#include "substitutions/unlabelled/downward_open_pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +DownwardOpenPatternEdge::DownwardOpenPatternEdge( + ::FlexFlow::DownwardOpenMultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool DownwardOpenPatternEdge::operator==( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool DownwardOpenPatternEdge::operator!=( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool DownwardOpenPatternEdge::operator<( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool DownwardOpenPatternEdge::operator>( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool DownwardOpenPatternEdge::operator<=( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool DownwardOpenPatternEdge::operator>=( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::DownwardOpenPatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::DownwardOpenMultiDiEdge>{}(x.raw_edge) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc b/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc new file mode 100644 index 0000000000..33ea7dc9f6 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc @@ -0,0 +1,35 @@ +#include "substitutions/unlabelled/edge_splits.h" + +namespace FlexFlow { + +std::pair + get_split_edges(UnlabelledPatternEdgeSplits const &splits, + ClosedPatternEdge const &e) { + std::pair raw_result = + splits.unwrapped.at_l(e.raw_edge); + return { + OutputPatternEdge{raw_result.first}, + InputPatternEdge{raw_result.second}, + }; +} + +std::vector> + as_closed_output_input_tuples(UnlabelledPatternEdgeSplits const &s) { + std::vector< + std::tuple> + result; + + for (auto const &kv : s.unwrapped) { + MultiDiEdge standard_edge = kv.first; + OutputMultiDiEdge output_edge = kv.second.first; + InputMultiDiEdge input_edge = kv.second.second; + + result.push_back({ClosedPatternEdge{standard_edge}, + OutputPatternEdge{output_edge}, + InputPatternEdge{input_edge}}); + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc new file mode 100644 index 0000000000..4da15179da --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc @@ -0,0 +1,31 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml +/* proj-data +{ + "generated_from": "f172b041a99f4de1d396e5d451a5e64d" +} +*/ + +#include "substitutions/unlabelled/edge_splits.dtg.h" + +#include "utils/bidict.h" +#include "utils/graph.h" +#include + +namespace FlexFlow { +UnlabelledPatternEdgeSplits::UnlabelledPatternEdgeSplits( + ::FlexFlow::bidict<::FlexFlow::MultiDiEdge, + std::pair<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::InputMultiDiEdge>> const + &unwrapped) + : unwrapped(unwrapped) {} +bool UnlabelledPatternEdgeSplits::operator==( + UnlabelledPatternEdgeSplits const &other) const { + return std::tie(this->unwrapped) == std::tie(other.unwrapped); +} +bool UnlabelledPatternEdgeSplits::operator!=( + UnlabelledPatternEdgeSplits const &other) const { + return std::tie(this->unwrapped) != std::tie(other.unwrapped); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc new file mode 100644 index 0000000000..8c787ca255 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -0,0 +1,161 @@ +#include "substitutions/unlabelled/find_pattern_matches.h" +#include "substitutions/unlabelled/downward_open_pattern_edge.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "substitutions/unlabelled/upward_open_pattern_edge.h" +#include "utils/containers.h" + +namespace FlexFlow { + +static std::vector + sorted_by_dst_idx(std::unordered_set const &in) { + return sorted_by( + in, compare_by([](UpwardOpenPatternEdge const &e) { + return get_dst_idx(e); + })); +} + +static std::vector + sorted_by_src_idx(std::unordered_set const &in) { + return sorted_by( + in, + compare_by( + [](DownwardOpenPatternEdge const &e) { return get_src_idx(e); })); +} + +static std::vector + sorted_by_dst_idx(std::unordered_set const &in) { + return sorted_by( + in, compare_by([](UpwardOpenPatternEdge const &e) { + return get_dst_idx(e); + })); +} + +static std::vector + sorted_by_src_idx(std::unordered_set const &in) { + return sorted_by( + in, + compare_by( + [](DownwardOpenMultiDiEdge const &e) { return get_src_idx(e); })); +} + +static std::optional + get_candidate_singleton_match(UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + Node const &graph_node) { + assert(is_singleton_pattern(pattern)); + + PatternNode pattern_node = get_only(get_nodes(pattern)); + + MultiDiGraphPatternMatch match = empty_multidigraph_pattern_match(); + match.node_assignment.equate(pattern_node, graph_node); + + std::unordered_set incoming = + get_incoming_edges(graph, graph_node); + std::unordered_set outgoing = + get_outgoing_edges(graph, graph_node); + + std::unordered_set pattern_incoming = + get_incoming_edges(pattern, pattern_node); + std::unordered_set pattern_outgoing = + get_outgoing_edges(pattern, pattern_node); + + if (!pattern_incoming.empty() && pattern_incoming.size() != incoming.size()) { + return std::nullopt; + } + + if (!pattern_outgoing.empty() && pattern_outgoing.size() != outgoing.size()) { + return std::nullopt; + } + + std::vector incoming_ordered = + sorted_by_dst_idx(incoming); + std::vector outgoing_ordered = + sorted_by_src_idx(outgoing); + + std::vector pattern_incoming_ordered = + sorted_by_dst_idx(pattern_incoming); + std::vector pattern_outgoing_ordered = + sorted_by_src_idx(pattern_outgoing); + + if (pattern_incoming.size() > 0) { + std::unordered_map node_port_mapping; + for (int i = 0; i < incoming_ordered.size(); ++i) { + UpwardOpenMultiDiEdge graph_edge = incoming_ordered[i]; + UpwardOpenPatternEdge pattern_edge = pattern_incoming_ordered[i]; + NodePort graph_port = get_dst_idx(graph_edge), + pattern_port = get_dst_idx(pattern_edge); + if (!contains_key(node_port_mapping, graph_port)) { + node_port_mapping.emplace(graph_port, pattern_port); + } else { + if (pattern_port != node_port_mapping.at(graph_port)) { + return std::nullopt; + } + } + match.edge_assignment.equate(widen(pattern_edge), + widen(graph_edge)); + } + } + + if (pattern_outgoing.size() > 0) { + std::unordered_map node_port_mapping; + for (int i = 0; i < outgoing_ordered.size(); ++i) { + DownwardOpenMultiDiEdge graph_edge = outgoing_ordered[i], + DownwardOpenPatternEdge pattern_edge = + pattern_outgoing_ordered[i]; + + NodePort graph_port = get_src_idx(graph_edge), + pattern_port = get_src_idx(pattern_edge); + if (!contains_key(node_port_mapping, graph_port)) { + node_port_mapping.insert({graph_port, pattern_port}); + } else { + if (pattern_port != node_port_mapping.at(graph_port)) { + return std::nullopt; + } + } + match.edge_assignment.equate(widen(pattern_edge), + widen(graph_edge)); + } + } + + return match; +} + +std::vector + find_pattern_matches(UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MatchAdditionalCriterion const &additional_criterion) { + std::vector matches; + if (is_singleton_pattern(pattern)) { + for (Node const &graph_node : get_nodes(graph)) { + std::optional candidate = + get_candidate_singleton_match(pattern, graph, graph_node); + if (candidate.has_value() && + pattern_does_match( + pattern, graph, candidate.value(), additional_criterion)) { + matches.push_back(candidate.value()); + } + } + } else { + GraphSplit split = split_pattern(pattern); + auto subpatterns = apply_split(pattern, split); + auto prefix_matches = + find_pattern_matches(subpatterns.first, graph, additional_criterion); + auto postfix_matches = + find_pattern_matches(subpatterns.second, graph, additional_criterion); + auto edge_splits = get_edge_splits(pattern, split); + for (MultiDiGraphPatternMatch const &prefix_match : prefix_matches) { + for (MultiDiGraphPatternMatch const &postfix_match : postfix_matches) { + std::optional unsplit = + unsplit_matches(prefix_match, postfix_match, edge_splits); + if (unsplit.has_value()) { + matches.push_back(unsplit.value()); + } + } + } + } + + return matches; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc new file mode 100644 index 0000000000..2eff39bb1e --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc @@ -0,0 +1,9 @@ +#include "substitutions/unlabelled/input_pattern_edge.h" + +namespace FlexFlow { + +PatternNode get_dst_node(InputPatternEdge const &e) { + return PatternNode{e.raw_edge.dst}; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc new file mode 100644 index 0000000000..f3f5a8ce45 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "d0cc0e65c4e3feb2e9b8435947c99e5f" +} +*/ + +#include "substitutions/unlabelled/input_pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +InputPatternEdge::InputPatternEdge(::FlexFlow::InputMultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool InputPatternEdge::operator==(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool InputPatternEdge::operator!=(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool InputPatternEdge::operator<(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool InputPatternEdge::operator>(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool InputPatternEdge::operator<=(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool InputPatternEdge::operator>=(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::InputPatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::InputMultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc new file mode 100644 index 0000000000..613159ad83 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc @@ -0,0 +1,25 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml +/* proj-data +{ + "generated_from": "2dff356c85dccda1fce8f714d41c6202" +} +*/ + +#include "substitutions/unlabelled/match_additional_criterion.dtg.h" + +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/graph.h" +#include + +namespace FlexFlow { +MatchAdditionalCriterion::MatchAdditionalCriterion( + std::function const &node_criterion, + std::function const + &edge_criterion) + : node_criterion(node_criterion), edge_criterion(edge_criterion) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/match_split.cc b/lib/substitutions/src/substitutions/unlabelled/match_split.cc new file mode 100644 index 0000000000..ef0397d6a8 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/match_split.cc @@ -0,0 +1,69 @@ +#include "substitutions/unlabelled/match_split.h" +#include "substitutions/unlabelled/edge_splits.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.h" +#include "substitutions/unlabelled/pattern_edge.h" +#include "substitutions/unlabelled/pattern_split.h" + +namespace FlexFlow { + +MatchSplit empty_match_split() { + return MatchSplit{empty_multidigraph_pattern_match(), + empty_multidigraph_pattern_match()}; +} + +MatchSplit apply_split(UnlabelledGraphPattern const &pattern, + MultiDiGraphPatternMatch const &match, + PatternSplit const &split) { + std::unordered_set prefix = split.first; + std::unordered_set postfix = split.second; + + MatchSplit result = empty_match_split(); + + for (auto const &[pattern_node, match_node] : match.node_assignment) { + if (contains(split.first, pattern_node)) { + result.prefix_submatch.node_assignment.equate(pattern_node, match_node); + } else { + assert(contains(split.second, pattern_node)); + result.postfix_submatch.node_assignment.equate(pattern_node, match_node); + } + } + + UnlabelledPatternEdgeSplits edge_splits = get_edge_splits(pattern, split); + + std::function + handle_edge = [&](PatternEdge const &pattern_edge, + OpenMultiDiEdge const &graph_edge) -> void { + std::unordered_set edge_nodes = get_nodes(pattern_edge); + + if (is_subseteq_of(edge_nodes, prefix)) { + result.prefix_submatch.edge_assignment.equate(pattern_edge, graph_edge); + } else if (is_subseteq_of(edge_nodes, postfix)) { + result.postfix_submatch.edge_assignment.equate(pattern_edge, graph_edge); + } else { + assert(is_standard_edge(graph_edge)); + + ClosedPatternEdge closed_edge = require_closed_edge(pattern_edge); + + auto split = get_split_edges(edge_splits, closed_edge); + OutputPatternEdge output_edge = split.first; + InputPatternEdge input_edge = split.second; + + auto split_graph_edge = split_edge(std::get(graph_edge)); + OutputMultiDiEdge output_graph_edge = split_graph_edge.first; + InputMultiDiEdge input_graph_edge = split_graph_edge.second; + + handle_edge(pattern_edge_from_input_edge(input_edge), + OpenMultiDiEdge{input_graph_edge}); + handle_edge(pattern_edge_from_output_edge(output_edge), + OpenMultiDiEdge{output_graph_edge}); + } + }; + + for (auto const &[pattern_edge, match_edge] : match.edge_assignment) { + handle_edge(pattern_edge, match_edge); + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/match_split.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/match_split.dtg.cc new file mode 100644 index 0000000000..ffbdf96912 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/match_split.dtg.cc @@ -0,0 +1,26 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml +/* proj-data +{ + "generated_from": "e44c4347e07263a493cbbd5caccedd22" +} +*/ + +#include "substitutions/unlabelled/match_split.dtg.h" + +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" + +namespace FlexFlow { +MatchSplit::MatchSplit(MultiDiGraphPatternMatch const &prefix_submatch, + MultiDiGraphPatternMatch const &postfix_submatch) + : prefix_submatch(prefix_submatch), postfix_submatch(postfix_submatch) {} +bool MatchSplit::operator==(MatchSplit const &other) const { + return std::tie(this->prefix_submatch, this->postfix_submatch) == + std::tie(other.prefix_submatch, other.postfix_submatch); +} +bool MatchSplit::operator!=(MatchSplit const &other) const { + return std::tie(this->prefix_submatch, this->postfix_submatch) != + std::tie(other.prefix_submatch, other.postfix_submatch); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc new file mode 100644 index 0000000000..8f4fd7f535 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc @@ -0,0 +1,56 @@ +#include "substitutions/unlabelled/multidigraph_pattern_match.h" +#include "substitutions/unlabelled/edge_splits.h" +#include "substitutions/unlabelled/pattern_edge.h" +#include "utils/containers.h" + +namespace FlexFlow { + +MultiDiGraphPatternMatch empty_multidigraph_pattern_match() { + return MultiDiGraphPatternMatch{ + bidict{}, + bidict{}, + }; +} + +std::optional + unsplit_matches(MultiDiGraphPatternMatch const &prefix, + MultiDiGraphPatternMatch const &postfix, + UnlabelledPatternEdgeSplits const &edge_splits) { + + MultiDiGraphPatternMatch result = empty_multidigraph_pattern_match(); + + std::unordered_set handled; + for (auto const &coi : as_closed_output_input_tuples(edge_splits)) { + ClosedPatternEdge closed_edge = std::get(coi); + OutputPatternEdge output_edge = std::get(coi); + InputPatternEdge input_edge = std::get(coi); + + handled.insert(pattern_edge_from_output_edge(output_edge)); + handled.insert(pattern_edge_from_input_edge(input_edge)); + + OpenMultiDiEdge output_graph_edge = + prefix.edge_assignment.at_l(pattern_edge_from_output_edge(output_edge)); + OpenMultiDiEdge input_graph_edge = + postfix.edge_assignment.at_l(pattern_edge_from_input_edge(input_edge)); + if (output_graph_edge == input_graph_edge) { + result.edge_assignment.equate(pattern_edge_from_closed_edge(closed_edge), + output_graph_edge); + } else { + return std::nullopt; + } + } + + for (auto const &kv : + merge_maps(prefix.edge_assignment, postfix.edge_assignment)) { + if (!contains(handled, kv.first)) { + result.edge_assignment.equate(kv.first, kv.second); + } + } + + result.node_assignment = + merge_maps(prefix.node_assignment, postfix.node_assignment); + + return result; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.dtg.cc new file mode 100644 index 0000000000..9fc2169dd7 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.dtg.cc @@ -0,0 +1,34 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml +/* proj-data +{ + "generated_from": "9842661a5d4e7d717f12d2c27da7df0d" +} +*/ + +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" + +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/bidict.h" +#include "utils/graph.h" + +namespace FlexFlow { +MultiDiGraphPatternMatch::MultiDiGraphPatternMatch( + ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> const + &node_assignment, + ::FlexFlow::bidict<::FlexFlow::PatternEdge, + ::FlexFlow::OpenMultiDiEdge> const &edge_assignment) + : node_assignment(node_assignment), edge_assignment(edge_assignment) {} +bool MultiDiGraphPatternMatch::operator==( + MultiDiGraphPatternMatch const &other) const { + return std::tie(this->node_assignment, this->edge_assignment) == + std::tie(other.node_assignment, other.edge_assignment); +} +bool MultiDiGraphPatternMatch::operator!=( + MultiDiGraphPatternMatch const &other) const { + return std::tie(this->node_assignment, this->edge_assignment) != + std::tie(other.node_assignment, other.edge_assignment); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.cc new file mode 100644 index 0000000000..6e70fc8df6 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.cc @@ -0,0 +1,9 @@ +#include "substitutions/unlabelled/output_pattern_edge.h" + +namespace FlexFlow { + +PatternNode get_src_node(OutputPatternEdge const &e) { + return PatternNode{e.raw_edge.src}; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc new file mode 100644 index 0000000000..fb9de06135 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "3222696e351c3e203e008714245c737f" +} +*/ + +#include "substitutions/unlabelled/output_pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +OutputPatternEdge::OutputPatternEdge( + ::FlexFlow::OutputMultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool OutputPatternEdge::operator==(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool OutputPatternEdge::operator!=(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool OutputPatternEdge::operator<(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool OutputPatternEdge::operator>(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool OutputPatternEdge::operator<=(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool OutputPatternEdge::operator>=(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OutputPatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OutputMultiDiEdge>{}(x.raw_edge) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc new file mode 100644 index 0000000000..3dd4987705 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc @@ -0,0 +1,50 @@ +#include "substitutions/unlabelled/pattern_edge.h" +#include "utils/containers.h" + +namespace FlexFlow { + +std::unordered_set get_nodes(PatternEdge const &e) { + return transform(get_nodes(e.raw_edge), + [](Node const &n) { return PatternNode{n}; }); +} + +bool is_standard_edge(PatternEdge const &e) { + return is_standard_edge(e.raw_edge); +} + +bool is_input_edge(PatternEdge const &e) { + return is_input_edge(e.raw_edge); +} + +bool is_output_edge(PatternEdge const &e) { + return is_output_edge(e.raw_edge); +} + +ClosedPatternEdge require_closed_edge(PatternEdge const &e) { + assert(is_closed_edge(e)); + return ClosedPatternEdge{std::get(e.raw_edge)}; +} + +InputPatternEdge require_input_edge(PatternEdge const &e) { + assert(is_input_edge(e)); + return InputPatternEdge{std::get(e.raw_edge)}; +} + +OutputPatternEdge require_output_edge(PatternEdge const &e) { + assert(is_output_edge(e)); + return OutputPatternEdge{std::get(e.raw_edge)}; +} + +PatternEdge pattern_edge_from_input_edge(InputPatternEdge const &e) { + return PatternEdge{OpenMultiDiEdge{e.raw_edge}}; +} + +PatternEdge pattern_edge_from_output_edge(OutputPatternEdge const &e) { + return PatternEdge{OpenMultiDiEdge{e.raw_edge}}; +} + +PatternEdge pattern_edge_from_closed_edge(ClosedPatternEdge const &e) { + return PatternEdge{OpenMultiDiEdge{e.raw_edge}}; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc new file mode 100644 index 0000000000..e4d11d0d7e --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "a3eff166b0c8be2ddf3f7305eec094fd" +} +*/ + +#include "substitutions/unlabelled/pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +PatternEdge::PatternEdge(::FlexFlow::OpenMultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool PatternEdge::operator==(PatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool PatternEdge::operator!=(PatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool PatternEdge::operator<(PatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool PatternEdge::operator>(PatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool PatternEdge::operator<=(PatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool PatternEdge::operator>=(PatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::PatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OpenMultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc new file mode 100644 index 0000000000..335b9664ea --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -0,0 +1,74 @@ +#include "substitutions/unlabelled/pattern_matching.h" +#include "substitutions/unlabelled/input_pattern_edge.h" +#include "substitutions/unlabelled/match_split.h" +#include "substitutions/unlabelled/output_pattern_edge.h" +#include "substitutions/unlabelled/pattern_edge.h" +#include "substitutions/unlabelled/pattern_split.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include + +namespace FlexFlow { + +bool unlabelled_pattern_does_match( + UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MultiDiGraphPatternMatch const &match, + MatchAdditionalCriterion const &additional_criterion) { + if (is_singleton_pattern(pattern)) { + PatternNode pattern_node = get_only(get_nodes(pattern)); + Node matched_node = match.node_assignment.at_l(pattern_node); + if (!additional_criterion.node_criterion(pattern_node, matched_node)) { + return false; + } + for (PatternEdge const &e : get_edges(pattern)) { + OpenMultiDiEdge matched_edge = match.edge_assignment.at_l(e); + + assert(is_input_edge(e) || is_output_edge(e)); + if (is_input_edge(e)) { + if (is_output_edge(matched_edge)) { + return false; + } + UpwardOpenMultiDiEdge matched_edge = + narrow(matched_edge).value(); + InputPatternEdge input_edge = require_input_edge(e); + if (match.node_assignment.at_l(get_dst_node(input_edge)) != + get_dst_node(matched_edge)) { + return false; + } + } else { + if (is_input_edge(matched_edge)) { + return false; + } + DownwardOpenMultiDiEdge matched_edge = + narrow(matched_edge).value(); + OutputPatternEdge output_edge = require_output_edge(e); + if (match.node_assignment.at_l(get_src_node(output_edge)) != + get_src_node(matched_edge)) { + return false; + } + } + + if (!additional_criterion.edge_criterion(e, matched_edge)) { + return false; + } + } + + return true; + } + + PatternSplit split = find_even_split(pattern); + std::pair subpatterns = + apply_split(pattern, split); + auto submatches = apply_split(pattern, match, split); + + return unlabelled_pattern_does_match(subpatterns.first, + graph, + submatches.prefix_submatch, + additional_criterion) && + unlabelled_pattern_does_match(subpatterns.second, + graph, + submatches.postfix_submatch, + additional_criterion); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc new file mode 100644 index 0000000000..6ea64de69e --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml +/* proj-data +{ + "generated_from": "a0e58ade010a9b250d2c1c378fde2639" +} +*/ + +#include "substitutions/unlabelled/pattern_node.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +PatternNode::PatternNode(::FlexFlow::Node const &raw_node) + : raw_node(raw_node) {} +bool PatternNode::operator==(PatternNode const &other) const { + return std::tie(this->raw_node) == std::tie(other.raw_node); +} +bool PatternNode::operator!=(PatternNode const &other) const { + return std::tie(this->raw_node) != std::tie(other.raw_node); +} +bool PatternNode::operator<(PatternNode const &other) const { + return std::tie(this->raw_node) < std::tie(other.raw_node); +} +bool PatternNode::operator>(PatternNode const &other) const { + return std::tie(this->raw_node) > std::tie(other.raw_node); +} +bool PatternNode::operator<=(PatternNode const &other) const { + return std::tie(this->raw_node) <= std::tie(other.raw_node); +} +bool PatternNode::operator>=(PatternNode const &other) const { + return std::tie(this->raw_node) >= std::tie(other.raw_node); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::PatternNode const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.raw_node) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc new file mode 100644 index 0000000000..e116c062df --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc @@ -0,0 +1,42 @@ +#include "substitutions/unlabelled/pattern_split.h" + +namespace FlexFlow { + +PatternSplit find_even_split(UnlabelledGraphPattern const &p) { + std::vector topological_ordering = + get_topological_ordering(pattern.raw_graph); + assert(topological_ordering.size() >= 2); + + int split_point = topological_ordering.size() / 2; + auto split = vector_split(topological_ordering, split_point); + std::unordered_set prefix(split.first.begin(), + split.first.end()); + std::unordered_set postfix(split.second.begin(), + split.second.end()); + return {prefix, postfix}; +} + +GraphSplit get_raw_split(PatternSplit const &s) { + return std::pair{ + transform(s.first, [](PatternNode const &n) { return n.raw_node; }), + transform(s.second, [](PatternNode const &n) { return n.raw_node; }), + }; +} + +UnlabelledPatternEdgeSplits + get_edge_splits(UnlabelledGraphPattern const &pattern, + PatternSplit const &split) { + bidict> + raw_result = get_edge_splits(pattern.raw_graph, get_raw_split(split), ); + return UnlabelledPatternEdgeSplits{raw_result}; +} + +std::pair + apply_split(UnlabelledGraphPattern const &p, PatternSplit const &s) { + return { + get_subgraph(p, s.left); + get_subgraph(p, s.right); + }; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc new file mode 100644 index 0000000000..bbcd4c3902 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc @@ -0,0 +1,60 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml +/* proj-data +{ + "generated_from": "8604edb5bd1a546ffa94ef496888e46d" +} +*/ + +#include "substitutions/unlabelled/pattern_split.dtg.h" + +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +PatternSplit::PatternSplit( + std::unordered_set<::FlexFlow::PatternNode> const &first, + std::unordered_set<::FlexFlow::PatternNode> const &second) + : first(first), second(second) {} +bool PatternSplit::operator==(PatternSplit const &other) const { + return std::tie(this->first, this->second) == + std::tie(other.first, other.second); +} +bool PatternSplit::operator!=(PatternSplit const &other) const { + return std::tie(this->first, this->second) != + std::tie(other.first, other.second); +} +} // namespace FlexFlow + +namespace nlohmann { +FlexFlow::PatternSplit + adl_serializer::from_json(json const &j) { + return { + j.at("first").template get>(), + j.at("second") + .template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::PatternSplit const &v) { + j["__type"] = "PatternSplit"; + j["first"] = v.first; + j["second"] = v.second; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(PatternSplit const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, PatternSplit const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc new file mode 100644 index 0000000000..df10507a04 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -0,0 +1,52 @@ +#include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "utils/containers.h" + +namespace FlexFlow { + +size_t num_nodes(UnlabelledGraphPattern const &p) { + return num_nodes(p.raw_graph); +} + +bool is_singleton_pattern(UnlabelledGraphPattern const &pattern) { + return num_nodes(pattern) == 1; +} + +std::unordered_set get_nodes(UnlabelledGraphPattern const &p) { + return transform(get_nodes(p.raw_graph), + [](Node const &n) { + return PatternNode{n}; }}); +} + +std::unordered_set get_edges(UnlabelledGraphPattern const &p) { + return transform(get_nodes(p.raw_graph), + [](OpenMultiDiEdge const &e) { + return PatternEdge{e}; }}); +} + +std::vector get_topological_ordering(UnlabelledGraphPattern const &p) { + return transform(get_topological_ordering(p), + [](Node const &n) { + return PatternNode{n}; }}); +} + +UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &p, + std::unordered_set const &n) { + return { + get_subgraph(p.raw_graph, + transform(n, [](PatternNode const &n) { return n.raw_node; })); + }; +} + +std::unordered_set + get_incoming_edges(UnlabelledGraphPattern const &p, PatternNode const &n) { + return transform(get_incoming_edges(p.raw_graph, n.raw_node), + [](Node const &n) { return PatternNode{n}; }); +} + +std::unordered_set + get_outgoing_edges(UnlabelledGraphPattern const &p, PatternNode const &n) { + return transform(get_outgoing_edges(p.raw_graph, n.raw_node), + [](Node const &n) { return PatternNode{n}; }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc new file mode 100644 index 0000000000..019209ee86 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc @@ -0,0 +1,18 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml +/* proj-data +{ + "generated_from": "f494ed79eb1ba4010155e456b452157f" +} +*/ + +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +UnlabelledGraphPattern::UnlabelledGraphPattern( + ::FlexFlow::OpenMultiDiGraphView const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.cc new file mode 100644 index 0000000000..8664f3c66c --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.cc @@ -0,0 +1,9 @@ +#include "substitutions/unlabelled/upward_open_pattern_edge.h" + +namespace FlexFlow { + +int get_dst_idx(UpwardOpenPatternEdge const &e) { + return get_src_idx(e.raw_edge); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc new file mode 100644 index 0000000000..ca8dd6c020 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc @@ -0,0 +1,52 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "a1d4c9d1dd94eb456c5e29d80ad579da" +} +*/ + +#include "substitutions/unlabelled/upward_open_pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +UpwardOpenPatternEdge::UpwardOpenPatternEdge( + ::FlexFlow::UpwardOpenMultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool UpwardOpenPatternEdge::operator==( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool UpwardOpenPatternEdge::operator!=( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool UpwardOpenPatternEdge::operator<( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool UpwardOpenPatternEdge::operator>( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool UpwardOpenPatternEdge::operator<=( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool UpwardOpenPatternEdge::operator>=( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::UpwardOpenPatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::UpwardOpenMultiDiEdge>{}(x.raw_edge) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index df22d8a620..2d9320275d 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -8,8 +8,10 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("apply_substitution") { OperatorPattern operator_pattern_n0{ - std::vector{OperatorAttributeConstraint{ - ConstraintType::EQUAL, OperatorAttributeKey::OP_TYPE, Op::LINEAR}}}; + std::vector{ + OperatorAttributeConstraint{ConstraintType::EQUAL, + OperatorAttributeKey::OP_TYPE, + OperatorType::LINEAR}}}; ParallelTensorPattern tensor_pattern_e0{ std::vector{ @@ -38,12 +40,13 @@ TEST_SUITE(FF_TEST_SUITE) { GraphPattern input_graph{ig}; OperatorAttrAssignment op_ass_n1{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REPARTITION}}, + {{OperatorAttributeKey::OP_TYPE, + AttrConstant{OperatorType::REPARTITION}}, {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; OperatorAttrAssignment op_ass_n2{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::LINEAR}}, + {{OperatorAttributeKey::OP_TYPE, AttrConstant{OperatorType::LINEAR}}, {OperatorAttributeKey::OUT_CHANNELS, OperatorAttrAccess{n0, OperatorAttributeKey::OUT_CHANNELS}}, {OperatorAttributeKey::USE_BIAS, @@ -56,7 +59,7 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorAttrAccess{n0, OperatorAttributeKey::REGULARIZER}}}}; OperatorAttrAssignment op_ass_n3{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REDUCTION}}, + {{OperatorAttributeKey::OP_TYPE, AttrConstant{OperatorType::REDUCTION}}, {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; @@ -100,10 +103,9 @@ TEST_SUITE(FF_TEST_SUITE) { MultiDiEdge e4{n5, p5, n4, p4}; pcg.add_edge(e4); - pcg.add_label(e4, - ParallelTensor(ParallelTensorDims({2, 1}), - DataType::FLOAT, - CreateGrad::YES)); + ParallelDim dim = {2, 1, false}; + ParallelTensorDims dims = {FFOrdered{dim}}; + pcg.add_label(e4, ParallelTensor(dims, DataType::FLOAT, CreateGrad::YES)); MatchAdditionalCriterion criterion{ [&](Node const &pattern_node, Node const &graph_node) { diff --git a/lib/utils/include/utils/bidict.h b/lib/utils/include/utils/bidict.h index 0869b0f9e8..6af18c2a4a 100644 --- a/lib/utils/include/utils/bidict.h +++ b/lib/utils/include/utils/bidict.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_UTILS_BIDICT_H #define _FLEXFLOW_UTILS_BIDICT_H +#include "utils/fmt/unordered_map.h" #include #include @@ -55,9 +56,22 @@ struct bidict { bwd_map.insert({lr.second, lr.first}); } + bool operator==(bidict const &other) const { + bool result = this->fwd_map == other.fwd_map; + assert(result == (this->bwd_map == other.bwd_map)); + return result; + } + + bool operator!=(bidict const &other) const { + bool result = this->fwd_map != other.fwd_map; + assert(result == (this->bwd_map != other.bwd_map)); + return result; + } + R const &at_l(L const &l) const { return fwd_map.at(l); } + L const &at_r(R const &r) const { return bwd_map.at(r); } @@ -163,6 +177,22 @@ struct bidict { std::unordered_map bwd_map; }; +template +std::unordered_map format_as(bidict const &b) { + return b; +} + } // namespace FlexFlow +namespace std { + +template +struct hash<::FlexFlow::bidict> { + size_t operator()(::FlexFlow::bidict const &b) const { + return hash>{}(b); + } +}; + +} // namespace std + #endif diff --git a/lib/utils/include/utils/check_fmtable.h b/lib/utils/include/utils/check_fmtable.h new file mode 100644 index 0000000000..3b4e55c459 --- /dev/null +++ b/lib/utils/include/utils/check_fmtable.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CHECK_FMTABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CHECK_FMTABLE_H + +#define CHECK_FMTABLE(...) \ + static_assert(::FlexFlow::is_fmtable<__VA_ARGS__>::value, \ + #__VA_ARGS__ " must be fmtable"); + +namespace FlexFlow { + +template +using is_fmtable = ::fmt::is_formattable; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 0332a331b2..b02c95bf77 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -16,19 +16,6 @@ struct get_element_type; template using get_element_type_t = typename get_element_type::type; -template -std::string join_strings(InputIt first, - InputIt last, - std::string const &delimiter, - F const &f); - -template -std::string - join_strings(InputIt first, InputIt last, std::string const &delimiter); - -template -std::string join_strings(Container const &c, std::string const &delimiter); - template typename Container::const_iterator find(Container const &c, typename Container::value_type const &e); @@ -154,6 +141,9 @@ template > bidict generate_bidict(C const &c, F const &f); +template +std::optional at_idx(std::vector const &v, size_t idx); + template std::function lookup_in(std::unordered_map const &m); @@ -208,6 +198,9 @@ void extend(C &lhs, std::optional const &e); template bool all_of(C const &c, F const &f); +template +std::optional optional_all_of(Container const &, Function const &); + template int count(C const &c, F const &f); @@ -226,11 +219,6 @@ template auto transform(req const &c, F const &f) -> decltype(transform(std::declval(), std::declval())); -template ()(std::declval()))> -std::vector vector_transform(F const &f, std::vector const &v); - template ()(std::declval()))> diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 1606eb0605..fbaf572df1 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -5,6 +5,8 @@ #include "containers.decl.h" #include "required_core.h" #include "type_traits_core.h" +#include "utils/containers/extend_vector.h" +#include "utils/containers/vector_transform.h" #include "utils/exception.h" #include "utils/type_traits.h" #include @@ -21,38 +23,6 @@ namespace FlexFlow { -template -std::string join_strings(InputIt first, - InputIt last, - std::string const &delimiter, - F const &f) { - std::ostringstream oss; - bool first_iter = true; - /* int i = 0; */ - for (; first != last; first++) { - if (!first_iter) { - oss << delimiter; - } - oss << *first; - /* break; */ - first_iter = false; - /* i++; */ - } - return oss.str(); -} - -template -std::string - join_strings(InputIt first, InputIt last, std::string const &delimiter) { - using Ref = typename InputIt::reference; - return join_strings(first, last, delimiter, [](Ref r) { return r; }); -} - -template -std::string join_strings(Container const &c, std::string const &delimiter) { - return join_strings(c.cbegin(), c.cend(), delimiter); -} - template typename Container::const_iterator find(Container const &c, typename Container::value_type const &e) { @@ -346,6 +316,15 @@ bidict generate_bidict(C const &c, F const &f) { return {transformed.cbegin(), transformed.cend()}; } +template +std::optional at_idx(std::vector const &v, size_t idx) { + if (idx >= v.size()) { + return std::nullopt; + } else { + return v.at(idx); + } +} + template std::function lookup_in(std::unordered_map const &m) { return [&m](K const &k) -> V { return m.at(k); }; @@ -441,6 +420,7 @@ T get_first(std::unordered_set const &s) { template void extend(std::vector &lhs, C const &rhs) { + extend_vector(lhs, rhs); lhs.reserve(lhs.size() + std::distance(rhs.begin(), rhs.end())); lhs.insert(lhs.end(), rhs.begin(), rhs.end()); } @@ -468,6 +448,22 @@ bool all_of(C const &c, F const &f) { return true; } +template +std::optional optional_all_of(Container const &container, + Function const &func) { + for (auto const &element : container) { + std::optional condition = func(element); + if (!condition.has_value()) { + return std::nullopt; + } + + if (!condition.value()) { + return false; + } + } + return true; +} + template int count(C const &c, F const &f) { int result = 0; @@ -509,11 +505,6 @@ auto transform(req const &c, F const &f) return transform(static_cast(c), f); } -template -std::vector vector_transform(F const &f, std::vector const &v) { - return transform(v, f); -} - template std::unordered_set transform(std::unordered_set const &v, F const &f) { std::unordered_set result; @@ -693,7 +684,7 @@ std::vector subvec(std::vector const &v, auto resolve_loc = [&](int idx) -> typename std::vector::iterator::difference_type { if (idx < 0) { - return v.size() - idx; + return v.size() + idx; } else { return idx; } diff --git a/lib/utils/include/utils/containers/concat_vectors.h b/lib/utils/include/utils/containers/concat_vectors.h new file mode 100644 index 0000000000..7940a37510 --- /dev/null +++ b/lib/utils/include/utils/containers/concat_vectors.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CONCAT_VECTORS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CONCAT_VECTORS_H + +#include "utils/containers/extend_vector.h" + +namespace FlexFlow { + +template +std::vector concat_vectors(std::vector const &prefix, + std::vector const &postfix) { + std::vector result = prefix; + extend_vector(result, postfix); + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/enumerate_vector.h b/lib/utils/include/utils/containers/enumerate_vector.h new file mode 100644 index 0000000000..8d36a5fe3b --- /dev/null +++ b/lib/utils/include/utils/containers/enumerate_vector.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H + +#include +#include + +namespace FlexFlow { + +template +std::vector> enumerate_vector(std::vector const &v) { + std::vector> result; + for (int i = 0; i < v.size(); i++) { + result.push_back({i, v.at(i)}); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/extend_vector.h b/lib/utils/include/utils/containers/extend_vector.h new file mode 100644 index 0000000000..62ce94e49c --- /dev/null +++ b/lib/utils/include/utils/containers/extend_vector.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_EXTEND_VECTOR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_EXTEND_VECTOR_H + +#include + +namespace FlexFlow { + +template +void extend_vector(std::vector &lhs, C const &rhs) { + lhs.reserve(lhs.size() + std::distance(rhs.begin(), rhs.end())); + lhs.insert(lhs.end(), rhs.begin(), rhs.end()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/vector_transform.h b/lib/utils/include/utils/containers/vector_transform.h new file mode 100644 index 0000000000..6d13584775 --- /dev/null +++ b/lib/utils/include/utils/containers/vector_transform.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_TRANSFORM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_TRANSFORM_H + +#include +#include + +namespace FlexFlow { + +template +std::vector> + vector_transform(std::vector const &v, F const &f) { + using Out = std::invoke_result_t; + + std::vector result; + std::transform(v.cbegin(), v.cend(), std::back_inserter(result), f); + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_vectors.h b/lib/utils/include/utils/containers/zip_vectors.h new file mode 100644 index 0000000000..d32e539bef --- /dev/null +++ b/lib/utils/include/utils/containers/zip_vectors.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VECTORS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VECTORS_H + +#include +#include + +namespace FlexFlow { + +template +std::vector> zip(std::vector const &l, + std::vector const &r) { + std::vector> result; + for (int i = 0; i < std::min(l.size(), r.size()); i++) { + result.push_back(std::make_pair(l.at(i), r.at(i))); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/exception.decl.h b/lib/utils/include/utils/exception.decl.h index d27174f474..93c450294b 100644 --- a/lib/utils/include/utils/exception.decl.h +++ b/lib/utils/include/utils/exception.decl.h @@ -3,20 +3,30 @@ #include "utils/fmt.decl.h" #include +#include namespace FlexFlow { #ifdef FF_REQUIRE_IMPLEMENTED -#define NOT_IMPLEMENTED() static_assert(false, "Function not yet implemented"); +#define NOT_IMPLEMENTED() \ + static_assert(false, \ + "Function " __FUNC__ " not yet implemented " __FILE__ \ + ":" __LINE__); #else -#define NOT_IMPLEMENTED() throw not_implemented(); +#define NOT_IMPLEMENTED() \ + throw not_implemented(__PRETTY_FUNCTION__, __FILE__, __LINE__); #endif class not_implemented : public std::logic_error { public: - not_implemented(); + not_implemented(std::string const &function_name, + std::string const &file_name, + int line); }; +template +T throw_if_unexpected(tl::expected const &r); + template std::runtime_error mk_runtime_error(fmt::format_string fmt_str, T &&...args); diff --git a/lib/utils/include/utils/exception.h b/lib/utils/include/utils/exception.h index fd3a0b7ee0..a00d2dba2b 100644 --- a/lib/utils/include/utils/exception.h +++ b/lib/utils/include/utils/exception.h @@ -4,9 +4,19 @@ #include "utils/exception.decl.h" #include "utils/fmt.h" #include +#include namespace FlexFlow { +template +T throw_if_unexpected(tl::expected const &r) { + if (r.has_value()) { + return r.value(); + } else { + throw std::runtime_error(fmt::to_string(r.error())); + } +} + template std::runtime_error mk_runtime_error(fmt::format_string fmt_str, T &&...args) { diff --git a/lib/utils/include/utils/fmt.decl.h b/lib/utils/include/utils/fmt.decl.h index 7adb2052ad..04902c8240 100644 --- a/lib/utils/include/utils/fmt.decl.h +++ b/lib/utils/include/utils/fmt.decl.h @@ -2,22 +2,18 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_FMT_DECL_H #include "fmt/format.h" +#include "utils/check_fmtable.h" #include +#include +#include #include -#define CHECK_FMTABLE(...) \ - static_assert(::FlexFlow::is_fmtable<__VA_ARGS__>::value, \ - #__VA_ARGS__ " must be fmtable"); - #define DELEGATE_OSTREAM(...) \ template <> \ struct delegate_ostream_operator<__VA_ARGS__> : std::true_type {} namespace FlexFlow { -template -using is_fmtable = ::fmt::is_formattable; - template struct delegate_ostream_operator : std::false_type {}; @@ -30,20 +26,38 @@ typename std::enable_if>::value, namespace fmt { -template -struct formatter<::std::unordered_set> : formatter<::std::string> { +template +struct formatter< + ::std::unordered_set, + Char, + std::enable_if_t>::value>> + : formatter<::std::string, Char> { template auto format(::std::unordered_set const &m, FormatContext &ctx) -> decltype(ctx.out()); }; -template -struct formatter<::std::vector> : formatter<::std::string> { +/* template */ +/* std::string format_as(::std::unordered_set const &); */ + +template +struct formatter< + ::std::vector, + Char, + std::enable_if_t>::value>> + : formatter<::std::string> { template auto format(::std::vector const &m, FormatContext &ctx) -> decltype(ctx.out()); }; +template +struct formatter<::std::variant> : formatter<::std::string> { + template + auto format(::std::variant const &m, FormatContext &ctx) + -> decltype(ctx.out()); +}; + } // namespace fmt #endif diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index 9cb56e4e2b..967a41f22b 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -1,16 +1,69 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_FMT_H #define _FLEXFLOW_UTILS_INCLUDE_FMT_H -#include "utils/containers.decl.h" +#include "utils/containers.h" #include "utils/fmt.decl.h" #include "utils/test_types.h" #include "utils/type_traits_core.h" +#include +#include #include +#include -#include +namespace fmt { + +template +template +auto formatter< + ::std::unordered_set, + Char, + std::enable_if_t>::value>>:: + format(::std::unordered_set const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + /* CHECK_FMTABLE(T); */ + + /* std::string result = ::FlexFlow::join_strings( */ + /* m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); + * }); */ + std::string result = ""; + return formatter::format(result, ctx); +} + +template +template +auto formatter< + ::std::vector, + Char, + std::enable_if_t>::value>>:: + format(::std::vector const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + CHECK_FMTABLE(T); + + std::string result = ::FlexFlow::join_strings( + m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); + return formatter::format("[" + result + "]", ctx); +} + +template +template +auto formatter<::std::variant>::format(::std::variant const &m, + FormatContext &ctx) + -> decltype(ctx.out()) { + + std::string result = + std::visit([](auto &&x) { return fmt::to_string(x); }, m); + return formatter::format(result, ctx); +} +} // namespace fmt namespace FlexFlow { +template +struct delegate_ostream_operator> : std::true_type {}; + +template +struct delegate_ostream_operator> : std::true_type {}; + template struct delegate_ostream_operator> : std::true_type {}; @@ -31,32 +84,4 @@ typename std::enable_if>::value, } // namespace FlexFlow -namespace fmt { - -template -template -auto formatter<::std::unordered_set>::format( - ::std::unordered_set const &m, FormatContext &ctx) - -> decltype(ctx.out()) { - CHECK_FMTABLE(T); - - std::string result = join_strings( - m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); - return formatter::format(result, ctx); -} - -template -template -auto formatter<::std::vector>::format(::std::vector const &m, - FormatContext &ctx) - -> decltype(ctx.out()) { - CHECK_FMTABLE(T); - - std::string result = join_strings( - m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); - return formatter::format(result, ctx); -} - -} // namespace fmt - #endif diff --git a/lib/utils/include/utils/fmt/expected.h b/lib/utils/include/utils/fmt/expected.h new file mode 100644 index 0000000000..5edd054ebe --- /dev/null +++ b/lib/utils/include/utils/fmt/expected.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_EXPECTED_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_EXPECTED_H + +#include "fmt/format.h" +#include "utils/check_fmtable.h" +#include +#include + +namespace fmt { + +template +struct formatter<::tl::expected, Char> + /* std::enable_if_t>::value>> */ + : formatter<::std::string> { + template + auto format(::tl::expected const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + + std::string result; + if (m.has_value()) { + result = fmt::format("expected({})", m.value()); + } else { + result = fmt::format("unexpected({})", m.error()); + } + + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + +#endif diff --git a/lib/utils/include/utils/fmt/pair.h b/lib/utils/include/utils/fmt/pair.h new file mode 100644 index 0000000000..eb1147ae3c --- /dev/null +++ b/lib/utils/include/utils/fmt/pair.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H + +#include "fmt/format.h" +#include "utils/check_fmtable.h" +#include + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::pair const &m) { + CHECK_FMTABLE(L); + CHECK_FMTABLE(R); + + return s << fmt::to_string(m); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/fmt/unordered_map.h b/lib/utils/include/utils/fmt/unordered_map.h new file mode 100644 index 0000000000..19701bfb0c --- /dev/null +++ b/lib/utils/include/utils/fmt/unordered_map.h @@ -0,0 +1,45 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_MAP_H + +#include "fmt/format.h" +#include "utils/check_fmtable.h" +#include "utils/join_strings.h" +#include + +namespace fmt { + +template +struct formatter< + ::std::unordered_map, + Char, + std::enable_if_t>::value>> + : formatter<::std::string> { + template + auto format(::std::unordered_map const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + /* CHECK_FMTABLE(K); */ + /* CHECK_FMTABLE(V); */ + + /* std::string result = ::FlexFlow::join_strings( */ + /* m.cbegin(), m.cend(), ", ", [](std::pair const &p) { return + * fmt::to_string(p); }); */ + std::string result = ""; + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::unordered_map const &m) { + CHECK_FMTABLE(K); + CHECK_FMTABLE(V); + + return s << fmt::to_string(m); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/multidiedge.h b/lib/utils/include/utils/graph/multidiedge.h index 808981afa1..d7c2c1590b 100644 --- a/lib/utils/include/utils/graph/multidiedge.h +++ b/lib/utils/include/utils/graph/multidiedge.h @@ -17,6 +17,10 @@ FF_VISIT_FMTABLE(MultiDiInput); struct MultiDiOutput : DiOutput { NodePort src_idx; + + bool operator>(MultiDiOutput const &) const; + bool operator>=(MultiDiOutput const &) const; + bool operator<=(MultiDiOutput const &) const; }; FF_VISITABLE_STRUCT(MultiDiOutput, src, src_idx); FF_VISIT_FMTABLE(MultiDiOutput); diff --git a/lib/utils/include/utils/integer_conversions.h b/lib/utils/include/utils/integer_conversions.h new file mode 100644 index 0000000000..154aaa2a67 --- /dev/null +++ b/lib/utils/include/utils/integer_conversions.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_INTEGER_CONVERSIONS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_INTEGER_CONVERSIONS_H + +#include + +namespace FlexFlow { + +size_t size_t_from_int(int); +int int_from_size_t(size_t); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/join_strings.h b/lib/utils/include/utils/join_strings.h new file mode 100644 index 0000000000..db82004317 --- /dev/null +++ b/lib/utils/include/utils/join_strings.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JOIN_STRINGS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JOIN_STRINGS_H + +#include +#include + +namespace FlexFlow { + +template +std::string join_strings(InputIt first, + InputIt last, + std::string const &delimiter, + F const &f) { + std::ostringstream oss; + bool first_iter = true; + /* int i = 0; */ + for (; first != last; first++) { + if (!first_iter) { + oss << delimiter; + } + oss << *first; + /* break; */ + first_iter = false; + /* i++; */ + } + return oss.str(); +} + +template +std::string + join_strings(InputIt first, InputIt last, std::string const &delimiter) { + using Ref = typename InputIt::reference; + return join_strings(first, last, delimiter, [](Ref r) { return r; }); +} + +template +std::string join_strings(Container const &c, std::string const &delimiter) { + return join_strings(c.cbegin(), c.cend(), delimiter); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json.h b/lib/utils/include/utils/json.h index 010943a9f9..f56917e329 100644 --- a/lib/utils/include/utils/json.h +++ b/lib/utils/include/utils/json.h @@ -143,40 +143,53 @@ struct VariantToJsonFunctor { void operator()(T const &t) { static_assert(is_jsonable::value, ""); - j["type"] = get_name(t); - j["value"] = t; + j = t; } }; template void variant_to_json(json &j, std::variant const &v) { - visit(::FlexFlow::VariantToJsonFunctor{j}, v.value); + json jval; + visit(::FlexFlow::VariantToJsonFunctor{jval}, v); + j["value"] = jval; + j["index"] = v.index(); } -template -struct VariantFromJsonFunctor { - VariantFromJsonFunctor(json const &j) : j(j) {} +template +std::optional variant_from_json_impl(json const &j) { + using Type = typename std::variant_alternative::type; - json const &j; - - template - std::optional - operator()(std::integral_constant const &) const { - using Type = typename std::variant_alternative::type; + if (j.at("index").get() == Idx) { + return j.at("value").get(); + } + return std::nullopt; +} - if (visit_struct::get_name()) { - return j.at("value").get(); +template +std::optional variant_from_json_impl(json const &j, + std::index_sequence) { + // If there were no errors when parsing, all but one element of the array + // will be nullopt. This is because each call to variant_from_json_impl will + // have a unique index and exactly one of them will match the index in the + // json object. + std::array, sizeof...(Is)> results{ + variant_from_json_impl(j)...}; + for (std::optional &maybe : results) { + if (maybe) { + return maybe.value(); } } -}; + return std::nullopt; +} template std::variant variant_from_json(json const &j) { - ::FlexFlow::VariantFromJsonFunctor> func(j); - auto result = seq_map(func, seq_enumerate_args_t{}); + using Variant = std::variant; + std::optional result = variant_from_json_impl( + j, std::make_index_sequence()); if (!result.has_value()) { throw ::FlexFlow::mk_runtime_error("Invalid type {} found in json", - j.at("type").get()); + j.at("index").get()); } return result.value(); } diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 71b6d9d975..2594a96c8e 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_OPTIONAL_H #include "fmt.h" +#include "rapidcheck.h" #include "utils/exception.h" #include "utils/optional.decl" @@ -23,12 +24,28 @@ T const &assert_unwrap(std::optional const &o) { return o.value(); } +template +std::optional> transform(std::optional const &o, + F &&f) { + using Return = std::invoke_result_t; + if (o.has_value()) { + Return r = f(o.value()); + return std::optional{r}; + } else { + return std::nullopt; + } +} + } // namespace FlexFlow namespace fmt { -template -struct formatter<::std::optional> : formatter { +template +struct formatter< + ::std::optional, + Char, + std::enable_if_t>::value>> + : formatter { template auto format(::std::optional const &q, FormatContext &ctx) -> decltype(ctx.out()) { @@ -44,4 +61,18 @@ struct formatter<::std::optional> : formatter { } // namespace fmt +namespace rc { + +template +struct Arbitrary> { + static Gen> arbitrary() { + return gen::map( + gen::maybe(std::move(gen::arbitrary())), [](Maybe &&m) { + return m ? std::optional(std::move(*m)) : std::optional(); + }); + } +}; + +} // namespace rc + #endif diff --git a/lib/utils/include/utils/overload.h b/lib/utils/include/utils/overload.h new file mode 100644 index 0000000000..7e0431eba5 --- /dev/null +++ b/lib/utils/include/utils/overload.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OVERLOAD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OVERLOAD_H + +namespace FlexFlow { + +template +struct overload : Ts... { + using Ts::operator()...; +}; +template +overload(Ts...) -> overload; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/stack_string.h b/lib/utils/include/utils/stack_string.h index 71b092d2c1..0074877768 100644 --- a/lib/utils/include/utils/stack_string.h +++ b/lib/utils/include/utils/stack_string.h @@ -4,6 +4,7 @@ #include "fmt/core.h" #include "stack_vector.h" #include "utils/fmt.h" +#include "utils/json.h" #include "utils/type_traits.h" #include #include @@ -64,6 +65,19 @@ struct stack_basic_string { template using stack_string = stack_basic_string; +template +void to_json(json &j, stack_string const &v) { + std::string as_string = v; + j = as_string; +} + +template +void from_json(json const &j, stack_string &v) { + std::string as_string; + j.get_to(as_string); + v = stack_string{as_string}; +} + } // namespace FlexFlow namespace std { diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index fe665ed749..d47886b055 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -3,7 +3,9 @@ #include "containers.h" #include "hash-utils.h" +#include "rapidcheck.h" #include "utils/fmt.h" +#include "utils/json.h" #include "utils/test_types.h" #include "utils/type_traits.h" #include @@ -38,9 +40,10 @@ struct stack_vector { template stack_vector(Iterator start, Iterator end) { - assert(end - start >= 0); - assert(end - start <= MAXSIZE); - for (; start < end; start++) { + size_t elements_added = 0; + for (; start != end; start++) { + elements_added++; + assert(elements_added <= MAXSIZE); this->push_back(static_cast(*start)); } } @@ -310,8 +313,24 @@ struct stack_vector { implies, is_lt_comparable>::value, ""); }; +template +struct delegate_ostream_operator> : std::true_type {}; + // CHECK_FMTABLE(stack_vector); +template +void to_json(json &j, stack_vector const &v) { + std::vector as_vec(v.begin(), v.end()); + j = as_vec; +} + +template +void from_json(json const &j, stack_vector &v) { + std::vector as_vec; + j.get_to(as_vec); + v = stack_vector{as_vec.begin(), as_vec.end()}; +} + } // namespace FlexFlow namespace std { @@ -329,4 +348,18 @@ struct hash<::FlexFlow::stack_vector> { } // namespace std +namespace rc { + +template +struct Arbitrary<::FlexFlow::stack_vector> { + static Gen<::FlexFlow::stack_vector> arbitrary() { + return gen::mapcat(gen::inRange(0, MAXSIZE), [](size_t size) { + return gen::container<::FlexFlow::stack_vector>( + size, gen::arbitrary()); + }); + } +}; + +} // namespace rc + #endif diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index 272caaffde..bb2286a9cd 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_UTILS_VARIANT_H #define _FLEXFLOW_UTILS_VARIANT_H +#include "rapidcheck.h" #include "utils/type_traits.h" #include #include @@ -212,4 +213,15 @@ std::optional cast(VariantIn const &v) { } // namespace FlexFlow +namespace rc { + +template +struct Arbitrary> { + static Gen> arbitrary() { + return gen::oneOf(gen::cast>(gen::arbitrary())...); + } +}; + +} // namespace rc + #endif diff --git a/lib/utils/src/exception.cc b/lib/utils/src/exception.cc index 7dccdc3074..5f78491ef2 100644 --- a/lib/utils/src/exception.cc +++ b/lib/utils/src/exception.cc @@ -2,7 +2,12 @@ namespace FlexFlow { -not_implemented::not_implemented() - : std::logic_error("Function not yet implemented"){}; +not_implemented::not_implemented(std::string const &function_name, + std::string const &file_name, + int line) + : std::logic_error(fmt::format("Function '{}' not yet implemented at {}:{}", + function_name, + file_name, + line)){}; } diff --git a/lib/utils/src/utils/graph/multidiedge.cc b/lib/utils/src/utils/graph/multidiedge.cc new file mode 100644 index 0000000000..cd3655c8e6 --- /dev/null +++ b/lib/utils/src/utils/graph/multidiedge.cc @@ -0,0 +1,17 @@ +#include "utils/graph/multidiedge.h" + +namespace FlexFlow { + +bool MultiDiOutput::operator>(MultiDiOutput const &other) const { + return !(*this < other) && !(*this == other); +} + +bool MultiDiOutput::operator>=(MultiDiOutput const &other) const { + return !(*this < other); +} + +bool MultiDiOutput::operator<=(MultiDiOutput const &other) const { + return (*this < other) || (*this == other); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/integer_conversions.cc b/lib/utils/src/utils/integer_conversions.cc new file mode 100644 index 0000000000..34ee3109bf --- /dev/null +++ b/lib/utils/src/utils/integer_conversions.cc @@ -0,0 +1,17 @@ +#include "utils/integer_conversions.h" +#include +#include + +namespace FlexFlow { + +size_t size_t_from_int(int x) { + assert(x >= 0); + return static_cast(x); +} + +int int_from_size_t(size_t x) { + assert(x < std::numeric_limits::max()); + return static_cast(x); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/overload.cc b/lib/utils/src/utils/overload.cc new file mode 100644 index 0000000000..55bfbdc08d --- /dev/null +++ b/lib/utils/src/utils/overload.cc @@ -0,0 +1 @@ +#include "utils/overload.h" diff --git a/lib/utils/test/common/include/test/utils/doctest.h b/lib/utils/test/common/include/test/utils/doctest.h index 47c7ebde6d..ff7683dbcd 100644 --- a/lib/utils/test/common/include/test/utils/doctest.h +++ b/lib/utils/test/common/include/test/utils/doctest.h @@ -1,7 +1,9 @@ -#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN #include "doctest/doctest.h" #include "utils/containers.decl.h" +#include "utils/fmt/expected.h" +#include #include +#include #include #include #include @@ -64,10 +66,11 @@ namespace doctest { // } // }; -// template -// struct StringMaker> { -// static String convert(std::vector const &vec) { -// return doctest_print_container(vec, "[ ", ", ", " ]").c_str(); -// } -// }; +template +struct StringMaker> { + static String convert(tl::expected const &m) { + return toString(fmt::to_string(m)); + } +}; + } // namespace doctest diff --git a/lib/utils/test/common/src/main.cc b/lib/utils/test/common/src/main.cc new file mode 100644 index 0000000000..9522fa7fdb --- /dev/null +++ b/lib/utils/test/common/src/main.cc @@ -0,0 +1,2 @@ +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#include "doctest/doctest.h" diff --git a/lib/utils/test/src/test_optional.cc b/lib/utils/test/src/test_optional.cc new file mode 100644 index 0000000000..8ef9e18f18 --- /dev/null +++ b/lib/utils/test/src/test_optional.cc @@ -0,0 +1,10 @@ +#include "test/utils/doctest.h" +#include "utils/optional.h" +#include + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE_TEMPLATE("RC arbitrary", T, int, double, char) { + CHECK(rc::check("generate", [](std::optional o) {})); + } +} diff --git a/lib/utils/test/src/test_stack_vector.cc b/lib/utils/test/src/test_stack_vector.cc index 6c0ecf36f3..141cd30e95 100644 --- a/lib/utils/test/src/test_stack_vector.cc +++ b/lib/utils/test/src/test_stack_vector.cc @@ -1,6 +1,7 @@ #include "test/utils/doctest.h" #include "utils/stack_vector.h" #include +#include using namespace FlexFlow; @@ -76,4 +77,11 @@ TEST_SUITE(FF_TEST_SUITE) { vector.push_back(20); CHECK(vector.back() == 20); } + + TEST_CASE_TEMPLATE("RC arbitrary", T, int, double, char) { + constexpr std::size_t MAXSIZE = 10; + CHECK(rc::check("within bound", [](stack_vector v) { + return v.size() <= MAXSIZE; + })); + } } diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/test_variant.cc index 0fef782c0e..7cffe9fbe4 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/test_variant.cc @@ -1,5 +1,6 @@ #include "test/utils/doctest.h" #include "utils/variant.h" +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("widen and narrow functions") { @@ -69,4 +70,10 @@ TEST_SUITE(FF_TEST_SUITE) { // Check the result CHECK(get(wider_variant) == 42); } + + TEST_CASE("RC arbitrary") { + CHECK(rc::check("valid type", [](std::variant v) { + return std::holds_alternative(v) || std::holds_alternative(v); + })); + } }