From 3786e0ed987b9f67a738f8fcb575c466acc000f2 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 4 Apr 2024 04:38:28 -0700 Subject: [PATCH 01/43] Add initial lib/substitution-generator and bin/substitutions-to-dot --- .github/workflows/per-lib-check.yml | 8 + CMakeLists.txt | 2 +- bin/CMakeLists.txt | 2 +- bin/substitutions-to-dot/CMakeLists.txt | 8 + .../substitution_to_dot.cc | 14 +- bin/substitutions_to_dot/CMakeLists.txt | 12 - cmake/flexflow-utils.cmake | 32 +++ lib/CMakeLists.txt | 1 + lib/op-attrs/include/op-attrs/op.h | 83 +++++++ lib/substitution-generator/CMakeLists.txt | 17 ++ .../include/substitution-generator/json.h | 209 ++++++++++++++++++ .../src/substitution-generator/json.cc | 73 ++++++ .../test/CMakeLists.txt | 20 ++ .../test/substitution-generator/json.cc | 35 +++ lib/utils/include/utils/fmt.decl.h | 14 +- lib/utils/include/utils/fmt.h | 69 ++---- lib/utils/include/utils/graph/query_set.h | 3 + lib/utils/include/utils/required_core.h | 3 + lib/utils/include/utils/strong_typedef.h | 1 + lib/utils/include/utils/visitable.h | 1 + 20 files changed, 528 insertions(+), 79 deletions(-) create mode 100644 bin/substitutions-to-dot/CMakeLists.txt rename bin/{substitutions_to_dot => substitutions-to-dot}/substitution_to_dot.cc (91%) delete mode 100644 bin/substitutions_to_dot/CMakeLists.txt create mode 100644 lib/substitution-generator/CMakeLists.txt create mode 100644 lib/substitution-generator/include/substitution-generator/json.h create mode 100644 lib/substitution-generator/src/substitution-generator/json.cc create mode 100644 lib/substitution-generator/test/CMakeLists.txt create mode 100644 lib/substitution-generator/test/substitution-generator/json.cc diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 874a298587..a53a6afc11 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -80,6 +80,10 @@ jobs: run: | build_libs.sh compiler + - name: Build substitution-generator + run: | + build_libs.sh substitution-generator + - name: Test utils run: | test_libs.sh utils @@ -91,3 +95,7 @@ jobs: - name: Test compiler run: | test_libs.sh compiler + + - name: Test substitution-generator + run: | + test_libs.sh substitution-generator diff --git a/CMakeLists.txt b/CMakeLists.txt index d6f43f366c..3ae04341fb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,7 +71,7 @@ option(FF_BUILD_SPLIT_TEST_2 "build split test 2 example" OFF) option(FF_BUILD_ALL_EXAMPLES "build all examples. Overrides others" OFF) option(FF_BUILD_UNIT_TESTS "build non-operator unit tests" OFF) option(FF_BUILD_SUBSTITUTION_TOOL "build substitution conversion tool" OFF) -option(FF_BUILD_VISUALIZATION_TOOL "build substitution visualization tool" OFF) +option(FF_BUILD_VISUALIZATION_TOOL "build substitution visualization tool" ON) option(FF_BUILD_ARG_PARSER "build command line argument parser" OFF) set(FF_CUDA_ARCH "autodetect" CACHE STRING "Target CUDA Arch") diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 2d18e77620..fcc19b33b9 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -7,7 +7,7 @@ if(FF_BUILD_SUBSTITUTION_TOOL) endif() if(FF_BUILD_VISUALIZATION_TOOL) - add_subdirectory(substitutions_to_dot) + add_subdirectory(substitutions-to-dot) endif() if(FF_BUILD_ARG_PARSER) diff --git a/bin/substitutions-to-dot/CMakeLists.txt b/bin/substitutions-to-dot/CMakeLists.txt new file mode 100644 index 0000000000..ed9b017d52 --- /dev/null +++ b/bin/substitutions-to-dot/CMakeLists.txt @@ -0,0 +1,8 @@ +ff_add_executable( + NAME + substitution-to-dot + SRC_PATTERNS + *.cc + DEPS + substitution-generator +) diff --git a/bin/substitutions_to_dot/substitution_to_dot.cc b/bin/substitutions-to-dot/substitution_to_dot.cc similarity index 91% rename from bin/substitutions_to_dot/substitution_to_dot.cc rename to bin/substitutions-to-dot/substitution_to_dot.cc index a2ee8af815..6b48f140b2 100644 --- a/bin/substitutions_to_dot/substitution_to_dot.cc +++ b/bin/substitutions-to-dot/substitution_to_dot.cc @@ -1,12 +1,9 @@ -#include "ffc/substitution_loader.h" -#include "op-meta/ffconst.h" -#include "tl/optional.hpp" +#include "substitution-generator/json.h" #include "utils/dot_file.h" #include #include -using namespace FlexFlow::substitution_loader; -using FlexFlow::opmeta::get_operator_type_name; +using namespace FlexFlow; enum class NodeType { SRC, @@ -29,7 +26,7 @@ int main(int argc, char **argv) { RuleCollection rule_collection = load_rule_collection_from_path(json_path); - tl::optional found = tl::nullopt; + std::optional found = std::nullopt; for (Rule const &r : rule_collection.rules) { if (r.name == rule_name) { found = r; @@ -90,8 +87,7 @@ int main(int argc, char **argv) { { dot.add_node( srcOpNode, - label_map(FlexFlow::opmeta::get_operator_type_name(o.op_type), - srcOpNode)); + label_map(fmt::to_string(o.op_type), srcOpNode)); dot.add_node_to_subgraph(srcOpNode, src_body_subgraph); } @@ -117,7 +113,7 @@ int main(int argc, char **argv) { { dot.add_node( dstOpNode, - label_map(FlexFlow::opmeta::get_operator_type_name(o.op_type), + label_map(fmt::to_string(o.op_type), dstOpNode)); dot.add_node_to_subgraph(dstOpNode, dst_body_subgraph); } diff --git a/bin/substitutions_to_dot/CMakeLists.txt b/bin/substitutions_to_dot/CMakeLists.txt deleted file mode 100644 index 7cdc4627ef..0000000000 --- a/bin/substitutions_to_dot/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -cmake_minimum_required(VERSION 3.6) - -include(json) - -project(visualizeTool) -set(project_target visualize) - -add_executable(${project_target} substitution_to_dot.cc) -#include_directories(${CMAKE_CURRENT_BINARY_DIR}) -message("flexflow include dirs: ${FLEXFLOW_INCLUDE_DIRS}") -target_include_directories(${project_target} PRIVATE ${FLEXFLOW_INCLUDE_DIRS} ${CMAKE_INSTALL_INCLUDEDIR}) -target_link_libraries(${project_target} -Wl,--whole-archive ${FLEXFLOW_LIBS} -Wl,--no-whole-archive ${FLEXFLOW_EXT_LIBRARIES} nlohmann_json::nlohmann_json substitution_loader) diff --git a/cmake/flexflow-utils.cmake b/cmake/flexflow-utils.cmake index 4cf5450942..32798e6833 100644 --- a/cmake/flexflow-utils.cmake +++ b/cmake/flexflow-utils.cmake @@ -124,3 +124,35 @@ function(ff_add_test_executable) ff_set_cxx_properties(${FF_TEST_EXEC_NAME}) doctest_discover_tests(${FF_TEST_EXEC_NAME} ADD_LABELS 1) endfunction() + +function(ff_add_executable) + ff_parse_args( + PREFIX + FF_EXEC + ARGS + NAME + VARIADIC_ARGS + SRC_PATTERNS + PRIVATE_INCLUDE + DEPS + PARSE + ${ARGN} + ) + + project(${FF_EXEC_NAME}) + file(GLOB_RECURSE SRC + CONFIGURE_DEPENDS + LIST_DIRECTORIES False + ${FF_EXEC_SRC_PATTERNS}) + + add_executable( + ${FF_EXEC_NAME} + ${SRC}) + + target_link_libraries( + ${FF_EXEC_NAME} + ${FF_EXEC_DEPS}) + + define_ff_vars(${FF_EXEC_NAME}) + ff_set_cxx_properties(${FF_EXEC_NAME}) +endfunction() diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index f7c166f0dd..a73da48fac 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -6,3 +6,4 @@ add_subdirectory(kernels) add_subdirectory(utils) add_subdirectory(ffi) add_subdirectory(substitutions) +add_subdirectory(substitution-generator) diff --git a/lib/op-attrs/include/op-attrs/op.h b/lib/op-attrs/include/op-attrs/op.h index fee3deee56..9ad83c3641 100644 --- a/lib/op-attrs/include/op-attrs/op.h +++ b/lib/op-attrs/include/op-attrs/op.h @@ -111,171 +111,254 @@ struct formatter<::FlexFlow::Op> : formatter { switch (ot) { case Op::CONV2D: name = "Conv2D"; + break; case Op::DROPOUT: name = "Dropout"; + break; case Op::LINEAR: name = "Dense"; + break; case Op::BATCHMATMUL: name = "BatchMatMul"; + break; case Op::POOL2D: name = "Pool2D"; + break; case Op::SCALAR_MULTIPLY: name = "ScalarMultiply"; + break; case Op::SCALAR_ADD: name = "ScalarAdd"; + break; case Op::SCALAR_FLOOR_DIV: name = "ScalarFloorDiv"; + break; case Op::SCALAR_TRUE_DIV: name = "ScalarTrueDiv"; + break; case Op::SCALAR_SUB: name = "ScalarSub"; + break; case Op::RELU: name = "ReLU"; + break; case Op::SIGMOID: name = "Sigmoid"; + break; case Op::TANH: name = "Tanh"; + break; case Op::ELU: name = "Elu"; + break; case Op::FLAT: name = "Flat"; + break; case Op::SOFTMAX: name = "Softmax"; + break; case Op::BATCHNORM: name = "BatchNorm"; + break; case Op::CONCAT: name = "Concat"; + break; case Op::SPLIT: name = "Split"; + break; case Op::EMBEDDING: name = "Embedding"; + break; case Op::GATHER: name = "Gather"; + break; case Op::CACHE: name = "Cache"; + break; case Op::RESHAPE: name = "Reshape"; + break; case Op::REVERSE: name = "Reverse"; + break; case Op::TRANSPOSE: name = "Transpose"; + break; case Op::EW_ADD: name = "Add"; + break; case Op::EW_MUL: name = "Mul"; + break; case Op::MATMUL: name = "Matmul"; + break; case Op::MUL: name = "Mul"; + break; case Op::ENLARGE: name = "Enlarge"; + break; case Op::SQUEEZE: name = "Squeeze"; + break; case Op::UNSQUEEZE: name = "Unsqueeze"; + break; case Op::EW_SUB: name = "Sub"; + break; case Op::EW_DIV: name = "Div"; + break; case Op::EW_EQUAL: name = "Equal"; + break; case Op::EW_GREATER: name = "Greater"; + break; case Op::EW_LESS: name = "Less"; + break; case Op::EW_MAX: name = "Max"; + break; case Op::EW_MIN: name = "Min"; + break; case Op::REDUCE_ARGMAX: name = "ReduceArgMax"; + break; case Op::REDUCE_ARGMIN: name = "ReduceArgMin"; + break; case Op::REDUCE_MAX: name = "ReduceMax"; + break; case Op::REDUCE_MEAN: name = "ReduceMean"; + break; case Op::REDUCE_MIN: name = "ReduceMin"; + break; case Op::REDUCE_PROD: name = "ReduceProd"; + break; case Op::REDUCE_SUM: name = "ReduceSum"; + break; case Op::PAD: name = "Pad"; + break; case Op::SHAPE: name = "Shape"; + break; case Op::SIZE: name = "Size"; + break; case Op::TOPK: name = "TopK"; + break; case Op::WHERE: name = "Where"; + break; case Op::CEIL: name = "Ceil"; + break; case Op::CAST: name = "Cast"; + break; case Op::EXP: name = "Exp"; + break; case Op::SIN: name = "Sin"; + break; case Op::COS: name = "Cos"; + break; case Op::ROUND: name = "Round"; + break; case Op::LOG: name = "Log"; + break; case Op::LOGICAL_NOT: name = "LogicalNot"; + break; case Op::SQRT: name = "Sqrt"; + break; case Op::LEAKYRELU: name = "LeakyReLU"; + break; case Op::SLICE: name = "Slice"; + break; case Op::RESIZE: name = "Resize"; + break; case Op::PRELU: name = "PReLU"; + break; case Op::MULTIHEAD_ATTENTION: name = "MultiHeadAttention"; + break; case Op::INPUT: name = "Input"; + break; case Op::WEIGHT: name = "Weight"; + break; case Op::NOOP: name = "NoOp"; + break; case Op::FUSED: name = "FusedOp"; + break; case Op::RSQRT: name = "Rsqrt"; + break; case Op::POW: name = "Pow"; + break; case Op::MEAN: name = "Mean"; + break; case Op::LAYERNORM: name = "LayerNorm"; + break; case Op::IDENTITY: name = "Identity"; + break; // Parallel Ops case Op::REPARTITION: name = "Repartition"; + break; case Op::COMBINE: name = "Combine"; + break; case Op::REPLICATE: name = "Replicate"; + break; case Op::REDUCTION: name = "Reduction"; + break; case Op::PIPELINE: name = "Pipeline"; + break; case Op::FUSED_PARALLEL: name = "FusedParallelOp"; + break; case Op::GELU: name = "GeLU"; + break; case Op::BROADCAST: name = "Broadcast"; + break; case Op::BATCH: name = "Batch"; + break; } return formatter::format(name, ctx); } diff --git a/lib/substitution-generator/CMakeLists.txt b/lib/substitution-generator/CMakeLists.txt new file mode 100644 index 0000000000..41005e6a4e --- /dev/null +++ b/lib/substitution-generator/CMakeLists.txt @@ -0,0 +1,17 @@ +ff_add_library( + NAME + substitution-generator + SRC_PATTERNS + src/*.cc + PUBLIC_INCLUDE + include/ + PRIVATE_INCLUDE + src/ + DEPS + utils + op-attrs + pcg +) + +# add_subdirectory(ffi) +add_subdirectory(test) diff --git a/lib/substitution-generator/include/substitution-generator/json.h b/lib/substitution-generator/include/substitution-generator/json.h new file mode 100644 index 0000000000..7011345aae --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/json.h @@ -0,0 +1,209 @@ +#ifndef _FLEXFLOW_SUBSTITUTION_LOADER_H +#define _FLEXFLOW_SUBSTITUTION_LOADER_H + +#include +#include +#include "op-attrs/op.h" + +namespace FlexFlow { + +enum PMParameter { + PM_OP_TYPE, // AnyOp + PM_NUM_INPUTS, // AnyOp + PM_NUM_OUTPUTS, // AnyOp + PM_GROUP, // Conv2D + PM_KERNEL_H, // Conv2D, Pool2D + PM_KERNEL_W, // Conv2D, Pool2D + PM_STRIDE_H, // Conv2D, Pool2D + PM_STRIDE_W, // Conv2D, Pool2D + PM_PADDING_H, // Conv2D, Pool2D + PM_PADDING_W, // Conv2D, Pool2D + PM_ACTI, // Conv2D, Pool2D + PM_NUMDIM, // Concat, Transpose + PM_AXIS, // Concat, Split + PM_PERM, // Transpose + PM_OUTSHUFFLE, // Transpose + PM_MERGE_GCONV_COUNT, // MergeGConv + PM_AXES, // Squeeze, Unsqueeze, Reduce* + PM_KEEP_DIMS, // Reduce* + PM_EPSILON, // BatchNorm + PM_REPARTITION_DIM, // Repartition + PM_REPARTITION_DEGREE, // Repartition + PM_REPLICATE_DIM, // Replicate + PM_REPLICATE_DEGREE, // Replicate + PM_COMBINE_DIM, // Combine + PM_COMBINE_DEGREE, // Combine + PM_REDUCTION_DIM, // Reduction + PM_REDUCTION_DEGREE, // Reduction + PM_SOFTMAX_DIM, // Softmax + PM_NUM_HEADS, // MultiHeadAttention + PM_INVALID, + PM_PARALLEL_DIM, + PM_PARALLEL_DEGREE, + PM_PAD, +}; + +NLOHMANN_JSON_SERIALIZE_ENUM(PMParameter, + {{PM_INVALID, nullptr}, + {PM_OP_TYPE, "PM_OP_TYPE"}, + {PM_NUM_INPUTS, "PM_NUM_INPUTS"}, + {PM_NUM_OUTPUTS, "PM_NUM_OUTPUTS"}, + {PM_GROUP, "PM_GROUP"}, + {PM_KERNEL_H, "PM_KERNEL_H"}, + {PM_KERNEL_W, "PM_KERNEL_W"}, + {PM_STRIDE_H, "PM_STRIDE_H"}, + {PM_STRIDE_W, "PM_STRIDE_W"}, + {PM_PADDING_H, "PM_PADDING_H"}, + {PM_PADDING_W, "PM_PADDING_W"}, + {PM_ACTI, "PM_ACTI"}, + {PM_NUMDIM, "PM_NUMDIM"}, + {PM_AXIS, "PM_AXIS"}, + {PM_PERM, "PM_PERM"}, + {PM_OUTSHUFFLE, "PM_OUTSHUFFLE"}, + {PM_MERGE_GCONV_COUNT, "PM_MERGE_GCONV_COUNT"}, + {PM_AXES, "PM_AXES"}, + {PM_KEEP_DIMS, "PM_KEEP_DIMS"}, + {PM_EPSILON, "PM_EPSILON"}, + {PM_REPARTITION_DIM, "PM_REPARTITION_DIM"}, + {PM_REPARTITION_DEGREE, "PM_REPARTITION_DEGREE"}, + {PM_REPLICATE_DIM, "PM_REPLICATE_DIM"}, + {PM_REPLICATE_DEGREE, "PM_REPLICATE_DEGREE"}, + {PM_COMBINE_DIM, "PM_COMBINE_DIM"}, + {PM_COMBINE_DEGREE, "PM_COMBINE_DEGREE"}, + {PM_REDUCTION_DIM, "PM_REDUCTION_DIM"}, + {PM_REDUCTION_DEGREE, "PM_REDUCTION_DEGREE"}, + {PM_SOFTMAX_DIM, "PM_SOFTMAX_DIM"}, + {PM_NUM_HEADS, "PM_NUM_HEADS"}, + {PM_PARALLEL_DIM, "PM_PARALLEL_DIM"}, + {PM_PARALLEL_DEGREE, "PM_PARALLEL_DEGREE"}, + {PM_PAD, "PM_PAD"}}) + +NLOHMANN_JSON_SERIALIZE_ENUM(Op, + {{Op::NOOP, "OP_NOOP"}, + {Op::CONV2D, "OP_CONV2D"}, + {Op::DROPOUT, "OP_DROPOUT"}, + {Op::LINEAR, "OP_LINEAR"}, + {Op::BATCHMATMUL, "OP_BATCHMATMUL"}, + {Op::POOL2D, "OP_POOL2D_MAX"}, + {Op::SCALAR_MULTIPLY, "OP_SCALAR_MULTIPLY"}, + {Op::SCALAR_ADD, "OP_SCALAR_ADD"}, + {Op::SCALAR_FLOOR_DIV, "OP_SCALAR_FLOOR_DIV"}, + {Op::SCALAR_TRUE_DIV, "OP_SCALAR_TRUE_DIV"}, + {Op::SCALAR_SUB, "OP_SCALAR_SUB"}, + {Op::RELU, "OP_RELU"}, + {Op::IDENTITY, "OP_IDENTITY"}, + {Op::SIGMOID, "OP_SIGMOID"}, + {Op::TANH, "OP_TANH"}, + {Op::ELU, "OP_ELU"}, + {Op::FLAT, "OP_FLAT"}, + {Op::SOFTMAX, "OP_SOFTMAX"}, + {Op::BATCHNORM, "OP_BATCHNORM"}, + {Op::CONCAT, "OP_CONCAT"}, + {Op::SPLIT, "OP_SPLIT"}, + {Op::EMBEDDING, "OP_EMBEDDING"}, + {Op::CACHE, "OP_CACHE"}, + {Op::RESHAPE, "OP_RESHAPE"}, + {Op::REVERSE, "OP_REVERSE"}, + {Op::TRANSPOSE, "OP_TRANSPOSE"}, + {Op::EW_ADD, "OP_EW_ADD"}, + {Op::EW_MUL, "OP_EW_MUL"}, + {Op::MATMUL, "OP_MATMUL"}, + {Op::MUL, "OP_MUL"}, + {Op::ENLARGE, "OP_ENLARGE"}, + {Op::SQUEEZE, "OP_SQUEEZE"}, + {Op::UNSQUEEZE, "OP_UNSQUEEZE"}, + {Op::EW_SUB, "OP_EW_SUB"}, + {Op::EW_DIV, "OP_EW_DIV"}, + {Op::EW_EQUAL, "OP_EW_EQUAL"}, + {Op::EW_GREATER, "OP_EW_GREATER"}, + {Op::EW_LESS, "OP_EW_LESS"}, + {Op::EW_MAX, "OP_EW_MAX"}, + {Op::EW_MIN, "OP_EW_MIN"}, + {Op::REDUCE_ARGMAX, "OP_REDUCE_ARGMAX"}, + {Op::REDUCE_ARGMIN, "OP_REDUCE_ARGMIN"}, + {Op::REDUCE_MAX, "OP_REDUCE_MAX"}, + {Op::REDUCE_MEAN, "OP_REDUCE_MEAN"}, + {Op::REDUCE_MIN, "OP_REDUCE_MIN"}, + {Op::REDUCE_PROD, "OP_REDUCE_PROD"}, + {Op::REDUCE_SUM, "OP_REDUCE_SUM"}, + {Op::PAD, "OP_PAD"}, + {Op::SHAPE, "OP_SHAPE"}, + {Op::SIZE, "OP_SIZE"}, + {Op::TOPK, "OP_TOPK"}, + {Op::WHERE, "OP_WHERE"}, + {Op::CEIL, "OP_CEIL"}, + {Op::CAST, "OP_CAST"}, + {Op::EXP, "OP_EXP"}, + {Op::ROUND, "OP_ROUND"}, + {Op::LOG, "OP_LOG"}, + {Op::LOGICAL_NOT, "OP_LOGICAL_NOT"}, + {Op::SQRT, "OP_SQRT"}, + {Op::SIN, "OP_SIN"}, + {Op::COS, "OP_COS"}, + {Op::LEAKYRELU, "OP_LEAKYRELU"}, + {Op::SLICE, "OP_SLICE"}, + {Op::RESIZE, "OP_RESIZE"}, + {Op::PRELU, "OP_PRELU"}, + {Op::GELU, "OP_GELU"}, + {Op::MULTIHEAD_ATTENTION, + "OP_MULTIHEAD_ATTENTION"}, + {Op::FUSED, "OP_FUSED"}, + {Op::RSQRT, "OP_RSQRT"}, + {Op::POW, "OP_POW"}, + {Op::MEAN, "OP_MEAN"}, + {Op::LAYERNORM, "OP_LAYERNORM"}, + {Op::REPARTITION, "OP_PARTITION"}, + {Op::COMBINE, "OP_COMBINE"}, + {Op::REPLICATE, "OP_REPLICATE"}, + {Op::REDUCTION, "OP_REDUCE"}, + {Op::PIPELINE, "OP_PIPELINE"}, + {Op::FUSED_PARALLEL, "OP_FUSED_PARALLEL"}}) + +struct Parameter { + PMParameter key; + int value; +}; +void from_json(nlohmann::json const &j, Parameter &p); + +struct Tensor { + int opId; + int tsId; +}; +void from_json(nlohmann::json const &j, Tensor &t); + +struct Operator { + OperatorType op_type; + std::vector input; + std::vector para; + + std::optional at(PMParameter key) const; +}; +void from_json(nlohmann::json const &j, Operator &t); + +struct MapOutput { + int dstOpId; + int dstTsId; + int srcOpId; + int srcTsId; +}; +void from_json(nlohmann::json const &j, MapOutput &t); + +struct Rule { + std::string name; + std::vector srcOp; + std::vector dstOp; + std::vector mappedOutput; +}; +void from_json(nlohmann::json const &j, Rule &t); + +struct RuleCollection { + std::vector rules; +}; +void from_json(nlohmann::json const &j, RuleCollection &c); + +RuleCollection load_rule_collection(std::istream &s); +RuleCollection load_rule_collection_from_path(std::string const &path); + +} // namespace FlexFlow + +#endif // _FLEXFLOW_SUBSTITUTION_LOADER_H diff --git a/lib/substitution-generator/src/substitution-generator/json.cc b/lib/substitution-generator/src/substitution-generator/json.cc new file mode 100644 index 0000000000..7e6a93b863 --- /dev/null +++ b/lib/substitution-generator/src/substitution-generator/json.cc @@ -0,0 +1,73 @@ +#include "substitution-generator/json.h" +#include +#include +#include + +using json = nlohmann::json; + +namespace FlexFlow { + +void from_json(json const &j, Parameter &p) { + j.at("key").get_to(p.key); + j.at("value").get_to(p.value); + if (p.key == PM_INVALID) { + std::ostringstream oss; + oss << "Attempted to load invalid PMParameter: " << j.at("key"); + throw std::runtime_error(oss.str()); + } +} + +void from_json(json const &j, Tensor &t) { + j.at("opId").get_to(t.opId); + j.at("tsId").get_to(t.tsId); +} + +std::optional Operator::at(PMParameter key) const { + std::optional value = std::nullopt; + for (Parameter const &p : this->para) { + if (p.key == key) { + assert(!value.has_value()); + value = p.key; + } + } + + return value; +} + +void from_json(json const &j, Operator &o) { + j.at("type").get_to(o.op_type); + j.at("input").get_to(o.input); + j.at("para").get_to(o.para); +} + +void from_json(json const &j, MapOutput &m) { + j.at("dstOpId").get_to(m.dstOpId); + j.at("dstTsId").get_to(m.dstTsId); + j.at("srcOpId").get_to(m.srcOpId); + j.at("srcTsId").get_to(m.srcTsId); +} + +void from_json(json const &j, Rule &r) { + j.at("name").get_to(r.name); + j.at("srcOp").get_to(r.srcOp); + j.at("dstOp").get_to(r.dstOp); + j.at("mappedOutput").get_to(r.mappedOutput); +} + +void from_json(json const &j, RuleCollection &c) { + j.at("rule").get_to(c.rules); +} + +RuleCollection load_rule_collection(std::istream &s) { + json j; + s >> j; + RuleCollection rule_collection = j; + return rule_collection; +} + +RuleCollection load_rule_collection_from_path(std::string const &path) { + std::ifstream input(path); + return load_rule_collection(input); +} + +} // namespace FlexFlow diff --git a/lib/substitution-generator/test/CMakeLists.txt b/lib/substitution-generator/test/CMakeLists.txt new file mode 100644 index 0000000000..a7374cdf78 --- /dev/null +++ b/lib/substitution-generator/test/CMakeLists.txt @@ -0,0 +1,20 @@ +add_custom_target(copy-test-makefile ALL DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/graph_subst_3_v2.json) +add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/graph_subst_3_v2.json + COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/substitutions/graph_subst_3_v2.json + ${CMAKE_CURRENT_BINARY_DIR}/graph_subst_3_v2.json + DEPENDS ${PROJECT_SOURCE_DIR}/substitutions/graph_subst_3_v2.json) + +ff_add_test_executable( + NAME + substitution-generator-tests + SRC_PATTERNS + substitution-generator/*.cc + PRIVATE_INCLUDE + src/ + DEPS + utils + doctest + substitution-generator + utils-test-common +) +add_dependencies(substitution-generator-tests copy-test-makefile) diff --git a/lib/substitution-generator/test/substitution-generator/json.cc b/lib/substitution-generator/test/substitution-generator/json.cc new file mode 100644 index 0000000000..ed048da1e1 --- /dev/null +++ b/lib/substitution-generator/test/substitution-generator/json.cc @@ -0,0 +1,35 @@ +#include "doctest/doctest.h" +#include "substitution-generator/json.h" + +using namespace FlexFlow; +using nlohmann::json; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("operator json deserialization") { + json j = { + {"_t", "Operator"}, + {"input", + std::vector{{{"_t", "Tensor"}, {"opId", -2}, {"tsId", 0}}, + {{"_t", "Tensor"}, {"opId", -3}, {"tsId", 0}}}}, + {"para", std::vector{}}, + {"type", "OP_EW_ADD"}, + }; + + Operator o; + from_json(j, o); + + CHECK(o.op_type == Op::EW_ADD); + CHECK(o.input.size() == 2); + CHECK(o.input[0].opId == -2); + CHECK(o.input[0].tsId == 0); + CHECK(o.input[1].opId == -3); + CHECK(o.input[1].tsId == 0); + CHECK(o.para.size() == 0); + } + + TEST_CASE("deserialize full file") { + RuleCollection collection = + load_rule_collection_from_path("graph_subst_3_v2.json"); + CHECK(collection.rules.size() == 640); + } +} diff --git a/lib/utils/include/utils/fmt.decl.h b/lib/utils/include/utils/fmt.decl.h index 367a712b87..6b0260eb15 100644 --- a/lib/utils/include/utils/fmt.decl.h +++ b/lib/utils/include/utils/fmt.decl.h @@ -5,18 +5,26 @@ #include #include +#define CHECK_FMTABLE(...) \ + static_assert(::FlexFlow::is_fmtable<__VA_ARGS__>::value, \ + #__VA_ARGS__ " must be fmtable"); + +#define DELEGATE_OSTREAM(...) \ + template <> \ + struct delegate_ostream_operator<__VA_ARGS__> : std::true_type {} + namespace FlexFlow { template using is_fmtable = ::fmt::is_formattable; template -struct already_has_ostream_operator; +struct delegate_ostream_operator : std::false_type {}; template -typename std::enable_if::value, +typename std::enable_if>::value, std::ostream &>::type - operator<<(std::ostream &s, T const &t); + operator<<(std::ostream &s, T); } // namespace FlexFlow diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index 905b4622f1..9cb56e4e2b 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -5,64 +5,29 @@ #include "utils/fmt.decl.h" #include "utils/test_types.h" #include "utils/type_traits_core.h" +#include #include namespace FlexFlow { -template -struct already_has_ostream_operator : std::false_type {}; +template +struct delegate_ostream_operator> : std::true_type {}; -template <> -struct already_has_ostream_operator : std::true_type {}; - -template <> -struct already_has_ostream_operator : std::true_type {}; - -template <> -struct already_has_ostream_operator : std::true_type {}; - -template -struct already_has_ostream_operator : std::true_type {}; - -template <> -struct already_has_ostream_operator : std::true_type {}; - -template <> -struct already_has_ostream_operator> : std::true_type {}; - -template <> -struct already_has_ostream_operator : std::true_type {}; - -// This will create an error -/* template -std::ostream & -operator<<(std::ostream &s, T const &t) { - return s << "FlexFlow::ostream<<"; -} -*/ - -#define CHECK_FMTABLE(...) \ - static_assert(::FlexFlow::is_fmtable<__VA_ARGS__>::value, \ - #__VA_ARGS__ " must be fmtable"); +struct delegate_ostream_operator> : std::true_type {}; -// This will not -/* template */ -/* typename std::enable_if::value, */ -/* std::ostream &>::type */ -/* operator<<(std::ostream &s, T const &t) { */ -/* // CHECK_FMTABLE(T); */ +template +struct delegate_ostream_operator> : std::true_type {}; -/* std::string result = fmt::to_string(t); */ -/* return s << result; */ -/* } */ +template +typename std::enable_if>::value, + std::ostream &>::type + operator<<(std::ostream &s, T t) { + CHECK_FMTABLE(T); -// template -// typename std::enable_if::value, std::ostream &>::type -// operator<<(std::ostream &s, T const &t) { -// return s << fmt::to_string(t); -// } + return s << fmt::to_string(t); +} } // namespace FlexFlow @@ -73,7 +38,7 @@ template auto formatter<::std::unordered_set>::format( ::std::unordered_set const &m, FormatContext &ctx) -> decltype(ctx.out()) { - // CHECK_FMTABLE(T); + CHECK_FMTABLE(T); std::string result = join_strings( m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); @@ -85,15 +50,13 @@ template auto formatter<::std::vector>::format(::std::vector const &m, FormatContext &ctx) -> decltype(ctx.out()) { - // CHECK_FMTABLE(T); + CHECK_FMTABLE(T); + std::string result = join_strings( m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); return formatter::format(result, ctx); } -// CHECK_FMTABLE(std::vector); -// CHECK_FMTABLE(std::unordered_set); - } // namespace fmt #endif diff --git a/lib/utils/include/utils/graph/query_set.h b/lib/utils/include/utils/graph/query_set.h index c835afa6a6..dda06e997f 100644 --- a/lib/utils/include/utils/graph/query_set.h +++ b/lib/utils/include/utils/graph/query_set.h @@ -59,6 +59,9 @@ std::string format_as(query_set const &q) { } } +template +struct delegate_ostream_operator> : std::true_type {}; + template query_set matchall() { return query_set::matchall(); diff --git a/lib/utils/include/utils/required_core.h b/lib/utils/include/utils/required_core.h index 5677e84b7d..643315ff64 100644 --- a/lib/utils/include/utils/required_core.h +++ b/lib/utils/include/utils/required_core.h @@ -190,6 +190,9 @@ struct required::value>::type> template using req = required; +template +struct delegate_ostream_operator> : std::true_type {}; + template struct remove_req { using type = T; diff --git a/lib/utils/include/utils/strong_typedef.h b/lib/utils/include/utils/strong_typedef.h index f700a20c79..85f4ea742f 100644 --- a/lib/utils/include/utils/strong_typedef.h +++ b/lib/utils/include/utils/strong_typedef.h @@ -226,6 +226,7 @@ struct numerical_typedef : strong_typedef { } \ MAKE_TYPEDEF_PRINTABLE(::FlexFlow::TYPEDEF_NAME, TYPEDEF_SHORTNAME); \ namespace FlexFlow { \ + DELEGATE_OSTREAM(TYPEDEF_NAME); \ static_assert(true, ""); #endif diff --git a/lib/utils/include/utils/visitable.h b/lib/utils/include/utils/visitable.h index 6859b7a785..14f7679940 100644 --- a/lib/utils/include/utils/visitable.h +++ b/lib/utils/include/utils/visitable.h @@ -333,6 +333,7 @@ struct visitable_formatter : public ::fmt::formatter { : ::FlexFlow::visitable_formatter<::FlexFlow::TYPENAME> {}; \ } \ namespace FlexFlow { \ + DELEGATE_OSTREAM(::FlexFlow::TYPENAME); \ static_assert(is_fmtable::value, \ #TYPENAME \ " failed sanity check on is_fmtable and FF_VISIT_FMTABLE"); \ From 2890740abe09df8b05fb9bb84d33cd7977626fe8 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 4 Apr 2024 04:39:25 -0700 Subject: [PATCH 02/43] Format --- bin/substitutions-to-dot/substitution_to_dot.cc | 9 ++------- .../include/substitution-generator/json.h | 2 +- .../test/substitution-generator/json.cc | 2 +- lib/utils/include/utils/fmt.decl.h | 4 ++-- lib/utils/include/utils/strong_typedef.h | 2 +- lib/utils/include/utils/visitable.h | 2 +- 6 files changed, 8 insertions(+), 13 deletions(-) diff --git a/bin/substitutions-to-dot/substitution_to_dot.cc b/bin/substitutions-to-dot/substitution_to_dot.cc index 6b48f140b2..49a199ddd3 100644 --- a/bin/substitutions-to-dot/substitution_to_dot.cc +++ b/bin/substitutions-to-dot/substitution_to_dot.cc @@ -85,9 +85,7 @@ int main(int argc, char **argv) { Operator const &o = r.srcOp[i]; Node srcOpNode = {NodeType::SRC, i, 0}; { - dot.add_node( - srcOpNode, - label_map(fmt::to_string(o.op_type), srcOpNode)); + dot.add_node(srcOpNode, label_map(fmt::to_string(o.op_type), srcOpNode)); dot.add_node_to_subgraph(srcOpNode, src_body_subgraph); } @@ -111,10 +109,7 @@ int main(int argc, char **argv) { Operator const &o = r.dstOp[j]; Node dstOpNode = {NodeType::DST, j, 0}; { - dot.add_node( - dstOpNode, - label_map(fmt::to_string(o.op_type), - dstOpNode)); + dot.add_node(dstOpNode, label_map(fmt::to_string(o.op_type), dstOpNode)); dot.add_node_to_subgraph(dstOpNode, dst_body_subgraph); } diff --git a/lib/substitution-generator/include/substitution-generator/json.h b/lib/substitution-generator/include/substitution-generator/json.h index 7011345aae..dbde110f8d 100644 --- a/lib/substitution-generator/include/substitution-generator/json.h +++ b/lib/substitution-generator/include/substitution-generator/json.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_SUBSTITUTION_LOADER_H #define _FLEXFLOW_SUBSTITUTION_LOADER_H +#include "op-attrs/op.h" #include #include -#include "op-attrs/op.h" namespace FlexFlow { diff --git a/lib/substitution-generator/test/substitution-generator/json.cc b/lib/substitution-generator/test/substitution-generator/json.cc index ed048da1e1..d12b294a2e 100644 --- a/lib/substitution-generator/test/substitution-generator/json.cc +++ b/lib/substitution-generator/test/substitution-generator/json.cc @@ -1,5 +1,5 @@ -#include "doctest/doctest.h" #include "substitution-generator/json.h" +#include "doctest/doctest.h" using namespace FlexFlow; using nlohmann::json; diff --git a/lib/utils/include/utils/fmt.decl.h b/lib/utils/include/utils/fmt.decl.h index 6b0260eb15..7adb2052ad 100644 --- a/lib/utils/include/utils/fmt.decl.h +++ b/lib/utils/include/utils/fmt.decl.h @@ -9,8 +9,8 @@ static_assert(::FlexFlow::is_fmtable<__VA_ARGS__>::value, \ #__VA_ARGS__ " must be fmtable"); -#define DELEGATE_OSTREAM(...) \ - template <> \ +#define DELEGATE_OSTREAM(...) \ + template <> \ struct delegate_ostream_operator<__VA_ARGS__> : std::true_type {} namespace FlexFlow { diff --git a/lib/utils/include/utils/strong_typedef.h b/lib/utils/include/utils/strong_typedef.h index 85f4ea742f..a2e4a5f57b 100644 --- a/lib/utils/include/utils/strong_typedef.h +++ b/lib/utils/include/utils/strong_typedef.h @@ -226,7 +226,7 @@ struct numerical_typedef : strong_typedef { } \ MAKE_TYPEDEF_PRINTABLE(::FlexFlow::TYPEDEF_NAME, TYPEDEF_SHORTNAME); \ namespace FlexFlow { \ - DELEGATE_OSTREAM(TYPEDEF_NAME); \ + DELEGATE_OSTREAM(TYPEDEF_NAME); \ static_assert(true, ""); #endif diff --git a/lib/utils/include/utils/visitable.h b/lib/utils/include/utils/visitable.h index 14f7679940..ba9ceac053 100644 --- a/lib/utils/include/utils/visitable.h +++ b/lib/utils/include/utils/visitable.h @@ -333,7 +333,7 @@ struct visitable_formatter : public ::fmt::formatter { : ::FlexFlow::visitable_formatter<::FlexFlow::TYPENAME> {}; \ } \ namespace FlexFlow { \ - DELEGATE_OSTREAM(::FlexFlow::TYPENAME); \ + DELEGATE_OSTREAM(::FlexFlow::TYPENAME); \ static_assert(is_fmtable::value, \ #TYPENAME \ " failed sanity check on is_fmtable and FF_VISIT_FMTABLE"); \ From a342be7bcad0a279ddd38192eabd47b8b8474b8d Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 4 Apr 2024 02:43:48 -0700 Subject: [PATCH 03/43] Update proj version and add .proj.toml file to repo directly --- .proj.toml | 21 +++++++++++++++++++++ flake.lock | 6 +++--- flake.nix | 6 ++---- 3 files changed, 26 insertions(+), 7 deletions(-) create mode 100644 .proj.toml diff --git a/.proj.toml b/.proj.toml new file mode 100644 index 0000000000..a4592dcccc --- /dev/null +++ b/.proj.toml @@ -0,0 +1,21 @@ +project_name = "flexflow" +testsuite_macro = "FF_TEST_SUITE" +namespace_name = "FlexFlow" +header_extension = ".h" + +build_targets = [ + "utils", + "op-attrs", + "kernels", + "substitutions", + "compiler", +] +test_targets = [ + "utils-tests", + "substitutions-tests", + "compiler-tests", +] + +[cmake_flags_extra] +FF_CUDA_ARCH = "60" +CMAKE_CUDA_ARCHITECTURES = "60" diff --git a/flake.lock b/flake.lock index ea4187e13c..ffd4a02962 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1711832134, - "narHash": "sha256-2KceZmXOOELnFiVH/wjndH2QtKro+B0W2SEkjkzuDD0=", + "lastModified": 1712222904, + "narHash": "sha256-FRI/RdOTtmo9o7iwZiACD0lSSlgvKqcpppjliXUHyRU=", "owner": "lockshaw", "repo": "proj", - "rev": "1c7c809a6cab8360620bb27470a615a1a0b03a17", + "rev": "5b7a82dc01fa25076a8b3db96c1f2ea4752ae990", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index bd372e4cbf..c3e415cf89 100644 --- a/flake.nix +++ b/flake.nix @@ -16,7 +16,7 @@ inputs = { nixpkgs.url = "nixpkgs/nixos-23.11"; flake-utils.url = "github:numtide/flake-utils"; - + proj-repo = { url = "github:lockshaw/proj"; inputs.nixpkgs.follows = "nixpkgs"; @@ -60,7 +60,7 @@ devShells = rec { ci = mkShell { shellHook = '' - export PATH="$HOME/ff/.scripts/:$PATH" + export PATH="$HOME/ff/.scripts/:$HOME/ff/.modules/proj/bin/:$PATH" ''; CMAKE_FLAGS = lib.strings.concatStringsSep " " [ @@ -71,7 +71,6 @@ "-DFF_USE_EXTERNAL_SPDLOG=ON" "-DFF_USE_EXTERNAL_DOCTEST=ON" "-DFF_USE_EXTERNAL_RAPIDCHECK=ON" - "-DFF_USE_EXTERNAL_EXPECTED=ON" "-DFF_USE_EXTERNAL_RANGEV3=ON" "-DFF_USE_EXTERNAL_BOOST_PREPROCESSOR=ON" "-DFF_USE_EXTERNAL_TYPE_INDEX=ON" @@ -95,7 +94,6 @@ cudaPackages.nccl cudaPackages.libcublas cudaPackages.cuda_cudart - tl-expected ]) (with self.packages.${system}; [ legion From b54699ba070eb86dc2532f1ca9603331fd76d9aa Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 4 Apr 2024 02:45:15 -0700 Subject: [PATCH 04/43] Revert changes to flake.nix --- flake.nix | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/flake.nix b/flake.nix index c3e415cf89..bd372e4cbf 100644 --- a/flake.nix +++ b/flake.nix @@ -16,7 +16,7 @@ inputs = { nixpkgs.url = "nixpkgs/nixos-23.11"; flake-utils.url = "github:numtide/flake-utils"; - + proj-repo = { url = "github:lockshaw/proj"; inputs.nixpkgs.follows = "nixpkgs"; @@ -60,7 +60,7 @@ devShells = rec { ci = mkShell { shellHook = '' - export PATH="$HOME/ff/.scripts/:$HOME/ff/.modules/proj/bin/:$PATH" + export PATH="$HOME/ff/.scripts/:$PATH" ''; CMAKE_FLAGS = lib.strings.concatStringsSep " " [ @@ -71,6 +71,7 @@ "-DFF_USE_EXTERNAL_SPDLOG=ON" "-DFF_USE_EXTERNAL_DOCTEST=ON" "-DFF_USE_EXTERNAL_RAPIDCHECK=ON" + "-DFF_USE_EXTERNAL_EXPECTED=ON" "-DFF_USE_EXTERNAL_RANGEV3=ON" "-DFF_USE_EXTERNAL_BOOST_PREPROCESSOR=ON" "-DFF_USE_EXTERNAL_TYPE_INDEX=ON" @@ -94,6 +95,7 @@ cudaPackages.nccl cudaPackages.libcublas cudaPackages.cuda_cudart + tl-expected ]) (with self.packages.${system}; [ legion From 0eb172a021ff86241c198622675a79bc857bd1e8 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 7 Apr 2024 02:04:27 -0700 Subject: [PATCH 05/43] Prototype implementation of dtgen --- flake.lock | 7 +-- flake.nix | 2 +- lib/op-attrs/include/op-attrs/ops/broadcast.h | 43 +++++++++++---- .../op-attrs/ops/broadcast.struct.toml | 17 ++++++ lib/op-attrs/src/broadcast.cc | 3 -- lib/op-attrs/src/op-attrs/ops/broadcast.cc | 52 +++++++++++++++++++ lib/utils/include/utils/stack_vector.h | 14 +++++ 7 files changed, 122 insertions(+), 16 deletions(-) create mode 100644 lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml delete mode 100644 lib/op-attrs/src/broadcast.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/broadcast.cc diff --git a/flake.lock b/flake.lock index ffd4a02962..c62a25de80 100644 --- a/flake.lock +++ b/flake.lock @@ -43,15 +43,16 @@ ] }, "locked": { - "lastModified": 1712222904, - "narHash": "sha256-FRI/RdOTtmo9o7iwZiACD0lSSlgvKqcpppjliXUHyRU=", + "lastModified": 1712480160, + "narHash": "sha256-CbuzEbFxmgI0Kd7mOwrrDhY1QGMLY5m6HG5dLHkVUus=", "owner": "lockshaw", "repo": "proj", - "rev": "5b7a82dc01fa25076a8b3db96c1f2ea4752ae990", + "rev": "041fa8ecf00dec6276d71ed70b275b53dccb1c93", "type": "github" }, "original": { "owner": "lockshaw", + "ref": "dtgen", "repo": "proj", "type": "github" } diff --git a/flake.nix b/flake.nix index bd372e4cbf..09a1dfcae5 100644 --- a/flake.nix +++ b/flake.nix @@ -18,7 +18,7 @@ flake-utils.url = "github:numtide/flake-utils"; proj-repo = { - url = "github:lockshaw/proj"; + url = "github:lockshaw/proj/dtgen"; inputs.nixpkgs.follows = "nixpkgs"; inputs.flake-utils.follows = "flake-utils"; }; diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index 433bf23241..a370224e4e 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -1,19 +1,44 @@ -#ifndef _FLEXFLOW_INCLUDE_OPATTRS_OPS_BROADCAST_H -#define _FLEXFLOW_INCLUDE_OPATTRS_OPS_BROADCAST_H +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H #include "core.h" +#include "nlohmann/json.hpp" #include "utils/stack_vector.h" -#include "utils/visitable.h" +#include +#include namespace FlexFlow { - struct BroadcastAttrs { - req> target_dims; + BroadcastAttrs() = delete; + BroadcastAttrs(stack_vector const &target_dims); + + bool operator==(BroadcastAttrs const &) const; + bool operator!=(BroadcastAttrs const &) const; + bool operator<(BroadcastAttrs const &) const; + bool operator>(BroadcastAttrs const &) const; + bool operator<=(BroadcastAttrs const &) const; + bool operator>=(BroadcastAttrs const &) const; + stack_vector target_dims; }; -FF_VISITABLE_STRUCT(BroadcastAttrs, target_dims); +} // namespace FlexFlow -CHECK_VALID_OP_ATTR(BroadcastAttrs); +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::BroadcastAttrs const &) const; +}; +} // namespace std -} // namespace FlexFlow +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::BroadcastAttrs from_json(json const &); + static void to_json(json &, FlexFlow::BroadcastAttrs const &); +}; +} // namespace nlohmann -#endif +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml b/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml new file mode 100644 index 0000000000..b46c080e28 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "BroadcastAttrs" +features = [ + "eq", + "ord", + "hash", + "json", +] + +includes = [ + "core.h", + "utils/stack_vector.h", +] + +[[fields]] +name = "target_dims" +type = "stack_vector" diff --git a/lib/op-attrs/src/broadcast.cc b/lib/op-attrs/src/broadcast.cc deleted file mode 100644 index c69f480b84..0000000000 --- a/lib/op-attrs/src/broadcast.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/broadcast.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.cc new file mode 100644 index 0000000000..2553ac5801 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.cc @@ -0,0 +1,52 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml + +#include "op-attrs/ops/broadcast.h" + +namespace FlexFlow { +BroadcastAttrs::BroadcastAttrs( + stack_vector const &target_dims) + : target_dims(target_dims) {} +bool BroadcastAttrs::operator==(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) == std::tie(other.target_dims); +} +bool BroadcastAttrs::operator!=(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) != std::tie(other.target_dims); +} +bool BroadcastAttrs::operator<(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) < std::tie(other.target_dims); +} +bool BroadcastAttrs::operator>(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) > std::tie(other.target_dims); +} +bool BroadcastAttrs::operator<=(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) <= std::tie(other.target_dims); +} +bool BroadcastAttrs::operator>=(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) >= std::tie(other.target_dims); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::BroadcastAttrs const &x) const { + size_t result = 0; + result ^= std::hash>{}(x.target_dims) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::BroadcastAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("target_dims").template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::BroadcastAttrs const &v) { + j["__type"] = "BroadcastAttrs"; + j["target_dims"] = v.target_dims; +} +} // namespace nlohmann diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index fe665ed749..ce371adeba 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -4,6 +4,7 @@ #include "containers.h" #include "hash-utils.h" #include "utils/fmt.h" +#include "utils/json.h" #include "utils/test_types.h" #include "utils/type_traits.h" #include @@ -312,6 +313,19 @@ struct stack_vector { // CHECK_FMTABLE(stack_vector); +template +void to_json(json &j, stack_vector const &v) { + std::vector as_vec(v.begin(), v.end()); + j = as_vec; +} + +template +void from_json(json const &j, stack_vector &v) { + std::vector as_vec; + j.get_to(as_vec); + v = stack_vector{as_vec.begin(), as_vec.end()}; +} + } // namespace FlexFlow namespace std { From 2586f780ae7f4e6cccfa7b3f8610fac5bf12554b Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 7 Apr 2024 22:56:34 -0700 Subject: [PATCH 06/43] Refactor op-attrs to use dtgen --- .editorconfig | 22 ++ flake.lock | 6 +- flake.nix | 4 + lib/op-attrs/include/op-attrs/activation.h | 7 + lib/op-attrs/include/op-attrs/aggregate_op.h | 22 ++ lib/op-attrs/include/op-attrs/dim_ordered.h | 4 +- lib/op-attrs/include/op-attrs/ff_dim.h | 50 +++- .../include/op-attrs/ff_dim.struct.toml | 15 ++ .../include/op-attrs/l1_regularizer_attrs.h | 58 +++++ .../op-attrs/l1_regularizer_attrs.struct.toml | 14 ++ .../include/op-attrs/l2_regularizer_attrs.h | 58 +++++ .../op-attrs/l2_regularizer_attrs.struct.toml | 14 ++ lib/op-attrs/include/op-attrs/ops/attention.h | 16 +- .../include/op-attrs/ops/attention_attrs.h | 72 ++++++ .../op-attrs/ops/attention_attrs.struct.toml | 43 ++++ .../include/op-attrs/ops/batch_matmul.h | 61 ++++- .../op-attrs/ops/batch_matmul.struct.toml | 19 ++ .../include/op-attrs/ops/batch_norm.h | 6 +- .../include/op-attrs/ops/batch_norm_attrs.h | 58 +++++ .../op-attrs/ops/batch_norm_attrs.struct.toml | 15 ++ lib/op-attrs/include/op-attrs/ops/broadcast.h | 14 +- .../op-attrs/ops/broadcast.struct.toml | 5 +- lib/op-attrs/include/op-attrs/ops/cast.h | 9 +- .../include/op-attrs/ops/cast_attrs.h | 51 ++++ .../op-attrs/ops/cast_attrs.struct.toml | 18 ++ lib/op-attrs/include/op-attrs/ops/combine.h | 9 +- .../include/op-attrs/ops/combine_attrs.h | 53 ++++ .../op-attrs/ops/combine_attrs.struct.toml | 22 ++ lib/op-attrs/include/op-attrs/ops/concat.h | 9 +- .../include/op-attrs/ops/concat_attrs.h | 52 ++++ .../op-attrs/ops/concat_attrs.struct.toml | 22 ++ lib/op-attrs/include/op-attrs/ops/conv_2d.h | 21 +- .../include/op-attrs/ops/conv_2d_attrs.h | 71 ++++++ .../op-attrs/ops/conv_2d_attrs.struct.toml | 29 +++ lib/op-attrs/include/op-attrs/ops/dropout.h | 7 +- .../include/op-attrs/ops/dropout_attrs.h | 59 +++++ .../op-attrs/ops/dropout_attrs.struct.toml | 19 ++ .../include/op-attrs/ops/element_binary.h | 15 +- .../op-attrs/ops/element_binary_attrs.h | 58 +++++ .../ops/element_binary_attrs.struct.toml | 32 +++ .../op-attrs/ops/element_scalar_unary_attrs.h | 52 ++++ .../element_scalar_unary_attrs.struct.toml | 23 ++ .../include/op-attrs/ops/element_unary.h | 14 +- .../op-attrs/ops/element_unary_attrs.h | 51 ++++ .../ops/element_unary_attrs.struct.toml | 19 ++ lib/op-attrs/include/op-attrs/ops/embedding.h | 39 +-- .../include/op-attrs/ops/embedding_attrs.h | 59 +++++ .../op-attrs/ops/embedding_attrs.struct.toml | 32 +++ lib/op-attrs/include/op-attrs/ops/flat.h | 4 +- .../include/op-attrs/ops/flat_attrs.h | 54 ++++ .../op-attrs/ops/flat_attrs.struct.toml | 11 + lib/op-attrs/include/op-attrs/ops/gather.h | 7 +- .../include/op-attrs/ops/gather_attrs.h | 51 ++++ .../op-attrs/ops/gather_attrs.struct.toml | 18 ++ lib/op-attrs/include/op-attrs/ops/input.h | 3 +- .../include/op-attrs/ops/input_attrs.h | 54 ++++ .../op-attrs/ops/input_attrs.struct.toml | 11 + .../include/op-attrs/ops/layer_norm.h | 9 +- .../include/op-attrs/ops/layer_norm_attrs.h | 57 +++++ .../op-attrs/ops/layer_norm_attrs.struct.toml | 27 ++ lib/op-attrs/include/op-attrs/ops/linear.h | 27 +- .../include/op-attrs/ops/linear_attrs.h | 62 +++++ .../op-attrs/ops/linear_attrs.struct.toml | 37 +++ lib/op-attrs/include/op-attrs/ops/noop.h | 4 +- .../include/op-attrs/ops/noop_attrs.h | 54 ++++ .../op-attrs/ops/noop_attrs.struct.toml | 11 + lib/op-attrs/include/op-attrs/ops/pool_2d.h | 46 +--- .../include/op-attrs/ops/pool_2d_attrs.h | 66 +++++ .../op-attrs/ops/pool_2d_attrs.struct.toml | 47 ++++ lib/op-attrs/include/op-attrs/ops/reduce.h | 11 +- .../include/op-attrs/ops/reduce_attrs.h | 58 +++++ .../op-attrs/ops/reduce_attrs.struct.toml | 28 +++ lib/op-attrs/include/op-attrs/ops/reduction.h | 8 +- .../include/op-attrs/ops/reduction_attrs.h | 53 ++++ .../op-attrs/ops/reduction_attrs.struct.toml | 22 ++ .../include/op-attrs/ops/repartition.h | 8 +- .../include/op-attrs/ops/repartition_attrs.h | 53 ++++ .../ops/repartition_attrs.struct.toml | 22 ++ lib/op-attrs/include/op-attrs/ops/replicate.h | 8 +- .../include/op-attrs/ops/replicate_attrs.h | 53 ++++ .../op-attrs/ops/replicate_attrs.struct.toml | 22 ++ lib/op-attrs/include/op-attrs/ops/reshape.h | 7 +- .../include/op-attrs/ops/reshape_attrs.h | 51 ++++ .../op-attrs/ops/reshape_attrs.struct.toml | 18 ++ lib/op-attrs/include/op-attrs/ops/reverse.h | 7 +- .../include/op-attrs/ops/reverse_attrs.h | 51 ++++ .../op-attrs/ops/reverse_attrs.struct.toml | 18 ++ lib/op-attrs/include/op-attrs/ops/softmax.h | 7 +- .../include/op-attrs/ops/softmax_attrs.h | 51 ++++ .../op-attrs/ops/softmax_attrs.struct.toml | 18 ++ lib/op-attrs/include/op-attrs/ops/split.h | 8 +- .../include/op-attrs/ops/split_attrs.h | 54 ++++ .../op-attrs/ops/split_attrs.struct.toml | 23 ++ lib/op-attrs/include/op-attrs/ops/topk.h | 7 +- .../include/op-attrs/ops/topk_attrs.h | 59 +++++ .../op-attrs/ops/topk_attrs.struct.toml | 18 ++ lib/op-attrs/include/op-attrs/ops/transpose.h | 7 +- .../include/op-attrs/ops/transpose_attrs.h | 53 ++++ .../op-attrs/ops/transpose_attrs.struct.toml | 19 ++ lib/op-attrs/include/op-attrs/pool_op.h | 21 ++ .../include/op-attrs/regularizer_attrs.h | 13 + lib/op-attrs/src/op-attrs/aggregate_op.cc | 17 ++ lib/op-attrs/src/op-attrs/ff_dim.cc | 61 +++++ .../src/op-attrs/l1_regularizer_attrs.cc | 69 ++++++ .../src/op-attrs/l2_regularizer_attrs.cc | 69 ++++++ .../src/op-attrs/ops/attention_attrs.cc | 213 ++++++++++++++++ lib/op-attrs/src/op-attrs/ops/batch_matmul.cc | 83 +++++++ .../src/op-attrs/ops/batch_norm_attrs.cc | 68 ++++++ lib/op-attrs/src/op-attrs/ops/broadcast.cc | 22 +- lib/op-attrs/src/op-attrs/ops/cast_attrs.cc | 62 +++++ .../src/op-attrs/ops/combine_attrs.cc | 75 ++++++ lib/op-attrs/src/op-attrs/ops/concat_attrs.cc | 75 ++++++ .../src/op-attrs/ops/conv_2d_attrs.cc | 230 ++++++++++++++++++ .../src/op-attrs/ops/dropout_attrs.cc | 75 ++++++ .../src/op-attrs/ops/element_binary_attrs.cc | 125 ++++++++++ .../ops/element_scalar_unary_attrs.cc | 82 +++++++ .../src/op-attrs/ops/element_unary_attrs.cc | 63 +++++ .../src/op-attrs/ops/embedding_attrs.cc | 118 +++++++++ lib/op-attrs/src/op-attrs/ops/flat_attrs.cc | 63 +++++ lib/op-attrs/src/op-attrs/ops/gather_attrs.cc | 62 +++++ lib/op-attrs/src/op-attrs/ops/input_attrs.cc | 63 +++++ .../src/op-attrs/ops/layer_norm_attrs.cc | 88 +++++++ lib/op-attrs/src/op-attrs/ops/linear_attrs.cc | 139 +++++++++++ lib/op-attrs/src/op-attrs/ops/noop_attrs.cc | 63 +++++ .../src/op-attrs/ops/pool_2d_attrs.cc | 191 +++++++++++++++ lib/op-attrs/src/op-attrs/ops/reduce_attrs.cc | 88 +++++++ .../src/op-attrs/ops/reduction_attrs.cc | 75 ++++++ .../src/op-attrs/ops/repartition_attrs.cc | 76 ++++++ .../src/op-attrs/ops/replicate_attrs.cc | 75 ++++++ .../src/op-attrs/ops/reshape_attrs.cc | 63 +++++ .../src/op-attrs/ops/reverse_attrs.cc | 62 +++++ .../src/op-attrs/ops/softmax_attrs.cc | 62 +++++ lib/op-attrs/src/op-attrs/ops/split_attrs.cc | 78 ++++++ lib/op-attrs/src/op-attrs/ops/topk_attrs.cc | 72 ++++++ .../src/op-attrs/ops/transpose_attrs.cc | 69 ++++++ lib/op-attrs/src/op-attrs/pool_op.cc | 17 ++ 136 files changed, 5529 insertions(+), 328 deletions(-) create mode 100644 .editorconfig create mode 100644 lib/op-attrs/include/op-attrs/aggregate_op.h create mode 100644 lib/op-attrs/include/op-attrs/ff_dim.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/l1_regularizer_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/l2_regularizer_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/attention_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/cast_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/combine_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/concat_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/dropout_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/element_binary_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/element_unary_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/embedding_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/flat_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/gather_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/input_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/linear_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/noop_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/reduce_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/reduction_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/repartition_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/replicate_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/reshape_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/reverse_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/softmax_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/split_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/topk_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/transpose_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/pool_op.h create mode 100644 lib/op-attrs/include/op-attrs/regularizer_attrs.h create mode 100644 lib/op-attrs/src/op-attrs/aggregate_op.cc create mode 100644 lib/op-attrs/src/op-attrs/ff_dim.cc create mode 100644 lib/op-attrs/src/op-attrs/l1_regularizer_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/l2_regularizer_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/attention_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/batch_matmul.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/cast_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/combine_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/concat_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/dropout_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/element_binary_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/element_unary_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/embedding_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/flat_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/gather_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/input_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/linear_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/noop_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/reduce_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/reduction_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/repartition_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/replicate_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/reshape_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/reverse_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/softmax_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/split_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/topk_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/transpose_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/pool_op.cc diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000..7242dd283c --- /dev/null +++ b/.editorconfig @@ -0,0 +1,22 @@ +root = true + +# Unix-style newlines with a newline ending every file +[*] +end_of_line = lf +insert_final_newline = true + +[{CMakeLists.txt,*.cmake}] +indent_style = space +indent_size = 2 + +[*.{cc,h,cu,cpp}] +indent_style = space +indent_size = 2 + +[*.py] +indent_style = space +indent_size = 4 + +[*.toml] +indent_style = space +indent_size = 2 diff --git a/flake.lock b/flake.lock index c62a25de80..b42ff59e74 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1712480160, - "narHash": "sha256-CbuzEbFxmgI0Kd7mOwrrDhY1QGMLY5m6HG5dLHkVUus=", + "lastModified": 1712555489, + "narHash": "sha256-V7Ck7y0BC18HR+CHd8fSp9i3ObGwdJsu22NpxQvTGVs=", "owner": "lockshaw", "repo": "proj", - "rev": "041fa8ecf00dec6276d71ed70b275b53dccb1c93", + "rev": "6b7312e2079178332ffefc24f54c11879fc85e7a", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 09a1dfcae5..3300d7dbba 100644 --- a/flake.nix +++ b/flake.nix @@ -109,6 +109,10 @@ inputsFrom = [ ci ]; inherit (ci) CMAKE_FLAGS; + VIMPLUGINS = lib.strings.concatStringsSep "," [ + "${proj-repo.packages.${system}.proj-nvim}" + ]; + buildInputs = builtins.concatLists [ (with pkgs; [ clang-tools diff --git a/lib/op-attrs/include/op-attrs/activation.h b/lib/op-attrs/include/op-attrs/activation.h index 8fa07825fd..87729b1206 100644 --- a/lib/op-attrs/include/op-attrs/activation.h +++ b/lib/op-attrs/include/op-attrs/activation.h @@ -2,11 +2,18 @@ #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_ACTIVATION_H #include "utils/fmt.h" +#include "nlohmann/json.hpp" namespace FlexFlow { enum class Activation { RELU, SIGMOID, TANH, GELU }; +NLOHMANN_JSON_SERIALIZE_ENUM(Activation, + {{Activation::RELU, "RELU"}, + {Activation::SIGMOID, "SIGMOID"}, + {Activation::TANH, "TANH"}, + {Activation::GELU, "GELU"}}); + } namespace fmt { diff --git a/lib/op-attrs/include/op-attrs/aggregate_op.h b/lib/op-attrs/include/op-attrs/aggregate_op.h new file mode 100644 index 0000000000..b0d1e6cf93 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/aggregate_op.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_H + +#include "utils/fmt.h" +#include "nlohmann/json.hpp" + +namespace FlexFlow { + +enum class AggregateOp { + SUM, + AVG, +}; + +NLOHMANN_JSON_SERIALIZE_ENUM(AggregateOp, + {{AggregateOp::SUM, "SUM"}, + {AggregateOp::AVG, "AVG"}}); + +std::string format_as(AggregateOp); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index b726d0687f..b03667466d 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -28,11 +28,11 @@ struct DimOrdered { : contents(contents.begin(), contents.end()) {} T const &at(Idx idx) const { - return this->contents.at(idx.value()); + return this->contents.at(idx.value); } T &at(Idx idx) { - return this->contents.at(idx.value()); + return this->contents.at(idx.value); } T const &operator[](Idx idx) const { diff --git a/lib/op-attrs/include/op-attrs/ff_dim.h b/lib/op-attrs/include/op-attrs/ff_dim.h index be1f148a70..d7f590aeac 100644 --- a/lib/op-attrs/include/op-attrs/ff_dim.h +++ b/lib/op-attrs/include/op-attrs/ff_dim.h @@ -1,18 +1,50 @@ -#ifndef _FLEXFLOW_OPATTRS_INCLUDE_FF_DIM_H -#define _FLEXFLOW_OPATTRS_INCLUDE_FF_DIM_H +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ff_dim.struct.toml -#include "utils/strong_typedef.h" +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include #include +#include +#include namespace FlexFlow { +struct ff_dim_t { + ff_dim_t() = delete; + ff_dim_t(int const &value); -struct ff_dim_t : public numerical_typedef { - using numerical_typedef::numerical_typedef; + bool operator==(ff_dim_t const &) const; + bool operator!=(ff_dim_t const &) const; + bool operator<(ff_dim_t const &) const; + bool operator>(ff_dim_t const &) const; + bool operator<=(ff_dim_t const &) const; + bool operator>=(ff_dim_t const &) const; + int value; }; - } // namespace FlexFlow -MAKE_TYPEDEF_HASHABLE(::FlexFlow::ff_dim_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::ff_dim_t, "ff_dim"); +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ff_dim_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ff_dim_t from_json(json const &); + static void to_json(json &, FlexFlow::ff_dim_t const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ff_dim_t const &); +std::ostream &operator<<(std::ostream &, ff_dim_t const &); +} // namespace FlexFlow -#endif +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H diff --git a/lib/op-attrs/include/op-attrs/ff_dim.struct.toml b/lib/op-attrs/include/op-attrs/ff_dim.struct.toml new file mode 100644 index 0000000000..feae1e4b21 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ff_dim.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "ff_dim_t" + +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.h b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.h new file mode 100644 index 0000000000..18afd8a38b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L1_REGULARIZER_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L1_REGULARIZER_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct L1RegularizerAttrs { + L1RegularizerAttrs() = delete; + L1RegularizerAttrs(float const &lambda); + + bool operator==(L1RegularizerAttrs const &) const; + bool operator!=(L1RegularizerAttrs const &) const; + bool operator<(L1RegularizerAttrs const &) const; + bool operator>(L1RegularizerAttrs const &) const; + bool operator<=(L1RegularizerAttrs const &) const; + bool operator>=(L1RegularizerAttrs const &) const; + float lambda; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::L1RegularizerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::L1RegularizerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::L1RegularizerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(L1RegularizerAttrs const &); +std::ostream &operator<<(std::ostream &, L1RegularizerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L1_REGULARIZER_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml new file mode 100644 index 0000000000..60fabfb94a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "L1RegularizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "lambda" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.h b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.h new file mode 100644 index 0000000000..3b403334dc --- /dev/null +++ b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L2_REGULARIZER_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L2_REGULARIZER_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct L2RegularizerAttrs { + L2RegularizerAttrs() = delete; + L2RegularizerAttrs(float const &lambda); + + bool operator==(L2RegularizerAttrs const &) const; + bool operator!=(L2RegularizerAttrs const &) const; + bool operator<(L2RegularizerAttrs const &) const; + bool operator>(L2RegularizerAttrs const &) const; + bool operator<=(L2RegularizerAttrs const &) const; + bool operator>=(L2RegularizerAttrs const &) const; + float lambda; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::L2RegularizerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::L2RegularizerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::L2RegularizerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(L2RegularizerAttrs const &); +std::ostream &operator<<(std::ostream &, L2RegularizerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L2_REGULARIZER_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml new file mode 100644 index 0000000000..adce4397a4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "L2RegularizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "lambda" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index ec3e592607..4000732ddf 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -4,24 +4,10 @@ #include "core.h" #include "op-attrs/parallel_tensor_shape.h" #include "utils/visitable.h" +#include "op-attrs/ops/attention_attrs.h" namespace FlexFlow { -struct MultiHeadAttentionAttrs { - req embed_dim, num_heads, kdim, vdim; - req dropout; - req bias, add_bias_kv, add_zero_attn; -}; -FF_VISITABLE_STRUCT(MultiHeadAttentionAttrs, - embed_dim, - num_heads, - kdim, - vdim, - dropout, - bias, - add_bias_kv, - add_zero_attn); - template struct MultiHeadAttentionInputs : public use_visitable_cmp> { diff --git a/lib/op-attrs/include/op-attrs/ops/attention_attrs.h b/lib/op-attrs/include/op-attrs/ops/attention_attrs.h new file mode 100644 index 0000000000..029ddc08ac --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention_attrs.h @@ -0,0 +1,72 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct MultiHeadAttentionAttrs { + MultiHeadAttentionAttrs() = delete; + MultiHeadAttentionAttrs(int const &embed_dim, + int const &num_heads, + int const &kdim, + int const &vdim, + float const &dropout, + bool const &bias, + bool const &add_bias_kv, + bool const &add_zero_attn); + + bool operator==(MultiHeadAttentionAttrs const &) const; + bool operator!=(MultiHeadAttentionAttrs const &) const; + bool operator<(MultiHeadAttentionAttrs const &) const; + bool operator>(MultiHeadAttentionAttrs const &) const; + bool operator<=(MultiHeadAttentionAttrs const &) const; + bool operator>=(MultiHeadAttentionAttrs const &) const; + int embed_dim; + int num_heads; + int kdim; + int vdim; + float dropout; + bool bias; + bool add_bias_kv; + bool add_zero_attn; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MultiHeadAttentionAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MultiHeadAttentionAttrs from_json(json const &); + static void to_json(json &, FlexFlow::MultiHeadAttentionAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionAttrs const &); +std::ostream &operator<<(std::ostream &, MultiHeadAttentionAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml new file mode 100644 index 0000000000..d96d8af69c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml @@ -0,0 +1,43 @@ +namespace = "FlexFlow" +name = "MultiHeadAttentionAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "embed_dim" +type = "int" + +[[fields]] +name = "num_heads" +type = "int" + +[[fields]] +name = "kdim" +type = "int" + +[[fields]] +name = "vdim" +type = "int" + +[[fields]] +name = "dropout" +type = "float" + +[[fields]] +name = "bias" +type = "bool" + +[[fields]] +name = "add_bias_kv" +type = "bool" + +[[fields]] +name = "add_zero_attn" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index b05a5eb022..db781f547b 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -1,18 +1,59 @@ -#ifndef _FF_OP_META_BATCH_MATMUL_ATTRS_H -#define _FF_OP_META_BATCH_MATMUL_ATTRS_H +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml -#include "core.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H -namespace FlexFlow { +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include +#include +namespace FlexFlow { struct BatchMatmulAttrs { - req a_seq_length_dim, b_seq_length_dim; + BatchMatmulAttrs() = delete; + BatchMatmulAttrs(int const &a_seq_length_dim, int const &b_seq_length_dim); + + bool operator==(BatchMatmulAttrs const &) const; + bool operator!=(BatchMatmulAttrs const &) const; + bool operator<(BatchMatmulAttrs const &) const; + bool operator>(BatchMatmulAttrs const &) const; + bool operator<=(BatchMatmulAttrs const &) const; + bool operator>=(BatchMatmulAttrs const &) const; + int a_seq_length_dim; + int b_seq_length_dim; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::BatchMatmulAttrs const &) const; }; -FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); +} // namespace std -CHECK_VALID_OP_ATTR(BatchMatmulAttrs); +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::BatchMatmulAttrs from_json(json const &); + static void to_json(json &, FlexFlow::BatchMatmulAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(BatchMatmulAttrs const &); +std::ostream &operator<<(std::ostream &, BatchMatmulAttrs const &); } // namespace FlexFlow -#endif +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml b/lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml new file mode 100644 index 0000000000..3b1dd3f687 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "BatchMatmulAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "a_seq_length_dim" +type = "int" + +[[fields]] +name = "b_seq_length_dim" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index 4ec823d4ae..1683d2a30c 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -4,14 +4,10 @@ #include "core.h" #include "op-attrs/parallel_tensor_shape.h" #include "utils/visitable.h" +#include "op-attrs/ops/batch_norm_attrs.h" namespace FlexFlow { -struct BatchNormAttrs { - req relu; -}; -FF_VISITABLE_STRUCT(BatchNormAttrs, relu); - ParallelTensorShape get_output_shape(BatchNormAttrs const &); CHECK_VALID_OP_ATTR(BatchNormAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.h b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.h new file mode 100644 index 0000000000..f786d730c8 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct BatchNormAttrs { + BatchNormAttrs() = delete; + BatchNormAttrs(bool const &relu); + + bool operator==(BatchNormAttrs const &) const; + bool operator!=(BatchNormAttrs const &) const; + bool operator<(BatchNormAttrs const &) const; + bool operator>(BatchNormAttrs const &) const; + bool operator<=(BatchNormAttrs const &) const; + bool operator>=(BatchNormAttrs const &) const; + bool relu; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::BatchNormAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::BatchNormAttrs from_json(json const &); + static void to_json(json &, FlexFlow::BatchNormAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(BatchNormAttrs const &); +std::ostream &operator<<(std::ostream &, BatchNormAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml new file mode 100644 index 0000000000..bc82f3c743 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "BatchNormAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "relu" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index a370224e4e..2d26a5a51d 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -5,16 +5,19 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H -#include "core.h" +#include "fmt/format.h" #include "nlohmann/json.hpp" #include "utils/stack_vector.h" #include +#include +#include #include namespace FlexFlow { struct BroadcastAttrs { BroadcastAttrs() = delete; - BroadcastAttrs(stack_vector const &target_dims); + BroadcastAttrs( + ::FlexFlow::stack_vector const &target_dims); bool operator==(BroadcastAttrs const &) const; bool operator!=(BroadcastAttrs const &) const; @@ -22,7 +25,7 @@ struct BroadcastAttrs { bool operator>(BroadcastAttrs const &) const; bool operator<=(BroadcastAttrs const &) const; bool operator>=(BroadcastAttrs const &) const; - stack_vector target_dims; + ::FlexFlow::stack_vector target_dims; }; } // namespace FlexFlow @@ -41,4 +44,9 @@ struct adl_serializer { }; } // namespace nlohmann +namespace FlexFlow { +std::string format_as(BroadcastAttrs const &); +std::ostream &operator<<(std::ostream &, BroadcastAttrs const &); +} // namespace FlexFlow + #endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml b/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml index b46c080e28..ae5549c9b9 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml @@ -5,13 +5,14 @@ features = [ "ord", "hash", "json", + # "rapidcheck", + "fmt", ] includes = [ - "core.h", "utils/stack_vector.h", ] [[fields]] name = "target_dims" -type = "stack_vector" +type = "::FlexFlow::stack_vector" diff --git a/lib/op-attrs/include/op-attrs/ops/cast.h b/lib/op-attrs/include/op-attrs/ops/cast.h index 63563f8df8..e86f5a1c82 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast.h +++ b/lib/op-attrs/include/op-attrs/ops/cast.h @@ -2,17 +2,10 @@ #define _FLEXFLOW_CAST_ATTRS_H #include "core.h" -#include "op-attrs/datatype.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/cast_attrs.h" namespace FlexFlow { -struct CastAttrs { - req dtype; -}; -FF_VISITABLE_STRUCT(CastAttrs, dtype); - CHECK_VALID_OP_ATTR(CastAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/cast_attrs.h b/lib/op-attrs/include/op-attrs/ops/cast_attrs.h new file mode 100644 index 0000000000..ca70701261 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/cast_attrs.h @@ -0,0 +1,51 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct CastAttrs { + CastAttrs() = delete; + CastAttrs(DataType const &dtype); + + bool operator==(CastAttrs const &) const; + bool operator!=(CastAttrs const &) const; + bool operator<(CastAttrs const &) const; + bool operator>(CastAttrs const &) const; + bool operator<=(CastAttrs const &) const; + bool operator>=(CastAttrs const &) const; + DataType dtype; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::CastAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::CastAttrs from_json(json const &); + static void to_json(json &, FlexFlow::CastAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(CastAttrs const &); +std::ostream &operator<<(std::ostream &, CastAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml new file mode 100644 index 0000000000..75231ebc45 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "CastAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/datatype.h" +] + +[[fields]] +name = "dtype" +type = "DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/combine.h b/lib/op-attrs/include/op-attrs/ops/combine.h index deaba9e093..3c6b951462 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine.h +++ b/lib/op-attrs/include/op-attrs/ops/combine.h @@ -2,17 +2,10 @@ #define _FLEXFLOW_COMBINE_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/combine_attrs.h" namespace FlexFlow { -struct CombineAttrs { - ff_dim_t combine_dim; - req combine_degree; -}; -FF_VISITABLE_STRUCT(CombineAttrs, combine_dim, combine_degree); CHECK_VALID_OP_ATTR(CombineAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/combine_attrs.h b/lib/op-attrs/include/op-attrs/ops/combine_attrs.h new file mode 100644 index 0000000000..4663fd495d --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/combine_attrs.h @@ -0,0 +1,53 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct CombineAttrs { + CombineAttrs() = delete; + CombineAttrs(::FlexFlow::ff_dim_t const &combine_dim, + int const &combine_degree); + + bool operator==(CombineAttrs const &) const; + bool operator!=(CombineAttrs const &) const; + bool operator<(CombineAttrs const &) const; + bool operator>(CombineAttrs const &) const; + bool operator<=(CombineAttrs const &) const; + bool operator>=(CombineAttrs const &) const; + ::FlexFlow::ff_dim_t combine_dim; + int combine_degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::CombineAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::CombineAttrs from_json(json const &); + static void to_json(json &, FlexFlow::CombineAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(CombineAttrs const &); +std::ostream &operator<<(std::ostream &, CombineAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml new file mode 100644 index 0000000000..dea30897bb --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "CombineAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", +] + +[[fields]] +name = "combine_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "combine_degree" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index 78f848f18b..e01164eb5b 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -2,17 +2,10 @@ #define _FLEXFLOW_CONCAT_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/concat_attrs.h" namespace FlexFlow { -struct ConcatAttrs { - ff_dim_t axis; - req num_inputs; -}; -FF_VISITABLE_STRUCT(ConcatAttrs, axis, num_inputs); CHECK_VALID_OP_ATTR(ConcatAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.h b/lib/op-attrs/include/op-attrs/ops/concat_attrs.h new file mode 100644 index 0000000000..93cafadcb2 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.h @@ -0,0 +1,52 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ConcatAttrs { + ConcatAttrs() = delete; + ConcatAttrs(::FlexFlow::ff_dim_t const &axis, int const &num_inputs); + + bool operator==(ConcatAttrs const &) const; + bool operator!=(ConcatAttrs const &) const; + bool operator<(ConcatAttrs const &) const; + bool operator>(ConcatAttrs const &) const; + bool operator<=(ConcatAttrs const &) const; + bool operator>=(ConcatAttrs const &) const; + ::FlexFlow::ff_dim_t axis; + int num_inputs; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ConcatAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ConcatAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ConcatAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ConcatAttrs const &); +std::ostream &operator<<(std::ostream &, ConcatAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml new file mode 100644 index 0000000000..032657b60c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "ConcatAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h" +] + +[[fields]] +name = "axis" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "num_inputs" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index 79980d545d..d515c76048 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -2,31 +2,12 @@ #define _FLEXFLOW_CONV_2D_ATTRS_H #include "core.h" -#include "op-attrs/activation.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/conv_2d_attrs.h" namespace FlexFlow { -struct Conv2DAttrs { - int out_channels, kernel_h, kernel_w, stride_h, stride_w, padding_h, - padding_w, groups; - std::optional activation; - req use_bias; -}; - -FF_VISITABLE_STRUCT(Conv2DAttrs, - out_channels, - kernel_h, - kernel_w, - stride_h, - stride_w, - padding_h, - padding_w, - groups, - activation, - use_bias); CHECK_VALID_OP_ATTR(Conv2DAttrs); TensorShape get_kernel_shape(Conv2DAttrs const &, TensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.h b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.h new file mode 100644 index 0000000000..d437e4a95a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.h @@ -0,0 +1,71 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/activation.h" +#include "utils/json.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct Conv2DAttrs { + Conv2DAttrs() = delete; + Conv2DAttrs(int const &out_channels, + int const &kernel_h, + int const &kernel_w, + int const &stride_h, + int const &stride_w, + int const &padding_h, + int const &padding_w, + int const &groups, + std::optional<::FlexFlow::Activation> const &activation, + bool const &use_bias); + + bool operator==(Conv2DAttrs const &) const; + bool operator!=(Conv2DAttrs const &) const; + bool operator<(Conv2DAttrs const &) const; + bool operator>(Conv2DAttrs const &) const; + bool operator<=(Conv2DAttrs const &) const; + bool operator>=(Conv2DAttrs const &) const; + int out_channels; + int kernel_h; + int kernel_w; + int stride_h; + int stride_w; + int padding_h; + int padding_w; + int groups; + std::optional<::FlexFlow::Activation> activation; + bool use_bias; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::Conv2DAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::Conv2DAttrs from_json(json const &); + static void to_json(json &, FlexFlow::Conv2DAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(Conv2DAttrs const &); +std::ostream &operator<<(std::ostream &, Conv2DAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml new file mode 100644 index 0000000000..1d0b59ce87 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "Conv2DAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "op-attrs/activation.h", + "utils/json.h", +] + +fields = [ + { name = "out_channels", type = "int" }, + { name = "kernel_h", type = "int" }, + { name = "kernel_w", type = "int" }, + { name = "stride_h", type = "int" }, + { name = "stride_w", type = "int" }, + { name = "padding_h", type = "int" }, + { name = "padding_w", type = "int" }, + { name = "groups", type = "int" }, + { name = "activation", type = "std::optional<::FlexFlow::Activation>" }, + { name = "use_bias", type = "bool" }, +] diff --git a/lib/op-attrs/include/op-attrs/ops/dropout.h b/lib/op-attrs/include/op-attrs/ops/dropout.h index 8e0049f526..54def8a6c4 100644 --- a/lib/op-attrs/include/op-attrs/ops/dropout.h +++ b/lib/op-attrs/include/op-attrs/ops/dropout.h @@ -3,15 +3,10 @@ #include "core.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/dropout_attrs.h" namespace FlexFlow { -struct DropoutAttrs { - req rate; - req seed; -}; -FF_VISITABLE_STRUCT(DropoutAttrs, rate, seed); CHECK_VALID_OP_ATTR(DropoutAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/dropout_attrs.h b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.h new file mode 100644 index 0000000000..6d17870138 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.h @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct DropoutAttrs { + DropoutAttrs() = delete; + DropoutAttrs(float const &rate, unsigned long long const &seed); + + bool operator==(DropoutAttrs const &) const; + bool operator!=(DropoutAttrs const &) const; + bool operator<(DropoutAttrs const &) const; + bool operator>(DropoutAttrs const &) const; + bool operator<=(DropoutAttrs const &) const; + bool operator>=(DropoutAttrs const &) const; + float rate; + unsigned long long seed; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::DropoutAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::DropoutAttrs from_json(json const &); + static void to_json(json &, FlexFlow::DropoutAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(DropoutAttrs const &); +std::ostream &operator<<(std::ostream &, DropoutAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml new file mode 100644 index 0000000000..8731e0780b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "DropoutAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "rate" +type = "float" + +[[fields]] +name = "seed" +type = "unsigned long long" diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary.h b/lib/op-attrs/include/op-attrs/ops/element_binary.h index c4a096166d..b46c66807d 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary.h @@ -2,24 +2,11 @@ #define _FLEXFLOW_ELEMENT_BINARY_ATTRS_H #include "core.h" -#include "op-attrs/datatype.h" -#include "op-attrs/op.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/element_binary_attrs.h" namespace FlexFlow { -struct ElementBinaryAttrs { - req type; - req compute_type; - req should_broadcast_lhs; - req should_broadcast_rhs; -}; -FF_VISITABLE_STRUCT(ElementBinaryAttrs, - type, - compute_type, - should_broadcast_lhs, - should_broadcast_rhs); CHECK_VALID_OP_ATTR(ElementBinaryAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.h b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.h new file mode 100644 index 0000000000..df25e3d4d8 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.h" +#include "op-attrs/op.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ElementBinaryAttrs { + ElementBinaryAttrs() = delete; + ElementBinaryAttrs(::FlexFlow::Op const &type, + ::FlexFlow::DataType const &compute_type, + bool const &should_broadcast_lhs, + bool const &should_broadcast_rhs); + + bool operator==(ElementBinaryAttrs const &) const; + bool operator!=(ElementBinaryAttrs const &) const; + bool operator<(ElementBinaryAttrs const &) const; + bool operator>(ElementBinaryAttrs const &) const; + bool operator<=(ElementBinaryAttrs const &) const; + bool operator>=(ElementBinaryAttrs const &) const; + ::FlexFlow::Op type; + ::FlexFlow::DataType compute_type; + bool should_broadcast_lhs; + bool should_broadcast_rhs; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ElementBinaryAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ElementBinaryAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ElementBinaryAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ElementBinaryAttrs const &); +std::ostream &operator<<(std::ostream &, ElementBinaryAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml new file mode 100644 index 0000000000..e09f1d551b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "ElementBinaryAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/op.h", + "op-attrs/datatype.h", +] + +[[fields]] +name = "type" +type = "::FlexFlow::Op" + +[[fields]] +name = "compute_type" +type = "::FlexFlow::DataType" + +[[fields]] +name = "should_broadcast_lhs" +type = "bool" + +[[fields]] +name = "should_broadcast_rhs" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.h b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.h new file mode 100644 index 0000000000..445b2b7849 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.h @@ -0,0 +1,52 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_SCALAR_UNARY_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_SCALAR_UNARY_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/op.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ElementScalarUnaryAttrs { + ElementScalarUnaryAttrs() = delete; + ElementScalarUnaryAttrs(::FlexFlow::Op const &op_type, float const &scalar); + + bool operator==(ElementScalarUnaryAttrs const &) const; + bool operator!=(ElementScalarUnaryAttrs const &) const; + bool operator<(ElementScalarUnaryAttrs const &) const; + bool operator>(ElementScalarUnaryAttrs const &) const; + bool operator<=(ElementScalarUnaryAttrs const &) const; + bool operator>=(ElementScalarUnaryAttrs const &) const; + ::FlexFlow::Op op_type; + float scalar; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ElementScalarUnaryAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ElementScalarUnaryAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ElementScalarUnaryAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ElementScalarUnaryAttrs const &); +std::ostream &operator<<(std::ostream &, ElementScalarUnaryAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_SCALAR_UNARY_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml new file mode 100644 index 0000000000..3f20ea7a51 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ElementScalarUnaryAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/op.h" +] + +[[fields]] +name = "op_type" +type = "::FlexFlow::Op" + +[[fields]] +name = "scalar" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 5e19b81c8c..e1c874ed5a 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -2,23 +2,13 @@ #define _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H #include "core.h" -#include "op-attrs/op.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/element_unary_attrs.h" +#include "op-attrs/ops/element_scalar_unary_attrs.h" namespace FlexFlow { -struct ElementUnaryAttrs { - req op_type; -}; -FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); - -struct ElementScalarUnaryAttrs { - Op op_type; - req scalar; -}; -FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op_type, scalar); CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.h b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.h new file mode 100644 index 0000000000..163e2824cb --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.h @@ -0,0 +1,51 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/op.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ElementUnaryAttrs { + ElementUnaryAttrs() = delete; + ElementUnaryAttrs(::FlexFlow::Op const &op_type); + + bool operator==(ElementUnaryAttrs const &) const; + bool operator!=(ElementUnaryAttrs const &) const; + bool operator<(ElementUnaryAttrs const &) const; + bool operator>(ElementUnaryAttrs const &) const; + bool operator<=(ElementUnaryAttrs const &) const; + bool operator>=(ElementUnaryAttrs const &) const; + ::FlexFlow::Op op_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ElementUnaryAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ElementUnaryAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ElementUnaryAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ElementUnaryAttrs const &); +std::ostream &operator<<(std::ostream &, ElementUnaryAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml new file mode 100644 index 0000000000..60ec81cd66 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "ElementUnaryAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/op.h" +] + +[[fields]] +name = "op_type" +type = "::FlexFlow::Op" diff --git a/lib/op-attrs/include/op-attrs/ops/embedding.h b/lib/op-attrs/include/op-attrs/ops/embedding.h index 8b00fa22ce..8ad95a7fca 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding.h @@ -2,53 +2,16 @@ #define _FLEXFLOW_EMBEDDING_ATTRS_H #include "core.h" -#include "op-attrs/datatype.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" -#include "utils/fmt.h" -#include "utils/visitable.h" +#include "op-attrs/ops/embedding_attrs.h" namespace FlexFlow { -enum class AggregateOp { - SUM, - AVG, -}; - -struct EmbeddingAttrs { - req num_entries, out_channels; - req aggr; - req data_type; -}; -FF_VISITABLE_STRUCT(EmbeddingAttrs, num_entries, out_channels, aggr, data_type); CHECK_VALID_OP_ATTR(EmbeddingAttrs); TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &); } // namespace FlexFlow -namespace fmt { - -template <> -struct formatter<::FlexFlow::AggregateOp> : formatter { - template - auto format(::FlexFlow::AggregateOp o, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (o) { - case AggregateOp::SUM: - name = "Sum"; - break; - case AggregateOp::AVG: - name = "Avg"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - #endif diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.h b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.h new file mode 100644 index 0000000000..3c72b4c12f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.h @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/aggregate_op.h" +#include "op-attrs/datatype.h" +#include "utils/stack_vector.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct EmbeddingAttrs { + EmbeddingAttrs() = delete; + EmbeddingAttrs(int const &num_entries, + int const &out_channels, + ::FlexFlow::AggregateOp const &aggr, + ::FlexFlow::DataType const &data_type); + + bool operator==(EmbeddingAttrs const &) const; + bool operator!=(EmbeddingAttrs const &) const; + bool operator<(EmbeddingAttrs const &) const; + bool operator>(EmbeddingAttrs const &) const; + bool operator<=(EmbeddingAttrs const &) const; + bool operator>=(EmbeddingAttrs const &) const; + int num_entries; + int out_channels; + ::FlexFlow::AggregateOp aggr; + ::FlexFlow::DataType data_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::EmbeddingAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::EmbeddingAttrs from_json(json const &); + static void to_json(json &, FlexFlow::EmbeddingAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(EmbeddingAttrs const &); +std::ostream &operator<<(std::ostream &, EmbeddingAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml new file mode 100644 index 0000000000..7985ca694c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "EmbeddingAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "utils/stack_vector.h", + "op-attrs/aggregate_op.h", + "op-attrs/datatype.h", +] + +[[fields]] +name = "num_entries" +type = "int" + +[[fields]] +name = "out_channels" +type = "int" + +[[fields]] +name = "aggr" +type = "::FlexFlow::AggregateOp" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index 706689199d..dac75e4aa3 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -3,12 +3,10 @@ #include "core.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/flat_attrs.h" namespace FlexFlow { -struct FlatAttrs {}; -FF_VISITABLE_STRUCT(FlatAttrs); CHECK_VALID_OP_ATTR(FlatAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.h b/lib/op-attrs/include/op-attrs/ops/flat_attrs.h new file mode 100644 index 0000000000..bc7d8f4a62 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.h @@ -0,0 +1,54 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct FlatAttrs { + bool operator==(FlatAttrs const &) const; + bool operator!=(FlatAttrs const &) const; + bool operator<(FlatAttrs const &) const; + bool operator>(FlatAttrs const &) const; + bool operator<=(FlatAttrs const &) const; + bool operator>=(FlatAttrs const &) const; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::FlatAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::FlatAttrs from_json(json const &); + static void to_json(json &, FlexFlow::FlatAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(FlatAttrs const &); +std::ostream &operator<<(std::ostream &, FlatAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml new file mode 100644 index 0000000000..e445535e29 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "FlatAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/op-attrs/include/op-attrs/ops/gather.h b/lib/op-attrs/include/op-attrs/ops/gather.h index ca2406ef75..596d266bb4 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather.h +++ b/lib/op-attrs/include/op-attrs/ops/gather.h @@ -2,16 +2,11 @@ #define _FLEXFLOW_GATHER_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/gather_attrs.h" namespace FlexFlow { -struct GatherAttrs { - ff_dim_t dim; -}; -FF_VISITABLE_STRUCT(GatherAttrs, dim); CHECK_VALID_OP_ATTR(GatherAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/gather_attrs.h b/lib/op-attrs/include/op-attrs/ops/gather_attrs.h new file mode 100644 index 0000000000..a91ad5fb29 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/gather_attrs.h @@ -0,0 +1,51 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct GatherAttrs { + GatherAttrs() = delete; + GatherAttrs(::FlexFlow::ff_dim_t const &dim); + + bool operator==(GatherAttrs const &) const; + bool operator!=(GatherAttrs const &) const; + bool operator<(GatherAttrs const &) const; + bool operator>(GatherAttrs const &) const; + bool operator<=(GatherAttrs const &) const; + bool operator>=(GatherAttrs const &) const; + ::FlexFlow::ff_dim_t dim; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::GatherAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::GatherAttrs from_json(json const &); + static void to_json(json &, FlexFlow::GatherAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(GatherAttrs const &); +std::ostream &operator<<(std::ostream &, GatherAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml new file mode 100644 index 0000000000..141e41bc24 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "GatherAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h" +] + +[[fields]] +name = "dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/input.h b/lib/op-attrs/include/op-attrs/ops/input.h index 26c486c9ac..73730c76d3 100644 --- a/lib/op-attrs/include/op-attrs/ops/input.h +++ b/lib/op-attrs/include/op-attrs/ops/input.h @@ -3,11 +3,10 @@ #include "core.h" #include "utils/visitable.h" +#include "op-attrs/ops/input_attrs.h" namespace FlexFlow { -struct InputAttrs {}; -FF_VISITABLE_STRUCT(InputAttrs); CHECK_VALID_OP_ATTR(InputAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/input_attrs.h b/lib/op-attrs/include/op-attrs/ops/input_attrs.h new file mode 100644 index 0000000000..b700799318 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/input_attrs.h @@ -0,0 +1,54 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_INPUT_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_INPUT_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct InputAttrs { + bool operator==(InputAttrs const &) const; + bool operator!=(InputAttrs const &) const; + bool operator<(InputAttrs const &) const; + bool operator>(InputAttrs const &) const; + bool operator<=(InputAttrs const &) const; + bool operator>=(InputAttrs const &) const; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::InputAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::InputAttrs from_json(json const &); + static void to_json(json &, FlexFlow::InputAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(InputAttrs const &); +std::ostream &operator<<(std::ostream &, InputAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_INPUT_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml new file mode 100644 index 0000000000..7e29de78df --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "InputAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index dab055b2c9..97fd4990d5 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -2,18 +2,11 @@ #define _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/layer_norm_attrs.h" namespace FlexFlow { -struct LayerNormAttrs { - stack_vector axes; - req elementwise_affine; - req eps; -}; -FF_VISITABLE_STRUCT(LayerNormAttrs, axes, elementwise_affine, eps); CHECK_VALID_OP_ATTR(LayerNormAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.h b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.h new file mode 100644 index 0000000000..b5839df513 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.h @@ -0,0 +1,57 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.h" +#include "utils/stack_vector.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct LayerNormAttrs { + LayerNormAttrs() = delete; + LayerNormAttrs(::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, + MAX_TENSOR_DIM> const &axes, + bool const &elementwise_affine, + float const &eps); + + bool operator==(LayerNormAttrs const &) const; + bool operator!=(LayerNormAttrs const &) const; + bool operator<(LayerNormAttrs const &) const; + bool operator>(LayerNormAttrs const &) const; + bool operator<=(LayerNormAttrs const &) const; + bool operator>=(LayerNormAttrs const &) const; + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> axes; + bool elementwise_affine; + float eps; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LayerNormAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::LayerNormAttrs from_json(json const &); + static void to_json(json &, FlexFlow::LayerNormAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(LayerNormAttrs const &); +std::ostream &operator<<(std::ostream &, LayerNormAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml new file mode 100644 index 0000000000..5be7f82256 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "LayerNormAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "utils/stack_vector.h", +] + +[[fields]] +name = "axes" +type = "::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>" + +[[fields]] +name = "elementwise_affine" +type = "bool" + +[[fields]] +name = "eps" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index 2c27b09f7c..d9ce4d354e 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -1,37 +1,12 @@ #ifndef _FLEXFLOW_LINEAR_ATTRS_H #define _FLEXFLOW_LINEAR_ATTRS_H -#include "op-attrs/activation.h" -#include "op-attrs/datatype.h" #include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/linear_attrs.h" namespace FlexFlow { -struct L1RegularizerAttrs { - req lambda; -}; -FF_VISITABLE_STRUCT(L1RegularizerAttrs, lambda); -CHECK_VALID_OP_ATTR(L1RegularizerAttrs); - -struct L2RegularizerAttrs { - req lambda; -}; -FF_VISITABLE_STRUCT(L2RegularizerAttrs, lambda); -CHECK_VALID_OP_ATTR(L2RegularizerAttrs); - -using RegularizerAttrs = std::variant; - -struct LinearAttrs { - int out_channels; - bool use_bias; - DataType data_type; - Activation activation; - req> regularizer; -}; -FF_VISITABLE_STRUCT( - LinearAttrs, out_channels, use_bias, data_type, activation, regularizer); CHECK_VALID_OP_ATTR(LinearAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.h b/lib/op-attrs/include/op-attrs/ops/linear_attrs.h new file mode 100644 index 0000000000..1c1fae52c4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/activation.h" +#include "op-attrs/datatype.h" +#include "op-attrs/regularizer_attrs.h" +#include "utils/json.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct LinearAttrs { + LinearAttrs() = delete; + LinearAttrs(int const &out_channels, + bool const &use_bias, + ::FlexFlow::DataType const &data_type, + ::FlexFlow::Activation const &activation, + std::optional<::FlexFlow::RegularizerAttrs> const ®ularizer); + + bool operator==(LinearAttrs const &) const; + bool operator!=(LinearAttrs const &) const; + bool operator<(LinearAttrs const &) const; + bool operator>(LinearAttrs const &) const; + bool operator<=(LinearAttrs const &) const; + bool operator>=(LinearAttrs const &) const; + int out_channels; + bool use_bias; + ::FlexFlow::DataType data_type; + ::FlexFlow::Activation activation; + std::optional<::FlexFlow::RegularizerAttrs> regularizer; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LinearAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::LinearAttrs from_json(json const &); + static void to_json(json &, FlexFlow::LinearAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(LinearAttrs const &); +std::ostream &operator<<(std::ostream &, LinearAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml new file mode 100644 index 0000000000..1168276890 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml @@ -0,0 +1,37 @@ +namespace = "FlexFlow" +name = "LinearAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/datatype.h", + "op-attrs/activation.h", + "op-attrs/regularizer_attrs.h", + "utils/json.h", +] + +[[fields]] +name = "out_channels" +type = "int" + +[[fields]] +name = "use_bias" +type = "bool" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" + +[[fields]] +name = "activation" +type = "::FlexFlow::Activation" + +[[fields]] +name = "regularizer" +type = "std::optional<::FlexFlow::RegularizerAttrs>" diff --git a/lib/op-attrs/include/op-attrs/ops/noop.h b/lib/op-attrs/include/op-attrs/ops/noop.h index 658e1b7d98..f5d2090201 100644 --- a/lib/op-attrs/include/op-attrs/ops/noop.h +++ b/lib/op-attrs/include/op-attrs/ops/noop.h @@ -2,12 +2,10 @@ #define _FLEXFLOW_OP_ATTRS_OPS_NOOP_H #include "core.h" -#include "utils/visitable.h" +#include "op-attrs/ops/noop_attrs.h" namespace FlexFlow { -struct NoopAttrs {}; -FF_VISITABLE_STRUCT(NoopAttrs); CHECK_VALID_OP_ATTR(NoopAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/noop_attrs.h b/lib/op-attrs/include/op-attrs/ops/noop_attrs.h new file mode 100644 index 0000000000..35473197d9 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/noop_attrs.h @@ -0,0 +1,54 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct NoopAttrs { + bool operator==(NoopAttrs const &) const; + bool operator!=(NoopAttrs const &) const; + bool operator<(NoopAttrs const &) const; + bool operator>(NoopAttrs const &) const; + bool operator<=(NoopAttrs const &) const; + bool operator>=(NoopAttrs const &) const; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::NoopAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::NoopAttrs from_json(json const &); + static void to_json(json &, FlexFlow::NoopAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(NoopAttrs const &); +std::ostream &operator<<(std::ostream &, NoopAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml new file mode 100644 index 0000000000..3d9202093c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "NoopAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index efe29b3b2e..b766edcdaf 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -2,57 +2,13 @@ #define _FLEXFLOW_POOL_2D_ATTRS_H #include "core.h" -#include "op-attrs/activation.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/pool_2d_attrs.h" namespace FlexFlow { -enum class PoolOp { - MAX, - AVG, -}; - -struct Pool2DAttrs { - req kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w; - req pool_type; - req activation; -}; -FF_VISITABLE_STRUCT(Pool2DAttrs, - kernel_h, - kernel_w, - stride_h, - stride_w, - padding_h, - padding_w, - pool_type, - activation); CHECK_VALID_OP_ATTR(Pool2DAttrs); } // namespace FlexFlow -namespace fmt { - -template <> -struct formatter<::FlexFlow::PoolOp> : formatter { - template - auto format(::FlexFlow::PoolOp o, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (o) { - case PoolOp::AVG: - name = "Avg"; - break; - case PoolOp::MAX: - name = "Max"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - #endif diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.h b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.h new file mode 100644 index 0000000000..f8bbfdf320 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/activation.h" +#include "op-attrs/pool_op.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct Pool2DAttrs { + Pool2DAttrs() = delete; + Pool2DAttrs(int const &kernel_h, + int const &kernel_w, + int const &stride_h, + int const &stride_w, + int const &padding_h, + int const &padding_w, + ::FlexFlow::PoolOp const &pool_type, + ::FlexFlow::Activation const &activation); + + bool operator==(Pool2DAttrs const &) const; + bool operator!=(Pool2DAttrs const &) const; + bool operator<(Pool2DAttrs const &) const; + bool operator>(Pool2DAttrs const &) const; + bool operator<=(Pool2DAttrs const &) const; + bool operator>=(Pool2DAttrs const &) const; + int kernel_h; + int kernel_w; + int stride_h; + int stride_w; + int padding_h; + int padding_w; + ::FlexFlow::PoolOp pool_type; + ::FlexFlow::Activation activation; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::Pool2DAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::Pool2DAttrs from_json(json const &); + static void to_json(json &, FlexFlow::Pool2DAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(Pool2DAttrs const &); +std::ostream &operator<<(std::ostream &, Pool2DAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml new file mode 100644 index 0000000000..24e9a814de --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml @@ -0,0 +1,47 @@ +namespace = "FlexFlow" +name = "Pool2DAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/pool_op.h", + "op-attrs/activation.h", +] + +[[fields]] +name = "kernel_h" +type = "int" + +[[fields]] +name = "kernel_w" +type = "int" + +[[fields]] +name = "stride_h" +type = "int" + +[[fields]] +name = "stride_w" +type = "int" + +[[fields]] +name = "padding_h" +type = "int" + +[[fields]] +name = "padding_w" +type = "int" + +[[fields]] +name = "pool_type" +type = "::FlexFlow::PoolOp" + +[[fields]] +name = "activation" +type = "::FlexFlow::Activation" diff --git a/lib/op-attrs/include/op-attrs/ops/reduce.h b/lib/op-attrs/include/op-attrs/ops/reduce.h index 193d3b0dc8..9923bda684 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce.h @@ -2,20 +2,11 @@ #define _FLEXFLOW_OP_META_OPS_REDUCE_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/op.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/stack_vector.h" -#include "utils/visitable.h" +#include "op-attrs/ops/reduce_attrs.h" namespace FlexFlow { -struct ReduceAttrs { - stack_vector axes; - req op_type; - req keepdims; -}; -FF_VISITABLE_STRUCT(ReduceAttrs, axes, op_type, keepdims); CHECK_VALID_OP_ATTR(ReduceAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.h b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.h new file mode 100644 index 0000000000..f1a94788a4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.h" +#include "op-attrs/op.h" +#include "utils/stack_vector.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ReduceAttrs { + ReduceAttrs() = delete; + ReduceAttrs(::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, + MAX_TENSOR_DIM> const &axes, + ::FlexFlow::Op const &op_type, + bool const &keepdims); + + bool operator==(ReduceAttrs const &) const; + bool operator!=(ReduceAttrs const &) const; + bool operator<(ReduceAttrs const &) const; + bool operator>(ReduceAttrs const &) const; + bool operator<=(ReduceAttrs const &) const; + bool operator>=(ReduceAttrs const &) const; + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> axes; + ::FlexFlow::Op op_type; + bool keepdims; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReduceAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReduceAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ReduceAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ReduceAttrs const &); +std::ostream &operator<<(std::ostream &, ReduceAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml new file mode 100644 index 0000000000..b5c12c80fe --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "ReduceAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/op.h", + "op-attrs/ff_dim.h", + "utils/stack_vector.h", +] + +[[fields]] +name = "axes" +type = "::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>" + +[[fields]] +name = "op_type" +type = "::FlexFlow::Op" + +[[fields]] +name = "keepdims" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/reduction.h b/lib/op-attrs/include/op-attrs/ops/reduction.h index f848f879fc..8005d3d64f 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction.h @@ -2,17 +2,11 @@ #define _FLEXFLOW_REDUCTION_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/reduction_attrs.h" namespace FlexFlow { -struct ReductionAttrs { - ff_dim_t reduction_dim; - req reduction_degree; -}; -FF_VISITABLE_STRUCT(ReductionAttrs, reduction_dim, reduction_degree); CHECK_VALID_OP_ATTR(ReductionAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.h b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.h new file mode 100644 index 0000000000..903cb1d004 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.h @@ -0,0 +1,53 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ReductionAttrs { + ReductionAttrs() = delete; + ReductionAttrs(::FlexFlow::ff_dim_t const &reduction_dim, + int const &reduction_degree); + + bool operator==(ReductionAttrs const &) const; + bool operator!=(ReductionAttrs const &) const; + bool operator<(ReductionAttrs const &) const; + bool operator>(ReductionAttrs const &) const; + bool operator<=(ReductionAttrs const &) const; + bool operator>=(ReductionAttrs const &) const; + ::FlexFlow::ff_dim_t reduction_dim; + int reduction_degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReductionAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReductionAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ReductionAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ReductionAttrs const &); +std::ostream &operator<<(std::ostream &, ReductionAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml new file mode 100644 index 0000000000..f0a5f2e08d --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "ReductionAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", +] + +[[fields]] +name = "reduction_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "reduction_degree" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/repartition.h b/lib/op-attrs/include/op-attrs/ops/repartition.h index 83c4ae870b..339b494855 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition.h @@ -2,17 +2,11 @@ #define _FLEXFLOW_PARTITION_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/repartition_attrs.h" namespace FlexFlow { -struct RepartitionAttrs { - ff_dim_t repartition_dim; - req repartition_degree; -}; -FF_VISITABLE_STRUCT(RepartitionAttrs, repartition_dim, repartition_degree); CHECK_VALID_OP_ATTR(RepartitionAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.h b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.h new file mode 100644 index 0000000000..9ffb03122e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.h @@ -0,0 +1,53 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct RepartitionAttrs { + RepartitionAttrs() = delete; + RepartitionAttrs(::FlexFlow::ff_dim_t const &repartition_dim, + int const &repartition_degree); + + bool operator==(RepartitionAttrs const &) const; + bool operator!=(RepartitionAttrs const &) const; + bool operator<(RepartitionAttrs const &) const; + bool operator>(RepartitionAttrs const &) const; + bool operator<=(RepartitionAttrs const &) const; + bool operator>=(RepartitionAttrs const &) const; + ::FlexFlow::ff_dim_t repartition_dim; + int repartition_degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::RepartitionAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::RepartitionAttrs from_json(json const &); + static void to_json(json &, FlexFlow::RepartitionAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(RepartitionAttrs const &); +std::ostream &operator<<(std::ostream &, RepartitionAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml new file mode 100644 index 0000000000..4fca9e8fb4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "RepartitionAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", +] + +[[fields]] +name = "repartition_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "repartition_degree" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index 92e64a4120..bc96c87808 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -2,17 +2,11 @@ #define _FLEXFLOW_REPLICATE_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/replicate_attrs.h" namespace FlexFlow { -struct ReplicateAttrs { - ff_dim_t replicate_dim; - req replicate_degree; -}; -FF_VISITABLE_STRUCT(ReplicateAttrs, replicate_dim, replicate_degree); CHECK_VALID_OP_ATTR(ReplicateAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.h b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.h new file mode 100644 index 0000000000..ba03f05889 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.h @@ -0,0 +1,53 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ReplicateAttrs { + ReplicateAttrs() = delete; + ReplicateAttrs(::FlexFlow::ff_dim_t const &replicate_dim, + int const &replicate_degree); + + bool operator==(ReplicateAttrs const &) const; + bool operator!=(ReplicateAttrs const &) const; + bool operator<(ReplicateAttrs const &) const; + bool operator>(ReplicateAttrs const &) const; + bool operator<=(ReplicateAttrs const &) const; + bool operator>=(ReplicateAttrs const &) const; + ::FlexFlow::ff_dim_t replicate_dim; + int replicate_degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReplicateAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReplicateAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ReplicateAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ReplicateAttrs const &); +std::ostream &operator<<(std::ostream &, ReplicateAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml new file mode 100644 index 0000000000..d892e2033e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "ReplicateAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", +] + +[[fields]] +name = "replicate_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "replicate_degree" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/reshape.h b/lib/op-attrs/include/op-attrs/ops/reshape.h index b118482a2b..dec4a6fde7 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape.h +++ b/lib/op-attrs/include/op-attrs/ops/reshape.h @@ -2,15 +2,10 @@ #define _FLEXFLOW_RESHAPE_ATTRS_H #include "core.h" -#include "op-attrs/tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/reshape_attrs.h" namespace FlexFlow { -struct ReshapeAttrs { - TensorShape shape; -}; -FF_VISITABLE_STRUCT(ReshapeAttrs, shape); CHECK_VALID_OP_ATTR(ReshapeAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.h b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.h new file mode 100644 index 0000000000..3c326d51e9 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.h @@ -0,0 +1,51 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/tensor_shape.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ReshapeAttrs { + ReshapeAttrs() = delete; + ReshapeAttrs(::FlexFlow::TensorShape const &shape); + + bool operator==(ReshapeAttrs const &) const; + bool operator!=(ReshapeAttrs const &) const; + bool operator<(ReshapeAttrs const &) const; + bool operator>(ReshapeAttrs const &) const; + bool operator<=(ReshapeAttrs const &) const; + bool operator>=(ReshapeAttrs const &) const; + ::FlexFlow::TensorShape shape; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReshapeAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReshapeAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ReshapeAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ReshapeAttrs const &); +std::ostream &operator<<(std::ostream &, ReshapeAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml new file mode 100644 index 0000000000..9086cbccae --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "ReshapeAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_shape.h", +] + +[[fields]] +name = "shape" +type = "::FlexFlow::TensorShape" diff --git a/lib/op-attrs/include/op-attrs/ops/reverse.h b/lib/op-attrs/include/op-attrs/ops/reverse.h index 6030285f14..2f7243ff58 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse.h +++ b/lib/op-attrs/include/op-attrs/ops/reverse.h @@ -2,15 +2,10 @@ #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "utils/visitable.h" +#include "op-attrs/ops/reverse_attrs.h" namespace FlexFlow { -struct ReverseAttrs { - ff_dim_t axis; -}; -FF_VISITABLE_STRUCT(ReverseAttrs, axis); CHECK_VALID_OP_ATTR(ReverseAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.h b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.h new file mode 100644 index 0000000000..d43363063b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.h @@ -0,0 +1,51 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ReverseAttrs { + ReverseAttrs() = delete; + ReverseAttrs(::FlexFlow::ff_dim_t const &axis); + + bool operator==(ReverseAttrs const &) const; + bool operator!=(ReverseAttrs const &) const; + bool operator<(ReverseAttrs const &) const; + bool operator>(ReverseAttrs const &) const; + bool operator<=(ReverseAttrs const &) const; + bool operator>=(ReverseAttrs const &) const; + ::FlexFlow::ff_dim_t axis; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReverseAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReverseAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ReverseAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ReverseAttrs const &); +std::ostream &operator<<(std::ostream &, ReverseAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml new file mode 100644 index 0000000000..572b33957e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "ReverseAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", +] + +[[fields]] +name = "axis" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/softmax.h b/lib/op-attrs/include/op-attrs/ops/softmax.h index 9a776737f5..272610a77c 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax.h +++ b/lib/op-attrs/include/op-attrs/ops/softmax.h @@ -2,16 +2,11 @@ #define _FLEXFLOW_SOFTMAX_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/softmax_attrs.h" namespace FlexFlow { -struct SoftmaxAttrs { - ff_dim_t dim; -}; -FF_VISITABLE_STRUCT(SoftmaxAttrs, dim); CHECK_VALID_OP_ATTR(SoftmaxAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.h b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.h new file mode 100644 index 0000000000..3e467d6cc1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.h @@ -0,0 +1,51 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct SoftmaxAttrs { + SoftmaxAttrs() = delete; + SoftmaxAttrs(::FlexFlow::ff_dim_t const &dim); + + bool operator==(SoftmaxAttrs const &) const; + bool operator!=(SoftmaxAttrs const &) const; + bool operator<(SoftmaxAttrs const &) const; + bool operator>(SoftmaxAttrs const &) const; + bool operator<=(SoftmaxAttrs const &) const; + bool operator>=(SoftmaxAttrs const &) const; + ::FlexFlow::ff_dim_t dim; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SoftmaxAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::SoftmaxAttrs from_json(json const &); + static void to_json(json &, FlexFlow::SoftmaxAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(SoftmaxAttrs const &); +std::ostream &operator<<(std::ostream &, SoftmaxAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml new file mode 100644 index 0000000000..380fdc12d3 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "SoftmaxAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", +] + +[[fields]] +name = "dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/split.h b/lib/op-attrs/include/op-attrs/ops/split.h index fa66bc46f5..7bd2b0ff1a 100644 --- a/lib/op-attrs/include/op-attrs/ops/split.h +++ b/lib/op-attrs/include/op-attrs/ops/split.h @@ -2,17 +2,11 @@ #define _FLEXFLOW_SPLIT_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/split_attrs.h" namespace FlexFlow { -struct SplitAttrs { - req> splits; - ff_dim_t axis; -}; -FF_VISITABLE_STRUCT(SplitAttrs, splits, axis); CHECK_VALID_OP_ATTR(SplitAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/split_attrs.h b/lib/op-attrs/include/op-attrs/ops/split_attrs.h new file mode 100644 index 0000000000..6edefae88d --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/split_attrs.h @@ -0,0 +1,54 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.h" +#include "utils/stack_vector.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct SplitAttrs { + SplitAttrs() = delete; + SplitAttrs(::FlexFlow::stack_vector const &splits, + ::FlexFlow::ff_dim_t const &axis); + + bool operator==(SplitAttrs const &) const; + bool operator!=(SplitAttrs const &) const; + bool operator<(SplitAttrs const &) const; + bool operator>(SplitAttrs const &) const; + bool operator<=(SplitAttrs const &) const; + bool operator>=(SplitAttrs const &) const; + ::FlexFlow::stack_vector splits; + ::FlexFlow::ff_dim_t axis; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SplitAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::SplitAttrs from_json(json const &); + static void to_json(json &, FlexFlow::SplitAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(SplitAttrs const &); +std::ostream &operator<<(std::ostream &, SplitAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml new file mode 100644 index 0000000000..ca73d4d2a1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "SplitAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "utils/stack_vector.h", + "op-attrs/ff_dim.h", +] + +[[fields]] +name = "splits" +type = "::FlexFlow::stack_vector" + +[[fields]] +name = "axis" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index 413855913c..b059e5071f 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -3,15 +3,10 @@ #include "core.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/topk_attrs.h" namespace FlexFlow { -struct TopKAttrs { - req k; - req sorted; -}; -FF_VISITABLE_STRUCT(TopKAttrs, k, sorted); CHECK_VALID_OP_ATTR(TopKAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/topk_attrs.h b/lib/op-attrs/include/op-attrs/ops/topk_attrs.h new file mode 100644 index 0000000000..4d12a3fa41 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/topk_attrs.h @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct TopKAttrs { + TopKAttrs() = delete; + TopKAttrs(int const &k, bool const &sorted); + + bool operator==(TopKAttrs const &) const; + bool operator!=(TopKAttrs const &) const; + bool operator<(TopKAttrs const &) const; + bool operator>(TopKAttrs const &) const; + bool operator<=(TopKAttrs const &) const; + bool operator>=(TopKAttrs const &) const; + int k; + bool sorted; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TopKAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TopKAttrs from_json(json const &); + static void to_json(json &, FlexFlow::TopKAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(TopKAttrs const &); +std::ostream &operator<<(std::ostream &, TopKAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml new file mode 100644 index 0000000000..9ecbf1d725 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "TopKAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "k" +type = "int" + +[[fields]] +name = "sorted" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/transpose.h b/lib/op-attrs/include/op-attrs/ops/transpose.h index 87db435979..847f660f1a 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose.h @@ -2,16 +2,11 @@ #define _FLEXFLOW_OP_META_OPS_TRANSPOSE_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/transpose_attrs.h" namespace FlexFlow { -struct TransposeAttrs { - req> perm; -}; -FF_VISITABLE_STRUCT(TransposeAttrs, perm); CHECK_VALID_OP_ATTR(TransposeAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.h b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.h new file mode 100644 index 0000000000..0613561a3b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.h @@ -0,0 +1,53 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_ATTRS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.h" +#include "utils/stack_vector.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct TransposeAttrs { + TransposeAttrs() = delete; + TransposeAttrs(::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, + MAX_TENSOR_DIM> const &perm); + + bool operator==(TransposeAttrs const &) const; + bool operator!=(TransposeAttrs const &) const; + bool operator<(TransposeAttrs const &) const; + bool operator>(TransposeAttrs const &) const; + bool operator<=(TransposeAttrs const &) const; + bool operator>=(TransposeAttrs const &) const; + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> perm; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TransposeAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TransposeAttrs from_json(json const &); + static void to_json(json &, FlexFlow::TransposeAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TransposeAttrs const &); +std::ostream &operator<<(std::ostream &, TransposeAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_ATTRS_H diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml new file mode 100644 index 0000000000..35646cda40 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "TransposeAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "utils/stack_vector.h", +] + +[[fields]] +name = "perm" +type = "::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>" diff --git a/lib/op-attrs/include/op-attrs/pool_op.h b/lib/op-attrs/include/op-attrs/pool_op.h new file mode 100644 index 0000000000..6006702156 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pool_op.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_H + +#include "utils/fmt.h" +#include "nlohmann/json.hpp" + +namespace FlexFlow { + +enum class PoolOp { + MAX, + AVG, +}; + +NLOHMANN_JSON_SERIALIZE_ENUM(PoolOp, + {{PoolOp::MAX, "MAX"}, + {PoolOp::AVG, "AVG"}}); + +std::string format_as(PoolOp); + +} // namespace FlexFlow +#endif diff --git a/lib/op-attrs/include/op-attrs/regularizer_attrs.h b/lib/op-attrs/include/op-attrs/regularizer_attrs.h new file mode 100644 index 0000000000..73c53d4f4b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/regularizer_attrs.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REGULARIZER_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REGULARIZER_ATTRS_H + +#include "op-attrs/l1_regularizer_attrs.h" +#include "op-attrs/l2_regularizer_attrs.h" + +namespace FlexFlow { + +using RegularizerAttrs = std::variant; + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/src/op-attrs/aggregate_op.cc b/lib/op-attrs/src/op-attrs/aggregate_op.cc new file mode 100644 index 0000000000..1d0bfa9f55 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/aggregate_op.cc @@ -0,0 +1,17 @@ +#include "op-attrs/aggregate_op.h" +#include "utils/exception.h" + +namespace FlexFlow { + +std::string format_as(AggregateOp o) { + switch (o) { + case AggregateOp::SUM: + return "SUM"; + case AggregateOp::AVG: + return "AVG"; + default: + throw mk_runtime_error(fmt::format("Unknown aggregate op {}", static_cast(o))); + } +} + +} diff --git a/lib/op-attrs/src/op-attrs/ff_dim.cc b/lib/op-attrs/src/op-attrs/ff_dim.cc new file mode 100644 index 0000000000..7c3c3293df --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ff_dim.cc @@ -0,0 +1,61 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ff_dim.struct.toml + +#include "op-attrs/ff_dim.h" + +namespace FlexFlow { +ff_dim_t::ff_dim_t(int const &value) : value(value) {} +bool ff_dim_t::operator==(ff_dim_t const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool ff_dim_t::operator!=(ff_dim_t const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool ff_dim_t::operator<(ff_dim_t const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool ff_dim_t::operator>(ff_dim_t const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool ff_dim_t::operator<=(ff_dim_t const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool ff_dim_t::operator>=(ff_dim_t const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()(FlexFlow::ff_dim_t const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ff_dim_t + adl_serializer::from_json(json const &j) { + return {j.at("value").template get()}; +} +void adl_serializer::to_json(json &j, + FlexFlow::ff_dim_t const &v) { + j["__type"] = "ff_dim_t"; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ff_dim_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ff_dim_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.cc b/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.cc new file mode 100644 index 0000000000..e5a75b1201 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.cc @@ -0,0 +1,69 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml + +#include "op-attrs/l1_regularizer_attrs.h" + +namespace FlexFlow { +L1RegularizerAttrs::L1RegularizerAttrs(float const &lambda) : lambda(lambda) {} +bool L1RegularizerAttrs::operator==(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) == std::tie(other.lambda); +} +bool L1RegularizerAttrs::operator!=(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) != std::tie(other.lambda); +} +bool L1RegularizerAttrs::operator<(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) < std::tie(other.lambda); +} +bool L1RegularizerAttrs::operator>(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) > std::tie(other.lambda); +} +bool L1RegularizerAttrs::operator<=(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) <= std::tie(other.lambda); +} +bool L1RegularizerAttrs::operator>=(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) >= std::tie(other.lambda); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::L1RegularizerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.lambda) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::L1RegularizerAttrs + adl_serializer::from_json(json const &j) { + return {j.at("lambda").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::L1RegularizerAttrs const &v) { + j["__type"] = "L1RegularizerAttrs"; + j["lambda"] = v.lambda; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(L1RegularizerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, L1RegularizerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.cc b/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.cc new file mode 100644 index 0000000000..45e48044ee --- /dev/null +++ b/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.cc @@ -0,0 +1,69 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml + +#include "op-attrs/l2_regularizer_attrs.h" + +namespace FlexFlow { +L2RegularizerAttrs::L2RegularizerAttrs(float const &lambda) : lambda(lambda) {} +bool L2RegularizerAttrs::operator==(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) == std::tie(other.lambda); +} +bool L2RegularizerAttrs::operator!=(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) != std::tie(other.lambda); +} +bool L2RegularizerAttrs::operator<(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) < std::tie(other.lambda); +} +bool L2RegularizerAttrs::operator>(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) > std::tie(other.lambda); +} +bool L2RegularizerAttrs::operator<=(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) <= std::tie(other.lambda); +} +bool L2RegularizerAttrs::operator>=(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) >= std::tie(other.lambda); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::L2RegularizerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.lambda) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::L2RegularizerAttrs + adl_serializer::from_json(json const &j) { + return {j.at("lambda").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::L2RegularizerAttrs const &v) { + j["__type"] = "L2RegularizerAttrs"; + j["lambda"] = v.lambda; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(L2RegularizerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, L2RegularizerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention_attrs.cc b/lib/op-attrs/src/op-attrs/ops/attention_attrs.cc new file mode 100644 index 0000000000..728359ff25 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention_attrs.cc @@ -0,0 +1,213 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml + +#include "op-attrs/ops/attention_attrs.h" + +namespace FlexFlow { +MultiHeadAttentionAttrs::MultiHeadAttentionAttrs(int const &embed_dim, + int const &num_heads, + int const &kdim, + int const &vdim, + float const &dropout, + bool const &bias, + bool const &add_bias_kv, + bool const &add_zero_attn) + : embed_dim(embed_dim), num_heads(num_heads), kdim(kdim), vdim(vdim), + dropout(dropout), bias(bias), add_bias_kv(add_bias_kv), + add_zero_attn(add_zero_attn) {} +bool MultiHeadAttentionAttrs::operator==( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) == std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +bool MultiHeadAttentionAttrs::operator!=( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) != std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +bool MultiHeadAttentionAttrs::operator<( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) < std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +bool MultiHeadAttentionAttrs::operator>( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) > std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +bool MultiHeadAttentionAttrs::operator<=( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) <= std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +bool MultiHeadAttentionAttrs::operator>=( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) >= std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MultiHeadAttentionAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.embed_dim) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.num_heads) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.kdim) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.vdim) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.dropout) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.bias) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.add_bias_kv) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.add_zero_attn) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MultiHeadAttentionAttrs + adl_serializer::from_json( + json const &j) { + return {j.at("embed_dim").template get(), + j.at("num_heads").template get(), + j.at("kdim").template get(), + j.at("vdim").template get(), + j.at("dropout").template get(), + j.at("bias").template get(), + j.at("add_bias_kv").template get(), + j.at("add_zero_attn").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MultiHeadAttentionAttrs const &v) { + j["__type"] = "MultiHeadAttentionAttrs"; + j["embed_dim"] = v.embed_dim; + j["num_heads"] = v.num_heads; + j["kdim"] = v.kdim; + j["vdim"] = v.vdim; + j["dropout"] = v.dropout; + j["bias"] = v.bias; + j["add_bias_kv"] = v.add_bias_kv; + j["add_zero_attn"] = v.add_zero_attn; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MultiHeadAttentionAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc new file mode 100644 index 0000000000..157654fa53 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc @@ -0,0 +1,83 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml + +#include "op-attrs/ops/batch_matmul.h" + +namespace FlexFlow { +BatchMatmulAttrs::BatchMatmulAttrs(int const &a_seq_length_dim, + int const &b_seq_length_dim) + : a_seq_length_dim(a_seq_length_dim), b_seq_length_dim(b_seq_length_dim) {} +bool BatchMatmulAttrs::operator==(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) == + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator!=(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) != + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator<(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) < + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator>(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) > + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator<=(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) <= + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator>=(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) >= + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::BatchMatmulAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.a_seq_length_dim) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.b_seq_length_dim) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::BatchMatmulAttrs + adl_serializer::from_json(json const &j) { + return {j.at("a_seq_length_dim").template get(), + j.at("b_seq_length_dim").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::BatchMatmulAttrs const &v) { + j["__type"] = "BatchMatmulAttrs"; + j["a_seq_length_dim"] = v.a_seq_length_dim; + j["b_seq_length_dim"] = v.b_seq_length_dim; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(BatchMatmulAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, BatchMatmulAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.cc new file mode 100644 index 0000000000..84ae40115d --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.cc @@ -0,0 +1,68 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml + +#include "op-attrs/ops/batch_norm_attrs.h" + +namespace FlexFlow { +BatchNormAttrs::BatchNormAttrs(bool const &relu) : relu(relu) {} +bool BatchNormAttrs::operator==(BatchNormAttrs const &other) const { + return std::tie(this->relu) == std::tie(other.relu); +} +bool BatchNormAttrs::operator!=(BatchNormAttrs const &other) const { + return std::tie(this->relu) != std::tie(other.relu); +} +bool BatchNormAttrs::operator<(BatchNormAttrs const &other) const { + return std::tie(this->relu) < std::tie(other.relu); +} +bool BatchNormAttrs::operator>(BatchNormAttrs const &other) const { + return std::tie(this->relu) > std::tie(other.relu); +} +bool BatchNormAttrs::operator<=(BatchNormAttrs const &other) const { + return std::tie(this->relu) <= std::tie(other.relu); +} +bool BatchNormAttrs::operator>=(BatchNormAttrs const &other) const { + return std::tie(this->relu) >= std::tie(other.relu); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::BatchNormAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.relu) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::BatchNormAttrs + adl_serializer::from_json(json const &j) { + return {j.at("relu").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::BatchNormAttrs const &v) { + j["__type"] = "BatchNormAttrs"; + j["relu"] = v.relu; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(BatchNormAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, BatchNormAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.cc index 2553ac5801..823df8e01b 100644 --- a/lib/op-attrs/src/op-attrs/ops/broadcast.cc +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.cc @@ -6,7 +6,7 @@ namespace FlexFlow { BroadcastAttrs::BroadcastAttrs( - stack_vector const &target_dims) + ::FlexFlow::stack_vector const &target_dims) : target_dims(target_dims) {} bool BroadcastAttrs::operator==(BroadcastAttrs const &other) const { return std::tie(this->target_dims) == std::tie(other.target_dims); @@ -32,7 +32,8 @@ namespace std { size_t hash::operator()( FlexFlow::BroadcastAttrs const &x) const { size_t result = 0; - result ^= std::hash>{}(x.target_dims) + + result ^= std::hash<::FlexFlow::stack_vector>{}( + x.target_dims) + 0x9e3779b9 + (result << 6) + (result >> 2); return result; } @@ -41,8 +42,8 @@ size_t hash::operator()( namespace nlohmann { FlexFlow::BroadcastAttrs adl_serializer::from_json(json const &j) { - return { - j.at("target_dims").template get>()}; + return {j.at("target_dims") + .template get<::FlexFlow::stack_vector>()}; } void adl_serializer::to_json( json &j, FlexFlow::BroadcastAttrs const &v) { @@ -50,3 +51,16 @@ void adl_serializer::to_json( j["target_dims"] = v.target_dims; } } // namespace nlohmann + +namespace FlexFlow { +std::string format_as(BroadcastAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, BroadcastAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/cast_attrs.cc b/lib/op-attrs/src/op-attrs/ops/cast_attrs.cc new file mode 100644 index 0000000000..8ce883341c --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/cast_attrs.cc @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml + +#include "op-attrs/ops/cast_attrs.h" + +namespace FlexFlow { +CastAttrs::CastAttrs(DataType const &dtype) : dtype(dtype) {} +bool CastAttrs::operator==(CastAttrs const &other) const { + return std::tie(this->dtype) == std::tie(other.dtype); +} +bool CastAttrs::operator!=(CastAttrs const &other) const { + return std::tie(this->dtype) != std::tie(other.dtype); +} +bool CastAttrs::operator<(CastAttrs const &other) const { + return std::tie(this->dtype) < std::tie(other.dtype); +} +bool CastAttrs::operator>(CastAttrs const &other) const { + return std::tie(this->dtype) > std::tie(other.dtype); +} +bool CastAttrs::operator<=(CastAttrs const &other) const { + return std::tie(this->dtype) <= std::tie(other.dtype); +} +bool CastAttrs::operator>=(CastAttrs const &other) const { + return std::tie(this->dtype) >= std::tie(other.dtype); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(FlexFlow::CastAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.dtype) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::CastAttrs + adl_serializer::from_json(json const &j) { + return {j.at("dtype").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::CastAttrs const &v) { + j["__type"] = "CastAttrs"; + j["dtype"] = v.dtype; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(CastAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, CastAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/combine_attrs.cc b/lib/op-attrs/src/op-attrs/ops/combine_attrs.cc new file mode 100644 index 0000000000..c0cf8e0d08 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/combine_attrs.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml + +#include "op-attrs/ops/combine_attrs.h" + +namespace FlexFlow { +CombineAttrs::CombineAttrs(::FlexFlow::ff_dim_t const &combine_dim, + int const &combine_degree) + : combine_dim(combine_dim), combine_degree(combine_degree) {} +bool CombineAttrs::operator==(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) == + std::tie(other.combine_dim, other.combine_degree); +} +bool CombineAttrs::operator!=(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) != + std::tie(other.combine_dim, other.combine_degree); +} +bool CombineAttrs::operator<(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) < + std::tie(other.combine_dim, other.combine_degree); +} +bool CombineAttrs::operator>(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) > + std::tie(other.combine_dim, other.combine_degree); +} +bool CombineAttrs::operator<=(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) <= + std::tie(other.combine_dim, other.combine_degree); +} +bool CombineAttrs::operator>=(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) >= + std::tie(other.combine_dim, other.combine_degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::CombineAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.combine_dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.combine_degree) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::CombineAttrs + adl_serializer::from_json(json const &j) { + return {j.at("combine_dim").template get<::FlexFlow::ff_dim_t>(), + j.at("combine_degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::CombineAttrs const &v) { + j["__type"] = "CombineAttrs"; + j["combine_dim"] = v.combine_dim; + j["combine_degree"] = v.combine_degree; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(CombineAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, CombineAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/concat_attrs.cc b/lib/op-attrs/src/op-attrs/ops/concat_attrs.cc new file mode 100644 index 0000000000..8e16552fa0 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/concat_attrs.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml + +#include "op-attrs/ops/concat_attrs.h" + +namespace FlexFlow { +ConcatAttrs::ConcatAttrs(::FlexFlow::ff_dim_t const &axis, + int const &num_inputs) + : axis(axis), num_inputs(num_inputs) {} +bool ConcatAttrs::operator==(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) == + std::tie(other.axis, other.num_inputs); +} +bool ConcatAttrs::operator!=(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) != + std::tie(other.axis, other.num_inputs); +} +bool ConcatAttrs::operator<(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) < + std::tie(other.axis, other.num_inputs); +} +bool ConcatAttrs::operator>(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) > + std::tie(other.axis, other.num_inputs); +} +bool ConcatAttrs::operator<=(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) <= + std::tie(other.axis, other.num_inputs); +} +bool ConcatAttrs::operator>=(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) >= + std::tie(other.axis, other.num_inputs); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ConcatAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.axis) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.num_inputs) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ConcatAttrs + adl_serializer::from_json(json const &j) { + return {j.at("axis").template get<::FlexFlow::ff_dim_t>(), + j.at("num_inputs").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ConcatAttrs const &v) { + j["__type"] = "ConcatAttrs"; + j["axis"] = v.axis; + j["num_inputs"] = v.num_inputs; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ConcatAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ConcatAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.cc new file mode 100644 index 0000000000..5085b3e121 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.cc @@ -0,0 +1,230 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml + +#include "op-attrs/ops/conv_2d_attrs.h" + +namespace FlexFlow { +Conv2DAttrs::Conv2DAttrs( + int const &out_channels, + int const &kernel_h, + int const &kernel_w, + int const &stride_h, + int const &stride_w, + int const &padding_h, + int const &padding_w, + int const &groups, + std::optional<::FlexFlow::Activation> const &activation, + bool const &use_bias) + : out_channels(out_channels), kernel_h(kernel_h), kernel_w(kernel_w), + stride_h(stride_h), stride_w(stride_w), padding_h(padding_h), + padding_w(padding_w), groups(groups), activation(activation), + use_bias(use_bias) {} +bool Conv2DAttrs::operator==(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) == std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +bool Conv2DAttrs::operator!=(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) != std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +bool Conv2DAttrs::operator<(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) < std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +bool Conv2DAttrs::operator>(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) > std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +bool Conv2DAttrs::operator<=(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) <= std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +bool Conv2DAttrs::operator>=(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) >= std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::Conv2DAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.out_channels) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.kernel_h) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.kernel_w) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stride_h) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stride_w) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.padding_h) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.padding_w) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.groups) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash>{}(x.activation) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.use_bias) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::Conv2DAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("out_channels").template get(), + j.at("kernel_h").template get(), + j.at("kernel_w").template get(), + j.at("stride_h").template get(), + j.at("stride_w").template get(), + j.at("padding_h").template get(), + j.at("padding_w").template get(), + j.at("groups").template get(), + j.at("activation").template get>(), + j.at("use_bias").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::Conv2DAttrs const &v) { + j["__type"] = "Conv2DAttrs"; + j["out_channels"] = v.out_channels; + j["kernel_h"] = v.kernel_h; + j["kernel_w"] = v.kernel_w; + j["stride_h"] = v.stride_h; + j["stride_w"] = v.stride_w; + j["padding_h"] = v.padding_h; + j["padding_w"] = v.padding_w; + j["groups"] = v.groups; + j["activation"] = v.activation; + j["use_bias"] = v.use_bias; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(Conv2DAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, Conv2DAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/dropout_attrs.cc b/lib/op-attrs/src/op-attrs/ops/dropout_attrs.cc new file mode 100644 index 0000000000..308725c0b0 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/dropout_attrs.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml + +#include "op-attrs/ops/dropout_attrs.h" + +namespace FlexFlow { +DropoutAttrs::DropoutAttrs(float const &rate, unsigned long long const &seed) + : rate(rate), seed(seed) {} +bool DropoutAttrs::operator==(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) == std::tie(other.rate, other.seed); +} +bool DropoutAttrs::operator!=(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) != std::tie(other.rate, other.seed); +} +bool DropoutAttrs::operator<(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) < std::tie(other.rate, other.seed); +} +bool DropoutAttrs::operator>(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) > std::tie(other.rate, other.seed); +} +bool DropoutAttrs::operator<=(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) <= std::tie(other.rate, other.seed); +} +bool DropoutAttrs::operator>=(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) >= std::tie(other.rate, other.seed); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::DropoutAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.rate) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.seed) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::DropoutAttrs + adl_serializer::from_json(json const &j) { + return {j.at("rate").template get(), + j.at("seed").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::DropoutAttrs const &v) { + j["__type"] = "DropoutAttrs"; + j["rate"] = v.rate; + j["seed"] = v.seed; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(DropoutAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DropoutAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.cc b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.cc new file mode 100644 index 0000000000..0c7c523d5b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.cc @@ -0,0 +1,125 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml + +#include "op-attrs/ops/element_binary_attrs.h" + +namespace FlexFlow { +ElementBinaryAttrs::ElementBinaryAttrs(::FlexFlow::Op const &type, + ::FlexFlow::DataType const &compute_type, + bool const &should_broadcast_lhs, + bool const &should_broadcast_rhs) + : type(type), compute_type(compute_type), + should_broadcast_lhs(should_broadcast_lhs), + should_broadcast_rhs(should_broadcast_rhs) {} +bool ElementBinaryAttrs::operator==(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) == + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +bool ElementBinaryAttrs::operator!=(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) != + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +bool ElementBinaryAttrs::operator<(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) < + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +bool ElementBinaryAttrs::operator>(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) > + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +bool ElementBinaryAttrs::operator<=(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) <= + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +bool ElementBinaryAttrs::operator>=(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) >= + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ElementBinaryAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Op>{}(x.type) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.compute_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.should_broadcast_lhs) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.should_broadcast_rhs) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ElementBinaryAttrs + adl_serializer::from_json(json const &j) { + return {j.at("type").template get<::FlexFlow::Op>(), + j.at("compute_type").template get<::FlexFlow::DataType>(), + j.at("should_broadcast_lhs").template get(), + j.at("should_broadcast_rhs").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ElementBinaryAttrs const &v) { + j["__type"] = "ElementBinaryAttrs"; + j["type"] = v.type; + j["compute_type"] = v.compute_type; + j["should_broadcast_lhs"] = v.should_broadcast_lhs; + j["should_broadcast_rhs"] = v.should_broadcast_rhs; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ElementBinaryAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ElementBinaryAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.cc b/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.cc new file mode 100644 index 0000000000..8c54874c58 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.cc @@ -0,0 +1,82 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml + +#include "op-attrs/ops/element_scalar_unary_attrs.h" + +namespace FlexFlow { +ElementScalarUnaryAttrs::ElementScalarUnaryAttrs(::FlexFlow::Op const &op_type, + float const &scalar) + : op_type(op_type), scalar(scalar) {} +bool ElementScalarUnaryAttrs::operator==( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) == + std::tie(other.op_type, other.scalar); +} +bool ElementScalarUnaryAttrs::operator!=( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) != + std::tie(other.op_type, other.scalar); +} +bool ElementScalarUnaryAttrs::operator<( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) < + std::tie(other.op_type, other.scalar); +} +bool ElementScalarUnaryAttrs::operator>( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) > + std::tie(other.op_type, other.scalar); +} +bool ElementScalarUnaryAttrs::operator<=( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) <= + std::tie(other.op_type, other.scalar); +} +bool ElementScalarUnaryAttrs::operator>=( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) >= + std::tie(other.op_type, other.scalar); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ElementScalarUnaryAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Op>{}(x.op_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.scalar) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ElementScalarUnaryAttrs + adl_serializer::from_json( + json const &j) { + return {j.at("op_type").template get<::FlexFlow::Op>(), + j.at("scalar").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ElementScalarUnaryAttrs const &v) { + j["__type"] = "ElementScalarUnaryAttrs"; + j["op_type"] = v.op_type; + j["scalar"] = v.scalar; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ElementScalarUnaryAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ElementScalarUnaryAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.cc b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.cc new file mode 100644 index 0000000000..1123890ac3 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.cc @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml + +#include "op-attrs/ops/element_unary_attrs.h" + +namespace FlexFlow { +ElementUnaryAttrs::ElementUnaryAttrs(::FlexFlow::Op const &op_type) + : op_type(op_type) {} +bool ElementUnaryAttrs::operator==(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) == std::tie(other.op_type); +} +bool ElementUnaryAttrs::operator!=(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) != std::tie(other.op_type); +} +bool ElementUnaryAttrs::operator<(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) < std::tie(other.op_type); +} +bool ElementUnaryAttrs::operator>(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) > std::tie(other.op_type); +} +bool ElementUnaryAttrs::operator<=(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) <= std::tie(other.op_type); +} +bool ElementUnaryAttrs::operator>=(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) >= std::tie(other.op_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ElementUnaryAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Op>{}(x.op_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ElementUnaryAttrs + adl_serializer::from_json(json const &j) { + return {j.at("op_type").template get<::FlexFlow::Op>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ElementUnaryAttrs const &v) { + j["__type"] = "ElementUnaryAttrs"; + j["op_type"] = v.op_type; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ElementUnaryAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ElementUnaryAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/embedding_attrs.cc b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.cc new file mode 100644 index 0000000000..6cc49ece0b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.cc @@ -0,0 +1,118 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml + +#include "op-attrs/ops/embedding_attrs.h" + +namespace FlexFlow { +EmbeddingAttrs::EmbeddingAttrs(int const &num_entries, + int const &out_channels, + ::FlexFlow::AggregateOp const &aggr, + ::FlexFlow::DataType const &data_type) + : num_entries(num_entries), out_channels(out_channels), aggr(aggr), + data_type(data_type) {} +bool EmbeddingAttrs::operator==(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) == std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +bool EmbeddingAttrs::operator!=(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) != std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +bool EmbeddingAttrs::operator<(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) < std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +bool EmbeddingAttrs::operator>(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) > std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +bool EmbeddingAttrs::operator<=(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) <= std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +bool EmbeddingAttrs::operator>=(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) >= std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::EmbeddingAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.num_entries) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.out_channels) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::AggregateOp>{}(x.aggr) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.data_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::EmbeddingAttrs + adl_serializer::from_json(json const &j) { + return {j.at("num_entries").template get(), + j.at("out_channels").template get(), + j.at("aggr").template get<::FlexFlow::AggregateOp>(), + j.at("data_type").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::EmbeddingAttrs const &v) { + j["__type"] = "EmbeddingAttrs"; + j["num_entries"] = v.num_entries; + j["out_channels"] = v.out_channels; + j["aggr"] = v.aggr; + j["data_type"] = v.data_type; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(EmbeddingAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, EmbeddingAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/flat_attrs.cc b/lib/op-attrs/src/op-attrs/ops/flat_attrs.cc new file mode 100644 index 0000000000..8705a18c86 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/flat_attrs.cc @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml + +#include "op-attrs/ops/flat_attrs.h" + +namespace FlexFlow { +bool FlatAttrs::operator==(FlatAttrs const &other) const { + return std::tie() == std::tie(); +} +bool FlatAttrs::operator!=(FlatAttrs const &other) const { + return std::tie() != std::tie(); +} +bool FlatAttrs::operator<(FlatAttrs const &other) const { + return std::tie() < std::tie(); +} +bool FlatAttrs::operator>(FlatAttrs const &other) const { + return std::tie() > std::tie(); +} +bool FlatAttrs::operator<=(FlatAttrs const &other) const { + return std::tie() <= std::tie(); +} +bool FlatAttrs::operator>=(FlatAttrs const &other) const { + return std::tie() >= std::tie(); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(FlexFlow::FlatAttrs const &x) const { + size_t result = 0; + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::FlatAttrs + adl_serializer::from_json(json const &j) { + return {}; +} +void adl_serializer::to_json( + json &j, FlexFlow::FlatAttrs const &v) { + j["__type"] = "FlatAttrs"; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(FlatAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, FlatAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/gather_attrs.cc b/lib/op-attrs/src/op-attrs/ops/gather_attrs.cc new file mode 100644 index 0000000000..05d6c674a7 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/gather_attrs.cc @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml + +#include "op-attrs/ops/gather_attrs.h" + +namespace FlexFlow { +GatherAttrs::GatherAttrs(::FlexFlow::ff_dim_t const &dim) : dim(dim) {} +bool GatherAttrs::operator==(GatherAttrs const &other) const { + return std::tie(this->dim) == std::tie(other.dim); +} +bool GatherAttrs::operator!=(GatherAttrs const &other) const { + return std::tie(this->dim) != std::tie(other.dim); +} +bool GatherAttrs::operator<(GatherAttrs const &other) const { + return std::tie(this->dim) < std::tie(other.dim); +} +bool GatherAttrs::operator>(GatherAttrs const &other) const { + return std::tie(this->dim) > std::tie(other.dim); +} +bool GatherAttrs::operator<=(GatherAttrs const &other) const { + return std::tie(this->dim) <= std::tie(other.dim); +} +bool GatherAttrs::operator>=(GatherAttrs const &other) const { + return std::tie(this->dim) >= std::tie(other.dim); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::GatherAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::GatherAttrs + adl_serializer::from_json(json const &j) { + return {j.at("dim").template get<::FlexFlow::ff_dim_t>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::GatherAttrs const &v) { + j["__type"] = "GatherAttrs"; + j["dim"] = v.dim; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(GatherAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, GatherAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/input_attrs.cc b/lib/op-attrs/src/op-attrs/ops/input_attrs.cc new file mode 100644 index 0000000000..21ec0a3ba5 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/input_attrs.cc @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml + +#include "op-attrs/ops/input_attrs.h" + +namespace FlexFlow { +bool InputAttrs::operator==(InputAttrs const &other) const { + return std::tie() == std::tie(); +} +bool InputAttrs::operator!=(InputAttrs const &other) const { + return std::tie() != std::tie(); +} +bool InputAttrs::operator<(InputAttrs const &other) const { + return std::tie() < std::tie(); +} +bool InputAttrs::operator>(InputAttrs const &other) const { + return std::tie() > std::tie(); +} +bool InputAttrs::operator<=(InputAttrs const &other) const { + return std::tie() <= std::tie(); +} +bool InputAttrs::operator>=(InputAttrs const &other) const { + return std::tie() >= std::tie(); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::InputAttrs const &x) const { + size_t result = 0; + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::InputAttrs + adl_serializer::from_json(json const &j) { + return {}; +} +void adl_serializer::to_json( + json &j, FlexFlow::InputAttrs const &v) { + j["__type"] = "InputAttrs"; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(InputAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, InputAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.cc new file mode 100644 index 0000000000..714d1f0c3c --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.cc @@ -0,0 +1,88 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml + +#include "op-attrs/ops/layer_norm_attrs.h" + +namespace FlexFlow { +LayerNormAttrs::LayerNormAttrs( + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> const &axes, + bool const &elementwise_affine, + float const &eps) + : axes(axes), elementwise_affine(elementwise_affine), eps(eps) {} +bool LayerNormAttrs::operator==(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) == + std::tie(other.axes, other.elementwise_affine, other.eps); +} +bool LayerNormAttrs::operator!=(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) != + std::tie(other.axes, other.elementwise_affine, other.eps); +} +bool LayerNormAttrs::operator<(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) < + std::tie(other.axes, other.elementwise_affine, other.eps); +} +bool LayerNormAttrs::operator>(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) > + std::tie(other.axes, other.elementwise_affine, other.eps); +} +bool LayerNormAttrs::operator<=(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) <= + std::tie(other.axes, other.elementwise_affine, other.eps); +} +bool LayerNormAttrs::operator>=(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) >= + std::tie(other.axes, other.elementwise_affine, other.eps); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::LayerNormAttrs const &x) const { + size_t result = 0; + result ^= + std::hash< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>{}( + x.axes) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.elementwise_affine) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.eps) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::LayerNormAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("axes") + .template get< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), + j.at("elementwise_affine").template get(), + j.at("eps").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::LayerNormAttrs const &v) { + j["__type"] = "LayerNormAttrs"; + j["axes"] = v.axes; + j["elementwise_affine"] = v.elementwise_affine; + j["eps"] = v.eps; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(LayerNormAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, LayerNormAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/linear_attrs.cc b/lib/op-attrs/src/op-attrs/ops/linear_attrs.cc new file mode 100644 index 0000000000..21d80d1f88 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/linear_attrs.cc @@ -0,0 +1,139 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml + +#include "op-attrs/ops/linear_attrs.h" + +namespace FlexFlow { +LinearAttrs::LinearAttrs( + int const &out_channels, + bool const &use_bias, + ::FlexFlow::DataType const &data_type, + ::FlexFlow::Activation const &activation, + std::optional<::FlexFlow::RegularizerAttrs> const ®ularizer) + : out_channels(out_channels), use_bias(use_bias), data_type(data_type), + activation(activation), regularizer(regularizer) {} +bool LinearAttrs::operator==(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) == std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +bool LinearAttrs::operator!=(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) != std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +bool LinearAttrs::operator<(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) < std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +bool LinearAttrs::operator>(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) > std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +bool LinearAttrs::operator<=(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) <= std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +bool LinearAttrs::operator>=(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) >= std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::LinearAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.out_channels) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.use_bias) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.data_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::Activation>{}(x.activation) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash>{}(x.regularizer) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::LinearAttrs + adl_serializer::from_json(json const &j) { + return {j.at("out_channels").template get(), + j.at("use_bias").template get(), + j.at("data_type").template get<::FlexFlow::DataType>(), + j.at("activation").template get<::FlexFlow::Activation>(), + j.at("regularizer") + .template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::LinearAttrs const &v) { + j["__type"] = "LinearAttrs"; + j["out_channels"] = v.out_channels; + j["use_bias"] = v.use_bias; + j["data_type"] = v.data_type; + j["activation"] = v.activation; + j["regularizer"] = v.regularizer; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(LinearAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, LinearAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/noop_attrs.cc b/lib/op-attrs/src/op-attrs/ops/noop_attrs.cc new file mode 100644 index 0000000000..1c42670d4d --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/noop_attrs.cc @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml + +#include "op-attrs/ops/noop_attrs.h" + +namespace FlexFlow { +bool NoopAttrs::operator==(NoopAttrs const &other) const { + return std::tie() == std::tie(); +} +bool NoopAttrs::operator!=(NoopAttrs const &other) const { + return std::tie() != std::tie(); +} +bool NoopAttrs::operator<(NoopAttrs const &other) const { + return std::tie() < std::tie(); +} +bool NoopAttrs::operator>(NoopAttrs const &other) const { + return std::tie() > std::tie(); +} +bool NoopAttrs::operator<=(NoopAttrs const &other) const { + return std::tie() <= std::tie(); +} +bool NoopAttrs::operator>=(NoopAttrs const &other) const { + return std::tie() >= std::tie(); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(FlexFlow::NoopAttrs const &x) const { + size_t result = 0; + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::NoopAttrs + adl_serializer::from_json(json const &j) { + return {}; +} +void adl_serializer::to_json( + json &j, FlexFlow::NoopAttrs const &v) { + j["__type"] = "NoopAttrs"; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(NoopAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, NoopAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.cc new file mode 100644 index 0000000000..b186433f1f --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.cc @@ -0,0 +1,191 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml + +#include "op-attrs/ops/pool_2d_attrs.h" + +namespace FlexFlow { +Pool2DAttrs::Pool2DAttrs(int const &kernel_h, + int const &kernel_w, + int const &stride_h, + int const &stride_w, + int const &padding_h, + int const &padding_w, + ::FlexFlow::PoolOp const &pool_type, + ::FlexFlow::Activation const &activation) + : kernel_h(kernel_h), kernel_w(kernel_w), stride_h(stride_h), + stride_w(stride_w), padding_h(padding_h), padding_w(padding_w), + pool_type(pool_type), activation(activation) {} +bool Pool2DAttrs::operator==(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) == std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +bool Pool2DAttrs::operator!=(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) != std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +bool Pool2DAttrs::operator<(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) < std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +bool Pool2DAttrs::operator>(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) > std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +bool Pool2DAttrs::operator<=(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) <= std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +bool Pool2DAttrs::operator>=(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) >= std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::Pool2DAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.kernel_h) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.kernel_w) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stride_h) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stride_w) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.padding_h) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.padding_w) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::PoolOp>{}(x.pool_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::Activation>{}(x.activation) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::Pool2DAttrs + adl_serializer::from_json(json const &j) { + return {j.at("kernel_h").template get(), + j.at("kernel_w").template get(), + j.at("stride_h").template get(), + j.at("stride_w").template get(), + j.at("padding_h").template get(), + j.at("padding_w").template get(), + j.at("pool_type").template get<::FlexFlow::PoolOp>(), + j.at("activation").template get<::FlexFlow::Activation>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::Pool2DAttrs const &v) { + j["__type"] = "Pool2DAttrs"; + j["kernel_h"] = v.kernel_h; + j["kernel_w"] = v.kernel_w; + j["stride_h"] = v.stride_h; + j["stride_w"] = v.stride_w; + j["padding_h"] = v.padding_h; + j["padding_w"] = v.padding_w; + j["pool_type"] = v.pool_type; + j["activation"] = v.activation; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(Pool2DAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, Pool2DAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reduce_attrs.cc b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.cc new file mode 100644 index 0000000000..b6d71c81b4 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.cc @@ -0,0 +1,88 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml + +#include "op-attrs/ops/reduce_attrs.h" + +namespace FlexFlow { +ReduceAttrs::ReduceAttrs( + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> const &axes, + ::FlexFlow::Op const &op_type, + bool const &keepdims) + : axes(axes), op_type(op_type), keepdims(keepdims) {} +bool ReduceAttrs::operator==(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) == + std::tie(other.axes, other.op_type, other.keepdims); +} +bool ReduceAttrs::operator!=(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) != + std::tie(other.axes, other.op_type, other.keepdims); +} +bool ReduceAttrs::operator<(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) < + std::tie(other.axes, other.op_type, other.keepdims); +} +bool ReduceAttrs::operator>(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) > + std::tie(other.axes, other.op_type, other.keepdims); +} +bool ReduceAttrs::operator<=(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) <= + std::tie(other.axes, other.op_type, other.keepdims); +} +bool ReduceAttrs::operator>=(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) >= + std::tie(other.axes, other.op_type, other.keepdims); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReduceAttrs const &x) const { + size_t result = 0; + result ^= + std::hash< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>{}( + x.axes) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::Op>{}(x.op_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.keepdims) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReduceAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("axes") + .template get< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), + j.at("op_type").template get<::FlexFlow::Op>(), + j.at("keepdims").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReduceAttrs const &v) { + j["__type"] = "ReduceAttrs"; + j["axes"] = v.axes; + j["op_type"] = v.op_type; + j["keepdims"] = v.keepdims; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ReduceAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReduceAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reduction_attrs.cc b/lib/op-attrs/src/op-attrs/ops/reduction_attrs.cc new file mode 100644 index 0000000000..6876cf863d --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reduction_attrs.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml + +#include "op-attrs/ops/reduction_attrs.h" + +namespace FlexFlow { +ReductionAttrs::ReductionAttrs(::FlexFlow::ff_dim_t const &reduction_dim, + int const &reduction_degree) + : reduction_dim(reduction_dim), reduction_degree(reduction_degree) {} +bool ReductionAttrs::operator==(ReductionAttrs const &other) const { + return std::tie(this->reduction_dim, this->reduction_degree) == + std::tie(other.reduction_dim, other.reduction_degree); +} +bool ReductionAttrs::operator!=(ReductionAttrs const &other) const { + return std::tie(this->reduction_dim, this->reduction_degree) != + std::tie(other.reduction_dim, other.reduction_degree); +} +bool ReductionAttrs::operator<(ReductionAttrs const &other) const { + return std::tie(this->reduction_dim, this->reduction_degree) < + std::tie(other.reduction_dim, other.reduction_degree); +} +bool ReductionAttrs::operator>(ReductionAttrs const &other) const { + return std::tie(this->reduction_dim, this->reduction_degree) > + std::tie(other.reduction_dim, other.reduction_degree); +} +bool ReductionAttrs::operator<=(ReductionAttrs const &other) const { + return std::tie(this->reduction_dim, this->reduction_degree) <= + std::tie(other.reduction_dim, other.reduction_degree); +} +bool ReductionAttrs::operator>=(ReductionAttrs const &other) const { + return std::tie(this->reduction_dim, this->reduction_degree) >= + std::tie(other.reduction_dim, other.reduction_degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReductionAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.reduction_dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.reduction_degree) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReductionAttrs + adl_serializer::from_json(json const &j) { + return {j.at("reduction_dim").template get<::FlexFlow::ff_dim_t>(), + j.at("reduction_degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReductionAttrs const &v) { + j["__type"] = "ReductionAttrs"; + j["reduction_dim"] = v.reduction_dim; + j["reduction_degree"] = v.reduction_degree; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ReductionAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReductionAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/repartition_attrs.cc b/lib/op-attrs/src/op-attrs/ops/repartition_attrs.cc new file mode 100644 index 0000000000..7abd7d1959 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/repartition_attrs.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml + +#include "op-attrs/ops/repartition_attrs.h" + +namespace FlexFlow { +RepartitionAttrs::RepartitionAttrs(::FlexFlow::ff_dim_t const &repartition_dim, + int const &repartition_degree) + : repartition_dim(repartition_dim), repartition_degree(repartition_degree) { +} +bool RepartitionAttrs::operator==(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) == + std::tie(other.repartition_dim, other.repartition_degree); +} +bool RepartitionAttrs::operator!=(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) != + std::tie(other.repartition_dim, other.repartition_degree); +} +bool RepartitionAttrs::operator<(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) < + std::tie(other.repartition_dim, other.repartition_degree); +} +bool RepartitionAttrs::operator>(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) > + std::tie(other.repartition_dim, other.repartition_degree); +} +bool RepartitionAttrs::operator<=(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) <= + std::tie(other.repartition_dim, other.repartition_degree); +} +bool RepartitionAttrs::operator>=(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) >= + std::tie(other.repartition_dim, other.repartition_degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::RepartitionAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.repartition_dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.repartition_degree) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::RepartitionAttrs + adl_serializer::from_json(json const &j) { + return {j.at("repartition_dim").template get<::FlexFlow::ff_dim_t>(), + j.at("repartition_degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::RepartitionAttrs const &v) { + j["__type"] = "RepartitionAttrs"; + j["repartition_dim"] = v.repartition_dim; + j["repartition_degree"] = v.repartition_degree; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(RepartitionAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, RepartitionAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/replicate_attrs.cc b/lib/op-attrs/src/op-attrs/ops/replicate_attrs.cc new file mode 100644 index 0000000000..ea7802a325 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/replicate_attrs.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml + +#include "op-attrs/ops/replicate_attrs.h" + +namespace FlexFlow { +ReplicateAttrs::ReplicateAttrs(::FlexFlow::ff_dim_t const &replicate_dim, + int const &replicate_degree) + : replicate_dim(replicate_dim), replicate_degree(replicate_degree) {} +bool ReplicateAttrs::operator==(ReplicateAttrs const &other) const { + return std::tie(this->replicate_dim, this->replicate_degree) == + std::tie(other.replicate_dim, other.replicate_degree); +} +bool ReplicateAttrs::operator!=(ReplicateAttrs const &other) const { + return std::tie(this->replicate_dim, this->replicate_degree) != + std::tie(other.replicate_dim, other.replicate_degree); +} +bool ReplicateAttrs::operator<(ReplicateAttrs const &other) const { + return std::tie(this->replicate_dim, this->replicate_degree) < + std::tie(other.replicate_dim, other.replicate_degree); +} +bool ReplicateAttrs::operator>(ReplicateAttrs const &other) const { + return std::tie(this->replicate_dim, this->replicate_degree) > + std::tie(other.replicate_dim, other.replicate_degree); +} +bool ReplicateAttrs::operator<=(ReplicateAttrs const &other) const { + return std::tie(this->replicate_dim, this->replicate_degree) <= + std::tie(other.replicate_dim, other.replicate_degree); +} +bool ReplicateAttrs::operator>=(ReplicateAttrs const &other) const { + return std::tie(this->replicate_dim, this->replicate_degree) >= + std::tie(other.replicate_dim, other.replicate_degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReplicateAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.replicate_dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.replicate_degree) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReplicateAttrs + adl_serializer::from_json(json const &j) { + return {j.at("replicate_dim").template get<::FlexFlow::ff_dim_t>(), + j.at("replicate_degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReplicateAttrs const &v) { + j["__type"] = "ReplicateAttrs"; + j["replicate_dim"] = v.replicate_dim; + j["replicate_degree"] = v.replicate_degree; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ReplicateAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReplicateAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reshape_attrs.cc b/lib/op-attrs/src/op-attrs/ops/reshape_attrs.cc new file mode 100644 index 0000000000..b9953e33f4 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reshape_attrs.cc @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml + +#include "op-attrs/ops/reshape_attrs.h" + +namespace FlexFlow { +ReshapeAttrs::ReshapeAttrs(::FlexFlow::TensorShape const &shape) + : shape(shape) {} +bool ReshapeAttrs::operator==(ReshapeAttrs const &other) const { + return std::tie(this->shape) == std::tie(other.shape); +} +bool ReshapeAttrs::operator!=(ReshapeAttrs const &other) const { + return std::tie(this->shape) != std::tie(other.shape); +} +bool ReshapeAttrs::operator<(ReshapeAttrs const &other) const { + return std::tie(this->shape) < std::tie(other.shape); +} +bool ReshapeAttrs::operator>(ReshapeAttrs const &other) const { + return std::tie(this->shape) > std::tie(other.shape); +} +bool ReshapeAttrs::operator<=(ReshapeAttrs const &other) const { + return std::tie(this->shape) <= std::tie(other.shape); +} +bool ReshapeAttrs::operator>=(ReshapeAttrs const &other) const { + return std::tie(this->shape) >= std::tie(other.shape); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReshapeAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorShape>{}(x.shape) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReshapeAttrs + adl_serializer::from_json(json const &j) { + return {j.at("shape").template get<::FlexFlow::TensorShape>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReshapeAttrs const &v) { + j["__type"] = "ReshapeAttrs"; + j["shape"] = v.shape; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ReshapeAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReshapeAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reverse_attrs.cc b/lib/op-attrs/src/op-attrs/ops/reverse_attrs.cc new file mode 100644 index 0000000000..1b8cbd715e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reverse_attrs.cc @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml + +#include "op-attrs/ops/reverse_attrs.h" + +namespace FlexFlow { +ReverseAttrs::ReverseAttrs(::FlexFlow::ff_dim_t const &axis) : axis(axis) {} +bool ReverseAttrs::operator==(ReverseAttrs const &other) const { + return std::tie(this->axis) == std::tie(other.axis); +} +bool ReverseAttrs::operator!=(ReverseAttrs const &other) const { + return std::tie(this->axis) != std::tie(other.axis); +} +bool ReverseAttrs::operator<(ReverseAttrs const &other) const { + return std::tie(this->axis) < std::tie(other.axis); +} +bool ReverseAttrs::operator>(ReverseAttrs const &other) const { + return std::tie(this->axis) > std::tie(other.axis); +} +bool ReverseAttrs::operator<=(ReverseAttrs const &other) const { + return std::tie(this->axis) <= std::tie(other.axis); +} +bool ReverseAttrs::operator>=(ReverseAttrs const &other) const { + return std::tie(this->axis) >= std::tie(other.axis); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReverseAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.axis) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReverseAttrs + adl_serializer::from_json(json const &j) { + return {j.at("axis").template get<::FlexFlow::ff_dim_t>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReverseAttrs const &v) { + j["__type"] = "ReverseAttrs"; + j["axis"] = v.axis; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ReverseAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReverseAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/softmax_attrs.cc b/lib/op-attrs/src/op-attrs/ops/softmax_attrs.cc new file mode 100644 index 0000000000..775a7d46a7 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/softmax_attrs.cc @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml + +#include "op-attrs/ops/softmax_attrs.h" + +namespace FlexFlow { +SoftmaxAttrs::SoftmaxAttrs(::FlexFlow::ff_dim_t const &dim) : dim(dim) {} +bool SoftmaxAttrs::operator==(SoftmaxAttrs const &other) const { + return std::tie(this->dim) == std::tie(other.dim); +} +bool SoftmaxAttrs::operator!=(SoftmaxAttrs const &other) const { + return std::tie(this->dim) != std::tie(other.dim); +} +bool SoftmaxAttrs::operator<(SoftmaxAttrs const &other) const { + return std::tie(this->dim) < std::tie(other.dim); +} +bool SoftmaxAttrs::operator>(SoftmaxAttrs const &other) const { + return std::tie(this->dim) > std::tie(other.dim); +} +bool SoftmaxAttrs::operator<=(SoftmaxAttrs const &other) const { + return std::tie(this->dim) <= std::tie(other.dim); +} +bool SoftmaxAttrs::operator>=(SoftmaxAttrs const &other) const { + return std::tie(this->dim) >= std::tie(other.dim); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::SoftmaxAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::SoftmaxAttrs + adl_serializer::from_json(json const &j) { + return {j.at("dim").template get<::FlexFlow::ff_dim_t>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::SoftmaxAttrs const &v) { + j["__type"] = "SoftmaxAttrs"; + j["dim"] = v.dim; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(SoftmaxAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, SoftmaxAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/split_attrs.cc b/lib/op-attrs/src/op-attrs/ops/split_attrs.cc new file mode 100644 index 0000000000..5761b7594f --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/split_attrs.cc @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml + +#include "op-attrs/ops/split_attrs.h" + +namespace FlexFlow { +SplitAttrs::SplitAttrs( + ::FlexFlow::stack_vector const &splits, + ::FlexFlow::ff_dim_t const &axis) + : splits(splits), axis(axis) {} +bool SplitAttrs::operator==(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) == + std::tie(other.splits, other.axis); +} +bool SplitAttrs::operator!=(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) != + std::tie(other.splits, other.axis); +} +bool SplitAttrs::operator<(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) < + std::tie(other.splits, other.axis); +} +bool SplitAttrs::operator>(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) > + std::tie(other.splits, other.axis); +} +bool SplitAttrs::operator<=(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) <= + std::tie(other.splits, other.axis); +} +bool SplitAttrs::operator>=(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) >= + std::tie(other.splits, other.axis); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::SplitAttrs const &x) const { + size_t result = 0; + result ^= + std::hash<::FlexFlow::stack_vector>{}(x.splits) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.axis) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::SplitAttrs + adl_serializer::from_json(json const &j) { + return {j.at("splits") + .template get<::FlexFlow::stack_vector>(), + j.at("axis").template get<::FlexFlow::ff_dim_t>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::SplitAttrs const &v) { + j["__type"] = "SplitAttrs"; + j["splits"] = v.splits; + j["axis"] = v.axis; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(SplitAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, SplitAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/topk_attrs.cc b/lib/op-attrs/src/op-attrs/ops/topk_attrs.cc new file mode 100644 index 0000000000..2105b6e716 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/topk_attrs.cc @@ -0,0 +1,72 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml + +#include "op-attrs/ops/topk_attrs.h" + +namespace FlexFlow { +TopKAttrs::TopKAttrs(int const &k, bool const &sorted) : k(k), sorted(sorted) {} +bool TopKAttrs::operator==(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) == std::tie(other.k, other.sorted); +} +bool TopKAttrs::operator!=(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) != std::tie(other.k, other.sorted); +} +bool TopKAttrs::operator<(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) < std::tie(other.k, other.sorted); +} +bool TopKAttrs::operator>(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) > std::tie(other.k, other.sorted); +} +bool TopKAttrs::operator<=(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) <= std::tie(other.k, other.sorted); +} +bool TopKAttrs::operator>=(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) >= std::tie(other.k, other.sorted); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(FlexFlow::TopKAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.k) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.sorted) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TopKAttrs + adl_serializer::from_json(json const &j) { + return {j.at("k").template get(), j.at("sorted").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TopKAttrs const &v) { + j["__type"] = "TopKAttrs"; + j["k"] = v.k; + j["sorted"] = v.sorted; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(TopKAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TopKAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/transpose_attrs.cc b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.cc new file mode 100644 index 0000000000..e046896753 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.cc @@ -0,0 +1,69 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml + +#include "op-attrs/ops/transpose_attrs.h" + +namespace FlexFlow { +TransposeAttrs::TransposeAttrs( + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> const &perm) + : perm(perm) {} +bool TransposeAttrs::operator==(TransposeAttrs const &other) const { + return std::tie(this->perm) == std::tie(other.perm); +} +bool TransposeAttrs::operator!=(TransposeAttrs const &other) const { + return std::tie(this->perm) != std::tie(other.perm); +} +bool TransposeAttrs::operator<(TransposeAttrs const &other) const { + return std::tie(this->perm) < std::tie(other.perm); +} +bool TransposeAttrs::operator>(TransposeAttrs const &other) const { + return std::tie(this->perm) > std::tie(other.perm); +} +bool TransposeAttrs::operator<=(TransposeAttrs const &other) const { + return std::tie(this->perm) <= std::tie(other.perm); +} +bool TransposeAttrs::operator>=(TransposeAttrs const &other) const { + return std::tie(this->perm) >= std::tie(other.perm); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TransposeAttrs const &x) const { + size_t result = 0; + result ^= + std::hash< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>{}( + x.perm) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TransposeAttrs + adl_serializer::from_json(json const &j) { + return {j.at("perm") + .template get<::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, + MAX_TENSOR_DIM>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TransposeAttrs const &v) { + j["__type"] = "TransposeAttrs"; + j["perm"] = v.perm; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TransposeAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TransposeAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/pool_op.cc b/lib/op-attrs/src/op-attrs/pool_op.cc new file mode 100644 index 0000000000..9e52e0801a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/pool_op.cc @@ -0,0 +1,17 @@ +#include "op-attrs/pool_op.h" +#include "utils/exception.h" + +namespace FlexFlow { + +std::string format_as(PoolOp o) { + switch (o) { + case PoolOp::MAX: + return "MAX"; + case PoolOp::AVG: + return "AVG"; + default: + throw mk_runtime_error(fmt::format("Unknown pool op {}", static_cast(o))); + } +} + +} From 2011b31ebd4ab2ef66618662e0508914ea1b1880 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 7 Apr 2024 22:59:58 -0700 Subject: [PATCH 07/43] Format --- flake.lock | 6 +++--- lib/op-attrs/include/op-attrs/activation.h | 4 ++-- lib/op-attrs/include/op-attrs/aggregate_op.h | 2 +- lib/op-attrs/include/op-attrs/ops/attention.h | 2 +- lib/op-attrs/include/op-attrs/ops/batch_norm.h | 2 +- lib/op-attrs/include/op-attrs/ops/conv_2d.h | 2 +- lib/op-attrs/include/op-attrs/ops/dropout.h | 2 +- lib/op-attrs/include/op-attrs/ops/element_binary.h | 2 +- lib/op-attrs/include/op-attrs/ops/element_unary.h | 4 ++-- lib/op-attrs/include/op-attrs/ops/embedding.h | 2 +- lib/op-attrs/include/op-attrs/ops/flat.h | 2 +- lib/op-attrs/include/op-attrs/ops/gather.h | 2 +- lib/op-attrs/include/op-attrs/ops/input.h | 2 +- lib/op-attrs/include/op-attrs/ops/layer_norm.h | 2 +- lib/op-attrs/include/op-attrs/ops/linear.h | 2 +- lib/op-attrs/include/op-attrs/ops/pool_2d.h | 2 +- lib/op-attrs/include/op-attrs/ops/reduce.h | 2 +- lib/op-attrs/include/op-attrs/ops/reduction.h | 2 +- lib/op-attrs/include/op-attrs/ops/repartition.h | 2 +- lib/op-attrs/include/op-attrs/ops/replicate.h | 2 +- lib/op-attrs/include/op-attrs/ops/softmax.h | 2 +- lib/op-attrs/include/op-attrs/ops/split.h | 2 +- lib/op-attrs/include/op-attrs/ops/topk.h | 2 +- lib/op-attrs/include/op-attrs/ops/transpose.h | 2 +- lib/op-attrs/include/op-attrs/pool_op.h | 5 ++--- lib/op-attrs/src/op-attrs/aggregate_op.cc | 5 +++-- lib/op-attrs/src/op-attrs/pool_op.cc | 5 +++-- 27 files changed, 36 insertions(+), 35 deletions(-) diff --git a/flake.lock b/flake.lock index b42ff59e74..af7c1f50a7 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1712555489, - "narHash": "sha256-V7Ck7y0BC18HR+CHd8fSp9i3ObGwdJsu22NpxQvTGVs=", + "lastModified": 1712555847, + "narHash": "sha256-a/QUt59McmON5IjSM1+Bu3+NFizdO+o/jjFyx8dlxFQ=", "owner": "lockshaw", "repo": "proj", - "rev": "6b7312e2079178332ffefc24f54c11879fc85e7a", + "rev": "beb036ec235e8c982f840cdd7748855128e80c19", "type": "github" }, "original": { diff --git a/lib/op-attrs/include/op-attrs/activation.h b/lib/op-attrs/include/op-attrs/activation.h index 87729b1206..c3840fd9b3 100644 --- a/lib/op-attrs/include/op-attrs/activation.h +++ b/lib/op-attrs/include/op-attrs/activation.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_ACTIVATION_H #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_ACTIVATION_H -#include "utils/fmt.h" #include "nlohmann/json.hpp" +#include "utils/fmt.h" namespace FlexFlow { @@ -14,7 +14,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Activation, {Activation::TANH, "TANH"}, {Activation::GELU, "GELU"}}); -} +} // namespace FlexFlow namespace fmt { diff --git a/lib/op-attrs/include/op-attrs/aggregate_op.h b/lib/op-attrs/include/op-attrs/aggregate_op.h index b0d1e6cf93..eb8718533c 100644 --- a/lib/op-attrs/include/op-attrs/aggregate_op.h +++ b/lib/op-attrs/include/op-attrs/aggregate_op.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_H -#include "utils/fmt.h" #include "nlohmann/json.hpp" +#include "utils/fmt.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index 4000732ddf..69974c5646 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -2,9 +2,9 @@ #define _FLEXFLOW_ATTENTION_ATTRS_H #include "core.h" +#include "op-attrs/ops/attention_attrs.h" #include "op-attrs/parallel_tensor_shape.h" #include "utils/visitable.h" -#include "op-attrs/ops/attention_attrs.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index 1683d2a30c..8f42c3cc74 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -2,9 +2,9 @@ #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H #include "core.h" +#include "op-attrs/ops/batch_norm_attrs.h" #include "op-attrs/parallel_tensor_shape.h" #include "utils/visitable.h" -#include "op-attrs/ops/batch_norm_attrs.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index d515c76048..47b7149004 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -2,9 +2,9 @@ #define _FLEXFLOW_CONV_2D_ATTRS_H #include "core.h" +#include "op-attrs/ops/conv_2d_attrs.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" -#include "op-attrs/ops/conv_2d_attrs.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/dropout.h b/lib/op-attrs/include/op-attrs/ops/dropout.h index 54def8a6c4..0c0a1b746d 100644 --- a/lib/op-attrs/include/op-attrs/ops/dropout.h +++ b/lib/op-attrs/include/op-attrs/ops/dropout.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_DROPOUT_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/dropout_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary.h b/lib/op-attrs/include/op-attrs/ops/element_binary.h index b46c66807d..b6ed0e6210 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_ELEMENT_BINARY_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/element_binary_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index e1c874ed5a..888185042a 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -2,9 +2,9 @@ #define _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "op-attrs/ops/element_unary_attrs.h" #include "op-attrs/ops/element_scalar_unary_attrs.h" +#include "op-attrs/ops/element_unary_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/embedding.h b/lib/op-attrs/include/op-attrs/ops/embedding.h index 8ad95a7fca..948e189397 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding.h @@ -2,9 +2,9 @@ #define _FLEXFLOW_EMBEDDING_ATTRS_H #include "core.h" +#include "op-attrs/ops/embedding_attrs.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" -#include "op-attrs/ops/embedding_attrs.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index dac75e4aa3..6f51d17c98 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_FLAT_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/flat_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/gather.h b/lib/op-attrs/include/op-attrs/ops/gather.h index 596d266bb4..fd75292fe1 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather.h +++ b/lib/op-attrs/include/op-attrs/ops/gather.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_GATHER_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/gather_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/input.h b/lib/op-attrs/include/op-attrs/ops/input.h index 73730c76d3..9f7a8d2de1 100644 --- a/lib/op-attrs/include/op-attrs/ops/input.h +++ b/lib/op-attrs/include/op-attrs/ops/input.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_OP_ATTRS_OPS_OP_ATTRS_INPUT_H #include "core.h" -#include "utils/visitable.h" #include "op-attrs/ops/input_attrs.h" +#include "utils/visitable.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index 97fd4990d5..d2e394f0a3 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/layer_norm_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index d9ce4d354e..dc0054faad 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_LINEAR_ATTRS_H #include "op-attrs/ops/core.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/linear_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index b766edcdaf..1e1624c405 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_POOL_2D_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/pool_2d_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/reduce.h b/lib/op-attrs/include/op-attrs/ops/reduce.h index 9923bda684..10c15b023d 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_OP_META_OPS_REDUCE_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/reduce_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/reduction.h b/lib/op-attrs/include/op-attrs/ops/reduction.h index 8005d3d64f..9c421486e8 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_REDUCTION_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/reduction_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/repartition.h b/lib/op-attrs/include/op-attrs/ops/repartition.h index 339b494855..21a25ccec6 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_PARTITION_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/repartition_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index bc96c87808..50a6be6d76 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_REPLICATE_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/replicate_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/softmax.h b/lib/op-attrs/include/op-attrs/ops/softmax.h index 272610a77c..7b21e3ea38 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax.h +++ b/lib/op-attrs/include/op-attrs/ops/softmax.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_SOFTMAX_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/softmax_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/split.h b/lib/op-attrs/include/op-attrs/ops/split.h index 7bd2b0ff1a..864e8c4c4a 100644 --- a/lib/op-attrs/include/op-attrs/ops/split.h +++ b/lib/op-attrs/include/op-attrs/ops/split.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_SPLIT_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/split_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index b059e5071f..57379518b6 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_TOPK_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/topk_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/transpose.h b/lib/op-attrs/include/op-attrs/ops/transpose.h index 847f660f1a..f352be8f37 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_OP_META_OPS_TRANSPOSE_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/transpose_attrs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/pool_op.h b/lib/op-attrs/include/op-attrs/pool_op.h index 6006702156..eae02a1a3b 100644 --- a/lib/op-attrs/include/op-attrs/pool_op.h +++ b/lib/op-attrs/include/op-attrs/pool_op.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_H -#include "utils/fmt.h" #include "nlohmann/json.hpp" +#include "utils/fmt.h" namespace FlexFlow { @@ -12,8 +12,7 @@ enum class PoolOp { }; NLOHMANN_JSON_SERIALIZE_ENUM(PoolOp, - {{PoolOp::MAX, "MAX"}, - {PoolOp::AVG, "AVG"}}); + {{PoolOp::MAX, "MAX"}, {PoolOp::AVG, "AVG"}}); std::string format_as(PoolOp); diff --git a/lib/op-attrs/src/op-attrs/aggregate_op.cc b/lib/op-attrs/src/op-attrs/aggregate_op.cc index 1d0bfa9f55..29d143579b 100644 --- a/lib/op-attrs/src/op-attrs/aggregate_op.cc +++ b/lib/op-attrs/src/op-attrs/aggregate_op.cc @@ -10,8 +10,9 @@ std::string format_as(AggregateOp o) { case AggregateOp::AVG: return "AVG"; default: - throw mk_runtime_error(fmt::format("Unknown aggregate op {}", static_cast(o))); + throw mk_runtime_error( + fmt::format("Unknown aggregate op {}", static_cast(o))); } } -} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/pool_op.cc b/lib/op-attrs/src/op-attrs/pool_op.cc index 9e52e0801a..cdb0f8bf4b 100644 --- a/lib/op-attrs/src/op-attrs/pool_op.cc +++ b/lib/op-attrs/src/op-attrs/pool_op.cc @@ -10,8 +10,9 @@ std::string format_as(PoolOp o) { case PoolOp::AVG: return "AVG"; default: - throw mk_runtime_error(fmt::format("Unknown pool op {}", static_cast(o))); + throw mk_runtime_error( + fmt::format("Unknown pool op {}", static_cast(o))); } } -} +} // namespace FlexFlow From 0478c02d9231b6d21b0cf430eb00638f99778cb9 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 10 Apr 2024 23:03:47 -0700 Subject: [PATCH 08/43] More dtgen'ing --- flake.lock | 6 +- lib/op-attrs/include/op-attrs/datatype.h | 38 +------- .../include/op-attrs/datatype_t.enum.toml | 26 +++++ lib/op-attrs/include/op-attrs/datatype_t.h | 35 +++++++ lib/op-attrs/include/op-attrs/dim_ordered.h | 27 ++++++ .../include/op-attrs/operator_attrs.h | 1 + lib/op-attrs/include/op-attrs/ops/attention.h | 56 +++++------ .../include/op-attrs/ops/attention_inputs.h | 55 +++++++++++ .../op-attrs/ops/attention_inputs.struct.toml | 26 +++++ .../op-attrs/ops/parallel_attention_inputs.h | 58 +++++++++++ .../ops/parallel_attention_inputs.struct.toml | 26 +++++ lib/op-attrs/include/op-attrs/parallel_dim.h | 17 +--- .../include/op-attrs/parallel_dim_t.h | 62 ++++++++++++ .../op-attrs/parallel_dim_t.struct.toml | 22 +++++ .../include/op-attrs/parallel_tensor_dims.h | 53 +++------- .../include/op-attrs/parallel_tensor_dims_t.h | 53 ++++++++++ .../parallel_tensor_dims_t.struct.toml | 19 ++++ .../include/op-attrs/parallel_tensor_shape.h | 38 ++------ .../op-attrs/parallel_tensor_shape_t.h | 54 +++++++++++ .../parallel_tensor_shape_t.struct.toml | 23 +++++ .../include/op-attrs/regularizer_attrs.h | 3 + lib/op-attrs/include/op-attrs/tensor_dims.h | 14 +++ lib/op-attrs/include/op-attrs/tensor_dims_t.h | 51 ++++++++++ .../op-attrs/tensor_dims_t.struct.toml | 17 ++++ lib/op-attrs/include/op-attrs/tensor_shape.h | 26 +---- .../include/op-attrs/tensor_shape_t.h | 54 +++++++++++ .../op-attrs/tensor_shape_t.struct.toml | 23 +++++ lib/op-attrs/src/attention.cc | 46 +++++---- lib/op-attrs/src/op-attrs/datatype_t.cc | 97 +++++++++++++++++++ .../src/op-attrs/ops/attention_inputs.cc | 89 +++++++++++++++++ .../op-attrs/ops/parallel_attention_inputs.cc | 90 +++++++++++++++++ lib/op-attrs/src/op-attrs/parallel_dim_t.cc | 88 +++++++++++++++++ .../src/op-attrs/parallel_tensor_dims.cc | 18 ++++ .../src/op-attrs/parallel_tensor_dims_t.cc | 66 +++++++++++++ .../src/op-attrs/parallel_tensor_shape_t.cc | 76 +++++++++++++++ lib/op-attrs/src/op-attrs/tensor_dims.cc | 13 +++ lib/op-attrs/src/op-attrs/tensor_dims_t.cc | 63 ++++++++++++ lib/op-attrs/src/op-attrs/tensor_shape_t.cc | 75 ++++++++++++++ lib/op-attrs/src/parallel_tensor_shape.cc | 75 +------------- lib/op-attrs/src/tensor_shape.cc | 8 +- lib/substitutions/src/graph_pattern.cc | 4 +- .../test/src/test_substitution.cc | 6 +- lib/utils/include/utils/json.h | 46 ++++++--- 43 files changed, 1449 insertions(+), 294 deletions(-) create mode 100644 lib/op-attrs/include/op-attrs/datatype_t.enum.toml create mode 100644 lib/op-attrs/include/op-attrs/datatype_t.h create mode 100644 lib/op-attrs/include/op-attrs/ops/attention_inputs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/parallel_dim_t.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_dim_t.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/tensor_dims.h create mode 100644 lib/op-attrs/include/op-attrs/tensor_dims_t.h create mode 100644 lib/op-attrs/include/op-attrs/tensor_dims_t.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/tensor_shape_t.h create mode 100644 lib/op-attrs/include/op-attrs/tensor_shape_t.struct.toml create mode 100644 lib/op-attrs/src/op-attrs/datatype_t.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/attention_inputs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_dim_t.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_tensor_dims_t.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_tensor_shape_t.cc create mode 100644 lib/op-attrs/src/op-attrs/tensor_dims.cc create mode 100644 lib/op-attrs/src/op-attrs/tensor_dims_t.cc create mode 100644 lib/op-attrs/src/op-attrs/tensor_shape_t.cc diff --git a/flake.lock b/flake.lock index af7c1f50a7..893bb00c6c 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1712555847, - "narHash": "sha256-a/QUt59McmON5IjSM1+Bu3+NFizdO+o/jjFyx8dlxFQ=", + "lastModified": 1712564474, + "narHash": "sha256-4+QCnVuTpCSTxtTcH/NmAfsH0XvU6MLdoMNiSiMCCaE=", "owner": "lockshaw", "repo": "proj", - "rev": "beb036ec235e8c982f840cdd7748855128e80c19", + "rev": "2c9d234aefa756d7800c966589cac8874f6f21a0", "type": "github" }, "original": { diff --git a/lib/op-attrs/include/op-attrs/datatype.h b/lib/op-attrs/include/op-attrs/datatype.h index 4a8de665b4..d4c61e9895 100644 --- a/lib/op-attrs/include/op-attrs/datatype.h +++ b/lib/op-attrs/include/op-attrs/datatype.h @@ -4,11 +4,10 @@ #include "utils/fmt.h" #include "utils/fp16.h" #include +#include "op-attrs/datatype_t.h" namespace FlexFlow { -enum class DataType { BOOL, INT32, INT64, HALF, FLOAT, DOUBLE }; - template struct data_type_enum_to_class; @@ -61,39 +60,4 @@ size_t size_of(DataType); } // namespace FlexFlow -namespace fmt { -template <> -struct formatter<::FlexFlow::DataType> : formatter { - template - auto format(::FlexFlow::DataType dt, FormatContext &ctx) - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (dt) { - case DataType::BOOL: - name = "BOOL"; - break; - case DataType::INT32: - name = "INT32"; - break; - case DataType::INT64: - name = "INT64"; - break; - case DataType::HALF: - name = "HALF"; - break; - case DataType::FLOAT: - name = "FLOAT"; - break; - case DataType::DOUBLE: - name = "DOUBLE"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - #endif diff --git a/lib/op-attrs/include/op-attrs/datatype_t.enum.toml b/lib/op-attrs/include/op-attrs/datatype_t.enum.toml new file mode 100644 index 0000000000..15210cfe29 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/datatype_t.enum.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "DataType" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "BOOL" + +[[values]] +name = "INT32" + +[[values]] +name = "INT64" + +[[values]] +name = "HALF" + +[[values]] +name = "FLOAT" + +[[values]] +name = "DOUBLE" diff --git a/lib/op-attrs/include/op-attrs/datatype_t.h b/lib/op-attrs/include/op-attrs/datatype_t.h new file mode 100644 index 0000000000..67d7278927 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/datatype_t.h @@ -0,0 +1,35 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/datatype_t.enum.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DATATYPE_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DATATYPE_T_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class DataType { BOOL, INT32, INT64, HALF, FLOAT, DOUBLE }; +std::string format_as(DataType); +std::ostream &operator<<(std::ostream &, DataType); +void to_json(::nlohmann::json &, DataType); +void from_json(::nlohmann::json const &, DataType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::DataType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DATATYPE_T_H diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index b03667466d..ae1507914d 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -3,6 +3,7 @@ #include "op-attrs/ff_dim.h" #include "utils/stack_vector.h" +#include "utils/json.h" namespace FlexFlow { @@ -133,6 +134,8 @@ struct DimOrdered { template using FFOrdered = DimOrdered; +/* CHECK_JSONABLE(FFOrdered); */ + template auto inner_to_outer(FFOrdered const &ff_ordered) -> decltype(reversed_container(ff_ordered)) { @@ -160,6 +163,30 @@ FFOrdered const &outer_to_inner(FFOrdered const &ff_ordered) { } // namespace FlexFlow + +/* template */ +/* void to_json(json &j, DimOrdered const &x) { */ +/* /1* j = std::vector{x.cbegin(), x.cend()}; *1/ */ +/* } */ + +/* template */ +/* void from_json(json const &j, DimOrdered &x) { */ +/* /1* x = DimOrdered{j.template get>()}; *1/ */ +/* } */ + +namespace nlohmann { +template +struct adl_serializer<::FlexFlow::DimOrdered> { + static ::FlexFlow::DimOrdered from_json(json const &j) { + return {j.template get>()}; + } + + static void to_json(json& j, ::FlexFlow::DimOrdered const &x) { + j = std::vector{x.cbegin(), x.cend()}; + } +}; +} + namespace std { template diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index b63563cd67..09b290d7ef 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -33,6 +33,7 @@ #include "ops/transpose.h" #include "utils/variant.h" #include +#include "utils/record_formatter.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index 69974c5646..84c52895e1 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -4,57 +4,49 @@ #include "core.h" #include "op-attrs/ops/attention_attrs.h" #include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ops/attention_inputs.h" +#include "op-attrs/ops/parallel_attention_inputs.h" #include "utils/visitable.h" namespace FlexFlow { -template -struct MultiHeadAttentionInputs - : public use_visitable_cmp> { -public: - MultiHeadAttentionInputs() = delete; - - MultiHeadAttentionInputs(TensorType const &query, - TensorType const &key, - TensorType const &value) - : query(query), key(key), value(value) {} - - template - MultiHeadAttentionInputs(MultiHeadAttentionInputs const &sub) - : query(sub.query), key(sub.key), value(sub.value) {} - -public: - TensorType query; - TensorType key; - TensorType value; -}; - int get_qProjSize(MultiHeadAttentionAttrs const &); int get_vProjSize(MultiHeadAttentionAttrs const &); int get_kProjSize(MultiHeadAttentionAttrs const &); int get_oProjSize(MultiHeadAttentionAttrs const &); -int get_qSize(MultiHeadAttentionInputs const &); -int get_kSize(MultiHeadAttentionInputs const &); -int get_vSize(MultiHeadAttentionInputs const &); +int get_qSize(ParallelMultiHeadAttentionInputs const &); +int get_qSize(MultiHeadAttentionInputs const &); + +int get_kSize(ParallelMultiHeadAttentionInputs const &); +int get_kSize(MultiHeadAttentionInputs const &); + +int get_vSize(ParallelMultiHeadAttentionInputs const &); +int get_vSize(MultiHeadAttentionInputs const &); + int get_oSize(ParallelTensorShape const &); +int get_oSize(TensorShape const &); + +int get_qoSeqLength(ParallelMultiHeadAttentionInputs const &); +int get_qoSeqLength(MultiHeadAttentionInputs const &); -int get_qoSeqLength(MultiHeadAttentionInputs const &); -int get_kvSeqLength(MultiHeadAttentionInputs const &); +int get_kvSeqLength(ParallelMultiHeadAttentionInputs const &); +int get_kvSeqLength(MultiHeadAttentionInputs const &); -int get_num_samples(MultiHeadAttentionInputs const &); +int get_num_samples(ParallelMultiHeadAttentionInputs const &); +int get_num_samples(MultiHeadAttentionInputs const &); TensorShape get_weights_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &); + MultiHeadAttentionInputs const &); ParallelTensorShape get_weights_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &); + ParallelMultiHeadAttentionInputs const &); +TensorShape get_output_shape(MultiHeadAttentionAttrs const &, + MultiHeadAttentionInputs const &); ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &); -TensorShape get_output_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &); + ParallelMultiHeadAttentionInputs const &); CHECK_VALID_OP_ATTR(MultiHeadAttentionAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/attention_inputs.h b/lib/op-attrs/include/op-attrs/ops/attention_inputs.h new file mode 100644 index 0000000000..3d2d4d1d74 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention_inputs.h @@ -0,0 +1,55 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_INPUTS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_INPUTS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/tensor_shape.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct MultiHeadAttentionInputs { + MultiHeadAttentionInputs() = delete; + MultiHeadAttentionInputs(::FlexFlow::TensorShape const &query, + ::FlexFlow::TensorShape const &key, + ::FlexFlow::TensorShape const &value); + + bool operator==(MultiHeadAttentionInputs const &) const; + bool operator!=(MultiHeadAttentionInputs const &) const; + bool operator<(MultiHeadAttentionInputs const &) const; + bool operator>(MultiHeadAttentionInputs const &) const; + bool operator<=(MultiHeadAttentionInputs const &) const; + bool operator>=(MultiHeadAttentionInputs const &) const; + ::FlexFlow::TensorShape query; + ::FlexFlow::TensorShape key; + ::FlexFlow::TensorShape value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MultiHeadAttentionInputs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MultiHeadAttentionInputs from_json(json const &); + static void to_json(json &, FlexFlow::MultiHeadAttentionInputs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionInputs const &); +std::ostream &operator<<(std::ostream &, MultiHeadAttentionInputs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_INPUTS_H diff --git a/lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml new file mode 100644 index 0000000000..224bad4dc8 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "MultiHeadAttentionInputs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_shape.h" +] + +[[fields]] +name = "query" +type = "::FlexFlow::TensorShape" + +[[fields]] +name = "key" +type = "::FlexFlow::TensorShape" + +[[fields]] +name = "value" +type = "::FlexFlow::TensorShape" diff --git a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.h b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.h new file mode 100644 index 0000000000..55874f2d92 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_PARALLEL_ATTENTION_INPUTS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_PARALLEL_ATTENTION_INPUTS_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/parallel_tensor_shape.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ParallelMultiHeadAttentionInputs { + ParallelMultiHeadAttentionInputs() = delete; + ParallelMultiHeadAttentionInputs( + ::FlexFlow::ParallelTensorShape const &query, + ::FlexFlow::ParallelTensorShape const &key, + ::FlexFlow::ParallelTensorShape const &value); + + bool operator==(ParallelMultiHeadAttentionInputs const &) const; + bool operator!=(ParallelMultiHeadAttentionInputs const &) const; + bool operator<(ParallelMultiHeadAttentionInputs const &) const; + bool operator>(ParallelMultiHeadAttentionInputs const &) const; + bool operator<=(ParallelMultiHeadAttentionInputs const &) const; + bool operator>=(ParallelMultiHeadAttentionInputs const &) const; + ::FlexFlow::ParallelTensorShape query; + ::FlexFlow::ParallelTensorShape key; + ::FlexFlow::ParallelTensorShape value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelMultiHeadAttentionInputs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelMultiHeadAttentionInputs from_json(json const &); + static void to_json(json &, + FlexFlow::ParallelMultiHeadAttentionInputs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelMultiHeadAttentionInputs const &); +std::ostream &operator<<(std::ostream &, + ParallelMultiHeadAttentionInputs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_PARALLEL_ATTENTION_INPUTS_H diff --git a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml new file mode 100644 index 0000000000..f7513fee8f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "ParallelMultiHeadAttentionInputs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.h" +] + +[[fields]] +name = "query" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "key" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "value" +type = "::FlexFlow::ParallelTensorShape" diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.h b/lib/op-attrs/include/op-attrs/parallel_dim.h index 9d407ec469..64c40b9594 100644 --- a/lib/op-attrs/include/op-attrs/parallel_dim.h +++ b/lib/op-attrs/include/op-attrs/parallel_dim.h @@ -1,24 +1,17 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_H -#include "utils/type_traits.h" -#include "utils/visitable.h" +#include "op-attrs/parallel_dim_t.h" namespace FlexFlow { -struct ParallelDim { - size_t size; - int degree; - req is_replica_dim; -}; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ParallelDim, - size, - degree, - is_replica_dim); - bool is_valid(ParallelDim const &); bool is_replica_dim(ParallelDim const &); +ParallelDim with_size_set_to(ParallelDim const &, size_t); +ParallelDim with_degree_set_to(ParallelDim const &, int); +ParallelDim with_is_replica_set_to(ParallelDim const &, bool); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_dim_t.h b/lib/op-attrs/include/op-attrs/parallel_dim_t.h new file mode 100644 index 0000000000..730a5c6f5e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_dim_t.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_dim_t.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_T_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ParallelDim { + ParallelDim() = delete; + ParallelDim(size_t const &size, + int const °ree, + bool const &is_replica_dim); + + bool operator==(ParallelDim const &) const; + bool operator!=(ParallelDim const &) const; + bool operator<(ParallelDim const &) const; + bool operator>(ParallelDim const &) const; + bool operator<=(ParallelDim const &) const; + bool operator>=(ParallelDim const &) const; + size_t size; + int degree; + bool is_replica_dim; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelDim const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelDim from_json(json const &); + static void to_json(json &, FlexFlow::ParallelDim const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ParallelDim const &); +std::ostream &operator<<(std::ostream &, ParallelDim const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_T_H diff --git a/lib/op-attrs/include/op-attrs/parallel_dim_t.struct.toml b/lib/op-attrs/include/op-attrs/parallel_dim_t.struct.toml new file mode 100644 index 0000000000..7ecb6a5b04 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_dim_t.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "ParallelDim" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "size" +type = "size_t" + +[[fields]] +name = "degree" +type = "int" + +[[fields]] +name = "is_replica_dim" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index d38ba75232..787938322c 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -1,45 +1,23 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_H -#include "parallel_dim.h" -#include "utils/visitable.h" +#include "op-attrs/parallel_dim.h" +#include "op-attrs/parallel_tensor_dims_t.h" +#include "op-attrs/tensor_dims_t.h" namespace FlexFlow { -struct ParallelTensorDims : public use_visitable_cmp { - explicit ParallelTensorDims(TensorDims const &); - - size_t get_volume() const; - size_t num_dims() const; - - using iterator = typename FFOrdered::iterator; - using const_iterator = typename FFOrdered::const_iterator; - using reverse_iterator = typename FFOrdered::reverse_iterator; - using const_reverse_iterator = - typename FFOrdered::const_reverse_iterator; - using value_type = typename FFOrdered::value_type; - using pointer = typename FFOrdered::pointer; - using const_pointer = typename FFOrdered::const_pointer; - - ParallelDim const &at(ff_dim_t const &) const; - ParallelDim &at(ff_dim_t const &); - - iterator begin(); - const_iterator begin() const; - const_iterator cbegin() const; - iterator end(); - const_iterator end() const; - const_iterator cend() const; - reverse_iterator rbegin(); - const_reverse_iterator rbegin() const; - const_reverse_iterator crbegin() const; - reverse_iterator rend(); - const_reverse_iterator rend() const; - const_reverse_iterator crend() const; - -public: - FFOrdered data; -}; +FFOrdered const &ff_ordered(ParallelTensorDims const &); + +std::vector as_vector(ParallelTensorDims const &); + +int get_num_replica_dims(ParallelTensorDims const &); + +size_t get_volume(ParallelTensorDims const &); +size_t num_dims(ParallelTensorDims const &); + +ParallelDim dim_at_idx(ParallelTensorDims const &, ff_dim_t); +ParallelDim &dim_at_idx(ParallelTensorDims &, ff_dim_t); bool is_valid(ParallelTensorDims const &); TensorDims get_piece_dims(ParallelTensorDims const &); @@ -47,7 +25,4 @@ TensorDims get_tensor_dims_unsafe(ParallelTensorDims const &); } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::ParallelTensorDims, data); -MAKE_VISIT_HASHABLE(::FlexFlow::ParallelTensorDims); - #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.h new file mode 100644 index 0000000000..54fb55b4c0 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.h @@ -0,0 +1,53 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_T_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/dim_ordered.h" +#include "op-attrs/parallel_dim.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ParallelTensorDims { + ParallelTensorDims() = delete; + ParallelTensorDims( + ::FlexFlow::FFOrdered<::FlexFlow::ParallelDim> const &unwrapped); + + bool operator==(ParallelTensorDims const &) const; + bool operator!=(ParallelTensorDims const &) const; + bool operator<(ParallelTensorDims const &) const; + bool operator>(ParallelTensorDims const &) const; + bool operator<=(ParallelTensorDims const &) const; + bool operator>=(ParallelTensorDims const &) const; + ::FlexFlow::FFOrdered<::FlexFlow::ParallelDim> unwrapped; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelTensorDims const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelTensorDims from_json(json const &); + static void to_json(json &, FlexFlow::ParallelTensorDims const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelTensorDims const &); +std::ostream &operator<<(std::ostream &, ParallelTensorDims const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_T_H diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.struct.toml new file mode 100644 index 0000000000..660b6f95f4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "ParallelTensorDims" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/dim_ordered.h", + "op-attrs/parallel_dim.h", +] + +[[fields]] +name = "unwrapped" +type = "::FlexFlow::FFOrdered<::FlexFlow::ParallelDim>" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index fd560352bb..3000463365 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -1,41 +1,18 @@ #ifndef _OP_META_PARALLEL_TENSOR_SHAPE_H #define _OP_META_PARALLEL_TENSOR_SHAPE_H -#include "datatype.h" #include "op-attrs/tensor_shape.h" -#include "parallel_tensor_dims.h" -#include "utils/bidict.h" -#include "utils/record_formatter.h" -#include "utils/stack_vector.h" -#include "utils/visitable.h" -#include #include +#include "op-attrs/parallel_tensor_shape_t.h" +#include "op-attrs/tensor_shape.h" namespace FlexFlow { -/** - * @brief Represent the shape of a ParallelTensor. - */ -struct ParallelTensorShape : public use_visitable_cmp { - ParallelTensorShape() = delete; - - template - ParallelTensorShape(Dims const &dims, DataType data_type) - : dims(dims), data_type(data_type) {} - - ParallelTensorShape(TensorShape const &); +int num_dims(ParallelTensorShape const &); +ParallelDim dim_at_idx(ParallelTensorShape const &, ff_dim_t); +ParallelDim &dim_at_idx(ParallelTensorShape &, ff_dim_t); - int num_dims() const; - - ParallelDim const &at(ff_dim_t const &) const; - ParallelDim &at(ff_dim_t const &); - ParallelDim const &operator[](ff_dim_t const &) const; - ParallelDim &operator[](ff_dim_t const &); - -public: - ParallelTensorDims dims; - DataType data_type; -}; +ParallelTensorShape lift_to_parallel(TensorShape const &); TensorShape get_piece_shape(ParallelTensorShape const &); int get_num_replica_dims(ParallelTensorShape const &); @@ -49,7 +26,4 @@ std::vector } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::ParallelTensorShape, data_type, dims); -MAKE_VISIT_HASHABLE(::FlexFlow::ParallelTensorShape); - #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.h new file mode 100644 index 0000000000..e1f3333b9b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.h @@ -0,0 +1,54 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_T_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.h" +#include "op-attrs/parallel_tensor_dims.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ParallelTensorShape { + ParallelTensorShape() = delete; + ParallelTensorShape(::FlexFlow::ParallelTensorDims const &dims, + ::FlexFlow::DataType const &data_type); + + bool operator==(ParallelTensorShape const &) const; + bool operator!=(ParallelTensorShape const &) const; + bool operator<(ParallelTensorShape const &) const; + bool operator>(ParallelTensorShape const &) const; + bool operator<=(ParallelTensorShape const &) const; + bool operator>=(ParallelTensorShape const &) const; + ::FlexFlow::ParallelTensorDims dims; + ::FlexFlow::DataType data_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelTensorShape const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelTensorShape from_json(json const &); + static void to_json(json &, FlexFlow::ParallelTensorShape const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelTensorShape const &); +std::ostream &operator<<(std::ostream &, ParallelTensorShape const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_T_H diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.struct.toml new file mode 100644 index 0000000000..1199b0d816 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ParallelTensorShape" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_dims.h", + "op-attrs/datatype.h", +] + +[[fields]] +name = "dims" +type = "::FlexFlow::ParallelTensorDims" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/regularizer_attrs.h b/lib/op-attrs/include/op-attrs/regularizer_attrs.h index 73c53d4f4b..22a1c3c0a3 100644 --- a/lib/op-attrs/include/op-attrs/regularizer_attrs.h +++ b/lib/op-attrs/include/op-attrs/regularizer_attrs.h @@ -3,11 +3,14 @@ #include "op-attrs/l1_regularizer_attrs.h" #include "op-attrs/l2_regularizer_attrs.h" +#include "utils/json.h" namespace FlexFlow { using RegularizerAttrs = std::variant; +CHECK_IS_JSONABLE(RegularizerAttrs); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h new file mode 100644 index 0000000000..952d300aff --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_H + +#include "op-attrs/tensor_dims_t.h" + +namespace FlexFlow { + +FFOrdered const &ff_ordered(TensorDims const &); + +size_t dim_at_idx(TensorDims const &, ff_dim_t); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/tensor_dims_t.h b/lib/op-attrs/include/op-attrs/tensor_dims_t.h new file mode 100644 index 0000000000..5ae891ffe3 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_dims_t.h @@ -0,0 +1,51 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/tensor_dims_t.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_T_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/dim_ordered.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct TensorDims { + TensorDims() = delete; + TensorDims(::FlexFlow::FFOrdered const &ff_ordered); + + bool operator==(TensorDims const &) const; + bool operator!=(TensorDims const &) const; + bool operator<(TensorDims const &) const; + bool operator>(TensorDims const &) const; + bool operator<=(TensorDims const &) const; + bool operator>=(TensorDims const &) const; + ::FlexFlow::FFOrdered ff_ordered; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorDims const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorDims from_json(json const &); + static void to_json(json &, FlexFlow::TensorDims const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorDims const &); +std::ostream &operator<<(std::ostream &, TensorDims const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_T_H diff --git a/lib/op-attrs/include/op-attrs/tensor_dims_t.struct.toml b/lib/op-attrs/include/op-attrs/tensor_dims_t.struct.toml new file mode 100644 index 0000000000..e3913f60f6 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_dims_t.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "TensorDims" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] +includes = [ + "op-attrs/dim_ordered.h", +] + +[[fields]] +name = "ff_ordered" +type = "::FlexFlow::FFOrdered" diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.h b/lib/op-attrs/include/op-attrs/tensor_shape.h index fa34860817..c505bcdc5f 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.h @@ -1,34 +1,12 @@ #ifndef _FLEXFLOW_OPATTRS_TENSOR_SHAPE_H #define _FLEXFLOW_OPATTRS_TENSOR_SHAPE_H -#include "datatype.h" -#include "op-attrs/dim_ordered.h" -#include "op-attrs/ff_dim.h" -#include "utils/stack_vector.h" -#include "utils/visitable.h" +#include "op-attrs/tensor_shape_t.h" namespace FlexFlow { -using TensorDims = FFOrdered; - -struct TensorShape : public use_visitable_cmp { - TensorShape() = delete; - - template - TensorShape(Dims const &dims, DataType data_type) - : dims(dims), data_type(data_type) {} - - size_t at(ff_dim_t) const; - size_t operator[](ff_dim_t) const; - -public: - TensorDims dims; - DataType data_type; -}; +size_t dim_at_idx(TensorShape const &, ff_dim_t); } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::TensorShape, dims, data_type); -MAKE_VISIT_HASHABLE(::FlexFlow::TensorShape); - #endif diff --git a/lib/op-attrs/include/op-attrs/tensor_shape_t.h b/lib/op-attrs/include/op-attrs/tensor_shape_t.h new file mode 100644 index 0000000000..272307c523 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_shape_t.h @@ -0,0 +1,54 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/tensor_shape_t.struct.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SHAPE_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SHAPE_T_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.h" +#include "op-attrs/tensor_dims.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct TensorShape { + TensorShape() = delete; + TensorShape(::FlexFlow::TensorDims const &dims, + ::FlexFlow::DataType const &data_type); + + bool operator==(TensorShape const &) const; + bool operator!=(TensorShape const &) const; + bool operator<(TensorShape const &) const; + bool operator>(TensorShape const &) const; + bool operator<=(TensorShape const &) const; + bool operator>=(TensorShape const &) const; + ::FlexFlow::TensorDims dims; + ::FlexFlow::DataType data_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorShape const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorShape from_json(json const &); + static void to_json(json &, FlexFlow::TensorShape const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorShape const &); +std::ostream &operator<<(std::ostream &, TensorShape const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SHAPE_T_H diff --git a/lib/op-attrs/include/op-attrs/tensor_shape_t.struct.toml b/lib/op-attrs/include/op-attrs/tensor_shape_t.struct.toml new file mode 100644 index 0000000000..b4d8449a72 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_shape_t.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "TensorShape" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_dims.h", + "op-attrs/datatype.h", +] + +[[fields]] +name = "dims" +type = "::FlexFlow::TensorDims" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/src/attention.cc b/lib/op-attrs/src/attention.cc index 2c1500a477..21a98aaf6d 100644 --- a/lib/op-attrs/src/attention.cc +++ b/lib/op-attrs/src/attention.cc @@ -27,40 +27,52 @@ int get_oProjSize(MultiHeadAttentionAttrs const &attrs) { } int get_qSize(TensorShape const &query_shape) { - return query_shape.at(ff_dim_t(0)); + return dim_at_idx(query_shape, ff_dim_t(0)); } int get_kSize(TensorShape const &key_shape) { - return key_shape.at(ff_dim_t(0)); + return dim_at_idx(key_shape, ff_dim_t(0)); } int get_vSize(TensorShape const &value_shape) { - return value_shape.at(ff_dim_t(0)); + return dim_at_idx(value_shape, ff_dim_t(0)); } -int get_qSize(MultiHeadAttentionInputs const &) { +int get_qSize(ParallelMultiHeadAttentionInputs const &) { NOT_IMPLEMENTED(); } -int get_kSize(MultiHeadAttentionInputs const &) { +int get_qSize(MultiHeadAttentionInputs const &) { NOT_IMPLEMENTED(); } -int get_vSize(MultiHeadAttentionInputs const &) { +int get_kSize(ParallelMultiHeadAttentionInputs const &) { + NOT_IMPLEMENTED(); +} + +int get_kSize(MultiHeadAttentionInputs const &) { + NOT_IMPLEMENTED(); +} + +int get_vSize(ParallelMultiHeadAttentionInputs const &) { + NOT_IMPLEMENTED(); +} + +int get_vSize(MultiHeadAttentionInputs const &) { NOT_IMPLEMENTED(); } TensorShape get_weights_shape(MultiHeadAttentionAttrs const &attrs, - MultiHeadAttentionInputs const &inputs) { + MultiHeadAttentionInputs const &inputs) { size_t qParas = get_qProjSize(attrs) * get_qSize(inputs); size_t kParas = get_kProjSize(attrs) * get_kSize(inputs); size_t vParas = get_vProjSize(attrs) * get_vSize(inputs); TensorShape output_shape = get_output_shape(attrs, inputs); size_t oParas = get_oProjSize(attrs) * get_oSize(output_shape); - TensorDims dims = {qParas + kParas + vParas + oParas, - static_cast(attrs.embed_dim)}; + TensorDims dims = {{qParas + kParas + vParas + oParas, + static_cast(attrs.embed_dim)}}; return {dims, DataType::FLOAT}; } @@ -69,14 +81,8 @@ ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, ParallelTensorShape const &query_shape, ParallelTensorShape const &key_shape, ParallelTensorShape const &value_shape) { - /* ParallelDim replica_dim = query_shape.at(ff_dim_t(query_shape.num_dims() - - * 2)); */ - /* replica_dim.size = replica_dim.degree; */ - - /* ParallelDim */ - ParallelTensorShape output_shape = query_shape; - output_shape.at(ff_dim_t(output_shape.num_dims() - 1)).size = attrs.embed_dim; + dim_at_idx(output_shape, ff_dim_t(num_dims(output_shape) - 1)).size = attrs.embed_dim; return output_shape; } @@ -86,13 +92,13 @@ TensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, TensorShape const &value_shape) { ParallelTensorShape parallel_shape = get_output_shape(attrs, - static_cast(query_shape), - static_cast(key_shape), - static_cast(value_shape)); + lift_to_parallel(query_shape), + lift_to_parallel(key_shape), + lift_to_parallel(value_shape)); return get_tensor_shape_unsafe(parallel_shape); } TensorShape get_output_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &) { + MultiHeadAttentionInputs const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/datatype_t.cc b/lib/op-attrs/src/op-attrs/datatype_t.cc new file mode 100644 index 0000000000..ffacdb97ab --- /dev/null +++ b/lib/op-attrs/src/op-attrs/datatype_t.cc @@ -0,0 +1,97 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/datatype_t.enum.toml + +#include "op-attrs/datatype_t.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::DataType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(DataType x) { + switch (x) { + case DataType::BOOL: + return "BOOL"; + case DataType::INT32: + return "INT32"; + case DataType::INT64: + return "INT64"; + case DataType::HALF: + return "HALF"; + case DataType::FLOAT: + return "FLOAT"; + case DataType::DOUBLE: + return "DOUBLE"; + default: + std::ostringstream oss; + oss << "Unknown DataType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, DataType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, DataType x) { + switch (x) { + case DataType::BOOL: + j = "BOOL"; + break; + case DataType::INT32: + j = "INT32"; + break; + case DataType::INT64: + j = "INT64"; + break; + case DataType::HALF: + j = "HALF"; + break; + case DataType::FLOAT: + j = "FLOAT"; + break; + case DataType::DOUBLE: + j = "DOUBLE"; + break; + default: + std::ostringstream oss; + oss << "Unknown DataType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, DataType &x) { + std::string as_str = j.get(); + if (as_str == "BOOL") { + x = DataType::BOOL; + } else if (as_str == "INT32") { + x = DataType::INT32; + } else if (as_str == "INT64") { + x = DataType::INT64; + } else if (as_str == "HALF") { + x = DataType::HALF; + } else if (as_str == "FLOAT") { + x = DataType::FLOAT; + } else if (as_str == "DOUBLE") { + x = DataType::DOUBLE; + } else { + std::ostringstream oss; + oss << "Unknown DataType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::DataType::BOOL, + FlexFlow::DataType::INT32, + FlexFlow::DataType::INT64, + FlexFlow::DataType::HALF, + FlexFlow::DataType::FLOAT, + FlexFlow::DataType::DOUBLE); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/attention_inputs.cc b/lib/op-attrs/src/op-attrs/ops/attention_inputs.cc new file mode 100644 index 0000000000..121806c194 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention_inputs.cc @@ -0,0 +1,89 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml + +#include "op-attrs/ops/attention_inputs.h" + +namespace FlexFlow { +MultiHeadAttentionInputs::MultiHeadAttentionInputs( + ::FlexFlow::TensorShape const &query, + ::FlexFlow::TensorShape const &key, + ::FlexFlow::TensorShape const &value) + : query(query), key(key), value(value) {} +bool MultiHeadAttentionInputs::operator==( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) == + std::tie(other.query, other.key, other.value); +} +bool MultiHeadAttentionInputs::operator!=( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) != + std::tie(other.query, other.key, other.value); +} +bool MultiHeadAttentionInputs::operator<( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) < + std::tie(other.query, other.key, other.value); +} +bool MultiHeadAttentionInputs::operator>( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) > + std::tie(other.query, other.key, other.value); +} +bool MultiHeadAttentionInputs::operator<=( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) <= + std::tie(other.query, other.key, other.value); +} +bool MultiHeadAttentionInputs::operator>=( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) >= + std::tie(other.query, other.key, other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MultiHeadAttentionInputs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorShape>{}(x.query) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::TensorShape>{}(x.key) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::TensorShape>{}(x.value) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MultiHeadAttentionInputs + adl_serializer::from_json( + json const &j) { + return {j.at("query").template get<::FlexFlow::TensorShape>(), + j.at("key").template get<::FlexFlow::TensorShape>(), + j.at("value").template get<::FlexFlow::TensorShape>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MultiHeadAttentionInputs const &v) { + j["__type"] = "MultiHeadAttentionInputs"; + j["query"] = v.query; + j["key"] = v.key; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionInputs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MultiHeadAttentionInputs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.cc b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.cc new file mode 100644 index 0000000000..a2dd992600 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.cc @@ -0,0 +1,90 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml + +#include "op-attrs/ops/parallel_attention_inputs.h" + +namespace FlexFlow { +ParallelMultiHeadAttentionInputs::ParallelMultiHeadAttentionInputs( + ::FlexFlow::ParallelTensorShape const &query, + ::FlexFlow::ParallelTensorShape const &key, + ::FlexFlow::ParallelTensorShape const &value) + : query(query), key(key), value(value) {} +bool ParallelMultiHeadAttentionInputs::operator==( + ParallelMultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) == + std::tie(other.query, other.key, other.value); +} +bool ParallelMultiHeadAttentionInputs::operator!=( + ParallelMultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) != + std::tie(other.query, other.key, other.value); +} +bool ParallelMultiHeadAttentionInputs::operator<( + ParallelMultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) < + std::tie(other.query, other.key, other.value); +} +bool ParallelMultiHeadAttentionInputs::operator>( + ParallelMultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) > + std::tie(other.query, other.key, other.value); +} +bool ParallelMultiHeadAttentionInputs::operator<=( + ParallelMultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) <= + std::tie(other.query, other.key, other.value); +} +bool ParallelMultiHeadAttentionInputs::operator>=( + ParallelMultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) >= + std::tie(other.query, other.key, other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelMultiHeadAttentionInputs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ParallelTensorShape>{}(x.query) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ParallelTensorShape>{}(x.key) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ParallelTensorShape>{}(x.value) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelMultiHeadAttentionInputs + adl_serializer::from_json( + json const &j) { + return {j.at("query").template get<::FlexFlow::ParallelTensorShape>(), + j.at("key").template get<::FlexFlow::ParallelTensorShape>(), + j.at("value").template get<::FlexFlow::ParallelTensorShape>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelMultiHeadAttentionInputs const &v) { + j["__type"] = "ParallelMultiHeadAttentionInputs"; + j["query"] = v.query; + j["key"] = v.key; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelMultiHeadAttentionInputs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ParallelMultiHeadAttentionInputs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_dim_t.cc b/lib/op-attrs/src/op-attrs/parallel_dim_t.cc new file mode 100644 index 0000000000..a2fd7f686f --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_dim_t.cc @@ -0,0 +1,88 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_dim_t.struct.toml + +#include "op-attrs/parallel_dim_t.h" + +namespace FlexFlow { +ParallelDim::ParallelDim(size_t const &size, + int const °ree, + bool const &is_replica_dim) + : size(size), degree(degree), is_replica_dim(is_replica_dim) {} +bool ParallelDim::operator==(ParallelDim const &other) const { + return std::tie(this->size, this->degree, this->is_replica_dim) == + std::tie(other.size, other.degree, other.is_replica_dim); +} +bool ParallelDim::operator!=(ParallelDim const &other) const { + return std::tie(this->size, this->degree, this->is_replica_dim) != + std::tie(other.size, other.degree, other.is_replica_dim); +} +bool ParallelDim::operator<(ParallelDim const &other) const { + return std::tie(this->size, this->degree, this->is_replica_dim) < + std::tie(other.size, other.degree, other.is_replica_dim); +} +bool ParallelDim::operator>(ParallelDim const &other) const { + return std::tie(this->size, this->degree, this->is_replica_dim) > + std::tie(other.size, other.degree, other.is_replica_dim); +} +bool ParallelDim::operator<=(ParallelDim const &other) const { + return std::tie(this->size, this->degree, this->is_replica_dim) <= + std::tie(other.size, other.degree, other.is_replica_dim); +} +bool ParallelDim::operator>=(ParallelDim const &other) const { + return std::tie(this->size, this->degree, this->is_replica_dim) >= + std::tie(other.size, other.degree, other.is_replica_dim); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelDim const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.size) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.degree) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.is_replica_dim) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelDim + adl_serializer::from_json(json const &j) { + return {j.at("size").template get(), + j.at("degree").template get(), + j.at("is_replica_dim").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelDim const &v) { + j["__type"] = "ParallelDim"; + j["size"] = v.size; + j["degree"] = v.degree; + j["is_replica_dim"] = v.is_replica_dim; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), gen::arbitrary(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ParallelDim const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ParallelDim const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc new file mode 100644 index 0000000000..a0d2a9ba0d --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -0,0 +1,18 @@ +#include "op-attrs/parallel_tensor_dims.h" +#include "utils/containers.h" + +namespace FlexFlow { + +std::vector as_vector(ParallelTensorDims const &d) { + return as_vector(d.unwrapped); +} + +int get_num_replica_dims(ParallelTensorDims const &d) { + return count(d.unwrapped, is_replica_dim); +} + +bool is_valid(ParallelTensorDims const &dims) { + return all_of(dims.unwrapped, [](ParallelDim const &d) { return is_valid(d); }); +} + +} diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims_t.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims_t.cc new file mode 100644 index 0000000000..14ac950e69 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims_t.cc @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.struct.toml + +#include "op-attrs/parallel_tensor_dims_t.h" + +namespace FlexFlow { +ParallelTensorDims::ParallelTensorDims( + ::FlexFlow::FFOrdered<::FlexFlow::ParallelDim> const &unwrapped) + : unwrapped(unwrapped) {} +bool ParallelTensorDims::operator==(ParallelTensorDims const &other) const { + return std::tie(this->unwrapped) == std::tie(other.unwrapped); +} +bool ParallelTensorDims::operator!=(ParallelTensorDims const &other) const { + return std::tie(this->unwrapped) != std::tie(other.unwrapped); +} +bool ParallelTensorDims::operator<(ParallelTensorDims const &other) const { + return std::tie(this->unwrapped) < std::tie(other.unwrapped); +} +bool ParallelTensorDims::operator>(ParallelTensorDims const &other) const { + return std::tie(this->unwrapped) > std::tie(other.unwrapped); +} +bool ParallelTensorDims::operator<=(ParallelTensorDims const &other) const { + return std::tie(this->unwrapped) <= std::tie(other.unwrapped); +} +bool ParallelTensorDims::operator>=(ParallelTensorDims const &other) const { + return std::tie(this->unwrapped) >= std::tie(other.unwrapped); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelTensorDims const &x) const { + size_t result = 0; + result ^= + std::hash<::FlexFlow::FFOrdered<::FlexFlow::ParallelDim>>{}(x.unwrapped) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelTensorDims + adl_serializer::from_json(json const &j) { + return {j.at("unwrapped") + .template get<::FlexFlow::FFOrdered<::FlexFlow::ParallelDim>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelTensorDims const &v) { + j["__type"] = "ParallelTensorDims"; + j["unwrapped"] = v.unwrapped; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelTensorDims const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ParallelTensorDims const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape_t.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape_t.cc new file mode 100644 index 0000000000..27f14673db --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape_t.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.struct.toml + +#include "op-attrs/parallel_tensor_shape_t.h" + +namespace FlexFlow { +ParallelTensorShape::ParallelTensorShape( + ::FlexFlow::ParallelTensorDims const &dims, + ::FlexFlow::DataType const &data_type) + : dims(dims), data_type(data_type) {} +bool ParallelTensorShape::operator==(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) == + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator!=(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) != + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator<(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) < + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator>(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) > + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator<=(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) <= + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator>=(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) >= + std::tie(other.dims, other.data_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelTensorShape const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ParallelTensorDims>{}(x.dims) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.data_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelTensorShape + adl_serializer::from_json(json const &j) { + return {j.at("dims").template get<::FlexFlow::ParallelTensorDims>(), + j.at("data_type").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelTensorShape const &v) { + j["__type"] = "ParallelTensorShape"; + j["dims"] = v.dims; + j["data_type"] = v.data_type; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelTensorShape const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ParallelTensorShape const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc new file mode 100644 index 0000000000..34d988ed97 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -0,0 +1,13 @@ +#include "op-attrs/tensor_dims.h" + +namespace FlexFlow { + +FFOrdered const &ff_ordered(TensorDims const &dims) { + return dims.ff_ordered; +} + +size_t dim_at_idx(TensorDims const &dims, ff_dim_t idx) { + return dims.ff_ordered.at(idx); +} + +} diff --git a/lib/op-attrs/src/op-attrs/tensor_dims_t.cc b/lib/op-attrs/src/op-attrs/tensor_dims_t.cc new file mode 100644 index 0000000000..087048b965 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/tensor_dims_t.cc @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/tensor_dims_t.struct.toml + +#include "op-attrs/tensor_dims_t.h" + +namespace FlexFlow { +TensorDims::TensorDims(::FlexFlow::FFOrdered const &ff_ordered) + : ff_ordered(ff_ordered) {} +bool TensorDims::operator==(TensorDims const &other) const { + return std::tie(this->ff_ordered) == std::tie(other.ff_ordered); +} +bool TensorDims::operator!=(TensorDims const &other) const { + return std::tie(this->ff_ordered) != std::tie(other.ff_ordered); +} +bool TensorDims::operator<(TensorDims const &other) const { + return std::tie(this->ff_ordered) < std::tie(other.ff_ordered); +} +bool TensorDims::operator>(TensorDims const &other) const { + return std::tie(this->ff_ordered) > std::tie(other.ff_ordered); +} +bool TensorDims::operator<=(TensorDims const &other) const { + return std::tie(this->ff_ordered) <= std::tie(other.ff_ordered); +} +bool TensorDims::operator>=(TensorDims const &other) const { + return std::tie(this->ff_ordered) >= std::tie(other.ff_ordered); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorDims const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::FFOrdered>{}(x.ff_ordered) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorDims + adl_serializer::from_json(json const &j) { + return {j.at("ff_ordered").template get<::FlexFlow::FFOrdered>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorDims const &v) { + j["__type"] = "TensorDims"; + j["ff_ordered"] = v.ff_ordered; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorDims const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorDims const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_shape_t.cc b/lib/op-attrs/src/op-attrs/tensor_shape_t.cc new file mode 100644 index 0000000000..2cea614524 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/tensor_shape_t.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/tensor_shape_t.struct.toml + +#include "op-attrs/tensor_shape_t.h" + +namespace FlexFlow { +TensorShape::TensorShape(::FlexFlow::TensorDims const &dims, + ::FlexFlow::DataType const &data_type) + : dims(dims), data_type(data_type) {} +bool TensorShape::operator==(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) == + std::tie(other.dims, other.data_type); +} +bool TensorShape::operator!=(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) != + std::tie(other.dims, other.data_type); +} +bool TensorShape::operator<(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) < + std::tie(other.dims, other.data_type); +} +bool TensorShape::operator>(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) > + std::tie(other.dims, other.data_type); +} +bool TensorShape::operator<=(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) <= + std::tie(other.dims, other.data_type); +} +bool TensorShape::operator>=(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) >= + std::tie(other.dims, other.data_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorShape const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorDims>{}(x.dims) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.data_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorShape + adl_serializer::from_json(json const &j) { + return {j.at("dims").template get<::FlexFlow::TensorDims>(), + j.at("data_type").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorShape const &v) { + j["__type"] = "TensorShape"; + j["dims"] = v.dims; + j["data_type"] = v.data_type; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorShape const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorShape const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_tensor_shape.cc b/lib/op-attrs/src/parallel_tensor_shape.cc index e226c38eac..ccdf1d84d0 100644 --- a/lib/op-attrs/src/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/parallel_tensor_shape.cc @@ -4,27 +4,12 @@ namespace FlexFlow { -int ParallelTensorShape::num_dims() const { - return dims.num_dims(); +int num_dims(ParallelTensorShape const &s) { + return num_dims(s.dims); } -static std::vector lift_dims(TensorDims const &dims) { - std::vector lifted_dims; - for (size_t dim_size : dims) { - lifted_dims.push_back({dim_size, 1, false}); - } - lifted_dims.push_back({1, 1, true}); - return lifted_dims; -} - -ParallelTensorDims::ParallelTensorDims(TensorDims const &dims) - : data(lift_dims(dims)) {} - -ParallelTensorShape::ParallelTensorShape(TensorShape const &tensor_shape) - : dims(tensor_shape.dims), data_type(tensor_shape.data_type) {} - int get_num_replica_dims(ParallelTensorShape const &shape) { - return count(shape.dims, is_replica_dim); + return get_num_replica_dims(shape.dims); } int get_num_replicas(ParallelTensorShape const &shape) { @@ -33,62 +18,12 @@ int get_num_replicas(ParallelTensorShape const &shape) { [](ParallelDim const &d) -> int { return d.degree; })); } -bool is_valid(ParallelTensorDims const &dims) { - return all_of(dims, [](ParallelDim const &d) { return is_valid(d); }); -} - bool is_valid(ParallelTensorShape const &shape) { return is_valid(shape.dims); } -ParallelTensorDims::iterator ParallelTensorDims::begin() { - return data.begin(); -} - -ParallelTensorDims::const_iterator ParallelTensorDims::begin() const { - return data.begin(); -} - -ParallelTensorDims::const_iterator ParallelTensorDims::cbegin() const { - return data.cbegin(); -} - -ParallelTensorDims::iterator ParallelTensorDims::end() { - return data.end(); -} - -ParallelTensorDims::const_iterator ParallelTensorDims::end() const { - return data.end(); -} - -ParallelTensorDims::const_iterator ParallelTensorDims::cend() const { - return data.cend(); -} - -ParallelDim const &ParallelTensorDims::at(ff_dim_t const &d) const { - return data.at(d); -} - -ParallelDim &ParallelTensorDims::at(ff_dim_t const &d) { - return data.at(d); -} - -size_t ParallelTensorDims::num_dims() const { - return data.size(); -} - -ParallelDim const &ParallelTensorShape::at(ff_dim_t const &d) const { - return dims.at(d); -} - -ParallelDim &ParallelTensorShape::at(ff_dim_t const &d) { - return dims.at(d); -} -ParallelDim const &ParallelTensorShape::operator[](ff_dim_t const &d) const { - return dims.at(d); -} -ParallelDim &ParallelTensorShape::operator[](ff_dim_t const &d) { - return dims.at(d); +ParallelDim dim_at_idx(ParallelTensorShape const &s, ff_dim_t d) { + return dim_at_idx(s.dims, d); } TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) { diff --git a/lib/op-attrs/src/tensor_shape.cc b/lib/op-attrs/src/tensor_shape.cc index e456b31e3c..6e41f9175a 100644 --- a/lib/op-attrs/src/tensor_shape.cc +++ b/lib/op-attrs/src/tensor_shape.cc @@ -2,12 +2,8 @@ namespace FlexFlow { -size_t TensorShape::at(ff_dim_t d) const { - return dims.at(d); -} - -size_t TensorShape::operator[](ff_dim_t d) const { - return dims[d]; +size_t dim_at_idx(TensorShape const &s, ff_dim_t idx) { + return dim_at_idx(s.dims, idx); } } // namespace FlexFlow diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc index 296a975626..73f7b2c62d 100644 --- a/lib/substitutions/src/graph_pattern.cc +++ b/lib/substitutions/src/graph_pattern.cc @@ -103,14 +103,14 @@ struct EvaluateTensorAttributeExpr { switch (key) { case TensorAttributeKey::DIM_SIZES: { std::vector result; - for (ParallelDim const &dim : this->tensor_shape.dims) { + for (ParallelDim const &dim : ff_ordered(this->tensor_shape.dims)) { result.push_back(dim.size); } return result; } case TensorAttributeKey::DIM_DEGREES: { std::vector result; - for (ParallelDim const &dim : this->tensor_shape.dims) { + for (ParallelDim const &dim : ff_ordered(this->tensor_shape.dims)) { result.push_back(dim.degree); } return result; diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index df22d8a620..75d7e6dbcc 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -100,8 +100,12 @@ TEST_SUITE(FF_TEST_SUITE) { MultiDiEdge e4{n5, p5, n4, p4}; pcg.add_edge(e4); + ParallelDim dim = {2, 1, false}; + ParallelTensorDims dims = { + FFOrdered{dim} + }; pcg.add_label(e4, - ParallelTensor(ParallelTensorDims({2, 1}), + ParallelTensor(dims, DataType::FLOAT, CreateGrad::YES)); diff --git a/lib/utils/include/utils/json.h b/lib/utils/include/utils/json.h index 010943a9f9..1bf86f0cf7 100644 --- a/lib/utils/include/utils/json.h +++ b/lib/utils/include/utils/json.h @@ -150,33 +150,47 @@ struct VariantToJsonFunctor { template void variant_to_json(json &j, std::variant const &v) { - visit(::FlexFlow::VariantToJsonFunctor{j}, v.value); + json jval; + visit(::FlexFlow::VariantToJsonFunctor{jval}, v); + j["value"] = jval; + j["index"] = v.index(); } -template -struct VariantFromJsonFunctor { - VariantFromJsonFunctor(json const &j) : j(j) {} +template +std::optional variant_from_json_impl(json const &j) { + using Type = typename std::variant_alternative::type; - json const &j; - - template - std::optional - operator()(std::integral_constant const &) const { - using Type = typename std::variant_alternative::type; + if (j.at("index").get() == Idx) { + return j.at("value").get(); + } + return std::nullopt; +} - if (visit_struct::get_name()) { - return j.at("value").get(); +template +std::optional variant_from_json_impl(json const &j, + std::index_sequence) { + // If there were no errors when parsing, all but one element of the array + // will be nullopt. This is because each call to variant_from_json_impl will + // have a unique index and exactly one of them will match the index in the + // json object. + std::array, sizeof...(Is)> results{ + variant_from_json_impl(j)...}; + for (std::optional &maybe : results) { + if (maybe) { + return maybe.value(); } } -}; + return std::nullopt; +} template std::variant variant_from_json(json const &j) { - ::FlexFlow::VariantFromJsonFunctor> func(j); - auto result = seq_map(func, seq_enumerate_args_t{}); + using Variant = std::variant; + std::optional result = variant_from_json_impl( + j, std::make_index_sequence()); if (!result.has_value()) { throw ::FlexFlow::mk_runtime_error("Invalid type {} found in json", - j.at("type").get()); + j.at("index").get()); } return result.value(); } From 4873df65537150934df4035afa9988dd57141742 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 11 Apr 2024 00:25:14 -0700 Subject: [PATCH 09/43] Re-pass tests --- lib/compiler/test/src/test_optimal_cost.cc | 6 +- .../include/kernels/element_binary_kernels.h | 2 +- .../src/cuda/element_binary_kernels.cu | 50 +- lib/kernels/src/device.h | 2 +- .../src/hip/element_binary_kernels.cpp | 40 +- .../include/op-attrs/activation.enum.toml | 20 + lib/op-attrs/include/op-attrs/activation.h | 66 +- .../include/op-attrs/aggregate_op.enum.toml | 14 + lib/op-attrs/include/op-attrs/aggregate_op.h | 39 +- lib/op-attrs/include/op-attrs/op.h | 369 --------- lib/op-attrs/include/op-attrs/operator_type.h | 13 + .../op-attrs/operator_type_t.enum.toml | 95 +++ .../include/op-attrs/operator_type_t.h | 119 +++ .../op-attrs/ops/element_binary_attrs.h | 6 +- .../ops/element_binary_attrs.struct.toml | 4 +- .../op-attrs/ops/element_scalar_unary_attrs.h | 7 +- .../element_scalar_unary_attrs.struct.toml | 4 +- .../op-attrs/ops/element_unary_attrs.h | 6 +- .../ops/element_unary_attrs.struct.toml | 4 +- .../include/op-attrs/ops/reduce_attrs.h | 6 +- .../op-attrs/ops/reduce_attrs.struct.toml | 4 +- .../include/op-attrs/parallel_tensor_dims.h | 2 +- .../include/op-attrs/parallel_tensor_dims_t.h | 4 +- .../parallel_tensor_dims_t.struct.toml | 2 +- .../include/op-attrs/pool_op.enum.toml | 14 + lib/op-attrs/include/op-attrs/pool_op.h | 39 +- lib/op-attrs/include/op-attrs/tensor_dims.h | 3 + lib/op-attrs/src/get_op_type.cc | 54 +- lib/op-attrs/src/op-attrs/activation.cc | 81 ++ lib/op-attrs/src/op-attrs/aggregate_op.cc | 57 +- lib/op-attrs/src/op-attrs/operator_type.cc | 24 + lib/op-attrs/src/op-attrs/operator_type_t.cc | 715 ++++++++++++++++++ .../src/{ => op-attrs/ops}/attention.cc | 5 + .../src/op-attrs/ops/element_binary_attrs.cc | 8 +- .../ops/element_scalar_unary_attrs.cc | 8 +- .../src/op-attrs/ops/element_unary_attrs.cc | 6 +- lib/op-attrs/src/op-attrs/ops/reduce_attrs.cc | 6 +- .../src/op-attrs/parallel_tensor_dims.cc | 22 +- .../src/op-attrs/parallel_tensor_dims_t.cc | 28 +- .../{ => op-attrs}/parallel_tensor_shape.cc | 8 + lib/op-attrs/src/op-attrs/pool_op.cc | 61 +- lib/op-attrs/src/op-attrs/tensor_dims.cc | 8 + lib/op-attrs/src/op.cc | 24 - lib/pcg/src/computation_graph_builder.cc | 42 +- lib/runtime/src/parallel_op_info.h | 2 +- .../include/substitution-generator/json.h | 158 ++-- .../include/substitutions/operator_pattern.h | 2 +- lib/substitutions/src/substitution.cc | 106 +-- .../test/src/test_substitution.cc | 8 +- lib/utils/include/utils/fmt.decl.h | 9 + lib/utils/include/utils/fmt.h | 17 +- lib/utils/include/utils/stack_vector.h | 3 + 52 files changed, 1635 insertions(+), 767 deletions(-) create mode 100644 lib/op-attrs/include/op-attrs/activation.enum.toml create mode 100644 lib/op-attrs/include/op-attrs/aggregate_op.enum.toml delete mode 100644 lib/op-attrs/include/op-attrs/op.h create mode 100644 lib/op-attrs/include/op-attrs/operator_type.h create mode 100644 lib/op-attrs/include/op-attrs/operator_type_t.enum.toml create mode 100644 lib/op-attrs/include/op-attrs/operator_type_t.h create mode 100644 lib/op-attrs/include/op-attrs/pool_op.enum.toml create mode 100644 lib/op-attrs/src/op-attrs/activation.cc create mode 100644 lib/op-attrs/src/op-attrs/operator_type.cc create mode 100644 lib/op-attrs/src/op-attrs/operator_type_t.cc rename lib/op-attrs/src/{ => op-attrs/ops}/attention.cc (99%) rename lib/op-attrs/src/{ => op-attrs}/parallel_tensor_shape.cc (79%) delete mode 100644 lib/op-attrs/src/op.cc diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index 91c7a11888..959fa07f25 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -41,8 +41,12 @@ TEST_SUITE(FF_TEST_SUITE) { MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; pcg.add_edge(e); + ParallelDim dim = {2, 1, false}; + ParallelTensorDims dims = { + FFOrdered{dim} + }; pcg.add_output(e, - ParallelTensor(ParallelTensorDims({2, 1}), + ParallelTensor(dims, DataType::FLOAT, CreateGrad::YES)); diff --git a/lib/kernels/include/kernels/element_binary_kernels.h b/lib/kernels/include/kernels/element_binary_kernels.h index 838c9752c7..24eb8c94d2 100644 --- a/lib/kernels/include/kernels/element_binary_kernels.h +++ b/lib/kernels/include/kernels/element_binary_kernels.h @@ -5,7 +5,7 @@ #include "kernels/array_shape.h" #include "kernels/device.h" #include "op-attrs/datatype.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" namespace FlexFlow { diff --git a/lib/kernels/src/cuda/element_binary_kernels.cu b/lib/kernels/src/cuda/element_binary_kernels.cu index b164a2e041..be06504197 100644 --- a/lib/kernels/src/cuda/element_binary_kernels.cu +++ b/lib/kernels/src/cuda/element_binary_kernels.cu @@ -18,7 +18,7 @@ #include "kernels/element_binary_kernels.h" #include "kernels/ff_handle.h" #include "op-attrs/datatype.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" namespace FlexFlow { namespace Kernels { @@ -37,28 +37,28 @@ __global__ void elewise_binary_backward_kernel(size_t volume, float *rhs_grad) { CUDA_KERNEL_LOOP(i, volume) { switch (type) { - case Op::EW_ADD: { + case OperatorType::EW_ADD: { lhs_grad[i] = alpha * out_grad[i] + beta * lhs_grad[i]; rhs_grad[i] = alpha * out_grad[i] + beta * rhs_grad[i]; break; } - case Op::EW_SUB: { + case OperatorType::EW_SUB: { lhs_grad[i] = alpha * out_grad[i] + beta * lhs_grad[i]; rhs_grad[i] = -alpha * out_grad[i] + beta * rhs_grad[i]; break; } - case Op::EW_MUL: { + case OperatorType::EW_MUL: { lhs_grad[i] = alpha * out_grad[i] * rhs[i] + beta * lhs_grad[i]; rhs_grad[i] = alpha * out_grad[i] * lhs[i] + beta * rhs_grad[i]; break; } - case Op::EW_DIV: { + case OperatorType::EW_DIV: { lhs_grad[i] = alpha * out_grad[i] / rhs[i] + beta * lhs_grad[i]; rhs_grad[i] = -alpha * out_grad[i] * lhs[i] / (rhs[i] * rhs[i]) + beta * rhs_grad[i]; break; } - case Op::EW_MAX: { + case OperatorType::EW_MAX: { lhs_grad[i] = (lhs[i] >= rhs[i]) ? alpha * out_grad[i] + beta * lhs_grad[i] : beta * lhs_grad[i]; @@ -67,7 +67,7 @@ __global__ void elewise_binary_backward_kernel(size_t volume, : beta * rhs_grad[i]; break; } - case Op::EW_MIN: { + case OperatorType::EW_MIN: { lhs_grad[i] = (lhs[i] <= rhs[i]) ? alpha * out_grad[i] + beta * lhs_grad[i] : beta * lhs_grad[i]; @@ -103,17 +103,17 @@ ElementBinaryPerDeviceState init_kernel(PerDeviceFFHandle handle, checkCUDNN(cudnnCreateReduceTensorDescriptor(&reduceAddDesc)); switch (op_type) { - case Op::EW_ADD: - case Op::EW_SUB: + case OperatorType::EW_ADD: + case OperatorType::EW_SUB: mode = CUDNN_OP_TENSOR_ADD; break; - case Op::EW_MUL: + case OperatorType::EW_MUL: mode = CUDNN_OP_TENSOR_MUL; break; - case Op::EW_MAX: + case OperatorType::EW_MAX: mode = CUDNN_OP_TENSOR_MAX; break; - case Op::EW_MIN: + case OperatorType::EW_MIN: mode = CUDNN_OP_TENSOR_MIN; break; default: @@ -153,13 +153,13 @@ void forward_kernel(cudaStream_t stream, checkCUDNN(cudnnSetStream(handle.dnn, stream)); float alpha1 = 1.0f, alpha2 = 1.0f, beta = 0.0f; switch (op_type) { - case Op::EW_SUB: + case OperatorType::EW_SUB: alpha2 = -1.0f; break; - case Op::EW_ADD: - case Op::EW_MUL: - case Op::EW_MAX: - case Op::EW_MIN: + case OperatorType::EW_ADD: + case OperatorType::EW_MUL: + case OperatorType::EW_MAX: + case OperatorType::EW_MIN: break; default: assert(false); @@ -168,9 +168,9 @@ void forward_kernel(cudaStream_t stream, // cudnnOpTensor if (broadcast_inputLHS) { // currently only handle add and sub - assert(op_type == Op::EW_SUB || op_type == Op::EW_ADD || - op_type == Op::EW_MUL); - if (op_type == Op::EW_SUB || op_type == Op::EW_ADD) { + assert(op_type == OperatorType::EW_SUB || op_type == OperatorType::EW_ADD || + op_type == OperatorType::EW_MUL); + if (op_type == OperatorType::EW_SUB || op_type == OperatorType::EW_ADD) { // output = (beta*output + alpha1*input1) + beta*output = input1 checkCUDNN(cudnnOpTensor(handle.dnn, m.opDesc, @@ -196,7 +196,7 @@ void forward_kernel(cudaStream_t stream, &alpha1, m.outputTensor, out_ptr)); - } else if (op_type == Op::EW_MUL) { + } else if (op_type == OperatorType::EW_MUL) { checkCUDNN(cudnnSetOpTensorDescriptor(m.opDesc, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, @@ -259,7 +259,7 @@ void backward_kernel(cudaStream_t stream, checkCUDA(cublasSetStream(handle.blas, stream)); checkCUDNN(cudnnSetStream(handle.dnn, stream)); - if (op_type == Op::EW_ADD || op_type == Op::EW_SUB) { + if (op_type == OperatorType::EW_ADD || op_type == OperatorType::EW_SUB) { float alpha = 1.0f, beta = 1.0f; if (lhs_grad_ptr != nullptr) { if (broadcast_inputLHS) { @@ -285,7 +285,7 @@ void backward_kernel(cudaStream_t stream, lhs_grad_ptr)); } } - if (op_type == Op::EW_SUB) { + if (op_type == OperatorType::EW_SUB) { alpha = -1.0f; } if (rhs_grad_ptr != nullptr) { @@ -312,7 +312,7 @@ void backward_kernel(cudaStream_t stream, rhs_grad_ptr)); } } - } else if (op_type == Op::EW_MUL) { + } else if (op_type == OperatorType::EW_MUL) { float alpha1 = 1.0f, alpha2 = 1.0f, beta = 1.0f, zero = 0.0f; if (lhs_grad_ptr != nullptr) { if (broadcast_inputLHS) { @@ -394,7 +394,7 @@ void backward_kernel(cudaStream_t stream, rhs_grad_ptr)); } } - } else if (op_type == Op::EW_MIN || op_type == Op::EW_MAX) { + } else if (op_type == OperatorType::EW_MIN || op_type == OperatorType::EW_MAX) { float alpha = 1.0f, beta = 1.0f; cudnnDataType_t dataType; int n; diff --git a/lib/kernels/src/device.h b/lib/kernels/src/device.h index 00f2888f45..173cd14557 100644 --- a/lib/kernels/src/device.h +++ b/lib/kernels/src/device.h @@ -4,7 +4,7 @@ #include "kernels/array_shape.h" #include "kernels/device.h" #include "op-attrs/datatype.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" #include #if defined(FF_USE_CUDA) diff --git a/lib/kernels/src/hip/element_binary_kernels.cpp b/lib/kernels/src/hip/element_binary_kernels.cpp index 5d29c27837..c8d746847d 100644 --- a/lib/kernels/src/hip/element_binary_kernels.cpp +++ b/lib/kernels/src/hip/element_binary_kernels.cpp @@ -42,17 +42,17 @@ ElementBinaryPerDeviceState init_kernel(PerDeviceFFHandle handle, checkCUDNN(miopenCreateReduceTensorDescriptor(&reduceAddDesc)); switch (op_type) { - case Op::EW_ADD: - case Op::EW_SUB: + case OperatorType::EW_ADD: + case OperatorType::EW_SUB: mode = miopenTensorOpAdd; break; - case Op::EW_MUL: + case OperatorType::EW_MUL: mode = miopenTensorOpMul; break; - case Op::EW_MAX: + case OperatorType::EW_MAX: mode = miopenTensorOpMax; break; - case Op::EW_MIN: + case OperatorType::EW_MIN: mode = miopenTensorOpMin; break; default: @@ -90,25 +90,25 @@ __global__ void elewise_binary_forward_kernel(coord_t volume, float const *in2, float *out) { switch (type) { - case Op::EW_ADD: { + case OperatorType::EW_ADD: { CUDA_KERNEL_LOOP(i, volume) { out[i] = alpha * (in1[i] + in2[i]) + beta * out[i]; } break; } - case Op::EW_SUB: { + case OperatorType::EW_SUB: { CUDA_KERNEL_LOOP(i, volume) { out[i] = alpha * (in1[i] - in2[i]) + beta * out[i]; } break; } - case Op::EW_MUL: { + case OperatorType::EW_MUL: { CUDA_KERNEL_LOOP(i, volume) { out[i] = alpha * in1[i] * in2[i] + beta * out[i]; } break; } - case Op::EW_DIV: { + case OperatorType::EW_DIV: { CUDA_KERNEL_LOOP(i, volume) { out[i] = alpha * (in1[i] / in2[i]) + beta * out[i]; } @@ -130,22 +130,22 @@ __global__ void elewise_binary_backward_kernel(coord_t volume, float *in2_grad) { CUDA_KERNEL_LOOP(i, volume) { switch (type) { - case Op::EW_ADD: { + case OperatorType::EW_ADD: { in1_grad[i] = alpha * out_grad[i] + beta * in1_grad[i]; in2_grad[i] = alpha * out_grad[i] + beta * in2_grad[i]; break; } - case Op::EW_SUB: { + case OperatorType::EW_SUB: { in1_grad[i] = alpha * out_grad[i] + beta * in1_grad[i]; in2_grad[i] = -alpha * out_grad[i] + beta * in2_grad[i]; break; } - case Op::EW_MUL: { + case OperatorType::EW_MUL: { in1_grad[i] = alpha * out_grad[i] * in2[i] + beta * in1_grad[i]; in2_grad[i] = alpha * out_grad[i] * in1[i] + beta * in2_grad[i]; break; } - case Op::EW_DIV: { + case OperatorType::EW_DIV: { in1_grad[i] = alpha * out_grad[i] / in2[i] + beta * in1_grad[i]; in2_grad[i] = -alpha * out_grad[i] * in1[i] / (in2[i] * in2[i]) + beta * in2_grad[i]; @@ -170,11 +170,11 @@ void forward_kernel(hipStream_t stream, float alpha1 = 1.0f, alpha2 = 1.0f, beta = 0.0f; switch (op_type) { - case Op::EW_SUB: + case OperatorType::EW_SUB: alpha2 = -1.0f; break; - case Op::EW_ADD: - case Op::EW_MUL: + case OperatorType::EW_ADD: + case OperatorType::EW_MUL: break; default: assert(false); @@ -183,7 +183,7 @@ void forward_kernel(hipStream_t stream, // cudnnOpTensor if (broadcast_inputLHS) { // currently only handle add and sub - assert(op_type == Op::EW_SUB || op_type == Op::EW_ADD); + assert(op_type == OperatorType::EW_SUB || op_type == OperatorType::EW_ADD); checkCUDNN(miopenOpTensor(handle.dnn, m.opDesc, &beta, @@ -235,7 +235,7 @@ void backward_kernel(hipStream_t stream, checkCUDA(hipblasSetStream(handle.blas, stream)); checkCUDNN(miopenSetStream(handle.dnn, stream)); - if (m.op_type == Op::EW_ADD || m.op_type == Op::EW_SUB) { + if (m.op_type == OperatorType::EW_ADD || m.op_type == OperatorType::EW_SUB) { float alpha = 1.0f, alpha2 = 0.0f, beta = 1.0f; if (lhs_grad_ptr != nullptr) { if (m.broadcast_input1) { @@ -265,7 +265,7 @@ void backward_kernel(hipStream_t stream, lhs_grad_ptr)); } } - if (m.op_type == Op::EW_SUB) { + if (m.op_type == OperatorType::EW_SUB) { alpha = -1.0f; } if (rhs_grad_ptr != nullptr) { @@ -296,7 +296,7 @@ void backward_kernel(hipStream_t stream, rhs_grad_ptr)); } } - } else if (m.op_type == Op::EW_MUL) { + } else if (m.op_type == OperatorType::EW_MUL) { float alpha1 = 1.0f, alpha2 = 1.0f, beta = 1.0f; if (lhs_grad_ptr != nullptr) { checkCUDNN(miopenOpTensor(handle.dnn, diff --git a/lib/op-attrs/include/op-attrs/activation.enum.toml b/lib/op-attrs/include/op-attrs/activation.enum.toml new file mode 100644 index 0000000000..66119da9b1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/activation.enum.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "Activation" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "RELU" + +[[values]] +name = "SIGMOID" + +[[values]] +name = "TANH" + +[[values]] +name = "GELU" diff --git a/lib/op-attrs/include/op-attrs/activation.h b/lib/op-attrs/include/op-attrs/activation.h index c3840fd9b3..4f16289652 100644 --- a/lib/op-attrs/include/op-attrs/activation.h +++ b/lib/op-attrs/include/op-attrs/activation.h @@ -1,49 +1,35 @@ -#ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_ACTIVATION_H -#define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_ACTIVATION_H +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/activation.enum.toml +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_ACTIVATION_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_ACTIVATION_H + +#include "fmt/format.h" #include "nlohmann/json.hpp" -#include "utils/fmt.h" +#include "rapidcheck.h" +#include +#include +#include namespace FlexFlow { - enum class Activation { RELU, SIGMOID, TANH, GELU }; - -NLOHMANN_JSON_SERIALIZE_ENUM(Activation, - {{Activation::RELU, "RELU"}, - {Activation::SIGMOID, "SIGMOID"}, - {Activation::TANH, "TANH"}, - {Activation::GELU, "GELU"}}); - +std::string format_as(Activation); +std::ostream &operator<<(std::ostream &, Activation); +void to_json(::nlohmann::json &, Activation); +void from_json(::nlohmann::json const &, Activation &); } // namespace FlexFlow - -namespace fmt { - +namespace std { template <> -struct formatter<::FlexFlow::Activation> : formatter { - template - auto format(::FlexFlow::Activation a, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (a) { - case Activation::RELU: - name = "ReLU"; - break; - case Activation::SIGMOID: - name = "Sigmoid"; - break; - case Activation::TANH: - name = "Tanh"; - break; - case Activation::GELU: - name = "GeLU"; - break; - } - return formatter::format(name, ctx); - } +struct hash { + size_t operator()(FlexFlow::Activation) const; }; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc -} // namespace fmt - -#endif +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_ACTIVATION_H diff --git a/lib/op-attrs/include/op-attrs/aggregate_op.enum.toml b/lib/op-attrs/include/op-attrs/aggregate_op.enum.toml new file mode 100644 index 0000000000..27aa50f38f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/aggregate_op.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "AggregateOp" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "SUM" + +[[value]] +name = "AVG" diff --git a/lib/op-attrs/include/op-attrs/aggregate_op.h b/lib/op-attrs/include/op-attrs/aggregate_op.h index eb8718533c..e2c7177d9f 100644 --- a/lib/op-attrs/include/op-attrs/aggregate_op.h +++ b/lib/op-attrs/include/op-attrs/aggregate_op.h @@ -1,22 +1,35 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/aggregate_op.enum.toml + #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_H +#include "fmt/format.h" #include "nlohmann/json.hpp" -#include "utils/fmt.h" +#include "rapidcheck.h" +#include +#include +#include namespace FlexFlow { - -enum class AggregateOp { - SUM, - AVG, -}; - -NLOHMANN_JSON_SERIALIZE_ENUM(AggregateOp, - {{AggregateOp::SUM, "SUM"}, - {AggregateOp::AVG, "AVG"}}); - +enum class AggregateOp { SUM }; std::string format_as(AggregateOp); - +std::ostream &operator<<(std::ostream &, AggregateOp); +void to_json(::nlohmann::json &, AggregateOp); +void from_json(::nlohmann::json const &, AggregateOp &); } // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::AggregateOp) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc -#endif +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_H diff --git a/lib/op-attrs/include/op-attrs/op.h b/lib/op-attrs/include/op-attrs/op.h deleted file mode 100644 index 9ad83c3641..0000000000 --- a/lib/op-attrs/include/op-attrs/op.h +++ /dev/null @@ -1,369 +0,0 @@ -#ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_OP_H -#define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_OP_H - -#include "utils/fmt.h" - -namespace FlexFlow { - -enum class Op { - NOOP, - INPUT, - WEIGHT, - CONV2D, - DROPOUT, - LINEAR, - BATCHMATMUL, - POOL2D, - SCALAR_MULTIPLY, - SCALAR_ADD, - SCALAR_FLOOR_DIV, - SCALAR_TRUE_DIV, - SCALAR_SUB, - RELU, - IDENTITY, - SIGMOID, - TANH, - ELU, - FLAT, - SOFTMAX, - BATCHNORM, - CONCAT, - SPLIT, - EMBEDDING, - CACHE, - // OP_ELEMENTWISE, - RESHAPE, - REVERSE, - TRANSPOSE, - EW_ADD, - EW_MUL, - MATMUL, - MUL, - ENLARGE, - SQUEEZE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Squeeze - UNSQUEEZE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Unsqueeze - EW_SUB, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sub - EW_DIV, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Div - EW_EQUAL, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Equal - EW_GREATER, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Greater - EW_LESS, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Less - EW_MAX, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Max - EW_MIN, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Min - REDUCE_ARGMAX, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ArgMax - REDUCE_ARGMIN, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ArgMin - REDUCE_MAX, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceMax - REDUCE_MEAN, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceMean - REDUCE_MIN, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceMin - REDUCE_PROD, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceProd - REDUCE_SUM, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceSum - PAD, // https://github.com/dmlc/tvm/blob/master/topi/python/topi/nn/pad.py - SHAPE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shape - SIZE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Size - TOPK, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#TopK - WHERE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Where - CEIL, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Ceil - CAST, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Cast - EXP, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Exp - ROUND, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Round - LOG, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Log - LOGICAL_NOT, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Not - SQRT, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sqrt - SIN, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sin - COS, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Cos - LEAKYRELU, - SLICE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Slice - RESIZE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Resize - PRELU, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#PRelu - GELU, - MULTIHEAD_ATTENTION, - FUSED, // Fused operator type for internal fusion optimizations - RSQRT, // https://pytorch.org/docs/stable/generated/torch.rsqrt.html - POW, // https://pytorch.org/docs/stable/generated/torch.pow.html - MEAN, // https://pytorch.org/docs/stable/generated/torch.mean.html - LAYERNORM, - GATHER, // https://pytorch.org/docs/stable/generated/torch.gather.html - BROADCAST, - // Parallel Ops - REPARTITION, - COMBINE, - REPLICATE, - REDUCTION, - BATCH, - PIPELINE, - FUSED_PARALLEL, -}; - -using OperatorType = Op; - -std::string get_operator_type_name(Op op); - -} // namespace FlexFlow - -namespace fmt { - -template <> -struct formatter<::FlexFlow::Op> : formatter { - template - auto format(::FlexFlow::Op ot, FormatContext &ctx) -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (ot) { - case Op::CONV2D: - name = "Conv2D"; - break; - case Op::DROPOUT: - name = "Dropout"; - break; - case Op::LINEAR: - name = "Dense"; - break; - case Op::BATCHMATMUL: - name = "BatchMatMul"; - break; - case Op::POOL2D: - name = "Pool2D"; - break; - case Op::SCALAR_MULTIPLY: - name = "ScalarMultiply"; - break; - case Op::SCALAR_ADD: - name = "ScalarAdd"; - break; - case Op::SCALAR_FLOOR_DIV: - name = "ScalarFloorDiv"; - break; - case Op::SCALAR_TRUE_DIV: - name = "ScalarTrueDiv"; - break; - case Op::SCALAR_SUB: - name = "ScalarSub"; - break; - case Op::RELU: - name = "ReLU"; - break; - case Op::SIGMOID: - name = "Sigmoid"; - break; - case Op::TANH: - name = "Tanh"; - break; - case Op::ELU: - name = "Elu"; - break; - case Op::FLAT: - name = "Flat"; - break; - case Op::SOFTMAX: - name = "Softmax"; - break; - case Op::BATCHNORM: - name = "BatchNorm"; - break; - case Op::CONCAT: - name = "Concat"; - break; - case Op::SPLIT: - name = "Split"; - break; - case Op::EMBEDDING: - name = "Embedding"; - break; - case Op::GATHER: - name = "Gather"; - break; - case Op::CACHE: - name = "Cache"; - break; - case Op::RESHAPE: - name = "Reshape"; - break; - case Op::REVERSE: - name = "Reverse"; - break; - case Op::TRANSPOSE: - name = "Transpose"; - break; - case Op::EW_ADD: - name = "Add"; - break; - case Op::EW_MUL: - name = "Mul"; - break; - case Op::MATMUL: - name = "Matmul"; - break; - case Op::MUL: - name = "Mul"; - break; - case Op::ENLARGE: - name = "Enlarge"; - break; - case Op::SQUEEZE: - name = "Squeeze"; - break; - case Op::UNSQUEEZE: - name = "Unsqueeze"; - break; - case Op::EW_SUB: - name = "Sub"; - break; - case Op::EW_DIV: - name = "Div"; - break; - case Op::EW_EQUAL: - name = "Equal"; - break; - case Op::EW_GREATER: - name = "Greater"; - break; - case Op::EW_LESS: - name = "Less"; - break; - case Op::EW_MAX: - name = "Max"; - break; - case Op::EW_MIN: - name = "Min"; - break; - case Op::REDUCE_ARGMAX: - name = "ReduceArgMax"; - break; - case Op::REDUCE_ARGMIN: - name = "ReduceArgMin"; - break; - case Op::REDUCE_MAX: - name = "ReduceMax"; - break; - case Op::REDUCE_MEAN: - name = "ReduceMean"; - break; - case Op::REDUCE_MIN: - name = "ReduceMin"; - break; - case Op::REDUCE_PROD: - name = "ReduceProd"; - break; - case Op::REDUCE_SUM: - name = "ReduceSum"; - break; - case Op::PAD: - name = "Pad"; - break; - case Op::SHAPE: - name = "Shape"; - break; - case Op::SIZE: - name = "Size"; - break; - case Op::TOPK: - name = "TopK"; - break; - case Op::WHERE: - name = "Where"; - break; - case Op::CEIL: - name = "Ceil"; - break; - case Op::CAST: - name = "Cast"; - break; - case Op::EXP: - name = "Exp"; - break; - case Op::SIN: - name = "Sin"; - break; - case Op::COS: - name = "Cos"; - break; - case Op::ROUND: - name = "Round"; - break; - case Op::LOG: - name = "Log"; - break; - case Op::LOGICAL_NOT: - name = "LogicalNot"; - break; - case Op::SQRT: - name = "Sqrt"; - break; - case Op::LEAKYRELU: - name = "LeakyReLU"; - break; - case Op::SLICE: - name = "Slice"; - break; - case Op::RESIZE: - name = "Resize"; - break; - case Op::PRELU: - name = "PReLU"; - break; - case Op::MULTIHEAD_ATTENTION: - name = "MultiHeadAttention"; - break; - case Op::INPUT: - name = "Input"; - break; - case Op::WEIGHT: - name = "Weight"; - break; - case Op::NOOP: - name = "NoOp"; - break; - case Op::FUSED: - name = "FusedOp"; - break; - case Op::RSQRT: - name = "Rsqrt"; - break; - case Op::POW: - name = "Pow"; - break; - case Op::MEAN: - name = "Mean"; - break; - case Op::LAYERNORM: - name = "LayerNorm"; - break; - case Op::IDENTITY: - name = "Identity"; - break; - // Parallel Ops - case Op::REPARTITION: - name = "Repartition"; - break; - case Op::COMBINE: - name = "Combine"; - break; - case Op::REPLICATE: - name = "Replicate"; - break; - case Op::REDUCTION: - name = "Reduction"; - break; - case Op::PIPELINE: - name = "Pipeline"; - break; - case Op::FUSED_PARALLEL: - name = "FusedParallelOp"; - break; - case Op::GELU: - name = "GeLU"; - break; - case Op::BROADCAST: - name = "Broadcast"; - break; - case Op::BATCH: - name = "Batch"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - -#endif diff --git a/lib/op-attrs/include/op-attrs/operator_type.h b/lib/op-attrs/include/op-attrs/operator_type.h new file mode 100644 index 0000000000..ef7172eaa2 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_type.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_H + +#include "op-attrs/operator_type_t.h" + +namespace FlexFlow { + +std::string get_operator_type_name(OperatorType); +bool is_parallel_op(OperatorType); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/operator_type_t.enum.toml b/lib/op-attrs/include/op-attrs/operator_type_t.enum.toml new file mode 100644 index 0000000000..8815d69dda --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_type_t.enum.toml @@ -0,0 +1,95 @@ +namespace = "FlexFlow" +name = "OperatorType" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +values = [ + { name = "NOOP" }, + { name = "INPUT" }, + { name = "WEIGHT" }, + { name = "CONV2D" }, + { name = "DROPOUT" }, + { name = "LINEAR" }, + { name = "BATCHMATMUL" }, + { name = "POOL2D" }, + { name = "SCALAR_MULTIPLY" }, + { name = "SCALAR_ADD" }, + { name = "SCALAR_FLOOR_DIV" }, + { name = "SCALAR_TRUE_DIV" }, + { name = "SCALAR_SUB" }, + { name = "RELU" }, + { name = "IDENTITY" }, + { name = "SIGMOID" }, + { name = "TANH" }, + { name = "ELU" }, + { name = "FLAT" }, + { name = "SOFTMAX" }, + { name = "BATCHNORM" }, + { name = "CONCAT" }, + { name = "SPLIT" }, + { name = "EMBEDDING" }, + { name = "CACHE" }, + { name = "RESHAPE" }, + { name = "REVERSE" }, + { name = "TRANSPOSE" }, + { name = "EW_ADD" }, + { name = "EW_MUL" }, + { name = "MATMUL" }, + { name = "MUL" }, + { name = "ENLARGE" }, + { name = "SQUEEZE" }, + { name = "UNSQUEEZE" }, + { name = "EW_SUB" }, + { name = "EW_DIV" }, + { name = "EW_EQUAL" }, + { name = "EW_GREATER" }, + { name = "EW_LESS" }, + { name = "EW_MAX" }, + { name = "EW_MIN" }, + { name = "REDUCE_ARGMAX" }, + { name = "REDUCE_ARGMIN" }, + { name = "REDUCE_MAX" }, + { name = "REDUCE_MEAN" }, + { name = "REDUCE_MIN" }, + { name = "REDUCE_PROD" }, + { name = "REDUCE_SUM" }, + { name = "PAD" }, + { name = "SHAPE" }, + { name = "SIZE" }, + { name = "TOPK" }, + { name = "WHERE" }, + { name = "CEIL" }, + { name = "CAST" }, + { name = "EXP" }, + { name = "ROUND" }, + { name = "LOG" }, + { name = "LOGICAL_NOT" }, + { name = "SQRT" }, + { name = "SIN" }, + { name = "COS" }, + { name = "LEAKYRELU" }, + { name = "SLICE" }, + { name = "RESIZE" }, + { name = "PRELU" }, + { name = "GELU" }, + { name = "MULTIHEAD_ATTENTION" }, + { name = "FUSED" }, + { name = "RSQRT" }, + { name = "POW" }, + { name = "MEAN" }, + { name = "LAYERNORM" }, + { name = "GATHER" }, + { name = "BROADCAST" }, + { name = "REPARTITION" }, + { name = "COMBINE" }, + { name = "REPLICATE" }, + { name = "REDUCTION" }, + { name = "BATCH" }, + { name = "PIPELINE" }, + { name = "FUSED_PARALLEL" }, +] + diff --git a/lib/op-attrs/include/op-attrs/operator_type_t.h b/lib/op-attrs/include/op-attrs/operator_type_t.h new file mode 100644 index 0000000000..170dcc65c4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_type_t.h @@ -0,0 +1,119 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/operator_type_t.enum.toml + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_T_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class OperatorType { + NOOP, + INPUT, + WEIGHT, + CONV2D, + DROPOUT, + LINEAR, + BATCHMATMUL, + POOL2D, + SCALAR_MULTIPLY, + SCALAR_ADD, + SCALAR_FLOOR_DIV, + SCALAR_TRUE_DIV, + SCALAR_SUB, + RELU, + IDENTITY, + SIGMOID, + TANH, + ELU, + FLAT, + SOFTMAX, + BATCHNORM, + CONCAT, + SPLIT, + EMBEDDING, + CACHE, + RESHAPE, + REVERSE, + TRANSPOSE, + EW_ADD, + EW_MUL, + MATMUL, + MUL, + ENLARGE, + SQUEEZE, + UNSQUEEZE, + EW_SUB, + EW_DIV, + EW_EQUAL, + EW_GREATER, + EW_LESS, + EW_MAX, + EW_MIN, + REDUCE_ARGMAX, + REDUCE_ARGMIN, + REDUCE_MAX, + REDUCE_MEAN, + REDUCE_MIN, + REDUCE_PROD, + REDUCE_SUM, + PAD, + SHAPE, + SIZE, + TOPK, + WHERE, + CEIL, + CAST, + EXP, + ROUND, + LOG, + LOGICAL_NOT, + SQRT, + SIN, + COS, + LEAKYRELU, + SLICE, + RESIZE, + PRELU, + GELU, + MULTIHEAD_ATTENTION, + FUSED, + RSQRT, + POW, + MEAN, + LAYERNORM, + GATHER, + BROADCAST, + REPARTITION, + COMBINE, + REPLICATE, + REDUCTION, + BATCH, + PIPELINE, + FUSED_PARALLEL +}; +std::string format_as(OperatorType); +std::ostream &operator<<(std::ostream &, OperatorType); +void to_json(::nlohmann::json &, OperatorType); +void from_json(::nlohmann::json const &, OperatorType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_T_H diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.h b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.h index df25e3d4d8..a3532a41cf 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.h @@ -8,7 +8,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/datatype.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" #include #include #include @@ -17,7 +17,7 @@ namespace FlexFlow { struct ElementBinaryAttrs { ElementBinaryAttrs() = delete; - ElementBinaryAttrs(::FlexFlow::Op const &type, + ElementBinaryAttrs(::FlexFlow::OperatorType const &type, ::FlexFlow::DataType const &compute_type, bool const &should_broadcast_lhs, bool const &should_broadcast_rhs); @@ -28,7 +28,7 @@ struct ElementBinaryAttrs { bool operator>(ElementBinaryAttrs const &) const; bool operator<=(ElementBinaryAttrs const &) const; bool operator>=(ElementBinaryAttrs const &) const; - ::FlexFlow::Op type; + ::FlexFlow::OperatorType type; ::FlexFlow::DataType compute_type; bool should_broadcast_lhs; bool should_broadcast_rhs; diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml index e09f1d551b..9479cb2956 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml @@ -11,13 +11,13 @@ features = [ ] includes = [ - "op-attrs/op.h", + "op-attrs/operator_type.h", "op-attrs/datatype.h", ] [[fields]] name = "type" -type = "::FlexFlow::Op" +type = "::FlexFlow::OperatorType" [[fields]] name = "compute_type" diff --git a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.h b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.h index 445b2b7849..7eb369111a 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.h @@ -7,7 +7,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" #include #include #include @@ -16,7 +16,8 @@ namespace FlexFlow { struct ElementScalarUnaryAttrs { ElementScalarUnaryAttrs() = delete; - ElementScalarUnaryAttrs(::FlexFlow::Op const &op_type, float const &scalar); + ElementScalarUnaryAttrs(::FlexFlow::OperatorType const &op_type, + float const &scalar); bool operator==(ElementScalarUnaryAttrs const &) const; bool operator!=(ElementScalarUnaryAttrs const &) const; @@ -24,7 +25,7 @@ struct ElementScalarUnaryAttrs { bool operator>(ElementScalarUnaryAttrs const &) const; bool operator<=(ElementScalarUnaryAttrs const &) const; bool operator>=(ElementScalarUnaryAttrs const &) const; - ::FlexFlow::Op op_type; + ::FlexFlow::OperatorType op_type; float scalar; }; } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml index 3f20ea7a51..2f406a67d5 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml @@ -11,12 +11,12 @@ features = [ ] includes = [ - "op-attrs/op.h" + "op-attrs/operator_type.h" ] [[fields]] name = "op_type" -type = "::FlexFlow::Op" +type = "::FlexFlow::OperatorType" [[fields]] name = "scalar" diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.h b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.h index 163e2824cb..dfa2aa30e8 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.h @@ -7,7 +7,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" #include #include #include @@ -16,7 +16,7 @@ namespace FlexFlow { struct ElementUnaryAttrs { ElementUnaryAttrs() = delete; - ElementUnaryAttrs(::FlexFlow::Op const &op_type); + ElementUnaryAttrs(::FlexFlow::OperatorType const &op_type); bool operator==(ElementUnaryAttrs const &) const; bool operator!=(ElementUnaryAttrs const &) const; @@ -24,7 +24,7 @@ struct ElementUnaryAttrs { bool operator>(ElementUnaryAttrs const &) const; bool operator<=(ElementUnaryAttrs const &) const; bool operator>=(ElementUnaryAttrs const &) const; - ::FlexFlow::Op op_type; + ::FlexFlow::OperatorType op_type; }; } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml index 60ec81cd66..fad251d181 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml @@ -11,9 +11,9 @@ features = [ ] includes = [ - "op-attrs/op.h" + "op-attrs/operator_type.h" ] [[fields]] name = "op_type" -type = "::FlexFlow::Op" +type = "::FlexFlow::OperatorType" diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.h b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.h index f1a94788a4..0bc0fc759a 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.h @@ -8,7 +8,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/ff_dim.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" #include "utils/stack_vector.h" #include #include @@ -20,7 +20,7 @@ struct ReduceAttrs { ReduceAttrs() = delete; ReduceAttrs(::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> const &axes, - ::FlexFlow::Op const &op_type, + ::FlexFlow::OperatorType const &op_type, bool const &keepdims); bool operator==(ReduceAttrs const &) const; @@ -30,7 +30,7 @@ struct ReduceAttrs { bool operator<=(ReduceAttrs const &) const; bool operator>=(ReduceAttrs const &) const; ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> axes; - ::FlexFlow::Op op_type; + ::FlexFlow::OperatorType op_type; bool keepdims; }; } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml index b5c12c80fe..9f4bc6d5aa 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/op.h", + "op-attrs/operator_type.h", "op-attrs/ff_dim.h", "utils/stack_vector.h", ] @@ -21,7 +21,7 @@ type = "::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>" [[fields]] name = "op_type" -type = "::FlexFlow::Op" +type = "::FlexFlow::OperatorType" [[fields]] name = "keepdims" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index 787938322c..92519dc09a 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -13,7 +13,7 @@ std::vector as_vector(ParallelTensorDims const &); int get_num_replica_dims(ParallelTensorDims const &); -size_t get_volume(ParallelTensorDims const &); +/* size_t get_volume(ParallelTensorDims const &); */ size_t num_dims(ParallelTensorDims const &); ParallelDim dim_at_idx(ParallelTensorDims const &, ff_dim_t); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.h index 54fb55b4c0..dd5d53d9cf 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.h @@ -18,7 +18,7 @@ namespace FlexFlow { struct ParallelTensorDims { ParallelTensorDims() = delete; ParallelTensorDims( - ::FlexFlow::FFOrdered<::FlexFlow::ParallelDim> const &unwrapped); + ::FlexFlow::FFOrdered<::FlexFlow::ParallelDim> const &ff_ordered); bool operator==(ParallelTensorDims const &) const; bool operator!=(ParallelTensorDims const &) const; @@ -26,7 +26,7 @@ struct ParallelTensorDims { bool operator>(ParallelTensorDims const &) const; bool operator<=(ParallelTensorDims const &) const; bool operator>=(ParallelTensorDims const &) const; - ::FlexFlow::FFOrdered<::FlexFlow::ParallelDim> unwrapped; + ::FlexFlow::FFOrdered<::FlexFlow::ParallelDim> ff_ordered; }; } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.struct.toml index 660b6f95f4..09c5b5ff4f 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.struct.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.struct.toml @@ -15,5 +15,5 @@ includes = [ ] [[fields]] -name = "unwrapped" +name = "ff_ordered" type = "::FlexFlow::FFOrdered<::FlexFlow::ParallelDim>" diff --git a/lib/op-attrs/include/op-attrs/pool_op.enum.toml b/lib/op-attrs/include/op-attrs/pool_op.enum.toml new file mode 100644 index 0000000000..88f4dfea19 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pool_op.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "PoolOp" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "MAX" + +[[values]] +name = "AVG" diff --git a/lib/op-attrs/include/op-attrs/pool_op.h b/lib/op-attrs/include/op-attrs/pool_op.h index eae02a1a3b..00c7852bbf 100644 --- a/lib/op-attrs/include/op-attrs/pool_op.h +++ b/lib/op-attrs/include/op-attrs/pool_op.h @@ -1,20 +1,35 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/pool_op.enum.toml + #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_H +#include "fmt/format.h" #include "nlohmann/json.hpp" -#include "utils/fmt.h" +#include "rapidcheck.h" +#include +#include +#include namespace FlexFlow { - -enum class PoolOp { - MAX, - AVG, -}; - -NLOHMANN_JSON_SERIALIZE_ENUM(PoolOp, - {{PoolOp::MAX, "MAX"}, {PoolOp::AVG, "AVG"}}); - +enum class PoolOp { MAX, AVG }; std::string format_as(PoolOp); - +std::ostream &operator<<(std::ostream &, PoolOp); +void to_json(::nlohmann::json &, PoolOp); +void from_json(::nlohmann::json const &, PoolOp &); } // namespace FlexFlow -#endif +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::PoolOp) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_H diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h index 952d300aff..ec6c208331 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_H #include "op-attrs/tensor_dims_t.h" +#include "op-attrs/parallel_tensor_dims_t.h" namespace FlexFlow { @@ -9,6 +10,8 @@ FFOrdered const &ff_ordered(TensorDims const &); size_t dim_at_idx(TensorDims const &, ff_dim_t); +ParallelTensorDims lift_to_parallel(TensorDims const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/get_op_type.cc b/lib/op-attrs/src/get_op_type.cc index 3fa401b647..2fb539472f 100644 --- a/lib/op-attrs/src/get_op_type.cc +++ b/lib/op-attrs/src/get_op_type.cc @@ -3,25 +3,25 @@ namespace FlexFlow { OperatorType get_op_type(BatchMatmulAttrs const &) { - return Op::BATCHMATMUL; + return OperatorType::BATCHMATMUL; } OperatorType get_op_type(BatchNormAttrs const &) { - return Op::BATCHNORM; + return OperatorType::BATCHNORM; } OperatorType get_op_type(BroadcastAttrs const &) { - return Op::BROADCAST; + return OperatorType::BROADCAST; } OperatorType get_op_type(CastAttrs const &) { - return Op::CAST; + return OperatorType::CAST; } OperatorType get_op_type(ConcatAttrs const &) { - return Op::CONCAT; + return OperatorType::CONCAT; } OperatorType get_op_type(Conv2DAttrs const &) { - return Op::CONV2D; + return OperatorType::CONV2D; } OperatorType get_op_type(DropoutAttrs const &) { - return Op::DROPOUT; + return OperatorType::DROPOUT; } OperatorType get_op_type(ElementBinaryAttrs const &attrs) { return attrs.type; @@ -33,64 +33,64 @@ OperatorType get_op_type(ElementScalarUnaryAttrs const &attrs) { return attrs.op_type; } OperatorType get_op_type(EmbeddingAttrs const &) { - return Op::EMBEDDING; + return OperatorType::EMBEDDING; } OperatorType get_op_type(FlatAttrs const &) { - return Op::FLAT; + return OperatorType::FLAT; } OperatorType get_op_type(GatherAttrs const &) { - return Op::GATHER; + return OperatorType::GATHER; } OperatorType get_op_type(InputAttrs const &) { - return Op::INPUT; + return OperatorType::INPUT; } OperatorType get_op_type(LayerNormAttrs const &) { - return Op::LAYERNORM; + return OperatorType::LAYERNORM; } OperatorType get_op_type(LinearAttrs const &) { - return Op::LINEAR; + return OperatorType::LINEAR; } OperatorType get_op_type(MultiHeadAttentionAttrs const &) { - return Op::MULTIHEAD_ATTENTION; + return OperatorType::MULTIHEAD_ATTENTION; } OperatorType get_op_type(NoopAttrs const &) { - return Op::NOOP; + return OperatorType::NOOP; } OperatorType get_op_type(Pool2DAttrs const &) { - return Op::POOL2D; + return OperatorType::POOL2D; } OperatorType get_op_type(ReduceAttrs const &) { - return Op::REDUCE_SUM; + return OperatorType::REDUCE_SUM; } OperatorType get_op_type(ReshapeAttrs const &) { - return Op::RESHAPE; + return OperatorType::RESHAPE; } OperatorType get_op_type(SplitAttrs const &) { - return Op::SPLIT; + return OperatorType::SPLIT; } OperatorType get_op_type(SoftmaxAttrs const &) { - return Op::SOFTMAX; + return OperatorType::SOFTMAX; } OperatorType get_op_type(TopKAttrs const &) { - return Op::TOPK; + return OperatorType::TOPK; } OperatorType get_op_type(TransposeAttrs const &) { - return Op::TRANSPOSE; + return OperatorType::TRANSPOSE; } OperatorType get_op_type(CombineAttrs const &) { - return Op::COMBINE; + return OperatorType::COMBINE; } OperatorType get_op_type(ReductionAttrs const &) { - return Op::REDUCTION; + return OperatorType::REDUCTION; } OperatorType get_op_type(RepartitionAttrs const &) { - return Op::REPARTITION; + return OperatorType::REPARTITION; } OperatorType get_op_type(ReplicateAttrs const &) { - return Op::REPLICATE; + return OperatorType::REPLICATE; } OperatorType get_op_type(ReverseAttrs const &attrs) { - return Op::REVERSE; + return OperatorType::REVERSE; } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/activation.cc b/lib/op-attrs/src/op-attrs/activation.cc new file mode 100644 index 0000000000..de63eb1ec5 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/activation.cc @@ -0,0 +1,81 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/activation.enum.toml + +#include "op-attrs/activation.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::Activation x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(Activation x) { + switch (x) { + case Activation::RELU: + return "RELU"; + case Activation::SIGMOID: + return "SIGMOID"; + case Activation::TANH: + return "TANH"; + case Activation::GELU: + return "GELU"; + default: + std::ostringstream oss; + oss << "Unknown Activation value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, Activation x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, Activation x) { + switch (x) { + case Activation::RELU: + j = "RELU"; + break; + case Activation::SIGMOID: + j = "SIGMOID"; + break; + case Activation::TANH: + j = "TANH"; + break; + case Activation::GELU: + j = "GELU"; + break; + default: + std::ostringstream oss; + oss << "Unknown Activation value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, Activation &x) { + std::string as_str = j.get(); + if (as_str == "RELU") { + x = Activation::RELU; + } else if (as_str == "SIGMOID") { + x = Activation::SIGMOID; + } else if (as_str == "TANH") { + x = Activation::TANH; + } else if (as_str == "GELU") { + x = Activation::GELU; + } else { + std::ostringstream oss; + oss << "Unknown Activation value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::Activation::RELU, + FlexFlow::Activation::SIGMOID, + FlexFlow::Activation::TANH, + FlexFlow::Activation::GELU); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/aggregate_op.cc b/lib/op-attrs/src/op-attrs/aggregate_op.cc index 29d143579b..34b042ac41 100644 --- a/lib/op-attrs/src/op-attrs/aggregate_op.cc +++ b/lib/op-attrs/src/op-attrs/aggregate_op.cc @@ -1,18 +1,57 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/aggregate_op.enum.toml + #include "op-attrs/aggregate_op.h" -#include "utils/exception.h" -namespace FlexFlow { +#include +#include -std::string format_as(AggregateOp o) { - switch (o) { +namespace std { +size_t hash::operator()(FlexFlow::AggregateOp x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(AggregateOp x) { + switch (x) { case AggregateOp::SUM: return "SUM"; - case AggregateOp::AVG: - return "AVG"; default: - throw mk_runtime_error( - fmt::format("Unknown aggregate op {}", static_cast(o))); + std::ostringstream oss; + oss << "Unknown AggregateOp value " << static_cast(x); + throw std::runtime_error(oss.str()); } } - +std::ostream &operator<<(std::ostream &s, AggregateOp x) { + return s << fmt::to_string(x); +} } // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, AggregateOp x) { + switch (x) { + case AggregateOp::SUM: + j = "SUM"; + break; + default: + std::ostringstream oss; + oss << "Unknown AggregateOp value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, AggregateOp &x) { + std::string as_str = j.get(); + if (as_str == "SUM") { + x = AggregateOp::SUM; + } else { + std::ostringstream oss; + oss << "Unknown AggregateOp value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::AggregateOp::SUM); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/operator_type.cc b/lib/op-attrs/src/op-attrs/operator_type.cc new file mode 100644 index 0000000000..5a516ef122 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/operator_type.cc @@ -0,0 +1,24 @@ +#include "op-attrs/operator_type.h" + +namespace FlexFlow { + +std::string get_operator_type_name(OperatorType op) { + return fmt::to_string(op); +} + +bool is_parallel_op(OperatorType const &t) { + switch (t) { + case OperatorType::REPARTITION: + case OperatorType::COMBINE: + case OperatorType::REPLICATE: + case OperatorType::REDUCTION: + case OperatorType::BATCH: + case OperatorType::PIPELINE: + case OperatorType::FUSED_PARALLEL: + return true; + default: + return false; + } +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/operator_type_t.cc b/lib/op-attrs/src/op-attrs/operator_type_t.cc new file mode 100644 index 0000000000..65c083efd0 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/operator_type_t.cc @@ -0,0 +1,715 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/operator_type_t.enum.toml + +#include "op-attrs/operator_type_t.h" + +#include +#include + +namespace std { +size_t + hash::operator()(FlexFlow::OperatorType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(OperatorType x) { + switch (x) { + case OperatorType::NOOP: + return "NOOP"; + case OperatorType::INPUT: + return "INPUT"; + case OperatorType::WEIGHT: + return "WEIGHT"; + case OperatorType::CONV2D: + return "CONV2D"; + case OperatorType::DROPOUT: + return "DROPOUT"; + case OperatorType::LINEAR: + return "LINEAR"; + case OperatorType::BATCHMATMUL: + return "BATCHMATMUL"; + case OperatorType::POOL2D: + return "POOL2D"; + case OperatorType::SCALAR_MULTIPLY: + return "SCALAR_MULTIPLY"; + case OperatorType::SCALAR_ADD: + return "SCALAR_ADD"; + case OperatorType::SCALAR_FLOOR_DIV: + return "SCALAR_FLOOR_DIV"; + case OperatorType::SCALAR_TRUE_DIV: + return "SCALAR_TRUE_DIV"; + case OperatorType::SCALAR_SUB: + return "SCALAR_SUB"; + case OperatorType::RELU: + return "RELU"; + case OperatorType::IDENTITY: + return "IDENTITY"; + case OperatorType::SIGMOID: + return "SIGMOID"; + case OperatorType::TANH: + return "TANH"; + case OperatorType::ELU: + return "ELU"; + case OperatorType::FLAT: + return "FLAT"; + case OperatorType::SOFTMAX: + return "SOFTMAX"; + case OperatorType::BATCHNORM: + return "BATCHNORM"; + case OperatorType::CONCAT: + return "CONCAT"; + case OperatorType::SPLIT: + return "SPLIT"; + case OperatorType::EMBEDDING: + return "EMBEDDING"; + case OperatorType::CACHE: + return "CACHE"; + case OperatorType::RESHAPE: + return "RESHAPE"; + case OperatorType::REVERSE: + return "REVERSE"; + case OperatorType::TRANSPOSE: + return "TRANSPOSE"; + case OperatorType::EW_ADD: + return "EW_ADD"; + case OperatorType::EW_MUL: + return "EW_MUL"; + case OperatorType::MATMUL: + return "MATMUL"; + case OperatorType::MUL: + return "MUL"; + case OperatorType::ENLARGE: + return "ENLARGE"; + case OperatorType::SQUEEZE: + return "SQUEEZE"; + case OperatorType::UNSQUEEZE: + return "UNSQUEEZE"; + case OperatorType::EW_SUB: + return "EW_SUB"; + case OperatorType::EW_DIV: + return "EW_DIV"; + case OperatorType::EW_EQUAL: + return "EW_EQUAL"; + case OperatorType::EW_GREATER: + return "EW_GREATER"; + case OperatorType::EW_LESS: + return "EW_LESS"; + case OperatorType::EW_MAX: + return "EW_MAX"; + case OperatorType::EW_MIN: + return "EW_MIN"; + case OperatorType::REDUCE_ARGMAX: + return "REDUCE_ARGMAX"; + case OperatorType::REDUCE_ARGMIN: + return "REDUCE_ARGMIN"; + case OperatorType::REDUCE_MAX: + return "REDUCE_MAX"; + case OperatorType::REDUCE_MEAN: + return "REDUCE_MEAN"; + case OperatorType::REDUCE_MIN: + return "REDUCE_MIN"; + case OperatorType::REDUCE_PROD: + return "REDUCE_PROD"; + case OperatorType::REDUCE_SUM: + return "REDUCE_SUM"; + case OperatorType::PAD: + return "PAD"; + case OperatorType::SHAPE: + return "SHAPE"; + case OperatorType::SIZE: + return "SIZE"; + case OperatorType::TOPK: + return "TOPK"; + case OperatorType::WHERE: + return "WHERE"; + case OperatorType::CEIL: + return "CEIL"; + case OperatorType::CAST: + return "CAST"; + case OperatorType::EXP: + return "EXP"; + case OperatorType::ROUND: + return "ROUND"; + case OperatorType::LOG: + return "LOG"; + case OperatorType::LOGICAL_NOT: + return "LOGICAL_NOT"; + case OperatorType::SQRT: + return "SQRT"; + case OperatorType::SIN: + return "SIN"; + case OperatorType::COS: + return "COS"; + case OperatorType::LEAKYRELU: + return "LEAKYRELU"; + case OperatorType::SLICE: + return "SLICE"; + case OperatorType::RESIZE: + return "RESIZE"; + case OperatorType::PRELU: + return "PRELU"; + case OperatorType::GELU: + return "GELU"; + case OperatorType::MULTIHEAD_ATTENTION: + return "MULTIHEAD_ATTENTION"; + case OperatorType::FUSED: + return "FUSED"; + case OperatorType::RSQRT: + return "RSQRT"; + case OperatorType::POW: + return "POW"; + case OperatorType::MEAN: + return "MEAN"; + case OperatorType::LAYERNORM: + return "LAYERNORM"; + case OperatorType::GATHER: + return "GATHER"; + case OperatorType::BROADCAST: + return "BROADCAST"; + case OperatorType::REPARTITION: + return "REPARTITION"; + case OperatorType::COMBINE: + return "COMBINE"; + case OperatorType::REPLICATE: + return "REPLICATE"; + case OperatorType::REDUCTION: + return "REDUCTION"; + case OperatorType::BATCH: + return "BATCH"; + case OperatorType::PIPELINE: + return "PIPELINE"; + case OperatorType::FUSED_PARALLEL: + return "FUSED_PARALLEL"; + default: + std::ostringstream oss; + oss << "Unknown OperatorType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, OperatorType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, OperatorType x) { + switch (x) { + case OperatorType::NOOP: + j = "NOOP"; + break; + case OperatorType::INPUT: + j = "INPUT"; + break; + case OperatorType::WEIGHT: + j = "WEIGHT"; + break; + case OperatorType::CONV2D: + j = "CONV2D"; + break; + case OperatorType::DROPOUT: + j = "DROPOUT"; + break; + case OperatorType::LINEAR: + j = "LINEAR"; + break; + case OperatorType::BATCHMATMUL: + j = "BATCHMATMUL"; + break; + case OperatorType::POOL2D: + j = "POOL2D"; + break; + case OperatorType::SCALAR_MULTIPLY: + j = "SCALAR_MULTIPLY"; + break; + case OperatorType::SCALAR_ADD: + j = "SCALAR_ADD"; + break; + case OperatorType::SCALAR_FLOOR_DIV: + j = "SCALAR_FLOOR_DIV"; + break; + case OperatorType::SCALAR_TRUE_DIV: + j = "SCALAR_TRUE_DIV"; + break; + case OperatorType::SCALAR_SUB: + j = "SCALAR_SUB"; + break; + case OperatorType::RELU: + j = "RELU"; + break; + case OperatorType::IDENTITY: + j = "IDENTITY"; + break; + case OperatorType::SIGMOID: + j = "SIGMOID"; + break; + case OperatorType::TANH: + j = "TANH"; + break; + case OperatorType::ELU: + j = "ELU"; + break; + case OperatorType::FLAT: + j = "FLAT"; + break; + case OperatorType::SOFTMAX: + j = "SOFTMAX"; + break; + case OperatorType::BATCHNORM: + j = "BATCHNORM"; + break; + case OperatorType::CONCAT: + j = "CONCAT"; + break; + case OperatorType::SPLIT: + j = "SPLIT"; + break; + case OperatorType::EMBEDDING: + j = "EMBEDDING"; + break; + case OperatorType::CACHE: + j = "CACHE"; + break; + case OperatorType::RESHAPE: + j = "RESHAPE"; + break; + case OperatorType::REVERSE: + j = "REVERSE"; + break; + case OperatorType::TRANSPOSE: + j = "TRANSPOSE"; + break; + case OperatorType::EW_ADD: + j = "EW_ADD"; + break; + case OperatorType::EW_MUL: + j = "EW_MUL"; + break; + case OperatorType::MATMUL: + j = "MATMUL"; + break; + case OperatorType::MUL: + j = "MUL"; + break; + case OperatorType::ENLARGE: + j = "ENLARGE"; + break; + case OperatorType::SQUEEZE: + j = "SQUEEZE"; + break; + case OperatorType::UNSQUEEZE: + j = "UNSQUEEZE"; + break; + case OperatorType::EW_SUB: + j = "EW_SUB"; + break; + case OperatorType::EW_DIV: + j = "EW_DIV"; + break; + case OperatorType::EW_EQUAL: + j = "EW_EQUAL"; + break; + case OperatorType::EW_GREATER: + j = "EW_GREATER"; + break; + case OperatorType::EW_LESS: + j = "EW_LESS"; + break; + case OperatorType::EW_MAX: + j = "EW_MAX"; + break; + case OperatorType::EW_MIN: + j = "EW_MIN"; + break; + case OperatorType::REDUCE_ARGMAX: + j = "REDUCE_ARGMAX"; + break; + case OperatorType::REDUCE_ARGMIN: + j = "REDUCE_ARGMIN"; + break; + case OperatorType::REDUCE_MAX: + j = "REDUCE_MAX"; + break; + case OperatorType::REDUCE_MEAN: + j = "REDUCE_MEAN"; + break; + case OperatorType::REDUCE_MIN: + j = "REDUCE_MIN"; + break; + case OperatorType::REDUCE_PROD: + j = "REDUCE_PROD"; + break; + case OperatorType::REDUCE_SUM: + j = "REDUCE_SUM"; + break; + case OperatorType::PAD: + j = "PAD"; + break; + case OperatorType::SHAPE: + j = "SHAPE"; + break; + case OperatorType::SIZE: + j = "SIZE"; + break; + case OperatorType::TOPK: + j = "TOPK"; + break; + case OperatorType::WHERE: + j = "WHERE"; + break; + case OperatorType::CEIL: + j = "CEIL"; + break; + case OperatorType::CAST: + j = "CAST"; + break; + case OperatorType::EXP: + j = "EXP"; + break; + case OperatorType::ROUND: + j = "ROUND"; + break; + case OperatorType::LOG: + j = "LOG"; + break; + case OperatorType::LOGICAL_NOT: + j = "LOGICAL_NOT"; + break; + case OperatorType::SQRT: + j = "SQRT"; + break; + case OperatorType::SIN: + j = "SIN"; + break; + case OperatorType::COS: + j = "COS"; + break; + case OperatorType::LEAKYRELU: + j = "LEAKYRELU"; + break; + case OperatorType::SLICE: + j = "SLICE"; + break; + case OperatorType::RESIZE: + j = "RESIZE"; + break; + case OperatorType::PRELU: + j = "PRELU"; + break; + case OperatorType::GELU: + j = "GELU"; + break; + case OperatorType::MULTIHEAD_ATTENTION: + j = "MULTIHEAD_ATTENTION"; + break; + case OperatorType::FUSED: + j = "FUSED"; + break; + case OperatorType::RSQRT: + j = "RSQRT"; + break; + case OperatorType::POW: + j = "POW"; + break; + case OperatorType::MEAN: + j = "MEAN"; + break; + case OperatorType::LAYERNORM: + j = "LAYERNORM"; + break; + case OperatorType::GATHER: + j = "GATHER"; + break; + case OperatorType::BROADCAST: + j = "BROADCAST"; + break; + case OperatorType::REPARTITION: + j = "REPARTITION"; + break; + case OperatorType::COMBINE: + j = "COMBINE"; + break; + case OperatorType::REPLICATE: + j = "REPLICATE"; + break; + case OperatorType::REDUCTION: + j = "REDUCTION"; + break; + case OperatorType::BATCH: + j = "BATCH"; + break; + case OperatorType::PIPELINE: + j = "PIPELINE"; + break; + case OperatorType::FUSED_PARALLEL: + j = "FUSED_PARALLEL"; + break; + default: + std::ostringstream oss; + oss << "Unknown OperatorType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, OperatorType &x) { + std::string as_str = j.get(); + if (as_str == "NOOP") { + x = OperatorType::NOOP; + } else if (as_str == "INPUT") { + x = OperatorType::INPUT; + } else if (as_str == "WEIGHT") { + x = OperatorType::WEIGHT; + } else if (as_str == "CONV2D") { + x = OperatorType::CONV2D; + } else if (as_str == "DROPOUT") { + x = OperatorType::DROPOUT; + } else if (as_str == "LINEAR") { + x = OperatorType::LINEAR; + } else if (as_str == "BATCHMATMUL") { + x = OperatorType::BATCHMATMUL; + } else if (as_str == "POOL2D") { + x = OperatorType::POOL2D; + } else if (as_str == "SCALAR_MULTIPLY") { + x = OperatorType::SCALAR_MULTIPLY; + } else if (as_str == "SCALAR_ADD") { + x = OperatorType::SCALAR_ADD; + } else if (as_str == "SCALAR_FLOOR_DIV") { + x = OperatorType::SCALAR_FLOOR_DIV; + } else if (as_str == "SCALAR_TRUE_DIV") { + x = OperatorType::SCALAR_TRUE_DIV; + } else if (as_str == "SCALAR_SUB") { + x = OperatorType::SCALAR_SUB; + } else if (as_str == "RELU") { + x = OperatorType::RELU; + } else if (as_str == "IDENTITY") { + x = OperatorType::IDENTITY; + } else if (as_str == "SIGMOID") { + x = OperatorType::SIGMOID; + } else if (as_str == "TANH") { + x = OperatorType::TANH; + } else if (as_str == "ELU") { + x = OperatorType::ELU; + } else if (as_str == "FLAT") { + x = OperatorType::FLAT; + } else if (as_str == "SOFTMAX") { + x = OperatorType::SOFTMAX; + } else if (as_str == "BATCHNORM") { + x = OperatorType::BATCHNORM; + } else if (as_str == "CONCAT") { + x = OperatorType::CONCAT; + } else if (as_str == "SPLIT") { + x = OperatorType::SPLIT; + } else if (as_str == "EMBEDDING") { + x = OperatorType::EMBEDDING; + } else if (as_str == "CACHE") { + x = OperatorType::CACHE; + } else if (as_str == "RESHAPE") { + x = OperatorType::RESHAPE; + } else if (as_str == "REVERSE") { + x = OperatorType::REVERSE; + } else if (as_str == "TRANSPOSE") { + x = OperatorType::TRANSPOSE; + } else if (as_str == "EW_ADD") { + x = OperatorType::EW_ADD; + } else if (as_str == "EW_MUL") { + x = OperatorType::EW_MUL; + } else if (as_str == "MATMUL") { + x = OperatorType::MATMUL; + } else if (as_str == "MUL") { + x = OperatorType::MUL; + } else if (as_str == "ENLARGE") { + x = OperatorType::ENLARGE; + } else if (as_str == "SQUEEZE") { + x = OperatorType::SQUEEZE; + } else if (as_str == "UNSQUEEZE") { + x = OperatorType::UNSQUEEZE; + } else if (as_str == "EW_SUB") { + x = OperatorType::EW_SUB; + } else if (as_str == "EW_DIV") { + x = OperatorType::EW_DIV; + } else if (as_str == "EW_EQUAL") { + x = OperatorType::EW_EQUAL; + } else if (as_str == "EW_GREATER") { + x = OperatorType::EW_GREATER; + } else if (as_str == "EW_LESS") { + x = OperatorType::EW_LESS; + } else if (as_str == "EW_MAX") { + x = OperatorType::EW_MAX; + } else if (as_str == "EW_MIN") { + x = OperatorType::EW_MIN; + } else if (as_str == "REDUCE_ARGMAX") { + x = OperatorType::REDUCE_ARGMAX; + } else if (as_str == "REDUCE_ARGMIN") { + x = OperatorType::REDUCE_ARGMIN; + } else if (as_str == "REDUCE_MAX") { + x = OperatorType::REDUCE_MAX; + } else if (as_str == "REDUCE_MEAN") { + x = OperatorType::REDUCE_MEAN; + } else if (as_str == "REDUCE_MIN") { + x = OperatorType::REDUCE_MIN; + } else if (as_str == "REDUCE_PROD") { + x = OperatorType::REDUCE_PROD; + } else if (as_str == "REDUCE_SUM") { + x = OperatorType::REDUCE_SUM; + } else if (as_str == "PAD") { + x = OperatorType::PAD; + } else if (as_str == "SHAPE") { + x = OperatorType::SHAPE; + } else if (as_str == "SIZE") { + x = OperatorType::SIZE; + } else if (as_str == "TOPK") { + x = OperatorType::TOPK; + } else if (as_str == "WHERE") { + x = OperatorType::WHERE; + } else if (as_str == "CEIL") { + x = OperatorType::CEIL; + } else if (as_str == "CAST") { + x = OperatorType::CAST; + } else if (as_str == "EXP") { + x = OperatorType::EXP; + } else if (as_str == "ROUND") { + x = OperatorType::ROUND; + } else if (as_str == "LOG") { + x = OperatorType::LOG; + } else if (as_str == "LOGICAL_NOT") { + x = OperatorType::LOGICAL_NOT; + } else if (as_str == "SQRT") { + x = OperatorType::SQRT; + } else if (as_str == "SIN") { + x = OperatorType::SIN; + } else if (as_str == "COS") { + x = OperatorType::COS; + } else if (as_str == "LEAKYRELU") { + x = OperatorType::LEAKYRELU; + } else if (as_str == "SLICE") { + x = OperatorType::SLICE; + } else if (as_str == "RESIZE") { + x = OperatorType::RESIZE; + } else if (as_str == "PRELU") { + x = OperatorType::PRELU; + } else if (as_str == "GELU") { + x = OperatorType::GELU; + } else if (as_str == "MULTIHEAD_ATTENTION") { + x = OperatorType::MULTIHEAD_ATTENTION; + } else if (as_str == "FUSED") { + x = OperatorType::FUSED; + } else if (as_str == "RSQRT") { + x = OperatorType::RSQRT; + } else if (as_str == "POW") { + x = OperatorType::POW; + } else if (as_str == "MEAN") { + x = OperatorType::MEAN; + } else if (as_str == "LAYERNORM") { + x = OperatorType::LAYERNORM; + } else if (as_str == "GATHER") { + x = OperatorType::GATHER; + } else if (as_str == "BROADCAST") { + x = OperatorType::BROADCAST; + } else if (as_str == "REPARTITION") { + x = OperatorType::REPARTITION; + } else if (as_str == "COMBINE") { + x = OperatorType::COMBINE; + } else if (as_str == "REPLICATE") { + x = OperatorType::REPLICATE; + } else if (as_str == "REDUCTION") { + x = OperatorType::REDUCTION; + } else if (as_str == "BATCH") { + x = OperatorType::BATCH; + } else if (as_str == "PIPELINE") { + x = OperatorType::PIPELINE; + } else if (as_str == "FUSED_PARALLEL") { + x = OperatorType::FUSED_PARALLEL; + } else { + std::ostringstream oss; + oss << "Unknown OperatorType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element( + FlexFlow::OperatorType::NOOP, + FlexFlow::OperatorType::INPUT, + FlexFlow::OperatorType::WEIGHT, + FlexFlow::OperatorType::CONV2D, + FlexFlow::OperatorType::DROPOUT, + FlexFlow::OperatorType::LINEAR, + FlexFlow::OperatorType::BATCHMATMUL, + FlexFlow::OperatorType::POOL2D, + FlexFlow::OperatorType::SCALAR_MULTIPLY, + FlexFlow::OperatorType::SCALAR_ADD, + FlexFlow::OperatorType::SCALAR_FLOOR_DIV, + FlexFlow::OperatorType::SCALAR_TRUE_DIV, + FlexFlow::OperatorType::SCALAR_SUB, + FlexFlow::OperatorType::RELU, + FlexFlow::OperatorType::IDENTITY, + FlexFlow::OperatorType::SIGMOID, + FlexFlow::OperatorType::TANH, + FlexFlow::OperatorType::ELU, + FlexFlow::OperatorType::FLAT, + FlexFlow::OperatorType::SOFTMAX, + FlexFlow::OperatorType::BATCHNORM, + FlexFlow::OperatorType::CONCAT, + FlexFlow::OperatorType::SPLIT, + FlexFlow::OperatorType::EMBEDDING, + FlexFlow::OperatorType::CACHE, + FlexFlow::OperatorType::RESHAPE, + FlexFlow::OperatorType::REVERSE, + FlexFlow::OperatorType::TRANSPOSE, + FlexFlow::OperatorType::EW_ADD, + FlexFlow::OperatorType::EW_MUL, + FlexFlow::OperatorType::MATMUL, + FlexFlow::OperatorType::MUL, + FlexFlow::OperatorType::ENLARGE, + FlexFlow::OperatorType::SQUEEZE, + FlexFlow::OperatorType::UNSQUEEZE, + FlexFlow::OperatorType::EW_SUB, + FlexFlow::OperatorType::EW_DIV, + FlexFlow::OperatorType::EW_EQUAL, + FlexFlow::OperatorType::EW_GREATER, + FlexFlow::OperatorType::EW_LESS, + FlexFlow::OperatorType::EW_MAX, + FlexFlow::OperatorType::EW_MIN, + FlexFlow::OperatorType::REDUCE_ARGMAX, + FlexFlow::OperatorType::REDUCE_ARGMIN, + FlexFlow::OperatorType::REDUCE_MAX, + FlexFlow::OperatorType::REDUCE_MEAN, + FlexFlow::OperatorType::REDUCE_MIN, + FlexFlow::OperatorType::REDUCE_PROD, + FlexFlow::OperatorType::REDUCE_SUM, + FlexFlow::OperatorType::PAD, + FlexFlow::OperatorType::SHAPE, + FlexFlow::OperatorType::SIZE, + FlexFlow::OperatorType::TOPK, + FlexFlow::OperatorType::WHERE, + FlexFlow::OperatorType::CEIL, + FlexFlow::OperatorType::CAST, + FlexFlow::OperatorType::EXP, + FlexFlow::OperatorType::ROUND, + FlexFlow::OperatorType::LOG, + FlexFlow::OperatorType::LOGICAL_NOT, + FlexFlow::OperatorType::SQRT, + FlexFlow::OperatorType::SIN, + FlexFlow::OperatorType::COS, + FlexFlow::OperatorType::LEAKYRELU, + FlexFlow::OperatorType::SLICE, + FlexFlow::OperatorType::RESIZE, + FlexFlow::OperatorType::PRELU, + FlexFlow::OperatorType::GELU, + FlexFlow::OperatorType::MULTIHEAD_ATTENTION, + FlexFlow::OperatorType::FUSED, + FlexFlow::OperatorType::RSQRT, + FlexFlow::OperatorType::POW, + FlexFlow::OperatorType::MEAN, + FlexFlow::OperatorType::LAYERNORM, + FlexFlow::OperatorType::GATHER, + FlexFlow::OperatorType::BROADCAST, + FlexFlow::OperatorType::REPARTITION, + FlexFlow::OperatorType::COMBINE, + FlexFlow::OperatorType::REPLICATE, + FlexFlow::OperatorType::REDUCTION, + FlexFlow::OperatorType::BATCH, + FlexFlow::OperatorType::PIPELINE, + FlexFlow::OperatorType::FUSED_PARALLEL); +} +} // namespace rc diff --git a/lib/op-attrs/src/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc similarity index 99% rename from lib/op-attrs/src/attention.cc rename to lib/op-attrs/src/op-attrs/ops/attention.cc index 21a98aaf6d..9fb884db43 100644 --- a/lib/op-attrs/src/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -105,6 +105,11 @@ TensorShape get_output_shape(MultiHeadAttentionAttrs const &, int get_oSize(ParallelTensorShape const &) { NOT_IMPLEMENTED(); } + +int get_oSize(TensorShape const &) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow // Tensor FFModel::multihead_attention(const Tensor query, diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.cc b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.cc index 0c7c523d5b..b14920f803 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.cc @@ -5,7 +5,7 @@ #include "op-attrs/ops/element_binary_attrs.h" namespace FlexFlow { -ElementBinaryAttrs::ElementBinaryAttrs(::FlexFlow::Op const &type, +ElementBinaryAttrs::ElementBinaryAttrs(::FlexFlow::OperatorType const &type, ::FlexFlow::DataType const &compute_type, bool const &should_broadcast_lhs, bool const &should_broadcast_rhs) @@ -78,8 +78,8 @@ namespace std { size_t hash::operator()( FlexFlow::ElementBinaryAttrs const &x) const { size_t result = 0; - result ^= std::hash<::FlexFlow::Op>{}(x.type) + 0x9e3779b9 + (result << 6) + - (result >> 2); + result ^= std::hash<::FlexFlow::OperatorType>{}(x.type) + 0x9e3779b9 + + (result << 6) + (result >> 2); result ^= std::hash<::FlexFlow::DataType>{}(x.compute_type) + 0x9e3779b9 + (result << 6) + (result >> 2); result ^= std::hash{}(x.should_broadcast_lhs) + 0x9e3779b9 + @@ -93,7 +93,7 @@ size_t hash::operator()( namespace nlohmann { FlexFlow::ElementBinaryAttrs adl_serializer::from_json(json const &j) { - return {j.at("type").template get<::FlexFlow::Op>(), + return {j.at("type").template get<::FlexFlow::OperatorType>(), j.at("compute_type").template get<::FlexFlow::DataType>(), j.at("should_broadcast_lhs").template get(), j.at("should_broadcast_rhs").template get()}; diff --git a/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.cc b/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.cc index 8c54874c58..4c9d59e568 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.cc @@ -5,8 +5,8 @@ #include "op-attrs/ops/element_scalar_unary_attrs.h" namespace FlexFlow { -ElementScalarUnaryAttrs::ElementScalarUnaryAttrs(::FlexFlow::Op const &op_type, - float const &scalar) +ElementScalarUnaryAttrs::ElementScalarUnaryAttrs( + ::FlexFlow::OperatorType const &op_type, float const &scalar) : op_type(op_type), scalar(scalar) {} bool ElementScalarUnaryAttrs::operator==( ElementScalarUnaryAttrs const &other) const { @@ -44,7 +44,7 @@ namespace std { size_t hash::operator()( FlexFlow::ElementScalarUnaryAttrs const &x) const { size_t result = 0; - result ^= std::hash<::FlexFlow::Op>{}(x.op_type) + 0x9e3779b9 + + result ^= std::hash<::FlexFlow::OperatorType>{}(x.op_type) + 0x9e3779b9 + (result << 6) + (result >> 2); result ^= std::hash{}(x.scalar) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -56,7 +56,7 @@ namespace nlohmann { FlexFlow::ElementScalarUnaryAttrs adl_serializer::from_json( json const &j) { - return {j.at("op_type").template get<::FlexFlow::Op>(), + return {j.at("op_type").template get<::FlexFlow::OperatorType>(), j.at("scalar").template get()}; } void adl_serializer::to_json( diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.cc b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.cc index 1123890ac3..fa9a4a1697 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.cc @@ -5,7 +5,7 @@ #include "op-attrs/ops/element_unary_attrs.h" namespace FlexFlow { -ElementUnaryAttrs::ElementUnaryAttrs(::FlexFlow::Op const &op_type) +ElementUnaryAttrs::ElementUnaryAttrs(::FlexFlow::OperatorType const &op_type) : op_type(op_type) {} bool ElementUnaryAttrs::operator==(ElementUnaryAttrs const &other) const { return std::tie(this->op_type) == std::tie(other.op_type); @@ -31,7 +31,7 @@ namespace std { size_t hash::operator()( FlexFlow::ElementUnaryAttrs const &x) const { size_t result = 0; - result ^= std::hash<::FlexFlow::Op>{}(x.op_type) + 0x9e3779b9 + + result ^= std::hash<::FlexFlow::OperatorType>{}(x.op_type) + 0x9e3779b9 + (result << 6) + (result >> 2); return result; } @@ -40,7 +40,7 @@ size_t hash::operator()( namespace nlohmann { FlexFlow::ElementUnaryAttrs adl_serializer::from_json(json const &j) { - return {j.at("op_type").template get<::FlexFlow::Op>()}; + return {j.at("op_type").template get<::FlexFlow::OperatorType>()}; } void adl_serializer::to_json( json &j, FlexFlow::ElementUnaryAttrs const &v) { diff --git a/lib/op-attrs/src/op-attrs/ops/reduce_attrs.cc b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.cc index b6d71c81b4..3cb71259e1 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduce_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.cc @@ -7,7 +7,7 @@ namespace FlexFlow { ReduceAttrs::ReduceAttrs( ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> const &axes, - ::FlexFlow::Op const &op_type, + ::FlexFlow::OperatorType const &op_type, bool const &keepdims) : axes(axes), op_type(op_type), keepdims(keepdims) {} bool ReduceAttrs::operator==(ReduceAttrs const &other) const { @@ -45,7 +45,7 @@ size_t hash::operator()( ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>{}( x.axes) + 0x9e3779b9 + (result << 6) + (result >> 2); - result ^= std::hash<::FlexFlow::Op>{}(x.op_type) + 0x9e3779b9 + + result ^= std::hash<::FlexFlow::OperatorType>{}(x.op_type) + 0x9e3779b9 + (result << 6) + (result >> 2); result ^= std::hash{}(x.keepdims) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -60,7 +60,7 @@ FlexFlow::ReduceAttrs j.at("axes") .template get< ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), - j.at("op_type").template get<::FlexFlow::Op>(), + j.at("op_type").template get<::FlexFlow::OperatorType>(), j.at("keepdims").template get()}; } void adl_serializer::to_json( diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index a0d2a9ba0d..a1fe25e0b6 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -3,16 +3,32 @@ namespace FlexFlow { +FFOrdered const &ff_ordered(ParallelTensorDims const &d) { + return d.ff_ordered; +} + std::vector as_vector(ParallelTensorDims const &d) { - return as_vector(d.unwrapped); + return as_vector(d.ff_ordered); } int get_num_replica_dims(ParallelTensorDims const &d) { - return count(d.unwrapped, is_replica_dim); + return count(d.ff_ordered, is_replica_dim); } bool is_valid(ParallelTensorDims const &dims) { - return all_of(dims.unwrapped, [](ParallelDim const &d) { return is_valid(d); }); + return all_of(dims.ff_ordered, [](ParallelDim const &d) { return is_valid(d); }); +} + +size_t num_dims(ParallelTensorDims const &dims) { + return dims.ff_ordered.size(); +} + +ParallelDim dim_at_idx(ParallelTensorDims const &d, ff_dim_t idx) { + return d.ff_ordered.at(idx); +} + +ParallelDim &dim_at_idx(ParallelTensorDims &d, ff_dim_t idx) { + return d.ff_ordered.at(idx); } } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims_t.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims_t.cc index 14ac950e69..b70a18f4e0 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims_t.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims_t.cc @@ -6,25 +6,25 @@ namespace FlexFlow { ParallelTensorDims::ParallelTensorDims( - ::FlexFlow::FFOrdered<::FlexFlow::ParallelDim> const &unwrapped) - : unwrapped(unwrapped) {} + ::FlexFlow::FFOrdered<::FlexFlow::ParallelDim> const &ff_ordered) + : ff_ordered(ff_ordered) {} bool ParallelTensorDims::operator==(ParallelTensorDims const &other) const { - return std::tie(this->unwrapped) == std::tie(other.unwrapped); + return std::tie(this->ff_ordered) == std::tie(other.ff_ordered); } bool ParallelTensorDims::operator!=(ParallelTensorDims const &other) const { - return std::tie(this->unwrapped) != std::tie(other.unwrapped); + return std::tie(this->ff_ordered) != std::tie(other.ff_ordered); } bool ParallelTensorDims::operator<(ParallelTensorDims const &other) const { - return std::tie(this->unwrapped) < std::tie(other.unwrapped); + return std::tie(this->ff_ordered) < std::tie(other.ff_ordered); } bool ParallelTensorDims::operator>(ParallelTensorDims const &other) const { - return std::tie(this->unwrapped) > std::tie(other.unwrapped); + return std::tie(this->ff_ordered) > std::tie(other.ff_ordered); } bool ParallelTensorDims::operator<=(ParallelTensorDims const &other) const { - return std::tie(this->unwrapped) <= std::tie(other.unwrapped); + return std::tie(this->ff_ordered) <= std::tie(other.ff_ordered); } bool ParallelTensorDims::operator>=(ParallelTensorDims const &other) const { - return std::tie(this->unwrapped) >= std::tie(other.unwrapped); + return std::tie(this->ff_ordered) >= std::tie(other.ff_ordered); } } // namespace FlexFlow @@ -32,9 +32,9 @@ namespace std { size_t hash::operator()( FlexFlow::ParallelTensorDims const &x) const { size_t result = 0; - result ^= - std::hash<::FlexFlow::FFOrdered<::FlexFlow::ParallelDim>>{}(x.unwrapped) + - 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::FFOrdered<::FlexFlow::ParallelDim>>{}( + x.ff_ordered) + + 0x9e3779b9 + (result << 6) + (result >> 2); return result; } } // namespace std @@ -42,13 +42,13 @@ size_t hash::operator()( namespace nlohmann { FlexFlow::ParallelTensorDims adl_serializer::from_json(json const &j) { - return {j.at("unwrapped") + return {j.at("ff_ordered") .template get<::FlexFlow::FFOrdered<::FlexFlow::ParallelDim>>()}; } void adl_serializer::to_json( json &j, FlexFlow::ParallelTensorDims const &v) { j["__type"] = "ParallelTensorDims"; - j["unwrapped"] = v.unwrapped; + j["ff_ordered"] = v.ff_ordered; } } // namespace nlohmann @@ -56,7 +56,7 @@ namespace FlexFlow { std::string format_as(ParallelTensorDims const &x) { std::ostringstream oss; oss << ""; return oss.str(); } diff --git a/lib/op-attrs/src/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc similarity index 79% rename from lib/op-attrs/src/parallel_tensor_shape.cc rename to lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index ccdf1d84d0..8d6125c369 100644 --- a/lib/op-attrs/src/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -26,6 +26,14 @@ ParallelDim dim_at_idx(ParallelTensorShape const &s, ff_dim_t d) { return dim_at_idx(s.dims, d); } +ParallelDim &dim_at_idx(ParallelTensorShape &s, ff_dim_t d) { + return dim_at_idx(s.dims, d); +} + +ParallelTensorShape lift_to_parallel(TensorShape const &s) { + return { lift_to_parallel(s.dims), s.data_type }; +} + TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/pool_op.cc b/lib/op-attrs/src/op-attrs/pool_op.cc index cdb0f8bf4b..dd0e1bfd00 100644 --- a/lib/op-attrs/src/op-attrs/pool_op.cc +++ b/lib/op-attrs/src/op-attrs/pool_op.cc @@ -1,18 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/pool_op.enum.toml + #include "op-attrs/pool_op.h" -#include "utils/exception.h" -namespace FlexFlow { +#include +#include -std::string format_as(PoolOp o) { - switch (o) { +namespace std { +size_t hash::operator()(FlexFlow::PoolOp x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(PoolOp x) { + switch (x) { case PoolOp::MAX: return "MAX"; case PoolOp::AVG: return "AVG"; default: - throw mk_runtime_error( - fmt::format("Unknown pool op {}", static_cast(o))); + std::ostringstream oss; + oss << "Unknown PoolOp value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, PoolOp x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, PoolOp x) { + switch (x) { + case PoolOp::MAX: + j = "MAX"; + break; + case PoolOp::AVG: + j = "AVG"; + break; + default: + std::ostringstream oss; + oss << "Unknown PoolOp value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, PoolOp &x) { + std::string as_str = j.get(); + if (as_str == "MAX") { + x = PoolOp::MAX; + } else if (as_str == "AVG") { + x = PoolOp::AVG; + } else { + std::ostringstream oss; + oss << "Unknown PoolOp value " << as_str; + throw std::runtime_error(oss.str()); } } - } // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::PoolOp::MAX, + FlexFlow::PoolOp::AVG); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index 34d988ed97..9f7316e998 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -1,4 +1,5 @@ #include "op-attrs/tensor_dims.h" +#include "utils/containers.h" namespace FlexFlow { @@ -10,4 +11,11 @@ size_t dim_at_idx(TensorDims const &dims, ff_dim_t idx) { return dims.ff_ordered.at(idx); } +ParallelTensorDims lift_to_parallel(TensorDims const &dims) { + FFOrdered lifted = { + transform(as_vector(dims.ff_ordered), [](size_t const &size) { return ParallelDim{size, 1, false}; }) + }; + return {lifted}; +} + } diff --git a/lib/op-attrs/src/op.cc b/lib/op-attrs/src/op.cc deleted file mode 100644 index 5bc5498d6e..0000000000 --- a/lib/op-attrs/src/op.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "op-attrs/op.h" - -namespace FlexFlow { - -std::string get_operator_type_name(Op op) { - return fmt::to_string(op); -} - -bool is_parallel_op(OperatorType const &t) { - switch (t) { - case Op::REPARTITION: - case Op::COMBINE: - case Op::REPLICATE: - case Op::REDUCTION: - case Op::BATCH: - case Op::PIPELINE: - case Op::FUSED_PARALLEL: - return true; - default: - return false; - } -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/computation_graph_builder.cc b/lib/pcg/src/computation_graph_builder.cc index c2e008231e..68aeb45ff2 100644 --- a/lib/pcg/src/computation_graph_builder.cc +++ b/lib/pcg/src/computation_graph_builder.cc @@ -136,71 +136,71 @@ Tensor ComputationGraphBuilder::element_binary( Tensor ComputationGraphBuilder::exp(Tensor const &input, std::optional const &name) { - return this->element_unary(Op::EXP, input, name); + return this->element_unary(OperatorType::EXP, input, name); } Tensor ComputationGraphBuilder::add(Tensor const &lhs, Tensor const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_ADD, lhs, rhs, name); + return this->element_binary(OperatorType::EW_ADD, lhs, rhs, name); } Tensor ComputationGraphBuilder::subtract(Tensor const &lhs, Tensor const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_SUB, lhs, rhs, name); + return this->element_binary(OperatorType::EW_SUB, lhs, rhs, name); } Tensor ComputationGraphBuilder::multiply(Tensor const &lhs, Tensor const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_MUL, lhs, rhs, name); + return this->element_binary(OperatorType::EW_MUL, lhs, rhs, name); } Tensor ComputationGraphBuilder::divide(Tensor const &lhs, Tensor const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_DIV, lhs, rhs, name); + return this->element_binary(OperatorType::EW_DIV, lhs, rhs, name); } Tensor ComputationGraphBuilder::max(Tensor const &lhs, Tensor const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_MAX, lhs, rhs, name); + return this->element_binary(OperatorType::EW_MAX, lhs, rhs, name); } Tensor ComputationGraphBuilder::min(Tensor const &lhs, Tensor const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_MIN, lhs, rhs, name); + return this->element_binary(OperatorType::EW_MIN, lhs, rhs, name); } Tensor ComputationGraphBuilder::rsqrt(Tensor const &input, std::optional const &name) { - return this->element_unary(Op::RSQRT, input, name); + return this->element_unary(OperatorType::RSQRT, input, name); } Tensor ComputationGraphBuilder::pow(Tensor const &input, float exponent, std::optional const &name) { - return this->element_scalar_unary(Op::POW, input, exponent, name); + return this->element_scalar_unary(OperatorType::POW, input, exponent, name); } Tensor ComputationGraphBuilder::scalar_multiply( Tensor const &input, float scalar, std::optional const &name) { - return this->element_scalar_unary(Op::SCALAR_MULTIPLY, input, scalar, name); + return this->element_scalar_unary(OperatorType::SCALAR_MULTIPLY, input, scalar, name); } Tensor ComputationGraphBuilder::scalar_add( Tensor const &input, float scalar, std::optional const &name) { - return this->element_scalar_unary(Op::SCALAR_ADD, input, scalar, name); + return this->element_scalar_unary(OperatorType::SCALAR_ADD, input, scalar, name); } Tensor ComputationGraphBuilder::scalar_sub( Tensor const &lhs, float rhs, std::optional const &name) { - return this->element_scalar_unary(Op::SCALAR_SUB, lhs, rhs, name); + return this->element_scalar_unary(OperatorType::SCALAR_SUB, lhs, rhs, name); } Tensor ComputationGraphBuilder::scalar_truediv( @@ -208,49 +208,49 @@ Tensor ComputationGraphBuilder::scalar_truediv( float denominator, std::optional const &name) { return this->element_scalar_unary( - Op::SCALAR_TRUE_DIV, numerator, denominator, name); + OperatorType::SCALAR_TRUE_DIV, numerator, denominator, name); } Tensor ComputationGraphBuilder::sin(Tensor const &input, std::optional const &name) { - return this->element_unary(Op::SIN, input, name); + return this->element_unary(OperatorType::SIN, input, name); } Tensor ComputationGraphBuilder::cos(Tensor const &input, std::optional const &name) { - return this->element_unary(Op::COS, input, name); + return this->element_unary(OperatorType::COS, input, name); } Tensor ComputationGraphBuilder::relu(Tensor const &input, std::optional const &name) { - return this->element_unary(Op::RELU, input, name); + return this->element_unary(OperatorType::RELU, input, name); } Tensor ComputationGraphBuilder::identity(Tensor const &input, std::optional const &name) { - return this->element_unary(Op::IDENTITY, input, name); + return this->element_unary(OperatorType::IDENTITY, input, name); } Tensor ComputationGraphBuilder::gelu(Tensor const &input, std::optional const &name) { - return this->element_unary(Op::GELU, input, name); + return this->element_unary(OperatorType::GELU, input, name); } Tensor ComputationGraphBuilder::sigmoid(Tensor const &input, std::optional const &name) { - return this->element_unary(Op::SIGMOID, input, name); + return this->element_unary(OperatorType::SIGMOID, input, name); } Tensor ComputationGraphBuilder::tanh(Tensor const &input, std::optional const &name) { - return this->element_unary(Op::TANH, input, name); + return this->element_unary(OperatorType::TANH, input, name); } Tensor ComputationGraphBuilder::elu(Tensor const &input, std::optional const &name) { - return this->element_unary(Op::ELU, input, name); + return this->element_unary(OperatorType::ELU, input, name); } Tensor ComputationGraphBuilder::conv2d( diff --git a/lib/runtime/src/parallel_op_info.h b/lib/runtime/src/parallel_op_info.h index 11e2f03477..ebd44f012b 100644 --- a/lib/runtime/src/parallel_op_info.h +++ b/lib/runtime/src/parallel_op_info.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_PARALLEL_OPS_PARALLEL_OP_INFO_H #include "op-attrs/ff_dim.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" #include "utils/visitable.h" #include #include diff --git a/lib/substitution-generator/include/substitution-generator/json.h b/lib/substitution-generator/include/substitution-generator/json.h index dbde110f8d..ebffc93a76 100644 --- a/lib/substitution-generator/include/substitution-generator/json.h +++ b/lib/substitution-generator/include/substitution-generator/json.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_SUBSTITUTION_LOADER_H #define _FLEXFLOW_SUBSTITUTION_LOADER_H -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" #include #include @@ -79,85 +79,85 @@ NLOHMANN_JSON_SERIALIZE_ENUM(PMParameter, {PM_PAD, "PM_PAD"}}) NLOHMANN_JSON_SERIALIZE_ENUM(Op, - {{Op::NOOP, "OP_NOOP"}, - {Op::CONV2D, "OP_CONV2D"}, - {Op::DROPOUT, "OP_DROPOUT"}, - {Op::LINEAR, "OP_LINEAR"}, - {Op::BATCHMATMUL, "OP_BATCHMATMUL"}, - {Op::POOL2D, "OP_POOL2D_MAX"}, - {Op::SCALAR_MULTIPLY, "OP_SCALAR_MULTIPLY"}, - {Op::SCALAR_ADD, "OP_SCALAR_ADD"}, - {Op::SCALAR_FLOOR_DIV, "OP_SCALAR_FLOOR_DIV"}, - {Op::SCALAR_TRUE_DIV, "OP_SCALAR_TRUE_DIV"}, - {Op::SCALAR_SUB, "OP_SCALAR_SUB"}, - {Op::RELU, "OP_RELU"}, - {Op::IDENTITY, "OP_IDENTITY"}, - {Op::SIGMOID, "OP_SIGMOID"}, - {Op::TANH, "OP_TANH"}, - {Op::ELU, "OP_ELU"}, - {Op::FLAT, "OP_FLAT"}, - {Op::SOFTMAX, "OP_SOFTMAX"}, - {Op::BATCHNORM, "OP_BATCHNORM"}, - {Op::CONCAT, "OP_CONCAT"}, - {Op::SPLIT, "OP_SPLIT"}, - {Op::EMBEDDING, "OP_EMBEDDING"}, - {Op::CACHE, "OP_CACHE"}, - {Op::RESHAPE, "OP_RESHAPE"}, - {Op::REVERSE, "OP_REVERSE"}, - {Op::TRANSPOSE, "OP_TRANSPOSE"}, - {Op::EW_ADD, "OP_EW_ADD"}, - {Op::EW_MUL, "OP_EW_MUL"}, - {Op::MATMUL, "OP_MATMUL"}, - {Op::MUL, "OP_MUL"}, - {Op::ENLARGE, "OP_ENLARGE"}, - {Op::SQUEEZE, "OP_SQUEEZE"}, - {Op::UNSQUEEZE, "OP_UNSQUEEZE"}, - {Op::EW_SUB, "OP_EW_SUB"}, - {Op::EW_DIV, "OP_EW_DIV"}, - {Op::EW_EQUAL, "OP_EW_EQUAL"}, - {Op::EW_GREATER, "OP_EW_GREATER"}, - {Op::EW_LESS, "OP_EW_LESS"}, - {Op::EW_MAX, "OP_EW_MAX"}, - {Op::EW_MIN, "OP_EW_MIN"}, - {Op::REDUCE_ARGMAX, "OP_REDUCE_ARGMAX"}, - {Op::REDUCE_ARGMIN, "OP_REDUCE_ARGMIN"}, - {Op::REDUCE_MAX, "OP_REDUCE_MAX"}, - {Op::REDUCE_MEAN, "OP_REDUCE_MEAN"}, - {Op::REDUCE_MIN, "OP_REDUCE_MIN"}, - {Op::REDUCE_PROD, "OP_REDUCE_PROD"}, - {Op::REDUCE_SUM, "OP_REDUCE_SUM"}, - {Op::PAD, "OP_PAD"}, - {Op::SHAPE, "OP_SHAPE"}, - {Op::SIZE, "OP_SIZE"}, - {Op::TOPK, "OP_TOPK"}, - {Op::WHERE, "OP_WHERE"}, - {Op::CEIL, "OP_CEIL"}, - {Op::CAST, "OP_CAST"}, - {Op::EXP, "OP_EXP"}, - {Op::ROUND, "OP_ROUND"}, - {Op::LOG, "OP_LOG"}, - {Op::LOGICAL_NOT, "OP_LOGICAL_NOT"}, - {Op::SQRT, "OP_SQRT"}, - {Op::SIN, "OP_SIN"}, - {Op::COS, "OP_COS"}, - {Op::LEAKYRELU, "OP_LEAKYRELU"}, - {Op::SLICE, "OP_SLICE"}, - {Op::RESIZE, "OP_RESIZE"}, - {Op::PRELU, "OP_PRELU"}, - {Op::GELU, "OP_GELU"}, - {Op::MULTIHEAD_ATTENTION, + {{OperatorType::NOOP, "OP_NOOP"}, + {OperatorType::CONV2D, "OP_CONV2D"}, + {OperatorType::DROPOUT, "OP_DROPOUT"}, + {OperatorType::LINEAR, "OP_LINEAR"}, + {OperatorType::BATCHMATMUL, "OP_BATCHMATMUL"}, + {OperatorType::POOL2D, "OP_POOL2D_MAX"}, + {OperatorType::SCALAR_MULTIPLY, "OP_SCALAR_MULTIPLY"}, + {OperatorType::SCALAR_ADD, "OP_SCALAR_ADD"}, + {OperatorType::SCALAR_FLOOR_DIV, "OP_SCALAR_FLOOR_DIV"}, + {OperatorType::SCALAR_TRUE_DIV, "OP_SCALAR_TRUE_DIV"}, + {OperatorType::SCALAR_SUB, "OP_SCALAR_SUB"}, + {OperatorType::RELU, "OP_RELU"}, + {OperatorType::IDENTITY, "OP_IDENTITY"}, + {OperatorType::SIGMOID, "OP_SIGMOID"}, + {OperatorType::TANH, "OP_TANH"}, + {OperatorType::ELU, "OP_ELU"}, + {OperatorType::FLAT, "OP_FLAT"}, + {OperatorType::SOFTMAX, "OP_SOFTMAX"}, + {OperatorType::BATCHNORM, "OP_BATCHNORM"}, + {OperatorType::CONCAT, "OP_CONCAT"}, + {OperatorType::SPLIT, "OP_SPLIT"}, + {OperatorType::EMBEDDING, "OP_EMBEDDING"}, + {OperatorType::CACHE, "OP_CACHE"}, + {OperatorType::RESHAPE, "OP_RESHAPE"}, + {OperatorType::REVERSE, "OP_REVERSE"}, + {OperatorType::TRANSPOSE, "OP_TRANSPOSE"}, + {OperatorType::EW_ADD, "OP_EW_ADD"}, + {OperatorType::EW_MUL, "OP_EW_MUL"}, + {OperatorType::MATMUL, "OP_MATMUL"}, + {OperatorType::MUL, "OP_MUL"}, + {OperatorType::ENLARGE, "OP_ENLARGE"}, + {OperatorType::SQUEEZE, "OP_SQUEEZE"}, + {OperatorType::UNSQUEEZE, "OP_UNSQUEEZE"}, + {OperatorType::EW_SUB, "OP_EW_SUB"}, + {OperatorType::EW_DIV, "OP_EW_DIV"}, + {OperatorType::EW_EQUAL, "OP_EW_EQUAL"}, + {OperatorType::EW_GREATER, "OP_EW_GREATER"}, + {OperatorType::EW_LESS, "OP_EW_LESS"}, + {OperatorType::EW_MAX, "OP_EW_MAX"}, + {OperatorType::EW_MIN, "OP_EW_MIN"}, + {OperatorType::REDUCE_ARGMAX, "OP_REDUCE_ARGMAX"}, + {OperatorType::REDUCE_ARGMIN, "OP_REDUCE_ARGMIN"}, + {OperatorType::REDUCE_MAX, "OP_REDUCE_MAX"}, + {OperatorType::REDUCE_MEAN, "OP_REDUCE_MEAN"}, + {OperatorType::REDUCE_MIN, "OP_REDUCE_MIN"}, + {OperatorType::REDUCE_PROD, "OP_REDUCE_PROD"}, + {OperatorType::REDUCE_SUM, "OP_REDUCE_SUM"}, + {OperatorType::PAD, "OP_PAD"}, + {OperatorType::SHAPE, "OP_SHAPE"}, + {OperatorType::SIZE, "OP_SIZE"}, + {OperatorType::TOPK, "OP_TOPK"}, + {OperatorType::WHERE, "OP_WHERE"}, + {OperatorType::CEIL, "OP_CEIL"}, + {OperatorType::CAST, "OP_CAST"}, + {OperatorType::EXP, "OP_EXP"}, + {OperatorType::ROUND, "OP_ROUND"}, + {OperatorType::LOG, "OP_LOG"}, + {OperatorType::LOGICAL_NOT, "OP_LOGICAL_NOT"}, + {OperatorType::SQRT, "OP_SQRT"}, + {OperatorType::SIN, "OP_SIN"}, + {OperatorType::COS, "OP_COS"}, + {OperatorType::LEAKYRELU, "OP_LEAKYRELU"}, + {OperatorType::SLICE, "OP_SLICE"}, + {OperatorType::RESIZE, "OP_RESIZE"}, + {OperatorType::PRELU, "OP_PRELU"}, + {OperatorType::GELU, "OP_GELU"}, + {OperatorType::MULTIHEAD_ATTENTION, "OP_MULTIHEAD_ATTENTION"}, - {Op::FUSED, "OP_FUSED"}, - {Op::RSQRT, "OP_RSQRT"}, - {Op::POW, "OP_POW"}, - {Op::MEAN, "OP_MEAN"}, - {Op::LAYERNORM, "OP_LAYERNORM"}, - {Op::REPARTITION, "OP_PARTITION"}, - {Op::COMBINE, "OP_COMBINE"}, - {Op::REPLICATE, "OP_REPLICATE"}, - {Op::REDUCTION, "OP_REDUCE"}, - {Op::PIPELINE, "OP_PIPELINE"}, - {Op::FUSED_PARALLEL, "OP_FUSED_PARALLEL"}}) + {OperatorType::FUSED, "OP_FUSED"}, + {OperatorType::RSQRT, "OP_RSQRT"}, + {OperatorType::POW, "OP_POW"}, + {OperatorType::MEAN, "OP_MEAN"}, + {OperatorType::LAYERNORM, "OP_LAYERNORM"}, + {OperatorType::REPARTITION, "OP_PARTITION"}, + {OperatorType::COMBINE, "OP_COMBINE"}, + {OperatorType::REPLICATE, "OP_REPLICATE"}, + {OperatorType::REDUCTION, "OP_REDUCE"}, + {OperatorType::PIPELINE, "OP_PIPELINE"}, + {OperatorType::FUSED_PARALLEL, "OP_FUSED_PARALLEL"}}) struct Parameter { PMParameter key; diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h index 8fc4ebefc2..5f2be36a09 100644 --- a/lib/substitutions/include/substitutions/operator_pattern.h +++ b/lib/substitutions/include/substitutions/operator_pattern.h @@ -4,7 +4,7 @@ #include "attribute_expr.h" #include "op-attrs/activation.h" #include "op-attrs/datatype.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" #include "pcg/operator.h" #include #include diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 15816185ee..335d021a2b 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -118,28 +118,28 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, OperatorType op_type = std::get(assignments.at(OperatorAttributeKey::OP_TYPE)); switch (op_type) { - case Op::BATCHMATMUL: + case OperatorType::BATCHMATMUL: return Operator{ BatchMatmulAttrs{std::get(assignments.at( OperatorAttributeKey::A_SEQ_LENGTH_DIM)), std::get(assignments.at( OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, std::nullopt}; - case Op::BATCHNORM: + case OperatorType::BATCHNORM: return Operator{BatchNormAttrs{std::get( assignments.at(OperatorAttributeKey::RELU))}, std::nullopt}; - case Op::CAST: + case OperatorType::CAST: return Operator{CastAttrs{std::get( assignments.at(OperatorAttributeKey::DATA_TYPE))}, std::nullopt}; - case Op::CONCAT: + case OperatorType::CONCAT: return Operator{ ConcatAttrs{ std::get(assignments.at(OperatorAttributeKey::AXIS)), std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, std::nullopt}; - case Op::CONV2D: + case OperatorType::CONV2D: return Operator{ Conv2DAttrs{ std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), @@ -154,21 +154,21 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, assignments.at(OperatorAttributeKey::ACTIVATION)), std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, std::nullopt}; - case Op::DROPOUT: + case OperatorType::DROPOUT: return Operator{DropoutAttrs{std::get(assignments.at( OperatorAttributeKey::RATE)), std::get(assignments.at( OperatorAttributeKey::SEED))}, std::nullopt}; - case Op::EW_ADD: - case Op::EW_DIV: - case Op::EW_EQUAL: - case Op::EW_GREATER: - case Op::EW_LESS: - case Op::EW_MAX: - case Op::EW_MIN: - case Op::EW_MUL: - case Op::EW_SUB: + case OperatorType::EW_ADD: + case OperatorType::EW_DIV: + case OperatorType::EW_EQUAL: + case OperatorType::EW_GREATER: + case OperatorType::EW_LESS: + case OperatorType::EW_MAX: + case OperatorType::EW_MIN: + case OperatorType::EW_MUL: + case OperatorType::EW_SUB: return Operator{ ElementBinaryAttrs{op_type, std::get(assignments.at( @@ -178,25 +178,25 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, std::get(assignments.at( OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, std::nullopt}; - case Op::SCALAR_ADD: - case Op::SCALAR_FLOOR_DIV: - case Op::SCALAR_MULTIPLY: - case Op::SCALAR_SUB: - case Op::SCALAR_TRUE_DIV: + case OperatorType::SCALAR_ADD: + case OperatorType::SCALAR_FLOOR_DIV: + case OperatorType::SCALAR_MULTIPLY: + case OperatorType::SCALAR_SUB: + case OperatorType::SCALAR_TRUE_DIV: return Operator{ ElementScalarUnaryAttrs{ op_type, std::get(assignments.at(OperatorAttributeKey::SCALAR))}, std::nullopt}; - case Op::EXP: - case Op::IDENTITY: - case Op::GELU: - case Op::RSQRT: - case Op::POW: - case Op::SIN: - case Op::COS: + case OperatorType::EXP: + case OperatorType::IDENTITY: + case OperatorType::GELU: + case OperatorType::RSQRT: + case OperatorType::POW: + case OperatorType::SIN: + case OperatorType::COS: return Operator{ElementUnaryAttrs{op_type}, std::nullopt}; - case Op::EMBEDDING: + case OperatorType::EMBEDDING: return Operator{ EmbeddingAttrs{ std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), @@ -205,15 +205,15 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, std::get( assignments.at(OperatorAttributeKey::OP_TYPE))}, std::nullopt}; - case Op::FLAT: + case OperatorType::FLAT: return Operator{FlatAttrs{}, std::nullopt}; - case Op::GATHER: + case OperatorType::GATHER: return Operator{GatherAttrs{std::get( assignments.at(OperatorAttributeKey::DIM))}, std::nullopt}; - case Op::INPUT: + case OperatorType::INPUT: return Operator{InputAttrs{}, std::nullopt}; - case Op::LAYERNORM: + case OperatorType::LAYERNORM: return Operator{ LayerNormAttrs{ std::get>( @@ -222,7 +222,7 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), std::get(assignments.at(OperatorAttributeKey::EPSILON))}, std::nullopt}; - case Op::LINEAR: + case OperatorType::LINEAR: return Operator{ LinearAttrs{ std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), @@ -234,7 +234,7 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, std::get>( assignments.at(OperatorAttributeKey::REGULARIZER))}, std::nullopt}; - case Op::MULTIHEAD_ATTENTION: + case OperatorType::MULTIHEAD_ATTENTION: return Operator{ MultiHeadAttentionAttrs{ std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), @@ -247,9 +247,9 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, std::get( assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, std::nullopt}; - case Op::NOOP: + case OperatorType::NOOP: return Operator{NoopAttrs{}, std::nullopt}; - case Op::POOL2D: + case OperatorType::POOL2D: return Operator{ Pool2DAttrs{ std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), @@ -262,13 +262,13 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, std::get( assignments.at(OperatorAttributeKey::ACTIVATION))}, std::nullopt}; - case Op::REDUCE_ARGMAX: - case Op::REDUCE_ARGMIN: - case Op::REDUCE_MAX: - case Op::REDUCE_MEAN: - case Op::REDUCE_MIN: - case Op::REDUCE_PROD: - case Op::REDUCE_SUM: + case OperatorType::REDUCE_ARGMAX: + case OperatorType::REDUCE_ARGMIN: + case OperatorType::REDUCE_MAX: + case OperatorType::REDUCE_MEAN: + case OperatorType::REDUCE_MIN: + case OperatorType::REDUCE_PROD: + case OperatorType::REDUCE_SUM: return Operator{ ReduceAttrs{ std::get>( @@ -276,57 +276,57 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, op_type, std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, std::nullopt}; - case Op::REVERSE: + case OperatorType::REVERSE: return Operator{ReverseAttrs{std::get( assignments.at(OperatorAttributeKey::AXIS))}, std::nullopt}; - case Op::RESHAPE: + case OperatorType::RESHAPE: return Operator{ReshapeAttrs{std::get( assignments.at(OperatorAttributeKey::SHAPE))}, std::nullopt}; - case Op::SPLIT: + case OperatorType::SPLIT: return Operator{ SplitAttrs{ std::get>( assignments.at(OperatorAttributeKey::SPLITS)), std::get(assignments.at(OperatorAttributeKey::AXIS))}, std::nullopt}; - case Op::SOFTMAX: + case OperatorType::SOFTMAX: return Operator{SoftmaxAttrs{std::get( assignments.at(OperatorAttributeKey::DIM))}, std::nullopt}; - case Op::TOPK: + case OperatorType::TOPK: return Operator{ TopKAttrs{ std::get(assignments.at(OperatorAttributeKey::K)), std::get(assignments.at(OperatorAttributeKey::SORTED))}, std::nullopt}; - case Op::TRANSPOSE: + case OperatorType::TRANSPOSE: return Operator{ TransposeAttrs{std::get>( assignments.at(OperatorAttributeKey::PERMUTATION))}, std::nullopt}; - case Op::COMBINE: + case OperatorType::COMBINE: return Operator{CombineAttrs{std::get(assignments.at( OperatorAttributeKey::PARALLEL_DIM)), std::get(assignments.at( OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; - case Op::REDUCTION: + case OperatorType::REDUCTION: return Operator{ ReductionAttrs{std::get(assignments.at( OperatorAttributeKey::PARALLEL_DIM)), std::get(assignments.at( OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; - case Op::REPARTITION: + case OperatorType::REPARTITION: return Operator{ RepartitionAttrs{std::get(assignments.at( OperatorAttributeKey::PARALLEL_DIM)), std::get(assignments.at( OperatorAttributeKey::PARALLEL_DEGREE))}, std::nullopt}; - case Op::REPLICATE: + case OperatorType::REPLICATE: return Operator{ ReplicateAttrs{std::get(assignments.at( OperatorAttributeKey::PARALLEL_DIM)), diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index 75d7e6dbcc..32a596e940 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -9,7 +9,7 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("apply_substitution") { OperatorPattern operator_pattern_n0{ std::vector{OperatorAttributeConstraint{ - ConstraintType::EQUAL, OperatorAttributeKey::OP_TYPE, Op::LINEAR}}}; + ConstraintType::EQUAL, OperatorAttributeKey::OP_TYPE, OperatorType::LINEAR}}}; ParallelTensorPattern tensor_pattern_e0{ std::vector{ @@ -38,12 +38,12 @@ TEST_SUITE(FF_TEST_SUITE) { GraphPattern input_graph{ig}; OperatorAttrAssignment op_ass_n1{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REPARTITION}}, + {{OperatorAttributeKey::OP_TYPE, AttrConstant{OperatorType::REPARTITION}}, {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; OperatorAttrAssignment op_ass_n2{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::LINEAR}}, + {{OperatorAttributeKey::OP_TYPE, AttrConstant{OperatorType::LINEAR}}, {OperatorAttributeKey::OUT_CHANNELS, OperatorAttrAccess{n0, OperatorAttributeKey::OUT_CHANNELS}}, {OperatorAttributeKey::USE_BIAS, @@ -56,7 +56,7 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorAttrAccess{n0, OperatorAttributeKey::REGULARIZER}}}}; OperatorAttrAssignment op_ass_n3{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REDUCTION}}, + {{OperatorAttributeKey::OP_TYPE, AttrConstant{OperatorType::REDUCTION}}, {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; diff --git a/lib/utils/include/utils/fmt.decl.h b/lib/utils/include/utils/fmt.decl.h index 7adb2052ad..eeebaf5d88 100644 --- a/lib/utils/include/utils/fmt.decl.h +++ b/lib/utils/include/utils/fmt.decl.h @@ -4,6 +4,7 @@ #include "fmt/format.h" #include #include +#include #define CHECK_FMTABLE(...) \ static_assert(::FlexFlow::is_fmtable<__VA_ARGS__>::value, \ @@ -44,6 +45,14 @@ struct formatter<::std::vector> : formatter<::std::string> { -> decltype(ctx.out()); }; +template +struct formatter<::std::variant> : formatter<::std::string> { + template + auto format(::std::variant const &m, FormatContext &ctx) + -> decltype(ctx.out()); +}; + + } // namespace fmt #endif diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index 9cb56e4e2b..fe1a2ca979 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -1,12 +1,11 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_FMT_H #define _FLEXFLOW_UTILS_INCLUDE_FMT_H -#include "utils/containers.decl.h" +#include "utils/containers.h" #include "utils/fmt.decl.h" #include "utils/test_types.h" #include "utils/type_traits_core.h" #include - #include namespace FlexFlow { @@ -40,7 +39,7 @@ auto formatter<::std::unordered_set>::format( -> decltype(ctx.out()) { CHECK_FMTABLE(T); - std::string result = join_strings( + std::string result = ::FlexFlow::join_strings( m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); return formatter::format(result, ctx); } @@ -52,11 +51,21 @@ auto formatter<::std::vector>::format(::std::vector const &m, -> decltype(ctx.out()) { CHECK_FMTABLE(T); - std::string result = join_strings( + std::string result = ::FlexFlow::join_strings( m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); return formatter::format(result, ctx); } +template +template +auto formatter<::std::variant>::format(::std::variant const &m, + FormatContext &ctx) + -> decltype(ctx.out()) { + + std::string result = std::visit([](auto &&x) { return fmt::to_string(x); }, m); + return formatter::format(result, ctx); +} + } // namespace fmt #endif diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index ce371adeba..08248003f3 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -311,6 +311,9 @@ struct stack_vector { implies, is_lt_comparable>::value, ""); }; +template +struct delegate_ostream_operator> : std::true_type {}; + // CHECK_FMTABLE(stack_vector); template From 341742c5048309dece41198c15710ab0d9a149f2 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 24 Apr 2024 15:37:59 -0700 Subject: [PATCH 10/43] Simplify types in substitutions, more dtgen --- .proj.toml | 9 +- flake.lock | 6 +- lib/op-attrs/CMakeLists.txt | 1 + .../{activation.h => activation.dtg.h} | 11 +- .../{aggregate_op.h => aggregate_op.dtg.h} | 11 +- .../op-attrs/{datatype_t.h => datatype.dtg.h} | 13 +- ...atatype_t.enum.toml => datatype.enum.toml} | 0 lib/op-attrs/include/op-attrs/datatype.h | 4 +- lib/op-attrs/include/op-attrs/dim_ordered.h | 2 +- .../op-attrs/{ff_dim.h => ff_dim.dtg.h} | 12 +- .../include/op-attrs/get_output_shapes.h | 7 - ...zer_attrs.h => l1_regularizer_attrs.dtg.h} | 12 +- ...zer_attrs.h => l2_regularizer_attrs.dtg.h} | 12 +- .../include/op-attrs/operator_attrs.h | 21 +- ...{operator_type_t.h => operator_type.dtg.h} | 13 +- ...pe_t.enum.toml => operator_type.enum.toml} | 0 lib/op-attrs/include/op-attrs/operator_type.h | 2 +- lib/op-attrs/include/op-attrs/ops/attention.h | 9 +- ...ttention_attrs.h => attention_attrs.dtg.h} | 12 +- ...ention_inputs.h => attention_inputs.dtg.h} | 12 +- .../include/op-attrs/ops/batch_matmul.dtg.h | 63 ++ .../include/op-attrs/ops/batch_matmul.h | 54 +- .../include/op-attrs/ops/batch_norm.h | 5 +- ...ch_norm_attrs.h => batch_norm_attrs.dtg.h} | 12 +- .../include/op-attrs/ops/broadcast.dtg.h | 56 ++ lib/op-attrs/include/op-attrs/ops/broadcast.h | 47 +- lib/op-attrs/include/op-attrs/ops/cast.h | 2 +- .../ops/{cast_attrs.h => cast_attrs.dtg.h} | 12 +- lib/op-attrs/include/op-attrs/ops/combine.h | 2 +- .../{combine_attrs.h => combine_attrs.dtg.h} | 14 +- .../op-attrs/ops/combine_attrs.struct.toml | 2 +- lib/op-attrs/include/op-attrs/ops/concat.h | 2 +- .../{concat_attrs.h => concat_attrs.dtg.h} | 14 +- .../op-attrs/ops/concat_attrs.struct.toml | 2 +- lib/op-attrs/include/op-attrs/ops/conv_2d.h | 6 +- .../{conv_2d_attrs.h => conv_2d_attrs.dtg.h} | 14 +- .../op-attrs/ops/conv_2d_attrs.struct.toml | 2 +- lib/op-attrs/include/op-attrs/ops/dropout.h | 6 +- .../{dropout_attrs.h => dropout_attrs.dtg.h} | 12 +- .../include/op-attrs/ops/element_binary.h | 5 +- ...ary_attrs.h => element_binary_attrs.dtg.h} | 12 +- ...trs.h => element_scalar_unary_attrs.dtg.h} | 12 +- .../include/op-attrs/ops/element_unary.h | 13 +- ...nary_attrs.h => element_unary_attrs.dtg.h} | 12 +- lib/op-attrs/include/op-attrs/ops/embedding.h | 4 +- ...mbedding_attrs.h => embedding_attrs.dtg.h} | 16 +- .../op-attrs/ops/embedding_attrs.struct.toml | 4 +- lib/op-attrs/include/op-attrs/ops/flat.h | 2 +- .../ops/{flat_attrs.h => flat_attrs.dtg.h} | 12 +- lib/op-attrs/include/op-attrs/ops/gather.h | 2 +- .../{gather_attrs.h => gather_attrs.dtg.h} | 14 +- .../op-attrs/ops/gather_attrs.struct.toml | 2 +- lib/op-attrs/include/op-attrs/ops/input.h | 6 +- .../ops/{input_attrs.h => input_attrs.dtg.h} | 12 +- .../include/op-attrs/ops/layer_norm.h | 4 +- ...er_norm_attrs.h => layer_norm_attrs.dtg.h} | 14 +- .../op-attrs/ops/layer_norm_attrs.struct.toml | 2 +- lib/op-attrs/include/op-attrs/ops/linear.h | 6 +- .../{linear_attrs.h => linear_attrs.dtg.h} | 18 +- .../op-attrs/ops/linear_attrs.struct.toml | 6 +- lib/op-attrs/include/op-attrs/ops/noop.h | 5 +- .../ops/{noop_attrs.h => noop_attrs.dtg.h} | 12 +- ...puts.h => parallel_attention_inputs.dtg.h} | 12 +- lib/op-attrs/include/op-attrs/ops/pool_2d.h | 6 +- .../{pool_2d_attrs.h => pool_2d_attrs.dtg.h} | 16 +- .../op-attrs/ops/pool_2d_attrs.struct.toml | 4 +- lib/op-attrs/include/op-attrs/ops/reduce.h | 6 +- .../{reduce_attrs.h => reduce_attrs.dtg.h} | 16 +- .../op-attrs/ops/reduce_attrs.struct.toml | 4 +- lib/op-attrs/include/op-attrs/ops/reduction.h | 6 +- ...eduction_attrs.h => reduction_attrs.dtg.h} | 14 +- .../op-attrs/ops/reduction_attrs.struct.toml | 2 +- .../include/op-attrs/ops/repartition.h | 6 +- ...tition_attrs.h => repartition_attrs.dtg.h} | 14 +- .../ops/repartition_attrs.struct.toml | 2 +- lib/op-attrs/include/op-attrs/ops/replicate.h | 6 +- ...eplicate_attrs.h => replicate_attrs.dtg.h} | 14 +- .../op-attrs/ops/replicate_attrs.struct.toml | 2 +- lib/op-attrs/include/op-attrs/ops/reshape.h | 5 +- .../{reshape_attrs.h => reshape_attrs.dtg.h} | 14 +- .../op-attrs/ops/reshape_attrs.struct.toml | 2 +- lib/op-attrs/include/op-attrs/ops/reverse.h | 5 +- .../{reverse_attrs.h => reverse_attrs.dtg.h} | 14 +- .../op-attrs/ops/reverse_attrs.struct.toml | 2 +- lib/op-attrs/include/op-attrs/ops/softmax.h | 6 +- .../{softmax_attrs.h => softmax_attrs.dtg.h} | 14 +- .../op-attrs/ops/softmax_attrs.struct.toml | 2 +- lib/op-attrs/include/op-attrs/ops/split.h | 7 +- .../ops/{split_attrs.h => split_attrs.dtg.h} | 14 +- .../op-attrs/ops/split_attrs.struct.toml | 2 +- lib/op-attrs/include/op-attrs/ops/topk.h | 6 +- .../ops/{topk_attrs.h => topk_attrs.dtg.h} | 12 +- lib/op-attrs/include/op-attrs/ops/transpose.h | 6 +- ...ranspose_attrs.h => transpose_attrs.dtg.h} | 14 +- .../op-attrs/ops/transpose_attrs.struct.toml | 2 +- .../{parallel_dim_t.h => parallel_dim.dtg.h} | 14 +- lib/op-attrs/include/op-attrs/parallel_dim.h | 2 +- ...t.struct.toml => parallel_dim.struct.toml} | 0 ...or_dims_t.h => parallel_tensor_dims.dtg.h} | 14 +- .../include/op-attrs/parallel_tensor_dims.h | 4 +- ....toml => parallel_tensor_dims.struct.toml} | 0 ..._shape_t.h => parallel_tensor_shape.dtg.h} | 14 +- .../include/op-attrs/parallel_tensor_shape.h | 2 +- ...toml => parallel_tensor_shape.struct.toml} | 0 .../include/op-attrs/param_sync.dtg.h | 40 + .../include/op-attrs/param_sync.enum.toml | 14 + lib/op-attrs/include/op-attrs/param_sync.h | 28 +- .../include/op-attrs/pcg_operator_attrs.dtg.h | 438 ++++++++++ .../op-attrs/pcg_operator_attrs.variant.toml | 117 +++ .../op-attrs/{pool_op.h => pool_op.dtg.h} | 11 +- .../include/op-attrs/regularizer_attrs.dtg.h | 121 +++ .../include/op-attrs/regularizer_attrs.h | 16 - .../op-attrs/regularizer_attrs.variant.toml | 22 + .../{tensor_dims_t.h => tensor_dims.dtg.h} | 14 +- lib/op-attrs/include/op-attrs/tensor_dims.h | 4 +- ..._t.struct.toml => tensor_dims.struct.toml} | 0 .../{tensor_shape_t.h => tensor_shape.dtg.h} | 14 +- lib/op-attrs/include/op-attrs/tensor_shape.h | 2 +- ...t.struct.toml => tensor_shape.struct.toml} | 0 lib/op-attrs/src/batch_matmul.cc | 26 - .../{activation.cc => activation.dtg.cc} | 7 +- .../{aggregate_op.cc => aggregate_op.dtg.cc} | 7 +- .../{datatype_t.cc => datatype.dtg.cc} | 9 +- .../src/op-attrs/{ff_dim.cc => ff_dim.dtg.cc} | 9 +- ...r_attrs.cc => l1_regularizer_attrs.dtg.cc} | 9 +- ...r_attrs.cc => l2_regularizer_attrs.dtg.cc} | 9 +- ...perator_type_t.cc => operator_type.dtg.cc} | 9 +- ...ention_attrs.cc => attention_attrs.dtg.cc} | 9 +- ...tion_inputs.cc => attention_inputs.dtg.cc} | 10 +- lib/op-attrs/src/op-attrs/ops/batch_matmul.cc | 96 +-- .../src/op-attrs/ops/batch_matmul.dtg.cc | 90 ++ lib/op-attrs/src/op-attrs/ops/batch_norm.cc | 9 + ..._norm_attrs.cc => batch_norm_attrs.dtg.cc} | 9 +- .../ops/{broadcast.cc => broadcast.dtg.cc} | 10 +- .../ops/{cast_attrs.cc => cast_attrs.dtg.cc} | 10 +- ...{combine_attrs.cc => combine_attrs.dtg.cc} | 10 +- .../{concat_attrs.cc => concat_attrs.dtg.cc} | 10 +- ...{conv_2d_attrs.cc => conv_2d_attrs.dtg.cc} | 12 +- lib/op-attrs/src/op-attrs/ops/dropout.cc | 9 + ...{dropout_attrs.cc => dropout_attrs.dtg.cc} | 9 +- ...y_attrs.cc => element_binary_attrs.dtg.cc} | 11 +- ...s.cc => element_scalar_unary_attrs.dtg.cc} | 10 +- ...ry_attrs.cc => element_unary_attrs.dtg.cc} | 10 +- .../src/{ => op-attrs/ops}/embedding.cc | 4 + ...edding_attrs.cc => embedding_attrs.dtg.cc} | 12 +- .../ops/{flat_attrs.cc => flat_attrs.dtg.cc} | 9 +- .../{gather_attrs.cc => gather_attrs.dtg.cc} | 10 +- lib/op-attrs/src/op-attrs/ops/input.cc | 9 + .../{input_attrs.cc => input_attrs.dtg.cc} | 9 +- lib/op-attrs/src/op-attrs/ops/layer_norm.cc | 9 + ..._norm_attrs.cc => layer_norm_attrs.dtg.cc} | 11 +- lib/op-attrs/src/op-attrs/ops/linear.cc | 9 + .../{linear_attrs.cc => linear_attrs.dtg.cc} | 13 +- lib/op-attrs/src/op-attrs/ops/noop.cc | 9 + .../ops/{noop_attrs.cc => noop_attrs.dtg.cc} | 9 +- ...ts.cc => parallel_attention_inputs.dtg.cc} | 10 +- lib/op-attrs/src/op-attrs/ops/pool_2d.cc | 9 + ...{pool_2d_attrs.cc => pool_2d_attrs.dtg.cc} | 11 +- lib/op-attrs/src/op-attrs/ops/reduce.cc | 9 + .../{reduce_attrs.cc => reduce_attrs.dtg.cc} | 12 +- lib/op-attrs/src/op-attrs/ops/reduction.cc | 9 + ...uction_attrs.cc => reduction_attrs.dtg.cc} | 10 +- lib/op-attrs/src/op-attrs/ops/repartition.cc | 9 + ...tion_attrs.cc => repartition_attrs.dtg.cc} | 10 +- lib/op-attrs/src/op-attrs/ops/replicate.cc | 9 + ...licate_attrs.cc => replicate_attrs.dtg.cc} | 10 +- lib/op-attrs/src/op-attrs/ops/reshape.cc | 9 + ...{reshape_attrs.cc => reshape_attrs.dtg.cc} | 10 +- lib/op-attrs/src/op-attrs/ops/reverse.cc | 9 + ...{reverse_attrs.cc => reverse_attrs.dtg.cc} | 10 +- lib/op-attrs/src/op-attrs/ops/softmax.cc | 9 + ...{softmax_attrs.cc => softmax_attrs.dtg.cc} | 10 +- lib/op-attrs/src/op-attrs/ops/split.cc | 9 + .../{split_attrs.cc => split_attrs.dtg.cc} | 11 +- lib/op-attrs/src/op-attrs/ops/topk.cc | 9 + .../ops/{topk_attrs.cc => topk_attrs.dtg.cc} | 9 +- lib/op-attrs/src/op-attrs/ops/transpose.cc | 9 + ...nspose_attrs.cc => transpose_attrs.dtg.cc} | 11 +- ...{parallel_dim_t.cc => parallel_dim.dtg.cc} | 11 +- ..._dims_t.cc => parallel_tensor_dims.dtg.cc} | 13 +- ...hape_t.cc => parallel_tensor_shape.dtg.cc} | 13 +- lib/op-attrs/src/op-attrs/param_sync.dtg.cc | 70 ++ .../src/op-attrs/pcg_operator_attrs.dtg.cc | 476 +++++++++++ .../op-attrs/{pool_op.cc => pool_op.dtg.cc} | 7 +- .../src/op-attrs/regularizer_attrs.dtg.cc | 109 +++ .../{tensor_dims_t.cc => tensor_dims.dtg.cc} | 12 +- ...{tensor_shape_t.cc => tensor_shape.dtg.cc} | 13 +- lib/op-attrs/src/operator_attrs.cc | 2 +- lib/op-attrs/test/CMakeLists.txt | 13 + lib/op-attrs/test/src/test_operator_attrs.cc | 33 + lib/pcg/include/pcg/computation_graph.dtg.h | 31 + lib/pcg/include/pcg/computation_graph.h | 19 +- .../include/pcg/computation_graph.struct.toml | 13 + .../include/pcg/computation_graph_builder.h | 215 +++-- lib/pcg/include/pcg/cpu_id_t.dtg.h | 62 ++ lib/pcg/include/pcg/cpu_id_t.struct.toml | 14 + lib/pcg/include/pcg/create_grad.dtg.h | 40 + lib/pcg/include/pcg/create_grad.enum.toml | 14 + lib/pcg/include/pcg/create_grad.h | 28 +- lib/pcg/include/pcg/device_id.h | 24 +- lib/pcg/include/pcg/device_id_t.dtg.h | 117 +++ lib/pcg/include/pcg/device_id_t.variant.toml | 22 + lib/pcg/include/pcg/device_type.dtg.h | 40 + lib/pcg/include/pcg/device_type.enum.toml | 14 + lib/pcg/include/pcg/device_type.h | 36 - .../v1/{data_type.h => data_type_value.h} | 17 - lib/pcg/include/pcg/file_format/v1/graphs.h | 68 +- .../file_format/v1/graphs/v1_graph_edge.dtg.h | 60 ++ .../v1/graphs/v1_graph_edge.struct.toml | 26 + .../v1/graphs/v1_graph_output.dtg.h | 55 ++ .../v1/graphs/v1_graph_output.struct.toml | 18 + .../v1/graphs/v1_jsonable_graph.dtg.h | 109 +++ .../v1/graphs/v1_jsonable_graph.struct.toml | 38 + .../v1/graphs/v1_multidigraph.dtg.h | 47 + .../file_format/v1/graphs/v1_multidigraph.h | 16 + .../v1/graphs/v1_multidigraph.struct.toml | 29 + .../include/pcg/file_format/v1/initializer.h | 57 -- .../pcg/file_format/v1/operator_attrs.h | 20 - .../pcg/file_format/v1/parallel_tensor.h | 37 - .../include/pcg/file_format/v1/param_sync.h | 16 - lib/pcg/include/pcg/file_format/v1/tensor.h | 36 - lib/pcg/include/pcg/gpu_id_t.dtg.h | 62 ++ lib/pcg/include/pcg/gpu_id_t.struct.toml | 14 + lib/pcg/include/pcg/initializer.h | 50 -- lib/pcg/include/pcg/initializer_attrs.dtg.h | 169 ++++ .../pcg/initializer_attrs.variant.toml | 37 + .../constant_initializer_attrs.dtg.h | 56 ++ .../constant_initializer_attrs.struct.toml | 19 + .../initializers/glorot_uniform_attrs.dtg.h | 62 ++ .../glorot_uniform_attrs.struct.toml | 14 + .../initializers/norm_initializer_attrs.dtg.h | 64 ++ .../norm_initializer_attrs.struct.toml | 22 + .../uniform_initializer_attrs.dtg.h | 58 ++ .../uniform_initializer_attrs.struct.toml | 22 + .../initializers/zero_initializer_attrs.dtg.h | 58 ++ .../zero_initializer_attrs.struct.toml | 11 + lib/pcg/include/pcg/layer.h | 33 - lib/pcg/include/pcg/layer_attrs.dtg.h | 60 ++ lib/pcg/include/pcg/layer_attrs.struct.toml | 26 + .../include/pcg/machine_specification.dtg.h | 62 ++ lib/pcg/include/pcg/machine_specification.h | 23 +- .../pcg/machine_specification.struct.toml | 30 + lib/pcg/include/pcg/machine_view.dtg.h | 58 ++ lib/pcg/include/pcg/machine_view.h | 27 +- lib/pcg/include/pcg/machine_view.struct.toml | 23 + lib/pcg/include/pcg/num_points_t.dtg.h | 62 ++ lib/pcg/include/pcg/num_points_t.struct.toml | 14 + lib/pcg/include/pcg/operator.h | 27 - lib/pcg/include/pcg/operator_guid_t.dtg.h | 46 + lib/pcg/include/pcg/operator_guid_t.h | 18 - .../include/pcg/operator_guid_t.struct.toml | 18 + lib/pcg/include/pcg/optimizer.h | 41 - lib/pcg/include/pcg/optimizer_attrs.h | 14 + .../pcg/optimizers/adam_optimizer_attrs.dtg.h | 74 ++ .../adam_optimizer_attrs.struct.toml | 38 + .../pcg/optimizers/sgd_optimizer_attrs.dtg.h | 68 ++ .../sgd_optimizer_attrs.struct.toml | 26 + .../pcg/parallel_computation_graph.dtg.h | 30 + .../include/pcg/parallel_computation_graph.h | 27 +- .../parallel_computation_graph.struct.toml | 13 + .../include/pcg/parallel_layer_attrs.dtg.h | 60 ++ .../pcg/parallel_layer_attrs.struct.toml | 24 + lib/pcg/include/pcg/parallel_tensor.h | 46 +- .../include/pcg/parallel_tensor_attrs.dtg.h | 66 ++ .../pcg/parallel_tensor_attrs.struct.toml | 34 + lib/pcg/include/pcg/serialization.h | 61 -- lib/pcg/include/pcg/side_size_t.dtg.h | 62 ++ lib/pcg/include/pcg/side_size_t.struct.toml | 14 + lib/pcg/include/pcg/strided_rectangle.dtg.h | 57 ++ lib/pcg/include/pcg/strided_rectangle.h | 57 +- .../include/pcg/strided_rectangle.struct.toml | 19 + .../include/pcg/strided_rectangle_side.dtg.h | 65 ++ lib/pcg/include/pcg/strided_rectangle_side.h | 15 + .../pcg/strided_rectangle_side.struct.toml | 22 + lib/pcg/include/pcg/tensor.h | 38 - lib/pcg/include/pcg/tensor_attrs.dtg.h | 64 ++ lib/pcg/include/pcg/tensor_attrs.struct.toml | 33 + lib/pcg/include/pcg/tensor_guid_t.dtg.h | 46 + lib/pcg/include/pcg/tensor_guid_t.h | 17 - lib/pcg/include/pcg/tensor_guid_t.struct.toml | 18 + lib/pcg/src/device_id.cc | 19 - lib/pcg/src/file_format/v1/graphs.cc | 48 +- lib/pcg/src/file_format/v1/v1.cc | 13 - lib/pcg/src/layer.cc | 9 - lib/pcg/src/machine_view.cc | 29 - lib/pcg/src/operator.cc | 9 - lib/pcg/src/parallel_computation_graph.cc | 40 - lib/pcg/src/parallel_tensor.cc | 17 - lib/pcg/src/pcg/computation_graph.dtg.cc | 22 + .../{ => pcg}/computation_graph_builder.cc | 263 +++--- lib/pcg/src/pcg/cpu_id_t.dtg.cc | 74 ++ lib/pcg/src/pcg/create_grad.dtg.cc | 70 ++ lib/pcg/src/pcg/device_id.cc | 32 + lib/pcg/src/pcg/device_id_t.dtg.cc | 103 +++ lib/pcg/src/pcg/device_type.dtg.cc | 70 ++ .../v1/graphs/v1_graph_edge.dtg.cc | 94 ++ .../v1/graphs/v1_graph_output.dtg.cc | 81 ++ .../v1/graphs/v1_jsonable_graph.dtg.cc | 10 + .../v1/graphs/v1_multidigraph.dtg.cc | 56 ++ lib/pcg/src/pcg/gpu_id_t.dtg.cc | 74 ++ lib/pcg/src/pcg/initializer_attrs.dtg.cc | 158 ++++ .../constant_initializer_attrs.dtg.cc | 80 ++ .../initializers/glorot_uniform_attrs.dtg.cc | 76 ++ .../norm_initializer_attrs.dtg.cc | 96 +++ .../uniform_initializer_attrs.dtg.cc | 95 +++ .../zero_initializer_attrs.dtg.cc | 71 ++ lib/pcg/src/pcg/layer_attrs.dtg.cc | 84 ++ lib/pcg/src/pcg/machine_specification.dtg.cc | 151 ++++ lib/pcg/src/pcg/machine_view.cc | 63 ++ lib/pcg/src/pcg/machine_view.dtg.cc | 78 ++ lib/pcg/src/pcg/num_points_t.dtg.cc | 75 ++ lib/pcg/src/pcg/operator_guid_t.dtg.cc | 59 ++ .../optimizers/adam_optimizer_attrs.dtg.cc | 192 +++++ .../pcg/optimizers/sgd_optimizer_attrs.dtg.cc | 111 +++ .../src/pcg/parallel_computation_graph.dtg.cc | 22 + lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc | 83 ++ lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc | 134 +++ lib/pcg/src/pcg/side_size_t.dtg.cc | 75 ++ lib/pcg/src/pcg/strided_rectangle.dtg.cc | 77 ++ lib/pcg/src/pcg/strided_rectangle_side.cc | 14 + lib/pcg/src/pcg/strided_rectangle_side.dtg.cc | 91 ++ lib/pcg/src/pcg/tensor_attrs.dtg.cc | 133 +++ lib/pcg/src/pcg/tensor_guid_t.dtg.cc | 59 ++ lib/pcg/src/serialization.cc | 3 - lib/pcg/src/strided_rectangle.cc | 39 +- lib/pcg/src/tensor.cc | 13 - .../include/substitutions/attribute_expr.h | 40 - .../substitutions/constraint_type.dtg.h | 40 + .../substitutions/constraint_type.enum.toml | 11 + .../include/substitutions/graph_pattern.h | 30 +- .../substitutions/graph_pattern_match.h | 42 - .../include/substitutions/operator_pattern.h | 107 --- .../operator_pattern/eval_list_access.h | 15 + .../operator_pattern/eval_list_size.h | 14 + .../{ => operator_pattern}/get_attribute.h | 16 +- .../operator_attribute_constraint.dtg.h | 62 ++ .../operator_attribute_constraint.struct.toml | 28 + .../operator_attribute_expr.dtg.h | 143 ++++ .../operator_attribute_expr.h | 16 + .../operator_attribute_expr.variant.toml | 27 + .../operator_attribute_key.dtg.h | 97 +++ .../operator_attribute_key.enum.toml | 67 ++ .../operator_attribute_list_access.dtg.h | 67 ++ ...operator_attribute_list_access.struct.toml | 22 + .../operator_attribute_list_size.dtg.h | 64 ++ .../operator_attribute_list_size.struct.toml | 19 + .../operator_attribute_pattern.dtg.h | 56 ++ .../operator_attribute_pattern.struct.toml | 20 + .../operator_attribute_value.dtg.h | 264 ++++++ .../operator_attribute_value.variant.toml | 63 ++ .../operator_pattern/satisfies_constraint.h | 13 + .../operator_pattern/satisfies_pattern.h | 13 + .../include/substitutions/output_graph.h | 35 - .../output_graph/attr_constant.dtg.h | 46 + .../output_graph/attr_constant.struct.toml | 16 + .../output_graph/output_graph_expr.dtg.h | 28 + .../output_graph_expr.struct.toml | 12 + .../output_operator_attr_access.dtg.h | 49 ++ .../output_operator_attr_access.struct.toml | 23 + .../output_operator_attribute_expr.dtg.h | 119 +++ ...utput_operator_attribute_expr.variant.toml | 21 + .../output_operator_attrs_assignment.dtg.h | 49 ++ ...tput_operator_attrs_assignment.struct.toml | 21 + .../substitutions/parallel_tensor_pattern.h | 25 - .../include/substitutions/pcg_pattern.dtg.h | 31 + .../substitutions/pcg_pattern.struct.toml | 12 + .../sub_parallel_computation_graph.dtg.h | 31 + .../sub_parallel_computation_graph.h | 17 +- ...sub_parallel_computation_graph.struct.toml | 13 + .../include/substitutions/substitution.dtg.h | 38 + .../include/substitutions/substitution.h | 26 +- .../substitutions/substitution.struct.toml | 24 + .../tensor_pattern/eval_list_access.h | 14 + .../tensor_pattern/eval_list_size.h | 14 + .../tensor_pattern/get_attribute.h | 14 + .../tensor_pattern/satisfies_constraint.h | 13 + .../tensor_pattern/satisfies_pattern.h | 13 + .../tensor_attribute_constraint.dtg.h | 62 ++ .../tensor_attribute_constraint.struct.toml | 28 + .../tensor_attribute_expr.dtg.h | 141 +++ .../tensor_pattern/tensor_attribute_expr.h | 16 + .../tensor_attribute_expr.variant.toml | 27 + .../tensor_pattern/tensor_attribute_key.dtg.h | 40 + .../tensor_attribute_key.enum.toml | 14 + .../tensor_attribute_list_access.dtg.h | 66 ++ .../tensor_attribute_list_access.struct.toml | 22 + .../tensor_attribute_list_size.dtg.h | 63 ++ .../tensor_attribute_list_size.struct.toml | 19 + .../tensor_attribute_pattern.dtg.h | 56 ++ .../tensor_attribute_pattern.struct.toml | 20 + .../tensor_attribute_value.dtg.h | 118 +++ .../tensor_attribute_value.variant.toml | 21 + .../unlabelled/closed_pattern_edge.dtg.h | 39 + .../closed_pattern_edge.struct.toml | 15 + .../downward_open_pattern_edge.dtg.h | 39 + .../unlabelled/downward_open_pattern_edge.h | 12 + .../downward_open_pattern_edge.struct.toml | 15 + .../unlabelled/edge_splits.dtg.h | 36 + .../substitutions/unlabelled/edge_splits.h | 18 + .../unlabelled/edge_splits.struct.toml | 15 + .../unlabelled/find_pattern_matches.h | 18 + .../unlabelled/input_pattern_edge.dtg.h | 39 + .../unlabelled/input_pattern_edge.h | 13 + .../unlabelled/input_pattern_edge.struct.toml | 15 + .../match_additional_criterion.dtg.h | 36 + .../match_additional_criterion.struct.toml | 18 + .../unlabelled/match_split.dtg.h | 29 + .../substitutions/unlabelled/match_split.h | 18 + .../unlabelled/match_split.struct.toml | 18 + .../multidigraph_pattern_match.dtg.h | 36 + .../unlabelled/multidigraph_pattern_match.h | 17 + .../multidigraph_pattern_match.struct.toml | 24 + .../unlabelled/output_pattern_edge.dtg.h | 39 + .../unlabelled/output_pattern_edge.h | 13 + .../output_pattern_edge.struct.toml | 15 + .../unlabelled/pattern_edge.dtg.h | 39 + .../substitutions/unlabelled/pattern_edge.h | 27 + .../unlabelled/pattern_edge.struct.toml | 15 + .../unlabelled/pattern_matching.h | 24 + .../unlabelled/pattern_node.dtg.h | 39 + .../unlabelled/pattern_node.struct.toml | 15 + .../unlabelled/pattern_split.dtg.h | 47 + .../substitutions/unlabelled/pattern_split.h | 21 + .../unlabelled/pattern_split.struct.toml | 22 + .../unlabelled/unlabelled_graph_pattern.dtg.h | 24 + .../unlabelled/unlabelled_graph_pattern.h | 25 + .../unlabelled_graph_pattern.struct.toml | 10 + .../unlabelled/upward_open_pattern_edge.dtg.h | 39 + .../unlabelled/upward_open_pattern_edge.h | 12 + .../upward_open_pattern_edge.struct.toml | 15 + lib/substitutions/src/graph_pattern.cc | 257 ------ lib/substitutions/src/graph_pattern_match.cc | 305 ------- .../src/sub_parallel_computation_graph.cc | 3 - lib/substitutions/src/substitution.cc | 805 ++++++++---------- .../src/substitutions/constraint_type.dtg.cc | 64 ++ .../src/substitutions/graph_pattern.cc | 44 + .../operator_pattern/eval_list_access.cc | 39 + .../operator_pattern/eval_list_size.cc | 28 + .../operator_pattern/get_attribute.cc} | 128 ++- .../operator_attribute_constraint.dtg.cc | 121 +++ .../operator_attribute_expr.cc | 21 + .../operator_attribute_expr.dtg.cc | 137 +++ .../operator_attribute_key.dtg.cc | 505 +++++++++++ .../operator_attribute_list_access.dtg.cc | 101 +++ .../operator_attribute_list_size.dtg.cc | 88 ++ .../operator_attribute_pattern.dtg.cc | 73 ++ .../operator_attribute_value.dtg.cc | 292 +++++++ .../operator_pattern/satisfies_constraint.cc | 21 + .../operator_pattern/satisfies_pattern.cc | 11 + .../output_graph/attr_constant.dtg.cc | 59 ++ .../output_graph/output_graph_expr.dtg.cc | 20 + .../output_operator_attr_access.dtg.cc | 77 ++ .../output_operator_attribute_expr.dtg.cc | 79 ++ .../output_operator_attrs_assignment.dtg.cc | 58 ++ .../src/substitutions/pcg_pattern.dtg.cc | 21 + .../sub_parallel_computation_graph.cc | 17 + .../sub_parallel_computation_graph.dtg.cc | 22 + .../src/substitutions/substitution.cc | 153 ++++ .../src/substitutions/substitution.dtg.cc | 28 + .../tensor_pattern/eval_list_access.cc | 23 + .../tensor_pattern/eval_list_size.cc | 18 + .../tensor_pattern/get_attribute.cc | 25 + .../tensor_pattern/satisfies_constraint.cc | 17 + .../tensor_pattern/satisfies_pattern.cc | 10 + .../tensor_attribute_constraint.dtg.cc | 119 +++ .../tensor_pattern/tensor_attribute_expr.cc | 26 + .../tensor_attribute_expr.dtg.cc | 129 +++ .../tensor_attribute_key.dtg.cc | 73 ++ .../tensor_attribute_list_access.dtg.cc | 99 +++ .../tensor_attribute_list_size.dtg.cc | 87 ++ .../tensor_attribute_pattern.dtg.cc | 71 ++ .../tensor_attribute_value.dtg.cc | 105 +++ .../unlabelled/closed_pattern_edge.dtg.cc | 45 + .../unlabelled/downward_open_pattern_edge.cc | 9 + .../downward_open_pattern_edge.dtg.cc | 52 ++ .../substitutions/unlabelled/edge_splits.cc | 31 + .../unlabelled/edge_splits.dtg.cc | 31 + .../unlabelled/find_pattern_matches.cc | 141 +++ .../unlabelled/input_pattern_edge.cc | 9 + .../unlabelled/input_pattern_edge.dtg.cc | 45 + .../match_additional_criterion.dtg.cc | 25 + .../substitutions/unlabelled/match_split.cc | 70 ++ .../unlabelled/match_split.dtg.cc | 26 + .../unlabelled/multidigraph_pattern_match.cc | 54 ++ .../multidigraph_pattern_match.dtg.cc | 34 + .../unlabelled/output_pattern_edge.cc | 9 + .../unlabelled/output_pattern_edge.dtg.cc | 46 + .../substitutions/unlabelled/pattern_edge.cc | 56 ++ .../unlabelled/pattern_edge.dtg.cc | 45 + .../unlabelled/pattern_matching.cc | 74 ++ .../unlabelled/pattern_node.dtg.cc | 45 + .../substitutions/unlabelled/pattern_split.cc | 39 + .../unlabelled/pattern_split.dtg.cc | 60 ++ .../unlabelled/unlabelled_graph_pattern.cc | 43 + .../unlabelled_graph_pattern.dtg.cc | 18 + .../unlabelled/upward_open_pattern_edge.cc | 9 + .../upward_open_pattern_edge.dtg.cc | 52 ++ lib/utils/include/utils/bidict.h | 30 + lib/utils/include/utils/check_fmtable.h | 15 + lib/utils/include/utils/containers.decl.h | 20 +- lib/utils/include/utils/containers.h | 57 +- lib/utils/include/utils/fmt.decl.h | 27 +- lib/utils/include/utils/fmt.h | 82 +- lib/utils/include/utils/fmt/unordered_map.h | 43 + lib/utils/include/utils/join_strings.h | 44 + lib/utils/include/utils/json.h | 3 +- lib/utils/include/utils/optional.h | 8 +- lib/utils/include/utils/overload.h | 15 + lib/utils/include/utils/stack_string.h | 15 + lib/utils/src/utils/overload.cc | 1 + 510 files changed, 16581 insertions(+), 3263 deletions(-) rename lib/op-attrs/include/op-attrs/{activation.h => activation.dtg.h} (76%) rename lib/op-attrs/include/op-attrs/{aggregate_op.h => aggregate_op.dtg.h} (75%) rename lib/op-attrs/include/op-attrs/{datatype_t.h => datatype.dtg.h} (71%) rename lib/op-attrs/include/op-attrs/{datatype_t.enum.toml => datatype.enum.toml} (100%) rename lib/op-attrs/include/op-attrs/{ff_dim.h => ff_dim.dtg.h} (81%) rename lib/op-attrs/include/op-attrs/{l1_regularizer_attrs.h => l1_regularizer_attrs.dtg.h} (94%) rename lib/op-attrs/include/op-attrs/{l2_regularizer_attrs.h => l2_regularizer_attrs.dtg.h} (94%) rename lib/op-attrs/include/op-attrs/{operator_type_t.h => operator_type.dtg.h} (86%) rename lib/op-attrs/include/op-attrs/{operator_type_t.enum.toml => operator_type.enum.toml} (100%) rename lib/op-attrs/include/op-attrs/ops/{attention_attrs.h => attention_attrs.dtg.h} (95%) rename lib/op-attrs/include/op-attrs/ops/{attention_inputs.h => attention_inputs.dtg.h} (94%) create mode 100644 lib/op-attrs/include/op-attrs/ops/batch_matmul.dtg.h rename lib/op-attrs/include/op-attrs/ops/{batch_norm_attrs.h => batch_norm_attrs.dtg.h} (93%) create mode 100644 lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h rename lib/op-attrs/include/op-attrs/ops/{cast_attrs.h => cast_attrs.dtg.h} (85%) rename lib/op-attrs/include/op-attrs/ops/{combine_attrs.h => combine_attrs.dtg.h} (92%) rename lib/op-attrs/include/op-attrs/ops/{concat_attrs.h => concat_attrs.dtg.h} (91%) rename lib/op-attrs/include/op-attrs/ops/{conv_2d_attrs.h => conv_2d_attrs.dtg.h} (93%) rename lib/op-attrs/include/op-attrs/ops/{dropout_attrs.h => dropout_attrs.dtg.h} (94%) rename lib/op-attrs/include/op-attrs/ops/{element_binary_attrs.h => element_binary_attrs.dtg.h} (94%) rename lib/op-attrs/include/op-attrs/ops/{element_scalar_unary_attrs.h => element_scalar_unary_attrs.dtg.h} (92%) rename lib/op-attrs/include/op-attrs/ops/{element_unary_attrs.h => element_unary_attrs.dtg.h} (93%) rename lib/op-attrs/include/op-attrs/ops/{embedding_attrs.h => embedding_attrs.dtg.h} (90%) rename lib/op-attrs/include/op-attrs/ops/{flat_attrs.h => flat_attrs.dtg.h} (86%) rename lib/op-attrs/include/op-attrs/ops/{gather_attrs.h => gather_attrs.dtg.h} (91%) rename lib/op-attrs/include/op-attrs/ops/{input_attrs.h => input_attrs.dtg.h} (94%) rename lib/op-attrs/include/op-attrs/ops/{layer_norm_attrs.h => layer_norm_attrs.dtg.h} (92%) rename lib/op-attrs/include/op-attrs/ops/{linear_attrs.h => linear_attrs.dtg.h} (89%) rename lib/op-attrs/include/op-attrs/ops/{noop_attrs.h => noop_attrs.dtg.h} (86%) rename lib/op-attrs/include/op-attrs/ops/{parallel_attention_inputs.h => parallel_attention_inputs.dtg.h} (93%) rename lib/op-attrs/include/op-attrs/ops/{pool_2d_attrs.h => pool_2d_attrs.dtg.h} (91%) rename lib/op-attrs/include/op-attrs/ops/{reduce_attrs.h => reduce_attrs.dtg.h} (91%) rename lib/op-attrs/include/op-attrs/ops/{reduction_attrs.h => reduction_attrs.dtg.h} (91%) rename lib/op-attrs/include/op-attrs/ops/{repartition_attrs.h => repartition_attrs.dtg.h} (91%) rename lib/op-attrs/include/op-attrs/ops/{replicate_attrs.h => replicate_attrs.dtg.h} (91%) rename lib/op-attrs/include/op-attrs/ops/{reshape_attrs.h => reshape_attrs.dtg.h} (91%) rename lib/op-attrs/include/op-attrs/ops/{reverse_attrs.h => reverse_attrs.dtg.h} (91%) rename lib/op-attrs/include/op-attrs/ops/{softmax_attrs.h => softmax_attrs.dtg.h} (91%) rename lib/op-attrs/include/op-attrs/ops/{split_attrs.h => split_attrs.dtg.h} (92%) rename lib/op-attrs/include/op-attrs/ops/{topk_attrs.h => topk_attrs.dtg.h} (86%) rename lib/op-attrs/include/op-attrs/ops/{transpose_attrs.h => transpose_attrs.dtg.h} (92%) rename lib/op-attrs/include/op-attrs/{parallel_dim_t.h => parallel_dim.dtg.h} (81%) rename lib/op-attrs/include/op-attrs/{parallel_dim_t.struct.toml => parallel_dim.struct.toml} (100%) rename lib/op-attrs/include/op-attrs/{parallel_tensor_dims_t.h => parallel_tensor_dims.dtg.h} (90%) rename lib/op-attrs/include/op-attrs/{parallel_tensor_dims_t.struct.toml => parallel_tensor_dims.struct.toml} (100%) rename lib/op-attrs/include/op-attrs/{parallel_tensor_shape_t.h => parallel_tensor_shape.dtg.h} (90%) rename lib/op-attrs/include/op-attrs/{parallel_tensor_shape_t.struct.toml => parallel_tensor_shape.struct.toml} (100%) create mode 100644 lib/op-attrs/include/op-attrs/param_sync.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/param_sync.enum.toml create mode 100644 lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml rename lib/op-attrs/include/op-attrs/{pool_op.h => pool_op.dtg.h} (75%) create mode 100644 lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.h delete mode 100644 lib/op-attrs/include/op-attrs/regularizer_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml rename lib/op-attrs/include/op-attrs/{tensor_dims_t.h => tensor_dims.dtg.h} (79%) rename lib/op-attrs/include/op-attrs/{tensor_dims_t.struct.toml => tensor_dims.struct.toml} (100%) rename lib/op-attrs/include/op-attrs/{tensor_shape_t.h => tensor_shape.dtg.h} (80%) rename lib/op-attrs/include/op-attrs/{tensor_shape_t.struct.toml => tensor_shape.struct.toml} (100%) delete mode 100644 lib/op-attrs/src/batch_matmul.cc rename lib/op-attrs/src/op-attrs/{activation.cc => activation.dtg.cc} (95%) rename lib/op-attrs/src/op-attrs/{aggregate_op.cc => aggregate_op.dtg.cc} (93%) rename lib/op-attrs/src/op-attrs/{datatype_t.cc => datatype.dtg.cc} (94%) rename lib/op-attrs/src/op-attrs/{ff_dim.cc => ff_dim.dtg.cc} (93%) rename lib/op-attrs/src/op-attrs/{l1_regularizer_attrs.cc => l1_regularizer_attrs.dtg.cc} (94%) rename lib/op-attrs/src/op-attrs/{l2_regularizer_attrs.cc => l2_regularizer_attrs.dtg.cc} (94%) rename lib/op-attrs/src/op-attrs/{operator_type_t.cc => operator_type.dtg.cc} (99%) rename lib/op-attrs/src/op-attrs/ops/{attention_attrs.cc => attention_attrs.dtg.cc} (98%) rename lib/op-attrs/src/op-attrs/ops/{attention_inputs.cc => attention_inputs.dtg.cc} (94%) create mode 100644 lib/op-attrs/src/op-attrs/ops/batch_matmul.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/batch_norm.cc rename lib/op-attrs/src/op-attrs/ops/{batch_norm_attrs.cc => batch_norm_attrs.dtg.cc} (93%) rename lib/op-attrs/src/op-attrs/ops/{broadcast.cc => broadcast.dtg.cc} (93%) rename lib/op-attrs/src/op-attrs/ops/{cast_attrs.cc => cast_attrs.dtg.cc} (91%) rename lib/op-attrs/src/op-attrs/ops/{combine_attrs.cc => combine_attrs.dtg.cc} (94%) rename lib/op-attrs/src/op-attrs/ops/{concat_attrs.cc => concat_attrs.dtg.cc} (93%) rename lib/op-attrs/src/op-attrs/ops/{conv_2d_attrs.cc => conv_2d_attrs.dtg.cc} (97%) create mode 100644 lib/op-attrs/src/op-attrs/ops/dropout.cc rename lib/op-attrs/src/op-attrs/ops/{dropout_attrs.cc => dropout_attrs.dtg.cc} (94%) rename lib/op-attrs/src/op-attrs/ops/{element_binary_attrs.cc => element_binary_attrs.dtg.cc} (95%) rename lib/op-attrs/src/op-attrs/ops/{element_scalar_unary_attrs.cc => element_scalar_unary_attrs.dtg.cc} (93%) rename lib/op-attrs/src/op-attrs/ops/{element_unary_attrs.cc => element_unary_attrs.dtg.cc} (92%) rename lib/op-attrs/src/{ => op-attrs/ops}/embedding.cc (61%) rename lib/op-attrs/src/op-attrs/ops/{embedding_attrs.cc => embedding_attrs.dtg.cc} (95%) rename lib/op-attrs/src/op-attrs/ops/{flat_attrs.cc => flat_attrs.dtg.cc} (92%) rename lib/op-attrs/src/op-attrs/ops/{gather_attrs.cc => gather_attrs.dtg.cc} (91%) create mode 100644 lib/op-attrs/src/op-attrs/ops/input.cc rename lib/op-attrs/src/op-attrs/ops/{input_attrs.cc => input_attrs.dtg.cc} (92%) create mode 100644 lib/op-attrs/src/op-attrs/ops/layer_norm.cc rename lib/op-attrs/src/op-attrs/ops/{layer_norm_attrs.cc => layer_norm_attrs.dtg.cc} (94%) create mode 100644 lib/op-attrs/src/op-attrs/ops/linear.cc rename lib/op-attrs/src/op-attrs/ops/{linear_attrs.cc => linear_attrs.dtg.cc} (95%) create mode 100644 lib/op-attrs/src/op-attrs/ops/noop.cc rename lib/op-attrs/src/op-attrs/ops/{noop_attrs.cc => noop_attrs.dtg.cc} (92%) rename lib/op-attrs/src/op-attrs/ops/{parallel_attention_inputs.cc => parallel_attention_inputs.dtg.cc} (94%) create mode 100644 lib/op-attrs/src/op-attrs/ops/pool_2d.cc rename lib/op-attrs/src/op-attrs/ops/{pool_2d_attrs.cc => pool_2d_attrs.dtg.cc} (97%) create mode 100644 lib/op-attrs/src/op-attrs/ops/reduce.cc rename lib/op-attrs/src/op-attrs/ops/{reduce_attrs.cc => reduce_attrs.dtg.cc} (92%) create mode 100644 lib/op-attrs/src/op-attrs/ops/reduction.cc rename lib/op-attrs/src/op-attrs/ops/{reduction_attrs.cc => reduction_attrs.dtg.cc} (94%) create mode 100644 lib/op-attrs/src/op-attrs/ops/repartition.cc rename lib/op-attrs/src/op-attrs/ops/{repartition_attrs.cc => repartition_attrs.dtg.cc} (94%) create mode 100644 lib/op-attrs/src/op-attrs/ops/replicate.cc rename lib/op-attrs/src/op-attrs/ops/{replicate_attrs.cc => replicate_attrs.dtg.cc} (94%) create mode 100644 lib/op-attrs/src/op-attrs/ops/reshape.cc rename lib/op-attrs/src/op-attrs/ops/{reshape_attrs.cc => reshape_attrs.dtg.cc} (91%) create mode 100644 lib/op-attrs/src/op-attrs/ops/reverse.cc rename lib/op-attrs/src/op-attrs/ops/{reverse_attrs.cc => reverse_attrs.dtg.cc} (92%) create mode 100644 lib/op-attrs/src/op-attrs/ops/softmax.cc rename lib/op-attrs/src/op-attrs/ops/{softmax_attrs.cc => softmax_attrs.dtg.cc} (91%) create mode 100644 lib/op-attrs/src/op-attrs/ops/split.cc rename lib/op-attrs/src/op-attrs/ops/{split_attrs.cc => split_attrs.dtg.cc} (92%) create mode 100644 lib/op-attrs/src/op-attrs/ops/topk.cc rename lib/op-attrs/src/op-attrs/ops/{topk_attrs.cc => topk_attrs.dtg.cc} (94%) create mode 100644 lib/op-attrs/src/op-attrs/ops/transpose.cc rename lib/op-attrs/src/op-attrs/ops/{transpose_attrs.cc => transpose_attrs.dtg.cc} (91%) rename lib/op-attrs/src/op-attrs/{parallel_dim_t.cc => parallel_dim.dtg.cc} (94%) rename lib/op-attrs/src/op-attrs/{parallel_tensor_dims_t.cc => parallel_tensor_dims.dtg.cc} (89%) rename lib/op-attrs/src/op-attrs/{parallel_tensor_shape_t.cc => parallel_tensor_shape.dtg.cc} (90%) create mode 100644 lib/op-attrs/src/op-attrs/param_sync.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc rename lib/op-attrs/src/op-attrs/{pool_op.cc => pool_op.dtg.cc} (94%) create mode 100644 lib/op-attrs/src/op-attrs/regularizer_attrs.dtg.cc rename lib/op-attrs/src/op-attrs/{tensor_dims_t.cc => tensor_dims.dtg.cc} (90%) rename lib/op-attrs/src/op-attrs/{tensor_shape_t.cc => tensor_shape.dtg.cc} (90%) create mode 100644 lib/op-attrs/test/CMakeLists.txt create mode 100644 lib/op-attrs/test/src/test_operator_attrs.cc create mode 100644 lib/pcg/include/pcg/computation_graph.dtg.h create mode 100644 lib/pcg/include/pcg/computation_graph.struct.toml create mode 100644 lib/pcg/include/pcg/cpu_id_t.dtg.h create mode 100644 lib/pcg/include/pcg/cpu_id_t.struct.toml create mode 100644 lib/pcg/include/pcg/create_grad.dtg.h create mode 100644 lib/pcg/include/pcg/create_grad.enum.toml create mode 100644 lib/pcg/include/pcg/device_id_t.dtg.h create mode 100644 lib/pcg/include/pcg/device_id_t.variant.toml create mode 100644 lib/pcg/include/pcg/device_type.dtg.h create mode 100644 lib/pcg/include/pcg/device_type.enum.toml delete mode 100644 lib/pcg/include/pcg/device_type.h rename lib/pcg/include/pcg/file_format/v1/{data_type.h => data_type_value.h} (64%) create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.h create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.h create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.h create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml delete mode 100644 lib/pcg/include/pcg/file_format/v1/initializer.h delete mode 100644 lib/pcg/include/pcg/file_format/v1/operator_attrs.h delete mode 100644 lib/pcg/include/pcg/file_format/v1/parallel_tensor.h delete mode 100644 lib/pcg/include/pcg/file_format/v1/param_sync.h delete mode 100644 lib/pcg/include/pcg/file_format/v1/tensor.h create mode 100644 lib/pcg/include/pcg/gpu_id_t.dtg.h create mode 100644 lib/pcg/include/pcg/gpu_id_t.struct.toml delete mode 100644 lib/pcg/include/pcg/initializer.h create mode 100644 lib/pcg/include/pcg/initializer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/initializer_attrs.variant.toml create mode 100644 lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/initializers/glorot_uniform_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/initializers/norm_initializer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/initializers/zero_initializer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml delete mode 100644 lib/pcg/include/pcg/layer.h create mode 100644 lib/pcg/include/pcg/layer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/layer_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/machine_specification.dtg.h create mode 100644 lib/pcg/include/pcg/machine_specification.struct.toml create mode 100644 lib/pcg/include/pcg/machine_view.dtg.h create mode 100644 lib/pcg/include/pcg/machine_view.struct.toml create mode 100644 lib/pcg/include/pcg/num_points_t.dtg.h create mode 100644 lib/pcg/include/pcg/num_points_t.struct.toml delete mode 100644 lib/pcg/include/pcg/operator.h create mode 100644 lib/pcg/include/pcg/operator_guid_t.dtg.h delete mode 100644 lib/pcg/include/pcg/operator_guid_t.h create mode 100644 lib/pcg/include/pcg/operator_guid_t.struct.toml delete mode 100644 lib/pcg/include/pcg/optimizer.h create mode 100644 lib/pcg/include/pcg/optimizer_attrs.h create mode 100644 lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/parallel_computation_graph.dtg.h create mode 100644 lib/pcg/include/pcg/parallel_computation_graph.struct.toml create mode 100644 lib/pcg/include/pcg/parallel_layer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/parallel_layer_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/parallel_tensor_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml delete mode 100644 lib/pcg/include/pcg/serialization.h create mode 100644 lib/pcg/include/pcg/side_size_t.dtg.h create mode 100644 lib/pcg/include/pcg/side_size_t.struct.toml create mode 100644 lib/pcg/include/pcg/strided_rectangle.dtg.h create mode 100644 lib/pcg/include/pcg/strided_rectangle.struct.toml create mode 100644 lib/pcg/include/pcg/strided_rectangle_side.dtg.h create mode 100644 lib/pcg/include/pcg/strided_rectangle_side.h create mode 100644 lib/pcg/include/pcg/strided_rectangle_side.struct.toml delete mode 100644 lib/pcg/include/pcg/tensor.h create mode 100644 lib/pcg/include/pcg/tensor_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/tensor_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/tensor_guid_t.dtg.h delete mode 100644 lib/pcg/include/pcg/tensor_guid_t.h create mode 100644 lib/pcg/include/pcg/tensor_guid_t.struct.toml delete mode 100644 lib/pcg/src/device_id.cc delete mode 100644 lib/pcg/src/file_format/v1/v1.cc delete mode 100644 lib/pcg/src/layer.cc delete mode 100644 lib/pcg/src/machine_view.cc delete mode 100644 lib/pcg/src/operator.cc delete mode 100644 lib/pcg/src/parallel_computation_graph.cc delete mode 100644 lib/pcg/src/parallel_tensor.cc create mode 100644 lib/pcg/src/pcg/computation_graph.dtg.cc rename lib/pcg/src/{ => pcg}/computation_graph_builder.cc (55%) create mode 100644 lib/pcg/src/pcg/cpu_id_t.dtg.cc create mode 100644 lib/pcg/src/pcg/create_grad.dtg.cc create mode 100644 lib/pcg/src/pcg/device_id.cc create mode 100644 lib/pcg/src/pcg/device_id_t.dtg.cc create mode 100644 lib/pcg/src/pcg/device_type.dtg.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_edge.dtg.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_output.dtg.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc create mode 100644 lib/pcg/src/pcg/gpu_id_t.dtg.cc create mode 100644 lib/pcg/src/pcg/initializer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/initializers/glorot_uniform_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/initializers/norm_initializer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/initializers/zero_initializer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/layer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/machine_specification.dtg.cc create mode 100644 lib/pcg/src/pcg/machine_view.cc create mode 100644 lib/pcg/src/pcg/machine_view.dtg.cc create mode 100644 lib/pcg/src/pcg/num_points_t.dtg.cc create mode 100644 lib/pcg/src/pcg/operator_guid_t.dtg.cc create mode 100644 lib/pcg/src/pcg/optimizers/adam_optimizer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/optimizers/sgd_optimizer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/parallel_computation_graph.dtg.cc create mode 100644 lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/side_size_t.dtg.cc create mode 100644 lib/pcg/src/pcg/strided_rectangle.dtg.cc create mode 100644 lib/pcg/src/pcg/strided_rectangle_side.cc create mode 100644 lib/pcg/src/pcg/strided_rectangle_side.dtg.cc create mode 100644 lib/pcg/src/pcg/tensor_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/tensor_guid_t.dtg.cc delete mode 100644 lib/pcg/src/serialization.cc delete mode 100644 lib/pcg/src/tensor.cc delete mode 100644 lib/substitutions/include/substitutions/attribute_expr.h create mode 100644 lib/substitutions/include/substitutions/constraint_type.dtg.h create mode 100644 lib/substitutions/include/substitutions/constraint_type.enum.toml delete mode 100644 lib/substitutions/include/substitutions/graph_pattern_match.h delete mode 100644 lib/substitutions/include/substitutions/operator_pattern.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/eval_list_access.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/eval_list_size.h rename lib/substitutions/include/substitutions/{ => operator_pattern}/get_attribute.h (83%) create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.dtg.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.dtg.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml create mode 100644 lib/substitutions/include/substitutions/operator_pattern/satisfies_constraint.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/satisfies_pattern.h delete mode 100644 lib/substitutions/include/substitutions/output_graph.h create mode 100644 lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.h create mode 100644 lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml create mode 100644 lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h create mode 100644 lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml create mode 100644 lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.h create mode 100644 lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml create mode 100644 lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.dtg.h create mode 100644 lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml create mode 100644 lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.h create mode 100644 lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml delete mode 100644 lib/substitutions/include/substitutions/parallel_tensor_pattern.h create mode 100644 lib/substitutions/include/substitutions/pcg_pattern.dtg.h create mode 100644 lib/substitutions/include/substitutions/pcg_pattern.struct.toml create mode 100644 lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h create mode 100644 lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml create mode 100644 lib/substitutions/include/substitutions/substitution.dtg.h create mode 100644 lib/substitutions/include/substitutions/substitution.struct.toml create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.dtg.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.dtg.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/edge_splits.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/match_split.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/match_split.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_edge.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_matching.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_split.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml delete mode 100644 lib/substitutions/src/graph_pattern.cc delete mode 100644 lib/substitutions/src/graph_pattern_match.cc delete mode 100644 lib/substitutions/src/sub_parallel_computation_graph.cc create mode 100644 lib/substitutions/src/substitutions/constraint_type.dtg.cc create mode 100644 lib/substitutions/src/substitutions/graph_pattern.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/eval_list_size.cc rename lib/substitutions/src/{operator_attributes.cc => substitutions/operator_pattern/get_attribute.cc} (73%) create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.dtg.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_key.dtg.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_value.dtg.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/satisfies_pattern.cc create mode 100644 lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc create mode 100644 lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc create mode 100644 lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc create mode 100644 lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.dtg.cc create mode 100644 lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc create mode 100644 lib/substitutions/src/substitutions/pcg_pattern.dtg.cc create mode 100644 lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc create mode 100644 lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc create mode 100644 lib/substitutions/src/substitutions/substitution.cc create mode 100644 lib/substitutions/src/substitutions/substitution.dtg.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/eval_list_size.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/satisfies_pattern.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.dtg.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_key.dtg.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_value.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/edge_splits.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/match_split.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/match_split.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_split.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc create mode 100644 lib/utils/include/utils/check_fmtable.h create mode 100644 lib/utils/include/utils/fmt/unordered_map.h create mode 100644 lib/utils/include/utils/join_strings.h create mode 100644 lib/utils/include/utils/overload.h create mode 100644 lib/utils/src/utils/overload.cc diff --git a/.proj.toml b/.proj.toml index a4592dcccc..3f4fcddaad 100644 --- a/.proj.toml +++ b/.proj.toml @@ -7,13 +7,14 @@ build_targets = [ "utils", "op-attrs", "kernels", - "substitutions", - "compiler", + # "substitutions", + # "compiler", ] test_targets = [ "utils-tests", - "substitutions-tests", - "compiler-tests", + "op-attrs-tests", + # "substitutions-tests", + # "compiler-tests", ] [cmake_flags_extra] diff --git a/flake.lock b/flake.lock index 893bb00c6c..9745539839 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1712564474, - "narHash": "sha256-4+QCnVuTpCSTxtTcH/NmAfsH0XvU6MLdoMNiSiMCCaE=", + "lastModified": 1713942681, + "narHash": "sha256-thpBjg7m0wCqmcLzLZdZqXIW2sfwUpiBrHriimfeoZU=", "owner": "lockshaw", "repo": "proj", - "rev": "2c9d234aefa756d7800c966589cac8874f6f21a0", + "rev": "f9ee9aa7de919734228518f76f0f02d5fcdbb295", "type": "github" }, "original": { diff --git a/lib/op-attrs/CMakeLists.txt b/lib/op-attrs/CMakeLists.txt index 778be53d7c..9a9721ef2d 100644 --- a/lib/op-attrs/CMakeLists.txt +++ b/lib/op-attrs/CMakeLists.txt @@ -12,3 +12,4 @@ ff_add_library( ) add_subdirectory(ffi) +add_subdirectory(test) diff --git a/lib/op-attrs/include/op-attrs/activation.h b/lib/op-attrs/include/op-attrs/activation.dtg.h similarity index 76% rename from lib/op-attrs/include/op-attrs/activation.h rename to lib/op-attrs/include/op-attrs/activation.dtg.h index 4f16289652..a4c0e97882 100644 --- a/lib/op-attrs/include/op-attrs/activation.h +++ b/lib/op-attrs/include/op-attrs/activation.dtg.h @@ -1,9 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/activation.enum.toml +/* proj-data +{ + "generated_from": "2b0d2e3e825732838aa5be99f2f0e6df" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_ACTIVATION_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_ACTIVATION_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_ACTIVATION_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_ACTIVATION_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" @@ -32,4 +37,4 @@ struct Arbitrary { }; } // namespace rc -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_ACTIVATION_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_ACTIVATION_DTG_H diff --git a/lib/op-attrs/include/op-attrs/aggregate_op.h b/lib/op-attrs/include/op-attrs/aggregate_op.dtg.h similarity index 75% rename from lib/op-attrs/include/op-attrs/aggregate_op.h rename to lib/op-attrs/include/op-attrs/aggregate_op.dtg.h index e2c7177d9f..3ff3848dca 100644 --- a/lib/op-attrs/include/op-attrs/aggregate_op.h +++ b/lib/op-attrs/include/op-attrs/aggregate_op.dtg.h @@ -1,9 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/aggregate_op.enum.toml +/* proj-data +{ + "generated_from": "441fe9b0bb8f2dc2b31f74c58320ef30" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" @@ -32,4 +37,4 @@ struct Arbitrary { }; } // namespace rc -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_DTG_H diff --git a/lib/op-attrs/include/op-attrs/datatype_t.h b/lib/op-attrs/include/op-attrs/datatype.dtg.h similarity index 71% rename from lib/op-attrs/include/op-attrs/datatype_t.h rename to lib/op-attrs/include/op-attrs/datatype.dtg.h index 67d7278927..7052dba3b3 100644 --- a/lib/op-attrs/include/op-attrs/datatype_t.h +++ b/lib/op-attrs/include/op-attrs/datatype.dtg.h @@ -1,9 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/datatype_t.enum.toml +// lib/op-attrs/include/op-attrs/datatype.enum.toml +/* proj-data +{ + "generated_from": "8315d0aa0a65b00c13aa580e923592ef" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DATATYPE_T_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DATATYPE_T_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DATATYPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DATATYPE_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" @@ -32,4 +37,4 @@ struct Arbitrary { }; } // namespace rc -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DATATYPE_T_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DATATYPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/datatype_t.enum.toml b/lib/op-attrs/include/op-attrs/datatype.enum.toml similarity index 100% rename from lib/op-attrs/include/op-attrs/datatype_t.enum.toml rename to lib/op-attrs/include/op-attrs/datatype.enum.toml diff --git a/lib/op-attrs/include/op-attrs/datatype.h b/lib/op-attrs/include/op-attrs/datatype.h index d4c61e9895..f3f3c4a08e 100644 --- a/lib/op-attrs/include/op-attrs/datatype.h +++ b/lib/op-attrs/include/op-attrs/datatype.h @@ -4,7 +4,7 @@ #include "utils/fmt.h" #include "utils/fp16.h" #include -#include "op-attrs/datatype_t.h" +#include "op-attrs/datatype.dtg.h" namespace FlexFlow { @@ -53,7 +53,7 @@ using DataTypeValue = std::variant, real_type, real_type, real_type, - real_type, + /* real_type, */ real_type>; size_t size_of(DataType); diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index ae1507914d..d0e9ef9a4d 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim.dtg.h" #include "utils/stack_vector.h" #include "utils/json.h" diff --git a/lib/op-attrs/include/op-attrs/ff_dim.h b/lib/op-attrs/include/op-attrs/ff_dim.dtg.h similarity index 81% rename from lib/op-attrs/include/op-attrs/ff_dim.h rename to lib/op-attrs/include/op-attrs/ff_dim.dtg.h index d7f590aeac..8363ef0207 100644 --- a/lib/op-attrs/include/op-attrs/ff_dim.h +++ b/lib/op-attrs/include/op-attrs/ff_dim.dtg.h @@ -1,15 +1,19 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ff_dim.struct.toml +/* proj-data +{ + "generated_from": "ffd119eb46e048b0f5a2d8fbef253de3" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include #include -#include #include namespace FlexFlow { @@ -47,4 +51,4 @@ std::string format_as(ff_dim_t const &); std::ostream &operator<<(std::ostream &, ff_dim_t const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_DTG_H diff --git a/lib/op-attrs/include/op-attrs/get_output_shapes.h b/lib/op-attrs/include/op-attrs/get_output_shapes.h index 6fb93aac91..5f8732a9d7 100644 --- a/lib/op-attrs/include/op-attrs/get_output_shapes.h +++ b/lib/op-attrs/include/op-attrs/get_output_shapes.h @@ -125,13 +125,6 @@ ParallelTensorShape get_output_shape(Conv2DAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(DropoutAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ElementUnaryAttrs const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ElementScalarUnaryAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(EmbeddingAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(FlatAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.h b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.h similarity index 94% rename from lib/op-attrs/include/op-attrs/l1_regularizer_attrs.h rename to lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.h index 18afd8a38b..1d4747db7e 100644 --- a/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.h +++ b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "50968fb8a3d43395d0eab7594f4935c0" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L1_REGULARIZER_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L1_REGULARIZER_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L1_REGULARIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L1_REGULARIZER_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "rapidcheck.h" #include #include -#include #include namespace FlexFlow { @@ -55,4 +59,4 @@ std::string format_as(L1RegularizerAttrs const &); std::ostream &operator<<(std::ostream &, L1RegularizerAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L1_REGULARIZER_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L1_REGULARIZER_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.h b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.h similarity index 94% rename from lib/op-attrs/include/op-attrs/l2_regularizer_attrs.h rename to lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.h index 3b403334dc..981d3f4905 100644 --- a/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.h +++ b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "c4f182e547ab6f0d5613e7eeb95d438e" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L2_REGULARIZER_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L2_REGULARIZER_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L2_REGULARIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L2_REGULARIZER_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "rapidcheck.h" #include #include -#include #include namespace FlexFlow { @@ -55,4 +59,4 @@ std::string format_as(L2RegularizerAttrs const &); std::ostream &operator<<(std::ostream &, L2RegularizerAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L2_REGULARIZER_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L2_REGULARIZER_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index 09b290d7ef..1821839e5c 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -34,6 +34,7 @@ #include "utils/variant.h" #include #include "utils/record_formatter.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" namespace FlexFlow { @@ -94,8 +95,8 @@ using ComputationGraphAttrs = variant_join>; using CompGraphOperatorAttrs = ComputationGraphAttrs; -using PCGOperatorAttrs = - variant_join; +/* using PCGOperatorAttrs = */ +/* variant_join; */ static_assert(is_equal_comparable::value, "ComputationGraphAttrs must support =="); @@ -108,14 +109,14 @@ static_assert(is_lt_comparable::value, static_assert(is_hashable::value, "ComputationGraphAttrs must be hashable"); -static_assert(is_equal_comparable::value, - "PCGOperatorAttrs must support =="); -static_assert(is_neq_comparable::value, - "PCGOperatorAttrs must support !="); -static_assert(is_lt_comparable::value, - "PCGOperatorAttrs must support <"); -static_assert(is_hashable::value, - "PCGOperatorAttrs must be hashable"); +/* static_assert(is_equal_comparable::value, */ +/* "PCGOperatorAttrs must support =="); */ +/* static_assert(is_neq_comparable::value, */ +/* "PCGOperatorAttrs must support !="); */ +/* static_assert(is_lt_comparable::value, */ +/* "PCGOperatorAttrs must support <"); */ +/* static_assert(is_hashable::value, */ +/* "PCGOperatorAttrs must be hashable"); */ /* OperatorType get_op_type(CompGraphOperatorAttrs const &); */ /* OperatorType get_op_type(PCGOperatorAttrs const &); */ diff --git a/lib/op-attrs/include/op-attrs/operator_type_t.h b/lib/op-attrs/include/op-attrs/operator_type.dtg.h similarity index 86% rename from lib/op-attrs/include/op-attrs/operator_type_t.h rename to lib/op-attrs/include/op-attrs/operator_type.dtg.h index 170dcc65c4..3b4bd86552 100644 --- a/lib/op-attrs/include/op-attrs/operator_type_t.h +++ b/lib/op-attrs/include/op-attrs/operator_type.dtg.h @@ -1,9 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/operator_type_t.enum.toml +// lib/op-attrs/include/op-attrs/operator_type.enum.toml +/* proj-data +{ + "generated_from": "c1c4687ef2fbc7dad996e5c25d47124c" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_T_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_T_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" @@ -116,4 +121,4 @@ struct Arbitrary { }; } // namespace rc -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_T_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/operator_type_t.enum.toml b/lib/op-attrs/include/op-attrs/operator_type.enum.toml similarity index 100% rename from lib/op-attrs/include/op-attrs/operator_type_t.enum.toml rename to lib/op-attrs/include/op-attrs/operator_type.enum.toml diff --git a/lib/op-attrs/include/op-attrs/operator_type.h b/lib/op-attrs/include/op-attrs/operator_type.h index ef7172eaa2..4750af51ee 100644 --- a/lib/op-attrs/include/op-attrs/operator_type.h +++ b/lib/op-attrs/include/op-attrs/operator_type.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_H -#include "op-attrs/operator_type_t.h" +#include "op-attrs/operator_type.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index 84c52895e1..ae0e791a4e 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -2,11 +2,10 @@ #define _FLEXFLOW_ATTENTION_ATTRS_H #include "core.h" -#include "op-attrs/ops/attention_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "op-attrs/ops/attention_inputs.h" -#include "op-attrs/ops/parallel_attention_inputs.h" -#include "utils/visitable.h" +#include "op-attrs/ops/attention_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/ops/attention_inputs.dtg.h" +#include "op-attrs/ops/parallel_attention_inputs.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/attention_attrs.h b/lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.h similarity index 95% rename from lib/op-attrs/include/op-attrs/ops/attention_attrs.h rename to lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.h index 029ddc08ac..18b2906759 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml +/* proj-data +{ + "generated_from": "360324465947562229dc6632a9e9a2f3" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "rapidcheck.h" #include #include -#include #include namespace FlexFlow { @@ -69,4 +73,4 @@ std::string format_as(MultiHeadAttentionAttrs const &); std::ostream &operator<<(std::ostream &, MultiHeadAttentionAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/attention_inputs.h b/lib/op-attrs/include/op-attrs/ops/attention_inputs.dtg.h similarity index 94% rename from lib/op-attrs/include/op-attrs/ops/attention_inputs.h rename to lib/op-attrs/include/op-attrs/ops/attention_inputs.dtg.h index 3d2d4d1d74..bc1116eb17 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention_inputs.h +++ b/lib/op-attrs/include/op-attrs/ops/attention_inputs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml +/* proj-data +{ + "generated_from": "700f5fb734284b7feabbdd4cb61f3183" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_INPUTS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_INPUTS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_INPUTS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_INPUTS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/tensor_shape.h" #include #include -#include #include namespace FlexFlow { @@ -52,4 +56,4 @@ std::string format_as(MultiHeadAttentionInputs const &); std::ostream &operator<<(std::ostream &, MultiHeadAttentionInputs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_INPUTS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_INPUTS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.dtg.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.dtg.h new file mode 100644 index 0000000000..a8ab52d2b3 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml +/* proj-data +{ + "generated_from": "c3bbf4c76982ef27107b74e1e6e5d360" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct BatchMatmulAttrs { + BatchMatmulAttrs() = delete; + BatchMatmulAttrs(int const &a_seq_length_dim, int const &b_seq_length_dim); + + bool operator==(BatchMatmulAttrs const &) const; + bool operator!=(BatchMatmulAttrs const &) const; + bool operator<(BatchMatmulAttrs const &) const; + bool operator>(BatchMatmulAttrs const &) const; + bool operator<=(BatchMatmulAttrs const &) const; + bool operator>=(BatchMatmulAttrs const &) const; + int a_seq_length_dim; + int b_seq_length_dim; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::BatchMatmulAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::BatchMatmulAttrs from_json(json const &); + static void to_json(json &, FlexFlow::BatchMatmulAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(BatchMatmulAttrs const &); +std::ostream &operator<<(std::ostream &, BatchMatmulAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index db781f547b..7860f891e3 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -1,59 +1,13 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml - #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H -#include "fmt/format.h" -#include "nlohmann/json.hpp" -#include "rapidcheck.h" -#include -#include -#include -#include +#include "op-attrs/ops/batch_matmul.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct BatchMatmulAttrs { - BatchMatmulAttrs() = delete; - BatchMatmulAttrs(int const &a_seq_length_dim, int const &b_seq_length_dim); - - bool operator==(BatchMatmulAttrs const &) const; - bool operator!=(BatchMatmulAttrs const &) const; - bool operator<(BatchMatmulAttrs const &) const; - bool operator>(BatchMatmulAttrs const &) const; - bool operator<=(BatchMatmulAttrs const &) const; - bool operator>=(BatchMatmulAttrs const &) const; - int a_seq_length_dim; - int b_seq_length_dim; -}; -} // namespace FlexFlow -namespace std { -template <> -struct hash { - size_t operator()(FlexFlow::BatchMatmulAttrs const &) const; -}; -} // namespace std +bool is_valid(BatchMatmulAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &); -namespace nlohmann { -template <> -struct adl_serializer { - static FlexFlow::BatchMatmulAttrs from_json(json const &); - static void to_json(json &, FlexFlow::BatchMatmulAttrs const &); -}; -} // namespace nlohmann - -namespace rc { -template <> -struct Arbitrary { - static Gen arbitrary(); -}; -} // namespace rc - -namespace FlexFlow { -std::string format_as(BatchMatmulAttrs const &); -std::ostream &operator<<(std::ostream &, BatchMatmulAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index 8f42c3cc74..3230ab4239 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -2,13 +2,12 @@ #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H #include "core.h" -#include "op-attrs/ops/batch_norm_attrs.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" namespace FlexFlow { -ParallelTensorShape get_output_shape(BatchNormAttrs const &); +ParallelTensorShape get_output_shape(BatchNormAttrs const &, ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchNormAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.h b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.h similarity index 93% rename from lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.h rename to lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.h index f786d730c8..f153bfde7e 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml +/* proj-data +{ + "generated_from": "f8e0219d8a3e008a73c38cf84d25f66e" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "rapidcheck.h" #include #include -#include #include namespace FlexFlow { @@ -55,4 +59,4 @@ std::string format_as(BatchNormAttrs const &); std::ostream &operator<<(std::ostream &, BatchNormAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h b/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h new file mode 100644 index 0000000000..b940ccc2b3 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h @@ -0,0 +1,56 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml +/* proj-data +{ + "generated_from": "890d0e63a08a30d925aa170aea6992ba" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "utils/stack_vector.h" +#include +#include +#include + +namespace FlexFlow { +struct BroadcastAttrs { + BroadcastAttrs() = delete; + BroadcastAttrs( + ::FlexFlow::stack_vector const &target_dims); + + bool operator==(BroadcastAttrs const &) const; + bool operator!=(BroadcastAttrs const &) const; + bool operator<(BroadcastAttrs const &) const; + bool operator>(BroadcastAttrs const &) const; + bool operator<=(BroadcastAttrs const &) const; + bool operator>=(BroadcastAttrs const &) const; + ::FlexFlow::stack_vector target_dims; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::BroadcastAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::BroadcastAttrs from_json(json const &); + static void to_json(json &, FlexFlow::BroadcastAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(BroadcastAttrs const &); +std::ostream &operator<<(std::ostream &, BroadcastAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index 2d26a5a51d..9ee96458b9 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -1,52 +1,13 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml - #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H -#include "fmt/format.h" -#include "nlohmann/json.hpp" -#include "utils/stack_vector.h" -#include -#include -#include -#include +#include "op-attrs/ops/broadcast.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct BroadcastAttrs { - BroadcastAttrs() = delete; - BroadcastAttrs( - ::FlexFlow::stack_vector const &target_dims); - - bool operator==(BroadcastAttrs const &) const; - bool operator!=(BroadcastAttrs const &) const; - bool operator<(BroadcastAttrs const &) const; - bool operator>(BroadcastAttrs const &) const; - bool operator<=(BroadcastAttrs const &) const; - bool operator>=(BroadcastAttrs const &) const; - ::FlexFlow::stack_vector target_dims; -}; -} // namespace FlexFlow -namespace std { -template <> -struct hash { - size_t operator()(FlexFlow::BroadcastAttrs const &) const; -}; -} // namespace std +ParallelTensorShape get_output_shape(BroadcastAttrs const &, ParallelTensorShape const &); -namespace nlohmann { -template <> -struct adl_serializer { - static FlexFlow::BroadcastAttrs from_json(json const &); - static void to_json(json &, FlexFlow::BroadcastAttrs const &); -}; -} // namespace nlohmann - -namespace FlexFlow { -std::string format_as(BroadcastAttrs const &); -std::ostream &operator<<(std::ostream &, BroadcastAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/cast.h b/lib/op-attrs/include/op-attrs/ops/cast.h index e86f5a1c82..117dcb1e01 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast.h +++ b/lib/op-attrs/include/op-attrs/ops/cast.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_CAST_ATTRS_H #include "core.h" -#include "op-attrs/ops/cast_attrs.h" +#include "op-attrs/ops/cast_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/cast_attrs.h b/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h similarity index 85% rename from lib/op-attrs/include/op-attrs/ops/cast_attrs.h rename to lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h index ca70701261..5956b5b14f 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml +/* proj-data +{ + "generated_from": "62da4845a8aa0ae4ca3bce432a3aa9a3" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/datatype.h" #include #include -#include #include namespace FlexFlow { @@ -48,4 +52,4 @@ std::string format_as(CastAttrs const &); std::ostream &operator<<(std::ostream &, CastAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/combine.h b/lib/op-attrs/include/op-attrs/ops/combine.h index 3c6b951462..d2d86e2fea 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine.h +++ b/lib/op-attrs/include/op-attrs/ops/combine.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_COMBINE_ATTRS_H #include "core.h" -#include "op-attrs/ops/combine_attrs.h" +#include "op-attrs/ops/combine_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/combine_attrs.h b/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h similarity index 92% rename from lib/op-attrs/include/op-attrs/ops/combine_attrs.h rename to lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h index 4663fd495d..e3c8b9ea2a 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml +/* proj-data +{ + "generated_from": "7caa0f9668b1894f5e446556f1a424c8" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim.dtg.h" #include #include -#include #include namespace FlexFlow { @@ -50,4 +54,4 @@ std::string format_as(CombineAttrs const &); std::ostream &operator<<(std::ostream &, CombineAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml index dea30897bb..6791d3a110 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index e01164eb5b..8a72708971 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_CONCAT_ATTRS_H #include "core.h" -#include "op-attrs/ops/concat_attrs.h" +#include "op-attrs/ops/concat_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.h b/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h similarity index 91% rename from lib/op-attrs/include/op-attrs/ops/concat_attrs.h rename to lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h index 93cafadcb2..3d0b50c688 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +/* proj-data +{ + "generated_from": "b72ef29f9f79a917176c63a5c3683ab5" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim.dtg.h" #include #include -#include #include namespace FlexFlow { @@ -49,4 +53,4 @@ std::string format_as(ConcatAttrs const &); std::ostream &operator<<(std::ostream &, ConcatAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml index 032657b60c..b75839bd9c 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h" + "op-attrs/ff_dim.dtg.h" ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index 47b7149004..b75628cf8a 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_CONV_2D_ATTRS_H #include "core.h" -#include "op-attrs/ops/conv_2d_attrs.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" @@ -10,8 +10,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(Conv2DAttrs); -TensorShape get_kernel_shape(Conv2DAttrs const &, TensorShape const &); -TensorShape get_bias_shape(Conv2DAttrs const &, TensorShape const &); +TensorShape get_kernel_shape(Conv2DAttrs const &attrs, TensorShape const &input); +TensorShape get_bias_shape(Conv2DAttrs const &attrs, TensorShape const &input); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.h b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h similarity index 93% rename from lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.h rename to lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h index d437e4a95a..7eb9bd677c 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h @@ -1,18 +1,22 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml +/* proj-data +{ + "generated_from": "85f65c1b0e0340ea8e8622c2bf9ca38d" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/activation.h" +#include "op-attrs/activation.dtg.h" #include "utils/json.h" #include #include #include -#include #include namespace FlexFlow { @@ -68,4 +72,4 @@ std::string format_as(Conv2DAttrs const &); std::ostream &operator<<(std::ostream &, Conv2DAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml index 1d0b59ce87..b27c2e1899 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml @@ -11,7 +11,7 @@ features = [ includes = [ "", - "op-attrs/activation.h", + "op-attrs/activation.dtg.h", "utils/json.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/dropout.h b/lib/op-attrs/include/op-attrs/ops/dropout.h index 0c0a1b746d..54e6fbf279 100644 --- a/lib/op-attrs/include/op-attrs/ops/dropout.h +++ b/lib/op-attrs/include/op-attrs/ops/dropout.h @@ -2,11 +2,13 @@ #define _FLEXFLOW_DROPOUT_ATTRS_H #include "core.h" -#include "op-attrs/ops/dropout_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { +ParallelTensorShape get_output_shape(DropoutAttrs const &, ParallelTensorShape const &); + CHECK_VALID_OP_ATTR(DropoutAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/dropout_attrs.h b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.h similarity index 94% rename from lib/op-attrs/include/op-attrs/ops/dropout_attrs.h rename to lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.h index 6d17870138..ef86e49560 100644 --- a/lib/op-attrs/include/op-attrs/ops/dropout_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml +/* proj-data +{ + "generated_from": "4fdbf129ea59b8a7306813cfa4c46021" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "rapidcheck.h" #include #include -#include #include namespace FlexFlow { @@ -56,4 +60,4 @@ std::string format_as(DropoutAttrs const &); std::ostream &operator<<(std::ostream &, DropoutAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary.h b/lib/op-attrs/include/op-attrs/ops/element_binary.h index b6ed0e6210..18c4a1eea5 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary.h @@ -2,11 +2,14 @@ #define _FLEXFLOW_ELEMENT_BINARY_ATTRS_H #include "core.h" -#include "op-attrs/ops/element_binary_attrs.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { +ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &); +TensorShape get_output_shape(ElementBinaryAttrs const &, TensorShape const &, TensorShape const &); + CHECK_VALID_OP_ATTR(ElementBinaryAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.h b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h similarity index 94% rename from lib/op-attrs/include/op-attrs/ops/element_binary_attrs.h rename to lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h index a3532a41cf..66a0b66304 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h @@ -1,9 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml +/* proj-data +{ + "generated_from": "1aae4139632791a4b7638e59fa6b5dc8" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" @@ -11,7 +16,6 @@ #include "op-attrs/operator_type.h" #include #include -#include #include namespace FlexFlow { @@ -55,4 +59,4 @@ std::string format_as(ElementBinaryAttrs const &); std::ostream &operator<<(std::ostream &, ElementBinaryAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.h b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h similarity index 92% rename from lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.h rename to lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h index 7eb369111a..61041b3993 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml +/* proj-data +{ + "generated_from": "09554c353caed6075e362da5008c4bd2" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_SCALAR_UNARY_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_SCALAR_UNARY_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_SCALAR_UNARY_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_SCALAR_UNARY_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/operator_type.h" #include #include -#include #include namespace FlexFlow { @@ -50,4 +54,4 @@ std::string format_as(ElementScalarUnaryAttrs const &); std::ostream &operator<<(std::ostream &, ElementScalarUnaryAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_SCALAR_UNARY_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_SCALAR_UNARY_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 888185042a..808c453d2c 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -2,12 +2,19 @@ #define _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H #include "core.h" -#include "op-attrs/ops/element_scalar_unary_attrs.h" -#include "op-attrs/ops/element_unary_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { +ParallelTensorShape get_output_shape(ElementUnaryAttrs const &, ParallelTensorShape const &); +TensorShape get_output_shape(ElementUnaryAttrs const &, TensorShape const &); + +ParallelTensorShape get_output_shape(ElementScalarUnaryAttrs const &, ParallelTensorShape const &); +TensorShape get_output_shape(ElementScalarUnaryAttrs const &, TensorShape const &); + CHECK_VALID_OP_ATTR(ElementUnaryAttrs); CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.h b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h similarity index 93% rename from lib/op-attrs/include/op-attrs/ops/element_unary_attrs.h rename to lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h index dfa2aa30e8..bdf63fda8d 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml +/* proj-data +{ + "generated_from": "fdb867c04cdd7de320f573f360bcab90" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/operator_type.h" #include #include -#include #include namespace FlexFlow { @@ -48,4 +52,4 @@ std::string format_as(ElementUnaryAttrs const &); std::ostream &operator<<(std::ostream &, ElementUnaryAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/embedding.h b/lib/op-attrs/include/op-attrs/ops/embedding.h index 948e189397..2a7d8cd7bf 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_EMBEDDING_ATTRS_H #include "core.h" -#include "op-attrs/ops/embedding_attrs.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" @@ -12,6 +12,8 @@ CHECK_VALID_OP_ATTR(EmbeddingAttrs); TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &); +ParallelTensorShape get_output_shape(EmbeddingAttrs const &, ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.h b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h similarity index 90% rename from lib/op-attrs/include/op-attrs/ops/embedding_attrs.h rename to lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h index 3c72b4c12f..23df0b7cd2 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h @@ -1,18 +1,22 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml +/* proj-data +{ + "generated_from": "65af6a38dfabebbc05c8ad3f75397b07" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/aggregate_op.h" -#include "op-attrs/datatype.h" +#include "op-attrs/aggregate_op.dtg.h" +#include "op-attrs/datatype.dtg.h" #include "utils/stack_vector.h" #include #include -#include #include namespace FlexFlow { @@ -56,4 +60,4 @@ std::string format_as(EmbeddingAttrs const &); std::ostream &operator<<(std::ostream &, EmbeddingAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml index 7985ca694c..1bae4869bd 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml @@ -11,8 +11,8 @@ features = [ includes = [ "utils/stack_vector.h", - "op-attrs/aggregate_op.h", - "op-attrs/datatype.h", + "op-attrs/aggregate_op.dtg.h", + "op-attrs/datatype.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index 6f51d17c98..d5d9069f51 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_FLAT_ATTRS_H #include "core.h" -#include "op-attrs/ops/flat_attrs.h" +#include "op-attrs/ops/flat_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.h b/lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.h similarity index 86% rename from lib/op-attrs/include/op-attrs/ops/flat_attrs.h rename to lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.h index bc7d8f4a62..a94c0aeff3 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml +/* proj-data +{ + "generated_from": "b63924cd671481df30fae314a199c606" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "rapidcheck.h" #include #include -#include #include namespace FlexFlow { @@ -51,4 +55,4 @@ std::string format_as(FlatAttrs const &); std::ostream &operator<<(std::ostream &, FlatAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/gather.h b/lib/op-attrs/include/op-attrs/ops/gather.h index fd75292fe1..79516a8862 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather.h +++ b/lib/op-attrs/include/op-attrs/ops/gather.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_GATHER_ATTRS_H #include "core.h" -#include "op-attrs/ops/gather_attrs.h" +#include "op-attrs/ops/gather_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/gather_attrs.h b/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h similarity index 91% rename from lib/op-attrs/include/op-attrs/ops/gather_attrs.h rename to lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h index a91ad5fb29..6c74d77031 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml +/* proj-data +{ + "generated_from": "ee735644d3c5f53f790e0a1fa8b8beaf" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim.dtg.h" #include #include -#include #include namespace FlexFlow { @@ -48,4 +52,4 @@ std::string format_as(GatherAttrs const &); std::ostream &operator<<(std::ostream &, GatherAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml index 141e41bc24..c66f1585fd 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h" + "op-attrs/ff_dim.dtg.h" ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/input.h b/lib/op-attrs/include/op-attrs/ops/input.h index 9f7a8d2de1..9fe0ee2c2d 100644 --- a/lib/op-attrs/include/op-attrs/ops/input.h +++ b/lib/op-attrs/include/op-attrs/ops/input.h @@ -2,13 +2,15 @@ #define _FLEXFLOW_OP_ATTRS_OPS_OP_ATTRS_INPUT_H #include "core.h" -#include "op-attrs/ops/input_attrs.h" -#include "utils/visitable.h" +#include "op-attrs/ops/input_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(InputAttrs); +ParallelTensorShape get_output_shape(InputAttrs const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/input_attrs.h b/lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.h similarity index 94% rename from lib/op-attrs/include/op-attrs/ops/input_attrs.h rename to lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.h index b700799318..aa2ca1e933 100644 --- a/lib/op-attrs/include/op-attrs/ops/input_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml +/* proj-data +{ + "generated_from": "139ea46d57a3c8738b31b17a8c59a0aa" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_INPUT_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_INPUT_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_INPUT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_INPUT_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "rapidcheck.h" #include #include -#include #include namespace FlexFlow { @@ -51,4 +55,4 @@ std::string format_as(InputAttrs const &); std::ostream &operator<<(std::ostream &, InputAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_INPUT_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_INPUT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index d2e394f0a3..3186bbba11 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -2,11 +2,13 @@ #define _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H #include "core.h" -#include "op-attrs/ops/layer_norm_attrs.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { +ParallelTensorShape get_output_shape(LayerNormAttrs const &, ParallelTensorShape const &); + CHECK_VALID_OP_ATTR(LayerNormAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.h b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h similarity index 92% rename from lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.h rename to lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h index b5839df513..af8ace620a 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h @@ -1,17 +1,21 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml +/* proj-data +{ + "generated_from": "c03d823a6e889e1254b73a0730a71046" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim.dtg.h" #include "utils/stack_vector.h" #include #include -#include #include namespace FlexFlow { @@ -54,4 +58,4 @@ std::string format_as(LayerNormAttrs const &); std::ostream &operator<<(std::ostream &, LayerNormAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml index 5be7f82256..a72b903ebe 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", "utils/stack_vector.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index dc0054faad..2b0c5c7cda 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -2,13 +2,15 @@ #define _FLEXFLOW_LINEAR_ATTRS_H #include "op-attrs/ops/core.h" -#include "op-attrs/ops/linear_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ops/linear_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(LinearAttrs); +ParallelTensorShape get_output_shape(LinearAttrs const &, ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.h b/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h similarity index 89% rename from lib/op-attrs/include/op-attrs/ops/linear_attrs.h rename to lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h index 1c1fae52c4..572520031e 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h @@ -1,19 +1,23 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml +/* proj-data +{ + "generated_from": "dae07c937f6c52d4dc89ec322520e29f" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/activation.h" -#include "op-attrs/datatype.h" -#include "op-attrs/regularizer_attrs.h" +#include "op-attrs/activation.dtg.h" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/regularizer_attrs.dtg.h" #include "utils/json.h" #include #include -#include #include namespace FlexFlow { @@ -59,4 +63,4 @@ std::string format_as(LinearAttrs const &); std::ostream &operator<<(std::ostream &, LinearAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml index 1168276890..8945d47c55 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml @@ -10,9 +10,9 @@ features = [ ] includes = [ - "op-attrs/datatype.h", - "op-attrs/activation.h", - "op-attrs/regularizer_attrs.h", + "op-attrs/datatype.dtg.h", + "op-attrs/activation.dtg.h", + "op-attrs/regularizer_attrs.dtg.h", "utils/json.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/noop.h b/lib/op-attrs/include/op-attrs/ops/noop.h index f5d2090201..635fa3d490 100644 --- a/lib/op-attrs/include/op-attrs/ops/noop.h +++ b/lib/op-attrs/include/op-attrs/ops/noop.h @@ -2,12 +2,15 @@ #define _FLEXFLOW_OP_ATTRS_OPS_NOOP_H #include "core.h" -#include "op-attrs/ops/noop_attrs.h" +#include "op-attrs/ops/noop_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(NoopAttrs); +ParallelTensorShape get_output_shape(NoopAttrs const &, ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/noop_attrs.h b/lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.h similarity index 86% rename from lib/op-attrs/include/op-attrs/ops/noop_attrs.h rename to lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.h index 35473197d9..ed0d8c9348 100644 --- a/lib/op-attrs/include/op-attrs/ops/noop_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml +/* proj-data +{ + "generated_from": "d440077aa598fdad0e5aa95288b63c40" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "rapidcheck.h" #include #include -#include #include namespace FlexFlow { @@ -51,4 +55,4 @@ std::string format_as(NoopAttrs const &); std::ostream &operator<<(std::ostream &, NoopAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.h b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h similarity index 93% rename from lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.h rename to lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h index 55874f2d92..92154711e5 100644 --- a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.h +++ b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml +/* proj-data +{ + "generated_from": "722d92014b31bffcd5ad45eda476d8b3" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_PARALLEL_ATTENTION_INPUTS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_PARALLEL_ATTENTION_INPUTS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_PARALLEL_ATTENTION_INPUTS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_PARALLEL_ATTENTION_INPUTS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/parallel_tensor_shape.h" #include #include -#include #include namespace FlexFlow { @@ -55,4 +59,4 @@ std::ostream &operator<<(std::ostream &, ParallelMultiHeadAttentionInputs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_PARALLEL_ATTENTION_INPUTS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_PARALLEL_ATTENTION_INPUTS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index 1e1624c405..9a9193fd63 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -2,13 +2,15 @@ #define _FLEXFLOW_POOL_2D_ATTRS_H #include "core.h" -#include "op-attrs/ops/pool_2d_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(Pool2DAttrs); +ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.h b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h similarity index 91% rename from lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.h rename to lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h index f8bbfdf320..c976ca0720 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h @@ -1,17 +1,21 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml +/* proj-data +{ + "generated_from": "607be08f56d910bfa340fb180646c126" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/activation.h" -#include "op-attrs/pool_op.h" +#include "op-attrs/activation.dtg.h" +#include "op-attrs/pool_op.dtg.h" #include #include -#include #include namespace FlexFlow { @@ -63,4 +67,4 @@ std::string format_as(Pool2DAttrs const &); std::ostream &operator<<(std::ostream &, Pool2DAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml index 24e9a814de..58854d457c 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/pool_op.h", - "op-attrs/activation.h", + "op-attrs/pool_op.dtg.h", + "op-attrs/activation.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/reduce.h b/lib/op-attrs/include/op-attrs/ops/reduce.h index 10c15b023d..ce5ae7d3fd 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce.h @@ -2,13 +2,15 @@ #define _FLEXFLOW_OP_META_OPS_REDUCE_ATTRS_H #include "core.h" -#include "op-attrs/ops/reduce_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(ReduceAttrs); +ParallelTensorShape get_output_shape(ReduceAttrs const &, ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.h b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h similarity index 91% rename from lib/op-attrs/include/op-attrs/ops/reduce_attrs.h rename to lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h index 0bc0fc759a..f6f78911e3 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h @@ -1,18 +1,22 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml +/* proj-data +{ + "generated_from": "bc6279031650335f4a0b7b6cfe116c85" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/ff_dim.h" -#include "op-attrs/operator_type.h" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/operator_type.dtg.h" #include "utils/stack_vector.h" #include #include -#include #include namespace FlexFlow { @@ -55,4 +59,4 @@ std::string format_as(ReduceAttrs const &); std::ostream &operator<<(std::ostream &, ReduceAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml index 9f4bc6d5aa..e8a1785d19 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/operator_type.h", - "op-attrs/ff_dim.h", + "op-attrs/operator_type.dtg.h", + "op-attrs/ff_dim.dtg.h", "utils/stack_vector.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/reduction.h b/lib/op-attrs/include/op-attrs/ops/reduction.h index 9c421486e8..a4ce679330 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction.h @@ -2,13 +2,15 @@ #define _FLEXFLOW_REDUCTION_ATTRS_H #include "core.h" -#include "op-attrs/ops/reduction_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ops/reduction_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(ReductionAttrs); +ParallelTensorShape get_output_shape(ReductionAttrs const &attrs, ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.h b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h similarity index 91% rename from lib/op-attrs/include/op-attrs/ops/reduction_attrs.h rename to lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h index 903cb1d004..942e1870e8 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml +/* proj-data +{ + "generated_from": "57b8ccb5bc2e1a1a3bcf1bce2d8cad9e" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim.dtg.h" #include #include -#include #include namespace FlexFlow { @@ -50,4 +54,4 @@ std::string format_as(ReductionAttrs const &); std::ostream &operator<<(std::ostream &, ReductionAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml index f0a5f2e08d..5baafdfa42 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/repartition.h b/lib/op-attrs/include/op-attrs/ops/repartition.h index 21a25ccec6..5dff92e966 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition.h @@ -2,13 +2,15 @@ #define _FLEXFLOW_PARTITION_ATTRS_H #include "core.h" -#include "op-attrs/ops/repartition_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ops/repartition_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(RepartitionAttrs); +ParallelTensorShape get_output_shape(RepartitionAttrs const &, ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.h b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h similarity index 91% rename from lib/op-attrs/include/op-attrs/ops/repartition_attrs.h rename to lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h index 9ffb03122e..fa888700d0 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml +/* proj-data +{ + "generated_from": "366cb1a14093762f75508260ac6494ca" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim.dtg.h" #include #include -#include #include namespace FlexFlow { @@ -50,4 +54,4 @@ std::string format_as(RepartitionAttrs const &); std::ostream &operator<<(std::ostream &, RepartitionAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml index 4fca9e8fb4..344691a781 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index 50a6be6d76..c6430ddbc5 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -2,13 +2,15 @@ #define _FLEXFLOW_REPLICATE_ATTRS_H #include "core.h" -#include "op-attrs/ops/replicate_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ops/replicate_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(ReplicateAttrs); +ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.h b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h similarity index 91% rename from lib/op-attrs/include/op-attrs/ops/replicate_attrs.h rename to lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h index ba03f05889..4249a2c0e7 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml +/* proj-data +{ + "generated_from": "4224406d468444433d69e4abf61b7cd1" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim.dtg.h" #include #include -#include #include namespace FlexFlow { @@ -50,4 +54,4 @@ std::string format_as(ReplicateAttrs const &); std::ostream &operator<<(std::ostream &, ReplicateAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml index d892e2033e..d5f9c22f28 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/reshape.h b/lib/op-attrs/include/op-attrs/ops/reshape.h index dec4a6fde7..2cd0287d45 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape.h +++ b/lib/op-attrs/include/op-attrs/ops/reshape.h @@ -2,12 +2,15 @@ #define _FLEXFLOW_RESHAPE_ATTRS_H #include "core.h" -#include "op-attrs/ops/reshape_attrs.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(ReshapeAttrs); +ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.h b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h similarity index 91% rename from lib/op-attrs/include/op-attrs/ops/reshape_attrs.h rename to lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h index 3c326d51e9..860b61f2e8 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml +/* proj-data +{ + "generated_from": "5a6a9e646a457a6cf959c542fb631512" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" #include #include -#include #include namespace FlexFlow { @@ -48,4 +52,4 @@ std::string format_as(ReshapeAttrs const &); std::ostream &operator<<(std::ostream &, ReshapeAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml index 9086cbccae..dc0a96313d 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/tensor_shape.h", + "op-attrs/tensor_shape.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/reverse.h b/lib/op-attrs/include/op-attrs/ops/reverse.h index 2f7243ff58..45b05e62ab 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse.h +++ b/lib/op-attrs/include/op-attrs/ops/reverse.h @@ -2,12 +2,15 @@ #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_H #include "core.h" -#include "op-attrs/ops/reverse_attrs.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(ReverseAttrs); +ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.h b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h similarity index 91% rename from lib/op-attrs/include/op-attrs/ops/reverse_attrs.h rename to lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h index d43363063b..3ed917d33e 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml +/* proj-data +{ + "generated_from": "7c21c4192854f5981018abf4fbdd9ead" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim.dtg.h" #include #include -#include #include namespace FlexFlow { @@ -48,4 +52,4 @@ std::string format_as(ReverseAttrs const &); std::ostream &operator<<(std::ostream &, ReverseAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml index 572b33957e..e2058cf3e5 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/softmax.h b/lib/op-attrs/include/op-attrs/ops/softmax.h index 7b21e3ea38..7ae5eb7438 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax.h +++ b/lib/op-attrs/include/op-attrs/ops/softmax.h @@ -2,13 +2,15 @@ #define _FLEXFLOW_SOFTMAX_ATTRS_H #include "core.h" -#include "op-attrs/ops/softmax_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(SoftmaxAttrs); +ParallelTensorShape get_output_shape(SoftmaxAttrs const &attrs, ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.h b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h similarity index 91% rename from lib/op-attrs/include/op-attrs/ops/softmax_attrs.h rename to lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h index 3e467d6cc1..a2acbf7300 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml +/* proj-data +{ + "generated_from": "9be043678a4ce7666fc372cded600290" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim.dtg.h" #include #include -#include #include namespace FlexFlow { @@ -48,4 +52,4 @@ std::string format_as(SoftmaxAttrs const &); std::ostream &operator<<(std::ostream &, SoftmaxAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml index 380fdc12d3..3e4fcbc75a 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/split.h b/lib/op-attrs/include/op-attrs/ops/split.h index 864e8c4c4a..08ce826945 100644 --- a/lib/op-attrs/include/op-attrs/ops/split.h +++ b/lib/op-attrs/include/op-attrs/ops/split.h @@ -2,13 +2,16 @@ #define _FLEXFLOW_SPLIT_ATTRS_H #include "core.h" -#include "op-attrs/ops/split_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ops/split_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include namespace FlexFlow { CHECK_VALID_OP_ATTR(SplitAttrs); +std::vector get_output_shapes(SplitAttrs const &attrs, ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/split_attrs.h b/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h similarity index 92% rename from lib/op-attrs/include/op-attrs/ops/split_attrs.h rename to lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h index 6edefae88d..dee08ca1c8 100644 --- a/lib/op-attrs/include/op-attrs/ops/split_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h @@ -1,17 +1,21 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml +/* proj-data +{ + "generated_from": "4112baa96de544b865618e0a999e0807" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim.dtg.h" #include "utils/stack_vector.h" #include #include -#include #include namespace FlexFlow { @@ -51,4 +55,4 @@ std::string format_as(SplitAttrs const &); std::ostream &operator<<(std::ostream &, SplitAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml index ca73d4d2a1..8205cdbccb 100644 --- a/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml @@ -11,7 +11,7 @@ features = [ includes = [ "utils/stack_vector.h", - "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index 57379518b6..4fab6584b4 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -2,13 +2,15 @@ #define _FLEXFLOW_TOPK_ATTRS_H #include "core.h" -#include "op-attrs/ops/topk_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ops/topk_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(TopKAttrs); +ParallelTensorShape get_output_shape(TopKAttrs const &attrs, ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/topk_attrs.h b/lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.h similarity index 86% rename from lib/op-attrs/include/op-attrs/ops/topk_attrs.h rename to lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.h index 4d12a3fa41..d1f32f67b7 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml +/* proj-data +{ + "generated_from": "c1be9dc2acafc58690713e650663cc93" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "rapidcheck.h" #include #include -#include #include namespace FlexFlow { @@ -56,4 +60,4 @@ std::string format_as(TopKAttrs const &); std::ostream &operator<<(std::ostream &, TopKAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/transpose.h b/lib/op-attrs/include/op-attrs/ops/transpose.h index f352be8f37..4156885610 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose.h @@ -2,13 +2,15 @@ #define _FLEXFLOW_OP_META_OPS_TRANSPOSE_ATTRS_H #include "core.h" -#include "op-attrs/ops/transpose_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(TransposeAttrs); +ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs, ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.h b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h similarity index 92% rename from lib/op-attrs/include/op-attrs/ops/transpose_attrs.h rename to lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h index 0613561a3b..352aaf6e6a 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h @@ -1,17 +1,21 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml +/* proj-data +{ + "generated_from": "edff0b414040204e895666d81b49db07" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_ATTRS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim.dtg.h" #include "utils/stack_vector.h" #include #include -#include #include namespace FlexFlow { @@ -50,4 +54,4 @@ std::string format_as(TransposeAttrs const &); std::ostream &operator<<(std::ostream &, TransposeAttrs const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_ATTRS_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml index 35646cda40..af13022262 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", "utils/stack_vector.h", ] diff --git a/lib/op-attrs/include/op-attrs/parallel_dim_t.h b/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h similarity index 81% rename from lib/op-attrs/include/op-attrs/parallel_dim_t.h rename to lib/op-attrs/include/op-attrs/parallel_dim.dtg.h index 730a5c6f5e..3492694685 100644 --- a/lib/op-attrs/include/op-attrs/parallel_dim_t.h +++ b/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/parallel_dim_t.struct.toml +// lib/op-attrs/include/op-attrs/parallel_dim.struct.toml +/* proj-data +{ + "generated_from": "186bedde7826c7a3d00343ed63ab9971" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_T_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_T_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "rapidcheck.h" #include #include -#include #include namespace FlexFlow { @@ -59,4 +63,4 @@ std::string format_as(ParallelDim const &); std::ostream &operator<<(std::ostream &, ParallelDim const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_T_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.h b/lib/op-attrs/include/op-attrs/parallel_dim.h index 64c40b9594..5397ad7c68 100644 --- a/lib/op-attrs/include/op-attrs/parallel_dim.h +++ b/lib/op-attrs/include/op-attrs/parallel_dim.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_H -#include "op-attrs/parallel_dim_t.h" +#include "op-attrs/parallel_dim.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/parallel_dim_t.struct.toml b/lib/op-attrs/include/op-attrs/parallel_dim.struct.toml similarity index 100% rename from lib/op-attrs/include/op-attrs/parallel_dim_t.struct.toml rename to lib/op-attrs/include/op-attrs/parallel_dim.struct.toml diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h similarity index 90% rename from lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.h rename to lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h index dd5d53d9cf..ae49a17657 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h @@ -1,9 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.struct.toml +// lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml +/* proj-data +{ + "generated_from": "b46ffa08758bdcc57a75183255248ca6" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_T_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_T_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" @@ -11,7 +16,6 @@ #include "op-attrs/parallel_dim.h" #include #include -#include #include namespace FlexFlow { @@ -50,4 +54,4 @@ std::string format_as(ParallelTensorDims const &); std::ostream &operator<<(std::ostream &, ParallelTensorDims const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_T_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index 92519dc09a..2e7cb57b99 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_H #include "op-attrs/parallel_dim.h" -#include "op-attrs/parallel_tensor_dims_t.h" -#include "op-attrs/tensor_dims_t.h" +#include "op-attrs/parallel_tensor_dims.dtg.h" +#include "op-attrs/tensor_dims.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml similarity index 100% rename from lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.struct.toml rename to lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h similarity index 90% rename from lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.h rename to lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h index e1f3333b9b..dfad5b1007 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h @@ -1,9 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.struct.toml +// lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml +/* proj-data +{ + "generated_from": "b2d36c9212916e66569af4e958c893f4" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_T_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_T_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" @@ -11,7 +16,6 @@ #include "op-attrs/parallel_tensor_dims.h" #include #include -#include #include namespace FlexFlow { @@ -51,4 +55,4 @@ std::string format_as(ParallelTensorShape const &); std::ostream &operator<<(std::ostream &, ParallelTensorShape const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_T_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 3000463365..8a60ce0b8d 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -3,7 +3,7 @@ #include "op-attrs/tensor_shape.h" #include -#include "op-attrs/parallel_tensor_shape_t.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml similarity index 100% rename from lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.struct.toml rename to lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml diff --git a/lib/op-attrs/include/op-attrs/param_sync.dtg.h b/lib/op-attrs/include/op-attrs/param_sync.dtg.h new file mode 100644 index 0000000000..785105fbc4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/param_sync.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/param_sync.enum.toml +/* proj-data +{ + "generated_from": "288c6e9e256cf58ba5dbd0e3791c08df" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARAM_SYNC_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARAM_SYNC_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class ParamSync { PS, NCCL }; +std::string format_as(ParamSync); +std::ostream &operator<<(std::ostream &, ParamSync); +void to_json(::nlohmann::json &, ParamSync); +void from_json(::nlohmann::json const &, ParamSync &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParamSync) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARAM_SYNC_DTG_H diff --git a/lib/op-attrs/include/op-attrs/param_sync.enum.toml b/lib/op-attrs/include/op-attrs/param_sync.enum.toml new file mode 100644 index 0000000000..b16a47ab3c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/param_sync.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ParamSync" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "PS" + +[[values]] +name = "NCCL" diff --git a/lib/op-attrs/include/op-attrs/param_sync.h b/lib/op-attrs/include/op-attrs/param_sync.h index bfae1e712b..55845a931b 100644 --- a/lib/op-attrs/include/op-attrs/param_sync.h +++ b/lib/op-attrs/include/op-attrs/param_sync.h @@ -1,36 +1,10 @@ #ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_PARAM_SYNC_H #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_PARAM_SYNC_H -#include "utils/fmt.h" +#include "param_sync_t.h" namespace FlexFlow { -enum class ParamSync { PS, NCCL }; - } -namespace fmt { - -template <> -struct formatter<::FlexFlow::ParamSync> : formatter { - template - auto format(::FlexFlow::ParamSync ps, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (ps) { - case ParamSync::PS: - name = "ParameterServer"; - break; - case ParamSync::NCCL: - name = "NCCL"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - #endif diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h new file mode 100644 index 0000000000..132a575175 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h @@ -0,0 +1,438 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml +/* proj-data +{ + "generated_from": "e1b5c307ae023ce6d504f605c7ef8491" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PCG_OPERATOR_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PCG_OPERATOR_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ops/attention_attrs.dtg.h" +#include "op-attrs/ops/batch_matmul.dtg.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" +#include "op-attrs/ops/cast_attrs.dtg.h" +#include "op-attrs/ops/combine_attrs.dtg.h" +#include "op-attrs/ops/concat_attrs.dtg.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" +#include "op-attrs/ops/flat_attrs.dtg.h" +#include "op-attrs/ops/gather_attrs.dtg.h" +#include "op-attrs/ops/input_attrs.dtg.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" +#include "op-attrs/ops/linear_attrs.dtg.h" +#include "op-attrs/ops/noop_attrs.dtg.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" +#include "op-attrs/ops/reduction_attrs.dtg.h" +#include "op-attrs/ops/repartition_attrs.dtg.h" +#include "op-attrs/ops/replicate_attrs.dtg.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" +#include "op-attrs/ops/split_attrs.dtg.h" +#include "op-attrs/ops/topk_attrs.dtg.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct PCGOperatorAttrs { + PCGOperatorAttrs() = delete; + explicit PCGOperatorAttrs(::FlexFlow::BatchMatmulAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::BatchNormAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::CastAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ConcatAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::Conv2DAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::DropoutAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ElementBinaryAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ElementUnaryAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ElementScalarUnaryAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::EmbeddingAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::FlatAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::GatherAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::InputAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::LayerNormAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::LinearAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::MultiHeadAttentionAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::NoopAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::Pool2DAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ReduceAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ReverseAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ReshapeAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::SplitAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::SoftmaxAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::TopKAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::TransposeAttrs const &); + template + static constexpr bool IsPartOfPCGOperatorAttrs_v = + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::BatchMatmulAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::BatchNormAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::CastAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); + return result; + } + case 7: { + ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); + return result; + } + case 9: { + ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); + return result; + } + case 13: { + ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); + return result; + } + case 14: { + ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); + return result; + } + case 15: { + ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); + return result; + } + case 16: { + ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); + return result; + } + case 17: { + ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); + return result; + } + case 18: { + ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); + return result; + } + case 19: { + ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); + return result; + } + case 20: { + ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + return result; + } + case 21: { + ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + return result; + } + case 22: { + ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + return result; + } + case 23: { + ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + return result; + } + case 24: { + ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type PCGOperatorAttrs", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::BatchMatmulAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::BatchNormAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::CastAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); + return result; + } + case 7: { + ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); + return result; + } + case 9: { + ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); + return result; + } + case 13: { + ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); + return result; + } + case 14: { + ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); + return result; + } + case 15: { + ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); + return result; + } + case 16: { + ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); + return result; + } + case 17: { + ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); + return result; + } + case 18: { + ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); + return result; + } + case 19: { + ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); + return result; + } + case 20: { + ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + return result; + } + case 21: { + ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + return result; + } + case 22: { + ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + return result; + } + case 23: { + ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + return result; + } + case 24: { + ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type PCGOperatorAttrs", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfPCGOperatorAttrs_v, + "PCGOperatorAttrs::has() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::CastAttrs, ::FlexFlow::ConcatAttrs, " + "::FlexFlow::Conv2DAttrs, ::FlexFlow::DropoutAttrs, " + "::FlexFlow::ElementBinaryAttrs, ::FlexFlow::ElementUnaryAttrs, " + "::FlexFlow::ElementScalarUnaryAttrs, ::FlexFlow::EmbeddingAttrs, " + "::FlexFlow::FlatAttrs, ::FlexFlow::GatherAttrs, " + "::FlexFlow::InputAttrs, ::FlexFlow::LayerNormAttrs, " + "::FlexFlow::LinearAttrs, ::FlexFlow::MultiHeadAttentionAttrs, " + "::FlexFlow::NoopAttrs, ::FlexFlow::Pool2DAttrs, " + "::FlexFlow::ReduceAttrs, ::FlexFlow::ReverseAttrs, " + "::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, " + "::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, " + "::FlexFlow::TransposeAttrs], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfPCGOperatorAttrs_v, + "PCGOperatorAttrs::get() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::CastAttrs, ::FlexFlow::ConcatAttrs, " + "::FlexFlow::Conv2DAttrs, ::FlexFlow::DropoutAttrs, " + "::FlexFlow::ElementBinaryAttrs, ::FlexFlow::ElementUnaryAttrs, " + "::FlexFlow::ElementScalarUnaryAttrs, ::FlexFlow::EmbeddingAttrs, " + "::FlexFlow::FlatAttrs, ::FlexFlow::GatherAttrs, " + "::FlexFlow::InputAttrs, ::FlexFlow::LayerNormAttrs, " + "::FlexFlow::LinearAttrs, ::FlexFlow::MultiHeadAttentionAttrs, " + "::FlexFlow::NoopAttrs, ::FlexFlow::Pool2DAttrs, " + "::FlexFlow::ReduceAttrs, ::FlexFlow::ReverseAttrs, " + "::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, " + "::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, " + "::FlexFlow::TransposeAttrs], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfPCGOperatorAttrs_v, + "PCGOperatorAttrs::get() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::CastAttrs, ::FlexFlow::ConcatAttrs, " + "::FlexFlow::Conv2DAttrs, ::FlexFlow::DropoutAttrs, " + "::FlexFlow::ElementBinaryAttrs, ::FlexFlow::ElementUnaryAttrs, " + "::FlexFlow::ElementScalarUnaryAttrs, ::FlexFlow::EmbeddingAttrs, " + "::FlexFlow::FlatAttrs, ::FlexFlow::GatherAttrs, " + "::FlexFlow::InputAttrs, ::FlexFlow::LayerNormAttrs, " + "::FlexFlow::LinearAttrs, ::FlexFlow::MultiHeadAttentionAttrs, " + "::FlexFlow::NoopAttrs, ::FlexFlow::Pool2DAttrs, " + "::FlexFlow::ReduceAttrs, ::FlexFlow::ReverseAttrs, " + "::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, " + "::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, " + "::FlexFlow::TransposeAttrs], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(PCGOperatorAttrs const &) const; + bool operator!=(PCGOperatorAttrs const &) const; + bool operator<(PCGOperatorAttrs const &) const; + bool operator>(PCGOperatorAttrs const &) const; + bool operator<=(PCGOperatorAttrs const &) const; + bool operator>=(PCGOperatorAttrs const &) const; + std::variant<::FlexFlow::BatchMatmulAttrs, + ::FlexFlow::BatchNormAttrs, + ::FlexFlow::CastAttrs, + ::FlexFlow::ConcatAttrs, + ::FlexFlow::Conv2DAttrs, + ::FlexFlow::DropoutAttrs, + ::FlexFlow::ElementBinaryAttrs, + ::FlexFlow::ElementUnaryAttrs, + ::FlexFlow::ElementScalarUnaryAttrs, + ::FlexFlow::EmbeddingAttrs, + ::FlexFlow::FlatAttrs, + ::FlexFlow::GatherAttrs, + ::FlexFlow::InputAttrs, + ::FlexFlow::LayerNormAttrs, + ::FlexFlow::LinearAttrs, + ::FlexFlow::MultiHeadAttentionAttrs, + ::FlexFlow::NoopAttrs, + ::FlexFlow::Pool2DAttrs, + ::FlexFlow::ReduceAttrs, + ::FlexFlow::ReverseAttrs, + ::FlexFlow::ReshapeAttrs, + ::FlexFlow::SplitAttrs, + ::FlexFlow::SoftmaxAttrs, + ::FlexFlow::TopKAttrs, + ::FlexFlow::TransposeAttrs> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::PCGOperatorAttrs> { + size_t operator()(::FlexFlow::PCGOperatorAttrs const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::PCGOperatorAttrs> { + static ::FlexFlow::PCGOperatorAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::PCGOperatorAttrs const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::PCGOperatorAttrs const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::PCGOperatorAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PCG_OPERATOR_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml new file mode 100644 index 0000000000..6f15ec417d --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml @@ -0,0 +1,117 @@ +namespace = "FlexFlow" +name = "PCGOperatorAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "op-attrs/ops/attention_attrs.dtg.h", + "op-attrs/ops/batch_matmul.dtg.h", + "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/cast_attrs.dtg.h", + "op-attrs/ops/combine_attrs.dtg.h", + "op-attrs/ops/combine_attrs.dtg.h", + "op-attrs/ops/concat_attrs.dtg.h", + "op-attrs/ops/conv_2d_attrs.dtg.h", + "op-attrs/ops/dropout_attrs.dtg.h", + "op-attrs/ops/element_binary_attrs.dtg.h", + "op-attrs/ops/element_scalar_unary_attrs.dtg.h", + "op-attrs/ops/element_unary_attrs.dtg.h", + "op-attrs/ops/embedding_attrs.dtg.h", + "op-attrs/ops/flat_attrs.dtg.h", + "op-attrs/ops/gather_attrs.dtg.h", + "op-attrs/ops/input_attrs.dtg.h", + "op-attrs/ops/layer_norm_attrs.dtg.h", + "op-attrs/ops/linear_attrs.dtg.h", + "op-attrs/ops/noop_attrs.dtg.h", + "op-attrs/ops/pool_2d_attrs.dtg.h", + "op-attrs/ops/reduce_attrs.dtg.h", + "op-attrs/ops/reduction_attrs.dtg.h", + "op-attrs/ops/repartition_attrs.dtg.h", + "op-attrs/ops/replicate_attrs.dtg.h", + "op-attrs/ops/reshape_attrs.dtg.h", + "op-attrs/ops/reverse_attrs.dtg.h", + "op-attrs/ops/softmax_attrs.dtg.h", + "op-attrs/ops/split_attrs.dtg.h", + "op-attrs/ops/topk_attrs.dtg.h", + "op-attrs/ops/transpose_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::BatchMatmulAttrs" + +[[values]] +type = "::FlexFlow::BatchNormAttrs" + +[[values]] +type = "::FlexFlow::CastAttrs" + +[[values]] +type = "::FlexFlow::ConcatAttrs" + +[[values]] +type = "::FlexFlow::Conv2DAttrs" + +[[values]] +type = "::FlexFlow::DropoutAttrs" + +[[values]] +type = "::FlexFlow::ElementBinaryAttrs" + +[[values]] +type = "::FlexFlow::ElementUnaryAttrs" + +[[values]] +type = "::FlexFlow::ElementScalarUnaryAttrs" + +[[values]] +type = "::FlexFlow::EmbeddingAttrs" + +[[values]] +type = "::FlexFlow::FlatAttrs" + +[[values]] +type = "::FlexFlow::GatherAttrs" + +[[values]] +type = "::FlexFlow::InputAttrs" + +[[values]] +type = "::FlexFlow::LayerNormAttrs" + +[[values]] +type = "::FlexFlow::LinearAttrs" + +[[values]] +type = "::FlexFlow::MultiHeadAttentionAttrs" + +[[values]] +type = "::FlexFlow::NoopAttrs" + +[[values]] +type = "::FlexFlow::Pool2DAttrs" + +[[values]] +type = "::FlexFlow::ReduceAttrs" + +[[values]] +type = "::FlexFlow::ReverseAttrs" + +[[values]] +type = "::FlexFlow::ReshapeAttrs" + +[[values]] +type = "::FlexFlow::SplitAttrs" + +[[values]] +type = "::FlexFlow::SoftmaxAttrs" + +[[values]] +type = "::FlexFlow::TopKAttrs" + +[[values]] +type = "::FlexFlow::TransposeAttrs" diff --git a/lib/op-attrs/include/op-attrs/pool_op.h b/lib/op-attrs/include/op-attrs/pool_op.dtg.h similarity index 75% rename from lib/op-attrs/include/op-attrs/pool_op.h rename to lib/op-attrs/include/op-attrs/pool_op.dtg.h index 00c7852bbf..3511589b52 100644 --- a/lib/op-attrs/include/op-attrs/pool_op.h +++ b/lib/op-attrs/include/op-attrs/pool_op.dtg.h @@ -1,9 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/pool_op.enum.toml +/* proj-data +{ + "generated_from": "ed1d531c6227306c909eb28eb0a66538" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" @@ -32,4 +37,4 @@ struct Arbitrary { }; } // namespace rc -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_DTG_H diff --git a/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.h b/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.h new file mode 100644 index 0000000000..38add9b42b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.h @@ -0,0 +1,121 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml +/* proj-data +{ + "generated_from": "b0cb2d264215faf9759925c631f3d55f" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REGULARIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REGULARIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/l1_regularizer_attrs.dtg.h" +#include "op-attrs/l2_regularizer_attrs.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct RegularizerAttrs { + RegularizerAttrs() = delete; + explicit RegularizerAttrs(::FlexFlow::L1RegularizerAttrs const &); + explicit RegularizerAttrs(::FlexFlow::L2RegularizerAttrs const &); + template + static constexpr bool IsPartOfRegularizerAttrs_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::L1RegularizerAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::L2RegularizerAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type RegularizerAttrs", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::L1RegularizerAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::L2RegularizerAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type RegularizerAttrs", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfRegularizerAttrs_v, + "RegularizerAttrs::has() expected one of " + "[::FlexFlow::L1RegularizerAttrs, " + "::FlexFlow::L2RegularizerAttrs], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfRegularizerAttrs_v, + "RegularizerAttrs::get() expected one of " + "[::FlexFlow::L1RegularizerAttrs, " + "::FlexFlow::L2RegularizerAttrs], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfRegularizerAttrs_v, + "RegularizerAttrs::get() expected one of " + "[::FlexFlow::L1RegularizerAttrs, " + "::FlexFlow::L2RegularizerAttrs], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(RegularizerAttrs const &) const; + bool operator!=(RegularizerAttrs const &) const; + bool operator<(RegularizerAttrs const &) const; + bool operator>(RegularizerAttrs const &) const; + bool operator<=(RegularizerAttrs const &) const; + bool operator>=(RegularizerAttrs const &) const; + std::variant<::FlexFlow::L1RegularizerAttrs, ::FlexFlow::L2RegularizerAttrs> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::RegularizerAttrs> { + size_t operator()(::FlexFlow::RegularizerAttrs const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::RegularizerAttrs> { + static ::FlexFlow::RegularizerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::RegularizerAttrs const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::RegularizerAttrs const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::RegularizerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REGULARIZER_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/regularizer_attrs.h b/lib/op-attrs/include/op-attrs/regularizer_attrs.h deleted file mode 100644 index 22a1c3c0a3..0000000000 --- a/lib/op-attrs/include/op-attrs/regularizer_attrs.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REGULARIZER_ATTRS_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REGULARIZER_ATTRS_H - -#include "op-attrs/l1_regularizer_attrs.h" -#include "op-attrs/l2_regularizer_attrs.h" -#include "utils/json.h" - -namespace FlexFlow { - -using RegularizerAttrs = std::variant; - -CHECK_IS_JSONABLE(RegularizerAttrs); - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml b/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml new file mode 100644 index 0000000000..df974fed91 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "RegularizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "op-attrs/l1_regularizer_attrs.dtg.h", + "op-attrs/l2_regularizer_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::L1RegularizerAttrs" +key = "l1" + +[[values]] +type = "::FlexFlow::L2RegularizerAttrs" +key = "l2" diff --git a/lib/op-attrs/include/op-attrs/tensor_dims_t.h b/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h similarity index 79% rename from lib/op-attrs/include/op-attrs/tensor_dims_t.h rename to lib/op-attrs/include/op-attrs/tensor_dims.dtg.h index 5ae891ffe3..cb67f65c49 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims_t.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h @@ -1,16 +1,20 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/tensor_dims_t.struct.toml +// lib/op-attrs/include/op-attrs/tensor_dims.struct.toml +/* proj-data +{ + "generated_from": "f925a4c2343d2404116dc598c301beaf" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_T_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_T_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/dim_ordered.h" #include #include -#include #include namespace FlexFlow { @@ -48,4 +52,4 @@ std::string format_as(TensorDims const &); std::ostream &operator<<(std::ostream &, TensorDims const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_T_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h index ec6c208331..a0c37139d0 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_H -#include "op-attrs/tensor_dims_t.h" -#include "op-attrs/parallel_tensor_dims_t.h" +#include "op-attrs/tensor_dims.dtg.h" +#include "op-attrs/parallel_tensor_dims.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/tensor_dims_t.struct.toml b/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml similarity index 100% rename from lib/op-attrs/include/op-attrs/tensor_dims_t.struct.toml rename to lib/op-attrs/include/op-attrs/tensor_dims.struct.toml diff --git a/lib/op-attrs/include/op-attrs/tensor_shape_t.h b/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h similarity index 80% rename from lib/op-attrs/include/op-attrs/tensor_shape_t.h rename to lib/op-attrs/include/op-attrs/tensor_shape.dtg.h index 272307c523..2773317607 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape_t.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h @@ -1,9 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/tensor_shape_t.struct.toml +// lib/op-attrs/include/op-attrs/tensor_shape.struct.toml +/* proj-data +{ + "generated_from": "c02c9d2331d864a25c1443cfe70062d1" +} +*/ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SHAPE_T_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SHAPE_T_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SHAPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SHAPE_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" @@ -11,7 +16,6 @@ #include "op-attrs/tensor_dims.h" #include #include -#include #include namespace FlexFlow { @@ -51,4 +55,4 @@ std::string format_as(TensorShape const &); std::ostream &operator<<(std::ostream &, TensorShape const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SHAPE_T_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SHAPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.h b/lib/op-attrs/include/op-attrs/tensor_shape.h index c505bcdc5f..75ab2c2a64 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_OPATTRS_TENSOR_SHAPE_H #define _FLEXFLOW_OPATTRS_TENSOR_SHAPE_H -#include "op-attrs/tensor_shape_t.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/tensor_shape_t.struct.toml b/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml similarity index 100% rename from lib/op-attrs/include/op-attrs/tensor_shape_t.struct.toml rename to lib/op-attrs/include/op-attrs/tensor_shape.struct.toml diff --git a/lib/op-attrs/src/batch_matmul.cc b/lib/op-attrs/src/batch_matmul.cc deleted file mode 100644 index 1cc8c5cfda..0000000000 --- a/lib/op-attrs/src/batch_matmul.cc +++ /dev/null @@ -1,26 +0,0 @@ -#include "op-attrs/ops/batch_matmul.h" - -namespace FlexFlow { - -/* bool BatchMatmulAttrs::is_valid( */ -/* ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const { - */ -/* if (!lhs.is_valid() || !rhs.is_valid()) { */ -/* return false; */ -/* } */ -/* if (lhs.num_dims() != rhs.num_dims()) { */ -/* return false; */ -/* } */ -/* for (int i = lhs.num_dims() - 1; i >= 2; i--) { */ -/* if (lhs.at(i) != rhs.at(i)) { */ -/* return false; */ -/* } */ -/* } */ -/* if (lhs.at(0) != rhs.at(1)) { */ -/* return false; */ -/* } */ - -/* return true; */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/activation.cc b/lib/op-attrs/src/op-attrs/activation.dtg.cc similarity index 95% rename from lib/op-attrs/src/op-attrs/activation.cc rename to lib/op-attrs/src/op-attrs/activation.dtg.cc index de63eb1ec5..5671b1720f 100644 --- a/lib/op-attrs/src/op-attrs/activation.cc +++ b/lib/op-attrs/src/op-attrs/activation.dtg.cc @@ -1,8 +1,13 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/activation.enum.toml +/* proj-data +{ + "generated_from": "2b0d2e3e825732838aa5be99f2f0e6df" +} +*/ -#include "op-attrs/activation.h" +#include "op-attrs/activation.dtg.h" #include #include diff --git a/lib/op-attrs/src/op-attrs/aggregate_op.cc b/lib/op-attrs/src/op-attrs/aggregate_op.dtg.cc similarity index 93% rename from lib/op-attrs/src/op-attrs/aggregate_op.cc rename to lib/op-attrs/src/op-attrs/aggregate_op.dtg.cc index 34b042ac41..72beeb27c8 100644 --- a/lib/op-attrs/src/op-attrs/aggregate_op.cc +++ b/lib/op-attrs/src/op-attrs/aggregate_op.dtg.cc @@ -1,8 +1,13 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/aggregate_op.enum.toml +/* proj-data +{ + "generated_from": "441fe9b0bb8f2dc2b31f74c58320ef30" +} +*/ -#include "op-attrs/aggregate_op.h" +#include "op-attrs/aggregate_op.dtg.h" #include #include diff --git a/lib/op-attrs/src/op-attrs/datatype_t.cc b/lib/op-attrs/src/op-attrs/datatype.dtg.cc similarity index 94% rename from lib/op-attrs/src/op-attrs/datatype_t.cc rename to lib/op-attrs/src/op-attrs/datatype.dtg.cc index ffacdb97ab..a9c1d54f0e 100644 --- a/lib/op-attrs/src/op-attrs/datatype_t.cc +++ b/lib/op-attrs/src/op-attrs/datatype.dtg.cc @@ -1,8 +1,13 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/datatype_t.enum.toml +// lib/op-attrs/include/op-attrs/datatype.enum.toml +/* proj-data +{ + "generated_from": "8315d0aa0a65b00c13aa580e923592ef" +} +*/ -#include "op-attrs/datatype_t.h" +#include "op-attrs/datatype.dtg.h" #include #include diff --git a/lib/op-attrs/src/op-attrs/ff_dim.cc b/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc similarity index 93% rename from lib/op-attrs/src/op-attrs/ff_dim.cc rename to lib/op-attrs/src/op-attrs/ff_dim.dtg.cc index 7c3c3293df..f6a1863fff 100644 --- a/lib/op-attrs/src/op-attrs/ff_dim.cc +++ b/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc @@ -1,8 +1,15 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ff_dim.struct.toml +/* proj-data +{ + "generated_from": "ffd119eb46e048b0f5a2d8fbef253de3" +} +*/ + +#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/ff_dim.h" +#include namespace FlexFlow { ff_dim_t::ff_dim_t(int const &value) : value(value) {} diff --git a/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.cc b/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.dtg.cc similarity index 94% rename from lib/op-attrs/src/op-attrs/l1_regularizer_attrs.cc rename to lib/op-attrs/src/op-attrs/l1_regularizer_attrs.dtg.cc index e5a75b1201..ed06df2c78 100644 --- a/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.cc +++ b/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.dtg.cc @@ -1,8 +1,15 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "50968fb8a3d43395d0eab7594f4935c0" +} +*/ + +#include "op-attrs/l1_regularizer_attrs.dtg.h" -#include "op-attrs/l1_regularizer_attrs.h" +#include namespace FlexFlow { L1RegularizerAttrs::L1RegularizerAttrs(float const &lambda) : lambda(lambda) {} diff --git a/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.cc b/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.dtg.cc similarity index 94% rename from lib/op-attrs/src/op-attrs/l2_regularizer_attrs.cc rename to lib/op-attrs/src/op-attrs/l2_regularizer_attrs.dtg.cc index 45e48044ee..f0f3f34ee5 100644 --- a/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.cc +++ b/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.dtg.cc @@ -1,8 +1,15 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "c4f182e547ab6f0d5613e7eeb95d438e" +} +*/ + +#include "op-attrs/l2_regularizer_attrs.dtg.h" -#include "op-attrs/l2_regularizer_attrs.h" +#include namespace FlexFlow { L2RegularizerAttrs::L2RegularizerAttrs(float const &lambda) : lambda(lambda) {} diff --git a/lib/op-attrs/src/op-attrs/operator_type_t.cc b/lib/op-attrs/src/op-attrs/operator_type.dtg.cc similarity index 99% rename from lib/op-attrs/src/op-attrs/operator_type_t.cc rename to lib/op-attrs/src/op-attrs/operator_type.dtg.cc index 65c083efd0..07b6396a5a 100644 --- a/lib/op-attrs/src/op-attrs/operator_type_t.cc +++ b/lib/op-attrs/src/op-attrs/operator_type.dtg.cc @@ -1,8 +1,13 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/operator_type_t.enum.toml +// lib/op-attrs/include/op-attrs/operator_type.enum.toml +/* proj-data +{ + "generated_from": "c1c4687ef2fbc7dad996e5c25d47124c" +} +*/ -#include "op-attrs/operator_type_t.h" +#include "op-attrs/operator_type.dtg.h" #include #include diff --git a/lib/op-attrs/src/op-attrs/ops/attention_attrs.cc b/lib/op-attrs/src/op-attrs/ops/attention_attrs.dtg.cc similarity index 98% rename from lib/op-attrs/src/op-attrs/ops/attention_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/attention_attrs.dtg.cc index 728359ff25..ad0c094969 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention_attrs.dtg.cc @@ -1,8 +1,15 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml +/* proj-data +{ + "generated_from": "360324465947562229dc6632a9e9a2f3" +} +*/ + +#include "op-attrs/ops/attention_attrs.dtg.h" -#include "op-attrs/ops/attention_attrs.h" +#include namespace FlexFlow { MultiHeadAttentionAttrs::MultiHeadAttentionAttrs(int const &embed_dim, diff --git a/lib/op-attrs/src/op-attrs/ops/attention_inputs.cc b/lib/op-attrs/src/op-attrs/ops/attention_inputs.dtg.cc similarity index 94% rename from lib/op-attrs/src/op-attrs/ops/attention_inputs.cc rename to lib/op-attrs/src/op-attrs/ops/attention_inputs.dtg.cc index 121806c194..d12018acb5 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention_inputs.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention_inputs.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml +/* proj-data +{ + "generated_from": "700f5fb734284b7feabbdd4cb61f3183" +} +*/ + +#include "op-attrs/ops/attention_inputs.dtg.h" -#include "op-attrs/ops/attention_inputs.h" +#include "op-attrs/tensor_shape.h" +#include namespace FlexFlow { MultiHeadAttentionInputs::MultiHeadAttentionInputs( diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc index 157654fa53..28e1f0af0a 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc @@ -1,83 +1,31 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml - #include "op-attrs/ops/batch_matmul.h" namespace FlexFlow { -BatchMatmulAttrs::BatchMatmulAttrs(int const &a_seq_length_dim, - int const &b_seq_length_dim) - : a_seq_length_dim(a_seq_length_dim), b_seq_length_dim(b_seq_length_dim) {} -bool BatchMatmulAttrs::operator==(BatchMatmulAttrs const &other) const { - return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) == - std::tie(other.a_seq_length_dim, other.b_seq_length_dim); -} -bool BatchMatmulAttrs::operator!=(BatchMatmulAttrs const &other) const { - return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) != - std::tie(other.a_seq_length_dim, other.b_seq_length_dim); -} -bool BatchMatmulAttrs::operator<(BatchMatmulAttrs const &other) const { - return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) < - std::tie(other.a_seq_length_dim, other.b_seq_length_dim); -} -bool BatchMatmulAttrs::operator>(BatchMatmulAttrs const &other) const { - return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) > - std::tie(other.a_seq_length_dim, other.b_seq_length_dim); -} -bool BatchMatmulAttrs::operator<=(BatchMatmulAttrs const &other) const { - return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) <= - std::tie(other.a_seq_length_dim, other.b_seq_length_dim); -} -bool BatchMatmulAttrs::operator>=(BatchMatmulAttrs const &other) const { - return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) >= - std::tie(other.a_seq_length_dim, other.b_seq_length_dim); -} -} // namespace FlexFlow -namespace std { -size_t hash::operator()( - FlexFlow::BatchMatmulAttrs const &x) const { - size_t result = 0; - result ^= std::hash{}(x.a_seq_length_dim) + 0x9e3779b9 + (result << 6) + - (result >> 2); - result ^= std::hash{}(x.b_seq_length_dim) + 0x9e3779b9 + (result << 6) + - (result >> 2); - return result; -} -} // namespace std +/* bool BatchMatmulAttrs::is_valid( */ +/* ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const { + */ +/* if (!lhs.is_valid() || !rhs.is_valid()) { */ +/* return false; */ +/* } */ +/* if (lhs.num_dims() != rhs.num_dims()) { */ +/* return false; */ +/* } */ +/* for (int i = lhs.num_dims() - 1; i >= 2; i--) { */ +/* if (lhs.at(i) != rhs.at(i)) { */ +/* return false; */ +/* } */ +/* } */ +/* if (lhs.at(0) != rhs.at(1)) { */ +/* return false; */ +/* } */ -namespace nlohmann { -FlexFlow::BatchMatmulAttrs - adl_serializer::from_json(json const &j) { - return {j.at("a_seq_length_dim").template get(), - j.at("b_seq_length_dim").template get()}; -} -void adl_serializer::to_json( - json &j, FlexFlow::BatchMatmulAttrs const &v) { - j["__type"] = "BatchMatmulAttrs"; - j["a_seq_length_dim"] = v.a_seq_length_dim; - j["b_seq_length_dim"] = v.b_seq_length_dim; -} -} // namespace nlohmann +/* return true; */ +/* } */ -namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary(), - gen::arbitrary()); -} -} // namespace rc -namespace FlexFlow { -std::string format_as(BatchMatmulAttrs const &x) { - std::ostringstream oss; - oss << ""; - return oss.str(); -} -std::ostream &operator<<(std::ostream &s, BatchMatmulAttrs const &x) { - return s << fmt::to_string(x); +bool is_valid(BatchMatmulAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &) { + NOT_IMPLEMENTED(); } + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.dtg.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.dtg.cc new file mode 100644 index 0000000000..f178d40696 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/batch_matmul.dtg.cc @@ -0,0 +1,90 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml +/* proj-data +{ + "generated_from": "c3bbf4c76982ef27107b74e1e6e5d360" +} +*/ + +#include "op-attrs/ops/batch_matmul.dtg.h" + +#include + +namespace FlexFlow { +BatchMatmulAttrs::BatchMatmulAttrs(int const &a_seq_length_dim, + int const &b_seq_length_dim) + : a_seq_length_dim(a_seq_length_dim), b_seq_length_dim(b_seq_length_dim) {} +bool BatchMatmulAttrs::operator==(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) == + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator!=(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) != + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator<(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) < + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator>(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) > + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator<=(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) <= + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator>=(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) >= + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::BatchMatmulAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.a_seq_length_dim) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.b_seq_length_dim) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::BatchMatmulAttrs + adl_serializer::from_json(json const &j) { + return {j.at("a_seq_length_dim").template get(), + j.at("b_seq_length_dim").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::BatchMatmulAttrs const &v) { + j["__type"] = "BatchMatmulAttrs"; + j["a_seq_length_dim"] = v.a_seq_length_dim; + j["b_seq_length_dim"] = v.b_seq_length_dim; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(BatchMatmulAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, BatchMatmulAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc new file mode 100644 index 0000000000..9152a1306c --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/batch_norm.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(BatchNormAttrs const &, ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.dtg.cc similarity index 93% rename from lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.dtg.cc index 84ae40115d..cb8dcadae1 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.dtg.cc @@ -1,8 +1,15 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml +/* proj-data +{ + "generated_from": "f8e0219d8a3e008a73c38cf84d25f66e" +} +*/ + +#include "op-attrs/ops/batch_norm_attrs.dtg.h" -#include "op-attrs/ops/batch_norm_attrs.h" +#include namespace FlexFlow { BatchNormAttrs::BatchNormAttrs(bool const &relu) : relu(relu) {} diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc similarity index 93% rename from lib/op-attrs/src/op-attrs/ops/broadcast.cc rename to lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc index 823df8e01b..dadb8d4cff 100644 --- a/lib/op-attrs/src/op-attrs/ops/broadcast.cc +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml +/* proj-data +{ + "generated_from": "890d0e63a08a30d925aa170aea6992ba" +} +*/ + +#include "op-attrs/ops/broadcast.dtg.h" -#include "op-attrs/ops/broadcast.h" +#include "utils/stack_vector.h" +#include namespace FlexFlow { BroadcastAttrs::BroadcastAttrs( diff --git a/lib/op-attrs/src/op-attrs/ops/cast_attrs.cc b/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc similarity index 91% rename from lib/op-attrs/src/op-attrs/ops/cast_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc index 8ce883341c..63b0cda27b 100644 --- a/lib/op-attrs/src/op-attrs/ops/cast_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml +/* proj-data +{ + "generated_from": "62da4845a8aa0ae4ca3bce432a3aa9a3" +} +*/ + +#include "op-attrs/ops/cast_attrs.dtg.h" -#include "op-attrs/ops/cast_attrs.h" +#include "op-attrs/datatype.h" +#include namespace FlexFlow { CastAttrs::CastAttrs(DataType const &dtype) : dtype(dtype) {} diff --git a/lib/op-attrs/src/op-attrs/ops/combine_attrs.cc b/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc similarity index 94% rename from lib/op-attrs/src/op-attrs/ops/combine_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc index c0cf8e0d08..a652537871 100644 --- a/lib/op-attrs/src/op-attrs/ops/combine_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml +/* proj-data +{ + "generated_from": "7caa0f9668b1894f5e446556f1a424c8" +} +*/ + +#include "op-attrs/ops/combine_attrs.dtg.h" -#include "op-attrs/ops/combine_attrs.h" +#include "op-attrs/ff_dim.dtg.h" +#include namespace FlexFlow { CombineAttrs::CombineAttrs(::FlexFlow::ff_dim_t const &combine_dim, diff --git a/lib/op-attrs/src/op-attrs/ops/concat_attrs.cc b/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc similarity index 93% rename from lib/op-attrs/src/op-attrs/ops/concat_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc index 8e16552fa0..0494b7069a 100644 --- a/lib/op-attrs/src/op-attrs/ops/concat_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +/* proj-data +{ + "generated_from": "b72ef29f9f79a917176c63a5c3683ab5" +} +*/ + +#include "op-attrs/ops/concat_attrs.dtg.h" -#include "op-attrs/ops/concat_attrs.h" +#include "op-attrs/ff_dim.dtg.h" +#include namespace FlexFlow { ConcatAttrs::ConcatAttrs(::FlexFlow::ff_dim_t const &axis, diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc similarity index 97% rename from lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc index 5085b3e121..e0d1e654aa 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc @@ -1,8 +1,18 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml +/* proj-data +{ + "generated_from": "85f65c1b0e0340ea8e8622c2bf9ca38d" +} +*/ + +#include "op-attrs/ops/conv_2d_attrs.dtg.h" -#include "op-attrs/ops/conv_2d_attrs.h" +#include "op-attrs/activation.dtg.h" +#include "utils/json.h" +#include +#include namespace FlexFlow { Conv2DAttrs::Conv2DAttrs( diff --git a/lib/op-attrs/src/op-attrs/ops/dropout.cc b/lib/op-attrs/src/op-attrs/ops/dropout.cc new file mode 100644 index 0000000000..8dd381f65a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/dropout.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/dropout.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(DropoutAttrs const &, ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/dropout_attrs.cc b/lib/op-attrs/src/op-attrs/ops/dropout_attrs.dtg.cc similarity index 94% rename from lib/op-attrs/src/op-attrs/ops/dropout_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/dropout_attrs.dtg.cc index 308725c0b0..284443a0e4 100644 --- a/lib/op-attrs/src/op-attrs/ops/dropout_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/dropout_attrs.dtg.cc @@ -1,8 +1,15 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml +/* proj-data +{ + "generated_from": "4fdbf129ea59b8a7306813cfa4c46021" +} +*/ + +#include "op-attrs/ops/dropout_attrs.dtg.h" -#include "op-attrs/ops/dropout_attrs.h" +#include namespace FlexFlow { DropoutAttrs::DropoutAttrs(float const &rate, unsigned long long const &seed) diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.cc b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc similarity index 95% rename from lib/op-attrs/src/op-attrs/ops/element_binary_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc index b14920f803..bdaef6511f 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc @@ -1,8 +1,17 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml +/* proj-data +{ + "generated_from": "1aae4139632791a4b7638e59fa6b5dc8" +} +*/ + +#include "op-attrs/ops/element_binary_attrs.dtg.h" -#include "op-attrs/ops/element_binary_attrs.h" +#include "op-attrs/datatype.h" +#include "op-attrs/operator_type.h" +#include namespace FlexFlow { ElementBinaryAttrs::ElementBinaryAttrs(::FlexFlow::OperatorType const &type, diff --git a/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.cc b/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc similarity index 93% rename from lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc index 4c9d59e568..36c26653d4 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml +/* proj-data +{ + "generated_from": "09554c353caed6075e362da5008c4bd2" +} +*/ + +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" -#include "op-attrs/ops/element_scalar_unary_attrs.h" +#include "op-attrs/operator_type.h" +#include namespace FlexFlow { ElementScalarUnaryAttrs::ElementScalarUnaryAttrs( diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.cc b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc similarity index 92% rename from lib/op-attrs/src/op-attrs/ops/element_unary_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc index fa9a4a1697..b5968ed425 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml +/* proj-data +{ + "generated_from": "fdb867c04cdd7de320f573f360bcab90" +} +*/ + +#include "op-attrs/ops/element_unary_attrs.dtg.h" -#include "op-attrs/ops/element_unary_attrs.h" +#include "op-attrs/operator_type.h" +#include namespace FlexFlow { ElementUnaryAttrs::ElementUnaryAttrs(::FlexFlow::OperatorType const &op_type) diff --git a/lib/op-attrs/src/embedding.cc b/lib/op-attrs/src/op-attrs/ops/embedding.cc similarity index 61% rename from lib/op-attrs/src/embedding.cc rename to lib/op-attrs/src/op-attrs/ops/embedding.cc index 56014fcc67..2a55266a7f 100644 --- a/lib/op-attrs/src/embedding.cc +++ b/lib/op-attrs/src/op-attrs/ops/embedding.cc @@ -6,4 +6,8 @@ TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &) { NOT_IMPLEMENTED(); } +ParallelTensorShape get_output_shape(EmbeddingAttrs const &, ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/embedding_attrs.cc b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc similarity index 95% rename from lib/op-attrs/src/op-attrs/ops/embedding_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc index 6cc49ece0b..a9110f16bc 100644 --- a/lib/op-attrs/src/op-attrs/ops/embedding_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc @@ -1,8 +1,18 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml +/* proj-data +{ + "generated_from": "65af6a38dfabebbc05c8ad3f75397b07" +} +*/ + +#include "op-attrs/ops/embedding_attrs.dtg.h" -#include "op-attrs/ops/embedding_attrs.h" +#include "op-attrs/aggregate_op.dtg.h" +#include "op-attrs/datatype.dtg.h" +#include "utils/stack_vector.h" +#include namespace FlexFlow { EmbeddingAttrs::EmbeddingAttrs(int const &num_entries, diff --git a/lib/op-attrs/src/op-attrs/ops/flat_attrs.cc b/lib/op-attrs/src/op-attrs/ops/flat_attrs.dtg.cc similarity index 92% rename from lib/op-attrs/src/op-attrs/ops/flat_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/flat_attrs.dtg.cc index 8705a18c86..ef34d97a89 100644 --- a/lib/op-attrs/src/op-attrs/ops/flat_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat_attrs.dtg.cc @@ -1,8 +1,15 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml +/* proj-data +{ + "generated_from": "b63924cd671481df30fae314a199c606" +} +*/ + +#include "op-attrs/ops/flat_attrs.dtg.h" -#include "op-attrs/ops/flat_attrs.h" +#include namespace FlexFlow { bool FlatAttrs::operator==(FlatAttrs const &other) const { diff --git a/lib/op-attrs/src/op-attrs/ops/gather_attrs.cc b/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc similarity index 91% rename from lib/op-attrs/src/op-attrs/ops/gather_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc index 05d6c674a7..886794a9b1 100644 --- a/lib/op-attrs/src/op-attrs/ops/gather_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml +/* proj-data +{ + "generated_from": "ee735644d3c5f53f790e0a1fa8b8beaf" +} +*/ + +#include "op-attrs/ops/gather_attrs.dtg.h" -#include "op-attrs/ops/gather_attrs.h" +#include "op-attrs/ff_dim.dtg.h" +#include namespace FlexFlow { GatherAttrs::GatherAttrs(::FlexFlow::ff_dim_t const &dim) : dim(dim) {} diff --git a/lib/op-attrs/src/op-attrs/ops/input.cc b/lib/op-attrs/src/op-attrs/ops/input.cc new file mode 100644 index 0000000000..93606b603a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/input.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/input.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(InputAttrs const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/input_attrs.cc b/lib/op-attrs/src/op-attrs/ops/input_attrs.dtg.cc similarity index 92% rename from lib/op-attrs/src/op-attrs/ops/input_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/input_attrs.dtg.cc index 21ec0a3ba5..35544402f7 100644 --- a/lib/op-attrs/src/op-attrs/ops/input_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/input_attrs.dtg.cc @@ -1,8 +1,15 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml +/* proj-data +{ + "generated_from": "139ea46d57a3c8738b31b17a8c59a0aa" +} +*/ + +#include "op-attrs/ops/input_attrs.dtg.h" -#include "op-attrs/ops/input_attrs.h" +#include namespace FlexFlow { bool InputAttrs::operator==(InputAttrs const &other) const { diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc new file mode 100644 index 0000000000..d072fb8b17 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/layer_norm.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(LayerNormAttrs const &, ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc similarity index 94% rename from lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc index 714d1f0c3c..d3c4e0c57e 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc @@ -1,8 +1,17 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml +/* proj-data +{ + "generated_from": "c03d823a6e889e1254b73a0730a71046" +} +*/ + +#include "op-attrs/ops/layer_norm_attrs.dtg.h" -#include "op-attrs/ops/layer_norm_attrs.h" +#include "op-attrs/ff_dim.dtg.h" +#include "utils/stack_vector.h" +#include namespace FlexFlow { LayerNormAttrs::LayerNormAttrs( diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc new file mode 100644 index 0000000000..0ecd601e00 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/linear.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(LinearAttrs const &, ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/linear_attrs.cc b/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc similarity index 95% rename from lib/op-attrs/src/op-attrs/ops/linear_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc index 21d80d1f88..961222843d 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc @@ -1,8 +1,19 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml +/* proj-data +{ + "generated_from": "dae07c937f6c52d4dc89ec322520e29f" +} +*/ + +#include "op-attrs/ops/linear_attrs.dtg.h" -#include "op-attrs/ops/linear_attrs.h" +#include "op-attrs/activation.dtg.h" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/regularizer_attrs.dtg.h" +#include "utils/json.h" +#include namespace FlexFlow { LinearAttrs::LinearAttrs( diff --git a/lib/op-attrs/src/op-attrs/ops/noop.cc b/lib/op-attrs/src/op-attrs/ops/noop.cc new file mode 100644 index 0000000000..1b243a388a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/noop.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/noop.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(NoopAttrs const &, ParallelTensorShape const &input_shape) { + return input_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/noop_attrs.cc b/lib/op-attrs/src/op-attrs/ops/noop_attrs.dtg.cc similarity index 92% rename from lib/op-attrs/src/op-attrs/ops/noop_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/noop_attrs.dtg.cc index 1c42670d4d..3ef3a0119b 100644 --- a/lib/op-attrs/src/op-attrs/ops/noop_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/noop_attrs.dtg.cc @@ -1,8 +1,15 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml +/* proj-data +{ + "generated_from": "d440077aa598fdad0e5aa95288b63c40" +} +*/ + +#include "op-attrs/ops/noop_attrs.dtg.h" -#include "op-attrs/ops/noop_attrs.h" +#include namespace FlexFlow { bool NoopAttrs::operator==(NoopAttrs const &other) const { diff --git a/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.cc b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc similarity index 94% rename from lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.cc rename to lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc index a2dd992600..e1837a7360 100644 --- a/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.cc +++ b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml +/* proj-data +{ + "generated_from": "722d92014b31bffcd5ad45eda476d8b3" +} +*/ + +#include "op-attrs/ops/parallel_attention_inputs.dtg.h" -#include "op-attrs/ops/parallel_attention_inputs.h" +#include "op-attrs/parallel_tensor_shape.h" +#include namespace FlexFlow { ParallelMultiHeadAttentionInputs::ParallelMultiHeadAttentionInputs( diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc new file mode 100644 index 0000000000..a9ca71a060 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/pool_2d.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc similarity index 97% rename from lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc index b186433f1f..3316e4c136 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc @@ -1,8 +1,17 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml +/* proj-data +{ + "generated_from": "607be08f56d910bfa340fb180646c126" +} +*/ + +#include "op-attrs/ops/pool_2d_attrs.dtg.h" -#include "op-attrs/ops/pool_2d_attrs.h" +#include "op-attrs/activation.dtg.h" +#include "op-attrs/pool_op.dtg.h" +#include namespace FlexFlow { Pool2DAttrs::Pool2DAttrs(int const &kernel_h, diff --git a/lib/op-attrs/src/op-attrs/ops/reduce.cc b/lib/op-attrs/src/op-attrs/ops/reduce.cc new file mode 100644 index 0000000000..a08fb4128e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reduce.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/reduce.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ReduceAttrs const &, ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reduce_attrs.cc b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc similarity index 92% rename from lib/op-attrs/src/op-attrs/ops/reduce_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc index 3cb71259e1..004beb7c64 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduce_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc @@ -1,8 +1,18 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml +/* proj-data +{ + "generated_from": "bc6279031650335f4a0b7b6cfe116c85" +} +*/ + +#include "op-attrs/ops/reduce_attrs.dtg.h" -#include "op-attrs/ops/reduce_attrs.h" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/operator_type.dtg.h" +#include "utils/stack_vector.h" +#include namespace FlexFlow { ReduceAttrs::ReduceAttrs( diff --git a/lib/op-attrs/src/op-attrs/ops/reduction.cc b/lib/op-attrs/src/op-attrs/ops/reduction.cc new file mode 100644 index 0000000000..2396772a94 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reduction.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/reduction.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ReductionAttrs const &attrs, ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reduction_attrs.cc b/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc similarity index 94% rename from lib/op-attrs/src/op-attrs/ops/reduction_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc index 6876cf863d..a7cc019111 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduction_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml +/* proj-data +{ + "generated_from": "57b8ccb5bc2e1a1a3bcf1bce2d8cad9e" +} +*/ + +#include "op-attrs/ops/reduction_attrs.dtg.h" -#include "op-attrs/ops/reduction_attrs.h" +#include "op-attrs/ff_dim.dtg.h" +#include namespace FlexFlow { ReductionAttrs::ReductionAttrs(::FlexFlow::ff_dim_t const &reduction_dim, diff --git a/lib/op-attrs/src/op-attrs/ops/repartition.cc b/lib/op-attrs/src/op-attrs/ops/repartition.cc new file mode 100644 index 0000000000..45b54df80b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/repartition.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/repartition.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(RepartitionAttrs const &, ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/repartition_attrs.cc b/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc similarity index 94% rename from lib/op-attrs/src/op-attrs/ops/repartition_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc index 7abd7d1959..5ff0f44f44 100644 --- a/lib/op-attrs/src/op-attrs/ops/repartition_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml +/* proj-data +{ + "generated_from": "366cb1a14093762f75508260ac6494ca" +} +*/ + +#include "op-attrs/ops/repartition_attrs.dtg.h" -#include "op-attrs/ops/repartition_attrs.h" +#include "op-attrs/ff_dim.dtg.h" +#include namespace FlexFlow { RepartitionAttrs::RepartitionAttrs(::FlexFlow::ff_dim_t const &repartition_dim, diff --git a/lib/op-attrs/src/op-attrs/ops/replicate.cc b/lib/op-attrs/src/op-attrs/ops/replicate.cc new file mode 100644 index 0000000000..a639a51f15 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/replicate.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/replicate.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/replicate_attrs.cc b/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc similarity index 94% rename from lib/op-attrs/src/op-attrs/ops/replicate_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc index ea7802a325..bf92a0b656 100644 --- a/lib/op-attrs/src/op-attrs/ops/replicate_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml +/* proj-data +{ + "generated_from": "4224406d468444433d69e4abf61b7cd1" +} +*/ + +#include "op-attrs/ops/replicate_attrs.dtg.h" -#include "op-attrs/ops/replicate_attrs.h" +#include "op-attrs/ff_dim.dtg.h" +#include namespace FlexFlow { ReplicateAttrs::ReplicateAttrs(::FlexFlow::ff_dim_t const &replicate_dim, diff --git a/lib/op-attrs/src/op-attrs/ops/reshape.cc b/lib/op-attrs/src/op-attrs/ops/reshape.cc new file mode 100644 index 0000000000..49ec940525 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reshape.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/reshape.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reshape_attrs.cc b/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc similarity index 91% rename from lib/op-attrs/src/op-attrs/ops/reshape_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc index b9953e33f4..2c5509a655 100644 --- a/lib/op-attrs/src/op-attrs/ops/reshape_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml +/* proj-data +{ + "generated_from": "5a6a9e646a457a6cf959c542fb631512" +} +*/ + +#include "op-attrs/ops/reshape_attrs.dtg.h" -#include "op-attrs/ops/reshape_attrs.h" +#include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { ReshapeAttrs::ReshapeAttrs(::FlexFlow::TensorShape const &shape) diff --git a/lib/op-attrs/src/op-attrs/ops/reverse.cc b/lib/op-attrs/src/op-attrs/ops/reverse.cc new file mode 100644 index 0000000000..5afd3e726e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reverse.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/reverse.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reverse_attrs.cc b/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc similarity index 92% rename from lib/op-attrs/src/op-attrs/ops/reverse_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc index 1b8cbd715e..61122313b0 100644 --- a/lib/op-attrs/src/op-attrs/ops/reverse_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml +/* proj-data +{ + "generated_from": "7c21c4192854f5981018abf4fbdd9ead" +} +*/ + +#include "op-attrs/ops/reverse_attrs.dtg.h" -#include "op-attrs/ops/reverse_attrs.h" +#include "op-attrs/ff_dim.dtg.h" +#include namespace FlexFlow { ReverseAttrs::ReverseAttrs(::FlexFlow::ff_dim_t const &axis) : axis(axis) {} diff --git a/lib/op-attrs/src/op-attrs/ops/softmax.cc b/lib/op-attrs/src/op-attrs/ops/softmax.cc new file mode 100644 index 0000000000..05d6645637 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/softmax.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/softmax.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(SoftmaxAttrs const &attrs, ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/softmax_attrs.cc b/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc similarity index 91% rename from lib/op-attrs/src/op-attrs/ops/softmax_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc index 775a7d46a7..6b685b3de2 100644 --- a/lib/op-attrs/src/op-attrs/ops/softmax_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml +/* proj-data +{ + "generated_from": "9be043678a4ce7666fc372cded600290" +} +*/ + +#include "op-attrs/ops/softmax_attrs.dtg.h" -#include "op-attrs/ops/softmax_attrs.h" +#include "op-attrs/ff_dim.dtg.h" +#include namespace FlexFlow { SoftmaxAttrs::SoftmaxAttrs(::FlexFlow::ff_dim_t const &dim) : dim(dim) {} diff --git a/lib/op-attrs/src/op-attrs/ops/split.cc b/lib/op-attrs/src/op-attrs/ops/split.cc new file mode 100644 index 0000000000..bb3c35c645 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/split.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/split.h" + +namespace FlexFlow { + +std::vector get_output_shapes(SplitAttrs const &attrs, ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/split_attrs.cc b/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc similarity index 92% rename from lib/op-attrs/src/op-attrs/ops/split_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc index 5761b7594f..8ca4518d17 100644 --- a/lib/op-attrs/src/op-attrs/ops/split_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc @@ -1,8 +1,17 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml +/* proj-data +{ + "generated_from": "4112baa96de544b865618e0a999e0807" +} +*/ + +#include "op-attrs/ops/split_attrs.dtg.h" -#include "op-attrs/ops/split_attrs.h" +#include "op-attrs/ff_dim.dtg.h" +#include "utils/stack_vector.h" +#include namespace FlexFlow { SplitAttrs::SplitAttrs( diff --git a/lib/op-attrs/src/op-attrs/ops/topk.cc b/lib/op-attrs/src/op-attrs/ops/topk.cc new file mode 100644 index 0000000000..5e3607286d --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/topk.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/topk.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(TopKAttrs const &attrs, ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/topk_attrs.cc b/lib/op-attrs/src/op-attrs/ops/topk_attrs.dtg.cc similarity index 94% rename from lib/op-attrs/src/op-attrs/ops/topk_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/topk_attrs.dtg.cc index 2105b6e716..55ead7d858 100644 --- a/lib/op-attrs/src/op-attrs/ops/topk_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/topk_attrs.dtg.cc @@ -1,8 +1,15 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml +/* proj-data +{ + "generated_from": "c1be9dc2acafc58690713e650663cc93" +} +*/ + +#include "op-attrs/ops/topk_attrs.dtg.h" -#include "op-attrs/ops/topk_attrs.h" +#include namespace FlexFlow { TopKAttrs::TopKAttrs(int const &k, bool const &sorted) : k(k), sorted(sorted) {} diff --git a/lib/op-attrs/src/op-attrs/ops/transpose.cc b/lib/op-attrs/src/op-attrs/ops/transpose.cc new file mode 100644 index 0000000000..a8ce715f99 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/transpose.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/transpose.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs, ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/transpose_attrs.cc b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc similarity index 91% rename from lib/op-attrs/src/op-attrs/ops/transpose_attrs.cc rename to lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc index e046896753..7463c6b3de 100644 --- a/lib/op-attrs/src/op-attrs/ops/transpose_attrs.cc +++ b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc @@ -1,8 +1,17 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml +/* proj-data +{ + "generated_from": "edff0b414040204e895666d81b49db07" +} +*/ + +#include "op-attrs/ops/transpose_attrs.dtg.h" -#include "op-attrs/ops/transpose_attrs.h" +#include "op-attrs/ff_dim.dtg.h" +#include "utils/stack_vector.h" +#include namespace FlexFlow { TransposeAttrs::TransposeAttrs( diff --git a/lib/op-attrs/src/op-attrs/parallel_dim_t.cc b/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc similarity index 94% rename from lib/op-attrs/src/op-attrs/parallel_dim_t.cc rename to lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc index a2fd7f686f..df88de73ff 100644 --- a/lib/op-attrs/src/op-attrs/parallel_dim_t.cc +++ b/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc @@ -1,8 +1,15 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/parallel_dim_t.struct.toml +// lib/op-attrs/include/op-attrs/parallel_dim.struct.toml +/* proj-data +{ + "generated_from": "186bedde7826c7a3d00343ed63ab9971" +} +*/ + +#include "op-attrs/parallel_dim.dtg.h" -#include "op-attrs/parallel_dim_t.h" +#include namespace FlexFlow { ParallelDim::ParallelDim(size_t const &size, diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims_t.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc similarity index 89% rename from lib/op-attrs/src/op-attrs/parallel_tensor_dims_t.cc rename to lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc index b70a18f4e0..e4e8f0106a 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims_t.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc @@ -1,8 +1,17 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/parallel_tensor_dims_t.struct.toml +// lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml +/* proj-data +{ + "generated_from": "b46ffa08758bdcc57a75183255248ca6" +} +*/ + +#include "op-attrs/parallel_tensor_dims.dtg.h" -#include "op-attrs/parallel_tensor_dims_t.h" +#include "op-attrs/dim_ordered.h" +#include "op-attrs/parallel_dim.h" +#include namespace FlexFlow { ParallelTensorDims::ParallelTensorDims( diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape_t.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc similarity index 90% rename from lib/op-attrs/src/op-attrs/parallel_tensor_shape_t.cc rename to lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc index 27f14673db..037acbf996 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape_t.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc @@ -1,8 +1,17 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/parallel_tensor_shape_t.struct.toml +// lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml +/* proj-data +{ + "generated_from": "b2d36c9212916e66569af4e958c893f4" +} +*/ + +#include "op-attrs/parallel_tensor_shape.dtg.h" -#include "op-attrs/parallel_tensor_shape_t.h" +#include "op-attrs/datatype.h" +#include "op-attrs/parallel_tensor_dims.h" +#include namespace FlexFlow { ParallelTensorShape::ParallelTensorShape( diff --git a/lib/op-attrs/src/op-attrs/param_sync.dtg.cc b/lib/op-attrs/src/op-attrs/param_sync.dtg.cc new file mode 100644 index 0000000000..e0d13fdd2e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/param_sync.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/param_sync.enum.toml +/* proj-data +{ + "generated_from": "288c6e9e256cf58ba5dbd0e3791c08df" +} +*/ + +#include "op-attrs/param_sync.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::ParamSync x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(ParamSync x) { + switch (x) { + case ParamSync::PS: + return "PS"; + case ParamSync::NCCL: + return "NCCL"; + default: + std::ostringstream oss; + oss << "Unknown ParamSync value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, ParamSync x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, ParamSync x) { + switch (x) { + case ParamSync::PS: + j = "PS"; + break; + case ParamSync::NCCL: + j = "NCCL"; + break; + default: + std::ostringstream oss; + oss << "Unknown ParamSync value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, ParamSync &x) { + std::string as_str = j.get(); + if (as_str == "PS") { + x = ParamSync::PS; + } else if (as_str == "NCCL") { + x = ParamSync::NCCL; + } else { + std::ostringstream oss; + oss << "Unknown ParamSync value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::ParamSync::PS, + FlexFlow::ParamSync::NCCL); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc new file mode 100644 index 0000000000..5d915ab437 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc @@ -0,0 +1,476 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml +/* proj-data +{ + "generated_from": "e1b5c307ae023ce6d504f605c7ef8491" +} +*/ + +#include "op-attrs/pcg_operator_attrs.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::BatchMatmulAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::BatchNormAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::CastAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ConcatAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::Conv2DAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::DropoutAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ElementBinaryAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ElementUnaryAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ElementScalarUnaryAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::EmbeddingAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::FlatAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::GatherAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::InputAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::LayerNormAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::LinearAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::MultiHeadAttentionAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::NoopAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::Pool2DAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReduceAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReverseAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReshapeAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::SplitAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::SoftmaxAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::TopKAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::TransposeAttrs const &v) + : raw_variant(v) {} +bool PCGOperatorAttrs::operator==(PCGOperatorAttrs const &other) const { + return this->raw_variant == other.raw_variant; +} +bool PCGOperatorAttrs::operator!=(PCGOperatorAttrs const &other) const { + return this->raw_variant != other.raw_variant; +} +bool PCGOperatorAttrs::operator<(PCGOperatorAttrs const &other) const { + return this->raw_variant < other.raw_variant; +} +bool PCGOperatorAttrs::operator>(PCGOperatorAttrs const &other) const { + return this->raw_variant > other.raw_variant; +} +bool PCGOperatorAttrs::operator<=(PCGOperatorAttrs const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool PCGOperatorAttrs::operator>=(PCGOperatorAttrs const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::PCGOperatorAttrs>::operator()( + ::FlexFlow::PCGOperatorAttrs const &x) const { + return std::hash>{}(x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::PCGOperatorAttrs + adl_serializer<::FlexFlow::PCGOperatorAttrs>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "::FlexFlow::BatchMatmulAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::BatchMatmulAttrs>()}; + } else if (key == "::FlexFlow::BatchNormAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::BatchNormAttrs>()}; + } else if (key == "::FlexFlow::CastAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::CastAttrs>()}; + } else if (key == "::FlexFlow::ConcatAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ConcatAttrs>()}; + } else if (key == "::FlexFlow::Conv2DAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::Conv2DAttrs>()}; + } else if (key == "::FlexFlow::DropoutAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::DropoutAttrs>()}; + } else if (key == "::FlexFlow::ElementBinaryAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ElementBinaryAttrs>()}; + } else if (key == "::FlexFlow::ElementUnaryAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ElementUnaryAttrs>()}; + } else if (key == "::FlexFlow::ElementScalarUnaryAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ElementScalarUnaryAttrs>()}; + } else if (key == "::FlexFlow::EmbeddingAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::EmbeddingAttrs>()}; + } else if (key == "::FlexFlow::FlatAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::FlatAttrs>()}; + } else if (key == "::FlexFlow::GatherAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::GatherAttrs>()}; + } else if (key == "::FlexFlow::InputAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::InputAttrs>()}; + } else if (key == "::FlexFlow::LayerNormAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::LayerNormAttrs>()}; + } else if (key == "::FlexFlow::LinearAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::LinearAttrs>()}; + } else if (key == "::FlexFlow::MultiHeadAttentionAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::MultiHeadAttentionAttrs>()}; + } else if (key == "::FlexFlow::NoopAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::NoopAttrs>()}; + } else if (key == "::FlexFlow::Pool2DAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::Pool2DAttrs>()}; + } else if (key == "::FlexFlow::ReduceAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ReduceAttrs>()}; + } else if (key == "::FlexFlow::ReverseAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ReverseAttrs>()}; + } else if (key == "::FlexFlow::ReshapeAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ReshapeAttrs>()}; + } else if (key == "::FlexFlow::SplitAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::SplitAttrs>()}; + } else if (key == "::FlexFlow::SoftmaxAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::SoftmaxAttrs>()}; + } else if (key == "::FlexFlow::TopKAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::TopKAttrs>()}; + } else if (key == "::FlexFlow::TransposeAttrs") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::TransposeAttrs>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::PCGOperatorAttrs>::to_json( + json &j, ::FlexFlow::PCGOperatorAttrs const &x) { + j["__type"] = "PCGOperatorAttrs"; + switch (x.index()) { + case 0: { + j["type"] = "::FlexFlow::BatchMatmulAttrs"; + j["value"] = x.get<::FlexFlow::BatchMatmulAttrs>(); + break; + } + case 1: { + j["type"] = "::FlexFlow::BatchNormAttrs"; + j["value"] = x.get<::FlexFlow::BatchNormAttrs>(); + break; + } + case 2: { + j["type"] = "::FlexFlow::CastAttrs"; + j["value"] = x.get<::FlexFlow::CastAttrs>(); + break; + } + case 3: { + j["type"] = "::FlexFlow::ConcatAttrs"; + j["value"] = x.get<::FlexFlow::ConcatAttrs>(); + break; + } + case 4: { + j["type"] = "::FlexFlow::Conv2DAttrs"; + j["value"] = x.get<::FlexFlow::Conv2DAttrs>(); + break; + } + case 5: { + j["type"] = "::FlexFlow::DropoutAttrs"; + j["value"] = x.get<::FlexFlow::DropoutAttrs>(); + break; + } + case 6: { + j["type"] = "::FlexFlow::ElementBinaryAttrs"; + j["value"] = x.get<::FlexFlow::ElementBinaryAttrs>(); + break; + } + case 7: { + j["type"] = "::FlexFlow::ElementUnaryAttrs"; + j["value"] = x.get<::FlexFlow::ElementUnaryAttrs>(); + break; + } + case 8: { + j["type"] = "::FlexFlow::ElementScalarUnaryAttrs"; + j["value"] = x.get<::FlexFlow::ElementScalarUnaryAttrs>(); + break; + } + case 9: { + j["type"] = "::FlexFlow::EmbeddingAttrs"; + j["value"] = x.get<::FlexFlow::EmbeddingAttrs>(); + break; + } + case 10: { + j["type"] = "::FlexFlow::FlatAttrs"; + j["value"] = x.get<::FlexFlow::FlatAttrs>(); + break; + } + case 11: { + j["type"] = "::FlexFlow::GatherAttrs"; + j["value"] = x.get<::FlexFlow::GatherAttrs>(); + break; + } + case 12: { + j["type"] = "::FlexFlow::InputAttrs"; + j["value"] = x.get<::FlexFlow::InputAttrs>(); + break; + } + case 13: { + j["type"] = "::FlexFlow::LayerNormAttrs"; + j["value"] = x.get<::FlexFlow::LayerNormAttrs>(); + break; + } + case 14: { + j["type"] = "::FlexFlow::LinearAttrs"; + j["value"] = x.get<::FlexFlow::LinearAttrs>(); + break; + } + case 15: { + j["type"] = "::FlexFlow::MultiHeadAttentionAttrs"; + j["value"] = x.get<::FlexFlow::MultiHeadAttentionAttrs>(); + break; + } + case 16: { + j["type"] = "::FlexFlow::NoopAttrs"; + j["value"] = x.get<::FlexFlow::NoopAttrs>(); + break; + } + case 17: { + j["type"] = "::FlexFlow::Pool2DAttrs"; + j["value"] = x.get<::FlexFlow::Pool2DAttrs>(); + break; + } + case 18: { + j["type"] = "::FlexFlow::ReduceAttrs"; + j["value"] = x.get<::FlexFlow::ReduceAttrs>(); + break; + } + case 19: { + j["type"] = "::FlexFlow::ReverseAttrs"; + j["value"] = x.get<::FlexFlow::ReverseAttrs>(); + break; + } + case 20: { + j["type"] = "::FlexFlow::ReshapeAttrs"; + j["value"] = x.get<::FlexFlow::ReshapeAttrs>(); + break; + } + case 21: { + j["type"] = "::FlexFlow::SplitAttrs"; + j["value"] = x.get<::FlexFlow::SplitAttrs>(); + break; + } + case 22: { + j["type"] = "::FlexFlow::SoftmaxAttrs"; + j["value"] = x.get<::FlexFlow::SoftmaxAttrs>(); + break; + } + case 23: { + j["type"] = "::FlexFlow::TopKAttrs"; + j["value"] = x.get<::FlexFlow::TopKAttrs>(); + break; + } + case 24: { + j["type"] = "::FlexFlow::TransposeAttrs"; + j["value"] = x.get<::FlexFlow::TransposeAttrs>(); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type PCGOperatorAttrs", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::PCGOperatorAttrs const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + case 3: { + oss << ""; + break; + } + case 4: { + oss << ""; + break; + } + case 5: { + oss << ""; + break; + } + case 6: { + oss << ""; + break; + } + case 7: { + oss << ""; + break; + } + case 8: { + oss << ""; + break; + } + case 9: { + oss << ""; + break; + } + case 10: { + oss << ""; + break; + } + case 11: { + oss << ""; + break; + } + case 12: { + oss << ""; + break; + } + case 13: { + oss << ""; + break; + } + case 14: { + oss << ""; + break; + } + case 15: { + oss << ""; + break; + } + case 16: { + oss << ""; + break; + } + case 17: { + oss << ""; + break; + } + case 18: { + oss << ""; + break; + } + case 19: { + oss << ""; + break; + } + case 20: { + oss << ""; + break; + } + case 21: { + oss << ""; + break; + } + case 22: { + oss << ""; + break; + } + case 23: { + oss << ""; + break; + } + case 24: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type PCGOperatorAttrs", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::PCGOperatorAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/pool_op.cc b/lib/op-attrs/src/op-attrs/pool_op.dtg.cc similarity index 94% rename from lib/op-attrs/src/op-attrs/pool_op.cc rename to lib/op-attrs/src/op-attrs/pool_op.dtg.cc index dd0e1bfd00..08a6f43943 100644 --- a/lib/op-attrs/src/op-attrs/pool_op.cc +++ b/lib/op-attrs/src/op-attrs/pool_op.dtg.cc @@ -1,8 +1,13 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify // lib/op-attrs/include/op-attrs/pool_op.enum.toml +/* proj-data +{ + "generated_from": "ed1d531c6227306c909eb28eb0a66538" +} +*/ -#include "op-attrs/pool_op.h" +#include "op-attrs/pool_op.dtg.h" #include #include diff --git a/lib/op-attrs/src/op-attrs/regularizer_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/regularizer_attrs.dtg.cc new file mode 100644 index 0000000000..31a06cb19f --- /dev/null +++ b/lib/op-attrs/src/op-attrs/regularizer_attrs.dtg.cc @@ -0,0 +1,109 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml +/* proj-data +{ + "generated_from": "b0cb2d264215faf9759925c631f3d55f" +} +*/ + +#include "op-attrs/regularizer_attrs.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +RegularizerAttrs::RegularizerAttrs(::FlexFlow::L1RegularizerAttrs const &v) + : raw_variant(v) {} +RegularizerAttrs::RegularizerAttrs(::FlexFlow::L2RegularizerAttrs const &v) + : raw_variant(v) {} +bool RegularizerAttrs::operator==(RegularizerAttrs const &other) const { + return this->raw_variant == other.raw_variant; +} +bool RegularizerAttrs::operator!=(RegularizerAttrs const &other) const { + return this->raw_variant != other.raw_variant; +} +bool RegularizerAttrs::operator<(RegularizerAttrs const &other) const { + return this->raw_variant < other.raw_variant; +} +bool RegularizerAttrs::operator>(RegularizerAttrs const &other) const { + return this->raw_variant > other.raw_variant; +} +bool RegularizerAttrs::operator<=(RegularizerAttrs const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool RegularizerAttrs::operator>=(RegularizerAttrs const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::RegularizerAttrs>::operator()( + ::FlexFlow::RegularizerAttrs const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::RegularizerAttrs + adl_serializer<::FlexFlow::RegularizerAttrs>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "l1") { + return ::FlexFlow::RegularizerAttrs{ + j.at("value").template get<::FlexFlow::L1RegularizerAttrs>()}; + } else if (key == "l2") { + return ::FlexFlow::RegularizerAttrs{ + j.at("value").template get<::FlexFlow::L2RegularizerAttrs>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::RegularizerAttrs>::to_json( + json &j, ::FlexFlow::RegularizerAttrs const &x) { + j["__type"] = "RegularizerAttrs"; + switch (x.index()) { + case 0: { + j["type"] = "l1"; + j["value"] = x.get<::FlexFlow::L1RegularizerAttrs>(); + break; + } + case 1: { + j["type"] = "l2"; + j["value"] = x.get<::FlexFlow::L2RegularizerAttrs>(); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type RegularizerAttrs", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::RegularizerAttrs const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type RegularizerAttrs", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::RegularizerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_dims_t.cc b/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc similarity index 90% rename from lib/op-attrs/src/op-attrs/tensor_dims_t.cc rename to lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc index 087048b965..f2a7367f1d 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims_t.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc @@ -1,8 +1,16 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/tensor_dims_t.struct.toml +// lib/op-attrs/include/op-attrs/tensor_dims.struct.toml +/* proj-data +{ + "generated_from": "f925a4c2343d2404116dc598c301beaf" +} +*/ + +#include "op-attrs/tensor_dims.dtg.h" -#include "op-attrs/tensor_dims_t.h" +#include "op-attrs/dim_ordered.h" +#include namespace FlexFlow { TensorDims::TensorDims(::FlexFlow::FFOrdered const &ff_ordered) diff --git a/lib/op-attrs/src/op-attrs/tensor_shape_t.cc b/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc similarity index 90% rename from lib/op-attrs/src/op-attrs/tensor_shape_t.cc rename to lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc index 2cea614524..1538cc82c1 100644 --- a/lib/op-attrs/src/op-attrs/tensor_shape_t.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc @@ -1,8 +1,17 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/tensor_shape_t.struct.toml +// lib/op-attrs/include/op-attrs/tensor_shape.struct.toml +/* proj-data +{ + "generated_from": "c02c9d2331d864a25c1443cfe70062d1" +} +*/ + +#include "op-attrs/tensor_shape.dtg.h" -#include "op-attrs/tensor_shape_t.h" +#include "op-attrs/datatype.h" +#include "op-attrs/tensor_dims.h" +#include namespace FlexFlow { TensorShape::TensorShape(::FlexFlow::TensorDims const &dims, diff --git a/lib/op-attrs/src/operator_attrs.cc b/lib/op-attrs/src/operator_attrs.cc index a524ab3d14..7a0027fe61 100644 --- a/lib/op-attrs/src/operator_attrs.cc +++ b/lib/op-attrs/src/operator_attrs.cc @@ -193,7 +193,7 @@ struct IsValidFunctor { bool is_valid(PCGOperatorAttrs const &attrs, std::vector const &input_shapes) { - return visit(IsValidFunctor{input_shapes}, attrs); + NOT_IMPLEMENTED(); } /* int num_outputs(OperatorParameters const &o) { */ diff --git a/lib/op-attrs/test/CMakeLists.txt b/lib/op-attrs/test/CMakeLists.txt new file mode 100644 index 0000000000..b6ff72fc00 --- /dev/null +++ b/lib/op-attrs/test/CMakeLists.txt @@ -0,0 +1,13 @@ +ff_add_test_executable( + NAME + op-attrs-tests + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + src/ + DEPS + utils + op-attrs + doctest + utils-test-common +) diff --git a/lib/op-attrs/test/src/test_operator_attrs.cc b/lib/op-attrs/test/src/test_operator_attrs.cc new file mode 100644 index 0000000000..188c9d1607 --- /dev/null +++ b/lib/op-attrs/test/src/test_operator_attrs.cc @@ -0,0 +1,33 @@ +#include "doctest/doctest.h" +#include "op-attrs/operator_attrs.h" +#include "utils/json.h" +#include +#include + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("BatchNormAttrs to/from json") { + BatchNormAttrs correct = BatchNormAttrs{true}; + json j = correct; + auto result = j.get(); + CHECK(result == correct); + } + + TEST_CASE("ComputationGraphAttrs to/from json") { + ComputationGraphAttrs correct = BatchNormAttrs{true}; + json j = correct; + auto result = j.get(); + + CHECK(result == correct); + } + + TEST_CASE("PCGOperatorAttrs to/from json") { + PCGOperatorAttrs correct = RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{1}, + /*repartition_degree=*/4, + }; + json j = correct; + auto result = j.get(); + + CHECK(result == correct); + } +} diff --git a/lib/pcg/include/pcg/computation_graph.dtg.h b/lib/pcg/include/pcg/computation_graph.dtg.h new file mode 100644 index 0000000000..c5a74b08d5 --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph.dtg.h @@ -0,0 +1,31 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/computation_graph.struct.toml +/* proj-data +{ + "generated_from": "3639f7e8bb97a5ca2c2ef13caff3c84e" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H + +#include "pcg/layer_attrs.dtg.h" +#include "pcg/tensor_attrs.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +struct ComputationGraph { + ComputationGraph() = delete; + ComputationGraph( + ::FlexFlow::OutputLabelledMultiDiGraph<::FlexFlow::LayerAttrs, + ::FlexFlow::TensorAttrs> const + &raw_graph); + + ::FlexFlow::OutputLabelledMultiDiGraph<::FlexFlow::LayerAttrs, + ::FlexFlow::TensorAttrs> + raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index 11dad70356..a937a9d46e 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -1,25 +1,14 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H -#include "layer.h" -#include "operator_guid_t.h" -#include "tensor.h" -#include "utils/graph.h" -#include "utils/strong_typedef.h" -#include "visit_struct/visit_struct.hpp" +#include "pcg/computation_graph.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" +#include "pcg/tensor_attrs.dtg.h" namespace FlexFlow { -struct ComputationGraph - : public strong_typedef> { - using strong_typedef::strong_typedef; -}; +TensorAttrs get_tensor_attrs(ComputationGraph const &, tensor_guid_t const &); } // namespace FlexFlow -namespace FlexFlow { -static_assert(is_well_behaved_value_type_no_hash::value, ""); -} - #endif diff --git a/lib/pcg/include/pcg/computation_graph.struct.toml b/lib/pcg/include/pcg/computation_graph.struct.toml new file mode 100644 index 0000000000..30b3487da1 --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "ComputationGraph" +features = [ ] + +includes = [ + "utils/graph.h", + "pcg/layer_attrs.dtg.h", + "pcg/tensor_attrs.dtg.h" +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::OutputLabelledMultiDiGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs>" diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 035f0cad0b..2721743dcd 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -1,85 +1,86 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_BUILDER_H #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_BUILDER_H -#include "computation_graph.h" +#include "pcg/computation_graph.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" +#include "pcg/initializer_attrs.dtg.h" namespace FlexFlow { -struct ComputationGraphBuilder - : public use_visitable_cmp { +struct ComputationGraphBuilder { public: ComputationGraphBuilder(); // C++ APIs for constructing models // Add an exp layer - Tensor exp(Tensor const &, + tensor_guid_t exp(tensor_guid_t const &, std::optional const &name = std::nullopt); // Add an add layer - Tensor add(Tensor const &x, - Tensor const &y, + tensor_guid_t add(tensor_guid_t const &x, + tensor_guid_t const &y, std::optional const &name = std::nullopt); // Add a subtract layer - Tensor subtract(Tensor const &x, - Tensor const &y, + tensor_guid_t subtract(tensor_guid_t const &x, + tensor_guid_t const &y, std::optional const &name = std::nullopt); // Add a multiply layer - Tensor multiply(Tensor const &x, - Tensor const &y, + tensor_guid_t multiply(tensor_guid_t const &x, + tensor_guid_t const &y, std::optional const &name = std::nullopt); // Add a divide layer - Tensor divide(Tensor const &x, - Tensor const &y, + tensor_guid_t divide(tensor_guid_t const &x, + tensor_guid_t const &y, std::optional const &name = std::nullopt); // Add a max layer - Tensor max(Tensor const &x, - Tensor const &y, + tensor_guid_t max(tensor_guid_t const &x, + tensor_guid_t const &y, std::optional const &name = std::nullopt); // Add a min layer - Tensor min(Tensor const &x, - Tensor const &y, + tensor_guid_t min(tensor_guid_t const &x, + tensor_guid_t const &y, std::optional const &name = std::nullopt); // Add a rsqrt layer - Tensor rsqrt(Tensor const &x, + tensor_guid_t rsqrt(tensor_guid_t const &x, std::optional const &name = std::nullopt); // Add a pow layer - Tensor pow(Tensor const &x, + tensor_guid_t pow(tensor_guid_t const &x, float exponent, std::optional const &name = std::nullopt); // Add a scalar multiply layer - Tensor scalar_multiply(Tensor const &x, + tensor_guid_t scalar_multiply(tensor_guid_t const &x, float scalar, std::optional const &name = std::nullopt); - Tensor scalar_add(Tensor const &x, + tensor_guid_t scalar_add(tensor_guid_t const &x, float scalar, std::optional const &name = std::nullopt); - Tensor scalar_sub(Tensor const &lhs, + tensor_guid_t scalar_sub(tensor_guid_t const &lhs, float rhs, std::optional const &name = std::nullopt); - Tensor scalar_truediv(Tensor const &numerator, + tensor_guid_t scalar_truediv(tensor_guid_t const &numerator, float denominator, std::optional const &name = std::nullopt); // Add a sin layer - Tensor sin(Tensor const &x, + tensor_guid_t sin(tensor_guid_t const &x, std::optional const &name = std::nullopt); // Add a cos layer - Tensor cos(Tensor const &x, + tensor_guid_t cos(tensor_guid_t const &x, std::optional const &name = std::nullopt); // Add an activation layer - Tensor relu(Tensor const &x, + tensor_guid_t relu(tensor_guid_t const &x, std::optional const &name = std::nullopt); - Tensor identity(Tensor const &x, + tensor_guid_t identity(tensor_guid_t const &x, std::optional const &name = std::nullopt); - Tensor gelu(Tensor const &x, + tensor_guid_t gelu(tensor_guid_t const &x, std::optional const &name = std::nullopt); - Tensor sigmoid(Tensor const &x, + tensor_guid_t sigmoid(tensor_guid_t const &x, std::optional const &name = std::nullopt); - Tensor tanh(Tensor const &x, + tensor_guid_t tanh(tensor_guid_t const &x, std::optional const &name = std::nullopt); - Tensor elu(Tensor const &x, + tensor_guid_t elu(tensor_guid_t const &x, std::optional const &name = std::nullopt); // Add a 2D convolutional layer - Tensor conv2d( - Tensor const &input, + tensor_guid_t conv2d( + tensor_guid_t const &input, int outChannels, int kernelH, int kernelW, @@ -90,38 +91,38 @@ struct ComputationGraphBuilder std::optional const &activation = std::nullopt, int groups = 1, bool use_bias = true, - std::optional const &kernel_initializer = std::nullopt, - std::optional const &bias_initializer = std::nullopt, + std::optional const &kernel_initializer = std::nullopt, + std::optional const &bias_initializer = std::nullopt, std::optional const &kernel_regularizer = std::nullopt, std::optional const &name = std::nullopt); // Add a dropout layer - Tensor dropout(Tensor const &input, + tensor_guid_t dropout(tensor_guid_t const &input, float rate, unsigned long long seed = 0, std::optional const &name = std::nullopt); // Add an embedding layer - Tensor embedding( - Tensor const &input, + tensor_guid_t embedding( + tensor_guid_t const &input, int num_entries, int outDim, AggregateOp aggr, DataType dtype = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, + std::optional const &kernel_initializer = std::nullopt, std::optional const &name = std::nullopt); // Add a gather layer - std::vector - gather(Tensor const &input, - Tensor const &index, + std::vector + gather(tensor_guid_t const &input, + tensor_guid_t const &index, ff_dim_t dim, std::optional const &name = std::nullopt); // Add a cache layer - Tensor cache(Tensor const &input, + tensor_guid_t cache(tensor_guid_t const &input, int num_batches, std::function score_f = {}, std::optional const &name = std::nullopt); // Add a 2D pooling layer - Tensor pool2d(Tensor const &input, + tensor_guid_t pool2d(tensor_guid_t const &input, int kernelH, int kernelW, int strideH, @@ -131,78 +132,78 @@ struct ComputationGraphBuilder PoolOp type = PoolOp::MAX, std::optional const &activation = std::nullopt, std::optional const &name = std::nullopt); - Tensor layer_norm(Tensor const &input, + tensor_guid_t layer_norm(tensor_guid_t const &input, std::vector const &axes, bool elementwise_affine, float eps, std::optional const &name = std::nullopt); - Tensor batch_norm(Tensor const &input, + tensor_guid_t batch_norm(tensor_guid_t const &input, bool relu = true, std::optional const &name = std::nullopt); - Tensor batch_matmul(Tensor const &A, - Tensor const &B, + tensor_guid_t batch_matmul(tensor_guid_t const &A, + tensor_guid_t const &B, int a_seq_length_dim = -1, int b_seq_length_dim = -1, std::optional const &name = std::nullopt); - Tensor - dense(Tensor const &input, + tensor_guid_t + dense(tensor_guid_t const &input, int outDim, std::optional activation = std::nullopt, bool use_bias = true, DataType data_type = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, - std::optional const &bias_initializer = std::nullopt, + std::optional const &kernel_initializer = std::nullopt, + std::optional const &bias_initializer = std::nullopt, std::optional const &name = std::nullopt); // Add a cast layer - Tensor cast(Tensor const &input, + tensor_guid_t cast(tensor_guid_t const &input, DataType dtype, std::optional const &name = std::nullopt); // Add a concat layer - Tensor concat(int n, - std::vector const &tensors, + tensor_guid_t concat(int n, + std::vector const &tensors, int axis, std::optional const &name = std::nullopt); // Add a mean layer - Tensor mean(Tensor const &input, + tensor_guid_t mean(tensor_guid_t const &input, std::vector const &dims, bool keepdims, char const *name); // Add a split layer - void split(Tensor const &input, - Tensor *outputs, + void split(tensor_guid_t const &input, + tensor_guid_t *outputs, std::vector const &split, int axis, std::optional const &name = std::nullopt); // Add a flat layer - Tensor flat(Tensor const &input, + tensor_guid_t flat(tensor_guid_t const &input, std::optional const &name = std::nullopt); // Add a softmax layer - Tensor softmax(Tensor const &input, + tensor_guid_t softmax(tensor_guid_t const &input, int dim = -1, std::optional const &name = std::nullopt); // Create input tensors and constants - Tensor transpose(Tensor const &input, + tensor_guid_t transpose(tensor_guid_t const &input, std::vector const &perm, std::optional const &name = std::nullopt); - Tensor reduce_sum(Tensor const &input, + tensor_guid_t reduce_sum(tensor_guid_t const &input, std::vector const &axes, bool keepdims = false, std::optional const &name = std::nullopt); - Tensor reshape(Tensor const &input, + tensor_guid_t reshape(tensor_guid_t const &input, std::vector const &shape, std::optional const &name = std::nullopt); - Tensor reverse(Tensor const &input, + tensor_guid_t reverse(tensor_guid_t const &input, int axis, std::optional const &name = std::nullopt); - void top_k(Tensor const &input, - Tensor *outputs, + void top_k(tensor_guid_t const &input, + tensor_guid_t *outputs, int k, bool sorted, std::optional const &name = std::nullopt); - Tensor multihead_attention( - Tensor const &query, - Tensor const &key, - Tensor const &value, + tensor_guid_t multihead_attention( + tensor_guid_t const &query, + tensor_guid_t const &key, + tensor_guid_t const &value, int embed_dim, int num_heads, int kdim = 0, @@ -211,63 +212,66 @@ struct ComputationGraphBuilder bool bias = true, bool add_bias_kv = false, bool add_zero_attn = false, - std::optional initializer = std::nullopt, + std::optional initializer = std::nullopt, std::optional const &name = std::nullopt); - Tensor create_tensor(TensorShape const &, bool create_grad = true); - Parameter create_weight( + tensor_guid_t create_tensor(TensorShape const &, bool create_grad = true); + tensor_guid_t create_weight( TensorShape const &, bool create_grad = true, - std::optional const &initializer = std::nullopt, + std::optional const &initializer = std::nullopt, std::optional sync_type = std::nullopt); - std::vector get_outputs(Layer const &) const; - Tensor get_output(Layer const &, int idx) const; + std::vector get_outputs(LayerAttrs const &) const; + tensor_guid_t get_output(LayerAttrs const &, int idx) const; - Tensor at(MultiDiEdge const &) const; - Layer at(Node const &) const; + tensor_guid_t at(MultiDiEdge const &) const; + LayerAttrs at(Node const &) const; + TensorAttrs get_attrs(tensor_guid_t const &) const; + TensorShape get_shape(tensor_guid_t const &) const; private: - Tensor broadcast(Tensor const &, TensorShape const &); + tensor_guid_t broadcast(tensor_guid_t const &, TensorShape const &); - void add_layer(Layer const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs); - Tensor add_layer( - Layer const &layer, - std::vector const &inputs, - std::vector>> const + void add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs); + tensor_guid_t add_layer( + LayerAttrs const &layer, + std::vector const &inputs, + std::vector>> const &weight_shapes, TensorShape const &output_shape); - std::vector add_layer( - Layer const &layer, - std::vector const &inputs, - std::vector>> const + std::vector add_layer( + LayerAttrs const &layer, + std::vector const &inputs, + std::vector>> const &weight_shapes, std::vector const &output_shapes); - Tensor as_type(Tensor const &, DataType, std::string const &); + tensor_guid_t as_type(tensor_guid_t const &, DataType, std::string const &); + TensorShape get_broadcast_target_shape(std::vector const &); TensorShape get_broadcast_target_shape(std::vector const &); - Tensor element_binary(OperatorType, - Tensor const &lhs, - Tensor const &rhs, + tensor_guid_t element_binary(OperatorType, + tensor_guid_t const &lhs, + tensor_guid_t const &rhs, std::optional const &name = std::nullopt); - Tensor element_unary(OperatorType, - Tensor const &input, + tensor_guid_t element_unary(OperatorType, + tensor_guid_t const &input, std::optional const &name = std::nullopt); - Tensor element_scalar_unary( + tensor_guid_t element_scalar_unary( OperatorType, - Tensor const &input, + tensor_guid_t const &input, float scalar, std::optional const &name = std::nullopt); - Tensor element_unary(ElementUnaryAttrs const &, - Tensor const &input, + tensor_guid_t element_unary(ElementUnaryAttrs const &, + tensor_guid_t const &input, std::optional const &name = std::nullopt); - Tensor element_scalar_unary(ElementScalarUnaryAttrs const &attrs, - Tensor const &x, + tensor_guid_t element_scalar_unary(ElementScalarUnaryAttrs const &attrs, + tensor_guid_t const &x, std::optional const &maybe_name); public: @@ -276,11 +280,4 @@ struct ComputationGraphBuilder } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::ComputationGraphBuilder, computation_graph); - -namespace FlexFlow { -static_assert( - is_well_behaved_value_type_no_hash::value, ""); -} - #endif diff --git a/lib/pcg/include/pcg/cpu_id_t.dtg.h b/lib/pcg/include/pcg/cpu_id_t.dtg.h new file mode 100644 index 0000000000..a6c81e80b0 --- /dev/null +++ b/lib/pcg/include/pcg/cpu_id_t.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/cpu_id_t.struct.toml +/* proj-data +{ + "generated_from": "a0faf78831febfa3a02929169943d9f5" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CPU_ID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CPU_ID_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct cpu_id_t { + cpu_id_t() = delete; + cpu_id_t(int const &cpu_index); + + bool operator==(cpu_id_t const &) const; + bool operator!=(cpu_id_t const &) const; + bool operator<(cpu_id_t const &) const; + bool operator>(cpu_id_t const &) const; + bool operator<=(cpu_id_t const &) const; + bool operator>=(cpu_id_t const &) const; + int cpu_index; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::cpu_id_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::cpu_id_t from_json(json const &); + static void to_json(json &, FlexFlow::cpu_id_t const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(cpu_id_t const &); +std::ostream &operator<<(std::ostream &, cpu_id_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CPU_ID_T_DTG_H diff --git a/lib/pcg/include/pcg/cpu_id_t.struct.toml b/lib/pcg/include/pcg/cpu_id_t.struct.toml new file mode 100644 index 0000000000..0492a937be --- /dev/null +++ b/lib/pcg/include/pcg/cpu_id_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "cpu_id_t" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "cpu_index" +type = "int" diff --git a/lib/pcg/include/pcg/create_grad.dtg.h b/lib/pcg/include/pcg/create_grad.dtg.h new file mode 100644 index 0000000000..494ff06b75 --- /dev/null +++ b/lib/pcg/include/pcg/create_grad.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/create_grad.enum.toml +/* proj-data +{ + "generated_from": "9fd617027e850b6d6db476a49b3e0334" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CREATE_GRAD_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CREATE_GRAD_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class CreateGrad { YES, NO }; +std::string format_as(CreateGrad); +std::ostream &operator<<(std::ostream &, CreateGrad); +void to_json(::nlohmann::json &, CreateGrad); +void from_json(::nlohmann::json const &, CreateGrad &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::CreateGrad) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CREATE_GRAD_DTG_H diff --git a/lib/pcg/include/pcg/create_grad.enum.toml b/lib/pcg/include/pcg/create_grad.enum.toml new file mode 100644 index 0000000000..20febe49fb --- /dev/null +++ b/lib/pcg/include/pcg/create_grad.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "CreateGrad" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "YES" + +[[values]] +name = "NO" diff --git a/lib/pcg/include/pcg/create_grad.h b/lib/pcg/include/pcg/create_grad.h index 7dd843b76d..26ba88f1b2 100644 --- a/lib/pcg/include/pcg/create_grad.h +++ b/lib/pcg/include/pcg/create_grad.h @@ -1,36 +1,10 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_CREATE_GRAD_H #define _FLEXFLOW_PCG_INCLUDE_PCG_CREATE_GRAD_H -#include "utils/fmt.h" +#include "pcg/create_grad_t.h" namespace FlexFlow { -enum class CreateGrad { YES, NO }; - } -namespace fmt { - -template <> -struct formatter<::FlexFlow::CreateGrad> : formatter { - template - auto format(::FlexFlow::CreateGrad ps, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (ps) { - case CreateGrad::YES: - name = "yes"; - break; - case CreateGrad::NO: - name = "no"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - #endif diff --git a/lib/pcg/include/pcg/device_id.h b/lib/pcg/include/pcg/device_id.h index b118d69259..9c38674e82 100644 --- a/lib/pcg/include/pcg/device_id.h +++ b/lib/pcg/include/pcg/device_id.h @@ -1,35 +1,21 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_DEVICE_ID_H #define _FLEXFLOW_PCG_INCLUDE_PCG_DEVICE_ID_H -#include "device_type.h" -#include "utils/strong_typedef.h" -#include +#include "pcg/device_type.dtg.h" +#include "pcg/cpu_id_t.dtg.h" +#include "pcg/gpu_id_t.dtg.h" +#include "pcg/device_id_t.dtg.h" namespace FlexFlow { -struct gpu_id_t : strong_typedef { - using strong_typedef::strong_typedef; -}; - -struct cpu_id_t : strong_typedef { - using strong_typedef::strong_typedef; -}; - -using device_id_t = std::variant; device_id_t operator+(device_id_t, size_t); DeviceType get_device_type(device_id_t); gpu_id_t unwrap_gpu(device_id_t); cpu_id_t unwrap_cpu(device_id_t); -device_id_t from_index(int, DeviceType); +device_id_t device_id_from_index(int, DeviceType); } // namespace FlexFlow -MAKE_TYPEDEF_HASHABLE(::FlexFlow::gpu_id_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::gpu_id_t, "gpu_id"); - -MAKE_TYPEDEF_HASHABLE(::FlexFlow::cpu_id_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::cpu_id_t, "cpu_id"); - #endif diff --git a/lib/pcg/include/pcg/device_id_t.dtg.h b/lib/pcg/include/pcg/device_id_t.dtg.h new file mode 100644 index 0000000000..d46f3dd079 --- /dev/null +++ b/lib/pcg/include/pcg/device_id_t.dtg.h @@ -0,0 +1,117 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/device_id_t.variant.toml +/* proj-data +{ + "generated_from": "85870050c742b0159775399ec2be67e3" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_ID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_ID_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/cpu_id_t.dtg.h" +#include "pcg/gpu_id_t.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct device_id_t { + device_id_t() = delete; + explicit device_id_t(::FlexFlow::gpu_id_t const &); + explicit device_id_t(::FlexFlow::cpu_id_t const &); + template + static constexpr bool IsPartOfdevice_id_t_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::gpu_id_t>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::cpu_id_t>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type device_id_t", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::gpu_id_t>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::cpu_id_t>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type device_id_t", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfdevice_id_t_v, + "device_id_t::has() expected one of [::FlexFlow::gpu_id_t, " + "::FlexFlow::cpu_id_t], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfdevice_id_t_v, + "device_id_t::get() expected one of [::FlexFlow::gpu_id_t, " + "::FlexFlow::cpu_id_t], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfdevice_id_t_v, + "device_id_t::get() expected one of [::FlexFlow::gpu_id_t, " + "::FlexFlow::cpu_id_t], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(device_id_t const &) const; + bool operator!=(device_id_t const &) const; + bool operator<(device_id_t const &) const; + bool operator>(device_id_t const &) const; + bool operator<=(device_id_t const &) const; + bool operator>=(device_id_t const &) const; + std::variant<::FlexFlow::gpu_id_t, ::FlexFlow::cpu_id_t> raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::device_id_t> { + size_t operator()(::FlexFlow::device_id_t const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::device_id_t> { + static ::FlexFlow::device_id_t from_json(json const &); + static void to_json(json &, ::FlexFlow::device_id_t const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::device_id_t const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::device_id_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_ID_T_DTG_H diff --git a/lib/pcg/include/pcg/device_id_t.variant.toml b/lib/pcg/include/pcg/device_id_t.variant.toml new file mode 100644 index 0000000000..71af18919f --- /dev/null +++ b/lib/pcg/include/pcg/device_id_t.variant.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "device_id_t" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/cpu_id_t.dtg.h", + "pcg/gpu_id_t.dtg.h", +] + +[[values]] +type = "::FlexFlow::gpu_id_t" +key = "gpu" + +[[values]] +type = "::FlexFlow::cpu_id_t" +key = "cpu" diff --git a/lib/pcg/include/pcg/device_type.dtg.h b/lib/pcg/include/pcg/device_type.dtg.h new file mode 100644 index 0000000000..f5e90dc193 --- /dev/null +++ b/lib/pcg/include/pcg/device_type.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/device_type.enum.toml +/* proj-data +{ + "generated_from": "cfe4bc5e9f7c5796b9b90b420c33935f" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_TYPE_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_TYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class DeviceType { GPU, CPU }; +std::string format_as(DeviceType); +std::ostream &operator<<(std::ostream &, DeviceType); +void to_json(::nlohmann::json &, DeviceType); +void from_json(::nlohmann::json const &, DeviceType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::DeviceType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_TYPE_DTG_H diff --git a/lib/pcg/include/pcg/device_type.enum.toml b/lib/pcg/include/pcg/device_type.enum.toml new file mode 100644 index 0000000000..67f89fbc6f --- /dev/null +++ b/lib/pcg/include/pcg/device_type.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "DeviceType" +features = [ + "hash", + "json", + "fmt", + "rapidcheck", +] + +[[values]] +name = "GPU" + +[[values]] +name = "CPU" diff --git a/lib/pcg/include/pcg/device_type.h b/lib/pcg/include/pcg/device_type.h deleted file mode 100644 index 3ae374c5ea..0000000000 --- a/lib/pcg/include/pcg/device_type.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_DEVICE_TYPE_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_DEVICE_TYPE_H - -#include "utils/fmt.h" - -namespace FlexFlow { - -enum class DeviceType { GPU, CPU }; - -} - -namespace fmt { - -template <> -struct formatter<::FlexFlow::DeviceType> : formatter { - template - auto format(::FlexFlow::DeviceType d, FormatContext &ctx) const - -> decltype(ctx.out()) { - using ::FlexFlow::DeviceType; - - string_view name = "unknown"; - switch (d) { - case DeviceType::GPU: - name = "GPU"; - break; - case DeviceType::CPU: - name = "CPU"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/data_type.h b/lib/pcg/include/pcg/file_format/v1/data_type_value.h similarity index 64% rename from lib/pcg/include/pcg/file_format/v1/data_type.h rename to lib/pcg/include/pcg/file_format/v1/data_type_value.h index eab188155f..6e4e5abc54 100644 --- a/lib/pcg/include/pcg/file_format/v1/data_type.h +++ b/lib/pcg/include/pcg/file_format/v1/data_type_value.h @@ -9,23 +9,6 @@ namespace FlexFlow { using V1DataTypeValue = std::variant; -enum class V1DataType { - BOOL, - INT32, - INT64, - HALF, - FLOAT, - DOUBLE, -}; - -NLOHMANN_JSON_SERIALIZE_ENUM(V1DataType, - {{V1DataType::BOOL, "BOOL"}, - {V1DataType::INT32, "INT32"}, - {V1DataType::INT64, "INT64"}, - {V1DataType::HALF, "HALF"}, - {V1DataType::FLOAT, "FLOAT"}, - {V1DataType::DOUBLE, "DOUBLE"}}); - } // namespace FlexFlow namespace nlohmann { diff --git a/lib/pcg/include/pcg/file_format/v1/graphs.h b/lib/pcg/include/pcg/file_format/v1/graphs.h index 6bc852b0f1..6417a549cb 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs.h @@ -1,73 +1,23 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H -#include "operator_attrs.h" -#include "parallel_tensor.h" -#include "pcg/computation_graph.h" -#include "pcg/parallel_computation_graph.h" -#include "tensor.h" +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "pcg/tensor_attrs.dtg.h" +#include "pcg/computation_graph.dtg.h" +#include "pcg/parallel_computation_graph.dtg.h" +#include "pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h" #include "utils/json.h" -#include "utils/required.h" -#include "utils/visitable.h" +#include "pcg/layer_attrs.dtg.h" +#include "pcg/parallel_layer_attrs.dtg.h" namespace FlexFlow { -struct V1GraphOutput { - req srcNode; - req srcIdx; -}; -FF_VISITABLE_STRUCT(V1GraphOutput, srcNode, srcIdx); -CHECK_IS_JSONABLE(V1GraphOutput); - -struct V1GraphEdge { - req srcNode; - req srcIdx; - req dstNode; - req dstIdx; -}; -FF_VISITABLE_STRUCT(V1GraphEdge, srcNode, srcIdx, dstNode, dstIdx); -CHECK_IS_JSONABLE(V1GraphEdge); - -struct V1MultiDiGraph { - req> nodes; - req> ports; - req> edges; -}; -FF_VISITABLE_STRUCT(V1MultiDiGraph, nodes, ports, edges); -CHECK_IS_JSONABLE(V1MultiDiGraph); -V1MultiDiGraph to_v1(MultiDiGraphView const &); -V1MultiDiGraph to_v1(MultiDiGraphView const &, - std::unordered_map const &, - std::unordered_map const &); - -template -struct V1JsonableGraph { - using node_id = size_t; - using tensor_id = size_t; - - req> node_labels; - req> outputs; - req> output_labels; - V1MultiDiGraph graph; -}; - -struct V1Layer { - V1CompGraphOperatorAttrs attrs; - req> name; -}; -FF_VISITABLE_STRUCT(V1Layer, attrs, name); -V1Layer to_v1(Layer const &); - -using V1ComputationGraph = V1JsonableGraph; -FF_VISITABLE_STRUCT( - V1ComputationGraph, node_labels, outputs, output_labels, graph); +using V1ComputationGraph = V1JsonableGraph; CHECK_IS_JSONABLE(V1ComputationGraph); V1ComputationGraph to_v1(ComputationGraph const &); using V1ParallelComputationGraph = - V1JsonableGraph; -FF_VISITABLE_STRUCT( - V1ParallelComputationGraph, node_labels, outputs, output_labels, graph); + V1JsonableGraph; CHECK_IS_JSONABLE(V1ParallelComputationGraph); V1ParallelComputationGraph to_v1(ParallelComputationGraph const &); diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.h new file mode 100644 index 0000000000..e9238301d0 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.h @@ -0,0 +1,60 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml +/* proj-data +{ + "generated_from": "865097b569b831af049343e933834329" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_EDGE_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_EDGE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct V1GraphEdge { + V1GraphEdge() = delete; + V1GraphEdge(size_t const &srcNode, + size_t const &srcIdx, + size_t const &dstNode, + size_t const &dstIdx); + + bool operator==(V1GraphEdge const &) const; + bool operator!=(V1GraphEdge const &) const; + bool operator<(V1GraphEdge const &) const; + bool operator>(V1GraphEdge const &) const; + bool operator<=(V1GraphEdge const &) const; + bool operator>=(V1GraphEdge const &) const; + size_t srcNode; + size_t srcIdx; + size_t dstNode; + size_t dstIdx; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::V1GraphEdge const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::V1GraphEdge from_json(json const &); + static void to_json(json &, FlexFlow::V1GraphEdge const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1GraphEdge const &); +std::ostream &operator<<(std::ostream &, V1GraphEdge const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_EDGE_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml new file mode 100644 index 0000000000..b0d2546977 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "V1GraphEdge" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +[[fields]] +name = "srcNode" +type = "size_t" + +[[fields]] +name = "srcIdx" +type = "size_t" + +[[fields]] +name = "dstNode" +type = "size_t" + +[[fields]] +name = "dstIdx" +type = "size_t" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.h new file mode 100644 index 0000000000..730282bdb9 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.h @@ -0,0 +1,55 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml +/* proj-data +{ + "generated_from": "05ff8401c3d976ea2220899edb8dfe3a" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_OUTPUT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_OUTPUT_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct V1GraphOutput { + V1GraphOutput() = delete; + V1GraphOutput(size_t const &srcNode, size_t const &srcIdx); + + bool operator==(V1GraphOutput const &) const; + bool operator!=(V1GraphOutput const &) const; + bool operator<(V1GraphOutput const &) const; + bool operator>(V1GraphOutput const &) const; + bool operator<=(V1GraphOutput const &) const; + bool operator>=(V1GraphOutput const &) const; + size_t srcNode; + size_t srcIdx; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::V1GraphOutput const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::V1GraphOutput from_json(json const &); + static void to_json(json &, FlexFlow::V1GraphOutput const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1GraphOutput const &); +std::ostream &operator<<(std::ostream &, V1GraphOutput const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_OUTPUT_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml new file mode 100644 index 0000000000..ba41f7e43f --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "V1GraphOutput" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +[[fields]] +name = "srcNode" +type = "size_t" + +[[fields]] +name = "srcIdx" +type = "size_t" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h new file mode 100644 index 0000000000..f183a14a9e --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h @@ -0,0 +1,109 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml +/* proj-data +{ + "generated_from": "0595a9f5a6bc19f9a170cb0e42c4202d" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_JSONABLE_GRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_JSONABLE_GRAPH_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/file_format/v1/graphs/v1_graph_output.dtg.h" +#include "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h" +#include +#include +#include + +namespace FlexFlow { +template +struct V1JsonableGraph { + V1JsonableGraph() = delete; + V1JsonableGraph( + std::unordered_map const &node_labels, + std::unordered_map const &outputs, + std::unordered_map const &output_labels, + ::FlexFlow::V1MultiDiGraph const &graph); + + std::unordered_map node_labels; + std::unordered_map outputs; + std::unordered_map output_labels; + ::FlexFlow::V1MultiDiGraph graph; +}; +} // namespace FlexFlow + +namespace nlohmann { +template +struct adl_serializer> { + static FlexFlow::V1JsonableGraph from_json(json const &); + static void to_json(json &, + FlexFlow::V1JsonableGraph const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +template +std::string format_as(V1JsonableGraph const &); +template +std::ostream &operator<<(std::ostream &, + V1JsonableGraph const &); +} // namespace FlexFlow + +namespace FlexFlow { +template +V1JsonableGraph::V1JsonableGraph( + std::unordered_map const &node_labels, + std::unordered_map const &outputs, + std::unordered_map const &output_labels, + ::FlexFlow::V1MultiDiGraph const &graph) + : node_labels(node_labels), outputs(outputs), output_labels(output_labels), + graph(graph) {} +} // namespace FlexFlow + +namespace nlohmann { +template +FlexFlow::V1JsonableGraph + adl_serializer>::from_json( + json const &j) { + return { + j.at("node_labels").template get>(), + j.at("outputs") + .template get< + std::unordered_map>(), + j.at("output_labels").template get>(), + j.at("graph").template get<::FlexFlow::V1MultiDiGraph>()}; +} +template +void adl_serializer>::to_json( + json &j, FlexFlow::V1JsonableGraph const &v) { + j["__type"] = "V1JsonableGraph"; + j["node_labels"] = v.node_labels; + j["outputs"] = v.outputs; + j["output_labels"] = v.output_labels; + j["graph"] = v.graph; +} +} // namespace nlohmann + +namespace FlexFlow { +template +std::string format_as(V1JsonableGraph const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +template +std::ostream &operator<<(std::ostream &s, + V1JsonableGraph const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_JSONABLE_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml new file mode 100644 index 0000000000..ad9ba21c60 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml @@ -0,0 +1,38 @@ +namespace = "FlexFlow" +name = "V1JsonableGraph" +features = [ + # "eq", + # "ord", + # "hash", + "json", + # "rapidcheck", + "fmt", +] + +template_params = [ + "NodeT", + "TensorT", +] + +includes = [ + "", + "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h", + "pcg/file_format/v1/graphs/v1_graph_output.dtg.h", +] + +[[fields]] +name = "node_labels" +type = "std::unordered_map" + +[[fields]] +name = "outputs" +type = "std::unordered_map" + +[[fields]] +name = "output_labels" +type = "std::unordered_map" + +[[fields]] +name = "graph" +type = "::FlexFlow::V1MultiDiGraph" + diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h new file mode 100644 index 0000000000..5d7edcf1d8 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml +/* proj-data +{ + "generated_from": "fb1033385645e54a19c9b44cef0be04b" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" +#include "utils/fmt.h" +#include +#include +#include + +namespace FlexFlow { +struct V1MultiDiGraph { + V1MultiDiGraph() = delete; + V1MultiDiGraph(std::vector const &nodes, + std::vector const &ports, + std::unordered_set<::FlexFlow::V1GraphEdge> const &edges); + + std::vector nodes; + std::vector ports; + std::unordered_set<::FlexFlow::V1GraphEdge> edges; +}; +} // namespace FlexFlow + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::V1MultiDiGraph from_json(json const &); + static void to_json(json &, FlexFlow::V1MultiDiGraph const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1MultiDiGraph const &); +std::ostream &operator<<(std::ostream &, V1MultiDiGraph const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.h new file mode 100644 index 0000000000..49ff850a29 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_H + +#include "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { + +V1MultiDiGraph to_v1(MultiDiGraphView const &); +V1MultiDiGraph to_v1(MultiDiGraphView const &, + std::unordered_map const &, + std::unordered_map const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml new file mode 100644 index 0000000000..9650f3bd43 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "V1MultiDiGraph" +features = [ + # "eq", + # "ord", + # "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "", + "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h", + "utils/fmt.h", +] + +[[fields]] +name = "nodes" +type = "std::vector" + +[[fields]] +name = "ports" +type = "std::vector" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::V1GraphEdge>" diff --git a/lib/pcg/include/pcg/file_format/v1/initializer.h b/lib/pcg/include/pcg/file_format/v1/initializer.h deleted file mode 100644 index 21af7d55e0..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/initializer.h +++ /dev/null @@ -1,57 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_INITIALIZER_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_INITIALIZER_H - -#include "data_type.h" -#include "utils/json.h" -#include "utils/required.h" -#include "utils/variant.h" -#include "utils/visitable.h" -#include "visit_struct/visit_struct_intrusive.hpp" - -namespace FlexFlow { - -struct V1GlorotInitializer { - req seed; -}; -FF_VISITABLE_STRUCT(V1GlorotInitializer, seed); - -struct V1ZeroInitializer {}; -FF_VISITABLE_STRUCT(V1ZeroInitializer); - -struct V1UniformInitializer { - int seed; - float min_val; - req max_val; -}; -FF_VISITABLE_STRUCT(V1UniformInitializer, seed, min_val, max_val); - -struct V1NormInitializer { - int seed; - float mean; - req stddev; -}; -FF_VISITABLE_STRUCT(V1NormInitializer, seed, mean, stddev); - -struct V1ConstantInitializer { - req value; -}; -FF_VISITABLE_STRUCT(V1ConstantInitializer, value); - -using V1Initializer = std::variant; - -} // namespace FlexFlow - -namespace FlexFlow { -CHECK_IS_JSONABLE(V1GlorotInitializer); -CHECK_IS_JSONABLE(V1ZeroInitializer); -CHECK_IS_JSONABLE(V1UniformInitializer); -CHECK_IS_JSONABLE(V1NormInitializer); -CHECK_IS_JSONABLE(V1ConstantInitializer); -CHECK_IS_JSONABLE(V1Initializer); -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/operator_attrs.h b/lib/pcg/include/pcg/file_format/v1/operator_attrs.h deleted file mode 100644 index 2830fbd301..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/operator_attrs.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPERATOR_ATTRS_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPERATOR_ATTRS_H - -#include "utils/json.h" -#include - -namespace FlexFlow { - -struct V1Conv2DAttrs {}; -FF_VISITABLE_STRUCT(V1Conv2DAttrs); - -static_assert( - std::is_same, std::tuple<>>::value, ""); - -using V1CompGraphOperatorAttrs = std::variant; -using V1PCGOperatorAttrs = std::variant; - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/parallel_tensor.h b/lib/pcg/include/pcg/file_format/v1/parallel_tensor.h deleted file mode 100644 index c215569b21..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/parallel_tensor.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_PARALLEL_TENSOR_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_PARALLEL_TENSOR_H - -#include "data_type.h" -#include "initializer.h" -#include "param_sync.h" -#include "utils/json.h" -#include "utils/variant.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct V1ParallelDim { - size_t size; - int degree; - req is_replica_dim; -}; -FF_VISITABLE_STRUCT(V1ParallelDim, size, degree, is_replica_dim); - -struct V1ParallelTensorShape { - std::vector dims; - req data_type; -}; -FF_VISITABLE_STRUCT(V1ParallelTensorShape, dims, data_type); - -struct V1ParallelTensor { - V1ParallelTensorShape shape; - std::optional sync_type; - std::optional initializer; - req create_grad; -}; -FF_VISITABLE_STRUCT( - V1ParallelTensor, shape, sync_type, initializer, create_grad); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/param_sync.h b/lib/pcg/include/pcg/file_format/v1/param_sync.h deleted file mode 100644 index 32769a8d20..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/param_sync.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_PCG_FILE_FORMAT_V1_PARAM_SYNC_H -#define _FLEXFLOW_PCG_FILE_FORMAT_V1_PARAM_SYNC_H - -#include "utils/json.h" - -namespace FlexFlow { - -enum class V1ParamSync { PARAM_SERVER, NCCL }; - -NLOHMANN_JSON_SERIALIZE_ENUM(V1ParamSync, - {{V1ParamSync::PARAM_SERVER, "PARAM_SERVER"}, - {V1ParamSync::NCCL, "NCCL"}}); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/tensor.h b/lib/pcg/include/pcg/file_format/v1/tensor.h deleted file mode 100644 index c304a41401..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/tensor.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_TENSOR_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_TENSOR_H - -#include "data_type.h" -#include "initializer.h" -#include "op-attrs/tensor_shape.h" -#include "param_sync.h" -#include "pcg/tensor.h" -#include "utils/visitable.h" -#include - -namespace FlexFlow { - -struct V1TensorShape { - std::vector dims; - req data_type; -}; -FF_VISITABLE_STRUCT(V1TensorShape, dims, data_type); -CHECK_IS_JSONABLE(V1TensorShape); -V1TensorShape to_v1(TensorShape const &); - -struct V1Tensor { - V1TensorShape shape; - std::optional initializer; - bool create_gradients; - std::optional sync_type; - req> name; -}; -FF_VISITABLE_STRUCT( - V1Tensor, shape, initializer, create_gradients, sync_type, name); -CHECK_IS_JSONABLE(V1Tensor); -V1Tensor to_v1(Tensor const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/gpu_id_t.dtg.h b/lib/pcg/include/pcg/gpu_id_t.dtg.h new file mode 100644 index 0000000000..f0847848ca --- /dev/null +++ b/lib/pcg/include/pcg/gpu_id_t.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/gpu_id_t.struct.toml +/* proj-data +{ + "generated_from": "022355e43f43141d332be50ea3080ee2" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_GPU_ID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_GPU_ID_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct gpu_id_t { + gpu_id_t() = delete; + gpu_id_t(int const &gpu_index); + + bool operator==(gpu_id_t const &) const; + bool operator!=(gpu_id_t const &) const; + bool operator<(gpu_id_t const &) const; + bool operator>(gpu_id_t const &) const; + bool operator<=(gpu_id_t const &) const; + bool operator>=(gpu_id_t const &) const; + int gpu_index; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::gpu_id_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::gpu_id_t from_json(json const &); + static void to_json(json &, FlexFlow::gpu_id_t const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(gpu_id_t const &); +std::ostream &operator<<(std::ostream &, gpu_id_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_GPU_ID_T_DTG_H diff --git a/lib/pcg/include/pcg/gpu_id_t.struct.toml b/lib/pcg/include/pcg/gpu_id_t.struct.toml new file mode 100644 index 0000000000..170dbb96fa --- /dev/null +++ b/lib/pcg/include/pcg/gpu_id_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "gpu_id_t" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "gpu_index" +type = "int" diff --git a/lib/pcg/include/pcg/initializer.h b/lib/pcg/include/pcg/initializer.h deleted file mode 100644 index 6913289653..0000000000 --- a/lib/pcg/include/pcg/initializer.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_INITIALIZER_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_INITIALIZER_H - -#include "op-attrs/datatype.h" -#include "utils/required.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct GlorotUniform { - req seed; - /* float scale; */ - /* DataType data_type; */ -}; -FF_VISITABLE_STRUCT(GlorotUniform, seed); - -struct ZeroInitializer { - ZeroInitializer() = default; -}; -FF_VISITABLE_STRUCT(ZeroInitializer); - -struct UniformInitializer { - int seed; - float min_val; - req max_val; -}; -FF_VISITABLE_STRUCT(UniformInitializer, seed, min_val, max_val); - -struct NormInitializer { - int seed; - float mean; - req stddev; -}; -FF_VISITABLE_STRUCT(NormInitializer, seed, mean, stddev); - -struct ConstantInitializer { - req value; -}; -FF_VISITABLE_STRUCT(ConstantInitializer, value); - -using Initializer = std::variant; -CHECK_WELL_BEHAVED_VALUE_TYPE(Initializer); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializer_attrs.dtg.h new file mode 100644 index 0000000000..7f5a470a90 --- /dev/null +++ b/lib/pcg/include/pcg/initializer_attrs.dtg.h @@ -0,0 +1,169 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializer_attrs.variant.toml +/* proj-data +{ + "generated_from": "f66f3a89ea937e96a058d83ab52e2826" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/initializers/constant_initializer_attrs.dtg.h" +#include "pcg/initializers/glorot_uniform_attrs.dtg.h" +#include "pcg/initializers/norm_initializer_attrs.dtg.h" +#include "pcg/initializers/uniform_initializer_attrs.dtg.h" +#include "pcg/initializers/zero_initializer_attrs.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct InitializerAttrs { + InitializerAttrs() = delete; + explicit InitializerAttrs(::FlexFlow::GlorotUniformAttrs const &); + explicit InitializerAttrs(::FlexFlow::ZeroInitializerAttrs const &); + explicit InitializerAttrs(::FlexFlow::UniformInitializerAttrs const &); + explicit InitializerAttrs(::FlexFlow::NormInitializerAttrs const &); + explicit InitializerAttrs(::FlexFlow::ConstantInitializerAttrs const &); + template + static constexpr bool IsPartOfInitializerAttrs_v = + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::GlorotUniformAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::ZeroInitializerAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::UniformInitializerAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::NormInitializerAttrs>()); + return result; + } + case 4: { + ReturnType result = + v(this->get<::FlexFlow::ConstantInitializerAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type InitializerAttrs", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::GlorotUniformAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::ZeroInitializerAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::UniformInitializerAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::NormInitializerAttrs>()); + return result; + } + case 4: { + ReturnType result = + v(this->get<::FlexFlow::ConstantInitializerAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type InitializerAttrs", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfInitializerAttrs_v, + "InitializerAttrs::has() expected one of " + "[::FlexFlow::GlorotUniformAttrs, ::FlexFlow::ZeroInitializerAttrs, " + "::FlexFlow::UniformInitializerAttrs, " + "::FlexFlow::NormInitializerAttrs, " + "::FlexFlow::ConstantInitializerAttrs], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfInitializerAttrs_v, + "InitializerAttrs::get() expected one of " + "[::FlexFlow::GlorotUniformAttrs, ::FlexFlow::ZeroInitializerAttrs, " + "::FlexFlow::UniformInitializerAttrs, " + "::FlexFlow::NormInitializerAttrs, " + "::FlexFlow::ConstantInitializerAttrs], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfInitializerAttrs_v, + "InitializerAttrs::get() expected one of " + "[::FlexFlow::GlorotUniformAttrs, ::FlexFlow::ZeroInitializerAttrs, " + "::FlexFlow::UniformInitializerAttrs, " + "::FlexFlow::NormInitializerAttrs, " + "::FlexFlow::ConstantInitializerAttrs], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(InitializerAttrs const &) const; + bool operator!=(InitializerAttrs const &) const; + bool operator<(InitializerAttrs const &) const; + bool operator>(InitializerAttrs const &) const; + bool operator<=(InitializerAttrs const &) const; + bool operator>=(InitializerAttrs const &) const; + std::variant<::FlexFlow::GlorotUniformAttrs, + ::FlexFlow::ZeroInitializerAttrs, + ::FlexFlow::UniformInitializerAttrs, + ::FlexFlow::NormInitializerAttrs, + ::FlexFlow::ConstantInitializerAttrs> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::InitializerAttrs> { + size_t operator()(::FlexFlow::InitializerAttrs const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::InitializerAttrs> { + static ::FlexFlow::InitializerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::InitializerAttrs const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::InitializerAttrs const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::InitializerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializer_attrs.variant.toml b/lib/pcg/include/pcg/initializer_attrs.variant.toml new file mode 100644 index 0000000000..14a5cfdcac --- /dev/null +++ b/lib/pcg/include/pcg/initializer_attrs.variant.toml @@ -0,0 +1,37 @@ +namespace = "FlexFlow" +name = "InitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/initializers/glorot_uniform_attrs.dtg.h", + "pcg/initializers/zero_initializer_attrs.dtg.h", + "pcg/initializers/uniform_initializer_attrs.dtg.h", + "pcg/initializers/norm_initializer_attrs.dtg.h", + "pcg/initializers/constant_initializer_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::GlorotUniformAttrs" +key = "glorot_uniform" + +[[values]] +type = "::FlexFlow::ZeroInitializerAttrs" +key = "zero" + +[[values]] +type = "::FlexFlow::UniformInitializerAttrs" +key = "uniform" + +[[values]] +type = "::FlexFlow::NormInitializerAttrs" +key = "normal" + +[[values]] +type = "::FlexFlow::ConstantInitializerAttrs" +key = "constant" diff --git a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h new file mode 100644 index 0000000000..1eb9eb8834 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h @@ -0,0 +1,56 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "0162b9c49fe6cbfc65410c6fa8dec427" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_CONSTANT_INITIALIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_CONSTANT_INITIALIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.h" +#include "utils/json.h" +#include +#include +#include + +namespace FlexFlow { +struct ConstantInitializerAttrs { + ConstantInitializerAttrs() = delete; + ConstantInitializerAttrs(::FlexFlow::DataTypeValue const &value); + + bool operator==(ConstantInitializerAttrs const &) const; + bool operator!=(ConstantInitializerAttrs const &) const; + bool operator<(ConstantInitializerAttrs const &) const; + bool operator>(ConstantInitializerAttrs const &) const; + bool operator<=(ConstantInitializerAttrs const &) const; + bool operator>=(ConstantInitializerAttrs const &) const; + ::FlexFlow::DataTypeValue value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ConstantInitializerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ConstantInitializerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ConstantInitializerAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ConstantInitializerAttrs const &); +std::ostream &operator<<(std::ostream &, ConstantInitializerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_CONSTANT_INITIALIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml new file mode 100644 index 0000000000..3a80559d7b --- /dev/null +++ b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "ConstantInitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/datatype.h", + "utils/json.h", +] + +[[fields]] +name = "value" +type = "::FlexFlow::DataTypeValue" diff --git a/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.dtg.h b/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.dtg.h new file mode 100644 index 0000000000..04851fb333 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml +/* proj-data +{ + "generated_from": "a268b411b6d378faa11e60c8517d7be5" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_GLOROT_UNIFORM_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_GLOROT_UNIFORM_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct GlorotUniformAttrs { + GlorotUniformAttrs() = delete; + GlorotUniformAttrs(int const &seed); + + bool operator==(GlorotUniformAttrs const &) const; + bool operator!=(GlorotUniformAttrs const &) const; + bool operator<(GlorotUniformAttrs const &) const; + bool operator>(GlorotUniformAttrs const &) const; + bool operator<=(GlorotUniformAttrs const &) const; + bool operator>=(GlorotUniformAttrs const &) const; + int seed; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::GlorotUniformAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::GlorotUniformAttrs from_json(json const &); + static void to_json(json &, FlexFlow::GlorotUniformAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(GlorotUniformAttrs const &); +std::ostream &operator<<(std::ostream &, GlorotUniformAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_GLOROT_UNIFORM_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml b/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml new file mode 100644 index 0000000000..de7f9141b0 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "GlorotUniformAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" diff --git a/lib/pcg/include/pcg/initializers/norm_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/norm_initializer_attrs.dtg.h new file mode 100644 index 0000000000..e1d3e59ed7 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/norm_initializer_attrs.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "6843fc9ca02aea2b40e57dbc497f99ac" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_NORM_INITIALIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_NORM_INITIALIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct NormInitializerAttrs { + NormInitializerAttrs() = delete; + NormInitializerAttrs(int const &seed, float const &mean, float const &stddev); + + bool operator==(NormInitializerAttrs const &) const; + bool operator!=(NormInitializerAttrs const &) const; + bool operator<(NormInitializerAttrs const &) const; + bool operator>(NormInitializerAttrs const &) const; + bool operator<=(NormInitializerAttrs const &) const; + bool operator>=(NormInitializerAttrs const &) const; + int seed; + float mean; + float stddev; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::NormInitializerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::NormInitializerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::NormInitializerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(NormInitializerAttrs const &); +std::ostream &operator<<(std::ostream &, NormInitializerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_NORM_INITIALIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml new file mode 100644 index 0000000000..ec138de63e --- /dev/null +++ b/lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "NormInitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" + +[[fields]] +name = "mean" +type = "float" + +[[fields]] +name = "stddev" +type = "float" diff --git a/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h new file mode 100644 index 0000000000..1f4deada06 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "f887e1db5d5dc710793ec5fa99bb7cd4" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_UNIFORM_INITIALIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_UNIFORM_INITIALIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct UniformInitializerAttrs { + UniformInitializerAttrs() = delete; + UniformInitializerAttrs(int const &seed, + float const &min_val, + float const &max_val); + + bool operator==(UniformInitializerAttrs const &) const; + bool operator!=(UniformInitializerAttrs const &) const; + bool operator<(UniformInitializerAttrs const &) const; + bool operator>(UniformInitializerAttrs const &) const; + bool operator<=(UniformInitializerAttrs const &) const; + bool operator>=(UniformInitializerAttrs const &) const; + int seed; + float min_val; + float max_val; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::UniformInitializerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::UniformInitializerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::UniformInitializerAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(UniformInitializerAttrs const &); +std::ostream &operator<<(std::ostream &, UniformInitializerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_UNIFORM_INITIALIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml new file mode 100644 index 0000000000..11a6597c0a --- /dev/null +++ b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "UniformInitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" + +[[fields]] +name = "min_val" +type = "float" + +[[fields]] +name = "max_val" +type = "float" diff --git a/lib/pcg/include/pcg/initializers/zero_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/zero_initializer_attrs.dtg.h new file mode 100644 index 0000000000..f3086ea087 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/zero_initializer_attrs.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "a19d5a2cdc67a2840d6ba55250a10411" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_ZERO_INITIALIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_ZERO_INITIALIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ZeroInitializerAttrs { + bool operator==(ZeroInitializerAttrs const &) const; + bool operator!=(ZeroInitializerAttrs const &) const; + bool operator<(ZeroInitializerAttrs const &) const; + bool operator>(ZeroInitializerAttrs const &) const; + bool operator<=(ZeroInitializerAttrs const &) const; + bool operator>=(ZeroInitializerAttrs const &) const; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ZeroInitializerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ZeroInitializerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ZeroInitializerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ZeroInitializerAttrs const &); +std::ostream &operator<<(std::ostream &, ZeroInitializerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_ZERO_INITIALIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml new file mode 100644 index 0000000000..db1b6238d5 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "ZeroInitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/pcg/include/pcg/layer.h b/lib/pcg/include/pcg/layer.h deleted file mode 100644 index 9749cb9d06..0000000000 --- a/lib/pcg/include/pcg/layer.h +++ /dev/null @@ -1,33 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_LAYER_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_LAYER_H - -#include "op-attrs/operator_attrs.h" -#include "utils/stack_string.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct Layer { -public: - Layer() = delete; - Layer(CompGraphOperatorAttrs const &attrs, - std::optional const &name); - -public: - std::optional> name; - CompGraphOperatorAttrs attrs; -}; - -} // namespace FlexFlow - -VISITABLE_STRUCT(::FlexFlow::Layer, attrs, name); -MAKE_VISIT_HASHABLE(::FlexFlow::Layer); - -namespace FlexFlow { - -FF_VISIT_FMTABLE(Layer); -// CHECK_FMTABLE(Layer); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/layer_attrs.dtg.h b/lib/pcg/include/pcg/layer_attrs.dtg.h new file mode 100644 index 0000000000..9c9d277b67 --- /dev/null +++ b/lib/pcg/include/pcg/layer_attrs.dtg.h @@ -0,0 +1,60 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/layer_attrs.struct.toml +/* proj-data +{ + "generated_from": "12b49c15e8defff5118e5607a7823f59" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/operator_attrs.h" +#include "utils/json.h" +#include "utils/stack_string.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct LayerAttrs { + LayerAttrs() = delete; + LayerAttrs(::FlexFlow::CompGraphOperatorAttrs const &attrs, + std::optional<::FlexFlow::stack_string> const &name); + + bool operator==(LayerAttrs const &) const; + bool operator!=(LayerAttrs const &) const; + bool operator<(LayerAttrs const &) const; + bool operator>(LayerAttrs const &) const; + bool operator<=(LayerAttrs const &) const; + bool operator>=(LayerAttrs const &) const; + ::FlexFlow::CompGraphOperatorAttrs attrs; + std::optional<::FlexFlow::stack_string> name; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LayerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::LayerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::LayerAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(LayerAttrs const &); +std::ostream &operator<<(std::ostream &, LayerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/layer_attrs.struct.toml b/lib/pcg/include/pcg/layer_attrs.struct.toml new file mode 100644 index 0000000000..0dec35a1d8 --- /dev/null +++ b/lib/pcg/include/pcg/layer_attrs.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "LayerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/operator_attrs.h", + "utils/stack_string.h", + "", + "utils/json.h" +] + +[[fields]] +name = "attrs" +type = "::FlexFlow::CompGraphOperatorAttrs" + +[[fields]] +name = "name" +type = "std::optional<::FlexFlow::stack_string>" + diff --git a/lib/pcg/include/pcg/machine_specification.dtg.h b/lib/pcg/include/pcg/machine_specification.dtg.h new file mode 100644 index 0000000000..cd6ffe6c0f --- /dev/null +++ b/lib/pcg/include/pcg/machine_specification.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/machine_specification.struct.toml +/* proj-data +{ + "generated_from": "72c3ae372af189d0c8bae74c2dbbc531" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct MachineSpecification { + MachineSpecification() = delete; + MachineSpecification(int const &num_nodes, + int const &num_cpus_per_node, + int const &num_gpus_per_node, + float const &inter_node_bandwidth, + float const &intra_node_bandwidth); + + bool operator==(MachineSpecification const &) const; + bool operator!=(MachineSpecification const &) const; + bool operator<(MachineSpecification const &) const; + bool operator>(MachineSpecification const &) const; + bool operator<=(MachineSpecification const &) const; + bool operator>=(MachineSpecification const &) const; + int num_nodes; + int num_cpus_per_node; + int num_gpus_per_node; + float inter_node_bandwidth; + float intra_node_bandwidth; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MachineSpecification const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MachineSpecification from_json(json const &); + static void to_json(json &, FlexFlow::MachineSpecification const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(MachineSpecification const &); +std::ostream &operator<<(std::ostream &, MachineSpecification const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_DTG_H diff --git a/lib/pcg/include/pcg/machine_specification.h b/lib/pcg/include/pcg/machine_specification.h index 1b2a02b070..3886dcfe2e 100644 --- a/lib/pcg/include/pcg/machine_specification.h +++ b/lib/pcg/include/pcg/machine_specification.h @@ -1,31 +1,10 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_H #define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_H -#include "machine_view.h" -#include "utils/visitable.h" +#include "machine_specification_t.h" namespace FlexFlow { -struct BandwidthNetworkModelConfig - : public use_visitable_cmp { - int bandwidth; -}; - -struct MachineSpecification { - int num_nodes; - int num_cpus_per_node; - int num_gpus_per_node; - float inter_node_bandwidth; - req intra_node_bandwidth; -}; - -FF_VISITABLE_STRUCT(MachineSpecification, - num_nodes, - num_cpus_per_node, - num_gpus_per_node, - inter_node_bandwidth, - intra_node_bandwidth); - } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/machine_specification.struct.toml b/lib/pcg/include/pcg/machine_specification.struct.toml new file mode 100644 index 0000000000..e75b5018cb --- /dev/null +++ b/lib/pcg/include/pcg/machine_specification.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "MachineSpecification" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +[[fields]] +name = "num_nodes" +type = "int" + +[[fields]] +name = "num_cpus_per_node" +type = "int" + +[[fields]] +name = "num_gpus_per_node" +type = "int" + +[[fields]] +name = "inter_node_bandwidth" +type = "float" + +[[fields]] +name = "intra_node_bandwidth" +type = "float" diff --git a/lib/pcg/include/pcg/machine_view.dtg.h b/lib/pcg/include/pcg/machine_view.dtg.h new file mode 100644 index 0000000000..2eae6e2c8b --- /dev/null +++ b/lib/pcg/include/pcg/machine_view.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/machine_view.struct.toml +/* proj-data +{ + "generated_from": "16c571e6bb82d7ef88e5d2a9146638f4" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_VIEW_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_VIEW_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/device_id_t.dtg.h" +#include "pcg/strided_rectangle.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct MachineView { + MachineView() = delete; + MachineView(::FlexFlow::device_id_t const &start, + ::FlexFlow::StridedRectangle const &rect); + + bool operator==(MachineView const &) const; + bool operator!=(MachineView const &) const; + bool operator<(MachineView const &) const; + bool operator>(MachineView const &) const; + bool operator<=(MachineView const &) const; + bool operator>=(MachineView const &) const; + ::FlexFlow::device_id_t start; + ::FlexFlow::StridedRectangle rect; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MachineView const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MachineView from_json(json const &); + static void to_json(json &, FlexFlow::MachineView const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(MachineView const &); +std::ostream &operator<<(std::ostream &, MachineView const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_VIEW_DTG_H diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index 7521cd209a..60837a2abf 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -1,30 +1,19 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H #define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H -#include "device_id.h" -#include "device_type.h" -#include "strided_rectangle.h" -#include "utils/graph.h" -#include "utils/visitable.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/device_id_t.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/num_points_t.dtg.h" +#include "pcg/cpu_id_t.dtg.h" +#include "pcg/gpu_id_t.dtg.h" +#include "pcg/side_size_t.dtg.h" #include -#include #include namespace FlexFlow { -struct MachineView { - std::vector device_ids() const; - - device_id_t at(FFOrdered const &coord) const; - StridedRectangleSide at(size_t) const; - -public: - device_id_t start; - StridedRectangle rect; -}; - -FF_VISITABLE_STRUCT(MachineView, start, rect); - +std::vector device_ids(MachineView const &); std::size_t num_dims(MachineView const &); std::size_t num_devices(MachineView const &); DeviceType get_device_type(MachineView const &); diff --git a/lib/pcg/include/pcg/machine_view.struct.toml b/lib/pcg/include/pcg/machine_view.struct.toml new file mode 100644 index 0000000000..c97731991f --- /dev/null +++ b/lib/pcg/include/pcg/machine_view.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "MachineView" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "pcg/device_id_t.dtg.h", + "pcg/strided_rectangle.dtg.h", +] + +[[fields]] +name = "start" +type = "::FlexFlow::device_id_t" + +[[fields]] +name = "rect" +type = "::FlexFlow::StridedRectangle" diff --git a/lib/pcg/include/pcg/num_points_t.dtg.h b/lib/pcg/include/pcg/num_points_t.dtg.h new file mode 100644 index 0000000000..3b8e0e0c6c --- /dev/null +++ b/lib/pcg/include/pcg/num_points_t.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/num_points_t.struct.toml +/* proj-data +{ + "generated_from": "2a862b92055eda0508447d2f4df52f71" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_NUM_POINTS_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_NUM_POINTS_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct num_points_t { + num_points_t() = delete; + num_points_t(int const &unwrapped); + + bool operator==(num_points_t const &) const; + bool operator!=(num_points_t const &) const; + bool operator<(num_points_t const &) const; + bool operator>(num_points_t const &) const; + bool operator<=(num_points_t const &) const; + bool operator>=(num_points_t const &) const; + int unwrapped; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::num_points_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::num_points_t from_json(json const &); + static void to_json(json &, FlexFlow::num_points_t const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(num_points_t const &); +std::ostream &operator<<(std::ostream &, num_points_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_NUM_POINTS_T_DTG_H diff --git a/lib/pcg/include/pcg/num_points_t.struct.toml b/lib/pcg/include/pcg/num_points_t.struct.toml new file mode 100644 index 0000000000..b389245c63 --- /dev/null +++ b/lib/pcg/include/pcg/num_points_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "num_points_t" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "unwrapped" +type = "int" diff --git a/lib/pcg/include/pcg/operator.h b/lib/pcg/include/pcg/operator.h deleted file mode 100644 index bb9a4cf5e4..0000000000 --- a/lib/pcg/include/pcg/operator.h +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_OPERATOR_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_OPERATOR_H - -#include "op-attrs/operator_attrs.h" -#include "utils/stack_string.h" -#include "utils/visitable.h" - -#include - -namespace FlexFlow { - -struct Operator { -public: - operator PCGOperatorAttrs() const; - -public: - PCGOperatorAttrs attrs; - req> name; -}; - -FF_VISITABLE_STRUCT(Operator, attrs, name); - -static_assert(is_well_behaved_value_type::value); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/operator_guid_t.dtg.h b/lib/pcg/include/pcg/operator_guid_t.dtg.h new file mode 100644 index 0000000000..bf08150e5e --- /dev/null +++ b/lib/pcg/include/pcg/operator_guid_t.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_guid_t.struct.toml +/* proj-data +{ + "generated_from": "348b5a610f4ff6f545884564ee9a1e6a" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GUID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GUID_T_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct operator_guid_t { + operator_guid_t() = delete; + operator_guid_t(::FlexFlow::Node const &raw_graph_node); + + bool operator==(operator_guid_t const &) const; + bool operator!=(operator_guid_t const &) const; + bool operator<(operator_guid_t const &) const; + bool operator>(operator_guid_t const &) const; + bool operator<=(operator_guid_t const &) const; + bool operator>=(operator_guid_t const &) const; + ::FlexFlow::Node raw_graph_node; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::operator_guid_t const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(operator_guid_t const &); +std::ostream &operator<<(std::ostream &, operator_guid_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GUID_T_DTG_H diff --git a/lib/pcg/include/pcg/operator_guid_t.h b/lib/pcg/include/pcg/operator_guid_t.h deleted file mode 100644 index 46b640774a..0000000000 --- a/lib/pcg/include/pcg/operator_guid_t.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_OPERATOR_GUID_T_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_OPERATOR_GUID_T_H - -#include "utils/graph.h" -#include "utils/strong_typedef.h" - -namespace FlexFlow { - -struct operator_guid_t : strong_typedef { - using strong_typedef::strong_typedef; -}; - -} // namespace FlexFlow - -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::operator_guid_t, "operator_guid"); -MAKE_TYPEDEF_HASHABLE(::FlexFlow::operator_guid_t); - -#endif diff --git a/lib/pcg/include/pcg/operator_guid_t.struct.toml b/lib/pcg/include/pcg/operator_guid_t.struct.toml new file mode 100644 index 0000000000..f89d30137e --- /dev/null +++ b/lib/pcg/include/pcg/operator_guid_t.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "operator_guid_t" +features = [ + "eq", + "ord", + "hash", + # "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "raw_graph_node" +type = "::FlexFlow::Node" diff --git a/lib/pcg/include/pcg/optimizer.h b/lib/pcg/include/pcg/optimizer.h deleted file mode 100644 index 0bb3fab974..0000000000 --- a/lib/pcg/include/pcg/optimizer.h +++ /dev/null @@ -1,41 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H - -#include "utils/variant.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct SGDOptimizer { - double lr; - double momentum; - bool nesterov; - req weight_decay; -}; -FF_VISITABLE_STRUCT(SGDOptimizer, lr, momentum, nesterov, weight_decay); - -struct AdamOptimizer { - double alpha; - double beta1; - double beta2; - double weight_decay; - double epsilon; - double alpha_t; - double beta_t; - req beta2_t; -}; -FF_VISITABLE_STRUCT(AdamOptimizer, - alpha, - beta1, - beta2, - weight_decay, - epsilon, - alpha_t, - beta_t, - beta2_t); - -using Optimizer = std::variant; - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/optimizer_attrs.h b/lib/pcg/include/pcg/optimizer_attrs.h new file mode 100644 index 0000000000..03ba461c45 --- /dev/null +++ b/lib/pcg/include/pcg/optimizer_attrs.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H + +#include "utils/variant.h" +#include "pcg/optimizers/adam_optimizer_attrs.h" +#include "pcg/optimizers/sgd_optimizer_attrs.h" + +namespace FlexFlow { + +using OptimizerAttrs = std::variant; + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.h b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.h new file mode 100644 index 0000000000..a5a6a5ed0a --- /dev/null +++ b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.h @@ -0,0 +1,74 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "f49e1bebcb0ef2bc3c210073e3183d4d" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_ADAM_OPTIMIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_ADAM_OPTIMIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct AdamOptimizerAttrs { + AdamOptimizerAttrs() = delete; + AdamOptimizerAttrs(double const &alpha, + double const &beta1, + double const &beta2, + double const &weight_decay, + double const &alpha_t, + double const &beta_t, + double const &beta2_t); + + bool operator==(AdamOptimizerAttrs const &) const; + bool operator!=(AdamOptimizerAttrs const &) const; + bool operator<(AdamOptimizerAttrs const &) const; + bool operator>(AdamOptimizerAttrs const &) const; + bool operator<=(AdamOptimizerAttrs const &) const; + bool operator>=(AdamOptimizerAttrs const &) const; + double alpha; + double beta1; + double beta2; + double weight_decay; + double alpha_t; + double beta_t; + double beta2_t; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::AdamOptimizerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::AdamOptimizerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::AdamOptimizerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(AdamOptimizerAttrs const &); +std::ostream &operator<<(std::ostream &, AdamOptimizerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_ADAM_OPTIMIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml new file mode 100644 index 0000000000..fd3e83cc4a --- /dev/null +++ b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml @@ -0,0 +1,38 @@ +namespace = "FlexFlow" +name = "AdamOptimizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "alpha" +type = "double" + +[[fields]] +name = "beta1" +type = "double" + +[[fields]] +name = "beta2" +type = "double" + +[[fields]] +name = "weight_decay" +type = "double" + +[[fields]] +name = "alpha_t" +type = "double" + +[[fields]] +name = "beta_t" +type = "double" + +[[fields]] +name = "beta2_t" +type = "double" diff --git a/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.h b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.h new file mode 100644 index 0000000000..f6a17f2354 --- /dev/null +++ b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.h @@ -0,0 +1,68 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "d18c91cdddc760f1fb3990d2c817ee87" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_SGD_OPTIMIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_SGD_OPTIMIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct SGDOptimizerAttrs { + SGDOptimizerAttrs() = delete; + SGDOptimizerAttrs(double const &lr, + double const &momentum, + bool const &nesterov, + double const &weight_decay); + + bool operator==(SGDOptimizerAttrs const &) const; + bool operator!=(SGDOptimizerAttrs const &) const; + bool operator<(SGDOptimizerAttrs const &) const; + bool operator>(SGDOptimizerAttrs const &) const; + bool operator<=(SGDOptimizerAttrs const &) const; + bool operator>=(SGDOptimizerAttrs const &) const; + double lr; + double momentum; + bool nesterov; + double weight_decay; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SGDOptimizerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::SGDOptimizerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::SGDOptimizerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(SGDOptimizerAttrs const &); +std::ostream &operator<<(std::ostream &, SGDOptimizerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_SGD_OPTIMIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml new file mode 100644 index 0000000000..37affb0e1f --- /dev/null +++ b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "SGDOptimizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "lr" +type = "double" + +[[fields]] +name = "momentum" +type = "double" + +[[fields]] +name = "nesterov" +type = "bool" + +[[fields]] +name = "weight_decay" +type = "double" diff --git a/lib/pcg/include/pcg/parallel_computation_graph.dtg.h b/lib/pcg/include/pcg/parallel_computation_graph.dtg.h new file mode 100644 index 0000000000..f08e58a8b6 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph.dtg.h @@ -0,0 +1,30 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_computation_graph.struct.toml +/* proj-data +{ + "generated_from": "3bb0791e3481298ddea75f4bd134f9e1" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_DTG_H + +#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +struct ParallelComputationGraph { + ParallelComputationGraph() = delete; + ParallelComputationGraph(::FlexFlow::OutputLabelledMultiDiGraph< + ::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const &raw_graph); + + ::FlexFlow::OutputLabelledMultiDiGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> + raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph.h index 39a69a80ab..4dc2db5de4 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph.h @@ -1,31 +1,8 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H #define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H -#include "operator.h" -#include "parallel_tensor.h" -#include "utils/graph.h" +#include "pcg/parallel_computation_graph_t.h" -namespace FlexFlow { - -struct ParallelComputationGraph - : public strong_typedef< - ParallelComputationGraph, - OutputLabelledMultiDiGraph> { - using strong_typedef::strong_typedef; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(ParallelComputationGraph); - -bool operator==(ParallelComputationGraph const &, - ParallelComputationGraph const &); - -} // namespace FlexFlow - -namespace std { - -template <> -struct hash { - size_t operator()(FlexFlow::ParallelComputationGraph const &g) const; -}; -} // namespace std +namespace FlexFlow { } #endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph.struct.toml new file mode 100644 index 0000000000..5e9eaee4ab --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "ParallelComputationGraph" +features = [ ] + +includes = [ + "utils/graph.h", + "pcg/parallel_tensor_attrs.dtg.h", + "pcg/parallel_layer_attrs.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::OutputLabelledMultiDiGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/pcg/include/pcg/parallel_layer_attrs.dtg.h b/lib/pcg/include/pcg/parallel_layer_attrs.dtg.h new file mode 100644 index 0000000000..4c7fce4038 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_layer_attrs.dtg.h @@ -0,0 +1,60 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_layer_attrs.struct.toml +/* proj-data +{ + "generated_from": "97fa0b11c59ae892a8a530ffd67e33ad" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/operator_attrs.h" +#include "utils/stack_string.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ParallelLayerAttrs { + ParallelLayerAttrs() = delete; + ParallelLayerAttrs( + ::FlexFlow::PCGOperatorAttrs const &attrs, + std::optional<::FlexFlow::stack_string> const &name); + + bool operator==(ParallelLayerAttrs const &) const; + bool operator!=(ParallelLayerAttrs const &) const; + bool operator<(ParallelLayerAttrs const &) const; + bool operator>(ParallelLayerAttrs const &) const; + bool operator<=(ParallelLayerAttrs const &) const; + bool operator>=(ParallelLayerAttrs const &) const; + ::FlexFlow::PCGOperatorAttrs attrs; + std::optional<::FlexFlow::stack_string> name; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelLayerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelLayerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ParallelLayerAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelLayerAttrs const &); +std::ostream &operator<<(std::ostream &, ParallelLayerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/parallel_layer_attrs.struct.toml b/lib/pcg/include/pcg/parallel_layer_attrs.struct.toml new file mode 100644 index 0000000000..9b1f8f47aa --- /dev/null +++ b/lib/pcg/include/pcg/parallel_layer_attrs.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "ParallelLayerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/operator_attrs.h", + "utils/stack_string.h", + "", +] + +[[fields]] +name = "attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "name" +type = "std::optional<::FlexFlow::stack_string>" diff --git a/lib/pcg/include/pcg/parallel_tensor.h b/lib/pcg/include/pcg/parallel_tensor.h index 652b408c15..8fd2fc0e17 100644 --- a/lib/pcg/include/pcg/parallel_tensor.h +++ b/lib/pcg/include/pcg/parallel_tensor.h @@ -21,56 +21,14 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_TENSOR_H #define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_TENSOR_H -#include "create_grad.h" -#include "initializer.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "op-attrs/param_sync.h" +#include "pcg/parallel_tensor_attrs.h" namespace FlexFlow { -/** - * @brief Base structure of the parallel tensor representation. - * - * @details Parallel tensor is the fundamental component to support the - * representation and exploration of parallelization strategies. - */ -struct ParallelTensor : public use_visitable_cmp { - ParallelTensor() = delete; - - ParallelTensor(ParallelTensorShape const &, - CreateGrad create_gradients, - std::optional sync_type = std::nullopt, - std::optional initializer = std::nullopt); - ParallelTensor(ParallelTensorDims const &, - DataType, - CreateGrad create_gradients, - std::optional sync_type = std::nullopt, - std::optional initializer = std::nullopt); - - ParallelTensorShape get_shape() const; - -public: - ParallelTensorDims dims; - DataType data_type; - std::optional sync_type = std::nullopt; - std::optional initializer = std::nullopt; - CreateGrad create_gradients; -}; - -using ParallelParameter = ParallelTensor; - } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::ParallelTensor, - dims, - data_type, - sync_type, - initializer, - create_gradients); -MAKE_VISIT_HASHABLE(::FlexFlow::ParallelTensor); - namespace FlexFlow { -static_assert(is_well_behaved_value_type::value, ""); +static_assert(is_well_behaved_value_type::value, ""); } #endif diff --git a/lib/pcg/include/pcg/parallel_tensor_attrs.dtg.h b/lib/pcg/include/pcg/parallel_tensor_attrs.dtg.h new file mode 100644 index 0000000000..fa6b153b0a --- /dev/null +++ b/lib/pcg/include/pcg/parallel_tensor_attrs.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml +/* proj-data +{ + "generated_from": "b3e086b380bbc41d99332e1463a34b28" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/param_sync.dtg.h" +#include "pcg/create_grad.dtg.h" +#include "pcg/initializer_attrs.dtg.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ParallelTensorAttrs { + ParallelTensorAttrs() = delete; + ParallelTensorAttrs( + ::FlexFlow::ParallelTensorShape const &shape, + std::optional<::FlexFlow::ParamSync> const &sync_type, + std::optional<::FlexFlow::InitializerAttrs> const &initializer, + ::FlexFlow::CreateGrad const &create_gradients); + + bool operator==(ParallelTensorAttrs const &) const; + bool operator!=(ParallelTensorAttrs const &) const; + bool operator<(ParallelTensorAttrs const &) const; + bool operator>(ParallelTensorAttrs const &) const; + bool operator<=(ParallelTensorAttrs const &) const; + bool operator>=(ParallelTensorAttrs const &) const; + ::FlexFlow::ParallelTensorShape shape; + std::optional<::FlexFlow::ParamSync> sync_type; + std::optional<::FlexFlow::InitializerAttrs> initializer; + ::FlexFlow::CreateGrad create_gradients; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelTensorAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelTensorAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ParallelTensorAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelTensorAttrs const &); +std::ostream &operator<<(std::ostream &, ParallelTensorAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml b/lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml new file mode 100644 index 0000000000..1f81b56ec8 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "ParallelTensorAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.dtg.h", + "op-attrs/param_sync.dtg.h", + "pcg/initializer_attrs.dtg.h", + "pcg/create_grad.dtg.h", + "", +] + +[[fields]] +name = "shape" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "sync_type" +type = "std::optional<::FlexFlow::ParamSync>" + +[[fields]] +name = "initializer" +type = "std::optional<::FlexFlow::InitializerAttrs>" + +[[fields]] +name = "create_gradients" +type = "::FlexFlow::CreateGrad" diff --git a/lib/pcg/include/pcg/serialization.h b/lib/pcg/include/pcg/serialization.h deleted file mode 100644 index 28e16aeb1e..0000000000 --- a/lib/pcg/include/pcg/serialization.h +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_SERIALIZATION_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_SERIALIZATION_H - -#include "computation_graph.h" -#include "layer.h" -#include "machine_specification.h" -#include "parallel_computation_graph.h" -#include "parallel_tensor.h" -#include "tensor_mapping.h" -#include "utils/json.h" - -namespace FlexFlow { - -void from_json(json const &, ComputationGraph &); -void to_json(json &, ComputationGraph const &); - -void from_json(json const &, ParallelComputationGraph &); -void to_json(json &, ParallelComputationGraph const &); - -void from_json(json const &, Layer &); -void to_json(json &, Layer const &); - -void from_json(json const &, ParallelTensor &); -void to_json(json &, ParallelTensor const &); - -void from_json(json const &, Tensor &); -void to_json(json &, Tensor const &); - -void from_json(json const &, Initializer &); -void to_json(json &, Initializer const &); - -void from_json(json const &, MachineSpecification &); -void to_json(json &, MachineSpecification const &); - -void from_json(json const &, Operator &); -void to_json(json &, Operator const &); - -void from_json(json const &, MachineView &); -void to_json(json &, MachineView const &); - -void from_json(json const &, StridedRectangle &); -void to_json(json &, StridedRectangle const &); - -void from_json(json const &, StridedRectangleSide &); -void to_json(json &, StridedRectangleSide const &); - -void from_json(json const &, ParallelTensorDims &); -void to_json(json &, ParallelTensorDims const &); - -void from_json(json const &, TensorDims &); -void to_json(json &, TensorDims const &); - -void from_json(json const &, TensorMapping &); -void to_json(json &, TensorMapping const &); - -void from_json(json const &, ParallelTensorShape &); -void to_json(json &, ParallelTensorShape const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/side_size_t.dtg.h b/lib/pcg/include/pcg/side_size_t.dtg.h new file mode 100644 index 0000000000..fce31b1c9d --- /dev/null +++ b/lib/pcg/include/pcg/side_size_t.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/side_size_t.struct.toml +/* proj-data +{ + "generated_from": "6a1669890e547dcc7a4ddb90be05be15" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_SIDE_SIZE_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_SIDE_SIZE_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct side_size_t { + side_size_t() = delete; + side_size_t(int const &unwrapped); + + bool operator==(side_size_t const &) const; + bool operator!=(side_size_t const &) const; + bool operator<(side_size_t const &) const; + bool operator>(side_size_t const &) const; + bool operator<=(side_size_t const &) const; + bool operator>=(side_size_t const &) const; + int unwrapped; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::side_size_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::side_size_t from_json(json const &); + static void to_json(json &, FlexFlow::side_size_t const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(side_size_t const &); +std::ostream &operator<<(std::ostream &, side_size_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_SIDE_SIZE_T_DTG_H diff --git a/lib/pcg/include/pcg/side_size_t.struct.toml b/lib/pcg/include/pcg/side_size_t.struct.toml new file mode 100644 index 0000000000..dbaad4fedb --- /dev/null +++ b/lib/pcg/include/pcg/side_size_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "side_size_t" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "unwrapped" +type = "int" diff --git a/lib/pcg/include/pcg/strided_rectangle.dtg.h b/lib/pcg/include/pcg/strided_rectangle.dtg.h new file mode 100644 index 0000000000..cacc11093d --- /dev/null +++ b/lib/pcg/include/pcg/strided_rectangle.dtg.h @@ -0,0 +1,57 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/strided_rectangle.struct.toml +/* proj-data +{ + "generated_from": "87af84e6a16d5363049cb9a9a75e4f5f" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/dim_ordered.h" +#include "pcg/strided_rectangle_side.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct StridedRectangle { + StridedRectangle() = delete; + StridedRectangle( + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide> const &sides); + + bool operator==(StridedRectangle const &) const; + bool operator!=(StridedRectangle const &) const; + bool operator<(StridedRectangle const &) const; + bool operator>(StridedRectangle const &) const; + bool operator<=(StridedRectangle const &) const; + bool operator>=(StridedRectangle const &) const; + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide> sides; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::StridedRectangle const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::StridedRectangle from_json(json const &); + static void to_json(json &, FlexFlow::StridedRectangle const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(StridedRectangle const &); +std::ostream &operator<<(std::ostream &, StridedRectangle const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_DTG_H diff --git a/lib/pcg/include/pcg/strided_rectangle.h b/lib/pcg/include/pcg/strided_rectangle.h index d123d7c6ac..48bd4e8146 100644 --- a/lib/pcg/include/pcg/strided_rectangle.h +++ b/lib/pcg/include/pcg/strided_rectangle.h @@ -1,62 +1,15 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_H #define _FLEXFLOW_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_H -#include "op-attrs/dim_ordered.h" -#include "op-attrs/ff_dim.h" -#include "utils/stack_vector.h" -#include "utils/strong_typedef.h" -#include "utils/visitable.h" +#include "op-attrs/ff_dim.dtg.h" +#include "pcg/side_size_t.dtg.h" +#include "pcg/strided_rectangle.dtg.h" namespace FlexFlow { -struct num_points_t : public strong_typedef { - using strong_typedef::strong_typedef; -}; - -struct side_size_t : public strong_typedef { - using strong_typedef::strong_typedef; -}; - -struct StridedRectangleSide { -public: - StridedRectangleSide() = delete; - StridedRectangleSide(num_points_t const &, int stride); - StridedRectangleSide(side_size_t const &, int stride); - - num_points_t get_num_points() const; - side_size_t get_size() const; - int get_stride() const; - - side_size_t at(num_points_t) const; - num_points_t at(side_size_t) const; - -public: - num_points_t num_points; - req stride; -}; - -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(StridedRectangleSide, - num_points, - stride); - -struct StridedRectangle { -public: - size_t at(FFOrdered const &) const; - StridedRectangleSide at(ff_dim_t const &) const; - size_t num_dims() const; - -public: - FFOrdered sides; -}; - -FF_VISITABLE_STRUCT(StridedRectangle, sides); +size_t get_num_dims(StridedRectangle const &); +StridedRectangleSide get_side_at_idx(StridedRectangle const &, ff_dim_t const &); } // namespace FlexFlow -MAKE_TYPEDEF_HASHABLE(::FlexFlow::num_points_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::num_points_t, "num_points"); - -MAKE_TYPEDEF_HASHABLE(::FlexFlow::side_size_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::side_size_t, "side_size"); - #endif diff --git a/lib/pcg/include/pcg/strided_rectangle.struct.toml b/lib/pcg/include/pcg/strided_rectangle.struct.toml new file mode 100644 index 0000000000..ec9eca9ffa --- /dev/null +++ b/lib/pcg/include/pcg/strided_rectangle.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "StridedRectangle" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "pcg/strided_rectangle_side.dtg.h", + "op-attrs/dim_ordered.h", +] + +[[fields]] +name = "sides" +type = "::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>" diff --git a/lib/pcg/include/pcg/strided_rectangle_side.dtg.h b/lib/pcg/include/pcg/strided_rectangle_side.dtg.h new file mode 100644 index 0000000000..3e4365c24d --- /dev/null +++ b/lib/pcg/include/pcg/strided_rectangle_side.dtg.h @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/strided_rectangle_side.struct.toml +/* proj-data +{ + "generated_from": "b14fcf1e28c262d22b92fac691ede3d4" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/num_points_t.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct StridedRectangleSide { + StridedRectangleSide() = delete; + StridedRectangleSide(::FlexFlow::num_points_t const &num_points, + int const &stride); + + bool operator==(StridedRectangleSide const &) const; + bool operator!=(StridedRectangleSide const &) const; + bool operator<(StridedRectangleSide const &) const; + bool operator>(StridedRectangleSide const &) const; + bool operator<=(StridedRectangleSide const &) const; + bool operator>=(StridedRectangleSide const &) const; + ::FlexFlow::num_points_t num_points; + int stride; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::StridedRectangleSide const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::StridedRectangleSide from_json(json const &); + static void to_json(json &, FlexFlow::StridedRectangleSide const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(StridedRectangleSide const &); +std::ostream &operator<<(std::ostream &, StridedRectangleSide const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_DTG_H diff --git a/lib/pcg/include/pcg/strided_rectangle_side.h b/lib/pcg/include/pcg/strided_rectangle_side.h new file mode 100644 index 0000000000..540bb76bc8 --- /dev/null +++ b/lib/pcg/include/pcg/strided_rectangle_side.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_H + +#include "pcg/strided_rectangle_side.dtg.h" +#include "pcg/side_size_t.dtg.h" + +namespace FlexFlow { + +StridedRectangleSide strided_side_from_size_and_stride(side_size_t, int stride); + +side_size_t get_side_size(StridedRectangleSide const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/strided_rectangle_side.struct.toml b/lib/pcg/include/pcg/strided_rectangle_side.struct.toml new file mode 100644 index 0000000000..f26adfafd5 --- /dev/null +++ b/lib/pcg/include/pcg/strided_rectangle_side.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "StridedRectangleSide" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "pcg/num_points_t.dtg.h", +] + +[[fields]] +name = "num_points" +type = "::FlexFlow::num_points_t" + +[[fields]] +name = "stride" +type = "int" diff --git a/lib/pcg/include/pcg/tensor.h b/lib/pcg/include/pcg/tensor.h deleted file mode 100644 index 975a69809d..0000000000 --- a/lib/pcg/include/pcg/tensor.h +++ /dev/null @@ -1,38 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_TENSOR_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_TENSOR_H - -#include "create_grad.h" -#include "initializer.h" -#include "op-attrs/param_sync.h" -#include "op-attrs/tensor_shape.h" - -namespace FlexFlow { - -struct Tensor { - /* Tensor() = delete; */ - /* Tensor(TensorShape const &, */ - /* CreateGrad create_gradients, */ - /* optional initializer = nullopt, */ - /* optional sync_type = nullopt); */ - - size_t get_volume() const; - TensorShape get_shape() const; - int num_dims() const; - - operator TensorShape() const; - -public: - TensorDims dims; - DataType data_type; - std::optional initializer; - bool create_gradients; - req> sync_type; -}; -FF_VISITABLE_STRUCT( - Tensor, dims, data_type, initializer, create_gradients, sync_type); - -using Parameter = Tensor; - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/tensor_attrs.dtg.h b/lib/pcg/include/pcg/tensor_attrs.dtg.h new file mode 100644 index 0000000000..8bc9d3ce9d --- /dev/null +++ b/lib/pcg/include/pcg/tensor_attrs.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/tensor_attrs.struct.toml +/* proj-data +{ + "generated_from": "68447a4357476647ef25dd39dfd12578" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/param_sync.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "pcg/initializer_attrs.dtg.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct TensorAttrs { + TensorAttrs() = delete; + TensorAttrs(::FlexFlow::TensorShape const &shape, + std::optional<::FlexFlow::InitializerAttrs> const &initializer, + bool const &create_gradients, + std::optional<::FlexFlow::ParamSync> const &sync_type); + + bool operator==(TensorAttrs const &) const; + bool operator!=(TensorAttrs const &) const; + bool operator<(TensorAttrs const &) const; + bool operator>(TensorAttrs const &) const; + bool operator<=(TensorAttrs const &) const; + bool operator>=(TensorAttrs const &) const; + ::FlexFlow::TensorShape shape; + std::optional<::FlexFlow::InitializerAttrs> initializer; + bool create_gradients; + std::optional<::FlexFlow::ParamSync> sync_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorAttrs from_json(json const &); + static void to_json(json &, FlexFlow::TensorAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttrs const &); +std::ostream &operator<<(std::ostream &, TensorAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/tensor_attrs.struct.toml b/lib/pcg/include/pcg/tensor_attrs.struct.toml new file mode 100644 index 0000000000..eefb6da702 --- /dev/null +++ b/lib/pcg/include/pcg/tensor_attrs.struct.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "TensorAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_shape.dtg.h", + "pcg/initializer_attrs.dtg.h", + "op-attrs/param_sync.dtg.h", + "", +] + +[[fields]] +name = "shape" +type = "::FlexFlow::TensorShape" + +[[fields]] +name = "initializer" +type = "std::optional<::FlexFlow::InitializerAttrs>" + +[[fields]] +name = "create_gradients" +type = "bool" + +[[fields]] +name = "sync_type" +type = "std::optional<::FlexFlow::ParamSync>" diff --git a/lib/pcg/include/pcg/tensor_guid_t.dtg.h b/lib/pcg/include/pcg/tensor_guid_t.dtg.h new file mode 100644 index 0000000000..c6109c6103 --- /dev/null +++ b/lib/pcg/include/pcg/tensor_guid_t.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/tensor_guid_t.struct.toml +/* proj-data +{ + "generated_from": "dc15fcbb876ec70509dfa8b662963bc3" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_GUID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_GUID_T_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct tensor_guid_t { + tensor_guid_t() = delete; + tensor_guid_t(::FlexFlow::MultiDiOutput const &raw_graph_output); + + bool operator==(tensor_guid_t const &) const; + bool operator!=(tensor_guid_t const &) const; + bool operator<(tensor_guid_t const &) const; + bool operator>(tensor_guid_t const &) const; + bool operator<=(tensor_guid_t const &) const; + bool operator>=(tensor_guid_t const &) const; + ::FlexFlow::MultiDiOutput raw_graph_output; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::tensor_guid_t const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(tensor_guid_t const &); +std::ostream &operator<<(std::ostream &, tensor_guid_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_GUID_T_DTG_H diff --git a/lib/pcg/include/pcg/tensor_guid_t.h b/lib/pcg/include/pcg/tensor_guid_t.h deleted file mode 100644 index 3e4e840a5f..0000000000 --- a/lib/pcg/include/pcg/tensor_guid_t.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_TENSOR_GUID_T_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_TENSOR_GUID_T_H - -#include "utils/graph.h" - -namespace FlexFlow { - -struct tensor_guid_t : strong_typedef { - using strong_typedef::strong_typedef; -}; - -} // namespace FlexFlow - -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::tensor_guid_t, "tensor_guid"); -MAKE_TYPEDEF_HASHABLE(::FlexFlow::tensor_guid_t); - -#endif diff --git a/lib/pcg/include/pcg/tensor_guid_t.struct.toml b/lib/pcg/include/pcg/tensor_guid_t.struct.toml new file mode 100644 index 0000000000..aea4fad108 --- /dev/null +++ b/lib/pcg/include/pcg/tensor_guid_t.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "tensor_guid_t" +features = [ + "eq", + "ord", + "hash", + # "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "raw_graph_output" +type = "::FlexFlow::MultiDiOutput" diff --git a/lib/pcg/src/device_id.cc b/lib/pcg/src/device_id.cc deleted file mode 100644 index 2849df7c3c..0000000000 --- a/lib/pcg/src/device_id.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "pcg/device_id.h" -#include "utils/exception.h" -#include - -namespace FlexFlow { - -DeviceType get_device_type(device_id_t const &id) { - if (std::holds_alternative(id)) { - return DeviceType::GPU; - } else { - assert(std::holds_alternative(id)); - return DeviceType::CPU; - } -} - -device_id_t operator+(device_id_t, size_t) { - NOT_IMPLEMENTED(); -} -} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/graphs.cc b/lib/pcg/src/file_format/v1/graphs.cc index d00de7b0c1..69fbb4e88e 100644 --- a/lib/pcg/src/file_format/v1/graphs.cc +++ b/lib/pcg/src/file_format/v1/graphs.cc @@ -3,21 +3,15 @@ namespace FlexFlow { -V1MultiDiGraph to_v1(MultiDiGraphView const &g) { - return to_v1(g, - enumerate(get_nodes(g)).reversed(), - enumerate(get_present_node_ports(g)).reversed()); -} - -V1MultiDiGraph to_v1(MultiDiGraphView const &g, - std::unordered_map const &nodes, - std::unordered_map const &node_ports) { +static V1MultiDiGraph to_v1(MultiDiGraphView const &g, + bidict const &nodes, + bidict const &node_ports) { std::unordered_set edges; for (MultiDiEdge const &e : get_edges(g)) { - edges.insert({nodes.at(e.src), - node_ports.at(e.src_idx), - nodes.at(e.dst), - node_ports.at(e.dst_idx)}); + edges.insert({nodes.at_l(e.src), + node_ports.at_l(e.src_idx), + nodes.at_l(e.dst), + node_ports.at_l(e.dst_idx)}); } return V1MultiDiGraph{ @@ -27,32 +21,40 @@ V1MultiDiGraph to_v1(MultiDiGraphView const &g, }; } +static V1MultiDiGraph to_v1(MultiDiGraphView const &g) { + return to_v1(g, + enumerate(get_nodes(g)).reversed(), + enumerate(get_present_node_ports(g)).reversed()); +} + template -V1JsonableGraph())), - decltype(to_v1(std::declval()))> +static V1JsonableGraph to_v1(OutputLabelledMultiDiGraph const &g) { - using V1NodeLabel = decltype(to_v1(std::declval())); - using V1OutputLabel = decltype(to_v1(std::declval())); - bidict nodes = enumerate(get_nodes(g)); bidict node_ports = enumerate(get_present_node_ports(g)); V1MultiDiGraph unlabelled = to_v1(g, nodes.reversed(), node_ports.reversed()); - std::unordered_map node_labels = - map_values(nodes, [&](Node const &n) { return to_v1(g.at(n)); }); + std::unordered_map node_labels = + map_values(nodes, [&](Node const &n) { return g.at(n); }); + bidict outputs_bidict = enumerate(get_outputs(g)); std::unordered_map outputs = map_values(outputs_bidict, [&](MultiDiOutput const &o) { return V1GraphOutput{nodes.at_r(o.src), node_ports.at_r(o.src_idx)}; }); - std::unordered_map output_labels = map_values( - outputs_bidict, [&](MultiDiOutput const &o) { return to_v1(g.at(o)); }); + + std::unordered_map output_labels = map_values( + outputs_bidict, [&](MultiDiOutput const &o) { return g.at(o); }); return {node_labels, outputs, output_labels, unlabelled}; } V1ComputationGraph to_v1(ComputationGraph const &g) { - return to_v1(g.value()); + return to_v1(g.raw_graph); +} + +V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) { + return to_v1(g.raw_graph); } } // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/v1.cc b/lib/pcg/src/file_format/v1/v1.cc deleted file mode 100644 index 7715985eed..0000000000 --- a/lib/pcg/src/file_format/v1/v1.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "pcg/file_format/v1/v1.h" - -namespace FlexFlow { - -V1Tensor to_v1(Tensor const &) { - NOT_IMPLEMENTED(); -} - -V1Layer to_v1(Layer const &) { - NOT_IMPLEMENTED(); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/layer.cc b/lib/pcg/src/layer.cc deleted file mode 100644 index 00fb07a8c5..0000000000 --- a/lib/pcg/src/layer.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "pcg/layer.h" - -namespace FlexFlow { - -Layer::Layer(CompGraphOperatorAttrs const &_attrs, - std::optional const &_name) - : attrs(_attrs), name(_name) {} - -} // namespace FlexFlow diff --git a/lib/pcg/src/machine_view.cc b/lib/pcg/src/machine_view.cc deleted file mode 100644 index 46f87833f0..0000000000 --- a/lib/pcg/src/machine_view.cc +++ /dev/null @@ -1,29 +0,0 @@ -#include "pcg/machine_view.h" -#include "utils/utils.h" - -namespace FlexFlow { - -static StridedRectangle make_1d_rect(int start, int stop, int stride) { - assert(stop > start); - assert(stride > 0); - StridedRectangleSide side = {side_size_t(stop - start), stride}; - StridedRectangle rect = {{side}}; - return rect; -} - -MachineView make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride) { - StridedRectangle rect = make_1d_rect(start.value(), stop.value(), stride); - return {start, rect}; -} - -MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride) { - StridedRectangle rect = make_1d_rect(start.value(), stop.value(), stride); - return {start, rect}; -} - -device_id_t MachineView::at(FFOrdered const &coord) const { - size_t offset = this->rect.at(coord); - return this->start + offset; -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/operator.cc b/lib/pcg/src/operator.cc deleted file mode 100644 index 9d36ae1b25..0000000000 --- a/lib/pcg/src/operator.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "pcg/operator.h" - -namespace FlexFlow { - -Operator::operator PCGOperatorAttrs() const { - return attrs; -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/parallel_computation_graph.cc b/lib/pcg/src/parallel_computation_graph.cc deleted file mode 100644 index 011c40eb4c..0000000000 --- a/lib/pcg/src/parallel_computation_graph.cc +++ /dev/null @@ -1,40 +0,0 @@ -#include "pcg/parallel_computation_graph.h" -#include "utils/graph/algorithms.h" - -namespace FlexFlow { - -bool operator==(ParallelComputationGraph const &lhs, - ParallelComputationGraph const &rhs) { - return std::hash{}(lhs) == - std::hash{}(rhs); -} - -} // namespace FlexFlow - -namespace std { - -size_t hash::operator()( - FlexFlow::ParallelComputationGraph const &g) const { - using namespace FlexFlow; - - size_t h = 0; - - std::vector ordered_nodes = get_topological_ordering(g.value()); - hash_combine(h, ordered_nodes.size()); - - std::unordered_map node_index; - for (int i = 0; i < ordered_nodes.size(); ++i) { - node_index[ordered_nodes[i]] = i; - hash_combine(h, g.value().at(ordered_nodes[i])); - } - - for (MultiDiEdge const &edge : get_edges(g.value())) { - hash_combine(h, node_index.at(edge.src)); - hash_combine(h, node_index.at(edge.dst)); - hash_combine(h, g.value().at(edge)); - } - - return h; -} - -} // namespace std diff --git a/lib/pcg/src/parallel_tensor.cc b/lib/pcg/src/parallel_tensor.cc deleted file mode 100644 index ff53e456ec..0000000000 --- a/lib/pcg/src/parallel_tensor.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "pcg/parallel_tensor.h" - -namespace FlexFlow { - -ParallelTensor::ParallelTensor(ParallelTensorDims const &dims, - DataType data_type, - CreateGrad create_gradients, - std::optional sync_type, - std::optional initializer) - : dims(dims), data_type(data_type), sync_type(sync_type), - initializer(initializer), create_gradients(create_gradients) {} - -ParallelTensorShape ParallelTensor::get_shape() const { - return ParallelTensorShape(dims, data_type); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph.dtg.cc b/lib/pcg/src/pcg/computation_graph.dtg.cc new file mode 100644 index 0000000000..b9b2ae56ee --- /dev/null +++ b/lib/pcg/src/pcg/computation_graph.dtg.cc @@ -0,0 +1,22 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/computation_graph.struct.toml +/* proj-data +{ + "generated_from": "3639f7e8bb97a5ca2c2ef13caff3c84e" +} +*/ + +#include "pcg/computation_graph.dtg.h" + +#include "pcg/layer_attrs.dtg.h" +#include "pcg/tensor_attrs.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +ComputationGraph::ComputationGraph( + ::FlexFlow::OutputLabelledMultiDiGraph<::FlexFlow::LayerAttrs, + ::FlexFlow::TensorAttrs> const + &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/pcg/src/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc similarity index 55% rename from lib/pcg/src/computation_graph_builder.cc rename to lib/pcg/src/pcg/computation_graph_builder.cc index 68aeb45ff2..f6336f9510 100644 --- a/lib/pcg/src/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -3,49 +3,55 @@ #include "op-attrs/get_output_shapes.h" #include "utils/expected.h" #include "utils/fmt.h" +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/embedding.h" namespace FlexFlow { -void ComputationGraphBuilder::add_layer(Layer const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs) { +void ComputationGraphBuilder::add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs) { NOT_IMPLEMENTED(); } -Tensor ComputationGraphBuilder::add_layer( - Layer const &layer, - std::vector const &inputs, - std::vector>> const + +tensor_guid_t ComputationGraphBuilder::add_layer( + LayerAttrs const &layer, + std::vector const &inputs, + std::vector>> const &weight_shapes, TensorShape const &output_shape) { NOT_IMPLEMENTED(); } -std::vector ComputationGraphBuilder::add_layer( - Layer const &layer, - std::vector const &inputs, - std::vector>> const + +std::vector ComputationGraphBuilder::add_layer( + LayerAttrs const &layer, + std::vector const &inputs, + std::vector>> const &weight_shapes, std::vector const &output_shapes) { NOT_IMPLEMENTED(); } -Tensor ComputationGraphBuilder::broadcast(Tensor const &, TensorShape const &) { +tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &, TensorShape const &) { NOT_IMPLEMENTED(); } -Tensor ComputationGraphBuilder::cast(Tensor const &input, + +tensor_guid_t ComputationGraphBuilder::cast(tensor_guid_t const &input, DataType dtype, std::optional const &name){ NOT_IMPLEMENTED()} -Tensor ComputationGraphBuilder::as_type(Tensor const &x, +tensor_guid_t ComputationGraphBuilder::as_type(tensor_guid_t const &x, DataType data_type, std::string const &name) { - if (x.data_type < data_type) { + DataType x_datatype = this->get_shape(x).data_type; + if (x_datatype < data_type) { return this->cast(x, data_type, name); - } else if (x.data_type > data_type) { + } else if (x_datatype > data_type) { throw mk_runtime_error("Could not convert provided tensor data type {} to " "desired data type {}", - x.data_type, + x_datatype, data_type); } return x; @@ -64,197 +70,204 @@ static std::string get_default_name(std::variant const &attrs) { return get_default_name(widen(attrs)); } -Tensor ComputationGraphBuilder::element_unary( +tensor_guid_t ComputationGraphBuilder::element_unary( ElementUnaryAttrs const &attrs, - Tensor const &x, + tensor_guid_t const &x, std::optional const &maybe_name) { std::string name = maybe_name.value_or(get_default_name(attrs)); - Tensor input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - Layer layer = {attrs, name}; - TensorShape output_shape = get_output_shape(attrs, input); + LayerAttrs layer = {attrs, name}; + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); return this->add_layer(layer, {input}, {}, output_shape); } -Tensor ComputationGraphBuilder::element_scalar_unary( +tensor_guid_t ComputationGraphBuilder::element_scalar_unary( ElementScalarUnaryAttrs const &attrs, - Tensor const &x, + tensor_guid_t const &x, std::optional const &maybe_name) { std::string name = maybe_name.value_or(get_default_name(attrs)); - Tensor input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - Layer layer = {attrs, name}; - TensorShape output_shape = get_output_shape(attrs, input); + LayerAttrs layer = {attrs, name}; + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); return this->add_layer(layer, {input}, {}, output_shape); } -Tensor ComputationGraphBuilder::element_unary( +tensor_guid_t ComputationGraphBuilder::element_unary( OperatorType op_type, - Tensor const &input, + tensor_guid_t const &input, std::optional const &name) { ElementUnaryAttrs attrs = {op_type}; return this->element_unary(attrs, input, name); } -Tensor ComputationGraphBuilder::element_scalar_unary( +tensor_guid_t ComputationGraphBuilder::element_scalar_unary( OperatorType op_type, - Tensor const &input, + tensor_guid_t const &input, float scalar, std::optional const &name) { ElementScalarUnaryAttrs attrs = {op_type, scalar}; return this->element_scalar_unary(attrs, input, name); } -Tensor ComputationGraphBuilder::element_binary( +tensor_guid_t ComputationGraphBuilder::element_binary( OperatorType op_type, - Tensor const &lhs, - Tensor const &rhs, + tensor_guid_t const &lhs, + tensor_guid_t const &rhs, std::optional const &maybe_name) { std::string name = maybe_name.value_or(get_default_name(op_type)); TensorShape compute_shape = this->get_broadcast_target_shape({lhs, rhs}); - DataType compute_type = std::max(lhs.data_type, rhs.data_type); + DataType compute_type = std::max( + this->get_shape(lhs).data_type, + this->get_shape(rhs).data_type + ); - Tensor const lhs_input = this->as_type(this->broadcast(lhs, compute_shape), + tensor_guid_t lhs_input = this->as_type(this->broadcast(lhs, compute_shape), compute_type, name + "_inputl_pre_cast"); - Tensor const rhs_input = this->as_type(this->broadcast(rhs, compute_shape), + tensor_guid_t rhs_input = this->as_type(this->broadcast(rhs, compute_shape), compute_type, name + "_inputr_pre_cast"); ElementBinaryAttrs attrs = {op_type, compute_type, false, false}; - Layer layer = {attrs, name}; - TensorShape output_shape = get_output_shape(attrs, lhs_input, rhs_input); + LayerAttrs layer = {attrs, name}; + TensorShape output_shape = get_output_shape( + attrs, + this->get_shape(lhs_input), + this->get_shape(rhs_input) + ); return this->add_layer(layer, {lhs_input, rhs_input}, {}, output_shape); } -Tensor ComputationGraphBuilder::exp(Tensor const &input, +tensor_guid_t ComputationGraphBuilder::exp(tensor_guid_t const &input, std::optional const &name) { return this->element_unary(OperatorType::EXP, input, name); } -Tensor ComputationGraphBuilder::add(Tensor const &lhs, - Tensor const &rhs, +tensor_guid_t ComputationGraphBuilder::add(tensor_guid_t const &lhs, + tensor_guid_t const &rhs, std::optional const &name) { return this->element_binary(OperatorType::EW_ADD, lhs, rhs, name); } -Tensor - ComputationGraphBuilder::subtract(Tensor const &lhs, - Tensor const &rhs, +tensor_guid_t + ComputationGraphBuilder::subtract(tensor_guid_t const &lhs, + tensor_guid_t const &rhs, std::optional const &name) { return this->element_binary(OperatorType::EW_SUB, lhs, rhs, name); } -Tensor - ComputationGraphBuilder::multiply(Tensor const &lhs, - Tensor const &rhs, +tensor_guid_t + ComputationGraphBuilder::multiply(tensor_guid_t const &lhs, + tensor_guid_t const &rhs, std::optional const &name) { return this->element_binary(OperatorType::EW_MUL, lhs, rhs, name); } -Tensor ComputationGraphBuilder::divide(Tensor const &lhs, - Tensor const &rhs, +tensor_guid_t ComputationGraphBuilder::divide(tensor_guid_t const &lhs, + tensor_guid_t const &rhs, std::optional const &name) { return this->element_binary(OperatorType::EW_DIV, lhs, rhs, name); } -Tensor ComputationGraphBuilder::max(Tensor const &lhs, - Tensor const &rhs, +tensor_guid_t ComputationGraphBuilder::max(tensor_guid_t const &lhs, + tensor_guid_t const &rhs, std::optional const &name) { return this->element_binary(OperatorType::EW_MAX, lhs, rhs, name); } -Tensor ComputationGraphBuilder::min(Tensor const &lhs, - Tensor const &rhs, +tensor_guid_t ComputationGraphBuilder::min(tensor_guid_t const &lhs, + tensor_guid_t const &rhs, std::optional const &name) { return this->element_binary(OperatorType::EW_MIN, lhs, rhs, name); } -Tensor ComputationGraphBuilder::rsqrt(Tensor const &input, +tensor_guid_t ComputationGraphBuilder::rsqrt(tensor_guid_t const &input, std::optional const &name) { return this->element_unary(OperatorType::RSQRT, input, name); } -Tensor ComputationGraphBuilder::pow(Tensor const &input, +tensor_guid_t ComputationGraphBuilder::pow(tensor_guid_t const &input, float exponent, std::optional const &name) { return this->element_scalar_unary(OperatorType::POW, input, exponent, name); } -Tensor ComputationGraphBuilder::scalar_multiply( - Tensor const &input, float scalar, std::optional const &name) { +tensor_guid_t ComputationGraphBuilder::scalar_multiply( + tensor_guid_t const &input, float scalar, std::optional const &name) { return this->element_scalar_unary(OperatorType::SCALAR_MULTIPLY, input, scalar, name); } -Tensor ComputationGraphBuilder::scalar_add( - Tensor const &input, float scalar, std::optional const &name) { +tensor_guid_t ComputationGraphBuilder::scalar_add( + tensor_guid_t const &input, float scalar, std::optional const &name) { return this->element_scalar_unary(OperatorType::SCALAR_ADD, input, scalar, name); } -Tensor ComputationGraphBuilder::scalar_sub( - Tensor const &lhs, float rhs, std::optional const &name) { +tensor_guid_t ComputationGraphBuilder::scalar_sub( + tensor_guid_t const &lhs, float rhs, std::optional const &name) { return this->element_scalar_unary(OperatorType::SCALAR_SUB, lhs, rhs, name); } -Tensor ComputationGraphBuilder::scalar_truediv( - Tensor const &numerator, +tensor_guid_t ComputationGraphBuilder::scalar_truediv( + tensor_guid_t const &numerator, float denominator, std::optional const &name) { return this->element_scalar_unary( OperatorType::SCALAR_TRUE_DIV, numerator, denominator, name); } -Tensor ComputationGraphBuilder::sin(Tensor const &input, +tensor_guid_t ComputationGraphBuilder::sin(tensor_guid_t const &input, std::optional const &name) { return this->element_unary(OperatorType::SIN, input, name); } -Tensor ComputationGraphBuilder::cos(Tensor const &input, +tensor_guid_t ComputationGraphBuilder::cos(tensor_guid_t const &input, std::optional const &name) { return this->element_unary(OperatorType::COS, input, name); } -Tensor ComputationGraphBuilder::relu(Tensor const &input, +tensor_guid_t ComputationGraphBuilder::relu(tensor_guid_t const &input, std::optional const &name) { return this->element_unary(OperatorType::RELU, input, name); } -Tensor - ComputationGraphBuilder::identity(Tensor const &input, +tensor_guid_t + ComputationGraphBuilder::identity(tensor_guid_t const &input, std::optional const &name) { return this->element_unary(OperatorType::IDENTITY, input, name); } -Tensor ComputationGraphBuilder::gelu(Tensor const &input, +tensor_guid_t ComputationGraphBuilder::gelu(tensor_guid_t const &input, std::optional const &name) { return this->element_unary(OperatorType::GELU, input, name); } -Tensor - ComputationGraphBuilder::sigmoid(Tensor const &input, +tensor_guid_t + ComputationGraphBuilder::sigmoid(tensor_guid_t const &input, std::optional const &name) { return this->element_unary(OperatorType::SIGMOID, input, name); } -Tensor ComputationGraphBuilder::tanh(Tensor const &input, +tensor_guid_t ComputationGraphBuilder::tanh(tensor_guid_t const &input, std::optional const &name) { return this->element_unary(OperatorType::TANH, input, name); } -Tensor ComputationGraphBuilder::elu(Tensor const &input, +tensor_guid_t ComputationGraphBuilder::elu(tensor_guid_t const &input, std::optional const &name) { return this->element_unary(OperatorType::ELU, input, name); } -Tensor ComputationGraphBuilder::conv2d( - Tensor const &x, +tensor_guid_t ComputationGraphBuilder::conv2d( + tensor_guid_t const &x, int outChannels, int kernelH, int kernelW, @@ -265,8 +278,8 @@ Tensor ComputationGraphBuilder::conv2d( std::optional const &activation, int groups, bool use_bias, - std::optional const &kernel_initializer, - std::optional const &bias_initializer, + std::optional const &kernel_initializer, + std::optional const &bias_initializer, std::optional const &kernel_regularizer, std::optional const &maybe_name) { Conv2DAttrs attrs = {outChannels, @@ -281,102 +294,103 @@ Tensor ComputationGraphBuilder::conv2d( use_bias}; std::string name = maybe_name.value_or(get_default_name(attrs)); - Tensor input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - Layer layer = {attrs, name}; - TensorShape output_shape = get_output_shape(attrs, input); + LayerAttrs layer = {attrs, name}; + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - std::vector>> weights; + std::vector>> weights; - weights.push_back({get_kernel_shape(attrs, input), kernel_initializer}); + weights.push_back({get_kernel_shape(attrs, this->get_shape(input)), kernel_initializer}); if (use_bias) { - weights.push_back({get_bias_shape(attrs, input), bias_initializer}); + weights.push_back({get_bias_shape(attrs, this->get_shape(input)), bias_initializer}); } return this->add_layer(layer, {input}, weights, output_shape); } -Tensor ComputationGraphBuilder::dropout( - Tensor const &x, +tensor_guid_t ComputationGraphBuilder::dropout( + tensor_guid_t const &x, float rate, unsigned long long seed, std::optional const &maybe_name) { DropoutAttrs attrs = {rate, seed}; std::string name = maybe_name.value_or(get_default_name(attrs)); - Layer layer = {attrs, name}; - Tensor input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + LayerAttrs layer = {attrs, name}; + tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - TensorShape output_shape = get_output_shape(attrs, input); + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); return this->add_layer(layer, {input}, {}, output_shape); } -Tensor ComputationGraphBuilder::embedding( - Tensor const &x, +tensor_guid_t ComputationGraphBuilder::embedding( + tensor_guid_t const &x, int num_entries, int outDim, AggregateOp aggr, DataType dtype, - std::optional const &kernel_initializer, + std::optional const &kernel_initializer, std::optional const &maybe_name) { EmbeddingAttrs attrs = {num_entries, outDim, aggr, dtype}; std::string name = maybe_name.value_or(get_default_name(attrs)); - Layer layer = {attrs, name}; - Tensor input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + LayerAttrs layer = {attrs, name}; + tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - TensorShape output_shape = get_output_shape(attrs, input); - TensorShape weights_shape = get_weights_shape(attrs, input); + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + TensorShape weights_shape = get_weights_shape(attrs, this->get_shape(input)); return this->add_layer( layer, {input}, {{weights_shape, kernel_initializer}}, output_shape); } -std::vector ComputationGraphBuilder::gather( - Tensor const &input, - Tensor const &index, +std::vector ComputationGraphBuilder::gather( + tensor_guid_t const &input, + tensor_guid_t const &index, ff_dim_t dim, std::optional const &maybe_name) { GatherAttrs attrs = {dim}; std::string name = maybe_name.value_or(get_default_name(attrs)); - Layer layer = {attrs, name}; - if (index.data_type != DataType::INT32 && - index.data_type != DataType::INT64) { + LayerAttrs layer = {attrs, name}; + if (this->get_shape(index).data_type != DataType::INT32 && + this->get_shape(index).data_type != DataType::INT64) { throw mk_runtime_error("Invalid data type for input tensor 2 for Gather: " "{} (should be {} or {})", - input.data_type, + this->get_shape(input).data_type, DataType::INT32, DataType::INT64); } std::vector output_shapes = - get_output_shapes(attrs, input, index); + get_output_shapes(attrs, this->get_shape(input), this->get_shape(index)); return this->add_layer(layer, {input}, {}, output_shapes); } -TensorShape get_shape(Tensor const &t) { - return t.get_shape(); -} -std::vector get_shape(std::vector const &) { - NOT_IMPLEMENTED(); +TensorShape ComputationGraphBuilder::get_shape(tensor_guid_t const &t) const { + return this->get_attrs(t).shape; } -// Tensor ComputationGraphBuilder::aggregate( -// Tensor const &gate_preds, -// Tensor const &gate_assign, -// Tensor const &true_gate_assign, -// Tensor const &full_gate_gradients, -// std::vector const &exp_preds, +/* std::vector ComputationGraphBuilder::get_shapes(std::vector const &ts) const { */ +/* return transform(ts, [&](tensor_guid_t const &t) { return this->get_shape(t); }); */ +/* } */ + +// tensor_guid_t ComputationGraphBuilder::aggregate( +// tensor_guid_t const &gate_preds, +// tensor_guid_t const &gate_assign, +// tensor_guid_t const &true_gate_assign, +// tensor_guid_t const &full_gate_gradients, +// std::vector const &exp_preds, // int n, // float lambda_bal, // std::optional const &maybe_name) { // AggregateAttrs attrs = {n, lambda_bal}; // std::string name = maybe_name.value_or(get_default_name(attrs)); -// Layer layer = {attrs, name}; +// LayerAttrs layer = {attrs, name}; // TensorShape output_shape = get_output_shape(attrs, // get_shape(gate_preds), // get_shape(gate_assign), @@ -384,26 +398,31 @@ std::vector get_shape(std::vector const &) { // get_shape(full_gate_gradients), // get_shape(exp_preds)); -// std::vector inputs = { +// std::vector inputs = { // gate_preds, gate_assign, true_gate_assign, full_gate_gradients}; // extend(inputs, exp_preds); // return this->add_layer(layer, inputs, {}, output_shape); // } -Tensor ComputationGraphBuilder::batch_norm( - Tensor const &input, +tensor_guid_t ComputationGraphBuilder::batch_norm( + tensor_guid_t const &input, bool relu, std::optional const &maybe_name) { BatchNormAttrs attrs = BatchNormAttrs{relu}; std::string name = maybe_name.value_or(get_default_name(attrs)); - Layer layer = {attrs, name}; + LayerAttrs layer = {attrs, name}; - TensorShape output_shape = get_output_shape(attrs, get_shape(input)); + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); return this->add_layer(layer, {input}, {}, output_shape); } +TensorShape ComputationGraphBuilder::get_broadcast_target_shape( + std::vector const &) { + NOT_IMPLEMENTED(); +} + TensorShape ComputationGraphBuilder::get_broadcast_target_shape( std::vector const &) { NOT_IMPLEMENTED(); diff --git a/lib/pcg/src/pcg/cpu_id_t.dtg.cc b/lib/pcg/src/pcg/cpu_id_t.dtg.cc new file mode 100644 index 0000000000..f865442eb0 --- /dev/null +++ b/lib/pcg/src/pcg/cpu_id_t.dtg.cc @@ -0,0 +1,74 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/cpu_id_t.struct.toml +/* proj-data +{ + "generated_from": "a0faf78831febfa3a02929169943d9f5" +} +*/ + +#include "pcg/cpu_id_t.dtg.h" + +#include + +namespace FlexFlow { +cpu_id_t::cpu_id_t(int const &cpu_index) : cpu_index(cpu_index) {} +bool cpu_id_t::operator==(cpu_id_t const &other) const { + return std::tie(this->cpu_index) == std::tie(other.cpu_index); +} +bool cpu_id_t::operator!=(cpu_id_t const &other) const { + return std::tie(this->cpu_index) != std::tie(other.cpu_index); +} +bool cpu_id_t::operator<(cpu_id_t const &other) const { + return std::tie(this->cpu_index) < std::tie(other.cpu_index); +} +bool cpu_id_t::operator>(cpu_id_t const &other) const { + return std::tie(this->cpu_index) > std::tie(other.cpu_index); +} +bool cpu_id_t::operator<=(cpu_id_t const &other) const { + return std::tie(this->cpu_index) <= std::tie(other.cpu_index); +} +bool cpu_id_t::operator>=(cpu_id_t const &other) const { + return std::tie(this->cpu_index) >= std::tie(other.cpu_index); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()(FlexFlow::cpu_id_t const &x) const { + size_t result = 0; + result ^= std::hash{}(x.cpu_index) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::cpu_id_t + adl_serializer::from_json(json const &j) { + return {j.at("cpu_index").template get()}; +} +void adl_serializer::to_json(json &j, + FlexFlow::cpu_id_t const &v) { + j["__type"] = "cpu_id_t"; + j["cpu_index"] = v.cpu_index; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(cpu_id_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, cpu_id_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/create_grad.dtg.cc b/lib/pcg/src/pcg/create_grad.dtg.cc new file mode 100644 index 0000000000..b2b7e3233b --- /dev/null +++ b/lib/pcg/src/pcg/create_grad.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/create_grad.enum.toml +/* proj-data +{ + "generated_from": "9fd617027e850b6d6db476a49b3e0334" +} +*/ + +#include "pcg/create_grad.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::CreateGrad x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(CreateGrad x) { + switch (x) { + case CreateGrad::YES: + return "YES"; + case CreateGrad::NO: + return "NO"; + default: + std::ostringstream oss; + oss << "Unknown CreateGrad value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, CreateGrad x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, CreateGrad x) { + switch (x) { + case CreateGrad::YES: + j = "YES"; + break; + case CreateGrad::NO: + j = "NO"; + break; + default: + std::ostringstream oss; + oss << "Unknown CreateGrad value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, CreateGrad &x) { + std::string as_str = j.get(); + if (as_str == "YES") { + x = CreateGrad::YES; + } else if (as_str == "NO") { + x = CreateGrad::NO; + } else { + std::ostringstream oss; + oss << "Unknown CreateGrad value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::CreateGrad::YES, + FlexFlow::CreateGrad::NO); +} +} // namespace rc diff --git a/lib/pcg/src/pcg/device_id.cc b/lib/pcg/src/pcg/device_id.cc new file mode 100644 index 0000000000..35b0c9aeda --- /dev/null +++ b/lib/pcg/src/pcg/device_id.cc @@ -0,0 +1,32 @@ +#include "pcg/device_id.h" +#include "utils/exception.h" +#include + +namespace FlexFlow { + +device_id_t operator+(device_id_t, size_t) { + NOT_IMPLEMENTED(); +} + +DeviceType get_device_type(device_id_t const &device_id) { + if (device_id.has()) { + return DeviceType::GPU; + } else { + assert(device_id.has()); + return DeviceType::CPU; + } +} + +gpu_id_t unwrap_gpu(device_id_t device_id) { + return device_id.get(); +} + +cpu_id_t unwrap_cpu(device_id_t device_id) { + return device_id.get(); +} + +device_id_t device_id_from_index(int, DeviceType) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/device_id_t.dtg.cc b/lib/pcg/src/pcg/device_id_t.dtg.cc new file mode 100644 index 0000000000..517c6c198c --- /dev/null +++ b/lib/pcg/src/pcg/device_id_t.dtg.cc @@ -0,0 +1,103 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/device_id_t.variant.toml +/* proj-data +{ + "generated_from": "85870050c742b0159775399ec2be67e3" +} +*/ + +#include "pcg/device_id_t.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +device_id_t::device_id_t(::FlexFlow::gpu_id_t const &v) : raw_variant(v) {} +device_id_t::device_id_t(::FlexFlow::cpu_id_t const &v) : raw_variant(v) {} +bool device_id_t::operator==(device_id_t const &other) const { + return this->raw_variant == other.raw_variant; +} +bool device_id_t::operator!=(device_id_t const &other) const { + return this->raw_variant != other.raw_variant; +} +bool device_id_t::operator<(device_id_t const &other) const { + return this->raw_variant < other.raw_variant; +} +bool device_id_t::operator>(device_id_t const &other) const { + return this->raw_variant > other.raw_variant; +} +bool device_id_t::operator<=(device_id_t const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool device_id_t::operator>=(device_id_t const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::device_id_t>::operator()( + ::FlexFlow::device_id_t const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::device_id_t + adl_serializer<::FlexFlow::device_id_t>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "gpu") { + return ::FlexFlow::device_id_t{ + j.at("value").template get<::FlexFlow::gpu_id_t>()}; + } else if (key == "cpu") { + return ::FlexFlow::device_id_t{ + j.at("value").template get<::FlexFlow::cpu_id_t>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::device_id_t>::to_json( + json &j, ::FlexFlow::device_id_t const &x) { + j["__type"] = "device_id_t"; + switch (x.index()) { + case 0: { + j["type"] = "gpu"; + j["value"] = x.get<::FlexFlow::gpu_id_t>(); + break; + } + case 1: { + j["type"] = "cpu"; + j["value"] = x.get<::FlexFlow::cpu_id_t>(); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type device_id_t", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::device_id_t const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type device_id_t", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ::FlexFlow::device_id_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/device_type.dtg.cc b/lib/pcg/src/pcg/device_type.dtg.cc new file mode 100644 index 0000000000..8279cc4c16 --- /dev/null +++ b/lib/pcg/src/pcg/device_type.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/device_type.enum.toml +/* proj-data +{ + "generated_from": "cfe4bc5e9f7c5796b9b90b420c33935f" +} +*/ + +#include "pcg/device_type.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::DeviceType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(DeviceType x) { + switch (x) { + case DeviceType::GPU: + return "GPU"; + case DeviceType::CPU: + return "CPU"; + default: + std::ostringstream oss; + oss << "Unknown DeviceType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, DeviceType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, DeviceType x) { + switch (x) { + case DeviceType::GPU: + j = "GPU"; + break; + case DeviceType::CPU: + j = "CPU"; + break; + default: + std::ostringstream oss; + oss << "Unknown DeviceType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, DeviceType &x) { + std::string as_str = j.get(); + if (as_str == "GPU") { + x = DeviceType::GPU; + } else if (as_str == "CPU") { + x = DeviceType::CPU; + } else { + std::ostringstream oss; + oss << "Unknown DeviceType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::DeviceType::GPU, + FlexFlow::DeviceType::CPU); +} +} // namespace rc diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_edge.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_edge.dtg.cc new file mode 100644 index 0000000000..713aa941d2 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_edge.dtg.cc @@ -0,0 +1,94 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml +/* proj-data +{ + "generated_from": "865097b569b831af049343e933834329" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" + +#include + +namespace FlexFlow { +V1GraphEdge::V1GraphEdge(size_t const &srcNode, + size_t const &srcIdx, + size_t const &dstNode, + size_t const &dstIdx) + : srcNode(srcNode), srcIdx(srcIdx), dstNode(dstNode), dstIdx(dstIdx) {} +bool V1GraphEdge::operator==(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) == + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +bool V1GraphEdge::operator!=(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) != + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +bool V1GraphEdge::operator<(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) < + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +bool V1GraphEdge::operator>(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) > + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +bool V1GraphEdge::operator<=(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) <= + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +bool V1GraphEdge::operator>=(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) >= + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::V1GraphEdge const &x) const { + size_t result = 0; + result ^= std::hash{}(x.srcNode) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.srcIdx) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.dstNode) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.dstIdx) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::V1GraphEdge + adl_serializer::from_json(json const &j) { + return {j.at("srcNode").template get(), + j.at("srcIdx").template get(), + j.at("dstNode").template get(), + j.at("dstIdx").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::V1GraphEdge const &v) { + j["__type"] = "V1GraphEdge"; + j["srcNode"] = v.srcNode; + j["srcIdx"] = v.srcIdx; + j["dstNode"] = v.dstNode; + j["dstIdx"] = v.dstIdx; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1GraphEdge const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, V1GraphEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_output.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_output.dtg.cc new file mode 100644 index 0000000000..fa0b792a37 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_output.dtg.cc @@ -0,0 +1,81 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml +/* proj-data +{ + "generated_from": "05ff8401c3d976ea2220899edb8dfe3a" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_graph_output.dtg.h" + +#include + +namespace FlexFlow { +V1GraphOutput::V1GraphOutput(size_t const &srcNode, size_t const &srcIdx) + : srcNode(srcNode), srcIdx(srcIdx) {} +bool V1GraphOutput::operator==(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) == + std::tie(other.srcNode, other.srcIdx); +} +bool V1GraphOutput::operator!=(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) != + std::tie(other.srcNode, other.srcIdx); +} +bool V1GraphOutput::operator<(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) < + std::tie(other.srcNode, other.srcIdx); +} +bool V1GraphOutput::operator>(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) > + std::tie(other.srcNode, other.srcIdx); +} +bool V1GraphOutput::operator<=(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) <= + std::tie(other.srcNode, other.srcIdx); +} +bool V1GraphOutput::operator>=(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) >= + std::tie(other.srcNode, other.srcIdx); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::V1GraphOutput const &x) const { + size_t result = 0; + result ^= std::hash{}(x.srcNode) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.srcIdx) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::V1GraphOutput + adl_serializer::from_json(json const &j) { + return {j.at("srcNode").template get(), + j.at("srcIdx").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::V1GraphOutput const &v) { + j["__type"] = "V1GraphOutput"; + j["srcNode"] = v.srcNode; + j["srcIdx"] = v.srcIdx; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1GraphOutput const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, V1GraphOutput const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc new file mode 100644 index 0000000000..7f7e670782 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc @@ -0,0 +1,10 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml +/* proj-data +{ + "generated_from": "0595a9f5a6bc19f9a170cb0e42c4202d" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h" diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc new file mode 100644 index 0000000000..0f5a83b02f --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc @@ -0,0 +1,56 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml +/* proj-data +{ + "generated_from": "fb1033385645e54a19c9b44cef0be04b" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h" + +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" +#include "utils/fmt.h" +#include +#include +#include + +namespace FlexFlow { +V1MultiDiGraph::V1MultiDiGraph( + std::vector const &nodes, + std::vector const &ports, + std::unordered_set<::FlexFlow::V1GraphEdge> const &edges) + : nodes(nodes), ports(ports), edges(edges) {} +} // namespace FlexFlow + +namespace nlohmann { +FlexFlow::V1MultiDiGraph + adl_serializer::from_json(json const &j) { + return {j.at("nodes").template get>(), + j.at("ports").template get>(), + j.at("edges") + .template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::V1MultiDiGraph const &v) { + j["__type"] = "V1MultiDiGraph"; + j["nodes"] = v.nodes; + j["ports"] = v.ports; + j["edges"] = v.edges; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1MultiDiGraph const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, V1MultiDiGraph const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/gpu_id_t.dtg.cc b/lib/pcg/src/pcg/gpu_id_t.dtg.cc new file mode 100644 index 0000000000..e2385a83ce --- /dev/null +++ b/lib/pcg/src/pcg/gpu_id_t.dtg.cc @@ -0,0 +1,74 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/gpu_id_t.struct.toml +/* proj-data +{ + "generated_from": "022355e43f43141d332be50ea3080ee2" +} +*/ + +#include "pcg/gpu_id_t.dtg.h" + +#include + +namespace FlexFlow { +gpu_id_t::gpu_id_t(int const &gpu_index) : gpu_index(gpu_index) {} +bool gpu_id_t::operator==(gpu_id_t const &other) const { + return std::tie(this->gpu_index) == std::tie(other.gpu_index); +} +bool gpu_id_t::operator!=(gpu_id_t const &other) const { + return std::tie(this->gpu_index) != std::tie(other.gpu_index); +} +bool gpu_id_t::operator<(gpu_id_t const &other) const { + return std::tie(this->gpu_index) < std::tie(other.gpu_index); +} +bool gpu_id_t::operator>(gpu_id_t const &other) const { + return std::tie(this->gpu_index) > std::tie(other.gpu_index); +} +bool gpu_id_t::operator<=(gpu_id_t const &other) const { + return std::tie(this->gpu_index) <= std::tie(other.gpu_index); +} +bool gpu_id_t::operator>=(gpu_id_t const &other) const { + return std::tie(this->gpu_index) >= std::tie(other.gpu_index); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()(FlexFlow::gpu_id_t const &x) const { + size_t result = 0; + result ^= std::hash{}(x.gpu_index) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::gpu_id_t + adl_serializer::from_json(json const &j) { + return {j.at("gpu_index").template get()}; +} +void adl_serializer::to_json(json &j, + FlexFlow::gpu_id_t const &v) { + j["__type"] = "gpu_id_t"; + j["gpu_index"] = v.gpu_index; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(gpu_id_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, gpu_id_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializer_attrs.dtg.cc new file mode 100644 index 0000000000..2a4e97db1e --- /dev/null +++ b/lib/pcg/src/pcg/initializer_attrs.dtg.cc @@ -0,0 +1,158 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializer_attrs.variant.toml +/* proj-data +{ + "generated_from": "f66f3a89ea937e96a058d83ab52e2826" +} +*/ + +#include "pcg/initializer_attrs.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +InitializerAttrs::InitializerAttrs(::FlexFlow::GlorotUniformAttrs const &v) + : raw_variant(v) {} +InitializerAttrs::InitializerAttrs(::FlexFlow::ZeroInitializerAttrs const &v) + : raw_variant(v) {} +InitializerAttrs::InitializerAttrs(::FlexFlow::UniformInitializerAttrs const &v) + : raw_variant(v) {} +InitializerAttrs::InitializerAttrs(::FlexFlow::NormInitializerAttrs const &v) + : raw_variant(v) {} +InitializerAttrs::InitializerAttrs( + ::FlexFlow::ConstantInitializerAttrs const &v) + : raw_variant(v) {} +bool InitializerAttrs::operator==(InitializerAttrs const &other) const { + return this->raw_variant == other.raw_variant; +} +bool InitializerAttrs::operator!=(InitializerAttrs const &other) const { + return this->raw_variant != other.raw_variant; +} +bool InitializerAttrs::operator<(InitializerAttrs const &other) const { + return this->raw_variant < other.raw_variant; +} +bool InitializerAttrs::operator>(InitializerAttrs const &other) const { + return this->raw_variant > other.raw_variant; +} +bool InitializerAttrs::operator<=(InitializerAttrs const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool InitializerAttrs::operator>=(InitializerAttrs const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::InitializerAttrs>::operator()( + ::FlexFlow::InitializerAttrs const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::InitializerAttrs + adl_serializer<::FlexFlow::InitializerAttrs>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "glorot_uniform") { + return ::FlexFlow::InitializerAttrs{ + j.at("value").template get<::FlexFlow::GlorotUniformAttrs>()}; + } else if (key == "zero") { + return ::FlexFlow::InitializerAttrs{ + j.at("value").template get<::FlexFlow::ZeroInitializerAttrs>()}; + } else if (key == "uniform") { + return ::FlexFlow::InitializerAttrs{ + j.at("value").template get<::FlexFlow::UniformInitializerAttrs>()}; + } else if (key == "normal") { + return ::FlexFlow::InitializerAttrs{ + j.at("value").template get<::FlexFlow::NormInitializerAttrs>()}; + } else if (key == "constant") { + return ::FlexFlow::InitializerAttrs{ + j.at("value").template get<::FlexFlow::ConstantInitializerAttrs>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::InitializerAttrs>::to_json( + json &j, ::FlexFlow::InitializerAttrs const &x) { + j["__type"] = "InitializerAttrs"; + switch (x.index()) { + case 0: { + j["type"] = "glorot_uniform"; + j["value"] = x.get<::FlexFlow::GlorotUniformAttrs>(); + break; + } + case 1: { + j["type"] = "zero"; + j["value"] = x.get<::FlexFlow::ZeroInitializerAttrs>(); + break; + } + case 2: { + j["type"] = "uniform"; + j["value"] = x.get<::FlexFlow::UniformInitializerAttrs>(); + break; + } + case 3: { + j["type"] = "normal"; + j["value"] = x.get<::FlexFlow::NormInitializerAttrs>(); + break; + } + case 4: { + j["type"] = "constant"; + j["value"] = x.get<::FlexFlow::ConstantInitializerAttrs>(); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type InitializerAttrs", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::InitializerAttrs const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + case 3: { + oss << ""; + break; + } + case 4: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type InitializerAttrs", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::InitializerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc new file mode 100644 index 0000000000..9770c35248 --- /dev/null +++ b/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc @@ -0,0 +1,80 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "0162b9c49fe6cbfc65410c6fa8dec427" +} +*/ + +#include "pcg/initializers/constant_initializer_attrs.dtg.h" + +#include "op-attrs/datatype.h" +#include "utils/json.h" +#include + +namespace FlexFlow { +ConstantInitializerAttrs::ConstantInitializerAttrs( + ::FlexFlow::DataTypeValue const &value) + : value(value) {} +bool ConstantInitializerAttrs::operator==( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool ConstantInitializerAttrs::operator!=( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool ConstantInitializerAttrs::operator<( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool ConstantInitializerAttrs::operator>( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool ConstantInitializerAttrs::operator<=( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool ConstantInitializerAttrs::operator>=( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ConstantInitializerAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::DataTypeValue>{}(x.value) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ConstantInitializerAttrs + adl_serializer::from_json( + json const &j) { + return {j.at("value").template get<::FlexFlow::DataTypeValue>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ConstantInitializerAttrs const &v) { + j["__type"] = "ConstantInitializerAttrs"; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ConstantInitializerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ConstantInitializerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializers/glorot_uniform_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/glorot_uniform_attrs.dtg.cc new file mode 100644 index 0000000000..0c8ae6e60c --- /dev/null +++ b/lib/pcg/src/pcg/initializers/glorot_uniform_attrs.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml +/* proj-data +{ + "generated_from": "a268b411b6d378faa11e60c8517d7be5" +} +*/ + +#include "pcg/initializers/glorot_uniform_attrs.dtg.h" + +#include + +namespace FlexFlow { +GlorotUniformAttrs::GlorotUniformAttrs(int const &seed) : seed(seed) {} +bool GlorotUniformAttrs::operator==(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) == std::tie(other.seed); +} +bool GlorotUniformAttrs::operator!=(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) != std::tie(other.seed); +} +bool GlorotUniformAttrs::operator<(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) < std::tie(other.seed); +} +bool GlorotUniformAttrs::operator>(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) > std::tie(other.seed); +} +bool GlorotUniformAttrs::operator<=(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) <= std::tie(other.seed); +} +bool GlorotUniformAttrs::operator>=(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) >= std::tie(other.seed); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::GlorotUniformAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.seed) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::GlorotUniformAttrs + adl_serializer::from_json(json const &j) { + return {j.at("seed").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::GlorotUniformAttrs const &v) { + j["__type"] = "GlorotUniformAttrs"; + j["seed"] = v.seed; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(GlorotUniformAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, GlorotUniformAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializers/norm_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/norm_initializer_attrs.dtg.cc new file mode 100644 index 0000000000..aceac12212 --- /dev/null +++ b/lib/pcg/src/pcg/initializers/norm_initializer_attrs.dtg.cc @@ -0,0 +1,96 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "6843fc9ca02aea2b40e57dbc497f99ac" +} +*/ + +#include "pcg/initializers/norm_initializer_attrs.dtg.h" + +#include + +namespace FlexFlow { +NormInitializerAttrs::NormInitializerAttrs(int const &seed, + float const &mean, + float const &stddev) + : seed(seed), mean(mean), stddev(stddev) {} +bool NormInitializerAttrs::operator==(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) == + std::tie(other.seed, other.mean, other.stddev); +} +bool NormInitializerAttrs::operator!=(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) != + std::tie(other.seed, other.mean, other.stddev); +} +bool NormInitializerAttrs::operator<(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) < + std::tie(other.seed, other.mean, other.stddev); +} +bool NormInitializerAttrs::operator>(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) > + std::tie(other.seed, other.mean, other.stddev); +} +bool NormInitializerAttrs::operator<=(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) <= + std::tie(other.seed, other.mean, other.stddev); +} +bool NormInitializerAttrs::operator>=(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) >= + std::tie(other.seed, other.mean, other.stddev); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::NormInitializerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.seed) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.mean) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stddev) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::NormInitializerAttrs + adl_serializer::from_json(json const &j) { + return {j.at("seed").template get(), + j.at("mean").template get(), + j.at("stddev").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::NormInitializerAttrs const &v) { + j["__type"] = "NormInitializerAttrs"; + j["seed"] = v.seed; + j["mean"] = v.mean; + j["stddev"] = v.stddev; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), gen::arbitrary(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(NormInitializerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, NormInitializerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc new file mode 100644 index 0000000000..a9c62675d0 --- /dev/null +++ b/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc @@ -0,0 +1,95 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "f887e1db5d5dc710793ec5fa99bb7cd4" +} +*/ + +#include "pcg/initializers/uniform_initializer_attrs.dtg.h" + +#include + +namespace FlexFlow { +UniformInitializerAttrs::UniformInitializerAttrs(int const &seed, + float const &min_val, + float const &max_val) + : seed(seed), min_val(min_val), max_val(max_val) {} +bool UniformInitializerAttrs::operator==( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) == + std::tie(other.seed, other.min_val, other.max_val); +} +bool UniformInitializerAttrs::operator!=( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) != + std::tie(other.seed, other.min_val, other.max_val); +} +bool UniformInitializerAttrs::operator<( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) < + std::tie(other.seed, other.min_val, other.max_val); +} +bool UniformInitializerAttrs::operator>( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) > + std::tie(other.seed, other.min_val, other.max_val); +} +bool UniformInitializerAttrs::operator<=( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) <= + std::tie(other.seed, other.min_val, other.max_val); +} +bool UniformInitializerAttrs::operator>=( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) >= + std::tie(other.seed, other.min_val, other.max_val); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::UniformInitializerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.seed) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.min_val) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.max_val) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::UniformInitializerAttrs + adl_serializer::from_json( + json const &j) { + return {j.at("seed").template get(), + j.at("min_val").template get(), + j.at("max_val").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::UniformInitializerAttrs const &v) { + j["__type"] = "UniformInitializerAttrs"; + j["seed"] = v.seed; + j["min_val"] = v.min_val; + j["max_val"] = v.max_val; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(UniformInitializerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, UniformInitializerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializers/zero_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/zero_initializer_attrs.dtg.cc new file mode 100644 index 0000000000..933501a734 --- /dev/null +++ b/lib/pcg/src/pcg/initializers/zero_initializer_attrs.dtg.cc @@ -0,0 +1,71 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "a19d5a2cdc67a2840d6ba55250a10411" +} +*/ + +#include "pcg/initializers/zero_initializer_attrs.dtg.h" + +#include + +namespace FlexFlow { +bool ZeroInitializerAttrs::operator==(ZeroInitializerAttrs const &other) const { + return std::tie() == std::tie(); +} +bool ZeroInitializerAttrs::operator!=(ZeroInitializerAttrs const &other) const { + return std::tie() != std::tie(); +} +bool ZeroInitializerAttrs::operator<(ZeroInitializerAttrs const &other) const { + return std::tie() < std::tie(); +} +bool ZeroInitializerAttrs::operator>(ZeroInitializerAttrs const &other) const { + return std::tie() > std::tie(); +} +bool ZeroInitializerAttrs::operator<=(ZeroInitializerAttrs const &other) const { + return std::tie() <= std::tie(); +} +bool ZeroInitializerAttrs::operator>=(ZeroInitializerAttrs const &other) const { + return std::tie() >= std::tie(); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ZeroInitializerAttrs const &x) const { + size_t result = 0; + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ZeroInitializerAttrs + adl_serializer::from_json(json const &j) { + return {}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ZeroInitializerAttrs const &v) { + j["__type"] = "ZeroInitializerAttrs"; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ZeroInitializerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ZeroInitializerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/layer_attrs.dtg.cc b/lib/pcg/src/pcg/layer_attrs.dtg.cc new file mode 100644 index 0000000000..54fe104ce3 --- /dev/null +++ b/lib/pcg/src/pcg/layer_attrs.dtg.cc @@ -0,0 +1,84 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/layer_attrs.struct.toml +/* proj-data +{ + "generated_from": "12b49c15e8defff5118e5607a7823f59" +} +*/ + +#include "pcg/layer_attrs.dtg.h" + +#include "op-attrs/operator_attrs.h" +#include "utils/json.h" +#include "utils/stack_string.h" +#include +#include + +namespace FlexFlow { +LayerAttrs::LayerAttrs( + ::FlexFlow::CompGraphOperatorAttrs const &attrs, + std::optional<::FlexFlow::stack_string> const &name) + : attrs(attrs), name(name) {} +bool LayerAttrs::operator==(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) == std::tie(other.attrs, other.name); +} +bool LayerAttrs::operator!=(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) != std::tie(other.attrs, other.name); +} +bool LayerAttrs::operator<(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) < std::tie(other.attrs, other.name); +} +bool LayerAttrs::operator>(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) > std::tie(other.attrs, other.name); +} +bool LayerAttrs::operator<=(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) <= std::tie(other.attrs, other.name); +} +bool LayerAttrs::operator>=(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) >= std::tie(other.attrs, other.name); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::LayerAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::CompGraphOperatorAttrs>{}(x.attrs) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash>>{}(x.name) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::LayerAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("attrs").template get<::FlexFlow::CompGraphOperatorAttrs>(), + j.at("name") + .template get>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::LayerAttrs const &v) { + j["__type"] = "LayerAttrs"; + j["attrs"] = v.attrs; + j["name"] = v.name; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(LayerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, LayerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_specification.dtg.cc b/lib/pcg/src/pcg/machine_specification.dtg.cc new file mode 100644 index 0000000000..238c61a014 --- /dev/null +++ b/lib/pcg/src/pcg/machine_specification.dtg.cc @@ -0,0 +1,151 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/machine_specification.struct.toml +/* proj-data +{ + "generated_from": "72c3ae372af189d0c8bae74c2dbbc531" +} +*/ + +#include "pcg/machine_specification.dtg.h" + +#include + +namespace FlexFlow { +MachineSpecification::MachineSpecification(int const &num_nodes, + int const &num_cpus_per_node, + int const &num_gpus_per_node, + float const &inter_node_bandwidth, + float const &intra_node_bandwidth) + : num_nodes(num_nodes), num_cpus_per_node(num_cpus_per_node), + num_gpus_per_node(num_gpus_per_node), + inter_node_bandwidth(inter_node_bandwidth), + intra_node_bandwidth(intra_node_bandwidth) {} +bool MachineSpecification::operator==(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) == + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +bool MachineSpecification::operator!=(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) != + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +bool MachineSpecification::operator<(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) < + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +bool MachineSpecification::operator>(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) > + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +bool MachineSpecification::operator<=(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) <= + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +bool MachineSpecification::operator>=(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) >= + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MachineSpecification const &x) const { + size_t result = 0; + result ^= std::hash{}(x.num_nodes) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.num_cpus_per_node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.num_gpus_per_node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.inter_node_bandwidth) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.intra_node_bandwidth) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MachineSpecification + adl_serializer::from_json(json const &j) { + return {j.at("num_nodes").template get(), + j.at("num_cpus_per_node").template get(), + j.at("num_gpus_per_node").template get(), + j.at("inter_node_bandwidth").template get(), + j.at("intra_node_bandwidth").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MachineSpecification const &v) { + j["__type"] = "MachineSpecification"; + j["num_nodes"] = v.num_nodes; + j["num_cpus_per_node"] = v.num_cpus_per_node; + j["num_gpus_per_node"] = v.num_gpus_per_node; + j["inter_node_bandwidth"] = v.inter_node_bandwidth; + j["intra_node_bandwidth"] = v.intra_node_bandwidth; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(MachineSpecification const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MachineSpecification const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc new file mode 100644 index 0000000000..c181a1ebbc --- /dev/null +++ b/lib/pcg/src/pcg/machine_view.cc @@ -0,0 +1,63 @@ +#include "pcg/machine_view.h" +#include "pcg/strided_rectangle_side.h" +#include "pcg/strided_rectangle.dtg.h" + +namespace FlexFlow { + +std::vector device_ids(MachineView const &) { + NOT_IMPLEMENTED(); +} + +std::size_t num_dims(MachineView const &) { + NOT_IMPLEMENTED(); +} + +std::size_t num_devices(MachineView const &) { + NOT_IMPLEMENTED(); +} + +DeviceType get_device_type(MachineView const &) { + NOT_IMPLEMENTED(); +} + +static StridedRectangle make_1d_rect(int start, int stop, int stride) { + assert(stop > start); + assert(stride > 0); + StridedRectangleSide side = strided_side_from_size_and_stride(side_size_t{stop - start}, stride); + StridedRectangle rect = {{side}}; + return rect; +} + +MachineView make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride) { + StridedRectangle rect = make_1d_rect(start.gpu_index, stop.gpu_index, stride); + return {device_id_t{start}, rect}; +} + +MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride) { + StridedRectangle rect = make_1d_rect(start.cpu_index, stop.cpu_index, stride); + return {device_id_t{start}, rect}; +} + +MachineView make_1d_machine_view(device_id_t start, + num_points_t num_points, + int stride) { + NOT_IMPLEMENTED(); +} + +MachineView make_1d_machine_view(device_id_t start, + side_size_t interval_size, + int stride) { + NOT_IMPLEMENTED(); +} + +MachineView make_1d_machine_view(device_id_t start, size_t interval_size) { + NOT_IMPLEMENTED(); +} + +/* device_id_t MachineView::at(FFOrdered const &coord) const { */ +/* size_t offset = this->rect.at(coord); */ +/* return this->start + offset; */ +/* } */ + + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_view.dtg.cc b/lib/pcg/src/pcg/machine_view.dtg.cc new file mode 100644 index 0000000000..edab125e3d --- /dev/null +++ b/lib/pcg/src/pcg/machine_view.dtg.cc @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/machine_view.struct.toml +/* proj-data +{ + "generated_from": "16c571e6bb82d7ef88e5d2a9146638f4" +} +*/ + +#include "pcg/machine_view.dtg.h" + +#include "pcg/device_id_t.dtg.h" +#include "pcg/strided_rectangle.dtg.h" +#include + +namespace FlexFlow { +MachineView::MachineView(::FlexFlow::device_id_t const &start, + ::FlexFlow::StridedRectangle const &rect) + : start(start), rect(rect) {} +bool MachineView::operator==(MachineView const &other) const { + return std::tie(this->start, this->rect) == std::tie(other.start, other.rect); +} +bool MachineView::operator!=(MachineView const &other) const { + return std::tie(this->start, this->rect) != std::tie(other.start, other.rect); +} +bool MachineView::operator<(MachineView const &other) const { + return std::tie(this->start, this->rect) < std::tie(other.start, other.rect); +} +bool MachineView::operator>(MachineView const &other) const { + return std::tie(this->start, this->rect) > std::tie(other.start, other.rect); +} +bool MachineView::operator<=(MachineView const &other) const { + return std::tie(this->start, this->rect) <= std::tie(other.start, other.rect); +} +bool MachineView::operator>=(MachineView const &other) const { + return std::tie(this->start, this->rect) >= std::tie(other.start, other.rect); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MachineView const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::device_id_t>{}(x.start) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::StridedRectangle>{}(x.rect) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MachineView + adl_serializer::from_json(json const &j) { + return {j.at("start").template get<::FlexFlow::device_id_t>(), + j.at("rect").template get<::FlexFlow::StridedRectangle>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MachineView const &v) { + j["__type"] = "MachineView"; + j["start"] = v.start; + j["rect"] = v.rect; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(MachineView const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MachineView const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/num_points_t.dtg.cc b/lib/pcg/src/pcg/num_points_t.dtg.cc new file mode 100644 index 0000000000..7a0a849814 --- /dev/null +++ b/lib/pcg/src/pcg/num_points_t.dtg.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/num_points_t.struct.toml +/* proj-data +{ + "generated_from": "2a862b92055eda0508447d2f4df52f71" +} +*/ + +#include "pcg/num_points_t.dtg.h" + +#include + +namespace FlexFlow { +num_points_t::num_points_t(int const &unwrapped) : unwrapped(unwrapped) {} +bool num_points_t::operator==(num_points_t const &other) const { + return std::tie(this->unwrapped) == std::tie(other.unwrapped); +} +bool num_points_t::operator!=(num_points_t const &other) const { + return std::tie(this->unwrapped) != std::tie(other.unwrapped); +} +bool num_points_t::operator<(num_points_t const &other) const { + return std::tie(this->unwrapped) < std::tie(other.unwrapped); +} +bool num_points_t::operator>(num_points_t const &other) const { + return std::tie(this->unwrapped) > std::tie(other.unwrapped); +} +bool num_points_t::operator<=(num_points_t const &other) const { + return std::tie(this->unwrapped) <= std::tie(other.unwrapped); +} +bool num_points_t::operator>=(num_points_t const &other) const { + return std::tie(this->unwrapped) >= std::tie(other.unwrapped); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::num_points_t const &x) const { + size_t result = 0; + result ^= std::hash{}(x.unwrapped) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::num_points_t + adl_serializer::from_json(json const &j) { + return {j.at("unwrapped").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::num_points_t const &v) { + j["__type"] = "num_points_t"; + j["unwrapped"] = v.unwrapped; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(num_points_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, num_points_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_guid_t.dtg.cc b/lib/pcg/src/pcg/operator_guid_t.dtg.cc new file mode 100644 index 0000000000..46b031f7e1 --- /dev/null +++ b/lib/pcg/src/pcg/operator_guid_t.dtg.cc @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_guid_t.struct.toml +/* proj-data +{ + "generated_from": "348b5a610f4ff6f545884564ee9a1e6a" +} +*/ + +#include "pcg/operator_guid_t.dtg.h" + +#include "utils/graph.h" +#include + +namespace FlexFlow { +operator_guid_t::operator_guid_t(::FlexFlow::Node const &raw_graph_node) + : raw_graph_node(raw_graph_node) {} +bool operator_guid_t::operator==(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) == std::tie(other.raw_graph_node); +} +bool operator_guid_t::operator!=(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) != std::tie(other.raw_graph_node); +} +bool operator_guid_t::operator<(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) < std::tie(other.raw_graph_node); +} +bool operator_guid_t::operator>(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) > std::tie(other.raw_graph_node); +} +bool operator_guid_t::operator<=(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) <= std::tie(other.raw_graph_node); +} +bool operator_guid_t::operator>=(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) >= std::tie(other.raw_graph_node); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::operator_guid_t const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.raw_graph_node) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(operator_guid_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, operator_guid_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/optimizers/adam_optimizer_attrs.dtg.cc b/lib/pcg/src/pcg/optimizers/adam_optimizer_attrs.dtg.cc new file mode 100644 index 0000000000..d362459cc3 --- /dev/null +++ b/lib/pcg/src/pcg/optimizers/adam_optimizer_attrs.dtg.cc @@ -0,0 +1,192 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "f49e1bebcb0ef2bc3c210073e3183d4d" +} +*/ + +#include "pcg/optimizers/adam_optimizer_attrs.dtg.h" + +#include + +namespace FlexFlow { +AdamOptimizerAttrs::AdamOptimizerAttrs(double const &alpha, + double const &beta1, + double const &beta2, + double const &weight_decay, + double const &alpha_t, + double const &beta_t, + double const &beta2_t) + : alpha(alpha), beta1(beta1), beta2(beta2), weight_decay(weight_decay), + alpha_t(alpha_t), beta_t(beta_t), beta2_t(beta2_t) {} +bool AdamOptimizerAttrs::operator==(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) == std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +bool AdamOptimizerAttrs::operator!=(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) != std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +bool AdamOptimizerAttrs::operator<(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) < std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +bool AdamOptimizerAttrs::operator>(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) > std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +bool AdamOptimizerAttrs::operator<=(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) <= std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +bool AdamOptimizerAttrs::operator>=(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) >= std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::AdamOptimizerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.alpha) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.beta1) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.beta2) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.weight_decay) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.alpha_t) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.beta_t) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.beta2_t) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::AdamOptimizerAttrs + adl_serializer::from_json(json const &j) { + return {j.at("alpha").template get(), + j.at("beta1").template get(), + j.at("beta2").template get(), + j.at("weight_decay").template get(), + j.at("alpha_t").template get(), + j.at("beta_t").template get(), + j.at("beta2_t").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::AdamOptimizerAttrs const &v) { + j["__type"] = "AdamOptimizerAttrs"; + j["alpha"] = v.alpha; + j["beta1"] = v.beta1; + j["beta2"] = v.beta2; + j["weight_decay"] = v.weight_decay; + j["alpha_t"] = v.alpha_t; + j["beta_t"] = v.beta_t; + j["beta2_t"] = v.beta2_t; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(AdamOptimizerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, AdamOptimizerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/optimizers/sgd_optimizer_attrs.dtg.cc b/lib/pcg/src/pcg/optimizers/sgd_optimizer_attrs.dtg.cc new file mode 100644 index 0000000000..d5e668917b --- /dev/null +++ b/lib/pcg/src/pcg/optimizers/sgd_optimizer_attrs.dtg.cc @@ -0,0 +1,111 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "d18c91cdddc760f1fb3990d2c817ee87" +} +*/ + +#include "pcg/optimizers/sgd_optimizer_attrs.dtg.h" + +#include + +namespace FlexFlow { +SGDOptimizerAttrs::SGDOptimizerAttrs(double const &lr, + double const &momentum, + bool const &nesterov, + double const &weight_decay) + : lr(lr), momentum(momentum), nesterov(nesterov), + weight_decay(weight_decay) {} +bool SGDOptimizerAttrs::operator==(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) == + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +bool SGDOptimizerAttrs::operator!=(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) != + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +bool SGDOptimizerAttrs::operator<(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) < + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +bool SGDOptimizerAttrs::operator>(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) > + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +bool SGDOptimizerAttrs::operator<=(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) <= + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +bool SGDOptimizerAttrs::operator>=(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) >= + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::SGDOptimizerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.lr) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.momentum) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.nesterov) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.weight_decay) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::SGDOptimizerAttrs + adl_serializer::from_json(json const &j) { + return {j.at("lr").template get(), + j.at("momentum").template get(), + j.at("nesterov").template get(), + j.at("weight_decay").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::SGDOptimizerAttrs const &v) { + j["__type"] = "SGDOptimizerAttrs"; + j["lr"] = v.lr; + j["momentum"] = v.momentum; + j["nesterov"] = v.nesterov; + j["weight_decay"] = v.weight_decay; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(SGDOptimizerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, SGDOptimizerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph.dtg.cc new file mode 100644 index 0000000000..18549b43a2 --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph.dtg.cc @@ -0,0 +1,22 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_computation_graph.struct.toml +/* proj-data +{ + "generated_from": "3bb0791e3481298ddea75f4bd134f9e1" +} +*/ + +#include "pcg/parallel_computation_graph.dtg.h" + +#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +ParallelComputationGraph::ParallelComputationGraph( + ::FlexFlow::OutputLabelledMultiDiGraph< + ::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc b/lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc new file mode 100644 index 0000000000..455fb22baf --- /dev/null +++ b/lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc @@ -0,0 +1,83 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_layer_attrs.struct.toml +/* proj-data +{ + "generated_from": "97fa0b11c59ae892a8a530ffd67e33ad" +} +*/ + +#include "pcg/parallel_layer_attrs.dtg.h" + +#include "op-attrs/operator_attrs.h" +#include "utils/stack_string.h" +#include +#include + +namespace FlexFlow { +ParallelLayerAttrs::ParallelLayerAttrs( + ::FlexFlow::PCGOperatorAttrs const &attrs, + std::optional<::FlexFlow::stack_string> const &name) + : attrs(attrs), name(name) {} +bool ParallelLayerAttrs::operator==(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) == std::tie(other.attrs, other.name); +} +bool ParallelLayerAttrs::operator!=(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) != std::tie(other.attrs, other.name); +} +bool ParallelLayerAttrs::operator<(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) < std::tie(other.attrs, other.name); +} +bool ParallelLayerAttrs::operator>(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) > std::tie(other.attrs, other.name); +} +bool ParallelLayerAttrs::operator<=(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) <= std::tie(other.attrs, other.name); +} +bool ParallelLayerAttrs::operator>=(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) >= std::tie(other.attrs, other.name); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelLayerAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::PCGOperatorAttrs>{}(x.attrs) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash>>{}(x.name) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelLayerAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("attrs").template get<::FlexFlow::PCGOperatorAttrs>(), + j.at("name") + .template get>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelLayerAttrs const &v) { + j["__type"] = "ParallelLayerAttrs"; + j["attrs"] = v.attrs; + j["name"] = v.name; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelLayerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ParallelLayerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc b/lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc new file mode 100644 index 0000000000..ae5d618172 --- /dev/null +++ b/lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc @@ -0,0 +1,134 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml +/* proj-data +{ + "generated_from": "b3e086b380bbc41d99332e1463a34b28" +} +*/ + +#include "pcg/parallel_tensor_attrs.dtg.h" + +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/param_sync.dtg.h" +#include "pcg/create_grad.dtg.h" +#include "pcg/initializer_attrs.dtg.h" +#include +#include + +namespace FlexFlow { +ParallelTensorAttrs::ParallelTensorAttrs( + ::FlexFlow::ParallelTensorShape const &shape, + std::optional<::FlexFlow::ParamSync> const &sync_type, + std::optional<::FlexFlow::InitializerAttrs> const &initializer, + ::FlexFlow::CreateGrad const &create_gradients) + : shape(shape), sync_type(sync_type), initializer(initializer), + create_gradients(create_gradients) {} +bool ParallelTensorAttrs::operator==(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) == std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +bool ParallelTensorAttrs::operator!=(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) != std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +bool ParallelTensorAttrs::operator<(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) < std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +bool ParallelTensorAttrs::operator>(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) > std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +bool ParallelTensorAttrs::operator<=(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) <= std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +bool ParallelTensorAttrs::operator>=(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) >= std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelTensorAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ParallelTensorShape>{}(x.shape) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash>{}(x.sync_type) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash>{}(x.initializer) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::CreateGrad>{}(x.create_gradients) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelTensorAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("shape").template get<::FlexFlow::ParallelTensorShape>(), + j.at("sync_type").template get>(), + j.at("initializer") + .template get>(), + j.at("create_gradients").template get<::FlexFlow::CreateGrad>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelTensorAttrs const &v) { + j["__type"] = "ParallelTensorAttrs"; + j["shape"] = v.shape; + j["sync_type"] = v.sync_type; + j["initializer"] = v.initializer; + j["create_gradients"] = v.create_gradients; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelTensorAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ParallelTensorAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/side_size_t.dtg.cc b/lib/pcg/src/pcg/side_size_t.dtg.cc new file mode 100644 index 0000000000..54db2974fe --- /dev/null +++ b/lib/pcg/src/pcg/side_size_t.dtg.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/side_size_t.struct.toml +/* proj-data +{ + "generated_from": "6a1669890e547dcc7a4ddb90be05be15" +} +*/ + +#include "pcg/side_size_t.dtg.h" + +#include + +namespace FlexFlow { +side_size_t::side_size_t(int const &unwrapped) : unwrapped(unwrapped) {} +bool side_size_t::operator==(side_size_t const &other) const { + return std::tie(this->unwrapped) == std::tie(other.unwrapped); +} +bool side_size_t::operator!=(side_size_t const &other) const { + return std::tie(this->unwrapped) != std::tie(other.unwrapped); +} +bool side_size_t::operator<(side_size_t const &other) const { + return std::tie(this->unwrapped) < std::tie(other.unwrapped); +} +bool side_size_t::operator>(side_size_t const &other) const { + return std::tie(this->unwrapped) > std::tie(other.unwrapped); +} +bool side_size_t::operator<=(side_size_t const &other) const { + return std::tie(this->unwrapped) <= std::tie(other.unwrapped); +} +bool side_size_t::operator>=(side_size_t const &other) const { + return std::tie(this->unwrapped) >= std::tie(other.unwrapped); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::side_size_t const &x) const { + size_t result = 0; + result ^= std::hash{}(x.unwrapped) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::side_size_t + adl_serializer::from_json(json const &j) { + return {j.at("unwrapped").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::side_size_t const &v) { + j["__type"] = "side_size_t"; + j["unwrapped"] = v.unwrapped; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(side_size_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, side_size_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/strided_rectangle.dtg.cc b/lib/pcg/src/pcg/strided_rectangle.dtg.cc new file mode 100644 index 0000000000..d9cb72a882 --- /dev/null +++ b/lib/pcg/src/pcg/strided_rectangle.dtg.cc @@ -0,0 +1,77 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/strided_rectangle.struct.toml +/* proj-data +{ + "generated_from": "87af84e6a16d5363049cb9a9a75e4f5f" +} +*/ + +#include "pcg/strided_rectangle.dtg.h" + +#include "op-attrs/dim_ordered.h" +#include "pcg/strided_rectangle_side.dtg.h" +#include + +namespace FlexFlow { +StridedRectangle::StridedRectangle( + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide> const &sides) + : sides(sides) {} +bool StridedRectangle::operator==(StridedRectangle const &other) const { + return std::tie(this->sides) == std::tie(other.sides); +} +bool StridedRectangle::operator!=(StridedRectangle const &other) const { + return std::tie(this->sides) != std::tie(other.sides); +} +bool StridedRectangle::operator<(StridedRectangle const &other) const { + return std::tie(this->sides) < std::tie(other.sides); +} +bool StridedRectangle::operator>(StridedRectangle const &other) const { + return std::tie(this->sides) > std::tie(other.sides); +} +bool StridedRectangle::operator<=(StridedRectangle const &other) const { + return std::tie(this->sides) <= std::tie(other.sides); +} +bool StridedRectangle::operator>=(StridedRectangle const &other) const { + return std::tie(this->sides) >= std::tie(other.sides); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::StridedRectangle const &x) const { + size_t result = 0; + result ^= + std::hash<::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>>{}( + x.sides) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::StridedRectangle + adl_serializer::from_json(json const &j) { + return {j.at("sides") + .template get< + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::StridedRectangle const &v) { + j["__type"] = "StridedRectangle"; + j["sides"] = v.sides; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(StridedRectangle const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, StridedRectangle const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/strided_rectangle_side.cc b/lib/pcg/src/pcg/strided_rectangle_side.cc new file mode 100644 index 0000000000..fad022a65b --- /dev/null +++ b/lib/pcg/src/pcg/strided_rectangle_side.cc @@ -0,0 +1,14 @@ +#include "pcg/strided_rectangle_side.h" +#include "utils/exception.h" + +namespace FlexFlow { + +StridedRectangleSide strided_side_from_size_and_stride(side_size_t, int stride) { + NOT_IMPLEMENTED(); +} + +side_size_t get_side_size(StridedRectangleSide const &s) { + return s.num_points.unwrapped * s.stride; +} + +} diff --git a/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc b/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc new file mode 100644 index 0000000000..0bb31b0496 --- /dev/null +++ b/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc @@ -0,0 +1,91 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/strided_rectangle_side.struct.toml +/* proj-data +{ + "generated_from": "b14fcf1e28c262d22b92fac691ede3d4" +} +*/ + +#include "pcg/strided_rectangle_side.dtg.h" + +#include "pcg/num_points_t.dtg.h" +#include + +namespace FlexFlow { +StridedRectangleSide::StridedRectangleSide( + ::FlexFlow::num_points_t const &num_points, int const &stride) + : num_points(num_points), stride(stride) {} +bool StridedRectangleSide::operator==(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) == + std::tie(other.num_points, other.stride); +} +bool StridedRectangleSide::operator!=(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) != + std::tie(other.num_points, other.stride); +} +bool StridedRectangleSide::operator<(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) < + std::tie(other.num_points, other.stride); +} +bool StridedRectangleSide::operator>(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) > + std::tie(other.num_points, other.stride); +} +bool StridedRectangleSide::operator<=(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) <= + std::tie(other.num_points, other.stride); +} +bool StridedRectangleSide::operator>=(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) >= + std::tie(other.num_points, other.stride); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::StridedRectangleSide const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::num_points_t>{}(x.num_points) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stride) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::StridedRectangleSide + adl_serializer::from_json(json const &j) { + return {j.at("num_points").template get<::FlexFlow::num_points_t>(), + j.at("stride").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::StridedRectangleSide const &v) { + j["__type"] = "StridedRectangleSide"; + j["num_points"] = v.num_points; + j["stride"] = v.stride; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::num_points_t>(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(StridedRectangleSide const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, StridedRectangleSide const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/tensor_attrs.dtg.cc b/lib/pcg/src/pcg/tensor_attrs.dtg.cc new file mode 100644 index 0000000000..46a6fb8d50 --- /dev/null +++ b/lib/pcg/src/pcg/tensor_attrs.dtg.cc @@ -0,0 +1,133 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/tensor_attrs.struct.toml +/* proj-data +{ + "generated_from": "68447a4357476647ef25dd39dfd12578" +} +*/ + +#include "pcg/tensor_attrs.dtg.h" + +#include "op-attrs/param_sync.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "pcg/initializer_attrs.dtg.h" +#include +#include + +namespace FlexFlow { +TensorAttrs::TensorAttrs( + ::FlexFlow::TensorShape const &shape, + std::optional<::FlexFlow::InitializerAttrs> const &initializer, + bool const &create_gradients, + std::optional<::FlexFlow::ParamSync> const &sync_type) + : shape(shape), initializer(initializer), + create_gradients(create_gradients), sync_type(sync_type) {} +bool TensorAttrs::operator==(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) == std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +bool TensorAttrs::operator!=(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) != std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +bool TensorAttrs::operator<(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) < std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +bool TensorAttrs::operator>(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) > std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +bool TensorAttrs::operator<=(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) <= std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +bool TensorAttrs::operator>=(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) >= std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorShape>{}(x.shape) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash>{}(x.initializer) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.create_gradients) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash>{}(x.sync_type) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("shape").template get<::FlexFlow::TensorShape>(), + j.at("initializer") + .template get>(), + j.at("create_gradients").template get(), + j.at("sync_type").template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorAttrs const &v) { + j["__type"] = "TensorAttrs"; + j["shape"] = v.shape; + j["initializer"] = v.initializer; + j["create_gradients"] = v.create_gradients; + j["sync_type"] = v.sync_type; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/tensor_guid_t.dtg.cc b/lib/pcg/src/pcg/tensor_guid_t.dtg.cc new file mode 100644 index 0000000000..9d57291112 --- /dev/null +++ b/lib/pcg/src/pcg/tensor_guid_t.dtg.cc @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/tensor_guid_t.struct.toml +/* proj-data +{ + "generated_from": "dc15fcbb876ec70509dfa8b662963bc3" +} +*/ + +#include "pcg/tensor_guid_t.dtg.h" + +#include "utils/graph.h" +#include + +namespace FlexFlow { +tensor_guid_t::tensor_guid_t(::FlexFlow::MultiDiOutput const &raw_graph_output) + : raw_graph_output(raw_graph_output) {} +bool tensor_guid_t::operator==(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) == std::tie(other.raw_graph_output); +} +bool tensor_guid_t::operator!=(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) != std::tie(other.raw_graph_output); +} +bool tensor_guid_t::operator<(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) < std::tie(other.raw_graph_output); +} +bool tensor_guid_t::operator>(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) > std::tie(other.raw_graph_output); +} +bool tensor_guid_t::operator<=(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) <= std::tie(other.raw_graph_output); +} +bool tensor_guid_t::operator>=(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) >= std::tie(other.raw_graph_output); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::tensor_guid_t const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::MultiDiOutput>{}(x.raw_graph_output) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(tensor_guid_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, tensor_guid_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/serialization.cc b/lib/pcg/src/serialization.cc deleted file mode 100644 index 439c03916d..0000000000 --- a/lib/pcg/src/serialization.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "pcg/serialization.h" - -namespace FlexFlow {} diff --git a/lib/pcg/src/strided_rectangle.cc b/lib/pcg/src/strided_rectangle.cc index 27ef9a7f5b..9c8ff69b42 100644 --- a/lib/pcg/src/strided_rectangle.cc +++ b/lib/pcg/src/strided_rectangle.cc @@ -3,34 +3,23 @@ namespace FlexFlow { -size_t StridedRectangle::at(FFOrdered const &coord) const { - assert(coord.size() == this->num_dims()); - - size_t _1d_stride = 1; - size_t idx = 0; - for (auto dim : inner_to_outer_idxs(this->sides)) { - idx += this->sides.at(dim).at(coord.at(dim)).value() * _1d_stride; - _1d_stride *= this->sides.at(dim).get_size().value(); - } - return idx; -} - -StridedRectangleSide::StridedRectangleSide(side_size_t const &num, int stride) - : num_points(num.value()), stride(stride) {} - -side_size_t StridedRectangleSide::at(num_points_t) const { - NOT_IMPLEMENTED(); -} - -num_points_t StridedRectangleSide::at(side_size_t) const { - NOT_IMPLEMENTED(); -} - -side_size_t StridedRectangleSide::get_size() const { +/* size_t StridedRectangle::at(FFOrdered const &coord) const { */ +/* assert(coord.size() == this->num_dims()); */ + +/* size_t _1d_stride = 1; */ +/* size_t idx = 0; */ +/* for (auto dim : inner_to_outer_idxs(this->sides)) { */ +/* idx += this->sides.at(dim).at(coord.at(dim)).value() * _1d_stride; */ +/* _1d_stride *= this->sides.at(dim).get_size().value(); */ +/* } */ +/* return idx; */ +/* } */ + +size_t get_num_dims(StridedRectangle const &) { NOT_IMPLEMENTED(); } -size_t StridedRectangle::num_dims() const { +size_t get_side_at_idx(StridedRectangle const &) { NOT_IMPLEMENTED(); } diff --git a/lib/pcg/src/tensor.cc b/lib/pcg/src/tensor.cc deleted file mode 100644 index a5aa4b0d0c..0000000000 --- a/lib/pcg/src/tensor.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "pcg/tensor.h" - -namespace FlexFlow { - -Tensor::operator TensorShape() const { - return TensorShape{dims, data_type}; -} - -TensorShape Tensor::get_shape() const { - return TensorShape(*this); -} - -} // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/attribute_expr.h b/lib/substitutions/include/substitutions/attribute_expr.h deleted file mode 100644 index 0afd48b431..0000000000 --- a/lib/substitutions/include/substitutions/attribute_expr.h +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_CONSTRAINT_H -#define _FLEXFLOW_SUBSTITUTIONS_CONSTRAINT_H - -#include "utils/variant.h" - -namespace FlexFlow { - -enum class ConstraintType { EQUAL }; - -template -struct ListIndexAccess { - T attribute_key; - req index; -}; - -template -struct ListSize { - req attribute_key; -}; - -template -using AttributeExpr = std::variant, ListSize>; - -template -struct AttributeConstraint { - ConstraintType constraint_type; - AttributeExpr attribute_expr; - V attribute_value; -}; - -template -struct AttributePattern { - std::vector> attribute_constraints; - // TODO: Revert to unordered_set once we have visitable for templates - // std::unordered_set> attribute_constraints; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/constraint_type.dtg.h b/lib/substitutions/include/substitutions/constraint_type.dtg.h new file mode 100644 index 0000000000..f99b794e46 --- /dev/null +++ b/lib/substitutions/include/substitutions/constraint_type.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/constraint_type.enum.toml +/* proj-data +{ + "generated_from": "06b029d76658cb434abf08b1fdb86137" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_CONSTRAINT_TYPE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_CONSTRAINT_TYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class ConstraintType { EQUAL }; +std::string format_as(ConstraintType); +std::ostream &operator<<(std::ostream &, ConstraintType); +void to_json(::nlohmann::json &, ConstraintType); +void from_json(::nlohmann::json const &, ConstraintType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ConstraintType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_CONSTRAINT_TYPE_DTG_H diff --git a/lib/substitutions/include/substitutions/constraint_type.enum.toml b/lib/substitutions/include/substitutions/constraint_type.enum.toml new file mode 100644 index 0000000000..8646ba1c83 --- /dev/null +++ b/lib/substitutions/include/substitutions/constraint_type.enum.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "ConstraintType" +features = [ + "json", + "hash", + "rapidcheck", + "fmt", +] + +[[values]] +name = "EQUAL" diff --git a/lib/substitutions/include/substitutions/graph_pattern.h b/lib/substitutions/include/substitutions/graph_pattern.h index 4f4021203b..533109387b 100644 --- a/lib/substitutions/include/substitutions/graph_pattern.h +++ b/lib/substitutions/include/substitutions/graph_pattern.h @@ -1,32 +1,22 @@ #ifndef _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_H #define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_H -#include "graph_pattern_match.h" -#include "operator_pattern.h" -#include "parallel_tensor_pattern.h" -#include "sub_parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/pcg_pattern.dtg.h" +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "substitutions/unlabelled/pattern_matching.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" namespace FlexFlow { -struct GraphPattern - : public strong_typedef< - GraphPattern, - OutputLabelledOpenMultiDiGraph> { - using strong_typedef::strong_typedef; -}; +UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &); -GraphSplit split_pattern(OpenMultiDiGraphView const &pattern); - -bool is_singleton_pattern(OpenMultiDiGraphView const &); - -bool operator_satisfies(Operator const ¶ms, OperatorPattern const &pattern); - -bool parallel_tensor_satisfies(ParallelTensor const ¶ms, - ParallelTensorPattern const &pattern); +TensorAttributePattern get_tensor_pattern(PCGPattern const &, PatternEdge const &); +OperatorAttributePattern get_operator_pattern(PCGPattern const &, PatternNode const &); bool assignment_satisfies(SubParallelComputationGraph const &, - GraphPattern const &, + PCGPattern const &, MultiDiGraphPatternMatch const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/graph_pattern_match.h b/lib/substitutions/include/substitutions/graph_pattern_match.h deleted file mode 100644 index bf6d6b6921..0000000000 --- a/lib/substitutions/include/substitutions/graph_pattern_match.h +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_GRAPH_PATTERN_MATCH_H -#define _FLEXFLOW_SUBSTITUTIONS_GRAPH_PATTERN_MATCH_H - -#include "utils/graph.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct MultiDiGraphPatternMatch { - using PatternNode = Node; - using PCGNode = Node; - using PatternEdge = OpenMultiDiEdge; - using PCGEdge = OpenMultiDiEdge; - - bidict node_assignment; - bidict edge_assignment; -}; - -struct MatchSplit { - MultiDiGraphPatternMatch prefix_submatch; - MultiDiGraphPatternMatch postfix_submatch; -}; - -struct MatchAdditionalCriterion { - std::function node_criterion; - std::function - edge_criterion; -}; - -bool pattern_matches(OpenMultiDiGraphView const &pattern, - OpenMultiDiGraphView const &graph, - MultiDiGraphPatternMatch const &match, - MatchAdditionalCriterion const &additional_criterion); - -std::vector - find_pattern_matches(OpenMultiDiGraphView const &pattern, - OpenMultiDiGraphView const &graph, - MatchAdditionalCriterion const &additional_criterion); - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h deleted file mode 100644 index 5f2be36a09..0000000000 --- a/lib/substitutions/include/substitutions/operator_pattern.h +++ /dev/null @@ -1,107 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_OPERATOR_PATTERN_H -#define _FLEXFLOW_SUBSTITUTIONS_OPERATOR_PATTERN_H - -#include "attribute_expr.h" -#include "op-attrs/activation.h" -#include "op-attrs/datatype.h" -#include "op-attrs/operator_type.h" -#include "pcg/operator.h" -#include -#include - -namespace FlexFlow { - -enum class OperatorAttributeKey { - OP_TYPE, // AnyOp - USE_BIAS, - GROUPS, - POOL_TYPE, - KERNEL_H, - KERNEL_W, - DATA_TYPE, - SCALAR, - STRIDE_H, - STRIDE_W, - PADDING_H, - PADDING_W, - AGGR, - NUM_ENTRIES, - OUT_CHANNELS, - ACTIVATION, - NUMDIM, - AXIS, - PERMUTATION, - OUTSHUFFLE, - MERGE_GCONV_COUNT, - AXES, - KEEP_DIMS, - EPSILON, - PARALLEL_OP_DIM, - PARALLEL_OP_DEGREE, - SOFTMAX_DIM, - NUM_HEADS, - PARALLEL_DIM, - PARALLEL_DEGREE, - PAD, - EMBED_DIM, - KDIM, - VDIM, - DROPOUT, - BIAS, - ADD_BIAS_KV, - ADD_ZERO_ATTN, - A_SEQ_LENGTH_DIM, - B_SEQ_LENGTH_DIM, - RELU, - TARGET_DIMS, - RATE, - SEED, - SHOULD_BROADCAST_LHS, - SHOULD_BROADCAST_RHS, - DIM, - ELEMENTWISE_AFFINE, - REGULARIZER, - SHAPE, - SPLITS, - K, - SORTED, - COMBINE_DIM, - COMBINE_DEGREE, - NUM_INPUTS -}; - -using OperatorAttributeValue = - std::variant, - stack_vector, - OperatorType, - Activation, - ff_dim_t, - unsigned long long, - AggregateOp, - stack_vector, - std::optional, - PoolOp, - TensorShape, - DataType>; - -FF_VISITABLE_STRUCT(ListIndexAccess, - attribute_key, - index); -FF_VISITABLE_STRUCT(ListSize, attribute_key); - -using OperatorAttributeConstraint = - AttributeConstraint; - -using OperatorPattern = - AttributePattern; - -std::optional - evaluate_attribute_expr(Operator const &attrs, - AttributeExpr const &expr); - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/eval_list_access.h b/lib/substitutions/include/substitutions/operator_pattern/eval_list_access.h new file mode 100644 index 0000000000..777f38edea --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/eval_list_access.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_H + +#include "substitutions/operator_pattern/operator_attribute_list_access.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include + +namespace FlexFlow { + +std::optional eval_list_access(PCGOperatorAttrs const &attrs, OperatorAttributeListIndexAccess const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/eval_list_size.h b/lib/substitutions/include/substitutions/operator_pattern/eval_list_size.h new file mode 100644 index 0000000000..337799955b --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/eval_list_size.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_H + +#include "substitutions/operator_pattern/operator_attribute_list_size.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" + +namespace FlexFlow { + +std::optional eval_list_size(PCGOperatorAttrs const &attrs, OperatorAttributeListSize const &acc); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/get_attribute.h b/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h similarity index 83% rename from lib/substitutions/include/substitutions/get_attribute.h rename to lib/substitutions/include/substitutions/operator_pattern/get_attribute.h index 0e6dd4c69b..93f4a2bc0f 100644 --- a/lib/substitutions/include/substitutions/get_attribute.h +++ b/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h @@ -1,9 +1,10 @@ #ifndef _FLEXFLOW_SUBSTITUTIONS_GET_ATTRIBUTES_H #define _FLEXFLOW_SUBSTITUTIONS_GET_ATTRIBUTES_H -#include "op-attrs/operator_attrs.h" -#include "operator_pattern.h" -#include "utils/optional.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include namespace FlexFlow { @@ -11,6 +12,8 @@ std::optional get_attribute(PCGOperatorAttrs const &, OperatorAttributeKey); std::optional get_attribute(BatchMatmulAttrs const &p, OperatorAttributeKey); +std::optional get_attribute(BatchNormAttrs const &p, + OperatorAttributeKey); std::optional get_attribute(CastAttrs const &p, OperatorAttributeKey); std::optional get_attribute(CombineAttrs const &p, @@ -33,12 +36,17 @@ std::optional get_attribute(FlatAttrs const &p, OperatorAttributeKey); std::optional get_attribute(GatherAttrs const &p, OperatorAttributeKey); +std::optional get_attribute(InputAttrs const &p, + OperatorAttributeKey); std::optional get_attribute(LayerNormAttrs const &p, OperatorAttributeKey); std::optional get_attribute(LinearAttrs const &p, OperatorAttributeKey); std::optional get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey); + +std::optional get_attribute(NoopAttrs const &p, + OperatorAttributeKey); std::optional get_attribute(Pool2DAttrs const &p, OperatorAttributeKey); std::optional get_attribute(ReduceAttrs const &p, @@ -51,6 +59,8 @@ std::optional get_attribute(ReplicateAttrs const &p, OperatorAttributeKey); std::optional get_attribute(ReshapeAttrs const &p, OperatorAttributeKey); +std::optional get_attribute(ReverseAttrs const &p, + OperatorAttributeKey); std::optional get_attribute(SplitAttrs const &p, OperatorAttributeKey); std::optional get_attribute(SoftmaxAttrs const &p, diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.h new file mode 100644 index 0000000000..35ec9e499f --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml +/* proj-data +{ + "generated_from": "7867bd0f403866c13417171bb5ec364c" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_CONSTRAINT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_CONSTRAINT_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/constraint_type.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributeConstraint { + OperatorAttributeConstraint() = delete; + OperatorAttributeConstraint( + ::FlexFlow::ConstraintType const &constraint_type, + ::FlexFlow::OperatorAttributeExpr const &attribute_expr, + ::FlexFlow::OperatorAttributeValue const &attribute_value); + + bool operator==(OperatorAttributeConstraint const &) const; + bool operator!=(OperatorAttributeConstraint const &) const; + bool operator<(OperatorAttributeConstraint const &) const; + bool operator>(OperatorAttributeConstraint const &) const; + bool operator<=(OperatorAttributeConstraint const &) const; + bool operator>=(OperatorAttributeConstraint const &) const; + ::FlexFlow::ConstraintType constraint_type; + ::FlexFlow::OperatorAttributeExpr attribute_expr; + ::FlexFlow::OperatorAttributeValue attribute_value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorAttributeConstraint const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::OperatorAttributeConstraint from_json(json const &); + static void to_json(json &, FlexFlow::OperatorAttributeConstraint const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(OperatorAttributeConstraint const &); +std::ostream &operator<<(std::ostream &, OperatorAttributeConstraint const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_CONSTRAINT_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml new file mode 100644 index 0000000000..646faf878e --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "OperatorAttributeConstraint" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "substitutions/constraint_type.dtg.h", + "substitutions/operator_pattern/operator_attribute_expr.dtg.h", + "substitutions/operator_pattern/operator_attribute_value.dtg.h", +] + +[[fields]] +name = "constraint_type" +type = "::FlexFlow::ConstraintType" + +[[fields]] +name = "attribute_expr" +type = "::FlexFlow::OperatorAttributeExpr" + +[[fields]] +name = "attribute_value" +type = "::FlexFlow::OperatorAttributeValue" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.dtg.h new file mode 100644 index 0000000000..a66a035ba8 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.dtg.h @@ -0,0 +1,143 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "15d26dd1f08092ecc82b725aa9411597" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_list_access.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_list_size.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributeExpr { + OperatorAttributeExpr() = delete; + explicit OperatorAttributeExpr(::FlexFlow::OperatorAttributeKey const &); + explicit OperatorAttributeExpr(::FlexFlow::OperatorAttributeListSize const &); + explicit OperatorAttributeExpr( + ::FlexFlow::OperatorAttributeListIndexAccess const &); + template + static constexpr bool IsPartOfOperatorAttributeExpr_v = + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::OperatorAttributeKey>()); + return result; + } + case 1: { + ReturnType result = + v(this->get<::FlexFlow::OperatorAttributeListSize>()); + return result; + } + case 2: { + ReturnType result = + v(this->get<::FlexFlow::OperatorAttributeListIndexAccess>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeExpr", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::OperatorAttributeKey>()); + return result; + } + case 1: { + ReturnType result = + v(this->get<::FlexFlow::OperatorAttributeListSize>()); + return result; + } + case 2: { + ReturnType result = + v(this->get<::FlexFlow::OperatorAttributeListIndexAccess>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeExpr", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfOperatorAttributeExpr_v, + "OperatorAttributeExpr::has() expected one of " + "[::FlexFlow::OperatorAttributeKey, " + "::FlexFlow::OperatorAttributeListSize, " + "::FlexFlow::OperatorAttributeListIndexAccess], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfOperatorAttributeExpr_v, + "OperatorAttributeExpr::get() expected one of " + "[::FlexFlow::OperatorAttributeKey, " + "::FlexFlow::OperatorAttributeListSize, " + "::FlexFlow::OperatorAttributeListIndexAccess], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfOperatorAttributeExpr_v, + "OperatorAttributeExpr::get() expected one of " + "[::FlexFlow::OperatorAttributeKey, " + "::FlexFlow::OperatorAttributeListSize, " + "::FlexFlow::OperatorAttributeListIndexAccess], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(OperatorAttributeExpr const &) const; + bool operator!=(OperatorAttributeExpr const &) const; + bool operator<(OperatorAttributeExpr const &) const; + bool operator>(OperatorAttributeExpr const &) const; + bool operator<=(OperatorAttributeExpr const &) const; + bool operator>=(OperatorAttributeExpr const &) const; + std::variant<::FlexFlow::OperatorAttributeKey, + ::FlexFlow::OperatorAttributeListSize, + ::FlexFlow::OperatorAttributeListIndexAccess> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::OperatorAttributeExpr> { + size_t operator()(::FlexFlow::OperatorAttributeExpr const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::OperatorAttributeExpr> { + static ::FlexFlow::OperatorAttributeExpr from_json(json const &); + static void to_json(json &, ::FlexFlow::OperatorAttributeExpr const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::OperatorAttributeExpr const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::OperatorAttributeExpr const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h new file mode 100644 index 0000000000..f37ad64df0 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_H + +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include "pcg/parallel_layer_attrs.dtg.h" +#include + +namespace FlexFlow { + +std::optional + evaluate_attribute_expr(PCGOperatorAttrs const &attrs, + OperatorAttributeExpr const &expr); +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml new file mode 100644 index 0000000000..ff79ecaaa5 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "OperatorAttributeExpr" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h", + "substitutions/operator_pattern/operator_attribute_list_access.dtg.h", + "substitutions/operator_pattern/operator_attribute_list_size.dtg.h", +] + +[[values]] +type = "::FlexFlow::OperatorAttributeKey" +key = "key" + +[[values]] +type = "::FlexFlow::OperatorAttributeListSize" +key = "list_size" + +[[values]] +type = "::FlexFlow::OperatorAttributeListIndexAccess" +key = "list_idx" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.dtg.h new file mode 100644 index 0000000000..49a5ccbbe6 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.dtg.h @@ -0,0 +1,97 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml +/* proj-data +{ + "generated_from": "e637388397720b328b1f4b9ba6b14611" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_KEY_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_KEY_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class OperatorAttributeKey { + OP_TYPE, + USE_BIAS, + GROUPS, + POOL_TYPE, + KERNEL_H, + KERNEL_W, + DATA_TYPE, + SCALAR, + STRIDE_H, + STRIDE_W, + PADDING_H, + PADDING_W, + AGGR, + NUM_ENTRIES, + OUT_CHANNELS, + ACTIVATION, + NUMDIM, + AXIS, + PERMUTATION, + OUTSHUFFLE, + MERGE_GCONV_COUNT, + AXES, + KEEP_DIMS, + EPSILON, + PARALLEL_OP_DIM, + PARALLEL_OP_DEGREE, + SOFTMAX_DIM, + NUM_HEADS, + PARALLEL_DIM, + PARALLEL_DEGREE, + PAD, + EMBED_DIM, + KDIM, + VDIM, + DROPOUT, + BIAS, + ADD_BIAS_KV, + ADD_ZERO_ATTN, + A_SEQ_LENGTH_DIM, + B_SEQ_LENGTH_DIM, + RELU, + TARGET_DIMS, + RATE, + SEED, + SHOULD_BROADCAST_LHS, + SHOULD_BROADCAST_RHS, + DIM, + ELEMENTWISE_AFFINE, + REGULARIZER, + SHAPE, + SPLITS, + K, + SORTED, + COMBINE_DIM, + COMBINE_DEGREE, + NUM_INPUTS +}; +std::string format_as(OperatorAttributeKey); +std::ostream &operator<<(std::ostream &, OperatorAttributeKey); +void to_json(::nlohmann::json &, OperatorAttributeKey); +void from_json(::nlohmann::json const &, OperatorAttributeKey &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorAttributeKey) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_KEY_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml new file mode 100644 index 0000000000..59e913750e --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml @@ -0,0 +1,67 @@ +namespace = "FlexFlow" +name = "OperatorAttributeKey" +features = [ + "json", + "hash", + "fmt", + "rapidcheck", +] + +values = [ + { name = "OP_TYPE" }, + { name = "USE_BIAS" }, + { name = "GROUPS" }, + { name = "POOL_TYPE" }, + { name = "KERNEL_H" }, + { name = "KERNEL_W" }, + { name = "DATA_TYPE" }, + { name = "SCALAR" }, + { name = "STRIDE_H" }, + { name = "STRIDE_W" }, + { name = "PADDING_H" }, + { name = "PADDING_W" }, + { name = "AGGR" }, + { name = "NUM_ENTRIES" }, + { name = "OUT_CHANNELS" }, + { name = "ACTIVATION" }, + { name = "NUMDIM" }, + { name = "AXIS" }, + { name = "PERMUTATION" }, + { name = "OUTSHUFFLE" }, + { name = "MERGE_GCONV_COUNT" }, + { name = "AXES" }, + { name = "KEEP_DIMS" }, + { name = "EPSILON" }, + { name = "PARALLEL_OP_DIM" }, + { name = "PARALLEL_OP_DEGREE" }, + { name = "SOFTMAX_DIM" }, + { name = "NUM_HEADS" }, + { name = "PARALLEL_DIM" }, + { name = "PARALLEL_DEGREE" }, + { name = "PAD" }, + { name = "EMBED_DIM" }, + { name = "KDIM" }, + { name = "VDIM" }, + { name = "DROPOUT" }, + { name = "BIAS" }, + { name = "ADD_BIAS_KV" }, + { name = "ADD_ZERO_ATTN" }, + { name = "A_SEQ_LENGTH_DIM" }, + { name = "B_SEQ_LENGTH_DIM" }, + { name = "RELU" }, + { name = "TARGET_DIMS" }, + { name = "RATE" }, + { name = "SEED" }, + { name = "SHOULD_BROADCAST_LHS" }, + { name = "SHOULD_BROADCAST_RHS" }, + { name = "DIM" }, + { name = "ELEMENTWISE_AFFINE" }, + { name = "REGULARIZER" }, + { name = "SHAPE" }, + { name = "SPLITS" }, + { name = "K" }, + { name = "SORTED" }, + { name = "COMBINE_DIM" }, + { name = "COMBINE_DEGREE" }, + { name = "NUM_INPUTS" }, +] diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.h new file mode 100644 index 0000000000..5a30c40f8d --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.h @@ -0,0 +1,67 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml +/* proj-data +{ + "generated_from": "1dc90d1e823f05b82c1a5ff433fbf000" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributeListIndexAccess { + OperatorAttributeListIndexAccess() = delete; + OperatorAttributeListIndexAccess( + ::FlexFlow::OperatorAttributeKey const &attribute_key, int const &index); + + bool operator==(OperatorAttributeListIndexAccess const &) const; + bool operator!=(OperatorAttributeListIndexAccess const &) const; + bool operator<(OperatorAttributeListIndexAccess const &) const; + bool operator>(OperatorAttributeListIndexAccess const &) const; + bool operator<=(OperatorAttributeListIndexAccess const &) const; + bool operator>=(OperatorAttributeListIndexAccess const &) const; + ::FlexFlow::OperatorAttributeKey attribute_key; + int index; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorAttributeListIndexAccess const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::OperatorAttributeListIndexAccess from_json(json const &); + static void to_json(json &, + FlexFlow::OperatorAttributeListIndexAccess const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(OperatorAttributeListIndexAccess const &); +std::ostream &operator<<(std::ostream &, + OperatorAttributeListIndexAccess const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml new file mode 100644 index 0000000000..bceff393d2 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "OperatorAttributeListIndexAccess" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h" +] + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::OperatorAttributeKey" + +[[fields]] +name = "index" +type = "int" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.h new file mode 100644 index 0000000000..17d76a08f1 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml +/* proj-data +{ + "generated_from": "30999ad6b0603e380bc33d32fa088e45" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributeListSize { + OperatorAttributeListSize() = delete; + OperatorAttributeListSize( + ::FlexFlow::OperatorAttributeKey const &attribute_key); + + bool operator==(OperatorAttributeListSize const &) const; + bool operator!=(OperatorAttributeListSize const &) const; + bool operator<(OperatorAttributeListSize const &) const; + bool operator>(OperatorAttributeListSize const &) const; + bool operator<=(OperatorAttributeListSize const &) const; + bool operator>=(OperatorAttributeListSize const &) const; + ::FlexFlow::OperatorAttributeKey attribute_key; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorAttributeListSize const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::OperatorAttributeListSize from_json(json const &); + static void to_json(json &, FlexFlow::OperatorAttributeListSize const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(OperatorAttributeListSize const &); +std::ostream &operator<<(std::ostream &, OperatorAttributeListSize const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml new file mode 100644 index 0000000000..271b545fda --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "OperatorAttributeListSize" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h", +] + + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::OperatorAttributeKey" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h new file mode 100644 index 0000000000..7bce198f3d --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h @@ -0,0 +1,56 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml +/* proj-data +{ + "generated_from": "968d7a3e93303a7fa7482bbcd50246b6" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_PATTERN_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_PATTERN_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" +#include "utils/fmt.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributePattern { + OperatorAttributePattern() = delete; + OperatorAttributePattern( + std::unordered_set<::FlexFlow::OperatorAttributeConstraint> const + &attribute_constraints); + + bool operator==(OperatorAttributePattern const &) const; + bool operator!=(OperatorAttributePattern const &) const; + std::unordered_set<::FlexFlow::OperatorAttributeConstraint> + attribute_constraints; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorAttributePattern const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::OperatorAttributePattern from_json(json const &); + static void to_json(json &, FlexFlow::OperatorAttributePattern const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(OperatorAttributePattern const &); +std::ostream &operator<<(std::ostream &, OperatorAttributePattern const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_PATTERN_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml new file mode 100644 index 0000000000..6facf7d3bc --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OperatorAttributePattern" +features = [ + "eq", + # "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "utils/fmt.h", + "substitutions/operator_pattern/operator_attribute_constraint.dtg.h", +] + +[[fields]] +name = "attribute_constraints" +type = "std::unordered_set<::FlexFlow::OperatorAttributeConstraint>" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.h new file mode 100644 index 0000000000..080909d147 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.h @@ -0,0 +1,264 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +/* proj-data +{ + "generated_from": "de14592f1f4bcfb52689bc95e9d3b55f" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_VALUE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_VALUE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/activation.dtg.h" +#include "op-attrs/aggregate_op.dtg.h" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/operator_type.dtg.h" +#include "op-attrs/pool_op.dtg.h" +#include "op-attrs/regularizer_attrs.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include +#include +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributeValue { + OperatorAttributeValue() = delete; + OperatorAttributeValue(int const &); + OperatorAttributeValue(bool const &); + OperatorAttributeValue(std::vector const &); + OperatorAttributeValue(std::vector<::FlexFlow::ff_dim_t> const &); + OperatorAttributeValue(::FlexFlow::OperatorType const &); + OperatorAttributeValue(::FlexFlow::Activation const &); + OperatorAttributeValue(::FlexFlow::ff_dim_t const &); + OperatorAttributeValue(size_t const &); + OperatorAttributeValue(::FlexFlow::AggregateOp const &); + OperatorAttributeValue(std::optional<::FlexFlow::RegularizerAttrs> const &); + OperatorAttributeValue(::FlexFlow::PoolOp const &); + OperatorAttributeValue(::FlexFlow::TensorShape const &); + OperatorAttributeValue(::FlexFlow::DataType const &); + template + static constexpr bool IsPartOfOperatorAttributeValue_v = + std::is_same_v || std::is_same_v || + std::is_same_v> || + std::is_same_v> || + std::is_same_v || + std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || + std::is_same_v> || + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get()); + return result; + } + case 1: { + ReturnType result = v(this->get()); + return result; + } + case 2: { + ReturnType result = v(this->get>()); + return result; + } + case 3: { + ReturnType result = v(this->get>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::OperatorType>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Activation>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::ff_dim_t>()); + return result; + } + case 7: { + ReturnType result = v(this->get()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::AggregateOp>()); + return result; + } + case 9: { + ReturnType result = + v(this->get>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::PoolOp>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::TensorShape>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::DataType>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeValue", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get()); + return result; + } + case 1: { + ReturnType result = v(this->get()); + return result; + } + case 2: { + ReturnType result = v(this->get>()); + return result; + } + case 3: { + ReturnType result = v(this->get>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::OperatorType>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Activation>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::ff_dim_t>()); + return result; + } + case 7: { + ReturnType result = v(this->get()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::AggregateOp>()); + return result; + } + case 9: { + ReturnType result = + v(this->get>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::PoolOp>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::TensorShape>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::DataType>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeValue", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfOperatorAttributeValue_v, + "OperatorAttributeValue::has() expected one of [int, bool, " + "std::vector, std::vector<::FlexFlow::ff_dim_t>, " + "::FlexFlow::OperatorType, ::FlexFlow::Activation, " + "::FlexFlow::ff_dim_t, size_t, ::FlexFlow::AggregateOp, " + "std::optional<::FlexFlow::RegularizerAttrs>, ::FlexFlow::PoolOp, " + "::FlexFlow::TensorShape, ::FlexFlow::DataType], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfOperatorAttributeValue_v, + "OperatorAttributeValue::get() expected one of [int, bool, " + "std::vector, std::vector<::FlexFlow::ff_dim_t>, " + "::FlexFlow::OperatorType, ::FlexFlow::Activation, " + "::FlexFlow::ff_dim_t, size_t, ::FlexFlow::AggregateOp, " + "std::optional<::FlexFlow::RegularizerAttrs>, ::FlexFlow::PoolOp, " + "::FlexFlow::TensorShape, ::FlexFlow::DataType], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfOperatorAttributeValue_v, + "OperatorAttributeValue::get() expected one of [int, bool, " + "std::vector, std::vector<::FlexFlow::ff_dim_t>, " + "::FlexFlow::OperatorType, ::FlexFlow::Activation, " + "::FlexFlow::ff_dim_t, size_t, ::FlexFlow::AggregateOp, " + "std::optional<::FlexFlow::RegularizerAttrs>, ::FlexFlow::PoolOp, " + "::FlexFlow::TensorShape, ::FlexFlow::DataType], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(OperatorAttributeValue const &) const; + bool operator!=(OperatorAttributeValue const &) const; + bool operator<(OperatorAttributeValue const &) const; + bool operator>(OperatorAttributeValue const &) const; + bool operator<=(OperatorAttributeValue const &) const; + bool operator>=(OperatorAttributeValue const &) const; + std::variant, + std::vector<::FlexFlow::ff_dim_t>, + ::FlexFlow::OperatorType, + ::FlexFlow::Activation, + ::FlexFlow::ff_dim_t, + size_t, + ::FlexFlow::AggregateOp, + std::optional<::FlexFlow::RegularizerAttrs>, + ::FlexFlow::PoolOp, + ::FlexFlow::TensorShape, + ::FlexFlow::DataType> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::OperatorAttributeValue> { + size_t operator()(::FlexFlow::OperatorAttributeValue const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::OperatorAttributeValue> { + static ::FlexFlow::OperatorAttributeValue from_json(json const &); + static void to_json(json &, ::FlexFlow::OperatorAttributeValue const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::OperatorAttributeValue const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::OperatorAttributeValue const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_VALUE_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml new file mode 100644 index 0000000000..9ab88e63c2 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml @@ -0,0 +1,63 @@ +namespace = "FlexFlow" +name = "OperatorAttributeValue" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", +] +explicit_constructors = false + +includes = [ + "", + "", + "op-attrs/operator_type.dtg.h", + "op-attrs/ff_dim.dtg.h", + "op-attrs/activation.dtg.h", + "op-attrs/aggregate_op.dtg.h", + "op-attrs/regularizer_attrs.dtg.h", + "op-attrs/pool_op.dtg.h", + "op-attrs/tensor_shape.dtg.h", + "op-attrs/datatype.dtg.h", + "", +] + +[[values]] +type = "int" + +[[values]] +type = "bool" + +[[values]] +type = "std::vector" + +[[values]] +type = "std::vector<::FlexFlow::ff_dim_t>" + +[[values]] +type = "::FlexFlow::OperatorType" + +[[values]] +type = "::FlexFlow::Activation" + +[[values]] +type = "::FlexFlow::ff_dim_t" + +[[values]] +type = "size_t" + +[[values]] +type = "::FlexFlow::AggregateOp" + +[[values]] +type = "std::optional<::FlexFlow::RegularizerAttrs>" + +[[values]] +type = "::FlexFlow::PoolOp" + +[[values]] +type = "::FlexFlow::TensorShape" + +[[values]] +type = "::FlexFlow::DataType" diff --git a/lib/substitutions/include/substitutions/operator_pattern/satisfies_constraint.h b/lib/substitutions/include/substitutions/operator_pattern/satisfies_constraint.h new file mode 100644 index 0000000000..2ac45af0be --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/satisfies_constraint.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_SATISFIES_CONSTRAINT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_SATISFIES_CONSTRAINT_H + +#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" + +namespace FlexFlow { + +bool operator_satisfies_constraint(PCGOperatorAttrs const ¶ms, OperatorAttributeConstraint const &constraint); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/satisfies_pattern.h b/lib/substitutions/include/substitutions/operator_pattern/satisfies_pattern.h new file mode 100644 index 0000000000..f33e027777 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/satisfies_pattern.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_SATISFIES_PATTERN_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_SATISFIES_PATTERN_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" + +namespace FlexFlow { + +bool operator_satisfies_pattern(PCGOperatorAttrs const &attrs, OperatorAttributePattern const &pattern); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h deleted file mode 100644 index 4ed90aed06..0000000000 --- a/lib/substitutions/include/substitutions/output_graph.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_OUTPUT_GRAPH_H -#define _FLEXFLOW_SUBSTITUTIONS_OUTPUT_GRAPH_H - -#include "utils/graph.h" - -namespace FlexFlow { - -// NOTE(@wmdi) I am not sure whether these should be part of attribute expr. -struct OperatorAttrAccess { - Node node; - AttributeExpr attr_expr; -}; - -struct AttrConstant { - OperatorAttributeValue value; -}; - -using OperatorAttributeExpr = std::variant; - -// NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can -// define the assignment for each operator type. -struct OperatorAttrAssignment { - std::unordered_map assignments; -}; - -struct OutputGraphExpr - : public strong_typedef< - OutputGraphExpr, - NodeLabelledOpenMultiDiGraph> { - using strong_typedef::strong_typedef; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.h b/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.h new file mode 100644 index 0000000000..9dd20bb10e --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml +/* proj-data +{ + "generated_from": "1e5beabcb8e3657d8fe9c9c8b1310cb1" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_ATTR_CONSTANT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_ATTR_CONSTANT_DTG_H + +#include "fmt/format.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct AttrConstant { + AttrConstant() = delete; + AttrConstant(::FlexFlow::OperatorAttributeValue const &value); + + bool operator==(AttrConstant const &) const; + bool operator!=(AttrConstant const &) const; + bool operator<(AttrConstant const &) const; + bool operator>(AttrConstant const &) const; + bool operator<=(AttrConstant const &) const; + bool operator>=(AttrConstant const &) const; + ::FlexFlow::OperatorAttributeValue value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::AttrConstant const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(AttrConstant const &); +std::ostream &operator<<(std::ostream &, AttrConstant const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_ATTR_CONSTANT_DTG_H diff --git a/lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml b/lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml new file mode 100644 index 0000000000..68973f9c0c --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "AttrConstant" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_value.dtg.h", +] + +[[fields]] +name = "value" +type = "::FlexFlow::OperatorAttributeValue" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h new file mode 100644 index 0000000000..3d6fb21574 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h @@ -0,0 +1,28 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml +/* proj-data +{ + "generated_from": "9084c9afb2724504a6f4db4288a83a0d" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_DTG_H + +#include "substitutions/output_graph/output_operator_attrs_assignment.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +struct OutputGraphExpr { + OutputGraphExpr() = delete; + OutputGraphExpr(::FlexFlow::NodeLabelledOpenMultiDiGraph< + ::FlexFlow::OutputOperatorAttrsAssignment> const &raw_graph); + + ::FlexFlow::NodeLabelledOpenMultiDiGraph< + ::FlexFlow::OutputOperatorAttrsAssignment> + raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_DTG_H diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml new file mode 100644 index 0000000000..37d87f7820 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "OutputGraphExpr" +features = [] + +includes = [ + "utils/graph.h", + "substitutions/output_graph/output_operator_attrs_assignment.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::NodeLabelledOpenMultiDiGraph<::FlexFlow::OutputOperatorAttrsAssignment>" diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.h new file mode 100644 index 0000000000..0d585f0aa0 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.h @@ -0,0 +1,49 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml +/* proj-data +{ + "generated_from": "e3b3a741183fcb38cfa68aacb82e12d1" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTR_ACCESS_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTR_ACCESS_DTG_H + +#include "fmt/format.h" +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct OutputOperatorAttrAccess { + OutputOperatorAttrAccess() = delete; + OutputOperatorAttrAccess(::FlexFlow::Node const &node, + ::FlexFlow::OperatorAttributeExpr const &attr_expr); + + bool operator==(OutputOperatorAttrAccess const &) const; + bool operator!=(OutputOperatorAttrAccess const &) const; + bool operator<(OutputOperatorAttrAccess const &) const; + bool operator>(OutputOperatorAttrAccess const &) const; + bool operator<=(OutputOperatorAttrAccess const &) const; + bool operator>=(OutputOperatorAttrAccess const &) const; + ::FlexFlow::Node node; + ::FlexFlow::OperatorAttributeExpr attr_expr; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OutputOperatorAttrAccess const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputOperatorAttrAccess const &); +std::ostream &operator<<(std::ostream &, OutputOperatorAttrAccess const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTR_ACCESS_DTG_H diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml new file mode 100644 index 0000000000..51aae54730 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "OutputOperatorAttrAccess" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph.h", + "substitutions/operator_pattern/operator_attribute_expr.dtg.h", +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +# NOTE(@wmdi) I am not sure whether these should be part of attribute expr. +[[fields]] +name = "attr_expr" +type = "::FlexFlow::OperatorAttributeExpr" + diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.dtg.h new file mode 100644 index 0000000000..327c230b61 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.dtg.h @@ -0,0 +1,119 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "89ebf777a5b909eef78ab5a5a177e041" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRIBUTE_EXPR_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRIBUTE_EXPR_DTG_H + +#include "fmt/format.h" +#include "substitutions/output_graph/attr_constant.dtg.h" +#include "substitutions/output_graph/output_operator_attr_access.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct OutputOperatorAttributeExpr { + OutputOperatorAttributeExpr() = delete; + explicit OutputOperatorAttributeExpr( + ::FlexFlow::OutputOperatorAttrAccess const &); + explicit OutputOperatorAttributeExpr(::FlexFlow::AttrConstant const &); + template + static constexpr bool IsPartOfOutputOperatorAttributeExpr_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = + v(this->get<::FlexFlow::OutputOperatorAttrAccess>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::AttrConstant>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type OutputOperatorAttributeExpr", + this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = + v(this->get<::FlexFlow::OutputOperatorAttrAccess>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::AttrConstant>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type OutputOperatorAttributeExpr", + this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfOutputOperatorAttributeExpr_v, + "OutputOperatorAttributeExpr::has() expected one of " + "[::FlexFlow::OutputOperatorAttrAccess, " + "::FlexFlow::AttrConstant], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfOutputOperatorAttributeExpr_v, + "OutputOperatorAttributeExpr::get() expected one of " + "[::FlexFlow::OutputOperatorAttrAccess, " + "::FlexFlow::AttrConstant], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfOutputOperatorAttributeExpr_v, + "OutputOperatorAttributeExpr::get() expected one of " + "[::FlexFlow::OutputOperatorAttrAccess, " + "::FlexFlow::AttrConstant], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(OutputOperatorAttributeExpr const &) const; + bool operator!=(OutputOperatorAttributeExpr const &) const; + bool operator<(OutputOperatorAttributeExpr const &) const; + bool operator>(OutputOperatorAttributeExpr const &) const; + bool operator<=(OutputOperatorAttributeExpr const &) const; + bool operator>=(OutputOperatorAttributeExpr const &) const; + std::variant<::FlexFlow::OutputOperatorAttrAccess, ::FlexFlow::AttrConstant> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::OutputOperatorAttributeExpr> { + size_t operator()(::FlexFlow::OutputOperatorAttributeExpr const &) const; +}; +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::OutputOperatorAttributeExpr const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::OutputOperatorAttributeExpr const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRIBUTE_EXPR_DTG_H diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml new file mode 100644 index 0000000000..19810a0151 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "OutputOperatorAttributeExpr" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "substitutions/output_graph/attr_constant.dtg.h", + "substitutions/output_graph/output_operator_attr_access.dtg.h", +] + +[[values]] +type = "::FlexFlow::OutputOperatorAttrAccess" +key = "attr_ref" + +[[values]] +type = "::FlexFlow::AttrConstant" +key = "constant" diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.h new file mode 100644 index 0000000000..5586a90a08 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.h @@ -0,0 +1,49 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml +/* proj-data +{ + "generated_from": "bbfb309c5a39a729da23dace4df4a9de" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRS_ASSIGNMENT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRS_ASSIGNMENT_DTG_H + +#include "fmt/format.h" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include "substitutions/output_graph/output_operator_attribute_expr.dtg.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct OutputOperatorAttrsAssignment { + OutputOperatorAttrsAssignment() = delete; + OutputOperatorAttrsAssignment( + std::unordered_map<::FlexFlow::OperatorAttributeKey, + ::FlexFlow::OutputOperatorAttributeExpr> const + &assignments); + + bool operator==(OutputOperatorAttrsAssignment const &) const; + bool operator!=(OutputOperatorAttrsAssignment const &) const; + std::unordered_map<::FlexFlow::OperatorAttributeKey, + ::FlexFlow::OutputOperatorAttributeExpr> + assignments; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OutputOperatorAttrsAssignment const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputOperatorAttrsAssignment const &); +std::ostream &operator<<(std::ostream &, OutputOperatorAttrsAssignment const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRS_ASSIGNMENT_DTG_H diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml new file mode 100644 index 0000000000..9781515803 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "OutputOperatorAttrsAssignment" +features = [ + "eq", + # "ord", + "hash", + # "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h", + "substitutions/output_graph/output_operator_attribute_expr.dtg.h", + "", +] + +# NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can +# define the assignment for each operator type. +[[fields]] +name = "assignments" +type = "std::unordered_map<::FlexFlow::OperatorAttributeKey, ::FlexFlow::OutputOperatorAttributeExpr>" diff --git a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h deleted file mode 100644 index 741554142f..0000000000 --- a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_TENSOR_PATTERN_H -#define _FLEXFLOW_SUBSTITUTIONS_TENSOR_PATTERN_H - -#include "attribute_expr.h" -#include "pcg/parallel_tensor.h" - -namespace FlexFlow { - -enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; - -using TensorAttributeValue = std::variant>; - -using TensorAttributeConstraint = - AttributeConstraint; - -using ParallelTensorPattern = - AttributePattern; - -std::optional - evaluate_attribute_expr(ParallelTensor const &tensor_shape, - AttributeExpr const &expr); - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/pcg_pattern.dtg.h b/lib/substitutions/include/substitutions/pcg_pattern.dtg.h new file mode 100644 index 0000000000..0c0cc41891 --- /dev/null +++ b/lib/substitutions/include/substitutions/pcg_pattern.dtg.h @@ -0,0 +1,31 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/pcg_pattern.struct.toml +/* proj-data +{ + "generated_from": "f536f846828ba39266dd4a1fbaeec0e6" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_DTG_H + +#include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +struct PCGPattern { + PCGPattern() = delete; + PCGPattern(::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::OperatorAttributePattern, + ::FlexFlow::TensorAttributePattern> const &raw_graph); + + ::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::OperatorAttributePattern, + ::FlexFlow::TensorAttributePattern> + raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_DTG_H diff --git a/lib/substitutions/include/substitutions/pcg_pattern.struct.toml b/lib/substitutions/include/substitutions/pcg_pattern.struct.toml new file mode 100644 index 0000000000..191d66a38c --- /dev/null +++ b/lib/substitutions/include/substitutions/pcg_pattern.struct.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "PCGPattern" +features = [] +includes = [ + "utils/graph.h", + "substitutions/operator_pattern/operator_attribute_pattern.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::OutputLabelledOpenMultiDiGraph<::FlexFlow::OperatorAttributePattern, ::FlexFlow::TensorAttributePattern>" diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h new file mode 100644 index 0000000000..d31d65d83b --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h @@ -0,0 +1,31 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml +/* proj-data +{ + "generated_from": "0022d1b2c1447667695a120c154a0168" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_DTG_H + +#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +struct SubParallelComputationGraph { + SubParallelComputationGraph() = delete; + SubParallelComputationGraph( + ::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const &raw_graph); + + ::FlexFlow::OutputLabelledOpenMultiDiGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> + raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_DTG_H diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 0d6bfe7628..e58502a745 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -1,18 +1,13 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H -#define _FLEXFLOW_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H -#include "op-attrs/operator_attrs.h" -#include "pcg/machine_view.h" -#include "pcg/operator.h" -#include "pcg/parallel_tensor.h" -#include "utils/graph.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" namespace FlexFlow { -using SubParallelComputationGraph = - OutputLabelledOpenMultiDiGraph; - -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(SubParallelComputationGraph); +ParallelLayerAttrs get_parallel_layer_attrs(SubParallelComputationGraph const &, Node const &); +PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &, Node const &); +ParallelTensorAttrs get_parallel_tensor_attrs(SubParallelComputationGraph const &, OpenMultiDiEdge const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml new file mode 100644 index 0000000000..1ba04b544c --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "SubParallelComputationGraph" +features = [ ] + +includes = [ + "pcg/parallel_layer_attrs.dtg.h", + "pcg/parallel_tensor_attrs.dtg.h", + "utils/graph.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::OutputLabelledOpenMultiDiGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/substitutions/include/substitutions/substitution.dtg.h b/lib/substitutions/include/substitutions/substitution.dtg.h new file mode 100644 index 0000000000..5f50d9bafc --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution.dtg.h @@ -0,0 +1,38 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/substitution.struct.toml +/* proj-data +{ + "generated_from": "c101f1d63e2d8d80a0ec9c5f5db4fa12" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_DTG_H + +#include "substitutions/output_graph/output_graph_expr.dtg.h" +#include "substitutions/pcg_pattern.dtg.h" + +namespace FlexFlow { +struct Substitution { + Substitution() = delete; + Substitution(::FlexFlow::PCGPattern const &pcg_pattern, + ::FlexFlow::OutputGraphExpr const &output_graph_expr, + ::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, + ::FlexFlow::InputMultiDiEdge> const + &input_edge_match_to_output, + ::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::OutputMultiDiEdge> const + &output_edge_match_to_output); + + ::FlexFlow::PCGPattern pcg_pattern; + ::FlexFlow::OutputGraphExpr output_graph_expr; + ::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, ::FlexFlow::InputMultiDiEdge> + input_edge_match_to_output; + ::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::OutputMultiDiEdge> + output_edge_match_to_output; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_DTG_H diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index 8dbe4e66cf..1aa2b2946b 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -1,24 +1,12 @@ #ifndef _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTION_H #define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTION_H -#include "graph_pattern.h" -#include "output_graph.h" -#include "sub_parallel_computation_graph.h" +#include "sub_parallel_computation_graph.dtg.h" +#include "substitutions/substitution.dtg.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" namespace FlexFlow { -struct Substitution { - using InputPatternInput = InputMultiDiEdge; - using InputPatternOutput = OutputMultiDiEdge; - using OutputPatternInput = InputMultiDiEdge; - using OutputPatternOutput = OutputMultiDiEdge; - - GraphPattern input_graph; - OutputGraphExpr output_graph_expr; - bidict input_mapping; - bidict output_mapping; -}; - bool is_valid_substitution(Substitution const &); SubParallelComputationGraph @@ -28,12 +16,4 @@ SubParallelComputationGraph } // namespace FlexFlow -namespace std { -template <> -struct hash { - size_t operator()(FlexFlow::Substitution const &) const; -}; - -}; // namespace std - #endif diff --git a/lib/substitutions/include/substitutions/substitution.struct.toml b/lib/substitutions/include/substitutions/substitution.struct.toml new file mode 100644 index 0000000000..eb630e9308 --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "Substitution" +features = [] + +includes = [ + "substitutions/pcg_pattern.dtg.h", + "substitutions/output_graph/output_graph_expr.dtg.h", +] + +[[fields]] +name = "pcg_pattern" +type = "::FlexFlow::PCGPattern" + +[[fields]] +name = "output_graph_expr" +type = "::FlexFlow::OutputGraphExpr" + +[[fields]] +name = "input_edge_match_to_output" +type = "::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, ::FlexFlow::InputMultiDiEdge>" + +[[fields]] +name = "output_edge_match_to_output" +type = "::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::OutputMultiDiEdge>" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h new file mode 100644 index 0000000000..78af6c7405 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_ACCESS_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_ACCESS_H + +#include "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" + +namespace FlexFlow { + +TensorAttributeValue eval_list_access(ParallelTensorAttrs const &attrs, TensorAttributeListIndexAccess const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h new file mode 100644 index 0000000000..863cb81239 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_SIZE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_SIZE_H + +#include "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" + +namespace FlexFlow { + +TensorAttributeValue eval_list_size(ParallelTensorAttrs const &attrs, TensorAttributeListSize const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h b/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h new file mode 100644 index 0000000000..f276fcbd3a --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_GET_ATTRIBUTE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_GET_ATTRIBUTE_H + +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" + +namespace FlexFlow { + +TensorAttributeValue get_attribute(ParallelTensorAttrs const &, TensorAttributeKey); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h new file mode 100644 index 0000000000..2e15e604f8 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_CONSTRAINT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_CONSTRAINT_H + +#include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" + +namespace FlexFlow { + +bool parallel_tensor_satisfies_constraint(ParallelTensorAttrs const ¶ms, TensorAttributeConstraint const &constraint); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h new file mode 100644 index 0000000000..8defca7e50 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_PATTERN_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_PATTERN_H + +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" + +namespace FlexFlow { + +bool parallel_tensor_satisfies_pattern(ParallelTensorAttrs const &attrs, TensorAttributePattern const &pattern); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h new file mode 100644 index 0000000000..ba705a5d35 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml +/* proj-data +{ + "generated_from": "29dbf81668bc864b06af52261060335e" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_CONSTRAINT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_CONSTRAINT_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/constraint_type.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributeConstraint { + TensorAttributeConstraint() = delete; + TensorAttributeConstraint( + ::FlexFlow::ConstraintType const &constraint_type, + ::FlexFlow::TensorAttributeExpr const &attribute_expr, + ::FlexFlow::TensorAttributeValue const &attribute_value); + + bool operator==(TensorAttributeConstraint const &) const; + bool operator!=(TensorAttributeConstraint const &) const; + bool operator<(TensorAttributeConstraint const &) const; + bool operator>(TensorAttributeConstraint const &) const; + bool operator<=(TensorAttributeConstraint const &) const; + bool operator>=(TensorAttributeConstraint const &) const; + ::FlexFlow::ConstraintType constraint_type; + ::FlexFlow::TensorAttributeExpr attribute_expr; + ::FlexFlow::TensorAttributeValue attribute_value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttributeConstraint const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorAttributeConstraint from_json(json const &); + static void to_json(json &, FlexFlow::TensorAttributeConstraint const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttributeConstraint const &); +std::ostream &operator<<(std::ostream &, TensorAttributeConstraint const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_CONSTRAINT_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml new file mode 100644 index 0000000000..6aba719e08 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "TensorAttributeConstraint" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "substitutions/constraint_type.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_value.dtg.h", +] + +[[fields]] +name = "constraint_type" +type = "::FlexFlow::ConstraintType" + +[[fields]] +name = "attribute_expr" +type = "::FlexFlow::TensorAttributeExpr" + +[[fields]] +name = "attribute_value" +type = "::FlexFlow::TensorAttributeValue" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.dtg.h new file mode 100644 index 0000000000..d34be357c5 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.dtg.h @@ -0,0 +1,141 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "b91285329f12f1b409805cbf9be575b2" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributeExpr { + TensorAttributeExpr() = delete; + explicit TensorAttributeExpr(::FlexFlow::TensorAttributeKey const &); + explicit TensorAttributeExpr(::FlexFlow::TensorAttributeListSize const &); + explicit TensorAttributeExpr( + ::FlexFlow::TensorAttributeListIndexAccess const &); + template + static constexpr bool IsPartOfTensorAttributeExpr_v = + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::TensorAttributeKey>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::TensorAttributeListSize>()); + return result; + } + case 2: { + ReturnType result = + v(this->get<::FlexFlow::TensorAttributeListIndexAccess>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeExpr", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::TensorAttributeKey>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::TensorAttributeListSize>()); + return result; + } + case 2: { + ReturnType result = + v(this->get<::FlexFlow::TensorAttributeListIndexAccess>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeExpr", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfTensorAttributeExpr_v, + "TensorAttributeExpr::has() expected one of " + "[::FlexFlow::TensorAttributeKey, ::FlexFlow::TensorAttributeListSize, " + "::FlexFlow::TensorAttributeListIndexAccess], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfTensorAttributeExpr_v, + "TensorAttributeExpr::get() expected one of " + "[::FlexFlow::TensorAttributeKey, ::FlexFlow::TensorAttributeListSize, " + "::FlexFlow::TensorAttributeListIndexAccess], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfTensorAttributeExpr_v, + "TensorAttributeExpr::get() expected one of " + "[::FlexFlow::TensorAttributeKey, ::FlexFlow::TensorAttributeListSize, " + "::FlexFlow::TensorAttributeListIndexAccess], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(TensorAttributeExpr const &) const; + bool operator!=(TensorAttributeExpr const &) const; + bool operator<(TensorAttributeExpr const &) const; + bool operator>(TensorAttributeExpr const &) const; + bool operator<=(TensorAttributeExpr const &) const; + bool operator>=(TensorAttributeExpr const &) const; + std::variant<::FlexFlow::TensorAttributeKey, + ::FlexFlow::TensorAttributeListSize, + ::FlexFlow::TensorAttributeListIndexAccess> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::TensorAttributeExpr> { + size_t operator()(::FlexFlow::TensorAttributeExpr const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::TensorAttributeExpr> { + static ::FlexFlow::TensorAttributeExpr from_json(json const &); + static void to_json(json &, ::FlexFlow::TensorAttributeExpr const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::TensorAttributeExpr const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::TensorAttributeExpr const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h new file mode 100644 index 0000000000..12515d2716 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_H + +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" + +namespace FlexFlow { + +TensorAttributeValue + evaluate_attribute_expr(ParallelTensorAttrs const &attrs, + TensorAttributeExpr const &expr); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml new file mode 100644 index 0000000000..03ec0eb624 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "TensorAttributeExpr" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "substitutions/tensor_pattern/tensor_attribute_key.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h", +] + +[[values]] +type = "::FlexFlow::TensorAttributeKey" +key = "key" + +[[values]] +type = "::FlexFlow::TensorAttributeListSize" +key = "list_size" + +[[values]] +type = "::FlexFlow::TensorAttributeListIndexAccess" +key = "list_idx" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.dtg.h new file mode 100644 index 0000000000..50a0aa49e8 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml +/* proj-data +{ + "generated_from": "63a7c40c1e5b582f98b59750a35f0a08" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_KEY_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_KEY_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; +std::string format_as(TensorAttributeKey); +std::ostream &operator<<(std::ostream &, TensorAttributeKey); +void to_json(::nlohmann::json &, TensorAttributeKey); +void from_json(::nlohmann::json const &, TensorAttributeKey &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttributeKey) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_KEY_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml new file mode 100644 index 0000000000..3df36d13ac --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "TensorAttributeKey" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "DIM_SIZES" + +[[values]] +name = "DIM_DEGREES" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h new file mode 100644 index 0000000000..473f4e1698 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml +/* proj-data +{ + "generated_from": "41f5449cd700b6d7ab017f3efa39dc1d" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_ACCESS_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_ACCESS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributeListIndexAccess { + TensorAttributeListIndexAccess() = delete; + TensorAttributeListIndexAccess( + ::FlexFlow::TensorAttributeKey const &attribute_key, int const &index); + + bool operator==(TensorAttributeListIndexAccess const &) const; + bool operator!=(TensorAttributeListIndexAccess const &) const; + bool operator<(TensorAttributeListIndexAccess const &) const; + bool operator>(TensorAttributeListIndexAccess const &) const; + bool operator<=(TensorAttributeListIndexAccess const &) const; + bool operator>=(TensorAttributeListIndexAccess const &) const; + ::FlexFlow::TensorAttributeKey attribute_key; + int index; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttributeListIndexAccess const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorAttributeListIndexAccess from_json(json const &); + static void to_json(json &, FlexFlow::TensorAttributeListIndexAccess const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorAttributeListIndexAccess const &); +std::ostream &operator<<(std::ostream &, + TensorAttributeListIndexAccess const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_ACCESS_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml new file mode 100644 index 0000000000..a57dd25845 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "TensorAttributeListIndexAccess" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +] + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::TensorAttributeKey" + +[[fields]] +name = "index" +type = "int" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h new file mode 100644 index 0000000000..1630014bdf --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml +/* proj-data +{ + "generated_from": "ec72cd39de5d1c0f0478696d7b83e4e9" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_SIZE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_SIZE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributeListSize { + TensorAttributeListSize() = delete; + TensorAttributeListSize(::FlexFlow::TensorAttributeKey const &attribute_key); + + bool operator==(TensorAttributeListSize const &) const; + bool operator!=(TensorAttributeListSize const &) const; + bool operator<(TensorAttributeListSize const &) const; + bool operator>(TensorAttributeListSize const &) const; + bool operator<=(TensorAttributeListSize const &) const; + bool operator>=(TensorAttributeListSize const &) const; + ::FlexFlow::TensorAttributeKey attribute_key; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttributeListSize const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorAttributeListSize from_json(json const &); + static void to_json(json &, FlexFlow::TensorAttributeListSize const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorAttributeListSize const &); +std::ostream &operator<<(std::ostream &, TensorAttributeListSize const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_SIZE_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml new file mode 100644 index 0000000000..c876696343 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "TensorAttributeListSize" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/tensor_pattern/tensor_attribute_key.dtg.h", +] + + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::TensorAttributeKey" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h new file mode 100644 index 0000000000..ecc4bc7da0 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h @@ -0,0 +1,56 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml +/* proj-data +{ + "generated_from": "42a51afce383f1ddc3d70827aa94a68f" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" +#include "utils/hash-utils.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributePattern { + TensorAttributePattern() = delete; + TensorAttributePattern( + std::unordered_set<::FlexFlow::TensorAttributeConstraint> const + &attribute_constraints); + + bool operator==(TensorAttributePattern const &) const; + bool operator!=(TensorAttributePattern const &) const; + std::unordered_set<::FlexFlow::TensorAttributeConstraint> + attribute_constraints; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttributePattern const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorAttributePattern from_json(json const &); + static void to_json(json &, FlexFlow::TensorAttributePattern const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttributePattern const &); +std::ostream &operator<<(std::ostream &, TensorAttributePattern const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml new file mode 100644 index 0000000000..43f45e95b9 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "TensorAttributePattern" +features = [ + "eq", + # "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h", + "utils/hash-utils.h", +] + +[[fields]] +name = "attribute_constraints" +type = "std::unordered_set<::FlexFlow::TensorAttributeConstraint>" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.h new file mode 100644 index 0000000000..948a7abae6 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.h @@ -0,0 +1,118 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml +/* proj-data +{ + "generated_from": "d80cf2e618d64df284c2647430a12a86" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_VALUE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_VALUE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "utils/fmt.h" +#include "utils/hash-utils-core.h" +#include +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributeValue { + TensorAttributeValue() = delete; + explicit TensorAttributeValue(size_t const &); + explicit TensorAttributeValue(std::vector const &); + template + static constexpr bool IsPartOfTensorAttributeValue_v = + std::is_same_v || std::is_same_v>; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get()); + return result; + } + case 1: { + ReturnType result = v(this->get>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeValue", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get()); + return result; + } + case 1: { + ReturnType result = v(this->get>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeValue", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfTensorAttributeValue_v, + "TensorAttributeValue::has() expected one of [size_t, " + "std::vector], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfTensorAttributeValue_v, + "TensorAttributeValue::get() expected one of [size_t, " + "std::vector], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfTensorAttributeValue_v, + "TensorAttributeValue::get() expected one of [size_t, " + "std::vector], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(TensorAttributeValue const &) const; + bool operator!=(TensorAttributeValue const &) const; + bool operator<(TensorAttributeValue const &) const; + bool operator>(TensorAttributeValue const &) const; + bool operator<=(TensorAttributeValue const &) const; + bool operator>=(TensorAttributeValue const &) const; + std::variant> raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::TensorAttributeValue> { + size_t operator()(::FlexFlow::TensorAttributeValue const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::TensorAttributeValue> { + static ::FlexFlow::TensorAttributeValue from_json(json const &); + static void to_json(json &, ::FlexFlow::TensorAttributeValue const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::TensorAttributeValue const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::TensorAttributeValue const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_VALUE_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml new file mode 100644 index 0000000000..91313f159b --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "TensorAttributeValue" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "", + "utils/hash-utils-core.h", + "utils/fmt.h", +] + +[[values]] +type = "size_t" + +[[values]] +type = "std::vector" diff --git a/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.dtg.h new file mode 100644 index 0000000000..6bf815791d --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "b4086fd78ca7ec0475ed7abfd034304c" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_CLOSED_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_CLOSED_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct ClosedPatternEdge { + ClosedPatternEdge() = delete; + ClosedPatternEdge(::FlexFlow::MultiDiEdge const &raw_edge); + + bool operator==(ClosedPatternEdge const &) const; + bool operator!=(ClosedPatternEdge const &) const; + bool operator<(ClosedPatternEdge const &) const; + bool operator>(ClosedPatternEdge const &) const; + bool operator<=(ClosedPatternEdge const &) const; + bool operator>=(ClosedPatternEdge const &) const; + ::FlexFlow::MultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ClosedPatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_CLOSED_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml new file mode 100644 index 0000000000..d609ca1c27 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "ClosedPatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::MultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.dtg.h new file mode 100644 index 0000000000..5ce0e63073 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "c67ec363a91ce090dc538dcf76fa1f12" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct DownwardOpenPatternEdge { + DownwardOpenPatternEdge() = delete; + DownwardOpenPatternEdge(::FlexFlow::DownwardOpenMultiDiEdge const &raw_edge); + + bool operator==(DownwardOpenPatternEdge const &) const; + bool operator!=(DownwardOpenPatternEdge const &) const; + bool operator<(DownwardOpenPatternEdge const &) const; + bool operator>(DownwardOpenPatternEdge const &) const; + bool operator<=(DownwardOpenPatternEdge const &) const; + bool operator>=(DownwardOpenPatternEdge const &) const; + ::FlexFlow::DownwardOpenMultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::DownwardOpenPatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.h new file mode 100644 index 0000000000..9855d96e46 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_H + +#include "substitutions/unlabelled/downward_open_pattern_edge.dtg.h" + +namespace FlexFlow { + +int get_src_idx(DownwardOpenPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml new file mode 100644 index 0000000000..2dda7498f0 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "DownwardOpenPatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::DownwardOpenMultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h b/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h new file mode 100644 index 0000000000..e92fe547b1 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h @@ -0,0 +1,36 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml +/* proj-data +{ + "generated_from": "f172b041a99f4de1d396e5d451a5e64d" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_DTG_H + +#include "utils/bidict.h" +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct UnlabelledPatternEdgeSplits { + UnlabelledPatternEdgeSplits() = delete; + UnlabelledPatternEdgeSplits( + ::FlexFlow::bidict<::FlexFlow::MultiDiEdge, + std::pair<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::InputMultiDiEdge>> const + &unwrapped); + + bool operator==(UnlabelledPatternEdgeSplits const &) const; + bool operator!=(UnlabelledPatternEdgeSplits const &) const; + ::FlexFlow::bidict< + ::FlexFlow::MultiDiEdge, + std::pair<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::InputMultiDiEdge>> + unwrapped; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/edge_splits.h b/lib/substitutions/include/substitutions/unlabelled/edge_splits.h new file mode 100644 index 0000000000..76135f34db --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/edge_splits.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_H + +#include +#include "substitutions/unlabelled/edge_splits.dtg.h" +#include "substitutions/unlabelled/closed_pattern_edge.dtg.h" +#include "substitutions/unlabelled/input_pattern_edge.dtg.h" +#include "substitutions/unlabelled/output_pattern_edge.dtg.h" + +namespace FlexFlow { + +std::pair get_split_edges(UnlabelledPatternEdgeSplits const &, ClosedPatternEdge const &); + +std::vector> as_closed_output_input_tuples(UnlabelledPatternEdgeSplits const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml b/lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml new file mode 100644 index 0000000000..fa714296c8 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "UnlabelledPatternEdgeSplits" +features = [ + "eq", +] + +includes = [ + "utils/bidict.h", + "utils/graph.h", + "", +] + +[[fields]] +name = "unwrapped" +type = "::FlexFlow::bidict<::FlexFlow::MultiDiEdge, std::pair<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::InputMultiDiEdge>>" diff --git a/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h b/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h new file mode 100644 index 0000000000..95f393109c --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_FIND_PATTERN_MATCHES_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_FIND_PATTERN_MATCHES_H + +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" +#include "utils/graph.h" +#include "substitutions/unlabelled/match_additional_criterion.dtg.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" + +namespace FlexFlow { + +std::vector + find_pattern_matches(UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MatchAdditionalCriterion const &additional_criterion); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.h new file mode 100644 index 0000000000..f292acba14 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "d0cc0e65c4e3feb2e9b8435947c99e5f" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct InputPatternEdge { + InputPatternEdge() = delete; + InputPatternEdge(::FlexFlow::InputMultiDiEdge const &raw_edge); + + bool operator==(InputPatternEdge const &) const; + bool operator!=(InputPatternEdge const &) const; + bool operator<(InputPatternEdge const &) const; + bool operator>(InputPatternEdge const &) const; + bool operator<=(InputPatternEdge const &) const; + bool operator>=(InputPatternEdge const &) const; + ::FlexFlow::InputMultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::InputPatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h new file mode 100644 index 0000000000..b05fa479db --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_H + +#include "substitutions/unlabelled/input_pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" + +namespace FlexFlow { + +PatternNode get_dst_node(InputPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml new file mode 100644 index 0000000000..6da52b58aa --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "InputPatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::InputMultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h new file mode 100644 index 0000000000..e910be21ba --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h @@ -0,0 +1,36 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml +/* proj-data +{ + "generated_from": "2dff356c85dccda1fce8f714d41c6202" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_ADDITIONAL_CRITERION_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_ADDITIONAL_CRITERION_DTG_H + +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/graph.h" +#include + +namespace FlexFlow { +struct MatchAdditionalCriterion { + MatchAdditionalCriterion() = delete; + MatchAdditionalCriterion( + std::function const &node_criterion, + std::function const + &edge_criterion); + + std::function + node_criterion; + std::function + edge_criterion; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_ADDITIONAL_CRITERION_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml new file mode 100644 index 0000000000..c0107d84e9 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MatchAdditionalCriterion" +features = [] + +includes = [ + "", + "utils/graph.h", + "substitutions/unlabelled/pattern_node.dtg.h", + "substitutions/unlabelled/pattern_edge.dtg.h", +] + +[[fields]] +name = "node_criterion" +type = "std::function" + +[[fields]] +name = "edge_criterion" +type = "std::function" diff --git a/lib/substitutions/include/substitutions/unlabelled/match_split.dtg.h b/lib/substitutions/include/substitutions/unlabelled/match_split.dtg.h new file mode 100644 index 0000000000..aa17814c52 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_split.dtg.h @@ -0,0 +1,29 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml +/* proj-data +{ + "generated_from": "e44c4347e07263a493cbbd5caccedd22" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_DTG_H + +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include + +namespace FlexFlow { +struct MatchSplit { + MatchSplit() = delete; + MatchSplit(MultiDiGraphPatternMatch const &prefix_submatch, + MultiDiGraphPatternMatch const &postfix_submatch); + + bool operator==(MatchSplit const &) const; + bool operator!=(MatchSplit const &) const; + MultiDiGraphPatternMatch prefix_submatch; + MultiDiGraphPatternMatch postfix_submatch; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/match_split.h b/lib/substitutions/include/substitutions/unlabelled/match_split.h new file mode 100644 index 0000000000..221805daa9 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_split.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_H + +#include "substitutions/unlabelled/match_split.dtg.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" +#include "substitutions/unlabelled/pattern_split.dtg.h" + +namespace FlexFlow { + +MatchSplit empty_match_split(); +MatchSplit apply_split(UnlabelledGraphPattern const &pattern, + MultiDiGraphPatternMatch const &match, + PatternSplit const &split); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml b/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml new file mode 100644 index 0000000000..3fd77e7b4a --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MatchSplit" +features = [ + "eq", + # "ord", +] + +includes = [ + "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +] + +[[fields]] +name = "prefix_submatch" +type = "MultiDiGraphPatternMatch" + +[[fields]] +name = "postfix_submatch" +type = "MultiDiGraphPatternMatch" diff --git a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.dtg.h b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.dtg.h new file mode 100644 index 0000000000..30f81504fe --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.dtg.h @@ -0,0 +1,36 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml +/* proj-data +{ + "generated_from": "9842661a5d4e7d717f12d2c27da7df0d" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_DTG_H + +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/bidict.h" +#include "utils/graph.h" +#include + +namespace FlexFlow { +struct MultiDiGraphPatternMatch { + MultiDiGraphPatternMatch() = delete; + MultiDiGraphPatternMatch( + ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> const + &node_assignment, + ::FlexFlow::bidict<::FlexFlow::PatternEdge, + ::FlexFlow::OpenMultiDiEdge> const &edge_assignment); + + bool operator==(MultiDiGraphPatternMatch const &) const; + bool operator!=(MultiDiGraphPatternMatch const &) const; + ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> node_assignment; + ::FlexFlow::bidict<::FlexFlow::PatternEdge, ::FlexFlow::OpenMultiDiEdge> + edge_assignment; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h new file mode 100644 index 0000000000..550d4249f4 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_H + +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include "substitutions/unlabelled/edge_splits.dtg.h" + +namespace FlexFlow { + +MultiDiGraphPatternMatch empty_multidigraph_pattern_match(); +std::optional unsplit_matches( + MultiDiGraphPatternMatch const &prefix, + MultiDiGraphPatternMatch const &postfix, + UnlabelledPatternEdgeSplits const &edge_splits); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml new file mode 100644 index 0000000000..778767ab62 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +# TODO(@lockshaw): rename to UnlabelledGraphPatternMatch +name = "MultiDiGraphPatternMatch" +features = [ + "eq", + # "ord", + # "hash", + # "fmt", +] + +includes = [ + "utils/bidict.h", + "utils/graph.h", + "substitutions/unlabelled/pattern_edge.dtg.h", + "substitutions/unlabelled/pattern_node.dtg.h", +] + +[[fields]] +name = "node_assignment" +type = "::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node>" + +[[fields]] +name = "edge_assignment" +type = "::FlexFlow::bidict<::FlexFlow::PatternEdge, ::FlexFlow::OpenMultiDiEdge>" diff --git a/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.dtg.h new file mode 100644 index 0000000000..04ec8c656d --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "3222696e351c3e203e008714245c737f" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct OutputPatternEdge { + OutputPatternEdge() = delete; + OutputPatternEdge(::FlexFlow::OutputMultiDiEdge const &raw_edge); + + bool operator==(OutputPatternEdge const &) const; + bool operator!=(OutputPatternEdge const &) const; + bool operator<(OutputPatternEdge const &) const; + bool operator>(OutputPatternEdge const &) const; + bool operator<=(OutputPatternEdge const &) const; + bool operator>=(OutputPatternEdge const &) const; + ::FlexFlow::OutputMultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OutputPatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.h new file mode 100644 index 0000000000..72e8ff02cf --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_H + +#include "substitutions/unlabelled/output_pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" + +namespace FlexFlow { + +PatternNode get_src_node(OutputPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml new file mode 100644 index 0000000000..362cbc3265 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "OutputPatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::OutputMultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h new file mode 100644 index 0000000000..4883590130 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "a3eff166b0c8be2ddf3f7305eec094fd" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct PatternEdge { + PatternEdge() = delete; + PatternEdge(::FlexFlow::OpenMultiDiEdge const &raw_edge); + + bool operator==(PatternEdge const &) const; + bool operator!=(PatternEdge const &) const; + bool operator<(PatternEdge const &) const; + bool operator>(PatternEdge const &) const; + bool operator<=(PatternEdge const &) const; + bool operator>=(PatternEdge const &) const; + ::FlexFlow::OpenMultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::PatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h new file mode 100644 index 0000000000..689f46012b --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PATTERN_EDGE_H + +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/input_pattern_edge.dtg.h" +#include "substitutions/unlabelled/output_pattern_edge.dtg.h" +#include "substitutions/unlabelled/closed_pattern_edge.dtg.h" + +namespace FlexFlow { + +std::unordered_set get_nodes(PatternEdge const &); +bool is_closed_edge(PatternEdge const &); +bool is_input_edge(PatternEdge const &); +bool is_output_edge(PatternEdge const &); + +ClosedPatternEdge require_closed_edge(PatternEdge const &); +InputPatternEdge require_input_edge(PatternEdge const &); +OutputPatternEdge require_output_edge(PatternEdge const &); + +PatternEdge pattern_edge_from_input_edge(InputPatternEdge const &); +PatternEdge pattern_edge_from_output_edge(OutputPatternEdge const &); +PatternEdge pattern_edge_from_closed_edge(ClosedPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml new file mode 100644 index 0000000000..4abfa1c0db --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "PatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::OpenMultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h new file mode 100644 index 0000000000..aee90413d5 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_MATCHING_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_MATCHING_H + +#include "utils/graph.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include "substitutions/unlabelled/match_split.dtg.h" +#include "substitutions/unlabelled/match_additional_criterion.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" + +namespace FlexFlow { + +bool unlabelled_pattern_does_match(UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MultiDiGraphPatternMatch const &match, + MatchAdditionalCriterion const &additional_criterion); + +std::vector + find_pattern_matches(UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MatchAdditionalCriterion const &additional_criterion); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.h b/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.h new file mode 100644 index 0000000000..56471c2e08 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml +/* proj-data +{ + "generated_from": "a0e58ade010a9b250d2c1c378fde2639" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_NODE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_NODE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct PatternNode { + PatternNode() = delete; + PatternNode(::FlexFlow::Node const &raw_node); + + bool operator==(PatternNode const &) const; + bool operator!=(PatternNode const &) const; + bool operator<(PatternNode const &) const; + bool operator>(PatternNode const &) const; + bool operator<=(PatternNode const &) const; + bool operator>=(PatternNode const &) const; + ::FlexFlow::Node raw_node; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::PatternNode const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_NODE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml new file mode 100644 index 0000000000..ecd0253516 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "PatternNode" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_node" +type = "::FlexFlow::Node" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.h b/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.h new file mode 100644 index 0000000000..453c4020a8 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml +/* proj-data +{ + "generated_from": "8604edb5bd1a546ffa94ef496888e46d" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct PatternSplit { + PatternSplit() = delete; + PatternSplit(std::unordered_set<::FlexFlow::PatternNode> const &first, + std::unordered_set<::FlexFlow::PatternNode> const &second); + + bool operator==(PatternSplit const &) const; + bool operator!=(PatternSplit const &) const; + std::unordered_set<::FlexFlow::PatternNode> first; + std::unordered_set<::FlexFlow::PatternNode> second; +}; +} // namespace FlexFlow + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::PatternSplit from_json(json const &); + static void to_json(json &, FlexFlow::PatternSplit const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(PatternSplit const &); +std::ostream &operator<<(std::ostream &, PatternSplit const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.h b/lib/substitutions/include/substitutions/unlabelled/pattern_split.h new file mode 100644 index 0000000000..50d3d37eb8 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_H + +#include "substitutions/unlabelled/pattern_split.dtg.h" +#include "substitutions/unlabelled/edge_splits.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" + +namespace FlexFlow { + +PatternSplit find_even_split(UnlabelledGraphPattern const &); + +GraphSplit get_raw_split(PatternSplit const &); + +UnlabelledPatternEdgeSplits get_edge_splits(UnlabelledGraphPattern const &pattern, PatternSplit const &split); + +std::pair + apply_split(UnlabelledGraphPattern const &, PatternSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml new file mode 100644 index 0000000000..04d1080ff7 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "PatternSplit" +features = [ + "eq", + # "ord", + "json", + "fmt", +] + +includes = [ + "utils/graph.h", + "", + "substitutions/unlabelled/pattern_node.dtg.h", +] + +[[fields]] +name = "first" +type = "std::unordered_set<::FlexFlow::PatternNode>" + +[[fields]] +name = "second" +type = "std::unordered_set<::FlexFlow::PatternNode>" diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h new file mode 100644 index 0000000000..a2ba6c26d2 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h @@ -0,0 +1,24 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml +/* proj-data +{ + "generated_from": "f494ed79eb1ba4010155e456b452157f" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_DTG_H + +#include "utils/graph.h" + +namespace FlexFlow { +struct UnlabelledGraphPattern { + UnlabelledGraphPattern() = delete; + UnlabelledGraphPattern(::FlexFlow::OpenMultiDiGraphView const &raw_graph); + + ::FlexFlow::OpenMultiDiGraphView raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h new file mode 100644 index 0000000000..822e51588a --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_H + +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/upward_open_pattern_edge.dtg.h" +#include "substitutions/unlabelled/downward_open_pattern_edge.dtg.h" + +namespace FlexFlow { + +size_t num_nodes(UnlabelledGraphPattern const &); +bool is_singleton_pattern(UnlabelledGraphPattern const &); +std::unordered_set get_nodes(UnlabelledGraphPattern const &); +std::unordered_set get_edges(UnlabelledGraphPattern const &); +std::vector get_topological_ordering(UnlabelledGraphPattern const &); + +std::unordered_set get_incoming_edges(UnlabelledGraphPattern const &, PatternNode const &); +std::unordered_set get_outgoing_edges(UnlabelledGraphPattern const &, PatternNode const &); + +UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &, std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml new file mode 100644 index 0000000000..03f4bd5523 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml @@ -0,0 +1,10 @@ +namespace = "FlexFlow" +name = "UnlabelledGraphPattern" +features = [] +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::OpenMultiDiGraphView" diff --git a/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.dtg.h new file mode 100644 index 0000000000..82440b5820 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "a1d4c9d1dd94eb456c5e29d80ad579da" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct UpwardOpenPatternEdge { + UpwardOpenPatternEdge() = delete; + UpwardOpenPatternEdge(::FlexFlow::UpwardOpenMultiDiEdge const &raw_edge); + + bool operator==(UpwardOpenPatternEdge const &) const; + bool operator!=(UpwardOpenPatternEdge const &) const; + bool operator<(UpwardOpenPatternEdge const &) const; + bool operator>(UpwardOpenPatternEdge const &) const; + bool operator<=(UpwardOpenPatternEdge const &) const; + bool operator>=(UpwardOpenPatternEdge const &) const; + ::FlexFlow::UpwardOpenMultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::UpwardOpenPatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.h new file mode 100644 index 0000000000..998cf1a519 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_H + +#include "substitutions/unlabelled/upward_open_pattern_edge.dtg.h" + +namespace FlexFlow { + +int get_dst_idx(UpwardOpenPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml new file mode 100644 index 0000000000..a4c3bad809 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "UpwardOpenPatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::UpwardOpenMultiDiEdge" diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc deleted file mode 100644 index 73f7b2c62d..0000000000 --- a/lib/substitutions/src/graph_pattern.cc +++ /dev/null @@ -1,257 +0,0 @@ -#include "substitutions/graph_pattern.h" -#include "op-attrs/operator_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "pcg/parallel_computation_graph.h" -#include "substitutions/get_attribute.h" -#include "substitutions/graph_pattern_match.h" -#include "substitutions/operator_pattern.h" -#include "substitutions/parallel_tensor_pattern.h" - -namespace FlexFlow { - -std::optional - evaluate_list_index_access(int index, - std::optional const &v) { - if (!v.has_value() || - !std::holds_alternative>(v.value()) || - !std::holds_alternative>( - v.value())) { - return std::nullopt; - } - - if (index >= MAX_TENSOR_DIM) { - return std::nullopt; - } - - if (std::holds_alternative>(v.value())) { - return get>(v.value()).at(index); - } else { - return get>(v.value()).at(index); - } -} - -std::optional - evaluate_list_index_access(int const &index, - std::optional const &v) { - if (!v.has_value() || !std::holds_alternative>(v.value())) { - return std::nullopt; - } - - auto vec = get>(v.value()); - - if (index >= vec.size()) { - return std::nullopt; - } - - return vec.at(index); -} - -std::optional - evaluate_list_size(std::optional const &v) { - return MAX_TENSOR_DIM; -} - -std::optional - evaluate_list_size(std::optional const &v) { - if (!v.has_value() || !std::holds_alternative>(v.value())) { - return std::nullopt; - } - - return (int)get>(v.value()).size(); -} - -struct EvaluateOperatorAttributeExpr { - EvaluateOperatorAttributeExpr(Operator const &attrs) : attrs(attrs) {} - - std::optional - operator()(OperatorAttributeKey const &key) { - return get_attribute(this->attrs.attrs, key); - } - - std::optional - operator()(ListIndexAccess const &index_access) { - std::optional v = - get_attribute(this->attrs.attrs, index_access.attribute_key); - return evaluate_list_index_access(index_access.index, v); - } - - std::optional - operator()(ListSize const &list_size) { - std::optional v = - get_attribute(this->attrs.attrs, list_size.attribute_key); - return evaluate_list_size(v); - } - -private: - Operator attrs; -}; - -std::optional - evaluate_tensor_attribute_expr(ParallelTensor const &, - AttributeExpr const &); - -struct EvaluateTensorAttributeExpr { - EvaluateTensorAttributeExpr(ParallelTensor const &tensor_shape) - : tensor_shape(tensor_shape) {} - - template - std::optional evaluate(T const &t) { - return this->operator()(t); - } - - std::optional operator()(TensorAttributeKey key) { - switch (key) { - case TensorAttributeKey::DIM_SIZES: { - std::vector result; - for (ParallelDim const &dim : ff_ordered(this->tensor_shape.dims)) { - result.push_back(dim.size); - } - return result; - } - case TensorAttributeKey::DIM_DEGREES: { - std::vector result; - for (ParallelDim const &dim : ff_ordered(this->tensor_shape.dims)) { - result.push_back(dim.degree); - } - return result; - } - default: - throw std::runtime_error("Unknown TensorAttributeKey"); - } - } - - std::optional - operator()(ListIndexAccess const &index_access) { - std::optional v = - this->evaluate(index_access.attribute_key); - return evaluate_list_index_access(index_access.index, v); - } - - std::optional - operator()(ListSize const &list_size) { - return evaluate_list_size(this->evaluate(list_size.attribute_key)); - } - -private: - ParallelTensor tensor_shape; -}; - -std::optional - evaluate_attribute_expr(ParallelTensor const &tensor_shape, - AttributeExpr const &expr) { - return visit(EvaluateTensorAttributeExpr(tensor_shape), expr); -} - -std::optional - evaluate_attribute_expr(Operator const &attrs, - AttributeExpr const &expr) { - return visit(EvaluateOperatorAttributeExpr(attrs), expr); -} - -template -std::optional satisfies(ConstraintType constraint_type, - V const &constraint_value, - std::optional const &maybe_attribute_value) { - if (!maybe_attribute_value.has_value()) { - return std::nullopt; - } - V attr_val = maybe_attribute_value.value(); - - if (attr_val.index() != constraint_value.index()) { - return std::nullopt; - } - - if (constraint_type == ConstraintType::EQUAL) { - return attr_val == constraint_value; - } else { - throw std::runtime_error("Unknown constraint_type"); - } -} - -std::optional satisfies(ParallelTensor const &tensor_shape, - TensorAttributeConstraint const &constraint) { - auto value = evaluate_attribute_expr(tensor_shape, constraint.attribute_expr); - return satisfies( - constraint.constraint_type, constraint.attribute_value, value); -} - -std::optional satisfies(Operator const ¶ms, - OperatorAttributeConstraint const &constraint) { - auto value = evaluate_attribute_expr(params, constraint.attribute_expr); - OperatorAttributeValue v = value.value(); - return satisfies( - constraint.constraint_type, constraint.attribute_value, value); -} - -template -std::optional optional_all_of(Container const &container, - Function const &func) { - for (auto const &element : container) { - std::optional condition = func(element); - if (!condition.has_value()) { - return std::nullopt; - } - - if (!condition.value()) { - return false; - } - } - return true; -} - -std::optional satisfies(Operator const ¶ms, - OperatorPattern const &pattern) { - return optional_all_of(pattern.attribute_constraints, - [&](OperatorAttributeConstraint const &c) { - return satisfies(params, c); - }); -} - -std::optional satisfies(ParallelTensor const ¶ms, - ParallelTensorPattern const &pattern) { - return optional_all_of( - pattern.attribute_constraints, - [&](TensorAttributeConstraint const &c) { return satisfies(params, c); }); -} - -bool operator_satisfies(Operator const ¶ms, - OperatorPattern const &pattern) { - return satisfies(params, pattern).value_or(false); -} - -bool parallel_tensor_satisfies(ParallelTensor const ¶ms, - ParallelTensorPattern const &pattern) { - return satisfies(params, pattern).value_or(false); -} - -bool assignment_satisfies(SubParallelComputationGraph const &pcg, - GraphPattern const &pattern, - MultiDiGraphPatternMatch const &patternMatch) { - bool result = true; - for (auto const &kv : patternMatch.node_assignment) { - Node patternNode = kv.first; - Node pcgNode = kv.second; - std::optional constraintResult = - satisfies(pcg.at(pcgNode), pattern.value().at(patternNode)); - result &= constraintResult.value_or(false); - } - - for (auto const &kv : patternMatch.edge_assignment) { - OpenMultiDiEdge patternEdge = kv.first; - OpenMultiDiEdge pcgEdge = kv.second; - std::optional constraintResult = - satisfies(pcg.at(pcgEdge), pattern.value().at(patternEdge)); - result &= constraintResult.value_or(false); - } - - result &= pattern_matches( - pattern, - pcg, - patternMatch, - MatchAdditionalCriterion{[](Node const &, Node const &) { return true; }, - [](OpenMultiDiEdge const &, - OpenMultiDiEdge const &) { return true; }}); - - return result; -} -} // namespace FlexFlow diff --git a/lib/substitutions/src/graph_pattern_match.cc b/lib/substitutions/src/graph_pattern_match.cc deleted file mode 100644 index f9c6b9a773..0000000000 --- a/lib/substitutions/src/graph_pattern_match.cc +++ /dev/null @@ -1,305 +0,0 @@ -#include "substitutions/graph_pattern.h" -#include "utils/hash-utils.h" -#include - -namespace FlexFlow { - -GraphSplit split_pattern(OpenMultiDiGraphView const &pattern) { - std::vector topological_ordering = get_topological_ordering(pattern); - assert(topological_ordering.size() >= 2); - - int split_point = topological_ordering.size() / 2; - auto split = vector_split(topological_ordering, split_point); - std::unordered_set prefix(split.first.begin(), split.first.end()); - std::unordered_set postfix(split.second.begin(), split.second.end()); - return {prefix, postfix}; -} - -std::pair - apply_split(OpenMultiDiGraphView const &pattern, GraphSplit const &split) { - return {get_subgraph(pattern, split.first), - get_subgraph(pattern, split.second)}; -} - -/* -Given a match and a pattern split, gets the submatches in subpatterns. -*/ -MatchSplit apply_split(OpenMultiDiGraphView const &pattern, - MultiDiGraphPatternMatch const &match, - GraphSplit const &split) { - auto prefix = split.first; - auto postfix = split.second; - - MatchSplit result; - - for (auto const &kv : match.node_assignment) { - Node pattern_node = kv.first; - Node graph_node = kv.second; - if (contains(split.first, pattern_node)) { - result.prefix_submatch.node_assignment.equate(pattern_node, graph_node); - } else { - assert(contains(split.second, pattern_node)); - result.postfix_submatch.node_assignment.equate(pattern_node, graph_node); - } - } - - auto edge_splits = get_edge_splits(pattern, split); - - std::function - handle_edge = [&](OpenMultiDiEdge const &pattern_edge, - OpenMultiDiEdge const &graph_edge) -> void { - auto edge_nodes = get_nodes(pattern_edge); - if (is_subseteq_of(edge_nodes, prefix)) { - result.prefix_submatch.edge_assignment.equate(pattern_edge, graph_edge); - } else if (is_subseteq_of(edge_nodes, postfix)) { - result.postfix_submatch.edge_assignment.equate(pattern_edge, graph_edge); - } else { - assert(is_standard_edge(pattern_edge)); - assert(is_standard_edge(graph_edge)); - auto standard_edge = std::get(pattern_edge); - auto divided = edge_splits.at_l(standard_edge); - auto divided_graph_edge = split_edge(get(graph_edge)); - handle_edge(divided.first, divided_graph_edge.first); - handle_edge(divided.second, divided_graph_edge.second); - } - }; - - for (auto const &kv : match.edge_assignment) { - OpenMultiDiEdge pattern_edge = kv.first; - OpenMultiDiEdge graph_edge = match.edge_assignment.at_l(pattern_edge); - handle_edge(pattern_edge, graph_edge); - } - - return result; -} - -bool is_singleton_pattern(OpenMultiDiGraphView const &pattern) { - return num_nodes(pattern) == 1; -} - -bool pattern_matches(OpenMultiDiGraphView const &pattern, - OpenMultiDiGraphView const &graph, - MultiDiGraphPatternMatch const &match, - MatchAdditionalCriterion const &additional_criterion) { - if (is_singleton_pattern(pattern)) { - Node pattern_node = get_only(get_nodes(pattern)); - Node graph_matched_node = match.node_assignment.at_l(pattern_node); - if (!additional_criterion.node_criterion(pattern_node, - graph_matched_node)) { - return false; - } - for (OpenMultiDiEdge const &e : get_edges(pattern)) { - OpenMultiDiEdge graph_matched_edge = match.edge_assignment.at_l(e); - - assert(is_input_edge(e) || is_output_edge(e)); - if (is_input_edge(e)) { - if (is_output_edge(graph_matched_edge)) { - return false; - } - UpwardOpenMultiDiEdge matched_edge = - narrow(graph_matched_edge).value(); - InputMultiDiEdge input_edge = std::get(e); - if (match.node_assignment.at_l(input_edge.dst) != - get_dst_node(matched_edge)) { - return false; - } - } else { - if (is_input_edge(graph_matched_edge)) { - return false; - } - DownwardOpenMultiDiEdge matched_edge = - narrow(graph_matched_edge).value(); - OutputMultiDiEdge output_edge = std::get(e); - if (match.node_assignment.at_l(output_edge.src) != - get_src_node(matched_edge)) { - return false; - } - } - - if (!additional_criterion.edge_criterion(e, graph_matched_edge)) { - return false; - } - } - - return true; - } - - auto split = split_pattern(pattern); - auto subpatterns = apply_split(pattern, split); - auto submatches = apply_split(pattern, match, split); - - return pattern_matches(subpatterns.first, - graph, - submatches.prefix_submatch, - additional_criterion) && - pattern_matches(subpatterns.second, - graph, - submatches.postfix_submatch, - additional_criterion); -} - -template -bool dst_compare(T const &lhs, T const &rhs) { - return get_dst_idx(lhs) < get_dst_idx(rhs); -} - -template -bool src_compare(T const &lhs, T const &rhs) { - return get_src_idx(lhs) < get_src_idx(rhs); -} - -std::optional - get_candidate_singleton_match(OpenMultiDiGraphView const &pattern, - OpenMultiDiGraphView const &graph, - Node const &graph_node) { - assert(is_singleton_pattern(pattern)); - - Node pattern_node = get_only(get_nodes(pattern)); - - MultiDiGraphPatternMatch match; - match.node_assignment.equate(pattern_node, graph_node); - - std::unordered_set incoming = - get_incoming_edges(graph, graph_node); - std::unordered_set outgoing = - get_outgoing_edges(graph, graph_node); - - std::unordered_set pattern_incoming = - get_incoming_edges(pattern, pattern_node); - std::unordered_set pattern_outgoing = - get_outgoing_edges(pattern, pattern_node); - - if (!pattern_incoming.empty() && pattern_incoming.size() != incoming.size()) { - return std::nullopt; - } - - if (!pattern_outgoing.empty() && pattern_outgoing.size() != outgoing.size()) { - return std::nullopt; - } - - std::vector incoming_ordered = - sorted_by(incoming, dst_compare); - std::vector outgoing_ordered = - sorted_by(outgoing, src_compare); - - std::vector pattern_incoming_ordered = - sorted_by(pattern_incoming, dst_compare); - std::vector pattern_outgoing_ordered = - sorted_by(pattern_outgoing, src_compare); - - if (pattern_incoming.size()) { - std::unordered_map node_port_mapping; - for (int i = 0; i < incoming_ordered.size(); ++i) { - UpwardOpenMultiDiEdge graph_edge = incoming_ordered[i], - pattern_edge = pattern_incoming_ordered[i]; - NodePort graph_port = get_dst_idx(graph_edge), - pattern_port = get_dst_idx(pattern_edge); - if (!contains_key(node_port_mapping, graph_port)) { - node_port_mapping.emplace(graph_port, pattern_port); - } else { - if (pattern_port != node_port_mapping.at(graph_port)) { - return std::nullopt; - } - } - match.edge_assignment.equate(widen(pattern_edge), - widen(graph_edge)); - } - } - - if (pattern_outgoing.size()) { - std::unordered_map node_port_mapping; - for (int i = 0; i < outgoing_ordered.size(); ++i) { - DownwardOpenMultiDiEdge graph_edge = outgoing_ordered[i], - pattern_edge = pattern_outgoing_ordered[i]; - NodePort graph_port = get_src_idx(graph_edge), - pattern_port = get_src_idx(pattern_edge); - if (!contains_key(node_port_mapping, graph_port)) { - node_port_mapping.insert({graph_port, pattern_port}); - } else { - if (pattern_port != node_port_mapping.at(graph_port)) { - return std::nullopt; - } - } - match.edge_assignment.equate(widen(pattern_edge), - widen(graph_edge)); - } - } - - return match; -} - -std::optional unsplit_matches( - MultiDiGraphPatternMatch const &prefix, - MultiDiGraphPatternMatch const &postfix, - bidict> const - &edge_splits) { - MultiDiGraphPatternMatch result; - std::unordered_set handled; - for (auto const &kv : edge_splits) { - MultiDiEdge standard_edge = kv.first; - OutputMultiDiEdge output_edge = kv.second.first; - InputMultiDiEdge input_edge = kv.second.second; - handled.insert(output_edge); - handled.insert(input_edge); - - OpenMultiDiEdge output_graph_edge = - prefix.edge_assignment.at_l(output_edge); - OpenMultiDiEdge input_graph_edge = postfix.edge_assignment.at_l(input_edge); - if (output_graph_edge == input_graph_edge) { - result.edge_assignment.equate(standard_edge, output_graph_edge); - } else { - return std::nullopt; - } - } - - for (auto const &kv : - merge_maps(prefix.edge_assignment, postfix.edge_assignment)) { - if (!contains(handled, kv.first)) { - result.edge_assignment.equate(kv.first, kv.second); - } - } - - result.node_assignment = - merge_maps(prefix.node_assignment, postfix.node_assignment); - - return result; -} - -std::vector - find_pattern_matches(OpenMultiDiGraphView const &pattern, - OpenMultiDiGraphView const &graph, - MatchAdditionalCriterion const &additional_criterion) { - std::vector matches; - if (is_singleton_pattern(pattern)) { - for (Node const &graph_node : get_nodes(graph)) { - std::optional candidate = - get_candidate_singleton_match(pattern, graph, graph_node); - if (candidate.has_value() && - pattern_matches( - pattern, graph, candidate.value(), additional_criterion)) { - matches.push_back(candidate.value()); - } - } - } else { - GraphSplit split = split_pattern(pattern); - auto subpatterns = apply_split(pattern, split); - auto prefix_matches = - find_pattern_matches(subpatterns.first, graph, additional_criterion); - auto postfix_matches = - find_pattern_matches(subpatterns.second, graph, additional_criterion); - auto edge_splits = get_edge_splits(pattern, split); - for (MultiDiGraphPatternMatch const &prefix_match : prefix_matches) { - for (MultiDiGraphPatternMatch const &postfix_match : postfix_matches) { - std::optional unsplit = - unsplit_matches(prefix_match, postfix_match, edge_splits); - if (unsplit.has_value()) { - matches.push_back(unsplit.value()); - } - } - } - } - - return matches; -} - -} // namespace FlexFlow diff --git a/lib/substitutions/src/sub_parallel_computation_graph.cc b/lib/substitutions/src/sub_parallel_computation_graph.cc deleted file mode 100644 index e8cb093222..0000000000 --- a/lib/substitutions/src/sub_parallel_computation_graph.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "substitutions/sub_parallel_computation_graph.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 335d021a2b..20e14c2256 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -2,475 +2,340 @@ namespace FlexFlow { -struct DeriveValidOperatorAttributeExpr { - template - std::unordered_set> - operator()(T const &t) { - return derive_valid_operator_attribute_expr(t); - } - - std::unordered_set> - derive_valid_operator_attribute_expr(OperatorAttributeKey const &key) { - return {key}; - } - - std::unordered_set> - derive_valid_operator_attribute_expr( - ListIndexAccess const &access) { - return {access, access.attribute_key}; - } - - std::unordered_set> - derive_valid_operator_attribute_expr( - ListSize const &ls) { - return {ls, ls.attribute_key}; - } -}; - -std::unordered_set> - get_valid_operator_attribute_exprs(OperatorPattern const &pattern) { - return set_union(transform( - pattern.attribute_constraints, [](OperatorAttributeConstraint const &t) { - return visit(DeriveValidOperatorAttributeExpr{}, t.attribute_expr); - })); -} - -bool is_valid_operator_attribute_expr( - OperatorPattern const &pattern, - AttributeExpr const &expr) { - return contains(get_valid_operator_attribute_exprs(pattern), expr); -} - -struct IsValidOperatorAttributeExprFunctor { - GraphPattern const &graph_pattern; - - template - bool operator()(T const &t) const { - return is_valid(t); - } - - bool is_valid(OperatorAttrAccess const &t) const { - return is_valid_operator_attribute_expr(graph_pattern.value().at(t.node), - t.attr_expr); - } - - bool is_valid(AttrConstant const &t) const { - return true; - } -}; - -bool is_valid_operator_attribute_expr(GraphPattern const &pattern, - OperatorAttributeExpr const &expr) { - return visit(IsValidOperatorAttributeExprFunctor{pattern}, expr); -} - -bool is_valid_substitution(Substitution const &s) { - for (Node const &node : get_nodes(s.output_graph_expr.value())) { - for (OperatorAttributeExpr expr : - values(s.output_graph_expr.value().at(node).assignments)) { - if (!is_valid_operator_attribute_expr(s.input_graph, expr)) { - return false; - } - } - } - return true; -} - -struct EvaluateOperatorAttributeExpr { - SubParallelComputationGraph const &graph; - MultiDiGraphPatternMatch const &match; - - template - OperatorAttributeValue operator()(T const &t) { - return evaluate(t); - } - - OperatorAttributeValue evaluate(OperatorAttrAccess const &t) { - Node node_in_pattern = t.node; - Node node_in_pcg = match.node_assignment.at_l(node_in_pattern); - return evaluate_attribute_expr(graph.at(node_in_pcg), t.attr_expr).value(); - } - - OperatorAttributeValue evaluate(AttrConstant const &t) { - return t.value; - } -}; - -OperatorAttributeValue - evaluate_graph_attribute_expr(SubParallelComputationGraph const &g, - MultiDiGraphPatternMatch const &match, - OperatorAttributeExpr const &expr) { - return visit(EvaluateOperatorAttributeExpr{g, match}, expr); -} - -Operator get_operator_attrs(SubParallelComputationGraph const &graph, - MultiDiGraphPatternMatch const &match, - OperatorAttrAssignment const &assignment) { - std::unordered_map assignments; - for (auto const &[key, expr] : assignment.assignments) { - OperatorAttributeValue value = - evaluate_graph_attribute_expr(graph, match, expr); - assignments.emplace(key, value); - } - assert(contains_key(assignments, OperatorAttributeKey::OP_TYPE)); - assert(std::holds_alternative( - assignments.at(OperatorAttributeKey::OP_TYPE))); - OperatorType op_type = - std::get(assignments.at(OperatorAttributeKey::OP_TYPE)); - switch (op_type) { - case OperatorType::BATCHMATMUL: - return Operator{ - BatchMatmulAttrs{std::get(assignments.at( - OperatorAttributeKey::A_SEQ_LENGTH_DIM)), - std::get(assignments.at( - OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, - std::nullopt}; - case OperatorType::BATCHNORM: - return Operator{BatchNormAttrs{std::get( - assignments.at(OperatorAttributeKey::RELU))}, - std::nullopt}; - case OperatorType::CAST: - return Operator{CastAttrs{std::get( - assignments.at(OperatorAttributeKey::DATA_TYPE))}, - std::nullopt}; - case OperatorType::CONCAT: - return Operator{ - ConcatAttrs{ - std::get(assignments.at(OperatorAttributeKey::AXIS)), - std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, - std::nullopt}; - case OperatorType::CONV2D: - return Operator{ - Conv2DAttrs{ - std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), - std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), - std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), - std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), - std::get(assignments.at(OperatorAttributeKey::PADDING_H)), - std::get(assignments.at(OperatorAttributeKey::PADDING_W)), - std::get(assignments.at(OperatorAttributeKey::GROUPS)), - std::get( - assignments.at(OperatorAttributeKey::ACTIVATION)), - std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, - std::nullopt}; - case OperatorType::DROPOUT: - return Operator{DropoutAttrs{std::get(assignments.at( - OperatorAttributeKey::RATE)), - std::get(assignments.at( - OperatorAttributeKey::SEED))}, - std::nullopt}; - case OperatorType::EW_ADD: - case OperatorType::EW_DIV: - case OperatorType::EW_EQUAL: - case OperatorType::EW_GREATER: - case OperatorType::EW_LESS: - case OperatorType::EW_MAX: - case OperatorType::EW_MIN: - case OperatorType::EW_MUL: - case OperatorType::EW_SUB: - return Operator{ - ElementBinaryAttrs{op_type, - std::get(assignments.at( - OperatorAttributeKey::DATA_TYPE)), - std::get(assignments.at( - OperatorAttributeKey::SHOULD_BROADCAST_LHS)), - std::get(assignments.at( - OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, - std::nullopt}; - case OperatorType::SCALAR_ADD: - case OperatorType::SCALAR_FLOOR_DIV: - case OperatorType::SCALAR_MULTIPLY: - case OperatorType::SCALAR_SUB: - case OperatorType::SCALAR_TRUE_DIV: - return Operator{ - ElementScalarUnaryAttrs{ - op_type, - std::get(assignments.at(OperatorAttributeKey::SCALAR))}, - std::nullopt}; - case OperatorType::EXP: - case OperatorType::IDENTITY: - case OperatorType::GELU: - case OperatorType::RSQRT: - case OperatorType::POW: - case OperatorType::SIN: - case OperatorType::COS: - return Operator{ElementUnaryAttrs{op_type}, std::nullopt}; - case OperatorType::EMBEDDING: - return Operator{ - EmbeddingAttrs{ - std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), - std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - std::get(assignments.at(OperatorAttributeKey::AGGR)), - std::get( - assignments.at(OperatorAttributeKey::OP_TYPE))}, - std::nullopt}; - case OperatorType::FLAT: - return Operator{FlatAttrs{}, std::nullopt}; - case OperatorType::GATHER: - return Operator{GatherAttrs{std::get( - assignments.at(OperatorAttributeKey::DIM))}, - std::nullopt}; - case OperatorType::INPUT: - return Operator{InputAttrs{}, std::nullopt}; - case OperatorType::LAYERNORM: - return Operator{ - LayerNormAttrs{ - std::get>( - assignments.at(OperatorAttributeKey::AXES)), - std::get( - assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), - std::get(assignments.at(OperatorAttributeKey::EPSILON))}, - std::nullopt}; - case OperatorType::LINEAR: - return Operator{ - LinearAttrs{ - std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - std::get(assignments.at(OperatorAttributeKey::USE_BIAS)), - std::get( - assignments.at(OperatorAttributeKey::DATA_TYPE)), - std::get( - assignments.at(OperatorAttributeKey::ACTIVATION)), - std::get>( - assignments.at(OperatorAttributeKey::REGULARIZER))}, - std::nullopt}; - case OperatorType::MULTIHEAD_ATTENTION: - return Operator{ - MultiHeadAttentionAttrs{ - std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), - std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - std::get(assignments.at(OperatorAttributeKey::VDIM)), - std::get(assignments.at(OperatorAttributeKey::DROPOUT)), - std::get(assignments.at(OperatorAttributeKey::BIAS)), - std::get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), - std::get( - assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, - std::nullopt}; - case OperatorType::NOOP: - return Operator{NoopAttrs{}, std::nullopt}; - case OperatorType::POOL2D: - return Operator{ - Pool2DAttrs{ - std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), - std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), - std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), - std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), - std::get(assignments.at(OperatorAttributeKey::PADDING_H)), - std::get(assignments.at(OperatorAttributeKey::PADDING_W)), - std::get(assignments.at(OperatorAttributeKey::POOL_TYPE)), - std::get( - assignments.at(OperatorAttributeKey::ACTIVATION))}, - std::nullopt}; - case OperatorType::REDUCE_ARGMAX: - case OperatorType::REDUCE_ARGMIN: - case OperatorType::REDUCE_MAX: - case OperatorType::REDUCE_MEAN: - case OperatorType::REDUCE_MIN: - case OperatorType::REDUCE_PROD: - case OperatorType::REDUCE_SUM: - return Operator{ - ReduceAttrs{ - std::get>( - assignments.at(OperatorAttributeKey::AXES)), - op_type, - std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, - std::nullopt}; - case OperatorType::REVERSE: - return Operator{ReverseAttrs{std::get( - assignments.at(OperatorAttributeKey::AXIS))}, - std::nullopt}; - case OperatorType::RESHAPE: - return Operator{ReshapeAttrs{std::get( - assignments.at(OperatorAttributeKey::SHAPE))}, - std::nullopt}; - case OperatorType::SPLIT: - return Operator{ - SplitAttrs{ - std::get>( - assignments.at(OperatorAttributeKey::SPLITS)), - std::get(assignments.at(OperatorAttributeKey::AXIS))}, - std::nullopt}; - case OperatorType::SOFTMAX: - return Operator{SoftmaxAttrs{std::get( - assignments.at(OperatorAttributeKey::DIM))}, - std::nullopt}; - case OperatorType::TOPK: - return Operator{ - TopKAttrs{ - std::get(assignments.at(OperatorAttributeKey::K)), - std::get(assignments.at(OperatorAttributeKey::SORTED))}, - std::nullopt}; - case OperatorType::TRANSPOSE: - return Operator{ - TransposeAttrs{std::get>( - assignments.at(OperatorAttributeKey::PERMUTATION))}, - std::nullopt}; - case OperatorType::COMBINE: - return Operator{CombineAttrs{std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DEGREE))}, - std::nullopt}; - case OperatorType::REDUCTION: - return Operator{ - ReductionAttrs{std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DEGREE))}, - std::nullopt}; - case OperatorType::REPARTITION: - return Operator{ - RepartitionAttrs{std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DEGREE))}, - std::nullopt}; - case OperatorType::REPLICATE: - return Operator{ - ReplicateAttrs{std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DEGREE))}, - std::nullopt}; - default: - throw mk_runtime_error("Unknown Operator"); - } -} - -struct AddMappedEdgeFunctor { - bidict const &node_mapping; - SubParallelComputationGraph &new_pcg; - - template - void operator()(T const &t) { - return add_mapped_edge(t); - } - - void add_mapped_edge(InputMultiDiEdge const &e) { - new_pcg.add_edge(InputMultiDiEdge{ - node_mapping.at_l(e.dst), new_pcg.add_node_port(), e.uid}); - } - - void add_mapped_edge(OutputMultiDiEdge const &e) { - new_pcg.add_edge(OutputMultiDiEdge{ - node_mapping.at_l(e.src), new_pcg.add_node_port(), e.uid}); - } - - void add_mapped_edge(MultiDiEdge const &e) { - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), - new_pcg.add_node_port(), - node_mapping.at_l(e.src), - new_pcg.add_node_port()}); - } -}; - -struct AddNewEdgeFunctor { - SubParallelComputationGraph const &old_pcg; - SubParallelComputationGraph &new_pcg; - MultiDiGraphPatternMatch const &match; - bidict node_mapping; - - template - void operator()(TO const &old_edge, TN const &new_edge) { - return add_new_edge(old_edge, new_edge); - } - - void add_new_edge(InputMultiDiEdge const &old_edge, - InputMultiDiEdge const &new_edge) { - new_pcg.add_edge(InputMultiDiEdge{node_mapping.at_l(new_edge.dst), - new_pcg.add_node_port(), - old_edge.uid}); - } - - void add_new_edge(MultiDiEdge const &old_edge, - InputMultiDiEdge const &new_edge) { - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(new_edge.dst), - new_pcg.add_node_port(), - node_mapping.at_l(old_edge.src), - new_pcg.add_node_port()}); - } - - void add_new_edge(OutputMultiDiEdge const &old_edge, - OutputMultiDiEdge const &new_edge) { - new_pcg.add_edge(OutputMultiDiEdge{node_mapping.at_l(new_edge.src), - new_pcg.add_node_port(), - old_edge.uid}); - } - - void add_new_edge(MultiDiEdge const &old_edge, - OutputMultiDiEdge const &new_edge) { - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(old_edge.dst), - new_pcg.add_node_port(), - node_mapping.at_l(new_edge.src), - new_pcg.add_node_port()}); - } - - void add_new_edge(InputMultiDiEdge const &, OutputMultiDiEdge const &) { - assert(false); - } - - void add_new_edge(OpenMultiDiEdge const &, MultiDiEdge const &) { - assert(false); - } - - void add_new_edge(OutputMultiDiEdge const &, InputMultiDiEdge const &) { - assert(false); - } -}; - -SubParallelComputationGraph - apply_substitution(SubParallelComputationGraph const &pcg, - Substitution const &substitution, - MultiDiGraphPatternMatch const &match) { - SubParallelComputationGraph new_pcg = - OutputLabelledOpenMultiDiGraph::template create< - UnorderedOutputLabelledOpenMultiDiGraph>(); - bidict node_mapping; // Refactor it with global nodes - for (Node const &node : get_nodes(pcg)) { - if (!contains_r(match.node_assignment, node)) { - node_mapping.equate(node, new_pcg.add_node(pcg.at(node))); - } - } - for (OpenMultiDiEdge const &edge : get_edges(pcg)) { - if (!contains_r(match.edge_assignment, edge)) { - visit(AddMappedEdgeFunctor{node_mapping, new_pcg}, edge); - } - } - for (Node const &output_node : - get_nodes(substitution.output_graph_expr.value())) { - Operator new_op = get_operator_attrs( - pcg, match, substitution.output_graph_expr.value().at(output_node)); - Node new_node = new_pcg.add_node(new_op); - node_mapping.equate(output_node, new_node); - } - for (OpenMultiDiEdge const &output_edge : - get_edges(substitution.output_graph_expr.value())) { - if (std::holds_alternative(output_edge)) { - InputMultiDiEdge e = std::get(output_edge); - OpenMultiDiEdge original_edge = - match.edge_assignment.at_l(substitution.input_mapping.at_r(e)); - visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, - original_edge, - output_edge); - } else if (std::holds_alternative(output_edge)) { - OutputMultiDiEdge e = std::get(output_edge); - OpenMultiDiEdge original_edge = - match.edge_assignment.at_l(substitution.output_mapping.at_r(e)); - visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, - original_edge, - output_edge); - } else { - assert(std::holds_alternative(output_edge)); - MultiDiEdge e = std::get(output_edge); - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), - new_pcg.add_node_port(), - node_mapping.at_l(e.src), - new_pcg.add_node_port()}); - } - } - - return new_pcg; -} +/* struct DeriveValidOperatorAttributeExpr { */ +/* template */ +/* std::unordered_set> */ +/* operator()(T const &t) { */ +/* return derive_valid_operator_attribute_expr(t); */ +/* } */ + +/* std::unordered_set> */ +/* derive_valid_operator_attribute_expr(OperatorAttributeKey const &key) { */ +/* return {key}; */ +/* } */ + +/* std::unordered_set> */ +/* derive_valid_operator_attribute_expr( */ +/* ListIndexAccess const &access) { */ +/* return {access, access.attribute_key}; */ +/* } */ + +/* std::unordered_set> */ +/* derive_valid_operator_attribute_expr( */ +/* ListSize const &ls) { */ +/* return {ls, ls.attribute_key}; */ +/* } */ +/* }; */ + +/* std::unordered_set> */ +/* get_valid_operator_attribute_exprs(OperatorPattern const &pattern) { */ +/* return set_union(transform( */ +/* pattern.attribute_constraints, [](OperatorAttributeConstraint const &t) { */ +/* return visit(DeriveValidOperatorAttributeExpr{}, t.attribute_expr); */ +/* })); */ +/* } */ + +/* bool is_valid_operator_attribute_expr( */ +/* OperatorPattern const &pattern, */ +/* AttributeExpr const &expr) { */ +/* return contains(get_valid_operator_attribute_exprs(pattern), expr); */ +/* } */ + +/* struct IsValidOperatorAttributeExprFunctor { */ +/* GraphPattern const &graph_pattern; */ + +/* template */ +/* bool operator()(T const &t) const { */ +/* return is_valid(t); */ +/* } */ + +/* bool is_valid(OperatorAttrAccess const &t) const { */ +/* return is_valid_operator_attribute_expr(graph_pattern.value().at(t.node), */ +/* t.attr_expr); */ +/* } */ + +/* bool is_valid(AttrConstant const &t) const { */ +/* return true; */ +/* } */ +/* }; */ + +/* bool is_valid_operator_attribute_expr(GraphPattern const &pattern, */ +/* OperatorAttributeExpr const &expr) { */ +/* return visit(IsValidOperatorAttributeExprFunctor{pattern}, expr); */ +/* } */ + +/* bool is_valid_substitution(Substitution const &s) { */ +/* for (Node const &node : get_nodes(s.output_graph_expr.value())) { */ +/* for (OperatorAttributeExpr expr : */ +/* values(s.output_graph_expr.value().at(node).assignments)) { */ +/* if (!is_valid_operator_attribute_expr(s.input_graph, expr)) { */ +/* return false; */ +/* } */ +/* } */ +/* } */ +/* return true; */ +/* } */ + +/* struct EvaluateOperatorAttributeExpr { */ +/* SubParallelComputationGraph const &graph; */ +/* MultiDiGraphPatternMatch const &match; */ + +/* template */ +/* OperatorAttributeValue operator()(T const &t) { */ +/* return evaluate(t); */ +/* } */ + +/* OperatorAttributeValue evaluate(OperatorAttrAccess const &t) { */ +/* Node node_in_pattern = t.node; */ +/* Node node_in_pcg = match.node_assignment.at_l(node_in_pattern); */ +/* return evaluate_attribute_expr(graph.at(node_in_pcg), t.attr_expr).value(); */ +/* } */ + +/* OperatorAttributeValue evaluate(AttrConstant const &t) { */ +/* return t.value; */ +/* } */ +/* }; */ + +/* OperatorAttributeValue */ +/* evaluate_graph_attribute_expr(SubParallelComputationGraph const &g, */ +/* MultiDiGraphPatternMatch const &match, */ +/* OperatorAttributeExpr const &expr) { */ +/* return visit(EvaluateOperatorAttributeExpr{g, match}, expr); */ +/* } */ + +/* Operator get_operator_attrs(SubParallelComputationGraph const &graph, */ +/* MultiDiGraphPatternMatch const &match, */ +/* OperatorAttrAssignment const &assignment) { */ +/* std::unordered_map assignments; */ +/* for (auto const &[key, expr] : assignment.assignments) { */ +/* OperatorAttributeValue value = */ +/* evaluate_graph_attribute_expr(graph, match, expr); */ +/* assignments.emplace(key, value); */ +/* } */ +/* assert(contains_key(assignments, OperatorAttributeKey::OP_TYPE)); */ +/* assert(std::holds_alternative( */ +/* assignments.at(OperatorAttributeKey::OP_TYPE))); */ +/* OperatorType op_type = */ +/* std::get(assignments.at(OperatorAttributeKey::OP_TYPE)); */ +/* switch (op_type) { */ +/* case OperatorType::BATCHMATMUL: */ +/* return Operator{ */ +/* BatchMatmulAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::A_SEQ_LENGTH_DIM)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, */ +/* std::nullopt}; */ +/* case OperatorType::BATCHNORM: */ +/* return Operator{BatchNormAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::RELU))}, */ +/* std::nullopt}; */ +/* case OperatorType::CAST: */ +/* return Operator{CastAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::DATA_TYPE))}, */ +/* std::nullopt}; */ +/* case OperatorType::CONCAT: */ +/* return Operator{ */ +/* ConcatAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::AXIS)), */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, */ +/* std::nullopt}; */ +/* case OperatorType::CONV2D: */ +/* return Operator{ */ +/* Conv2DAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_H)), */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_W)), */ +/* std::get(assignments.at(OperatorAttributeKey::GROUPS)), */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::ACTIVATION)), */ +/* std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, */ +/* std::nullopt}; */ +/* case OperatorType::DROPOUT: */ +/* return Operator{DropoutAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::RATE)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::SEED))}, */ +/* std::nullopt}; */ +/* case OperatorType::EW_ADD: */ +/* case OperatorType::EW_DIV: */ +/* case OperatorType::EW_EQUAL: */ +/* case OperatorType::EW_GREATER: */ +/* case OperatorType::EW_LESS: */ +/* case OperatorType::EW_MAX: */ +/* case OperatorType::EW_MIN: */ +/* case OperatorType::EW_MUL: */ +/* case OperatorType::EW_SUB: */ +/* return Operator{ */ +/* ElementBinaryAttrs{op_type, */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::DATA_TYPE)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::SHOULD_BROADCAST_LHS)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, */ +/* std::nullopt}; */ +/* case OperatorType::SCALAR_ADD: */ +/* case OperatorType::SCALAR_FLOOR_DIV: */ +/* case OperatorType::SCALAR_MULTIPLY: */ +/* case OperatorType::SCALAR_SUB: */ +/* case OperatorType::SCALAR_TRUE_DIV: */ +/* return Operator{ */ +/* ElementScalarUnaryAttrs{ */ +/* op_type, */ +/* std::get(assignments.at(OperatorAttributeKey::SCALAR))}, */ +/* std::nullopt}; */ +/* case OperatorType::EXP: */ +/* case OperatorType::IDENTITY: */ +/* case OperatorType::GELU: */ +/* case OperatorType::RSQRT: */ +/* case OperatorType::POW: */ +/* case OperatorType::SIN: */ +/* case OperatorType::COS: */ +/* return Operator{ElementUnaryAttrs{op_type}, std::nullopt}; */ +/* case OperatorType::EMBEDDING: */ +/* return Operator{ */ +/* EmbeddingAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), */ +/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), */ +/* std::get(assignments.at(OperatorAttributeKey::AGGR)), */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::OP_TYPE))}, */ +/* std::nullopt}; */ +/* case OperatorType::FLAT: */ +/* return Operator{FlatAttrs{}, std::nullopt}; */ +/* case OperatorType::GATHER: */ +/* return Operator{GatherAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::DIM))}, */ +/* std::nullopt}; */ +/* case OperatorType::INPUT: */ +/* return Operator{InputAttrs{}, std::nullopt}; */ +/* case OperatorType::LAYERNORM: */ +/* return Operator{ */ +/* LayerNormAttrs{ */ +/* std::get>( */ +/* assignments.at(OperatorAttributeKey::AXES)), */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), */ +/* std::get(assignments.at(OperatorAttributeKey::EPSILON))}, */ +/* std::nullopt}; */ +/* case OperatorType::LINEAR: */ +/* return Operator{ */ +/* LinearAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), */ +/* std::get(assignments.at(OperatorAttributeKey::USE_BIAS)), */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::DATA_TYPE)), */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::ACTIVATION)), */ +/* std::get>( */ +/* assignments.at(OperatorAttributeKey::REGULARIZER))}, */ +/* std::nullopt}; */ +/* case OperatorType::MULTIHEAD_ATTENTION: */ +/* return Operator{ */ +/* MultiHeadAttentionAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), */ +/* std::get(assignments.at(OperatorAttributeKey::VDIM)), */ +/* std::get(assignments.at(OperatorAttributeKey::DROPOUT)), */ +/* std::get(assignments.at(OperatorAttributeKey::BIAS)), */ +/* std::get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, */ +/* std::nullopt}; */ +/* case OperatorType::NOOP: */ +/* return Operator{NoopAttrs{}, std::nullopt}; */ +/* case OperatorType::POOL2D: */ +/* return Operator{ */ +/* Pool2DAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_H)), */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_W)), */ +/* std::get(assignments.at(OperatorAttributeKey::POOL_TYPE)), */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::ACTIVATION))}, */ +/* std::nullopt}; */ +/* case OperatorType::REDUCE_ARGMAX: */ +/* case OperatorType::REDUCE_ARGMIN: */ +/* case OperatorType::REDUCE_MAX: */ +/* case OperatorType::REDUCE_MEAN: */ +/* case OperatorType::REDUCE_MIN: */ +/* case OperatorType::REDUCE_PROD: */ +/* case OperatorType::REDUCE_SUM: */ +/* return Operator{ */ +/* ReduceAttrs{ */ +/* std::get>( */ +/* assignments.at(OperatorAttributeKey::AXES)), */ +/* op_type, */ +/* std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, */ +/* std::nullopt}; */ +/* case OperatorType::REVERSE: */ +/* return Operator{ReverseAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::AXIS))}, */ +/* std::nullopt}; */ +/* case OperatorType::RESHAPE: */ +/* return Operator{ReshapeAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::SHAPE))}, */ +/* std::nullopt}; */ +/* case OperatorType::SPLIT: */ +/* return Operator{ */ +/* SplitAttrs{ */ +/* std::get>( */ +/* assignments.at(OperatorAttributeKey::SPLITS)), */ +/* std::get(assignments.at(OperatorAttributeKey::AXIS))}, */ +/* std::nullopt}; */ +/* case OperatorType::SOFTMAX: */ +/* return Operator{SoftmaxAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::DIM))}, */ +/* std::nullopt}; */ +/* case OperatorType::TOPK: */ +/* return Operator{ */ +/* TopKAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::K)), */ +/* std::get(assignments.at(OperatorAttributeKey::SORTED))}, */ +/* std::nullopt}; */ +/* case OperatorType::TRANSPOSE: */ +/* return Operator{ */ +/* TransposeAttrs{std::get>( */ +/* assignments.at(OperatorAttributeKey::PERMUTATION))}, */ +/* std::nullopt}; */ +/* case OperatorType::COMBINE: */ +/* return Operator{CombineAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DIM)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ +/* std::nullopt}; */ +/* case OperatorType::REDUCTION: */ +/* return Operator{ */ +/* ReductionAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DIM)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ +/* std::nullopt}; */ +/* case OperatorType::REPARTITION: */ +/* return Operator{ */ +/* RepartitionAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DIM)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ +/* std::nullopt}; */ +/* case OperatorType::REPLICATE: */ +/* return Operator{ */ +/* ReplicateAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DIM)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ +/* std::nullopt}; */ +/* default: */ +/* throw mk_runtime_error("Unknown Operator"); */ +/* } */ +/* } */ } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/constraint_type.dtg.cc b/lib/substitutions/src/substitutions/constraint_type.dtg.cc new file mode 100644 index 0000000000..aa5c30dbe9 --- /dev/null +++ b/lib/substitutions/src/substitutions/constraint_type.dtg.cc @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/constraint_type.enum.toml +/* proj-data +{ + "generated_from": "06b029d76658cb434abf08b1fdb86137" +} +*/ + +#include "substitutions/constraint_type.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::ConstraintType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(ConstraintType x) { + switch (x) { + case ConstraintType::EQUAL: + return "EQUAL"; + default: + std::ostringstream oss; + oss << "Unknown ConstraintType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, ConstraintType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, ConstraintType x) { + switch (x) { + case ConstraintType::EQUAL: + j = "EQUAL"; + break; + default: + std::ostringstream oss; + oss << "Unknown ConstraintType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, ConstraintType &x) { + std::string as_str = j.get(); + if (as_str == "EQUAL") { + x = ConstraintType::EQUAL; + } else { + std::ostringstream oss; + oss << "Unknown ConstraintType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element( + FlexFlow::ConstraintType::EQUAL); +} +} // namespace rc diff --git a/lib/substitutions/src/substitutions/graph_pattern.cc b/lib/substitutions/src/substitutions/graph_pattern.cc new file mode 100644 index 0000000000..ac032d37ce --- /dev/null +++ b/lib/substitutions/src/substitutions/graph_pattern.cc @@ -0,0 +1,44 @@ +#include "substitutions/graph_pattern.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/operator_pattern/satisfies_pattern.h" +#include "substitutions/tensor_pattern/satisfies_pattern.h" + +namespace FlexFlow { + +UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &p) { + return UnlabelledGraphPattern{p.raw_graph}; +} + +TensorAttributePattern get_tensor_pattern(PCGPattern const &p, PatternEdge const &e) { + return p.raw_graph.at(e.raw_edge); +} + +OperatorAttributePattern get_operator_pattern(PCGPattern const &p, PatternNode const &n) { + return p.raw_graph.at(n.raw_node); +} + +bool assignment_satisfies(SubParallelComputationGraph const &pcg, + PCGPattern const &pattern, + MultiDiGraphPatternMatch const &patternMatch) { + return unlabelled_pattern_does_match( + get_unlabelled_pattern(pattern), + pcg.raw_graph, + patternMatch, + MatchAdditionalCriterion{ + [&](PatternNode const &patternNode, Node const &pcgNode) { + return operator_satisfies_pattern( + get_operator_attrs(pcg, pcgNode), + get_operator_pattern(pattern, patternNode) + ); + }, + [&](PatternEdge const &patternEdge, OpenMultiDiEdge const &pcgEdge) { + return parallel_tensor_satisfies_pattern( + get_parallel_tensor_attrs(pcg, pcgEdge), + get_tensor_pattern(pattern, patternEdge) + ); + } + } + ); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc b/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc new file mode 100644 index 0000000000..f90ddc2e20 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc @@ -0,0 +1,39 @@ +#include "substitutions/operator_pattern/eval_list_access.h" +#include "substitutions/operator_pattern/get_attribute.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::optional eval_list_access(PCGOperatorAttrs const &attrs, OperatorAttributeListIndexAccess const &acc) { + std::optional from_attr = get_attribute(attrs, acc.attribute_key); + + if (!from_attr.has_value()) { + return std::nullopt; + } + + return from_attr.value().visit< + std::optional + >([&](auto const &v) -> std::optional { + using T = std::decay_t; + + if constexpr (std::is_same_v>) { + if (acc.index >= v.size()) { + return std::nullopt; + } else { + int value = v.at(acc.index); + return OperatorAttributeValue{value}; + } + } else if constexpr (std::is_same_v>) { + if (acc.index >= v.size()) { + return std::nullopt; + } else { + ff_dim_t value = v.at(acc.index); + return OperatorAttributeValue{value}; + } + } else { + throw mk_runtime_error("Invalid operand"); + } + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/eval_list_size.cc b/lib/substitutions/src/substitutions/operator_pattern/eval_list_size.cc new file mode 100644 index 0000000000..2c2fd4850d --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/eval_list_size.cc @@ -0,0 +1,28 @@ +#include "substitutions/operator_pattern/eval_list_size.h" +#include "substitutions/operator_pattern/get_attribute.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::optional eval_list_size(PCGOperatorAttrs const &attrs, OperatorAttributeListSize const &acc) { + std::optional from_attr = get_attribute(attrs, acc.attribute_key); + + if (!from_attr.has_value()) { + return std::nullopt; + } + + return from_attr.value().visit< + std::optional + >([&](auto const &v) -> std::optional { + using T = std::decay_t; + + if constexpr (std::is_same_v> || std::is_same_v>) { + size_t size = v.size(); + return OperatorAttributeValue{size}; + } else { + throw mk_runtime_error("Invalid operand"); + } + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/operator_attributes.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc similarity index 73% rename from lib/substitutions/src/operator_attributes.cc rename to lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index 8bd8688194..7932a2d26e 100644 --- a/lib/substitutions/src/operator_attributes.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -1,11 +1,25 @@ +#include "substitutions/operator_pattern/get_attribute.h" #include "op-attrs/get_op_type.h" -#include "substitutions/get_attribute.h" +#include "utils/containers.h" namespace FlexFlow { std::optional get_attribute(BatchMatmulAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + default: + return std::nullopt; + } +} + +std::optional get_attribute(BatchNormAttrs const &p, OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + case OperatorAttributeKey::RELU: + return p.relu; default: return std::nullopt; } @@ -14,6 +28,8 @@ std::optional get_attribute(BatchMatmulAttrs const &p, std::optional get_attribute(CastAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::DATA_TYPE: return p.dtype; default: @@ -24,6 +40,8 @@ std::optional get_attribute(CastAttrs const &p, std::optional get_attribute(CombineAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::PARALLEL_OP_DIM: return p.combine_dim; case OperatorAttributeKey::PARALLEL_DIM: @@ -36,6 +54,8 @@ std::optional get_attribute(CombineAttrs const &p, std::optional get_attribute(ConcatAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::AXIS: return p.axis; default: @@ -46,6 +66,8 @@ std::optional get_attribute(ConcatAttrs const &p, std::optional get_attribute(Conv2DAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::KERNEL_H: return p.kernel_h; case OperatorAttributeKey::KERNEL_W: @@ -72,6 +94,8 @@ std::optional get_attribute(Conv2DAttrs const &p, std::optional get_attribute(ElementBinaryAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -80,6 +104,8 @@ std::optional get_attribute(ElementBinaryAttrs const &p, std::optional get_attribute(ElementUnaryAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -88,6 +114,8 @@ std::optional get_attribute(ElementUnaryAttrs const &p, std::optional get_attribute(ElementScalarUnaryAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -96,6 +124,8 @@ std::optional std::optional get_attribute(DropoutAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -104,6 +134,8 @@ std::optional get_attribute(DropoutAttrs const &p, std::optional get_attribute(EmbeddingAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::DATA_TYPE: return p.data_type; case OperatorAttributeKey::AGGR: @@ -120,6 +152,8 @@ std::optional get_attribute(EmbeddingAttrs const &p, std::optional get_attribute(FlatAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -128,6 +162,8 @@ std::optional get_attribute(FlatAttrs const &p, std::optional get_attribute(GatherAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::AXIS: return p.dim; default: @@ -135,9 +171,21 @@ std::optional get_attribute(GatherAttrs const &p, } } +std::optional get_attribute(InputAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + default: + return std::nullopt; + } +} + std::optional get_attribute(LayerNormAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -146,6 +194,8 @@ std::optional get_attribute(LayerNormAttrs const &p, std::optional get_attribute(LinearAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::OUT_CHANNELS: return p.out_channels; case OperatorAttributeKey::USE_BIAS: @@ -166,6 +216,8 @@ std::optional get_attribute(LinearAttrs const &p, std::optional get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::NUM_HEADS: return p.num_heads; case OperatorAttributeKey::USE_BIAS: @@ -175,9 +227,22 @@ std::optional } } +std::optional + get_attribute(NoopAttrs const &p, OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + default: + return std::nullopt; + } +} + + std::optional get_attribute(Pool2DAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::KERNEL_H: return p.kernel_h; case OperatorAttributeKey::KERNEL_W: @@ -202,6 +267,8 @@ std::optional get_attribute(Pool2DAttrs const &p, std::optional get_attribute(ReduceAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -210,6 +277,8 @@ std::optional get_attribute(ReduceAttrs const &p, std::optional get_attribute(ReductionAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::PARALLEL_OP_DIM: return p.reduction_dim; case OperatorAttributeKey::PARALLEL_OP_DEGREE: @@ -222,6 +291,8 @@ std::optional get_attribute(ReductionAttrs const &p, std::optional get_attribute(RepartitionAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::PARALLEL_OP_DIM: return p.repartition_dim; case OperatorAttributeKey::PARALLEL_OP_DEGREE: @@ -234,6 +305,8 @@ std::optional get_attribute(RepartitionAttrs const &p, std::optional get_attribute(ReplicateAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::PARALLEL_OP_DIM: return p.replicate_dim; case OperatorAttributeKey::PARALLEL_OP_DEGREE: @@ -246,14 +319,31 @@ std::optional get_attribute(ReplicateAttrs const &p, std::optional get_attribute(ReshapeAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } } +std::optional get_attribute(ReverseAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + case OperatorAttributeKey::AXIS: + return p.axis; + default: + return std::nullopt; + } +} + + std::optional get_attribute(SplitAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::AXIS: return p.axis; default: @@ -264,6 +354,8 @@ std::optional get_attribute(SplitAttrs const &p, std::optional get_attribute(SoftmaxAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::AXIS: return p.dim; default: @@ -274,6 +366,8 @@ std::optional get_attribute(SoftmaxAttrs const &p, std::optional get_attribute(TopKAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -282,38 +376,22 @@ std::optional get_attribute(TopKAttrs const &p, std::optional get_attribute(TransposeAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::PERMUTATION: - return p.perm; + return as_vector(p.perm); default: return std::nullopt; } } -struct GetAttribute { - GetAttribute(OperatorAttributeKey key) : key(key) {} - - template - std::optional operator()(T const &t) { - return get_attribute(t, this->key); - } - -private: - OperatorAttributeKey key; -}; - -struct GetOpType { - template - std::optional operator()(T const &t) { - return get_op_type(t); - } -}; - std::optional get_attribute(PCGOperatorAttrs const &p, OperatorAttributeKey key) { - if (key == OperatorAttributeKey::OP_TYPE) { - return std::visit(GetOpType{}, p); - } - return std::visit(GetAttribute(key), p); + return p.visit< + std::optional + >([&](auto const &attrs) { + return get_attribute(attrs, key); + }); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc new file mode 100644 index 0000000000..bc913b7c1a --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc @@ -0,0 +1,121 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml +/* proj-data +{ + "generated_from": "7867bd0f403866c13417171bb5ec364c" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" + +#include "substitutions/constraint_type.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include + +namespace FlexFlow { +OperatorAttributeConstraint::OperatorAttributeConstraint( + ::FlexFlow::ConstraintType const &constraint_type, + ::FlexFlow::OperatorAttributeExpr const &attribute_expr, + ::FlexFlow::OperatorAttributeValue const &attribute_value) + : constraint_type(constraint_type), attribute_expr(attribute_expr), + attribute_value(attribute_value) {} +bool OperatorAttributeConstraint::operator==( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) == std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool OperatorAttributeConstraint::operator!=( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) != std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool OperatorAttributeConstraint::operator<( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) < std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool OperatorAttributeConstraint::operator>( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) > std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool OperatorAttributeConstraint::operator<=( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) <= std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool OperatorAttributeConstraint::operator>=( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) >= std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorAttributeConstraint const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ConstraintType>{}(x.constraint_type) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::OperatorAttributeExpr>{}(x.attribute_expr) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::OperatorAttributeValue>{}(x.attribute_value) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::OperatorAttributeConstraint + adl_serializer::from_json( + json const &j) { + return { + j.at("constraint_type").template get<::FlexFlow::ConstraintType>(), + j.at("attribute_expr").template get<::FlexFlow::OperatorAttributeExpr>(), + j.at("attribute_value") + .template get<::FlexFlow::OperatorAttributeValue>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::OperatorAttributeConstraint const &v) { + j["__type"] = "OperatorAttributeConstraint"; + j["constraint_type"] = v.constraint_type; + j["attribute_expr"] = v.attribute_expr; + j["attribute_value"] = v.attribute_value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(OperatorAttributeConstraint const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + OperatorAttributeConstraint const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc new file mode 100644 index 0000000000..71f03bd364 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc @@ -0,0 +1,21 @@ +#include "substitutions/operator_pattern/operator_attribute_expr.h" +#include "substitutions/operator_pattern/get_attribute.h" +#include "substitutions/operator_pattern/eval_list_access.h" +#include "substitutions/operator_pattern/eval_list_size.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::optional + evaluate_attribute_expr(PCGOperatorAttrs const &attrs, + OperatorAttributeExpr const &expr) { + return expr.visit< + std::optional + >(overload { + [&](OperatorAttributeKey const &k) { return get_attribute(attrs, k); }, + [&](OperatorAttributeListSize const &k) { return eval_list_size(attrs, k); }, + [&](OperatorAttributeListIndexAccess const &k) { return eval_list_access(attrs, k); }, + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.dtg.cc new file mode 100644 index 0000000000..60c77d8d0f --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.dtg.cc @@ -0,0 +1,137 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "15d26dd1f08092ecc82b725aa9411597" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +OperatorAttributeExpr::OperatorAttributeExpr( + ::FlexFlow::OperatorAttributeKey const &v) + : raw_variant(v) {} +OperatorAttributeExpr::OperatorAttributeExpr( + ::FlexFlow::OperatorAttributeListSize const &v) + : raw_variant(v) {} +OperatorAttributeExpr::OperatorAttributeExpr( + ::FlexFlow::OperatorAttributeListIndexAccess const &v) + : raw_variant(v) {} +bool OperatorAttributeExpr::operator==( + OperatorAttributeExpr const &other) const { + return this->raw_variant == other.raw_variant; +} +bool OperatorAttributeExpr::operator!=( + OperatorAttributeExpr const &other) const { + return this->raw_variant != other.raw_variant; +} +bool OperatorAttributeExpr::operator<( + OperatorAttributeExpr const &other) const { + return this->raw_variant < other.raw_variant; +} +bool OperatorAttributeExpr::operator>( + OperatorAttributeExpr const &other) const { + return this->raw_variant > other.raw_variant; +} +bool OperatorAttributeExpr::operator<=( + OperatorAttributeExpr const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool OperatorAttributeExpr::operator>=( + OperatorAttributeExpr const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::OperatorAttributeExpr>::operator()( + ::FlexFlow::OperatorAttributeExpr const &x) const { + return std::hash< + std::variant<::FlexFlow::OperatorAttributeKey, + ::FlexFlow::OperatorAttributeListSize, + ::FlexFlow::OperatorAttributeListIndexAccess>>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::OperatorAttributeExpr + adl_serializer<::FlexFlow::OperatorAttributeExpr>::from_json( + json const &j) { + std::string key = j.at("type").template get(); + if (key == "key") { + return ::FlexFlow::OperatorAttributeExpr{ + j.at("value").template get<::FlexFlow::OperatorAttributeKey>()}; + } else if (key == "list_size") { + return ::FlexFlow::OperatorAttributeExpr{ + j.at("value").template get<::FlexFlow::OperatorAttributeListSize>()}; + } else if (key == "list_idx") { + return ::FlexFlow::OperatorAttributeExpr{ + j.at("value") + .template get<::FlexFlow::OperatorAttributeListIndexAccess>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::OperatorAttributeExpr>::to_json( + json &j, ::FlexFlow::OperatorAttributeExpr const &x) { + j["__type"] = "OperatorAttributeExpr"; + switch (x.index()) { + case 0: { + j["type"] = "key"; + j["value"] = x.get<::FlexFlow::OperatorAttributeKey>(); + break; + } + case 1: { + j["type"] = "list_size"; + j["value"] = x.get<::FlexFlow::OperatorAttributeListSize>(); + break; + } + case 2: { + j["type"] = "list_idx"; + j["value"] = x.get<::FlexFlow::OperatorAttributeListIndexAccess>(); + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeExpr", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::OperatorAttributeExpr const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeExpr", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::OperatorAttributeExpr const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_key.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_key.dtg.cc new file mode 100644 index 0000000000..a24e1c12e4 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_key.dtg.cc @@ -0,0 +1,505 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml +/* proj-data +{ + "generated_from": "e637388397720b328b1f4b9ba6b14611" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorAttributeKey x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(OperatorAttributeKey x) { + switch (x) { + case OperatorAttributeKey::OP_TYPE: + return "OP_TYPE"; + case OperatorAttributeKey::USE_BIAS: + return "USE_BIAS"; + case OperatorAttributeKey::GROUPS: + return "GROUPS"; + case OperatorAttributeKey::POOL_TYPE: + return "POOL_TYPE"; + case OperatorAttributeKey::KERNEL_H: + return "KERNEL_H"; + case OperatorAttributeKey::KERNEL_W: + return "KERNEL_W"; + case OperatorAttributeKey::DATA_TYPE: + return "DATA_TYPE"; + case OperatorAttributeKey::SCALAR: + return "SCALAR"; + case OperatorAttributeKey::STRIDE_H: + return "STRIDE_H"; + case OperatorAttributeKey::STRIDE_W: + return "STRIDE_W"; + case OperatorAttributeKey::PADDING_H: + return "PADDING_H"; + case OperatorAttributeKey::PADDING_W: + return "PADDING_W"; + case OperatorAttributeKey::AGGR: + return "AGGR"; + case OperatorAttributeKey::NUM_ENTRIES: + return "NUM_ENTRIES"; + case OperatorAttributeKey::OUT_CHANNELS: + return "OUT_CHANNELS"; + case OperatorAttributeKey::ACTIVATION: + return "ACTIVATION"; + case OperatorAttributeKey::NUMDIM: + return "NUMDIM"; + case OperatorAttributeKey::AXIS: + return "AXIS"; + case OperatorAttributeKey::PERMUTATION: + return "PERMUTATION"; + case OperatorAttributeKey::OUTSHUFFLE: + return "OUTSHUFFLE"; + case OperatorAttributeKey::MERGE_GCONV_COUNT: + return "MERGE_GCONV_COUNT"; + case OperatorAttributeKey::AXES: + return "AXES"; + case OperatorAttributeKey::KEEP_DIMS: + return "KEEP_DIMS"; + case OperatorAttributeKey::EPSILON: + return "EPSILON"; + case OperatorAttributeKey::PARALLEL_OP_DIM: + return "PARALLEL_OP_DIM"; + case OperatorAttributeKey::PARALLEL_OP_DEGREE: + return "PARALLEL_OP_DEGREE"; + case OperatorAttributeKey::SOFTMAX_DIM: + return "SOFTMAX_DIM"; + case OperatorAttributeKey::NUM_HEADS: + return "NUM_HEADS"; + case OperatorAttributeKey::PARALLEL_DIM: + return "PARALLEL_DIM"; + case OperatorAttributeKey::PARALLEL_DEGREE: + return "PARALLEL_DEGREE"; + case OperatorAttributeKey::PAD: + return "PAD"; + case OperatorAttributeKey::EMBED_DIM: + return "EMBED_DIM"; + case OperatorAttributeKey::KDIM: + return "KDIM"; + case OperatorAttributeKey::VDIM: + return "VDIM"; + case OperatorAttributeKey::DROPOUT: + return "DROPOUT"; + case OperatorAttributeKey::BIAS: + return "BIAS"; + case OperatorAttributeKey::ADD_BIAS_KV: + return "ADD_BIAS_KV"; + case OperatorAttributeKey::ADD_ZERO_ATTN: + return "ADD_ZERO_ATTN"; + case OperatorAttributeKey::A_SEQ_LENGTH_DIM: + return "A_SEQ_LENGTH_DIM"; + case OperatorAttributeKey::B_SEQ_LENGTH_DIM: + return "B_SEQ_LENGTH_DIM"; + case OperatorAttributeKey::RELU: + return "RELU"; + case OperatorAttributeKey::TARGET_DIMS: + return "TARGET_DIMS"; + case OperatorAttributeKey::RATE: + return "RATE"; + case OperatorAttributeKey::SEED: + return "SEED"; + case OperatorAttributeKey::SHOULD_BROADCAST_LHS: + return "SHOULD_BROADCAST_LHS"; + case OperatorAttributeKey::SHOULD_BROADCAST_RHS: + return "SHOULD_BROADCAST_RHS"; + case OperatorAttributeKey::DIM: + return "DIM"; + case OperatorAttributeKey::ELEMENTWISE_AFFINE: + return "ELEMENTWISE_AFFINE"; + case OperatorAttributeKey::REGULARIZER: + return "REGULARIZER"; + case OperatorAttributeKey::SHAPE: + return "SHAPE"; + case OperatorAttributeKey::SPLITS: + return "SPLITS"; + case OperatorAttributeKey::K: + return "K"; + case OperatorAttributeKey::SORTED: + return "SORTED"; + case OperatorAttributeKey::COMBINE_DIM: + return "COMBINE_DIM"; + case OperatorAttributeKey::COMBINE_DEGREE: + return "COMBINE_DEGREE"; + case OperatorAttributeKey::NUM_INPUTS: + return "NUM_INPUTS"; + default: + std::ostringstream oss; + oss << "Unknown OperatorAttributeKey value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, OperatorAttributeKey x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, OperatorAttributeKey x) { + switch (x) { + case OperatorAttributeKey::OP_TYPE: + j = "OP_TYPE"; + break; + case OperatorAttributeKey::USE_BIAS: + j = "USE_BIAS"; + break; + case OperatorAttributeKey::GROUPS: + j = "GROUPS"; + break; + case OperatorAttributeKey::POOL_TYPE: + j = "POOL_TYPE"; + break; + case OperatorAttributeKey::KERNEL_H: + j = "KERNEL_H"; + break; + case OperatorAttributeKey::KERNEL_W: + j = "KERNEL_W"; + break; + case OperatorAttributeKey::DATA_TYPE: + j = "DATA_TYPE"; + break; + case OperatorAttributeKey::SCALAR: + j = "SCALAR"; + break; + case OperatorAttributeKey::STRIDE_H: + j = "STRIDE_H"; + break; + case OperatorAttributeKey::STRIDE_W: + j = "STRIDE_W"; + break; + case OperatorAttributeKey::PADDING_H: + j = "PADDING_H"; + break; + case OperatorAttributeKey::PADDING_W: + j = "PADDING_W"; + break; + case OperatorAttributeKey::AGGR: + j = "AGGR"; + break; + case OperatorAttributeKey::NUM_ENTRIES: + j = "NUM_ENTRIES"; + break; + case OperatorAttributeKey::OUT_CHANNELS: + j = "OUT_CHANNELS"; + break; + case OperatorAttributeKey::ACTIVATION: + j = "ACTIVATION"; + break; + case OperatorAttributeKey::NUMDIM: + j = "NUMDIM"; + break; + case OperatorAttributeKey::AXIS: + j = "AXIS"; + break; + case OperatorAttributeKey::PERMUTATION: + j = "PERMUTATION"; + break; + case OperatorAttributeKey::OUTSHUFFLE: + j = "OUTSHUFFLE"; + break; + case OperatorAttributeKey::MERGE_GCONV_COUNT: + j = "MERGE_GCONV_COUNT"; + break; + case OperatorAttributeKey::AXES: + j = "AXES"; + break; + case OperatorAttributeKey::KEEP_DIMS: + j = "KEEP_DIMS"; + break; + case OperatorAttributeKey::EPSILON: + j = "EPSILON"; + break; + case OperatorAttributeKey::PARALLEL_OP_DIM: + j = "PARALLEL_OP_DIM"; + break; + case OperatorAttributeKey::PARALLEL_OP_DEGREE: + j = "PARALLEL_OP_DEGREE"; + break; + case OperatorAttributeKey::SOFTMAX_DIM: + j = "SOFTMAX_DIM"; + break; + case OperatorAttributeKey::NUM_HEADS: + j = "NUM_HEADS"; + break; + case OperatorAttributeKey::PARALLEL_DIM: + j = "PARALLEL_DIM"; + break; + case OperatorAttributeKey::PARALLEL_DEGREE: + j = "PARALLEL_DEGREE"; + break; + case OperatorAttributeKey::PAD: + j = "PAD"; + break; + case OperatorAttributeKey::EMBED_DIM: + j = "EMBED_DIM"; + break; + case OperatorAttributeKey::KDIM: + j = "KDIM"; + break; + case OperatorAttributeKey::VDIM: + j = "VDIM"; + break; + case OperatorAttributeKey::DROPOUT: + j = "DROPOUT"; + break; + case OperatorAttributeKey::BIAS: + j = "BIAS"; + break; + case OperatorAttributeKey::ADD_BIAS_KV: + j = "ADD_BIAS_KV"; + break; + case OperatorAttributeKey::ADD_ZERO_ATTN: + j = "ADD_ZERO_ATTN"; + break; + case OperatorAttributeKey::A_SEQ_LENGTH_DIM: + j = "A_SEQ_LENGTH_DIM"; + break; + case OperatorAttributeKey::B_SEQ_LENGTH_DIM: + j = "B_SEQ_LENGTH_DIM"; + break; + case OperatorAttributeKey::RELU: + j = "RELU"; + break; + case OperatorAttributeKey::TARGET_DIMS: + j = "TARGET_DIMS"; + break; + case OperatorAttributeKey::RATE: + j = "RATE"; + break; + case OperatorAttributeKey::SEED: + j = "SEED"; + break; + case OperatorAttributeKey::SHOULD_BROADCAST_LHS: + j = "SHOULD_BROADCAST_LHS"; + break; + case OperatorAttributeKey::SHOULD_BROADCAST_RHS: + j = "SHOULD_BROADCAST_RHS"; + break; + case OperatorAttributeKey::DIM: + j = "DIM"; + break; + case OperatorAttributeKey::ELEMENTWISE_AFFINE: + j = "ELEMENTWISE_AFFINE"; + break; + case OperatorAttributeKey::REGULARIZER: + j = "REGULARIZER"; + break; + case OperatorAttributeKey::SHAPE: + j = "SHAPE"; + break; + case OperatorAttributeKey::SPLITS: + j = "SPLITS"; + break; + case OperatorAttributeKey::K: + j = "K"; + break; + case OperatorAttributeKey::SORTED: + j = "SORTED"; + break; + case OperatorAttributeKey::COMBINE_DIM: + j = "COMBINE_DIM"; + break; + case OperatorAttributeKey::COMBINE_DEGREE: + j = "COMBINE_DEGREE"; + break; + case OperatorAttributeKey::NUM_INPUTS: + j = "NUM_INPUTS"; + break; + default: + std::ostringstream oss; + oss << "Unknown OperatorAttributeKey value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, OperatorAttributeKey &x) { + std::string as_str = j.get(); + if (as_str == "OP_TYPE") { + x = OperatorAttributeKey::OP_TYPE; + } else if (as_str == "USE_BIAS") { + x = OperatorAttributeKey::USE_BIAS; + } else if (as_str == "GROUPS") { + x = OperatorAttributeKey::GROUPS; + } else if (as_str == "POOL_TYPE") { + x = OperatorAttributeKey::POOL_TYPE; + } else if (as_str == "KERNEL_H") { + x = OperatorAttributeKey::KERNEL_H; + } else if (as_str == "KERNEL_W") { + x = OperatorAttributeKey::KERNEL_W; + } else if (as_str == "DATA_TYPE") { + x = OperatorAttributeKey::DATA_TYPE; + } else if (as_str == "SCALAR") { + x = OperatorAttributeKey::SCALAR; + } else if (as_str == "STRIDE_H") { + x = OperatorAttributeKey::STRIDE_H; + } else if (as_str == "STRIDE_W") { + x = OperatorAttributeKey::STRIDE_W; + } else if (as_str == "PADDING_H") { + x = OperatorAttributeKey::PADDING_H; + } else if (as_str == "PADDING_W") { + x = OperatorAttributeKey::PADDING_W; + } else if (as_str == "AGGR") { + x = OperatorAttributeKey::AGGR; + } else if (as_str == "NUM_ENTRIES") { + x = OperatorAttributeKey::NUM_ENTRIES; + } else if (as_str == "OUT_CHANNELS") { + x = OperatorAttributeKey::OUT_CHANNELS; + } else if (as_str == "ACTIVATION") { + x = OperatorAttributeKey::ACTIVATION; + } else if (as_str == "NUMDIM") { + x = OperatorAttributeKey::NUMDIM; + } else if (as_str == "AXIS") { + x = OperatorAttributeKey::AXIS; + } else if (as_str == "PERMUTATION") { + x = OperatorAttributeKey::PERMUTATION; + } else if (as_str == "OUTSHUFFLE") { + x = OperatorAttributeKey::OUTSHUFFLE; + } else if (as_str == "MERGE_GCONV_COUNT") { + x = OperatorAttributeKey::MERGE_GCONV_COUNT; + } else if (as_str == "AXES") { + x = OperatorAttributeKey::AXES; + } else if (as_str == "KEEP_DIMS") { + x = OperatorAttributeKey::KEEP_DIMS; + } else if (as_str == "EPSILON") { + x = OperatorAttributeKey::EPSILON; + } else if (as_str == "PARALLEL_OP_DIM") { + x = OperatorAttributeKey::PARALLEL_OP_DIM; + } else if (as_str == "PARALLEL_OP_DEGREE") { + x = OperatorAttributeKey::PARALLEL_OP_DEGREE; + } else if (as_str == "SOFTMAX_DIM") { + x = OperatorAttributeKey::SOFTMAX_DIM; + } else if (as_str == "NUM_HEADS") { + x = OperatorAttributeKey::NUM_HEADS; + } else if (as_str == "PARALLEL_DIM") { + x = OperatorAttributeKey::PARALLEL_DIM; + } else if (as_str == "PARALLEL_DEGREE") { + x = OperatorAttributeKey::PARALLEL_DEGREE; + } else if (as_str == "PAD") { + x = OperatorAttributeKey::PAD; + } else if (as_str == "EMBED_DIM") { + x = OperatorAttributeKey::EMBED_DIM; + } else if (as_str == "KDIM") { + x = OperatorAttributeKey::KDIM; + } else if (as_str == "VDIM") { + x = OperatorAttributeKey::VDIM; + } else if (as_str == "DROPOUT") { + x = OperatorAttributeKey::DROPOUT; + } else if (as_str == "BIAS") { + x = OperatorAttributeKey::BIAS; + } else if (as_str == "ADD_BIAS_KV") { + x = OperatorAttributeKey::ADD_BIAS_KV; + } else if (as_str == "ADD_ZERO_ATTN") { + x = OperatorAttributeKey::ADD_ZERO_ATTN; + } else if (as_str == "A_SEQ_LENGTH_DIM") { + x = OperatorAttributeKey::A_SEQ_LENGTH_DIM; + } else if (as_str == "B_SEQ_LENGTH_DIM") { + x = OperatorAttributeKey::B_SEQ_LENGTH_DIM; + } else if (as_str == "RELU") { + x = OperatorAttributeKey::RELU; + } else if (as_str == "TARGET_DIMS") { + x = OperatorAttributeKey::TARGET_DIMS; + } else if (as_str == "RATE") { + x = OperatorAttributeKey::RATE; + } else if (as_str == "SEED") { + x = OperatorAttributeKey::SEED; + } else if (as_str == "SHOULD_BROADCAST_LHS") { + x = OperatorAttributeKey::SHOULD_BROADCAST_LHS; + } else if (as_str == "SHOULD_BROADCAST_RHS") { + x = OperatorAttributeKey::SHOULD_BROADCAST_RHS; + } else if (as_str == "DIM") { + x = OperatorAttributeKey::DIM; + } else if (as_str == "ELEMENTWISE_AFFINE") { + x = OperatorAttributeKey::ELEMENTWISE_AFFINE; + } else if (as_str == "REGULARIZER") { + x = OperatorAttributeKey::REGULARIZER; + } else if (as_str == "SHAPE") { + x = OperatorAttributeKey::SHAPE; + } else if (as_str == "SPLITS") { + x = OperatorAttributeKey::SPLITS; + } else if (as_str == "K") { + x = OperatorAttributeKey::K; + } else if (as_str == "SORTED") { + x = OperatorAttributeKey::SORTED; + } else if (as_str == "COMBINE_DIM") { + x = OperatorAttributeKey::COMBINE_DIM; + } else if (as_str == "COMBINE_DEGREE") { + x = OperatorAttributeKey::COMBINE_DEGREE; + } else if (as_str == "NUM_INPUTS") { + x = OperatorAttributeKey::NUM_INPUTS; + } else { + std::ostringstream oss; + oss << "Unknown OperatorAttributeKey value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::element( + FlexFlow::OperatorAttributeKey::OP_TYPE, + FlexFlow::OperatorAttributeKey::USE_BIAS, + FlexFlow::OperatorAttributeKey::GROUPS, + FlexFlow::OperatorAttributeKey::POOL_TYPE, + FlexFlow::OperatorAttributeKey::KERNEL_H, + FlexFlow::OperatorAttributeKey::KERNEL_W, + FlexFlow::OperatorAttributeKey::DATA_TYPE, + FlexFlow::OperatorAttributeKey::SCALAR, + FlexFlow::OperatorAttributeKey::STRIDE_H, + FlexFlow::OperatorAttributeKey::STRIDE_W, + FlexFlow::OperatorAttributeKey::PADDING_H, + FlexFlow::OperatorAttributeKey::PADDING_W, + FlexFlow::OperatorAttributeKey::AGGR, + FlexFlow::OperatorAttributeKey::NUM_ENTRIES, + FlexFlow::OperatorAttributeKey::OUT_CHANNELS, + FlexFlow::OperatorAttributeKey::ACTIVATION, + FlexFlow::OperatorAttributeKey::NUMDIM, + FlexFlow::OperatorAttributeKey::AXIS, + FlexFlow::OperatorAttributeKey::PERMUTATION, + FlexFlow::OperatorAttributeKey::OUTSHUFFLE, + FlexFlow::OperatorAttributeKey::MERGE_GCONV_COUNT, + FlexFlow::OperatorAttributeKey::AXES, + FlexFlow::OperatorAttributeKey::KEEP_DIMS, + FlexFlow::OperatorAttributeKey::EPSILON, + FlexFlow::OperatorAttributeKey::PARALLEL_OP_DIM, + FlexFlow::OperatorAttributeKey::PARALLEL_OP_DEGREE, + FlexFlow::OperatorAttributeKey::SOFTMAX_DIM, + FlexFlow::OperatorAttributeKey::NUM_HEADS, + FlexFlow::OperatorAttributeKey::PARALLEL_DIM, + FlexFlow::OperatorAttributeKey::PARALLEL_DEGREE, + FlexFlow::OperatorAttributeKey::PAD, + FlexFlow::OperatorAttributeKey::EMBED_DIM, + FlexFlow::OperatorAttributeKey::KDIM, + FlexFlow::OperatorAttributeKey::VDIM, + FlexFlow::OperatorAttributeKey::DROPOUT, + FlexFlow::OperatorAttributeKey::BIAS, + FlexFlow::OperatorAttributeKey::ADD_BIAS_KV, + FlexFlow::OperatorAttributeKey::ADD_ZERO_ATTN, + FlexFlow::OperatorAttributeKey::A_SEQ_LENGTH_DIM, + FlexFlow::OperatorAttributeKey::B_SEQ_LENGTH_DIM, + FlexFlow::OperatorAttributeKey::RELU, + FlexFlow::OperatorAttributeKey::TARGET_DIMS, + FlexFlow::OperatorAttributeKey::RATE, + FlexFlow::OperatorAttributeKey::SEED, + FlexFlow::OperatorAttributeKey::SHOULD_BROADCAST_LHS, + FlexFlow::OperatorAttributeKey::SHOULD_BROADCAST_RHS, + FlexFlow::OperatorAttributeKey::DIM, + FlexFlow::OperatorAttributeKey::ELEMENTWISE_AFFINE, + FlexFlow::OperatorAttributeKey::REGULARIZER, + FlexFlow::OperatorAttributeKey::SHAPE, + FlexFlow::OperatorAttributeKey::SPLITS, + FlexFlow::OperatorAttributeKey::K, + FlexFlow::OperatorAttributeKey::SORTED, + FlexFlow::OperatorAttributeKey::COMBINE_DIM, + FlexFlow::OperatorAttributeKey::COMBINE_DEGREE, + FlexFlow::OperatorAttributeKey::NUM_INPUTS); +} +} // namespace rc diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc new file mode 100644 index 0000000000..71b71d4a51 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc @@ -0,0 +1,101 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml +/* proj-data +{ + "generated_from": "1dc90d1e823f05b82c1a5ff433fbf000" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_list_access.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include + +namespace FlexFlow { +OperatorAttributeListIndexAccess::OperatorAttributeListIndexAccess( + ::FlexFlow::OperatorAttributeKey const &attribute_key, int const &index) + : attribute_key(attribute_key), index(index) {} +bool OperatorAttributeListIndexAccess::operator==( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) == + std::tie(other.attribute_key, other.index); +} +bool OperatorAttributeListIndexAccess::operator!=( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) != + std::tie(other.attribute_key, other.index); +} +bool OperatorAttributeListIndexAccess::operator<( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) < + std::tie(other.attribute_key, other.index); +} +bool OperatorAttributeListIndexAccess::operator>( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) > + std::tie(other.attribute_key, other.index); +} +bool OperatorAttributeListIndexAccess::operator<=( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) <= + std::tie(other.attribute_key, other.index); +} +bool OperatorAttributeListIndexAccess::operator>=( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) >= + std::tie(other.attribute_key, other.index); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorAttributeListIndexAccess const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OperatorAttributeKey>{}(x.attribute_key) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.index) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::OperatorAttributeListIndexAccess + adl_serializer::from_json( + json const &j) { + return { + j.at("attribute_key").template get<::FlexFlow::OperatorAttributeKey>(), + j.at("index").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::OperatorAttributeListIndexAccess const &v) { + j["__type"] = "OperatorAttributeListIndexAccess"; + j["attribute_key"] = v.attribute_key; + j["index"] = v.index; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::OperatorAttributeKey>(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(OperatorAttributeListIndexAccess const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + OperatorAttributeListIndexAccess const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc new file mode 100644 index 0000000000..eb7ae28131 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc @@ -0,0 +1,88 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml +/* proj-data +{ + "generated_from": "30999ad6b0603e380bc33d32fa088e45" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_list_size.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include + +namespace FlexFlow { +OperatorAttributeListSize::OperatorAttributeListSize( + ::FlexFlow::OperatorAttributeKey const &attribute_key) + : attribute_key(attribute_key) {} +bool OperatorAttributeListSize::operator==( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) == std::tie(other.attribute_key); +} +bool OperatorAttributeListSize::operator!=( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) != std::tie(other.attribute_key); +} +bool OperatorAttributeListSize::operator<( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) < std::tie(other.attribute_key); +} +bool OperatorAttributeListSize::operator>( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) > std::tie(other.attribute_key); +} +bool OperatorAttributeListSize::operator<=( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) <= std::tie(other.attribute_key); +} +bool OperatorAttributeListSize::operator>=( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) >= std::tie(other.attribute_key); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorAttributeListSize const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OperatorAttributeKey>{}(x.attribute_key) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::OperatorAttributeListSize + adl_serializer::from_json( + json const &j) { + return { + j.at("attribute_key").template get<::FlexFlow::OperatorAttributeKey>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::OperatorAttributeListSize const &v) { + j["__type"] = "OperatorAttributeListSize"; + j["attribute_key"] = v.attribute_key; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::OperatorAttributeKey>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(OperatorAttributeListSize const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OperatorAttributeListSize const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc new file mode 100644 index 0000000000..5eaf54bb5f --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc @@ -0,0 +1,73 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml +/* proj-data +{ + "generated_from": "968d7a3e93303a7fa7482bbcd50246b6" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" +#include "utils/fmt.h" +#include +#include + +namespace FlexFlow { +OperatorAttributePattern::OperatorAttributePattern( + std::unordered_set<::FlexFlow::OperatorAttributeConstraint> const + &attribute_constraints) + : attribute_constraints(attribute_constraints) {} +bool OperatorAttributePattern::operator==( + OperatorAttributePattern const &other) const { + return std::tie(this->attribute_constraints) == + std::tie(other.attribute_constraints); +} +bool OperatorAttributePattern::operator!=( + OperatorAttributePattern const &other) const { + return std::tie(this->attribute_constraints) != + std::tie(other.attribute_constraints); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorAttributePattern const &x) const { + size_t result = 0; + result ^= + std::hash>{}( + x.attribute_constraints) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::OperatorAttributePattern + adl_serializer::from_json( + json const &j) { + return { + j.at("attribute_constraints") + .template get< + std::unordered_set<::FlexFlow::OperatorAttributeConstraint>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::OperatorAttributePattern const &v) { + j["__type"] = "OperatorAttributePattern"; + j["attribute_constraints"] = v.attribute_constraints; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(OperatorAttributePattern const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OperatorAttributePattern const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_value.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_value.dtg.cc new file mode 100644 index 0000000000..376a9c2ce8 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_value.dtg.cc @@ -0,0 +1,292 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +/* proj-data +{ + "generated_from": "de14592f1f4bcfb52689bc95e9d3b55f" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +OperatorAttributeValue::OperatorAttributeValue(int const &v) : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(bool const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(std::vector const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue( + std::vector<::FlexFlow::ff_dim_t> const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue( + ::FlexFlow::OperatorType const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::Activation const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::ff_dim_t const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(size_t const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::AggregateOp const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue( + std::optional<::FlexFlow::RegularizerAttrs> const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::PoolOp const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::TensorShape const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::DataType const &v) + : raw_variant(v) {} +bool OperatorAttributeValue::operator==( + OperatorAttributeValue const &other) const { + return this->raw_variant == other.raw_variant; +} +bool OperatorAttributeValue::operator!=( + OperatorAttributeValue const &other) const { + return this->raw_variant != other.raw_variant; +} +bool OperatorAttributeValue::operator<( + OperatorAttributeValue const &other) const { + return this->raw_variant < other.raw_variant; +} +bool OperatorAttributeValue::operator>( + OperatorAttributeValue const &other) const { + return this->raw_variant > other.raw_variant; +} +bool OperatorAttributeValue::operator<=( + OperatorAttributeValue const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool OperatorAttributeValue::operator>=( + OperatorAttributeValue const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::OperatorAttributeValue>::operator()( + ::FlexFlow::OperatorAttributeValue const &x) const { + return std::hash, + std::vector<::FlexFlow::ff_dim_t>, + ::FlexFlow::OperatorType, + ::FlexFlow::Activation, + ::FlexFlow::ff_dim_t, + size_t, + ::FlexFlow::AggregateOp, + std::optional<::FlexFlow::RegularizerAttrs>, + ::FlexFlow::PoolOp, + ::FlexFlow::TensorShape, + ::FlexFlow::DataType>>{}(x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::OperatorAttributeValue + adl_serializer<::FlexFlow::OperatorAttributeValue>::from_json( + json const &j) { + std::string key = j.at("type").template get(); + if (key == "int") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get()}; + } else if (key == "bool") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get()}; + } else if (key == "std::vector") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get>()}; + } else if (key == "std::vector<::FlexFlow::ff_dim_t>") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get>()}; + } else if (key == "::FlexFlow::OperatorType") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::OperatorType>()}; + } else if (key == "::FlexFlow::Activation") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::Activation>()}; + } else if (key == "::FlexFlow::ff_dim_t") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::ff_dim_t>()}; + } else if (key == "size_t") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get()}; + } else if (key == "::FlexFlow::AggregateOp") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::AggregateOp>()}; + } else if (key == "std::optional<::FlexFlow::RegularizerAttrs>") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value") + .template get>()}; + } else if (key == "::FlexFlow::PoolOp") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::PoolOp>()}; + } else if (key == "::FlexFlow::TensorShape") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::TensorShape>()}; + } else if (key == "::FlexFlow::DataType") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::DataType>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::OperatorAttributeValue>::to_json( + json &j, ::FlexFlow::OperatorAttributeValue const &x) { + j["__type"] = "OperatorAttributeValue"; + switch (x.index()) { + case 0: { + j["type"] = "int"; + j["value"] = x.get(); + break; + } + case 1: { + j["type"] = "bool"; + j["value"] = x.get(); + break; + } + case 2: { + j["type"] = "std::vector"; + j["value"] = x.get>(); + break; + } + case 3: { + j["type"] = "std::vector<::FlexFlow::ff_dim_t>"; + j["value"] = x.get>(); + break; + } + case 4: { + j["type"] = "::FlexFlow::OperatorType"; + j["value"] = x.get<::FlexFlow::OperatorType>(); + break; + } + case 5: { + j["type"] = "::FlexFlow::Activation"; + j["value"] = x.get<::FlexFlow::Activation>(); + break; + } + case 6: { + j["type"] = "::FlexFlow::ff_dim_t"; + j["value"] = x.get<::FlexFlow::ff_dim_t>(); + break; + } + case 7: { + j["type"] = "size_t"; + j["value"] = x.get(); + break; + } + case 8: { + j["type"] = "::FlexFlow::AggregateOp"; + j["value"] = x.get<::FlexFlow::AggregateOp>(); + break; + } + case 9: { + j["type"] = "std::optional<::FlexFlow::RegularizerAttrs>"; + j["value"] = x.get>(); + break; + } + case 10: { + j["type"] = "::FlexFlow::PoolOp"; + j["value"] = x.get<::FlexFlow::PoolOp>(); + break; + } + case 11: { + j["type"] = "::FlexFlow::TensorShape"; + j["value"] = x.get<::FlexFlow::TensorShape>(); + break; + } + case 12: { + j["type"] = "::FlexFlow::DataType"; + j["value"] = x.get<::FlexFlow::DataType>(); + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeValue", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::OperatorAttributeValue const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << "=" + << x.get>() << ">"; + break; + } + case 3: { + oss << "=" + << x.get>() << ">"; + break; + } + case 4: { + oss << ""; + break; + } + case 5: { + oss << ""; + break; + } + case 6: { + oss << ""; + break; + } + case 7: { + oss << ""; + break; + } + case 8: { + oss << ""; + break; + } + case 9: { + oss << "=" + << x.get>() << ">"; + break; + } + case 10: { + oss << ""; + break; + } + case 11: { + oss << ""; + break; + } + case 12: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeValue", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::OperatorAttributeValue const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc new file mode 100644 index 0000000000..5455cdced5 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc @@ -0,0 +1,21 @@ +#include "substitutions/operator_pattern/satisfies_constraint.h" +#include "substitutions/operator_pattern/operator_attribute_expr.h" + +namespace FlexFlow { + +bool operator_satisfies_constraint(PCGOperatorAttrs const &attrs, OperatorAttributeConstraint const &constraint) { + std::optional expr_val = evaluate_attribute_expr(attrs, constraint.attribute_expr); + + if (!expr_val.has_value()) { + return false; + } + + switch (constraint.constraint_type) { + case ConstraintType::EQUAL: + return expr_val.value() == constraint.attribute_value; + default: + throw mk_runtime_error(fmt::format("Unknown constraint type {}", static_cast(constraint.constraint_type))); + } +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_pattern.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_pattern.cc new file mode 100644 index 0000000000..28d7803a6b --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_pattern.cc @@ -0,0 +1,11 @@ +#include "substitutions/operator_pattern/satisfies_pattern.h" +#include "substitutions/operator_pattern/satisfies_constraint.h" + +namespace FlexFlow { + +bool operator_satisfies_pattern(PCGOperatorAttrs const &attrs, OperatorAttributePattern const &pattern) { + return all_of(pattern.attribute_constraints, + [&](OperatorAttributeConstraint const &c) { return operator_satisfies_constraint(attrs, c); }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc b/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc new file mode 100644 index 0000000000..f20afc1164 --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml +/* proj-data +{ + "generated_from": "1e5beabcb8e3657d8fe9c9c8b1310cb1" +} +*/ + +#include "substitutions/output_graph/attr_constant.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include + +namespace FlexFlow { +AttrConstant::AttrConstant(::FlexFlow::OperatorAttributeValue const &value) + : value(value) {} +bool AttrConstant::operator==(AttrConstant const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool AttrConstant::operator!=(AttrConstant const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool AttrConstant::operator<(AttrConstant const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool AttrConstant::operator>(AttrConstant const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool AttrConstant::operator<=(AttrConstant const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool AttrConstant::operator>=(AttrConstant const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::AttrConstant const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OperatorAttributeValue>{}(x.value) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(AttrConstant const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, AttrConstant const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc new file mode 100644 index 0000000000..7d07bf9218 --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc @@ -0,0 +1,20 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml +/* proj-data +{ + "generated_from": "9084c9afb2724504a6f4db4288a83a0d" +} +*/ + +#include "substitutions/output_graph/output_graph_expr.dtg.h" + +#include "substitutions/output_graph/output_operator_attrs_assignment.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +OutputGraphExpr::OutputGraphExpr( + ::FlexFlow::NodeLabelledOpenMultiDiGraph< + ::FlexFlow::OutputOperatorAttrsAssignment> const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc new file mode 100644 index 0000000000..0c6abc925d --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc @@ -0,0 +1,77 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml +/* proj-data +{ + "generated_from": "e3b3a741183fcb38cfa68aacb82e12d1" +} +*/ + +#include "substitutions/output_graph/output_operator_attr_access.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" +#include "utils/graph.h" +#include + +namespace FlexFlow { +OutputOperatorAttrAccess::OutputOperatorAttrAccess( + ::FlexFlow::Node const &node, + ::FlexFlow::OperatorAttributeExpr const &attr_expr) + : node(node), attr_expr(attr_expr) {} +bool OutputOperatorAttrAccess::operator==( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) == + std::tie(other.node, other.attr_expr); +} +bool OutputOperatorAttrAccess::operator!=( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) != + std::tie(other.node, other.attr_expr); +} +bool OutputOperatorAttrAccess::operator<( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) < + std::tie(other.node, other.attr_expr); +} +bool OutputOperatorAttrAccess::operator>( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) > + std::tie(other.node, other.attr_expr); +} +bool OutputOperatorAttrAccess::operator<=( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) <= + std::tie(other.node, other.attr_expr); +} +bool OutputOperatorAttrAccess::operator>=( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) >= + std::tie(other.node, other.attr_expr); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OutputOperatorAttrAccess const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::OperatorAttributeExpr>{}(x.attr_expr) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputOperatorAttrAccess const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OutputOperatorAttrAccess const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.dtg.cc new file mode 100644 index 0000000000..bf1b07c825 --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.dtg.cc @@ -0,0 +1,79 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "89ebf777a5b909eef78ab5a5a177e041" +} +*/ + +#include "substitutions/output_graph/output_operator_attribute_expr.dtg.h" + +#include + +namespace FlexFlow { +OutputOperatorAttributeExpr::OutputOperatorAttributeExpr( + ::FlexFlow::OutputOperatorAttrAccess const &v) + : raw_variant(v) {} +OutputOperatorAttributeExpr::OutputOperatorAttributeExpr( + ::FlexFlow::AttrConstant const &v) + : raw_variant(v) {} +bool OutputOperatorAttributeExpr::operator==( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant == other.raw_variant; +} +bool OutputOperatorAttributeExpr::operator!=( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant != other.raw_variant; +} +bool OutputOperatorAttributeExpr::operator<( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant < other.raw_variant; +} +bool OutputOperatorAttributeExpr::operator>( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant > other.raw_variant; +} +bool OutputOperatorAttributeExpr::operator<=( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool OutputOperatorAttributeExpr::operator>=( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::OutputOperatorAttributeExpr>::operator()( + ::FlexFlow::OutputOperatorAttributeExpr const &x) const { + return std::hash>{}(x.raw_variant); +} +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::OutputOperatorAttributeExpr const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OutputOperatorAttributeExpr", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::OutputOperatorAttributeExpr const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc new file mode 100644 index 0000000000..7a1950482a --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml +/* proj-data +{ + "generated_from": "bbfb309c5a39a729da23dace4df4a9de" +} +*/ + +#include "substitutions/output_graph/output_operator_attrs_assignment.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include "substitutions/output_graph/output_operator_attribute_expr.dtg.h" +#include +#include + +namespace FlexFlow { +OutputOperatorAttrsAssignment::OutputOperatorAttrsAssignment( + std::unordered_map<::FlexFlow::OperatorAttributeKey, + ::FlexFlow::OutputOperatorAttributeExpr> const + &assignments) + : assignments(assignments) {} +bool OutputOperatorAttrsAssignment::operator==( + OutputOperatorAttrsAssignment const &other) const { + return std::tie(this->assignments) == std::tie(other.assignments); +} +bool OutputOperatorAttrsAssignment::operator!=( + OutputOperatorAttrsAssignment const &other) const { + return std::tie(this->assignments) != std::tie(other.assignments); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OutputOperatorAttrsAssignment const &x) const { + size_t result = 0; + result ^= + std::hash>{}( + x.assignments) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputOperatorAttrsAssignment const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + OutputOperatorAttrsAssignment const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc b/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc new file mode 100644 index 0000000000..7133ab42a7 --- /dev/null +++ b/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc @@ -0,0 +1,21 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/pcg_pattern.struct.toml +/* proj-data +{ + "generated_from": "f536f846828ba39266dd4a1fbaeec0e6" +} +*/ + +#include "substitutions/pcg_pattern.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +PCGPattern::PCGPattern(::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::OperatorAttributePattern, + ::FlexFlow::TensorAttributePattern> const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc new file mode 100644 index 0000000000..d4dd87543d --- /dev/null +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -0,0 +1,17 @@ +#include "substitutions/sub_parallel_computation_graph.h" + +namespace FlexFlow { + +ParallelLayerAttrs get_parallel_layer_attrs(SubParallelComputationGraph const &spcg, Node const &n) { + return spcg.raw_graph.at(n); +} + +PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &spcg, Node const &n) { + return get_parallel_layer_attrs(spcg, n).attrs; +} + +ParallelTensorAttrs get_parallel_tensor_attrs(SubParallelComputationGraph const &spcg, OpenMultiDiEdge const &e) { + return spcg.raw_graph.at(e); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc new file mode 100644 index 0000000000..83baef2cfc --- /dev/null +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc @@ -0,0 +1,22 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml +/* proj-data +{ + "generated_from": "0022d1b2c1447667695a120c154a0168" +} +*/ + +#include "substitutions/sub_parallel_computation_graph.dtg.h" + +#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +SubParallelComputationGraph::SubParallelComputationGraph( + ::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/substitution.cc b/lib/substitutions/src/substitutions/substitution.cc new file mode 100644 index 0000000000..9d51c7018d --- /dev/null +++ b/lib/substitutions/src/substitutions/substitution.cc @@ -0,0 +1,153 @@ +#include "substitutions/substitution.h" + +namespace FlexFlow { + +/* struct AddMappedEdgeFunctor { */ +/* bidict const &node_mapping; */ +/* SubParallelComputationGraph &new_pcg; */ + +/* template */ +/* void operator()(T const &t) { */ +/* return add_mapped_edge(t); */ +/* } */ + +/* void add_mapped_edge(InputMultiDiEdge const &e) { */ +/* new_pcg.add_edge(InputMultiDiEdge{ */ +/* node_mapping.at_l(e.dst), new_pcg.add_node_port(), e.uid}); */ +/* } */ + +/* void add_mapped_edge(OutputMultiDiEdge const &e) { */ +/* new_pcg.add_edge(OutputMultiDiEdge{ */ +/* node_mapping.at_l(e.src), new_pcg.add_node_port(), e.uid}); */ +/* } */ + +/* void add_mapped_edge(MultiDiEdge const &e) { */ +/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), */ +/* new_pcg.add_node_port(), */ +/* node_mapping.at_l(e.src), */ +/* new_pcg.add_node_port()}); */ +/* } */ +/* }; */ + + + +/* struct AddNewEdgeFunctor { */ +/* SubParallelComputationGraph const &old_pcg; */ +/* SubParallelComputationGraph &new_pcg; */ +/* MultiDiGraphPatternMatch const &match; */ +/* bidict node_mapping; */ + +/* template */ +/* void operator()(TO const &old_edge, TN const &new_edge) { */ +/* return add_new_edge(old_edge, new_edge); */ +/* } */ + +/* void add_new_edge(InputMultiDiEdge const &old_edge, */ +/* InputMultiDiEdge const &new_edge) { */ +/* new_pcg.add_edge(InputMultiDiEdge{node_mapping.at_l(new_edge.dst), */ +/* new_pcg.add_node_port(), */ +/* old_edge.uid}); */ +/* } */ + +/* void add_new_edge(MultiDiEdge const &old_edge, */ +/* InputMultiDiEdge const &new_edge) { */ +/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(new_edge.dst), */ +/* new_pcg.add_node_port(), */ +/* node_mapping.at_l(old_edge.src), */ +/* new_pcg.add_node_port()}); */ +/* } */ + +/* void add_new_edge(OutputMultiDiEdge const &old_edge, */ +/* OutputMultiDiEdge const &new_edge) { */ +/* new_pcg.add_edge(OutputMultiDiEdge{node_mapping.at_l(new_edge.src), */ +/* new_pcg.add_node_port(), */ +/* old_edge.uid}); */ +/* } */ + +/* void add_new_edge(MultiDiEdge const &old_edge, */ +/* OutputMultiDiEdge const &new_edge) { */ +/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(old_edge.dst), */ +/* new_pcg.add_node_port(), */ +/* node_mapping.at_l(new_edge.src), */ +/* new_pcg.add_node_port()}); */ +/* } */ + +/* void add_new_edge(InputMultiDiEdge const &, OutputMultiDiEdge const &) { */ +/* assert(false); */ +/* } */ + +/* void add_new_edge(OpenMultiDiEdge const &, MultiDiEdge const &) { */ +/* assert(false); */ +/* } */ + +/* void add_new_edge(OutputMultiDiEdge const &, InputMultiDiEdge const &) { */ +/* assert(false); */ +/* } */ +/* }; */ + +/* SubParallelComputationGraph */ +/* apply_substitution(SubParallelComputationGraph const &pcg, */ +/* Substitution const &substitution, */ +/* MultiDiGraphPatternMatch const &match) { */ +/* SubParallelComputationGraph new_pcg = */ +/* OutputLabelledOpenMultiDiGraph::template create< */ +/* UnorderedOutputLabelledOpenMultiDiGraph>(); */ +/* bidict node_mapping; // Refactor it with global nodes */ +/* for (Node const &node : get_nodes(pcg)) { */ +/* if (!contains_r(match.node_assignment, node)) { */ +/* node_mapping.equate(node, new_pcg.add_node(pcg.at(node))); */ +/* } */ +/* } */ +/* for (OpenMultiDiEdge const &edge : get_edges(pcg)) { */ +/* if (!contains_r(match.edge_assignment, edge)) { */ +/* visit(AddMappedEdgeFunctor{node_mapping, new_pcg}, edge); */ +/* } */ +/* } */ +/* for (Node const &output_node : */ +/* get_nodes(substitution.output_graph_expr.value())) { */ +/* Operator new_op = get_operator_attrs( */ +/* pcg, match, substitution.output_graph_expr.value().at(output_node)); */ +/* Node new_node = new_pcg.add_node(new_op); */ +/* node_mapping.equate(output_node, new_node); */ +/* } */ +/* for (OpenMultiDiEdge const &output_edge : */ +/* get_edges(substitution.output_graph_expr.value())) { */ +/* if (std::holds_alternative(output_edge)) { */ +/* InputMultiDiEdge e = std::get(output_edge); */ +/* OpenMultiDiEdge original_edge = */ +/* match.edge_assignment.at_l(substitution.input_mapping.at_r(e)); */ +/* visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, */ +/* original_edge, */ +/* output_edge); */ +/* } else if (std::holds_alternative(output_edge)) { */ +/* OutputMultiDiEdge e = std::get(output_edge); */ +/* OpenMultiDiEdge original_edge = */ +/* match.edge_assignment.at_l(substitution.output_mapping.at_r(e)); */ +/* visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, */ +/* original_edge, */ +/* output_edge); */ +/* } else { */ +/* assert(std::holds_alternative(output_edge)); */ +/* MultiDiEdge e = std::get(output_edge); */ +/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), */ +/* new_pcg.add_node_port(), */ +/* node_mapping.at_l(e.src), */ +/* new_pcg.add_node_port()}); */ +/* } */ +/* } */ + +/* return new_pcg; */ +/* } */ + +bool is_valid_substitution(Substitution const &) { + NOT_IMPLEMENTED(); +} + +SubParallelComputationGraph + apply_substitution(SubParallelComputationGraph const &, + Substitution const &, + MultiDiGraphPatternMatch const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/substitution.dtg.cc b/lib/substitutions/src/substitutions/substitution.dtg.cc new file mode 100644 index 0000000000..67d39d6ff7 --- /dev/null +++ b/lib/substitutions/src/substitutions/substitution.dtg.cc @@ -0,0 +1,28 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/substitution.struct.toml +/* proj-data +{ + "generated_from": "c101f1d63e2d8d80a0ec9c5f5db4fa12" +} +*/ + +#include "substitutions/substitution.dtg.h" + +#include "substitutions/output_graph/output_graph_expr.dtg.h" +#include "substitutions/pcg_pattern.dtg.h" + +namespace FlexFlow { +Substitution::Substitution( + ::FlexFlow::PCGPattern const &pcg_pattern, + ::FlexFlow::OutputGraphExpr const &output_graph_expr, + ::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, + ::FlexFlow::InputMultiDiEdge> const + &input_edge_match_to_output, + ::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::OutputMultiDiEdge> const + &output_edge_match_to_output) + : pcg_pattern(pcg_pattern), output_graph_expr(output_graph_expr), + input_edge_match_to_output(input_edge_match_to_output), + output_edge_match_to_output(output_edge_match_to_output) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc new file mode 100644 index 0000000000..0b07725471 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc @@ -0,0 +1,23 @@ +#include "substitutions/tensor_pattern/eval_list_access.h" +#include "substitutions/tensor_pattern/get_attribute.h" +#include "utils/containers.h" +#include "utils/overload.h" + +namespace FlexFlow { + +TensorAttributeValue eval_list_access(ParallelTensorAttrs const &attrs, TensorAttributeListIndexAccess const &acc) { + TensorAttributeValue from_attr = get_attribute(attrs, acc.attribute_key); + + return from_attr.visit( + overload { + [&](std::vector const &v) -> TensorAttributeValue { + return TensorAttributeValue{ + static_cast(at_idx(v, acc.index).value()) + }; + }, + [](auto &&) -> TensorAttributeValue { throw mk_runtime_error("Invalid operand"); }, + } + ); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/eval_list_size.cc b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_size.cc new file mode 100644 index 0000000000..b63a35380f --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_size.cc @@ -0,0 +1,18 @@ +#include "substitutions/tensor_pattern/eval_list_size.h" +#include "substitutions/tensor_pattern/get_attribute.h" +#include "utils/overload.h" + +namespace FlexFlow { + +TensorAttributeValue eval_list_size(ParallelTensorAttrs const &attrs, TensorAttributeListSize const &acc) { + TensorAttributeValue from_attr = get_attribute(attrs, acc.attribute_key); + + return from_attr.visit(overload { + [](std::vector const &v) -> TensorAttributeValue { + return TensorAttributeValue{v.size()}; + }, + [](auto &&) -> TensorAttributeValue { throw mk_runtime_error("Invalid operand"); }, + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc new file mode 100644 index 0000000000..ed7ad120a4 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc @@ -0,0 +1,25 @@ +#include "substitutions/tensor_pattern/get_attribute.h" +#include "utils/containers.h" + +namespace FlexFlow { + +TensorAttributeValue get_attribute(ParallelTensorAttrs const &attrs, TensorAttributeKey key) { + switch (key) { + case TensorAttributeKey::DIM_SIZES: { + std::vector sizes = transform(as_vector(ff_ordered(attrs.shape.dims)), + [](ParallelDim const &d) { return d.size; }); + return TensorAttributeValue{sizes}; + } + case TensorAttributeKey::DIM_DEGREES: { + std::vector degrees = transform(as_vector(ff_ordered(attrs.shape.dims)), + [](ParallelDim const &d) { + return static_cast(d.degree); + }); + return TensorAttributeValue{degrees}; + } + default: + throw std::runtime_error(fmt::format("Unknown TensorAttributeKey {}", static_cast(key))); + } +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc new file mode 100644 index 0000000000..e464523eef --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc @@ -0,0 +1,17 @@ +#include "substitutions/tensor_pattern/satisfies_constraint.h" +#include "substitutions/tensor_pattern/tensor_attribute_expr.h" + +namespace FlexFlow { + +bool parallel_tensor_satisfies_constraint(ParallelTensorAttrs const &attrs, TensorAttributeConstraint const &constraint) { + TensorAttributeValue expr_val = evaluate_attribute_expr(attrs, constraint.attribute_expr); + + switch (constraint.constraint_type) { + case ConstraintType::EQUAL: + return expr_val == constraint.attribute_value; + default: + throw mk_runtime_error(fmt::format("Unknown constraint type {}", static_cast(constraint.constraint_type))); + } +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_pattern.cc b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_pattern.cc new file mode 100644 index 0000000000..1383c46cf8 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_pattern.cc @@ -0,0 +1,10 @@ +#include "substitutions/tensor_pattern/satisfies_pattern.h" +#include "substitutions/tensor_pattern/satisfies_constraint.h" + +namespace FlexFlow { + +bool parallel_tensor_satisfies_pattern(ParallelTensorAttrs const &attrs, TensorAttributePattern const &pattern) { + return all_of(pattern.attribute_constraints, + [&](TensorAttributeConstraint const &c) { return parallel_tensor_satisfies_constraint(attrs, c); }); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.cc new file mode 100644 index 0000000000..6f9df90fb2 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.cc @@ -0,0 +1,119 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml +/* proj-data +{ + "generated_from": "29dbf81668bc864b06af52261060335e" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" + +#include "substitutions/constraint_type.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" +#include + +namespace FlexFlow { +TensorAttributeConstraint::TensorAttributeConstraint( + ::FlexFlow::ConstraintType const &constraint_type, + ::FlexFlow::TensorAttributeExpr const &attribute_expr, + ::FlexFlow::TensorAttributeValue const &attribute_value) + : constraint_type(constraint_type), attribute_expr(attribute_expr), + attribute_value(attribute_value) {} +bool TensorAttributeConstraint::operator==( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) == std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool TensorAttributeConstraint::operator!=( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) != std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool TensorAttributeConstraint::operator<( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) < std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool TensorAttributeConstraint::operator>( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) > std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool TensorAttributeConstraint::operator<=( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) <= std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool TensorAttributeConstraint::operator>=( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) >= std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttributeConstraint const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ConstraintType>{}(x.constraint_type) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::TensorAttributeExpr>{}(x.attribute_expr) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::TensorAttributeValue>{}(x.attribute_value) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorAttributeConstraint + adl_serializer::from_json( + json const &j) { + return { + j.at("constraint_type").template get<::FlexFlow::ConstraintType>(), + j.at("attribute_expr").template get<::FlexFlow::TensorAttributeExpr>(), + j.at("attribute_value").template get<::FlexFlow::TensorAttributeValue>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorAttributeConstraint const &v) { + j["__type"] = "TensorAttributeConstraint"; + j["constraint_type"] = v.constraint_type; + j["attribute_expr"] = v.attribute_expr; + j["attribute_value"] = v.attribute_value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttributeConstraint const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorAttributeConstraint const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.cc new file mode 100644 index 0000000000..068d5d7a69 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.cc @@ -0,0 +1,26 @@ +#include "substitutions/tensor_pattern/tensor_attribute_expr.h" +#include "substitutions/tensor_pattern/get_attribute.h" +#include "substitutions/tensor_pattern/eval_list_size.h" +#include "substitutions/tensor_pattern/eval_list_access.h" +#include "utils/overload.h" + +namespace FlexFlow { + +TensorAttributeValue + evaluate_attribute_expr(ParallelTensorAttrs const &attrs, + TensorAttributeExpr const &expr) { + + return expr.visit(overload { + [&](TensorAttributeKey const &key) { + return get_attribute(attrs, key); + }, + [&](TensorAttributeListSize const &s) { + return eval_list_size(attrs, s); + }, + [&](TensorAttributeListIndexAccess const &s) { + return eval_list_access(attrs, s); + } + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.dtg.cc new file mode 100644 index 0000000000..a42f18bf26 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.dtg.cc @@ -0,0 +1,129 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "b91285329f12f1b409805cbf9be575b2" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +TensorAttributeExpr::TensorAttributeExpr( + ::FlexFlow::TensorAttributeKey const &v) + : raw_variant(v) {} +TensorAttributeExpr::TensorAttributeExpr( + ::FlexFlow::TensorAttributeListSize const &v) + : raw_variant(v) {} +TensorAttributeExpr::TensorAttributeExpr( + ::FlexFlow::TensorAttributeListIndexAccess const &v) + : raw_variant(v) {} +bool TensorAttributeExpr::operator==(TensorAttributeExpr const &other) const { + return this->raw_variant == other.raw_variant; +} +bool TensorAttributeExpr::operator!=(TensorAttributeExpr const &other) const { + return this->raw_variant != other.raw_variant; +} +bool TensorAttributeExpr::operator<(TensorAttributeExpr const &other) const { + return this->raw_variant < other.raw_variant; +} +bool TensorAttributeExpr::operator>(TensorAttributeExpr const &other) const { + return this->raw_variant > other.raw_variant; +} +bool TensorAttributeExpr::operator<=(TensorAttributeExpr const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool TensorAttributeExpr::operator>=(TensorAttributeExpr const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::TensorAttributeExpr>::operator()( + ::FlexFlow::TensorAttributeExpr const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::TensorAttributeExpr + adl_serializer<::FlexFlow::TensorAttributeExpr>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "key") { + return ::FlexFlow::TensorAttributeExpr{ + j.at("value").template get<::FlexFlow::TensorAttributeKey>()}; + } else if (key == "list_size") { + return ::FlexFlow::TensorAttributeExpr{ + j.at("value").template get<::FlexFlow::TensorAttributeListSize>()}; + } else if (key == "list_idx") { + return ::FlexFlow::TensorAttributeExpr{ + j.at("value") + .template get<::FlexFlow::TensorAttributeListIndexAccess>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::TensorAttributeExpr>::to_json( + json &j, ::FlexFlow::TensorAttributeExpr const &x) { + j["__type"] = "TensorAttributeExpr"; + switch (x.index()) { + case 0: { + j["type"] = "key"; + j["value"] = x.get<::FlexFlow::TensorAttributeKey>(); + break; + } + case 1: { + j["type"] = "list_size"; + j["value"] = x.get<::FlexFlow::TensorAttributeListSize>(); + break; + } + case 2: { + j["type"] = "list_idx"; + j["value"] = x.get<::FlexFlow::TensorAttributeListIndexAccess>(); + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeExpr", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::TensorAttributeExpr const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeExpr", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::TensorAttributeExpr const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_key.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_key.dtg.cc new file mode 100644 index 0000000000..fe87c63777 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_key.dtg.cc @@ -0,0 +1,73 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml +/* proj-data +{ + "generated_from": "63a7c40c1e5b582f98b59750a35f0a08" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttributeKey x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(TensorAttributeKey x) { + switch (x) { + case TensorAttributeKey::DIM_SIZES: + return "DIM_SIZES"; + case TensorAttributeKey::DIM_DEGREES: + return "DIM_DEGREES"; + default: + std::ostringstream oss; + oss << "Unknown TensorAttributeKey value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, TensorAttributeKey x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, TensorAttributeKey x) { + switch (x) { + case TensorAttributeKey::DIM_SIZES: + j = "DIM_SIZES"; + break; + case TensorAttributeKey::DIM_DEGREES: + j = "DIM_DEGREES"; + break; + default: + std::ostringstream oss; + oss << "Unknown TensorAttributeKey value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, TensorAttributeKey &x) { + std::string as_str = j.get(); + if (as_str == "DIM_SIZES") { + x = TensorAttributeKey::DIM_SIZES; + } else if (as_str == "DIM_DEGREES") { + x = TensorAttributeKey::DIM_DEGREES; + } else { + std::ostringstream oss; + oss << "Unknown TensorAttributeKey value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::element( + FlexFlow::TensorAttributeKey::DIM_SIZES, + FlexFlow::TensorAttributeKey::DIM_DEGREES); +} +} // namespace rc diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc new file mode 100644 index 0000000000..4e28de2c28 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc @@ -0,0 +1,99 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml +/* proj-data +{ + "generated_from": "41f5449cd700b6d7ab017f3efa39dc1d" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h" + +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include + +namespace FlexFlow { +TensorAttributeListIndexAccess::TensorAttributeListIndexAccess( + ::FlexFlow::TensorAttributeKey const &attribute_key, int const &index) + : attribute_key(attribute_key), index(index) {} +bool TensorAttributeListIndexAccess::operator==( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) == + std::tie(other.attribute_key, other.index); +} +bool TensorAttributeListIndexAccess::operator!=( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) != + std::tie(other.attribute_key, other.index); +} +bool TensorAttributeListIndexAccess::operator<( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) < + std::tie(other.attribute_key, other.index); +} +bool TensorAttributeListIndexAccess::operator>( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) > + std::tie(other.attribute_key, other.index); +} +bool TensorAttributeListIndexAccess::operator<=( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) <= + std::tie(other.attribute_key, other.index); +} +bool TensorAttributeListIndexAccess::operator>=( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) >= + std::tie(other.attribute_key, other.index); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttributeListIndexAccess const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorAttributeKey>{}(x.attribute_key) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.index) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorAttributeListIndexAccess + adl_serializer::from_json( + json const &j) { + return {j.at("attribute_key").template get<::FlexFlow::TensorAttributeKey>(), + j.at("index").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorAttributeListIndexAccess const &v) { + j["__type"] = "TensorAttributeListIndexAccess"; + j["attribute_key"] = v.attribute_key; + j["index"] = v.index; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::TensorAttributeKey>(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorAttributeListIndexAccess const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + TensorAttributeListIndexAccess const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc new file mode 100644 index 0000000000..24d8b6c025 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc @@ -0,0 +1,87 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml +/* proj-data +{ + "generated_from": "ec72cd39de5d1c0f0478696d7b83e4e9" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h" + +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include + +namespace FlexFlow { +TensorAttributeListSize::TensorAttributeListSize( + ::FlexFlow::TensorAttributeKey const &attribute_key) + : attribute_key(attribute_key) {} +bool TensorAttributeListSize::operator==( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) == std::tie(other.attribute_key); +} +bool TensorAttributeListSize::operator!=( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) != std::tie(other.attribute_key); +} +bool TensorAttributeListSize::operator<( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) < std::tie(other.attribute_key); +} +bool TensorAttributeListSize::operator>( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) > std::tie(other.attribute_key); +} +bool TensorAttributeListSize::operator<=( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) <= std::tie(other.attribute_key); +} +bool TensorAttributeListSize::operator>=( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) >= std::tie(other.attribute_key); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttributeListSize const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorAttributeKey>{}(x.attribute_key) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorAttributeListSize + adl_serializer::from_json( + json const &j) { + return {j.at("attribute_key").template get<::FlexFlow::TensorAttributeKey>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorAttributeListSize const &v) { + j["__type"] = "TensorAttributeListSize"; + j["attribute_key"] = v.attribute_key; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::TensorAttributeKey>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorAttributeListSize const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorAttributeListSize const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc new file mode 100644 index 0000000000..121549d4dc --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc @@ -0,0 +1,71 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml +/* proj-data +{ + "generated_from": "42a51afce383f1ddc3d70827aa94a68f" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" + +#include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" +#include "utils/hash-utils.h" +#include +#include + +namespace FlexFlow { +TensorAttributePattern::TensorAttributePattern( + std::unordered_set<::FlexFlow::TensorAttributeConstraint> const + &attribute_constraints) + : attribute_constraints(attribute_constraints) {} +bool TensorAttributePattern::operator==( + TensorAttributePattern const &other) const { + return std::tie(this->attribute_constraints) == + std::tie(other.attribute_constraints); +} +bool TensorAttributePattern::operator!=( + TensorAttributePattern const &other) const { + return std::tie(this->attribute_constraints) != + std::tie(other.attribute_constraints); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttributePattern const &x) const { + size_t result = 0; + result ^= + std::hash>{}( + x.attribute_constraints) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorAttributePattern + adl_serializer::from_json(json const &j) { + return {j.at("attribute_constraints") + .template get< + std::unordered_set<::FlexFlow::TensorAttributeConstraint>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorAttributePattern const &v) { + j["__type"] = "TensorAttributePattern"; + j["attribute_constraints"] = v.attribute_constraints; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttributePattern const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorAttributePattern const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_value.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_value.dtg.cc new file mode 100644 index 0000000000..27a82c4ffe --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_value.dtg.cc @@ -0,0 +1,105 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml +/* proj-data +{ + "generated_from": "d80cf2e618d64df284c2647430a12a86" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +TensorAttributeValue::TensorAttributeValue(size_t const &v) : raw_variant(v) {} +TensorAttributeValue::TensorAttributeValue(std::vector const &v) + : raw_variant(v) {} +bool TensorAttributeValue::operator==(TensorAttributeValue const &other) const { + return this->raw_variant == other.raw_variant; +} +bool TensorAttributeValue::operator!=(TensorAttributeValue const &other) const { + return this->raw_variant != other.raw_variant; +} +bool TensorAttributeValue::operator<(TensorAttributeValue const &other) const { + return this->raw_variant < other.raw_variant; +} +bool TensorAttributeValue::operator>(TensorAttributeValue const &other) const { + return this->raw_variant > other.raw_variant; +} +bool TensorAttributeValue::operator<=(TensorAttributeValue const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool TensorAttributeValue::operator>=(TensorAttributeValue const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::TensorAttributeValue>::operator()( + ::FlexFlow::TensorAttributeValue const &x) const { + return std::hash>>{}(x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::TensorAttributeValue + adl_serializer<::FlexFlow::TensorAttributeValue>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "size_t") { + return ::FlexFlow::TensorAttributeValue{ + j.at("value").template get()}; + } else if (key == "std::vector") { + return ::FlexFlow::TensorAttributeValue{ + j.at("value").template get>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::TensorAttributeValue>::to_json( + json &j, ::FlexFlow::TensorAttributeValue const &x) { + j["__type"] = "TensorAttributeValue"; + switch (x.index()) { + case 0: { + j["type"] = "size_t"; + j["value"] = x.get(); + break; + } + case 1: { + j["type"] = "std::vector"; + j["value"] = x.get>(); + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeValue", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::TensorAttributeValue const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << "=" + << x.get>() << ">"; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeValue", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::TensorAttributeValue const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc new file mode 100644 index 0000000000..fbefc6f01a --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "b4086fd78ca7ec0475ed7abfd034304c" +} +*/ + +#include "substitutions/unlabelled/closed_pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +ClosedPatternEdge::ClosedPatternEdge(::FlexFlow::MultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool ClosedPatternEdge::operator==(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool ClosedPatternEdge::operator!=(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool ClosedPatternEdge::operator<(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool ClosedPatternEdge::operator>(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool ClosedPatternEdge::operator<=(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool ClosedPatternEdge::operator>=(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ClosedPatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::MultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.cc new file mode 100644 index 0000000000..704e0aea1a --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.cc @@ -0,0 +1,9 @@ +#include "substitutions/unlabelled/downward_open_pattern_edge.h" + +namespace FlexFlow { + +int get_src_idx(DownwardOpenPatternEdge const &e) { + return get_src_idx(e.raw_edge); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc new file mode 100644 index 0000000000..30c52fbbb2 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc @@ -0,0 +1,52 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "c67ec363a91ce090dc538dcf76fa1f12" +} +*/ + +#include "substitutions/unlabelled/downward_open_pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +DownwardOpenPatternEdge::DownwardOpenPatternEdge( + ::FlexFlow::DownwardOpenMultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool DownwardOpenPatternEdge::operator==( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool DownwardOpenPatternEdge::operator!=( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool DownwardOpenPatternEdge::operator<( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool DownwardOpenPatternEdge::operator>( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool DownwardOpenPatternEdge::operator<=( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool DownwardOpenPatternEdge::operator>=( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::DownwardOpenPatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::DownwardOpenMultiDiEdge>{}(x.raw_edge) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc b/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc new file mode 100644 index 0000000000..e1a9fc1fe7 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc @@ -0,0 +1,31 @@ +#include "substitutions/unlabelled/edge_splits.h" + +namespace FlexFlow { + +std::pair get_split_edges(UnlabelledPatternEdgeSplits const &splits, ClosedPatternEdge const &e) { + std::pair raw_result = splits.unwrapped.at_l(e.raw_edge); + return { + OutputPatternEdge{raw_result.first}, + InputPatternEdge{raw_result.second}, + }; +} + +std::vector> as_closed_output_input_tuples(UnlabelledPatternEdgeSplits const &s) { + std::vector> result; + + for (auto const &kv : s.unwrapped) { + MultiDiEdge standard_edge = kv.first; + OutputMultiDiEdge output_edge = kv.second.first; + InputMultiDiEdge input_edge = kv.second.second; + + result.push_back({ + ClosedPatternEdge{standard_edge}, + OutputPatternEdge{output_edge}, + InputPatternEdge{input_edge} + }); + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc new file mode 100644 index 0000000000..4da15179da --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc @@ -0,0 +1,31 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml +/* proj-data +{ + "generated_from": "f172b041a99f4de1d396e5d451a5e64d" +} +*/ + +#include "substitutions/unlabelled/edge_splits.dtg.h" + +#include "utils/bidict.h" +#include "utils/graph.h" +#include + +namespace FlexFlow { +UnlabelledPatternEdgeSplits::UnlabelledPatternEdgeSplits( + ::FlexFlow::bidict<::FlexFlow::MultiDiEdge, + std::pair<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::InputMultiDiEdge>> const + &unwrapped) + : unwrapped(unwrapped) {} +bool UnlabelledPatternEdgeSplits::operator==( + UnlabelledPatternEdgeSplits const &other) const { + return std::tie(this->unwrapped) == std::tie(other.unwrapped); +} +bool UnlabelledPatternEdgeSplits::operator!=( + UnlabelledPatternEdgeSplits const &other) const { + return std::tie(this->unwrapped) != std::tie(other.unwrapped); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc new file mode 100644 index 0000000000..2f2c2d756c --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -0,0 +1,141 @@ +#include "substitutions/unlabelled/find_pattern_matches.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "utils/containers.h" +#include "substitutions/unlabelled/upward_open_pattern_edge.h" +#include "substitutions/unlabelled/downward_open_pattern_edge.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.h" + +namespace FlexFlow { + +static std::vector sorted_by_dst_idx(std::unordered_set const &in) { + return sorted_by(in, compare_by([](UpwardOpenPatternEdge const &e) { return get_dst_idx(e); })); +} + +static std::vector sorted_by_src_idx(std::unordered_set const &in) { + return sorted_by(in, compare_by([](DownwardOpenPatternEdge const &e) { return get_src_idx(e); })); +} + +static std::vector sorted_by_dst_idx(std::unordered_set const &in) { + return sorted_by(in, compare_by([](UpwardOpenPatternEdge const &e) { return get_dst_idx(e); })); +} + +static std::vector sorted_by_src_idx(std::unordered_set const &in) { + return sorted_by(in, compare_by([](DownwardOpenMultiDiEdge const &e) { return get_src_idx(e); })); +} + +static std::optional + get_candidate_singleton_match(UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + Node const &graph_node) { + assert(is_singleton_pattern(pattern)); + + PatternNode pattern_node = get_only(get_nodes(pattern)); + + MultiDiGraphPatternMatch match = empty_multidigraph_pattern_match(); + match.node_assignment.equate(pattern_node, graph_node); + + std::unordered_set incoming = + get_incoming_edges(graph, graph_node); + std::unordered_set outgoing = + get_outgoing_edges(graph, graph_node); + + std::unordered_set pattern_incoming = + get_incoming_edges(pattern, pattern_node); + std::unordered_set pattern_outgoing = + get_outgoing_edges(pattern, pattern_node); + + if (!pattern_incoming.empty() && pattern_incoming.size() != incoming.size()) { + return std::nullopt; + } + + if (!pattern_outgoing.empty() && pattern_outgoing.size() != outgoing.size()) { + return std::nullopt; + } + + std::vector incoming_ordered = sorted_by_dst_idx(incoming); + std::vector outgoing_ordered = sorted_by_src_idx(outgoing); + + std::vector pattern_incoming_ordered = sorted_by_dst_idx(pattern_incoming); + std::vector pattern_outgoing_ordered = sorted_by_src_idx(pattern_outgoing); + + if (pattern_incoming.size() > 0) { + std::unordered_map node_port_mapping; + for (int i = 0; i < incoming_ordered.size(); ++i) { + UpwardOpenMultiDiEdge graph_edge = incoming_ordered[i]; + UpwardOpenPatternEdge pattern_edge = pattern_incoming_ordered[i]; + NodePort graph_port = get_dst_idx(graph_edge), + pattern_port = get_dst_idx(pattern_edge); + if (!contains_key(node_port_mapping, graph_port)) { + node_port_mapping.emplace(graph_port, pattern_port); + } else { + if (pattern_port != node_port_mapping.at(graph_port)) { + return std::nullopt; + } + } + match.edge_assignment.equate(widen(pattern_edge), + widen(graph_edge)); + } + } + + if (pattern_outgoing.size() > 0) { + std::unordered_map node_port_mapping; + for (int i = 0; i < outgoing_ordered.size(); ++i) { + DownwardOpenMultiDiEdge graph_edge = outgoing_ordered[i], + DownwardOpenPatternEdge pattern_edge = pattern_outgoing_ordered[i]; + + NodePort graph_port = get_src_idx(graph_edge), + pattern_port = get_src_idx(pattern_edge); + if (!contains_key(node_port_mapping, graph_port)) { + node_port_mapping.insert({graph_port, pattern_port}); + } else { + if (pattern_port != node_port_mapping.at(graph_port)) { + return std::nullopt; + } + } + match.edge_assignment.equate(widen(pattern_edge), + widen(graph_edge)); + } + } + + return match; +} + + +std::vector + find_pattern_matches(UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MatchAdditionalCriterion const &additional_criterion) { + std::vector matches; + if (is_singleton_pattern(pattern)) { + for (Node const &graph_node : get_nodes(graph)) { + std::optional candidate = + get_candidate_singleton_match(pattern, graph, graph_node); + if (candidate.has_value() && + pattern_does_match( + pattern, graph, candidate.value(), additional_criterion)) { + matches.push_back(candidate.value()); + } + } + } else { + GraphSplit split = split_pattern(pattern); + auto subpatterns = apply_split(pattern, split); + auto prefix_matches = + find_pattern_matches(subpatterns.first, graph, additional_criterion); + auto postfix_matches = + find_pattern_matches(subpatterns.second, graph, additional_criterion); + auto edge_splits = get_edge_splits(pattern, split); + for (MultiDiGraphPatternMatch const &prefix_match : prefix_matches) { + for (MultiDiGraphPatternMatch const &postfix_match : postfix_matches) { + std::optional unsplit = + unsplit_matches(prefix_match, postfix_match, edge_splits); + if (unsplit.has_value()) { + matches.push_back(unsplit.value()); + } + } + } + } + + return matches; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc new file mode 100644 index 0000000000..2eff39bb1e --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc @@ -0,0 +1,9 @@ +#include "substitutions/unlabelled/input_pattern_edge.h" + +namespace FlexFlow { + +PatternNode get_dst_node(InputPatternEdge const &e) { + return PatternNode{e.raw_edge.dst}; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc new file mode 100644 index 0000000000..f3f5a8ce45 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "d0cc0e65c4e3feb2e9b8435947c99e5f" +} +*/ + +#include "substitutions/unlabelled/input_pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +InputPatternEdge::InputPatternEdge(::FlexFlow::InputMultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool InputPatternEdge::operator==(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool InputPatternEdge::operator!=(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool InputPatternEdge::operator<(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool InputPatternEdge::operator>(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool InputPatternEdge::operator<=(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool InputPatternEdge::operator>=(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::InputPatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::InputMultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc new file mode 100644 index 0000000000..613159ad83 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc @@ -0,0 +1,25 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml +/* proj-data +{ + "generated_from": "2dff356c85dccda1fce8f714d41c6202" +} +*/ + +#include "substitutions/unlabelled/match_additional_criterion.dtg.h" + +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/graph.h" +#include + +namespace FlexFlow { +MatchAdditionalCriterion::MatchAdditionalCriterion( + std::function const &node_criterion, + std::function const + &edge_criterion) + : node_criterion(node_criterion), edge_criterion(edge_criterion) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/match_split.cc b/lib/substitutions/src/substitutions/unlabelled/match_split.cc new file mode 100644 index 0000000000..10e7a9e975 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/match_split.cc @@ -0,0 +1,70 @@ +#include "substitutions/unlabelled/match_split.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.h" +#include "substitutions/unlabelled/pattern_split.h" +#include "substitutions/unlabelled/pattern_edge.h" +#include "substitutions/unlabelled/edge_splits.h" + +namespace FlexFlow { + +MatchSplit empty_match_split() { + return MatchSplit{ + empty_multidigraph_pattern_match(), + empty_multidigraph_pattern_match() + }; +} + +MatchSplit apply_split(UnlabelledGraphPattern const &pattern, + MultiDiGraphPatternMatch const &match, + PatternSplit const &split) { + std::unordered_set prefix = split.first; + std::unordered_set postfix = split.second; + + MatchSplit result = empty_match_split(); + + for (auto const &[pattern_node, match_node] : match.node_assignment) { + if (contains(split.first, pattern_node)) { + result.prefix_submatch.node_assignment.equate(pattern_node, match_node); + } else { + assert(contains(split.second, pattern_node)); + result.postfix_submatch.node_assignment.equate(pattern_node, match_node); + } + } + + UnlabelledPatternEdgeSplits edge_splits = get_edge_splits(pattern, split); + + std::function + handle_edge = [&](PatternEdge const &pattern_edge, + OpenMultiDiEdge const &graph_edge) -> void { + std::unordered_set edge_nodes = get_nodes(pattern_edge); + + if (is_subseteq_of(edge_nodes, prefix)) { + result.prefix_submatch.edge_assignment.equate(pattern_edge, graph_edge); + } else if (is_subseteq_of(edge_nodes, postfix)) { + result.postfix_submatch.edge_assignment.equate(pattern_edge, graph_edge); + } else { + assert(is_standard_edge(graph_edge)); + + ClosedPatternEdge closed_edge = require_closed_edge(pattern_edge); + + auto split = get_split_edges(edge_splits, closed_edge); + OutputPatternEdge output_edge = split.first; + InputPatternEdge input_edge = split.second; + + auto split_graph_edge = split_edge(std::get(graph_edge)); + OutputMultiDiEdge output_graph_edge = split_graph_edge.first; + InputMultiDiEdge input_graph_edge = split_graph_edge.second; + + handle_edge(pattern_edge_from_input_edge(input_edge), OpenMultiDiEdge{input_graph_edge}); + handle_edge(pattern_edge_from_output_edge(output_edge), OpenMultiDiEdge{output_graph_edge}); + } + }; + + for (auto const &[pattern_edge, match_edge] : match.edge_assignment) { + handle_edge(pattern_edge, match_edge); + } + + return result; +} + + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/match_split.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/match_split.dtg.cc new file mode 100644 index 0000000000..ffbdf96912 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/match_split.dtg.cc @@ -0,0 +1,26 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml +/* proj-data +{ + "generated_from": "e44c4347e07263a493cbbd5caccedd22" +} +*/ + +#include "substitutions/unlabelled/match_split.dtg.h" + +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" + +namespace FlexFlow { +MatchSplit::MatchSplit(MultiDiGraphPatternMatch const &prefix_submatch, + MultiDiGraphPatternMatch const &postfix_submatch) + : prefix_submatch(prefix_submatch), postfix_submatch(postfix_submatch) {} +bool MatchSplit::operator==(MatchSplit const &other) const { + return std::tie(this->prefix_submatch, this->postfix_submatch) == + std::tie(other.prefix_submatch, other.postfix_submatch); +} +bool MatchSplit::operator!=(MatchSplit const &other) const { + return std::tie(this->prefix_submatch, this->postfix_submatch) != + std::tie(other.prefix_submatch, other.postfix_submatch); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc new file mode 100644 index 0000000000..a2d18abc07 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc @@ -0,0 +1,54 @@ +#include "substitutions/unlabelled/multidigraph_pattern_match.h" +#include "utils/containers.h" +#include "substitutions/unlabelled/edge_splits.h" +#include "substitutions/unlabelled/pattern_edge.h" + +namespace FlexFlow { + +MultiDiGraphPatternMatch empty_multidigraph_pattern_match() { + return MultiDiGraphPatternMatch{ + bidict{}, + bidict{}, + }; +} + +std::optional unsplit_matches( + MultiDiGraphPatternMatch const &prefix, + MultiDiGraphPatternMatch const &postfix, + UnlabelledPatternEdgeSplits const &edge_splits) { + + MultiDiGraphPatternMatch result = empty_multidigraph_pattern_match(); + + std::unordered_set handled; + for (auto const &coi : as_closed_output_input_tuples(edge_splits)) { + ClosedPatternEdge closed_edge = std::get(coi); + OutputPatternEdge output_edge = std::get(coi); + InputPatternEdge input_edge = std::get(coi); + + handled.insert(pattern_edge_from_output_edge(output_edge)); + handled.insert(pattern_edge_from_input_edge(input_edge)); + + OpenMultiDiEdge output_graph_edge = + prefix.edge_assignment.at_l(pattern_edge_from_output_edge(output_edge)); + OpenMultiDiEdge input_graph_edge = postfix.edge_assignment.at_l(pattern_edge_from_input_edge(input_edge)); + if (output_graph_edge == input_graph_edge) { + result.edge_assignment.equate(pattern_edge_from_closed_edge(closed_edge), output_graph_edge); + } else { + return std::nullopt; + } + } + + for (auto const &kv : + merge_maps(prefix.edge_assignment, postfix.edge_assignment)) { + if (!contains(handled, kv.first)) { + result.edge_assignment.equate(kv.first, kv.second); + } + } + + result.node_assignment = + merge_maps(prefix.node_assignment, postfix.node_assignment); + + return result; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.dtg.cc new file mode 100644 index 0000000000..9fc2169dd7 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.dtg.cc @@ -0,0 +1,34 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml +/* proj-data +{ + "generated_from": "9842661a5d4e7d717f12d2c27da7df0d" +} +*/ + +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" + +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/bidict.h" +#include "utils/graph.h" + +namespace FlexFlow { +MultiDiGraphPatternMatch::MultiDiGraphPatternMatch( + ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> const + &node_assignment, + ::FlexFlow::bidict<::FlexFlow::PatternEdge, + ::FlexFlow::OpenMultiDiEdge> const &edge_assignment) + : node_assignment(node_assignment), edge_assignment(edge_assignment) {} +bool MultiDiGraphPatternMatch::operator==( + MultiDiGraphPatternMatch const &other) const { + return std::tie(this->node_assignment, this->edge_assignment) == + std::tie(other.node_assignment, other.edge_assignment); +} +bool MultiDiGraphPatternMatch::operator!=( + MultiDiGraphPatternMatch const &other) const { + return std::tie(this->node_assignment, this->edge_assignment) != + std::tie(other.node_assignment, other.edge_assignment); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.cc new file mode 100644 index 0000000000..6e70fc8df6 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.cc @@ -0,0 +1,9 @@ +#include "substitutions/unlabelled/output_pattern_edge.h" + +namespace FlexFlow { + +PatternNode get_src_node(OutputPatternEdge const &e) { + return PatternNode{e.raw_edge.src}; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc new file mode 100644 index 0000000000..fb9de06135 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "3222696e351c3e203e008714245c737f" +} +*/ + +#include "substitutions/unlabelled/output_pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +OutputPatternEdge::OutputPatternEdge( + ::FlexFlow::OutputMultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool OutputPatternEdge::operator==(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool OutputPatternEdge::operator!=(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool OutputPatternEdge::operator<(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool OutputPatternEdge::operator>(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool OutputPatternEdge::operator<=(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool OutputPatternEdge::operator>=(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OutputPatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OutputMultiDiEdge>{}(x.raw_edge) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc new file mode 100644 index 0000000000..fb662e0887 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc @@ -0,0 +1,56 @@ +#include "substitutions/unlabelled/pattern_edge.h" +#include "utils/containers.h" + +namespace FlexFlow { + +std::unordered_set get_nodes(PatternEdge const &e) { + return transform(get_nodes(e.raw_edge), + [](Node const &n) { return PatternNode{n}; }); +} + +bool is_standard_edge(PatternEdge const &e) { + return is_standard_edge(e.raw_edge); +} + +bool is_input_edge(PatternEdge const &e) { + return is_input_edge(e.raw_edge); +} + +bool is_output_edge(PatternEdge const &e) { + return is_output_edge(e.raw_edge); +} + +ClosedPatternEdge require_closed_edge(PatternEdge const &e) { + assert (is_closed_edge(e)); + return ClosedPatternEdge{ + std::get(e.raw_edge) + }; +} + +InputPatternEdge require_input_edge(PatternEdge const &e) { + assert (is_input_edge(e)); + return InputPatternEdge{ + std::get(e.raw_edge) + }; +} + +OutputPatternEdge require_output_edge(PatternEdge const &e) { + assert (is_output_edge(e)); + return OutputPatternEdge{ + std::get(e.raw_edge) + }; +} + +PatternEdge pattern_edge_from_input_edge(InputPatternEdge const &e) { + return PatternEdge{OpenMultiDiEdge{e.raw_edge}}; +} + +PatternEdge pattern_edge_from_output_edge(OutputPatternEdge const &e) { + return PatternEdge{OpenMultiDiEdge{e.raw_edge}}; +} + +PatternEdge pattern_edge_from_closed_edge(ClosedPatternEdge const &e) { + return PatternEdge{OpenMultiDiEdge{e.raw_edge}}; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc new file mode 100644 index 0000000000..e4d11d0d7e --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "a3eff166b0c8be2ddf3f7305eec094fd" +} +*/ + +#include "substitutions/unlabelled/pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +PatternEdge::PatternEdge(::FlexFlow::OpenMultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool PatternEdge::operator==(PatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool PatternEdge::operator!=(PatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool PatternEdge::operator<(PatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool PatternEdge::operator>(PatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool PatternEdge::operator<=(PatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool PatternEdge::operator>=(PatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::PatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OpenMultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc new file mode 100644 index 0000000000..f6678c0e3c --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -0,0 +1,74 @@ +#include "substitutions/unlabelled/pattern_matching.h" +#include "substitutions/unlabelled/match_split.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "substitutions/unlabelled/pattern_edge.h" +#include +#include "substitutions/unlabelled/input_pattern_edge.h" +#include "substitutions/unlabelled/output_pattern_edge.h" +#include "substitutions/unlabelled/pattern_split.h" + +namespace FlexFlow { + +bool unlabelled_pattern_does_match(UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MultiDiGraphPatternMatch const &match, + MatchAdditionalCriterion const &additional_criterion) { + if (is_singleton_pattern(pattern)) { + PatternNode pattern_node = get_only(get_nodes(pattern)); + Node matched_node = match.node_assignment.at_l(pattern_node); + if (!additional_criterion.node_criterion(pattern_node, + matched_node)) { + return false; + } + for (PatternEdge const &e : get_edges(pattern)) { + OpenMultiDiEdge matched_edge = match.edge_assignment.at_l(e); + + assert(is_input_edge(e) || is_output_edge(e)); + if (is_input_edge(e)) { + if (is_output_edge(matched_edge)) { + return false; + } + UpwardOpenMultiDiEdge matched_edge = + narrow(matched_edge).value(); + InputPatternEdge input_edge = require_input_edge(e); + if (match.node_assignment.at_l(get_dst_node(input_edge)) != + get_dst_node(matched_edge)) { + return false; + } + } else { + if (is_input_edge(matched_edge)) { + return false; + } + DownwardOpenMultiDiEdge matched_edge = + narrow(matched_edge).value(); + OutputPatternEdge output_edge = require_output_edge(e); + if (match.node_assignment.at_l(get_src_node(output_edge)) != + get_src_node(matched_edge)) { + return false; + } + } + + if (!additional_criterion.edge_criterion(e, matched_edge)) { + return false; + } + } + + return true; + } + + PatternSplit split = find_even_split(pattern); + std::pair subpatterns = apply_split(pattern, split); + auto submatches = apply_split(pattern, match, split); + + return unlabelled_pattern_does_match(subpatterns.first, + graph, + submatches.prefix_submatch, + additional_criterion) && + unlabelled_pattern_does_match(subpatterns.second, + graph, + submatches.postfix_submatch, + additional_criterion); +} + + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc new file mode 100644 index 0000000000..6ea64de69e --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml +/* proj-data +{ + "generated_from": "a0e58ade010a9b250d2c1c378fde2639" +} +*/ + +#include "substitutions/unlabelled/pattern_node.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +PatternNode::PatternNode(::FlexFlow::Node const &raw_node) + : raw_node(raw_node) {} +bool PatternNode::operator==(PatternNode const &other) const { + return std::tie(this->raw_node) == std::tie(other.raw_node); +} +bool PatternNode::operator!=(PatternNode const &other) const { + return std::tie(this->raw_node) != std::tie(other.raw_node); +} +bool PatternNode::operator<(PatternNode const &other) const { + return std::tie(this->raw_node) < std::tie(other.raw_node); +} +bool PatternNode::operator>(PatternNode const &other) const { + return std::tie(this->raw_node) > std::tie(other.raw_node); +} +bool PatternNode::operator<=(PatternNode const &other) const { + return std::tie(this->raw_node) <= std::tie(other.raw_node); +} +bool PatternNode::operator>=(PatternNode const &other) const { + return std::tie(this->raw_node) >= std::tie(other.raw_node); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::PatternNode const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.raw_node) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc new file mode 100644 index 0000000000..573b562395 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc @@ -0,0 +1,39 @@ +#include "substitutions/unlabelled/pattern_split.h" + +namespace FlexFlow { + +PatternSplit find_even_split(UnlabelledGraphPattern const &p) { + std::vector topological_ordering = get_topological_ordering(pattern.raw_graph); + assert(topological_ordering.size() >= 2); + + int split_point = topological_ordering.size() / 2; + auto split = vector_split(topological_ordering, split_point); + std::unordered_set prefix(split.first.begin(), split.first.end()); + std::unordered_set postfix(split.second.begin(), split.second.end()); + return {prefix, postfix}; +} + +GraphSplit get_raw_split(PatternSplit const &s) { + return std::pair{ + transform(s.first, [](PatternNode const &n) { return n.raw_node; }), + transform(s.second, [](PatternNode const &n) { return n.raw_node; }), + }; +} + +UnlabelledPatternEdgeSplits get_edge_splits(UnlabelledGraphPattern const &pattern, PatternSplit const &split) { + bidict> raw_result = get_edge_splits( + pattern.raw_graph, + get_raw_split(split), + ); + return UnlabelledPatternEdgeSplits{raw_result}; +} + +std::pair + apply_split(UnlabelledGraphPattern const &p, PatternSplit const &s) { + return { + get_subgraph(p, s.left); + get_subgraph(p, s.right); + }; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc new file mode 100644 index 0000000000..bbcd4c3902 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc @@ -0,0 +1,60 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml +/* proj-data +{ + "generated_from": "8604edb5bd1a546ffa94ef496888e46d" +} +*/ + +#include "substitutions/unlabelled/pattern_split.dtg.h" + +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +PatternSplit::PatternSplit( + std::unordered_set<::FlexFlow::PatternNode> const &first, + std::unordered_set<::FlexFlow::PatternNode> const &second) + : first(first), second(second) {} +bool PatternSplit::operator==(PatternSplit const &other) const { + return std::tie(this->first, this->second) == + std::tie(other.first, other.second); +} +bool PatternSplit::operator!=(PatternSplit const &other) const { + return std::tie(this->first, this->second) != + std::tie(other.first, other.second); +} +} // namespace FlexFlow + +namespace nlohmann { +FlexFlow::PatternSplit + adl_serializer::from_json(json const &j) { + return { + j.at("first").template get>(), + j.at("second") + .template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::PatternSplit const &v) { + j["__type"] = "PatternSplit"; + j["first"] = v.first; + j["second"] = v.second; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(PatternSplit const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, PatternSplit const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc new file mode 100644 index 0000000000..858b3197a8 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -0,0 +1,43 @@ +#include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "utils/containers.h" + +namespace FlexFlow { + +size_t num_nodes(UnlabelledGraphPattern const &p) { + return num_nodes(p.raw_graph); +} + +bool is_singleton_pattern(UnlabelledGraphPattern const &pattern) { + return num_nodes(pattern) == 1; +} + +std::unordered_set get_nodes(UnlabelledGraphPattern const &p) { + return transform(get_nodes(p.raw_graph), + [](Node const &n) { return PatternNode{n}; }}); +} + +std::unordered_set get_edges(UnlabelledGraphPattern const &p) { + return transform(get_nodes(p.raw_graph), + [](OpenMultiDiEdge const &e) { return PatternEdge{e}; }}); +} + +std::vector get_topological_ordering(UnlabelledGraphPattern const &p) { + return transform(get_topological_ordering(p), + [](Node const &n) { return PatternNode{n}; }}); +} + +UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &p, std::unordered_set const &n) { + return { + get_subgraph(p.raw_graph, transform(n, [](PatternNode const &n) { return n.raw_node; })); + }; +} + +std::unordered_set get_incoming_edges(UnlabelledGraphPattern const &p, PatternNode const &n) { + return transform(get_incoming_edges(p.raw_graph, n.raw_node), [](Node const &n) { return PatternNode{n}; }); +} + +std::unordered_set get_outgoing_edges(UnlabelledGraphPattern const &p, PatternNode const &n) { + return transform(get_outgoing_edges(p.raw_graph, n.raw_node), [](Node const &n) { return PatternNode{n}; }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc new file mode 100644 index 0000000000..019209ee86 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc @@ -0,0 +1,18 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml +/* proj-data +{ + "generated_from": "f494ed79eb1ba4010155e456b452157f" +} +*/ + +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +UnlabelledGraphPattern::UnlabelledGraphPattern( + ::FlexFlow::OpenMultiDiGraphView const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.cc new file mode 100644 index 0000000000..8664f3c66c --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.cc @@ -0,0 +1,9 @@ +#include "substitutions/unlabelled/upward_open_pattern_edge.h" + +namespace FlexFlow { + +int get_dst_idx(UpwardOpenPatternEdge const &e) { + return get_src_idx(e.raw_edge); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc new file mode 100644 index 0000000000..ca8dd6c020 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc @@ -0,0 +1,52 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "a1d4c9d1dd94eb456c5e29d80ad579da" +} +*/ + +#include "substitutions/unlabelled/upward_open_pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +UpwardOpenPatternEdge::UpwardOpenPatternEdge( + ::FlexFlow::UpwardOpenMultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool UpwardOpenPatternEdge::operator==( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool UpwardOpenPatternEdge::operator!=( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool UpwardOpenPatternEdge::operator<( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool UpwardOpenPatternEdge::operator>( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool UpwardOpenPatternEdge::operator<=( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool UpwardOpenPatternEdge::operator>=( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::UpwardOpenPatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::UpwardOpenMultiDiEdge>{}(x.raw_edge) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/utils/include/utils/bidict.h b/lib/utils/include/utils/bidict.h index 0869b0f9e8..08c286d842 100644 --- a/lib/utils/include/utils/bidict.h +++ b/lib/utils/include/utils/bidict.h @@ -3,6 +3,7 @@ #include #include +#include "utils/fmt/unordered_map.h" namespace FlexFlow { @@ -55,9 +56,22 @@ struct bidict { bwd_map.insert({lr.second, lr.first}); } + bool operator==(bidict const &other) const { + bool result = this->fwd_map == other.fwd_map; + assert (result == (this->bwd_map == other.bwd_map)); + return result; + } + + bool operator!=(bidict const &other) const { + bool result = this->fwd_map != other.fwd_map; + assert (result == (this->bwd_map != other.bwd_map)); + return result; + } + R const &at_l(L const &l) const { return fwd_map.at(l); } + L const &at_r(R const &r) const { return bwd_map.at(r); } @@ -163,6 +177,22 @@ struct bidict { std::unordered_map bwd_map; }; +template +std::unordered_map format_as(bidict const &b) { + return b; +} + } // namespace FlexFlow +namespace std { + +template +struct hash<::FlexFlow::bidict> { + size_t operator()(::FlexFlow::bidict const &b) const { + return hash>{}(b); + } +}; + +} + #endif diff --git a/lib/utils/include/utils/check_fmtable.h b/lib/utils/include/utils/check_fmtable.h new file mode 100644 index 0000000000..3b4e55c459 --- /dev/null +++ b/lib/utils/include/utils/check_fmtable.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CHECK_FMTABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CHECK_FMTABLE_H + +#define CHECK_FMTABLE(...) \ + static_assert(::FlexFlow::is_fmtable<__VA_ARGS__>::value, \ + #__VA_ARGS__ " must be fmtable"); + +namespace FlexFlow { + +template +using is_fmtable = ::fmt::is_formattable; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 0332a331b2..37b022106f 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -16,19 +16,6 @@ struct get_element_type; template using get_element_type_t = typename get_element_type::type; -template -std::string join_strings(InputIt first, - InputIt last, - std::string const &delimiter, - F const &f); - -template -std::string - join_strings(InputIt first, InputIt last, std::string const &delimiter); - -template -std::string join_strings(Container const &c, std::string const &delimiter); - template typename Container::const_iterator find(Container const &c, typename Container::value_type const &e); @@ -154,6 +141,9 @@ template > bidict generate_bidict(C const &c, F const &f); +template +std::optional at_idx(std::vector const &v, size_t idx); + template std::function lookup_in(std::unordered_map const &m); @@ -208,6 +198,10 @@ void extend(C &lhs, std::optional const &e); template bool all_of(C const &c, F const &f); +template +std::optional optional_all_of(Container const &, + Function const &); + template int count(C const &c, F const &f); diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 1606eb0605..ed4e4c0a2b 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -21,38 +21,6 @@ namespace FlexFlow { -template -std::string join_strings(InputIt first, - InputIt last, - std::string const &delimiter, - F const &f) { - std::ostringstream oss; - bool first_iter = true; - /* int i = 0; */ - for (; first != last; first++) { - if (!first_iter) { - oss << delimiter; - } - oss << *first; - /* break; */ - first_iter = false; - /* i++; */ - } - return oss.str(); -} - -template -std::string - join_strings(InputIt first, InputIt last, std::string const &delimiter) { - using Ref = typename InputIt::reference; - return join_strings(first, last, delimiter, [](Ref r) { return r; }); -} - -template -std::string join_strings(Container const &c, std::string const &delimiter) { - return join_strings(c.cbegin(), c.cend(), delimiter); -} - template typename Container::const_iterator find(Container const &c, typename Container::value_type const &e) { @@ -346,6 +314,15 @@ bidict generate_bidict(C const &c, F const &f) { return {transformed.cbegin(), transformed.cend()}; } +template +std::optional at_idx(std::vector const &v, size_t idx) { + if (idx >= v.size()) { + return std::nullopt; + } else { + return v.at(idx); + } +} + template std::function lookup_in(std::unordered_map const &m) { return [&m](K const &k) -> V { return m.at(k); }; @@ -468,6 +445,22 @@ bool all_of(C const &c, F const &f) { return true; } +template +std::optional optional_all_of(Container const &container, + Function const &func) { + for (auto const &element : container) { + std::optional condition = func(element); + if (!condition.has_value()) { + return std::nullopt; + } + + if (!condition.value()) { + return false; + } + } + return true; +} + template int count(C const &c, F const &f) { int result = 0; diff --git a/lib/utils/include/utils/fmt.decl.h b/lib/utils/include/utils/fmt.decl.h index eeebaf5d88..d38b36037b 100644 --- a/lib/utils/include/utils/fmt.decl.h +++ b/lib/utils/include/utils/fmt.decl.h @@ -5,10 +5,7 @@ #include #include #include - -#define CHECK_FMTABLE(...) \ - static_assert(::FlexFlow::is_fmtable<__VA_ARGS__>::value, \ - #__VA_ARGS__ " must be fmtable"); +#include "utils/check_fmtable.h" #define DELEGATE_OSTREAM(...) \ template <> \ @@ -16,9 +13,6 @@ namespace FlexFlow { -template -using is_fmtable = ::fmt::is_formattable; - template struct delegate_ostream_operator : std::false_type {}; @@ -31,15 +25,26 @@ typename std::enable_if>::value, namespace fmt { -template -struct formatter<::std::unordered_set> : formatter<::std::string> { +template +struct formatter< + ::std::unordered_set, + Char, + std::enable_if_t>::value> +> : formatter<::std::string, Char> { template auto format(::std::unordered_set const &m, FormatContext &ctx) -> decltype(ctx.out()); }; -template -struct formatter<::std::vector> : formatter<::std::string> { +/* template */ +/* std::string format_as(::std::unordered_set const &); */ + +template +struct formatter< + ::std::vector, + Char, + std::enable_if_t>::value> +> : formatter<::std::string> { template auto format(::std::vector const &m, FormatContext &ctx) -> decltype(ctx.out()); diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index fe1a2ca979..c976174ded 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -7,47 +7,41 @@ #include "utils/type_traits_core.h" #include #include - -namespace FlexFlow { - -template -struct delegate_ostream_operator> : std::true_type {}; - -template -struct delegate_ostream_operator> : std::true_type {}; - -template -struct delegate_ostream_operator> : std::true_type {}; - -template -typename std::enable_if>::value, - std::ostream &>::type - operator<<(std::ostream &s, T t) { - CHECK_FMTABLE(T); - - return s << fmt::to_string(t); -} - -} // namespace FlexFlow +#include +#include namespace fmt { -template +template template -auto formatter<::std::unordered_set>::format( +auto formatter< + ::std::unordered_set, + Char, + std::enable_if_t>::value> +>::format( ::std::unordered_set const &m, FormatContext &ctx) -> decltype(ctx.out()) { - CHECK_FMTABLE(T); + /* CHECK_FMTABLE(T); */ - std::string result = ::FlexFlow::join_strings( - m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); + /* std::string result = ::FlexFlow::join_strings( */ + /* m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); */ + std::string result = ""; return formatter::format(result, ctx); } -template +/* template */ +/* std::string format_as(::std::unordered_set const &m) { */ +/* return::string result = ::FlexFlow::join_strings( */ +/* m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); */ +/* } */ + +template template -auto formatter<::std::vector>::format(::std::vector const &m, - FormatContext &ctx) +auto formatter< + ::std::vector, + Char, + std::enable_if_t>::value> +>::format(::std::vector const &m, FormatContext &ctx) -> decltype(ctx.out()) { CHECK_FMTABLE(T); @@ -68,4 +62,32 @@ auto formatter<::std::variant>::format(::std::variant const &m, } // namespace fmt +namespace FlexFlow { + +template +struct delegate_ostream_operator> : std::true_type {}; + +template +struct delegate_ostream_operator> : std::true_type {}; + +template +struct delegate_ostream_operator> : std::true_type {}; + +template +struct delegate_ostream_operator> : std::true_type {}; + +template +struct delegate_ostream_operator> : std::true_type {}; + +template +typename std::enable_if>::value, + std::ostream &>::type + operator<<(std::ostream &s, T t) { + CHECK_FMTABLE(T); + + return s << fmt::to_string(t); +} + +} // namespace FlexFlow + #endif diff --git a/lib/utils/include/utils/fmt/unordered_map.h b/lib/utils/include/utils/fmt/unordered_map.h new file mode 100644 index 0000000000..8c9125e35a --- /dev/null +++ b/lib/utils/include/utils/fmt/unordered_map.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_MAP_H + +#include "fmt/format.h" +#include +#include "utils/join_strings.h" +#include "utils/check_fmtable.h" + +namespace fmt { + +template +struct formatter< + ::std::unordered_map, + Char, + std::enable_if_t>::value> +> : formatter<::std::string> { + template + auto format(::std::unordered_map const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + CHECK_FMTABLE(K); + CHECK_FMTABLE(V); + + std::string result = ::FlexFlow::join_strings( + m.cbegin(), m.cend(), ", ", [](std::pair const &p) { return fmt::to_string(p); }); + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::unordered_map const &m) { + CHECK_FMTABLE(K); + CHECK_FMTABLE(V); + + return s << fmt::to_string(m); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/join_strings.h b/lib/utils/include/utils/join_strings.h new file mode 100644 index 0000000000..9c761fc9ac --- /dev/null +++ b/lib/utils/include/utils/join_strings.h @@ -0,0 +1,44 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JOIN_STRINGS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JOIN_STRINGS_H + +#include +#include + +namespace FlexFlow { + +template +std::string join_strings(InputIt first, + InputIt last, + std::string const &delimiter, + F const &f) { + std::ostringstream oss; + bool first_iter = true; + /* int i = 0; */ + for (; first != last; first++) { + if (!first_iter) { + oss << delimiter; + } + oss << *first; + /* break; */ + first_iter = false; + /* i++; */ + } + return oss.str(); +} + +template +std::string + join_strings(InputIt first, InputIt last, std::string const &delimiter) { + using Ref = typename InputIt::reference; + return join_strings(first, last, delimiter, [](Ref r) { return r; }); +} + +template +std::string join_strings(Container const &c, std::string const &delimiter) { + return join_strings(c.cbegin(), c.cend(), delimiter); +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json.h b/lib/utils/include/utils/json.h index 1bf86f0cf7..46176f366e 100644 --- a/lib/utils/include/utils/json.h +++ b/lib/utils/include/utils/json.h @@ -143,8 +143,7 @@ struct VariantToJsonFunctor { void operator()(T const &t) { static_assert(is_jsonable::value, ""); - j["type"] = get_name(t); - j["value"] = t; + j = t; } }; diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 71b6d9d975..6133a27832 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -27,8 +27,12 @@ T const &assert_unwrap(std::optional const &o) { namespace fmt { -template -struct formatter<::std::optional> : formatter { +template +struct formatter< + ::std::optional, + Char, + std::enable_if_t>::value> +> : formatter { template auto format(::std::optional const &q, FormatContext &ctx) -> decltype(ctx.out()) { diff --git a/lib/utils/include/utils/overload.h b/lib/utils/include/utils/overload.h new file mode 100644 index 0000000000..7e0431eba5 --- /dev/null +++ b/lib/utils/include/utils/overload.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OVERLOAD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OVERLOAD_H + +namespace FlexFlow { + +template +struct overload : Ts... { + using Ts::operator()...; +}; +template +overload(Ts...) -> overload; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/stack_string.h b/lib/utils/include/utils/stack_string.h index 71b092d2c1..fccbbb3810 100644 --- a/lib/utils/include/utils/stack_string.h +++ b/lib/utils/include/utils/stack_string.h @@ -7,6 +7,7 @@ #include "utils/type_traits.h" #include #include +#include "utils/json.h" namespace FlexFlow { @@ -64,6 +65,20 @@ struct stack_basic_string { template using stack_string = stack_basic_string; +template +void to_json(json &j, stack_string const &v) { + std::string as_string = v; + j = as_string; +} + +template +void from_json(json const &j, stack_string &v) { + std::string as_string; + j.get_to(as_string); + v = stack_string{as_string}; +} + + } // namespace FlexFlow namespace std { diff --git a/lib/utils/src/utils/overload.cc b/lib/utils/src/utils/overload.cc new file mode 100644 index 0000000000..55bfbdc08d --- /dev/null +++ b/lib/utils/src/utils/overload.cc @@ -0,0 +1 @@ +#include "utils/overload.h" From 55c7cf17a47abe62a5efc93ccb40a3a3218e1d96 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 25 Apr 2024 03:41:10 -0700 Subject: [PATCH 11/43] Add new reduction dim shape inference for conv2d --- lib/op-attrs/include/op-attrs/as_dot.h | 15 + .../op-attrs/computation_graph_op_attrs.dtg.h | 449 +++++++++++++++ .../computation_graph_op_attrs.variant.toml | 142 +++++ lib/op-attrs/include/op-attrs/dim_ordered.h | 12 +- .../include/op-attrs/operator_attrs.h | 130 +++-- lib/op-attrs/include/op-attrs/ops/conv_2d.h | 5 + .../ops/parallel_attention_inputs.dtg.h | 6 +- .../ops/parallel_attention_inputs.struct.toml | 2 +- .../include/op-attrs/parallel_dim.dtg.h | 111 +++- .../op-attrs/parallel_dim.variant.toml | 22 + .../op-attrs/parallel_tensor_dims.dtg.h | 18 +- .../include/op-attrs/parallel_tensor_dims.h | 19 +- .../op-attrs/parallel_tensor_dims.struct.toml | 16 +- .../op-attrs/parallel_tensor_shape.dtg.h | 6 +- .../include/op-attrs/parallel_tensor_shape.h | 9 +- .../parallel_tensor_shape.struct.toml | 2 +- .../include/op-attrs/pcg_operator_attrs.dtg.h | 190 ++++--- .../include/op-attrs/pcg_operator_attrs.h | 12 + .../op-attrs/pcg_operator_attrs.variant.toml | 42 +- .../op-attrs/replica_parallel_dim.dtg.h | 65 +++ .../include/op-attrs/replica_parallel_dim.h | 12 + .../op-attrs/replica_parallel_dim.struct.toml | 22 + .../op-attrs/replica_parallel_dim_set.dtg.h | 63 +++ .../op-attrs/replica_parallel_dim_set.h | 17 + .../replica_parallel_dim_set.struct.toml | 18 + .../include/op-attrs/replica_type.dtg.h | 40 ++ .../include/op-attrs/replica_type.enum.toml | 14 + .../include/op-attrs/shard_parallel_dim.dtg.h | 63 +++ .../include/op-attrs/shard_parallel_dim.h | 12 + ...ct.toml => shard_parallel_dim.struct.toml} | 6 +- .../include/op-attrs/tensor_shape.dtg.h | 6 +- lib/op-attrs/include/op-attrs/tensor_shape.h | 1 + .../include/op-attrs/tensor_shape.struct.toml | 4 +- lib/op-attrs/src/batch_norm.cc | 3 - lib/op-attrs/src/conv_2d.cc | 115 ---- lib/op-attrs/src/linear.cc | 3 - lib/op-attrs/src/noop.cc | 1 - lib/op-attrs/src/op-attrs/as_dot.cc | 13 + .../computation_graph_op_attrs.dtg.cc | 521 ++++++++++++++++++ lib/op-attrs/src/op-attrs/ops/attention.cc | 7 +- lib/op-attrs/src/{ => op-attrs/ops}/cast.cc | 0 .../src/{ => op-attrs/ops}/combine.cc | 0 lib/op-attrs/src/{ => op-attrs/ops}/concat.cc | 0 lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 203 +++++++ .../op-attrs/ops/conv_2d_input_shape.dtg.cc | 157 ++++++ .../op-attrs/ops/conv_2d_input_shape.dtg.h | 72 +++ .../ops/conv_2d_input_shape.struct.toml | 35 ++ .../ops/conv_2d_parallel_input_shape.dtg.cc | 211 +++++++ .../ops/conv_2d_parallel_input_shape.dtg.h | 76 +++ .../conv_2d_parallel_input_shape.struct.toml | 43 ++ .../src/{ => op-attrs/ops}/element_binary.cc | 0 .../src/{ => op-attrs/ops}/element_unary.cc | 0 lib/op-attrs/src/{ => op-attrs/ops}/flat.cc | 0 lib/op-attrs/src/{ => op-attrs/ops}/gather.cc | 0 .../ops/parallel_attention_inputs.dtg.cc | 22 +- lib/op-attrs/src/op-attrs/ops/pool_2d.cc | 48 ++ lib/op-attrs/src/op-attrs/ops/reduction.cc | 8 + lib/op-attrs/src/op-attrs/ops/repartition.cc | 6 + lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc | 123 +++-- .../src/op-attrs/parallel_tensor_dims.cc | 51 +- .../src/op-attrs/parallel_tensor_dims.dtg.cc | 49 +- .../src/op-attrs/parallel_tensor_shape.cc | 29 +- .../src/op-attrs/parallel_tensor_shape.dtg.cc | 18 +- .../src/op-attrs/pcg_operator_attrs.cc | 12 + .../src/op-attrs/pcg_operator_attrs.dtg.cc | 298 ++++++---- .../src/op-attrs/replica_parallel_dim.cc | 9 + .../src/op-attrs/replica_parallel_dim.dtg.cc | 91 +++ .../src/op-attrs/replica_parallel_dim_set.cc | 32 ++ .../op-attrs/replica_parallel_dim_set.dtg.cc | 96 ++++ lib/op-attrs/src/op-attrs/replica_type.dtg.cc | 70 +++ .../src/op-attrs/shard_parallel_dim.cc | 9 + .../src/op-attrs/shard_parallel_dim.dtg.cc | 89 +++ lib/op-attrs/src/op-attrs/tensor_dims.cc | 10 +- .../src/{ => op-attrs}/tensor_shape.cc | 5 + lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc | 6 +- lib/op-attrs/src/parallel_dim.cc | 12 - lib/op-attrs/src/pool_2d.cc | 47 -- lib/op-attrs/src/reduce.cc | 3 - lib/op-attrs/src/reduction.cc | 13 - lib/op-attrs/src/repartition.cc | 11 - lib/op-attrs/src/replicate.cc | 3 - lib/op-attrs/src/reshape.cc | 3 - lib/op-attrs/src/softmax.cc | 3 - lib/op-attrs/src/split.cc | 3 - lib/op-attrs/src/topk.cc | 3 - lib/op-attrs/src/transpose.cc | 3 - lib/op-attrs/test/src/test_conv_2d.cc | 67 +++ lib/op-attrs/test/src/test_operator_attrs.cc | 17 +- lib/utils/include/utils/exception.decl.h | 6 +- lib/utils/include/utils/fmt.decl.h | 10 + lib/utils/include/utils/fmt.h | 20 +- lib/utils/include/utils/fmt/pair.h | 20 + lib/utils/include/utils/fmt/unordered_map.h | 9 +- lib/utils/src/exception.cc | 4 +- 94 files changed, 3626 insertions(+), 723 deletions(-) create mode 100644 lib/op-attrs/include/op-attrs/as_dot.h create mode 100644 lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml create mode 100644 lib/op-attrs/include/op-attrs/parallel_dim.variant.toml create mode 100644 lib/op-attrs/include/op-attrs/pcg_operator_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/replica_parallel_dim.h create mode 100644 lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/replica_parallel_dim_set.h create mode 100644 lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/replica_type.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/replica_type.enum.toml create mode 100644 lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/shard_parallel_dim.h rename lib/op-attrs/include/op-attrs/{parallel_dim.struct.toml => shard_parallel_dim.struct.toml} (72%) delete mode 100644 lib/op-attrs/src/batch_norm.cc delete mode 100644 lib/op-attrs/src/conv_2d.cc delete mode 100644 lib/op-attrs/src/linear.cc delete mode 100644 lib/op-attrs/src/noop.cc create mode 100644 lib/op-attrs/src/op-attrs/as_dot.cc create mode 100644 lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc rename lib/op-attrs/src/{ => op-attrs/ops}/cast.cc (100%) rename lib/op-attrs/src/{ => op-attrs/ops}/combine.cc (100%) rename lib/op-attrs/src/{ => op-attrs/ops}/concat.cc (100%) create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.dtg.h create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.struct.toml create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.dtg.h create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.struct.toml rename lib/op-attrs/src/{ => op-attrs/ops}/element_binary.cc (100%) rename lib/op-attrs/src/{ => op-attrs/ops}/element_unary.cc (100%) rename lib/op-attrs/src/{ => op-attrs/ops}/flat.cc (100%) rename lib/op-attrs/src/{ => op-attrs/ops}/gather.cc (100%) create mode 100644 lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/replica_parallel_dim.cc create mode 100644 lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc create mode 100644 lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/replica_type.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/shard_parallel_dim.cc create mode 100644 lib/op-attrs/src/op-attrs/shard_parallel_dim.dtg.cc rename lib/op-attrs/src/{ => op-attrs}/tensor_shape.cc (60%) delete mode 100644 lib/op-attrs/src/parallel_dim.cc delete mode 100644 lib/op-attrs/src/pool_2d.cc delete mode 100644 lib/op-attrs/src/reduce.cc delete mode 100644 lib/op-attrs/src/reduction.cc delete mode 100644 lib/op-attrs/src/repartition.cc delete mode 100644 lib/op-attrs/src/replicate.cc delete mode 100644 lib/op-attrs/src/reshape.cc delete mode 100644 lib/op-attrs/src/softmax.cc delete mode 100644 lib/op-attrs/src/split.cc delete mode 100644 lib/op-attrs/src/topk.cc delete mode 100644 lib/op-attrs/src/transpose.cc create mode 100644 lib/op-attrs/test/src/test_conv_2d.cc create mode 100644 lib/utils/include/utils/fmt/pair.h diff --git a/lib/op-attrs/include/op-attrs/as_dot.h b/lib/op-attrs/include/op-attrs/as_dot.h new file mode 100644 index 0000000000..4a5d9eaeb1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/as_dot.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AS_DOT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AS_DOT_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "utils/record_formatter.h" + +namespace FlexFlow { + +RecordFormatter as_dot(ComputationGraphOpAttrs const &); +RecordFormatter as_dot(PCGOperatorAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h new file mode 100644 index 0000000000..412bd7aea0 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h @@ -0,0 +1,449 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml +/* proj-data +{ + "generated_from": "87653647c900faaf564d3069478569e7" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ops/attention_attrs.dtg.h" +#include "op-attrs/ops/batch_matmul.dtg.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" +#include "op-attrs/ops/broadcast.dtg.h" +#include "op-attrs/ops/cast_attrs.dtg.h" +#include "op-attrs/ops/concat_attrs.dtg.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" +#include "op-attrs/ops/flat_attrs.dtg.h" +#include "op-attrs/ops/gather_attrs.dtg.h" +#include "op-attrs/ops/input_attrs.dtg.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" +#include "op-attrs/ops/linear_attrs.dtg.h" +#include "op-attrs/ops/noop_attrs.dtg.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" +#include "op-attrs/ops/split_attrs.dtg.h" +#include "op-attrs/ops/topk_attrs.dtg.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct ComputationGraphOpAttrs { + ComputationGraphOpAttrs() = delete; + explicit ComputationGraphOpAttrs(::FlexFlow::BatchMatmulAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::BatchNormAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::BroadcastAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::CastAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ConcatAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::Conv2DAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::DropoutAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ElementBinaryAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ElementUnaryAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ElementScalarUnaryAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::EmbeddingAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::FlatAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::GatherAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::InputAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::LayerNormAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::LinearAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::MultiHeadAttentionAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::NoopAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::Pool2DAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ReduceAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ReverseAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ReshapeAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::SplitAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::SoftmaxAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::TopKAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::TransposeAttrs const &); + template + static constexpr bool IsPartOfComputationGraphOpAttrs_v = + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::BatchMatmulAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::BatchNormAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::BroadcastAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::CastAttrs>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); + return result; + } + case 7: { + ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); + return result; + } + case 9: { + ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); + return result; + } + case 13: { + ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); + return result; + } + case 14: { + ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); + return result; + } + case 15: { + ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); + return result; + } + case 16: { + ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); + return result; + } + case 17: { + ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); + return result; + } + case 18: { + ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); + return result; + } + case 19: { + ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); + return result; + } + case 20: { + ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); + return result; + } + case 21: { + ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + return result; + } + case 22: { + ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + return result; + } + case 23: { + ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + return result; + } + case 24: { + ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + return result; + } + case 25: { + ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type ComputationGraphOpAttrs", + this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::BatchMatmulAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::BatchNormAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::BroadcastAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::CastAttrs>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); + return result; + } + case 7: { + ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); + return result; + } + case 9: { + ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); + return result; + } + case 13: { + ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); + return result; + } + case 14: { + ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); + return result; + } + case 15: { + ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); + return result; + } + case 16: { + ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); + return result; + } + case 17: { + ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); + return result; + } + case 18: { + ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); + return result; + } + case 19: { + ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); + return result; + } + case 20: { + ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); + return result; + } + case 21: { + ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + return result; + } + case 22: { + ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + return result; + } + case 23: { + ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + return result; + } + case 24: { + ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + return result; + } + case 25: { + ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type ComputationGraphOpAttrs", + this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfComputationGraphOpAttrs_v, + "ComputationGraphOpAttrs::has() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::BroadcastAttrs, ::FlexFlow::CastAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReverseAttrs, ::FlexFlow::ReshapeAttrs, " + "::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, " + "::FlexFlow::TopKAttrs, ::FlexFlow::TransposeAttrs], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfComputationGraphOpAttrs_v, + "ComputationGraphOpAttrs::get() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::BroadcastAttrs, ::FlexFlow::CastAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReverseAttrs, ::FlexFlow::ReshapeAttrs, " + "::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, " + "::FlexFlow::TopKAttrs, ::FlexFlow::TransposeAttrs], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfComputationGraphOpAttrs_v, + "ComputationGraphOpAttrs::get() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::BroadcastAttrs, ::FlexFlow::CastAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReverseAttrs, ::FlexFlow::ReshapeAttrs, " + "::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, " + "::FlexFlow::TopKAttrs, ::FlexFlow::TransposeAttrs], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(ComputationGraphOpAttrs const &) const; + bool operator!=(ComputationGraphOpAttrs const &) const; + bool operator<(ComputationGraphOpAttrs const &) const; + bool operator>(ComputationGraphOpAttrs const &) const; + bool operator<=(ComputationGraphOpAttrs const &) const; + bool operator>=(ComputationGraphOpAttrs const &) const; + std::variant<::FlexFlow::BatchMatmulAttrs, + ::FlexFlow::BatchNormAttrs, + ::FlexFlow::BroadcastAttrs, + ::FlexFlow::CastAttrs, + ::FlexFlow::ConcatAttrs, + ::FlexFlow::Conv2DAttrs, + ::FlexFlow::DropoutAttrs, + ::FlexFlow::ElementBinaryAttrs, + ::FlexFlow::ElementUnaryAttrs, + ::FlexFlow::ElementScalarUnaryAttrs, + ::FlexFlow::EmbeddingAttrs, + ::FlexFlow::FlatAttrs, + ::FlexFlow::GatherAttrs, + ::FlexFlow::InputAttrs, + ::FlexFlow::LayerNormAttrs, + ::FlexFlow::LinearAttrs, + ::FlexFlow::MultiHeadAttentionAttrs, + ::FlexFlow::NoopAttrs, + ::FlexFlow::Pool2DAttrs, + ::FlexFlow::ReduceAttrs, + ::FlexFlow::ReverseAttrs, + ::FlexFlow::ReshapeAttrs, + ::FlexFlow::SplitAttrs, + ::FlexFlow::SoftmaxAttrs, + ::FlexFlow::TopKAttrs, + ::FlexFlow::TransposeAttrs> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::ComputationGraphOpAttrs> { + size_t operator()(::FlexFlow::ComputationGraphOpAttrs const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::ComputationGraphOpAttrs> { + static ::FlexFlow::ComputationGraphOpAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ComputationGraphOpAttrs const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::ComputationGraphOpAttrs const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::ComputationGraphOpAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml new file mode 100644 index 0000000000..fdf9702875 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml @@ -0,0 +1,142 @@ +namespace = "FlexFlow" +name = "ComputationGraphOpAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "op-attrs/ops/attention_attrs.dtg.h", + "op-attrs/ops/batch_matmul.dtg.h", + "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/broadcast.dtg.h", + "op-attrs/ops/cast_attrs.dtg.h", + "op-attrs/ops/concat_attrs.dtg.h", + "op-attrs/ops/conv_2d_attrs.dtg.h", + "op-attrs/ops/dropout_attrs.dtg.h", + "op-attrs/ops/element_binary_attrs.dtg.h", + "op-attrs/ops/element_scalar_unary_attrs.dtg.h", + "op-attrs/ops/element_unary_attrs.dtg.h", + "op-attrs/ops/embedding_attrs.dtg.h", + "op-attrs/ops/flat_attrs.dtg.h", + "op-attrs/ops/gather_attrs.dtg.h", + "op-attrs/ops/input_attrs.dtg.h", + "op-attrs/ops/layer_norm_attrs.dtg.h", + "op-attrs/ops/linear_attrs.dtg.h", + "op-attrs/ops/noop_attrs.dtg.h", + "op-attrs/ops/pool_2d_attrs.dtg.h", + "op-attrs/ops/reduce_attrs.dtg.h", + "op-attrs/ops/reshape_attrs.dtg.h", + "op-attrs/ops/reverse_attrs.dtg.h", + "op-attrs/ops/softmax_attrs.dtg.h", + "op-attrs/ops/split_attrs.dtg.h", + "op-attrs/ops/topk_attrs.dtg.h", + "op-attrs/ops/transpose_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::BatchMatmulAttrs" +key = "batch_matmul" + +[[values]] +type = "::FlexFlow::BatchNormAttrs" +key = "batch_norm" + +[[values]] +type = "::FlexFlow::BroadcastAttrs" +key = "broadcast" + +[[values]] +type = "::FlexFlow::CastAttrs" +key = "cast" + +[[values]] +type = "::FlexFlow::ConcatAttrs" +key = "concat" + +[[values]] +type = "::FlexFlow::Conv2DAttrs" +key = "conv2d" + +[[values]] +type = "::FlexFlow::DropoutAttrs" +key = "dropout" + +[[values]] +type = "::FlexFlow::ElementBinaryAttrs" +key = "element_binary" + +[[values]] +type = "::FlexFlow::ElementUnaryAttrs" +key = "element_unary" + +[[values]] +type = "::FlexFlow::ElementScalarUnaryAttrs" +key = "element_scalar_unary" + +[[values]] +type = "::FlexFlow::EmbeddingAttrs" +key = "embedding" + +[[values]] +type = "::FlexFlow::FlatAttrs" +key = "flat" + +[[values]] +type = "::FlexFlow::GatherAttrs" +key = "gather" + +[[values]] +type = "::FlexFlow::InputAttrs" +key = "input" + +[[values]] +type = "::FlexFlow::LayerNormAttrs" +key = "layer_norm" + +[[values]] +type = "::FlexFlow::LinearAttrs" +key = "linear" + +[[values]] +type = "::FlexFlow::MultiHeadAttentionAttrs" +key = "multi_head_attention" + +[[values]] +type = "::FlexFlow::NoopAttrs" +key = "noop" + +[[values]] +type = "::FlexFlow::Pool2DAttrs" +key = "pool2d" + +[[values]] +type = "::FlexFlow::ReduceAttrs" +key = "reduce" + +[[values]] +type = "::FlexFlow::ReverseAttrs" +key = "reverse" + +[[values]] +type = "::FlexFlow::ReshapeAttrs" +key = "reshape" + +[[values]] +type = "::FlexFlow::SplitAttrs" +key = "split" + +[[values]] +type = "::FlexFlow::SoftmaxAttrs" +key = "softmax" + +[[values]] +type = "::FlexFlow::TopKAttrs" +key = "topk" + +[[values]] +type = "::FlexFlow::TransposeAttrs" +key = "transpose" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index d0e9ef9a4d..b843b96e70 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -126,7 +126,6 @@ struct DimOrdered { } friend struct ::std::hash; - private: stack_vector contents; }; @@ -134,7 +133,16 @@ struct DimOrdered { template using FFOrdered = DimOrdered; -/* CHECK_JSONABLE(FFOrdered); */ +template +std::string format_as(FFOrdered const &v) { + std::vector as_vec(v.cbegin(), v.cend()); + return fmt::format("", as_vec); +} + +template +std::ostream &operator<<(std::ostream &s, FFOrdered const &v) { + return (s << fmt::to_string(v)); +} template auto inner_to_outer(FFOrdered const &ff_ordered) diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index 1821839e5c..b96541d34f 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -38,76 +38,76 @@ namespace FlexFlow { -using SharedOperatorAttrs = std::variant; +/* using SharedOperatorAttrs = std::variant; */ -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ +/* static_assert(is_valid_opattr::value, ""); */ -using ParallelOperatorAttrs = std:: - variant; +/* using ParallelOperatorAttrs = std:: */ +/* variant; */ -using ComputationGraphAttrs = - variant_join>; -using CompGraphOperatorAttrs = ComputationGraphAttrs; +/* using ComputationGraphAttrs = */ +/* variant_join>; */ +/* using CompGraphOperatorAttrs = ComputationGraphAttrs; */ /* using PCGOperatorAttrs = */ /* variant_join; */ -static_assert(is_equal_comparable::value, - "ComputationGraphAttrs must support =="); -static_assert(elements_satisfy::value, - ""); -static_assert(is_neq_comparable::value, - "ComputationGraphAttrs must support !="); -static_assert(is_lt_comparable::value, - "ComputationGraphAttrs must support <"); -static_assert(is_hashable::value, - "ComputationGraphAttrs must be hashable"); +/* static_assert(is_equal_comparable::value, */ +/* "ComputationGraphAttrs must support =="); */ +/* static_assert(elements_satisfy::value, */ +/* ""); */ +/* static_assert(is_neq_comparable::value, */ +/* "ComputationGraphAttrs must support !="); */ +/* static_assert(is_lt_comparable::value, */ +/* "ComputationGraphAttrs must support <"); */ +/* static_assert(is_hashable::value, */ +/* "ComputationGraphAttrs must be hashable"); */ /* static_assert(is_equal_comparable::value, */ /* "PCGOperatorAttrs must support =="); */ @@ -121,14 +121,10 @@ static_assert(is_hashable::value, /* OperatorType get_op_type(CompGraphOperatorAttrs const &); */ /* OperatorType get_op_type(PCGOperatorAttrs const &); */ -RecordFormatter as_dot(CompGraphOperatorAttrs const &); -RecordFormatter as_dot(PCGOperatorAttrs const &); - std::vector get_output_shapes( PCGOperatorAttrs const &op_params, std::vector const &input_tensor_shapes); -bool is_parallel_op(PCGOperatorAttrs const &); bool is_valid(PCGOperatorAttrs const &, std::vector const &); diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index b75628cf8a..5edb680cc8 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -12,6 +12,11 @@ CHECK_VALID_OP_ATTR(Conv2DAttrs); TensorShape get_kernel_shape(Conv2DAttrs const &attrs, TensorShape const &input); TensorShape get_bias_shape(Conv2DAttrs const &attrs, TensorShape const &input); +TensorShape get_output_shape(Conv2DAttrs const &attrs, TensorShape const &input); + +ParallelTensorShape get_kernel_shape(Conv2DAttrs const &attrs, ParallelTensorShape const &input_shape); +ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs, ParallelTensorShape const &input_shape); +ParallelTensorShape get_output_shape(Conv2DAttrs const &attrs, ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h index 92154711e5..c3ee7782c4 100644 --- a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml /* proj-data { - "generated_from": "722d92014b31bffcd5ad45eda476d8b3" + "generated_from": "8d1e2a2d3852bfb59d8668d14d52c958" } */ @@ -27,10 +27,6 @@ struct ParallelMultiHeadAttentionInputs { bool operator==(ParallelMultiHeadAttentionInputs const &) const; bool operator!=(ParallelMultiHeadAttentionInputs const &) const; - bool operator<(ParallelMultiHeadAttentionInputs const &) const; - bool operator>(ParallelMultiHeadAttentionInputs const &) const; - bool operator<=(ParallelMultiHeadAttentionInputs const &) const; - bool operator>=(ParallelMultiHeadAttentionInputs const &) const; ::FlexFlow::ParallelTensorShape query; ::FlexFlow::ParallelTensorShape key; ::FlexFlow::ParallelTensorShape value; diff --git a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml index f7513fee8f..22136a948b 100644 --- a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml @@ -2,7 +2,7 @@ namespace = "FlexFlow" name = "ParallelMultiHeadAttentionInputs" features = [ "eq", - "ord", + # "ord", "hash", "json", # "rapidcheck", diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h b/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h index 3492694685..a5c4fc0b29 100644 --- a/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h +++ b/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h @@ -1,9 +1,9 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/parallel_dim.struct.toml +// lib/op-attrs/include/op-attrs/parallel_dim.variant.toml /* proj-data { - "generated_from": "186bedde7826c7a3d00343ed63ab9971" + "generated_from": "5550fc7ad51892b3411ef274c76e7d85" } */ @@ -12,55 +12,110 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "rapidcheck.h" +#include "op-attrs/replica_parallel_dim.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include #include #include -#include +#include +#include namespace FlexFlow { struct ParallelDim { ParallelDim() = delete; - ParallelDim(size_t const &size, - int const °ree, - bool const &is_replica_dim); - + explicit ParallelDim(::FlexFlow::ShardParallelDim const &); + explicit ParallelDim(::FlexFlow::ReplicaParallelDim const &); + template + static constexpr bool IsPartOfParallelDim_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::ShardParallelDim>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::ReplicaParallelDim>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type ParallelDim", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::ShardParallelDim>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::ReplicaParallelDim>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type ParallelDim", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfParallelDim_v, + "ParallelDim::has() expected one of [::FlexFlow::ShardParallelDim, " + "::FlexFlow::ReplicaParallelDim], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfParallelDim_v, + "ParallelDim::get() expected one of [::FlexFlow::ShardParallelDim, " + "::FlexFlow::ReplicaParallelDim], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfParallelDim_v, + "ParallelDim::get() expected one of [::FlexFlow::ShardParallelDim, " + "::FlexFlow::ReplicaParallelDim], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } bool operator==(ParallelDim const &) const; bool operator!=(ParallelDim const &) const; bool operator<(ParallelDim const &) const; bool operator>(ParallelDim const &) const; bool operator<=(ParallelDim const &) const; bool operator>=(ParallelDim const &) const; - size_t size; - int degree; - bool is_replica_dim; + std::variant<::FlexFlow::ShardParallelDim, ::FlexFlow::ReplicaParallelDim> + raw_variant; }; } // namespace FlexFlow - namespace std { template <> -struct hash { - size_t operator()(FlexFlow::ParallelDim const &) const; +struct hash<::FlexFlow::ParallelDim> { + size_t operator()(::FlexFlow::ParallelDim const &) const; }; } // namespace std - namespace nlohmann { template <> -struct adl_serializer { - static FlexFlow::ParallelDim from_json(json const &); - static void to_json(json &, FlexFlow::ParallelDim const &); +struct adl_serializer<::FlexFlow::ParallelDim> { + static ::FlexFlow::ParallelDim from_json(json const &); + static void to_json(json &, ::FlexFlow::ParallelDim const &); }; } // namespace nlohmann - -namespace rc { -template <> -struct Arbitrary { - static Gen arbitrary(); -}; -} // namespace rc - namespace FlexFlow { -std::string format_as(ParallelDim const &); -std::ostream &operator<<(std::ostream &, ParallelDim const &); +std::string format_as(::FlexFlow::ParallelDim const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::ParallelDim const &); } // namespace FlexFlow #endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml b/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml new file mode 100644 index 0000000000..eceffd38a3 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "ParallelDim" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/replica_parallel_dim.dtg.h", +] + +[[values]] +type = "::FlexFlow::ShardParallelDim" +key = "shard_dim" + +[[values]] +type = "::FlexFlow::ReplicaParallelDim" +key = "replica_dim" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h index ae49a17657..1090244a1b 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml /* proj-data { - "generated_from": "b46ffa08758bdcc57a75183255248ca6" + "generated_from": "141639bdce009a1594501f33c2f25c9e" } */ @@ -13,24 +13,26 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/dim_ordered.h" -#include "op-attrs/parallel_dim.h" +#include "op-attrs/replica_parallel_dim_set.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "utils/fmt/pair.h" +#include "utils/fmt/unordered_map.h" #include #include #include +#include namespace FlexFlow { struct ParallelTensorDims { ParallelTensorDims() = delete; ParallelTensorDims( - ::FlexFlow::FFOrdered<::FlexFlow::ParallelDim> const &ff_ordered); + ::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim> const &shard_dims, + ::FlexFlow::ReplicaParallelDimSet const &replica_dims); bool operator==(ParallelTensorDims const &) const; bool operator!=(ParallelTensorDims const &) const; - bool operator<(ParallelTensorDims const &) const; - bool operator>(ParallelTensorDims const &) const; - bool operator<=(ParallelTensorDims const &) const; - bool operator>=(ParallelTensorDims const &) const; - ::FlexFlow::FFOrdered<::FlexFlow::ParallelDim> ff_ordered; + ::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim> shard_dims; + ::FlexFlow::ReplicaParallelDimSet replica_dims; }; } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index 2e7cb57b99..2bb44c919f 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -7,22 +7,25 @@ namespace FlexFlow { -FFOrdered const &ff_ordered(ParallelTensorDims const &); - -std::vector as_vector(ParallelTensorDims const &); - -int get_num_replica_dims(ParallelTensorDims const &); +FFOrdered ff_ordered_shard_dims(ParallelTensorDims const &); +std::unordered_set replica_dims(ParallelTensorDims const &); /* size_t get_volume(ParallelTensorDims const &); */ -size_t num_dims(ParallelTensorDims const &); +size_t num_shard_dims(ParallelTensorDims const &); -ParallelDim dim_at_idx(ParallelTensorDims const &, ff_dim_t); -ParallelDim &dim_at_idx(ParallelTensorDims &, ff_dim_t); +int total_replica_degree(ParallelTensorDims const &); +int total_shard_degree(ParallelTensorDims const &); +int total_parallel_degree(ParallelTensorDims const &); + +ShardParallelDim shard_dim_at_idx(ParallelTensorDims const &, ff_dim_t); +ShardParallelDim &shard_dim_at_idx(ParallelTensorDims &, ff_dim_t); bool is_valid(ParallelTensorDims const &); TensorDims get_piece_dims(ParallelTensorDims const &); TensorDims get_tensor_dims_unsafe(ParallelTensorDims const &); +TensorDims get_reduced_dims(ParallelTensorDims const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml index 09c5b5ff4f..0d07939ff0 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml @@ -2,7 +2,7 @@ namespace = "FlexFlow" name = "ParallelTensorDims" features = [ "eq", - "ord", + # "ord", "hash", "json", # "rapidcheck", @@ -11,9 +11,17 @@ features = [ includes = [ "op-attrs/dim_ordered.h", - "op-attrs/parallel_dim.h", + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/replica_parallel_dim_set.dtg.h", + "", + "utils/fmt/unordered_map.h", + "utils/fmt/pair.h", ] [[fields]] -name = "ff_ordered" -type = "::FlexFlow::FFOrdered<::FlexFlow::ParallelDim>" +name = "shard_dims" +type = "::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>" + +[[fields]] +name = "replica_dims" +type = "::FlexFlow::ReplicaParallelDimSet" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h index dfad5b1007..b253880764 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml /* proj-data { - "generated_from": "b2d36c9212916e66569af4e958c893f4" + "generated_from": "bc7e838003fe037b95d45cd5ab4aa16f" } */ @@ -26,10 +26,6 @@ struct ParallelTensorShape { bool operator==(ParallelTensorShape const &) const; bool operator!=(ParallelTensorShape const &) const; - bool operator<(ParallelTensorShape const &) const; - bool operator>(ParallelTensorShape const &) const; - bool operator<=(ParallelTensorShape const &) const; - bool operator>=(ParallelTensorShape const &) const; ::FlexFlow::ParallelTensorDims dims; ::FlexFlow::DataType data_type; }; diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 8a60ce0b8d..6b9bde8283 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -8,12 +8,13 @@ namespace FlexFlow { -int num_dims(ParallelTensorShape const &); -ParallelDim dim_at_idx(ParallelTensorShape const &, ff_dim_t); -ParallelDim &dim_at_idx(ParallelTensorShape &, ff_dim_t); +int num_shard_dims(ParallelTensorShape const &); +ShardParallelDim shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); +ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &, ff_dim_t); ParallelTensorShape lift_to_parallel(TensorShape const &); +std::unordered_set replica_dims(ParallelTensorShape const &); TensorShape get_piece_shape(ParallelTensorShape const &); int get_num_replica_dims(ParallelTensorShape const &); int get_num_replicas(ParallelTensorShape const &); @@ -24,6 +25,8 @@ TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &); std::vector get_tensor_shapes_unsafe(std::vector const &); +TensorShape get_reduced_shape(ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml index 1199b0d816..411070848d 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml @@ -2,7 +2,7 @@ namespace = "FlexFlow" name = "ParallelTensorShape" features = [ "eq", - "ord", + # "ord", "hash", "json", # "rapidcheck", diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h index 132a575175..75b203ab3a 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml /* proj-data { - "generated_from": "e1b5c307ae023ce6d504f605c7ef8491" + "generated_from": "cf0da4385b7554748a06ec25ccf17f2f" } */ @@ -53,6 +53,7 @@ struct PCGOperatorAttrs { explicit PCGOperatorAttrs(::FlexFlow::BatchMatmulAttrs const &); explicit PCGOperatorAttrs(::FlexFlow::BatchNormAttrs const &); explicit PCGOperatorAttrs(::FlexFlow::CastAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::CombineAttrs const &); explicit PCGOperatorAttrs(::FlexFlow::ConcatAttrs const &); explicit PCGOperatorAttrs(::FlexFlow::Conv2DAttrs const &); explicit PCGOperatorAttrs(::FlexFlow::DropoutAttrs const &); @@ -69,6 +70,9 @@ struct PCGOperatorAttrs { explicit PCGOperatorAttrs(::FlexFlow::NoopAttrs const &); explicit PCGOperatorAttrs(::FlexFlow::Pool2DAttrs const &); explicit PCGOperatorAttrs(::FlexFlow::ReduceAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ReductionAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::RepartitionAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ReplicateAttrs const &); explicit PCGOperatorAttrs(::FlexFlow::ReverseAttrs const &); explicit PCGOperatorAttrs(::FlexFlow::ReshapeAttrs const &); explicit PCGOperatorAttrs(::FlexFlow::SplitAttrs const &); @@ -80,6 +84,7 @@ struct PCGOperatorAttrs { std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || @@ -96,6 +101,9 @@ struct PCGOperatorAttrs { std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || @@ -118,90 +126,106 @@ struct PCGOperatorAttrs { return result; } case 3: { - ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); + ReturnType result = v(this->get<::FlexFlow::CombineAttrs>()); return result; } case 4: { - ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); return result; } case 5: { - ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); + ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); return result; } case 6: { - ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); + ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); return result; } case 7: { - ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); return result; } case 8: { - ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); return result; } case 9: { - ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); return result; } case 10: { - ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); + ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); return result; } case 11: { - ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); + ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); return result; } case 12: { - ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); + ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); return result; } case 13: { - ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); + ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); return result; } case 14: { - ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); + ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); return result; } case 15: { - ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); + ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); return result; } case 16: { - ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); + ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); return result; } case 17: { - ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); + ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); return result; } case 18: { - ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); + ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); return result; } case 19: { - ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); return result; } case 20: { - ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ReductionAttrs>()); return result; } case 21: { - ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + ReturnType result = v(this->get<::FlexFlow::RepartitionAttrs>()); return result; } case 22: { - ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ReplicateAttrs>()); return result; } case 23: { - ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); return result; } case 24: { + ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + return result; + } + case 25: { + ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + return result; + } + case 26: { + ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + return result; + } + case 27: { + ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + return result; + } + case 28: { ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); return result; } @@ -227,90 +251,106 @@ struct PCGOperatorAttrs { return result; } case 3: { - ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); + ReturnType result = v(this->get<::FlexFlow::CombineAttrs>()); return result; } case 4: { - ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); return result; } case 5: { - ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); + ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); return result; } case 6: { - ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); + ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); return result; } case 7: { - ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); return result; } case 8: { - ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); return result; } case 9: { - ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); return result; } case 10: { - ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); + ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); return result; } case 11: { - ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); + ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); return result; } case 12: { - ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); + ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); return result; } case 13: { - ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); + ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); return result; } case 14: { - ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); + ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); return result; } case 15: { - ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); + ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); return result; } case 16: { - ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); + ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); return result; } case 17: { - ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); + ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); return result; } case 18: { - ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); + ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); return result; } case 19: { - ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); return result; } case 20: { - ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ReductionAttrs>()); return result; } case 21: { - ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + ReturnType result = v(this->get<::FlexFlow::RepartitionAttrs>()); return result; } case 22: { - ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ReplicateAttrs>()); return result; } case 23: { - ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); return result; } case 24: { + ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + return result; + } + case 25: { + ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + return result; + } + case 26: { + ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + return result; + } + case 27: { + ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + return result; + } + case 28: { ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); return result; } @@ -326,15 +366,17 @@ struct PCGOperatorAttrs { IsPartOfPCGOperatorAttrs_v, "PCGOperatorAttrs::has() expected one of " "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " - "::FlexFlow::CastAttrs, ::FlexFlow::ConcatAttrs, " - "::FlexFlow::Conv2DAttrs, ::FlexFlow::DropoutAttrs, " - "::FlexFlow::ElementBinaryAttrs, ::FlexFlow::ElementUnaryAttrs, " - "::FlexFlow::ElementScalarUnaryAttrs, ::FlexFlow::EmbeddingAttrs, " - "::FlexFlow::FlatAttrs, ::FlexFlow::GatherAttrs, " - "::FlexFlow::InputAttrs, ::FlexFlow::LayerNormAttrs, " - "::FlexFlow::LinearAttrs, ::FlexFlow::MultiHeadAttentionAttrs, " - "::FlexFlow::NoopAttrs, ::FlexFlow::Pool2DAttrs, " - "::FlexFlow::ReduceAttrs, ::FlexFlow::ReverseAttrs, " + "::FlexFlow::CastAttrs, ::FlexFlow::CombineAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReductionAttrs, ::FlexFlow::RepartitionAttrs, " + "::FlexFlow::ReplicateAttrs, ::FlexFlow::ReverseAttrs, " "::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, " "::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, " "::FlexFlow::TransposeAttrs], received T"); @@ -346,15 +388,17 @@ struct PCGOperatorAttrs { IsPartOfPCGOperatorAttrs_v, "PCGOperatorAttrs::get() expected one of " "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " - "::FlexFlow::CastAttrs, ::FlexFlow::ConcatAttrs, " - "::FlexFlow::Conv2DAttrs, ::FlexFlow::DropoutAttrs, " - "::FlexFlow::ElementBinaryAttrs, ::FlexFlow::ElementUnaryAttrs, " - "::FlexFlow::ElementScalarUnaryAttrs, ::FlexFlow::EmbeddingAttrs, " - "::FlexFlow::FlatAttrs, ::FlexFlow::GatherAttrs, " - "::FlexFlow::InputAttrs, ::FlexFlow::LayerNormAttrs, " - "::FlexFlow::LinearAttrs, ::FlexFlow::MultiHeadAttentionAttrs, " - "::FlexFlow::NoopAttrs, ::FlexFlow::Pool2DAttrs, " - "::FlexFlow::ReduceAttrs, ::FlexFlow::ReverseAttrs, " + "::FlexFlow::CastAttrs, ::FlexFlow::CombineAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReductionAttrs, ::FlexFlow::RepartitionAttrs, " + "::FlexFlow::ReplicateAttrs, ::FlexFlow::ReverseAttrs, " "::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, " "::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, " "::FlexFlow::TransposeAttrs], received T"); @@ -366,15 +410,17 @@ struct PCGOperatorAttrs { IsPartOfPCGOperatorAttrs_v, "PCGOperatorAttrs::get() expected one of " "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " - "::FlexFlow::CastAttrs, ::FlexFlow::ConcatAttrs, " - "::FlexFlow::Conv2DAttrs, ::FlexFlow::DropoutAttrs, " - "::FlexFlow::ElementBinaryAttrs, ::FlexFlow::ElementUnaryAttrs, " - "::FlexFlow::ElementScalarUnaryAttrs, ::FlexFlow::EmbeddingAttrs, " - "::FlexFlow::FlatAttrs, ::FlexFlow::GatherAttrs, " - "::FlexFlow::InputAttrs, ::FlexFlow::LayerNormAttrs, " - "::FlexFlow::LinearAttrs, ::FlexFlow::MultiHeadAttentionAttrs, " - "::FlexFlow::NoopAttrs, ::FlexFlow::Pool2DAttrs, " - "::FlexFlow::ReduceAttrs, ::FlexFlow::ReverseAttrs, " + "::FlexFlow::CastAttrs, ::FlexFlow::CombineAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReductionAttrs, ::FlexFlow::RepartitionAttrs, " + "::FlexFlow::ReplicateAttrs, ::FlexFlow::ReverseAttrs, " "::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, " "::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, " "::FlexFlow::TransposeAttrs], received T"); @@ -392,6 +438,7 @@ struct PCGOperatorAttrs { std::variant<::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, ::FlexFlow::CastAttrs, + ::FlexFlow::CombineAttrs, ::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, ::FlexFlow::DropoutAttrs, @@ -408,6 +455,9 @@ struct PCGOperatorAttrs { ::FlexFlow::NoopAttrs, ::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, + ::FlexFlow::ReductionAttrs, + ::FlexFlow::RepartitionAttrs, + ::FlexFlow::ReplicateAttrs, ::FlexFlow::ReverseAttrs, ::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h new file mode 100644 index 0000000000..4605a4c114 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PCG_OPERATOR_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PCG_OPERATOR_ATTRS_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" + +namespace FlexFlow { + +bool is_parallel_op(PCGOperatorAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml index 6f15ec417d..4062f08684 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml @@ -14,7 +14,6 @@ includes = [ "op-attrs/ops/batch_norm_attrs.dtg.h", "op-attrs/ops/cast_attrs.dtg.h", "op-attrs/ops/combine_attrs.dtg.h", - "op-attrs/ops/combine_attrs.dtg.h", "op-attrs/ops/concat_attrs.dtg.h", "op-attrs/ops/conv_2d_attrs.dtg.h", "op-attrs/ops/dropout_attrs.dtg.h", @@ -43,75 +42,116 @@ includes = [ [[values]] type = "::FlexFlow::BatchMatmulAttrs" +key = "batch_matmul" [[values]] type = "::FlexFlow::BatchNormAttrs" +key = "batch_norm" [[values]] type = "::FlexFlow::CastAttrs" +key = "cast" + +[[values]] +type = "::FlexFlow::CombineAttrs" +key = "combine_distributed" [[values]] type = "::FlexFlow::ConcatAttrs" +key = "concat" [[values]] type = "::FlexFlow::Conv2DAttrs" +key = "conv2d" [[values]] type = "::FlexFlow::DropoutAttrs" +key = "dropout" [[values]] type = "::FlexFlow::ElementBinaryAttrs" +key = "element_binary" [[values]] type = "::FlexFlow::ElementUnaryAttrs" +key = "element_unary" [[values]] type = "::FlexFlow::ElementScalarUnaryAttrs" +key = "element_scalar_unary" [[values]] type = "::FlexFlow::EmbeddingAttrs" +key = "embedding" [[values]] type = "::FlexFlow::FlatAttrs" +key = "flat" [[values]] type = "::FlexFlow::GatherAttrs" +key = "gather" [[values]] type = "::FlexFlow::InputAttrs" +key = "input" [[values]] type = "::FlexFlow::LayerNormAttrs" +key = "layer_norm" [[values]] type = "::FlexFlow::LinearAttrs" +key = "linear" [[values]] type = "::FlexFlow::MultiHeadAttentionAttrs" +key = "multi_head_attention" [[values]] type = "::FlexFlow::NoopAttrs" +key = "noop" [[values]] type = "::FlexFlow::Pool2DAttrs" +key = "pool2d" [[values]] type = "::FlexFlow::ReduceAttrs" +key = "reduce" + +[[values]] +type = "::FlexFlow::ReductionAttrs" +key = "reduce_distributed" + +[[values]] +type = "::FlexFlow::RepartitionAttrs" +key = "partition_distributed" + +[[values]] +type = "::FlexFlow::ReplicateAttrs" +key = "replicate_distributed" [[values]] type = "::FlexFlow::ReverseAttrs" +key = "reverse" [[values]] type = "::FlexFlow::ReshapeAttrs" +key = "reshape" [[values]] type = "::FlexFlow::SplitAttrs" +key = "split" [[values]] type = "::FlexFlow::SoftmaxAttrs" +key = "softmax" [[values]] type = "::FlexFlow::TopKAttrs" +key = "topk" [[values]] type = "::FlexFlow::TransposeAttrs" +key = "transpose" diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.h new file mode 100644 index 0000000000..250ba29947 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.h @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml +/* proj-data +{ + "generated_from": "f501393070c8d55a05c43dd73a81a8d7" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/replica_type.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ReplicaParallelDim { + ReplicaParallelDim() = delete; + ReplicaParallelDim(int const °ree, + ::FlexFlow::ReplicaType const &replica_type); + + bool operator==(ReplicaParallelDim const &) const; + bool operator!=(ReplicaParallelDim const &) const; + bool operator<(ReplicaParallelDim const &) const; + bool operator>(ReplicaParallelDim const &) const; + bool operator<=(ReplicaParallelDim const &) const; + bool operator>=(ReplicaParallelDim const &) const; + int degree; + ::FlexFlow::ReplicaType replica_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReplicaParallelDim const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReplicaParallelDim from_json(json const &); + static void to_json(json &, FlexFlow::ReplicaParallelDim const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicaParallelDim const &); +std::ostream &operator<<(std::ostream &, ReplicaParallelDim const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_DTG_H diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim.h new file mode 100644 index 0000000000..da3913b426 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_H + +#include "op-attrs/replica_parallel_dim.dtg.h" + +namespace FlexFlow { + +bool is_valid(ReplicaParallelDim const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml b/lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml new file mode 100644 index 0000000000..2ad442aa22 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "ReplicaParallelDim" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/replica_type.dtg.h", +] + +[[fields]] +name = "degree" +type = "int" + +[[fields]] +name = "replica_type" +type = "::FlexFlow::ReplicaType" diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h new file mode 100644 index 0000000000..c6c025f31c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml +/* proj-data +{ + "generated_from": "20d8004e6f1e710688fe692b92dc2816" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_SET_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_SET_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ReplicaParallelDimSet { + ReplicaParallelDimSet() = delete; + ReplicaParallelDimSet(int const &sum_degree, int const &discard_copy_degree); + + bool operator==(ReplicaParallelDimSet const &) const; + bool operator!=(ReplicaParallelDimSet const &) const; + bool operator<(ReplicaParallelDimSet const &) const; + bool operator>(ReplicaParallelDimSet const &) const; + bool operator<=(ReplicaParallelDimSet const &) const; + bool operator>=(ReplicaParallelDimSet const &) const; + int sum_degree; + int discard_copy_degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReplicaParallelDimSet const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReplicaParallelDimSet from_json(json const &); + static void to_json(json &, FlexFlow::ReplicaParallelDimSet const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicaParallelDimSet const &); +std::ostream &operator<<(std::ostream &, ReplicaParallelDimSet const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_SET_DTG_H diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.h new file mode 100644 index 0000000000..e8b6a92114 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_SET_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_SET_H + +#include "op-attrs/replica_parallel_dim.dtg.h" +#include "op-attrs/replica_parallel_dim_set.dtg.h" +#include "op-attrs/replica_type.dtg.h" + +namespace FlexFlow { + +ReplicaParallelDimSet empty_replica_parallel_dim_set(); +int get_degree_of_replica_type(ReplicaParallelDimSet const &, ReplicaType); +std::unordered_set get_replica_dims(ReplicaParallelDimSet const &); +bool is_valid(ReplicaParallelDimSet const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml new file mode 100644 index 0000000000..7c05a7809f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "ReplicaParallelDimSet" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "sum_degree" +type = "int" + +[[fields]] +name = "discard_copy_degree" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/replica_type.dtg.h b/lib/op-attrs/include/op-attrs/replica_type.dtg.h new file mode 100644 index 0000000000..3b965d3e77 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_type.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_type.enum.toml +/* proj-data +{ + "generated_from": "6ecba7a6851b8bea93705bba24661149" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_TYPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_TYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class ReplicaType { SUM, DISCARD_COPY }; +std::string format_as(ReplicaType); +std::ostream &operator<<(std::ostream &, ReplicaType); +void to_json(::nlohmann::json &, ReplicaType); +void from_json(::nlohmann::json const &, ReplicaType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReplicaType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_TYPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/replica_type.enum.toml b/lib/op-attrs/include/op-attrs/replica_type.enum.toml new file mode 100644 index 0000000000..0c0eb5e3ab --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_type.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ReplicaType" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "SUM" + +[[values]] +name = "DISCARD_COPY" diff --git a/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.h b/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.h new file mode 100644 index 0000000000..631852c259 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml +/* proj-data +{ + "generated_from": "18e074f80556d90b9b27d6515bbf9071" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHARD_PARALLEL_DIM_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHARD_PARALLEL_DIM_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ShardParallelDim { + ShardParallelDim() = delete; + ShardParallelDim(size_t const &size, int const °ree); + + bool operator==(ShardParallelDim const &) const; + bool operator!=(ShardParallelDim const &) const; + bool operator<(ShardParallelDim const &) const; + bool operator>(ShardParallelDim const &) const; + bool operator<=(ShardParallelDim const &) const; + bool operator>=(ShardParallelDim const &) const; + size_t size; + int degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ShardParallelDim const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ShardParallelDim from_json(json const &); + static void to_json(json &, FlexFlow::ShardParallelDim const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ShardParallelDim const &); +std::ostream &operator<<(std::ostream &, ShardParallelDim const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHARD_PARALLEL_DIM_DTG_H diff --git a/lib/op-attrs/include/op-attrs/shard_parallel_dim.h b/lib/op-attrs/include/op-attrs/shard_parallel_dim.h new file mode 100644 index 0000000000..0a6323192d --- /dev/null +++ b/lib/op-attrs/include/op-attrs/shard_parallel_dim.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHARD_PARALLEL_DIM_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHARD_PARALLEL_DIM_H + +#include "op-attrs/shard_parallel_dim.dtg.h" + +namespace FlexFlow { + +bool is_valid(ShardParallelDim const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.struct.toml b/lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml similarity index 72% rename from lib/op-attrs/include/op-attrs/parallel_dim.struct.toml rename to lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml index 7ecb6a5b04..21c81396d1 100644 --- a/lib/op-attrs/include/op-attrs/parallel_dim.struct.toml +++ b/lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "ParallelDim" +name = "ShardParallelDim" features = [ "eq", "ord", @@ -16,7 +16,3 @@ type = "size_t" [[fields]] name = "degree" type = "int" - -[[fields]] -name = "is_replica_dim" -type = "bool" diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h b/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h index 2773317607..8ac6655956 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/tensor_shape.struct.toml /* proj-data { - "generated_from": "c02c9d2331d864a25c1443cfe70062d1" + "generated_from": "52968754cf94f415c366d228c87042db" } */ @@ -12,8 +12,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/datatype.h" -#include "op-attrs/tensor_dims.h" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/tensor_dims.dtg.h" #include #include #include diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.h b/lib/op-attrs/include/op-attrs/tensor_shape.h index 75ab2c2a64..92d360a95d 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.h @@ -6,6 +6,7 @@ namespace FlexFlow { size_t dim_at_idx(TensorShape const &, ff_dim_t); +size_t num_dims(TensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml b/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml index b4d8449a72..24f9ff1b79 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml +++ b/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/tensor_dims.h", - "op-attrs/datatype.h", + "op-attrs/tensor_dims.dtg.h", + "op-attrs/datatype.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/src/batch_norm.cc b/lib/op-attrs/src/batch_norm.cc deleted file mode 100644 index 4e352d5f1c..0000000000 --- a/lib/op-attrs/src/batch_norm.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/batch_norm.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/conv_2d.cc b/lib/op-attrs/src/conv_2d.cc deleted file mode 100644 index 40ba3c8b41..0000000000 --- a/lib/op-attrs/src/conv_2d.cc +++ /dev/null @@ -1,115 +0,0 @@ -#include "op-attrs/ops/conv_2d.h" -#include "parallel_dim_mapping_record.h" -#include "parallel_dim_mapping_record_solver.h" -#include "utils/vector.h" - -namespace FlexFlow { - -namespace Input { -constexpr int WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, REPLICA = 4, - NUMDIM = 5; -} - -namespace Output { -constexpr int WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, REPLICA = 4, - NUMDIM = 5; -} - -namespace Kernel { -constexpr int WIDTH = 0, HEIGHT = 1, CHANNEL_IN = 2, CHANNEL_OUT = 3, - REPLICA = 4; -constexpr int WEIGHT_IDX = 0; -} // namespace Kernel - -namespace Bias { -constexpr int CHANNEL = 0, REPLICA_1 = 1, REPLICA_2 = 2, REPLICA_3 = 3, - REPLICA_4 = 4; -constexpr int WEIGHT_IDX = 1; -} // namespace Bias - -static std::vector - construct_output_mappings(ParallelTensorShape const &input_shape) { - return construct_output_parallel_dims( - {{Input::CHANNEL, MappingOperation::REPLICATE, Output::REPLICA}, - {Input::SAMPLE, MappingOperation::PARTITION, Output::SAMPLE}, - {Input::REPLICA, MappingOperation::PARTITION, Output::CHANNEL}, - {Input::HEIGHT, MappingOperation::PARTITION, Output::HEIGHT}, - {Input::WIDTH, MappingOperation::PARTITION, Output::WIDTH}}); -} - -static std::vector - construct_kernel_mappings(ParallelTensorShape const &input_shape) { - return construct_weight_parallel_dims( - { - {Input::REPLICA, MappingOperation::PARTITION, Kernel::CHANNEL_OUT}, - {Input::SAMPLE, MappingOperation::REPLICATE, Kernel::REPLICA}, - {Input::CHANNEL, MappingOperation::PARTITION, Kernel::CHANNEL_IN}, - {Input::HEIGHT, - MappingOperation::REPLICATE, - Kernel::HEIGHT}, // Kernel::{HEIGHT, WEIGHT} would both work - // here - {Input::WIDTH, - MappingOperation::REPLICATE, - Kernel::WIDTH}, // same as above - }, - 0, - Kernel::WEIGHT_IDX); -} - -static std::vector - construct_bias_mappings(ParallelTensorShape const &input_shape) { - return construct_weight_parallel_dims({{Input::REPLICA, Bias::REPLICA_1}, - {Input::SAMPLE, Bias::REPLICA_2}, - {Input::CHANNEL, Bias::CHANNEL}, - {Input::HEIGHT, Bias::REPLICA_3}, - {Input::WIDTH, Bias::REPLICA_4}}, - 0, - Bias::WEIGHT_IDX); -} - -std::vector - construct_mappings(ParallelTensorShape const &input_shape, bool use_bias) { - std::vector mappings = - concat(construct_output_mappings(input_shape), - construct_kernel_mappings(input_shape)); - if (use_bias) { - std::vector bias_mappings = - construct_bias_mappings(input_shape); - mappings.insert(mappings.end(), bias_mappings.begin(), bias_mappings.end()); - } - - return mappings; -} - -TensorShape get_kernel_shape(Conv2DAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); -} - -TensorShape get_bias_shape(Conv2DAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); -} - -/* bool Conv2DAttrs::is_valid(ParallelTensorShape const &input_shape) const { */ -/* bool is_valid = true; */ -/* is_valid &= input_shape.is_valid(); */ -/* is_valid &= this->calculate_output_shape(input_shape).is_valid(); */ -/* is_valid &= this->calculate_kernel_shape(input_shape).is_valid(); */ -/* if (use_bias) { */ -/* is_valid &= this->calculate_bias_shape(input_shape).is_valid(); */ -/* } */ - -/* // TODO FIXME: Currently disable parallelizing the height and width - * dimension */ -/* if (input_shape.at(0).degree > 1 || input_shape.at(1).degree > 1) { */ -/* return false; */ -/* } */ - -/* return is_valid; */ - -/* } */ - -/* OperatorType Conv2DAttrs::op_type() const { */ -/* return OP_CONV2D; */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/linear.cc b/lib/op-attrs/src/linear.cc deleted file mode 100644 index 16a94e7f6c..0000000000 --- a/lib/op-attrs/src/linear.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/linear.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/noop.cc b/lib/op-attrs/src/noop.cc deleted file mode 100644 index 387660164f..0000000000 --- a/lib/op-attrs/src/noop.cc +++ /dev/null @@ -1 +0,0 @@ -#include "op-attrs/ops/noop.h" diff --git a/lib/op-attrs/src/op-attrs/as_dot.cc b/lib/op-attrs/src/op-attrs/as_dot.cc new file mode 100644 index 0000000000..f8d05de941 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/as_dot.cc @@ -0,0 +1,13 @@ +#include "op-attrs/as_dot.h" + +namespace FlexFlow { + +RecordFormatter as_dot(ComputationGraphOpAttrs const &attrs) { + NOT_IMPLEMENTED(); +} + +RecordFormatter as_dot(PCGOperatorAttrs const &attrs) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc new file mode 100644 index 0000000000..b92e835ee6 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc @@ -0,0 +1,521 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml +/* proj-data +{ + "generated_from": "87653647c900faaf564d3069478569e7" +} +*/ + +#include "op-attrs/computation_graph_op_attrs.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::BatchMatmulAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::BatchNormAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::BroadcastAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs(::FlexFlow::CastAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ConcatAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::Conv2DAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::DropoutAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ElementBinaryAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ElementUnaryAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ElementScalarUnaryAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::EmbeddingAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs(::FlexFlow::FlatAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::GatherAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::InputAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::LayerNormAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::LinearAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::MultiHeadAttentionAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs(::FlexFlow::NoopAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::Pool2DAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ReduceAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ReverseAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ReshapeAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::SplitAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::SoftmaxAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs(::FlexFlow::TopKAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::TransposeAttrs const &v) + : raw_variant(v) {} +bool ComputationGraphOpAttrs::operator==( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant == other.raw_variant; +} +bool ComputationGraphOpAttrs::operator!=( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant != other.raw_variant; +} +bool ComputationGraphOpAttrs::operator<( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant < other.raw_variant; +} +bool ComputationGraphOpAttrs::operator>( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant > other.raw_variant; +} +bool ComputationGraphOpAttrs::operator<=( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool ComputationGraphOpAttrs::operator>=( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::ComputationGraphOpAttrs>::operator()( + ::FlexFlow::ComputationGraphOpAttrs const &x) const { + return std::hash>{}(x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::ComputationGraphOpAttrs + adl_serializer<::FlexFlow::ComputationGraphOpAttrs>::from_json( + json const &j) { + std::string key = j.at("type").template get(); + if (key == "batch_matmul") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::BatchMatmulAttrs>()}; + } else if (key == "batch_norm") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::BatchNormAttrs>()}; + } else if (key == "broadcast") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::BroadcastAttrs>()}; + } else if (key == "cast") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::CastAttrs>()}; + } else if (key == "concat") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ConcatAttrs>()}; + } else if (key == "conv2d") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::Conv2DAttrs>()}; + } else if (key == "dropout") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::DropoutAttrs>()}; + } else if (key == "element_binary") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ElementBinaryAttrs>()}; + } else if (key == "element_unary") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ElementUnaryAttrs>()}; + } else if (key == "element_scalar_unary") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ElementScalarUnaryAttrs>()}; + } else if (key == "embedding") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::EmbeddingAttrs>()}; + } else if (key == "flat") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::FlatAttrs>()}; + } else if (key == "gather") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::GatherAttrs>()}; + } else if (key == "input") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::InputAttrs>()}; + } else if (key == "layer_norm") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::LayerNormAttrs>()}; + } else if (key == "linear") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::LinearAttrs>()}; + } else if (key == "multi_head_attention") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::MultiHeadAttentionAttrs>()}; + } else if (key == "noop") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::NoopAttrs>()}; + } else if (key == "pool2d") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::Pool2DAttrs>()}; + } else if (key == "reduce") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ReduceAttrs>()}; + } else if (key == "reverse") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ReverseAttrs>()}; + } else if (key == "reshape") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ReshapeAttrs>()}; + } else if (key == "split") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::SplitAttrs>()}; + } else if (key == "softmax") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::SoftmaxAttrs>()}; + } else if (key == "topk") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::TopKAttrs>()}; + } else if (key == "transpose") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::TransposeAttrs>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::ComputationGraphOpAttrs>::to_json( + json &j, ::FlexFlow::ComputationGraphOpAttrs const &x) { + j["__type"] = "ComputationGraphOpAttrs"; + switch (x.index()) { + case 0: { + j["type"] = "batch_matmul"; + j["value"] = x.get<::FlexFlow::BatchMatmulAttrs>(); + break; + } + case 1: { + j["type"] = "batch_norm"; + j["value"] = x.get<::FlexFlow::BatchNormAttrs>(); + break; + } + case 2: { + j["type"] = "broadcast"; + j["value"] = x.get<::FlexFlow::BroadcastAttrs>(); + break; + } + case 3: { + j["type"] = "cast"; + j["value"] = x.get<::FlexFlow::CastAttrs>(); + break; + } + case 4: { + j["type"] = "concat"; + j["value"] = x.get<::FlexFlow::ConcatAttrs>(); + break; + } + case 5: { + j["type"] = "conv2d"; + j["value"] = x.get<::FlexFlow::Conv2DAttrs>(); + break; + } + case 6: { + j["type"] = "dropout"; + j["value"] = x.get<::FlexFlow::DropoutAttrs>(); + break; + } + case 7: { + j["type"] = "element_binary"; + j["value"] = x.get<::FlexFlow::ElementBinaryAttrs>(); + break; + } + case 8: { + j["type"] = "element_unary"; + j["value"] = x.get<::FlexFlow::ElementUnaryAttrs>(); + break; + } + case 9: { + j["type"] = "element_scalar_unary"; + j["value"] = x.get<::FlexFlow::ElementScalarUnaryAttrs>(); + break; + } + case 10: { + j["type"] = "embedding"; + j["value"] = x.get<::FlexFlow::EmbeddingAttrs>(); + break; + } + case 11: { + j["type"] = "flat"; + j["value"] = x.get<::FlexFlow::FlatAttrs>(); + break; + } + case 12: { + j["type"] = "gather"; + j["value"] = x.get<::FlexFlow::GatherAttrs>(); + break; + } + case 13: { + j["type"] = "input"; + j["value"] = x.get<::FlexFlow::InputAttrs>(); + break; + } + case 14: { + j["type"] = "layer_norm"; + j["value"] = x.get<::FlexFlow::LayerNormAttrs>(); + break; + } + case 15: { + j["type"] = "linear"; + j["value"] = x.get<::FlexFlow::LinearAttrs>(); + break; + } + case 16: { + j["type"] = "multi_head_attention"; + j["value"] = x.get<::FlexFlow::MultiHeadAttentionAttrs>(); + break; + } + case 17: { + j["type"] = "noop"; + j["value"] = x.get<::FlexFlow::NoopAttrs>(); + break; + } + case 18: { + j["type"] = "pool2d"; + j["value"] = x.get<::FlexFlow::Pool2DAttrs>(); + break; + } + case 19: { + j["type"] = "reduce"; + j["value"] = x.get<::FlexFlow::ReduceAttrs>(); + break; + } + case 20: { + j["type"] = "reverse"; + j["value"] = x.get<::FlexFlow::ReverseAttrs>(); + break; + } + case 21: { + j["type"] = "reshape"; + j["value"] = x.get<::FlexFlow::ReshapeAttrs>(); + break; + } + case 22: { + j["type"] = "split"; + j["value"] = x.get<::FlexFlow::SplitAttrs>(); + break; + } + case 23: { + j["type"] = "softmax"; + j["value"] = x.get<::FlexFlow::SoftmaxAttrs>(); + break; + } + case 24: { + j["type"] = "topk"; + j["value"] = x.get<::FlexFlow::TopKAttrs>(); + break; + } + case 25: { + j["type"] = "transpose"; + j["value"] = x.get<::FlexFlow::TransposeAttrs>(); + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type ComputationGraphOpAttrs", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::ComputationGraphOpAttrs const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + case 3: { + oss << ""; + break; + } + case 4: { + oss << ""; + break; + } + case 5: { + oss << ""; + break; + } + case 6: { + oss << ""; + break; + } + case 7: { + oss << ""; + break; + } + case 8: { + oss << ""; + break; + } + case 9: { + oss << ""; + break; + } + case 10: { + oss << ""; + break; + } + case 11: { + oss << ""; + break; + } + case 12: { + oss << ""; + break; + } + case 13: { + oss << ""; + break; + } + case 14: { + oss << ""; + break; + } + case 15: { + oss << ""; + break; + } + case 16: { + oss << ""; + break; + } + case 17: { + oss << ""; + break; + } + case 18: { + oss << ""; + break; + } + case 19: { + oss << ""; + break; + } + case 20: { + oss << ""; + break; + } + case 21: { + oss << ""; + break; + } + case 22: { + oss << ""; + break; + } + case 23: { + oss << ""; + break; + } + case 24: { + oss << ""; + break; + } + case 25: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type ComputationGraphOpAttrs", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::ComputationGraphOpAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index 9fb884db43..bc7ea3f57c 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -81,9 +81,10 @@ ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, ParallelTensorShape const &query_shape, ParallelTensorShape const &key_shape, ParallelTensorShape const &value_shape) { - ParallelTensorShape output_shape = query_shape; - dim_at_idx(output_shape, ff_dim_t(num_dims(output_shape) - 1)).size = attrs.embed_dim; - return output_shape; + NOT_IMPLEMENTED(); + /* ParallelTensorShape output_shape = query_shape; */ + /* dim_at_idx(output_shape, ff_dim_t(num_dims(output_shape) - 1)).size = attrs.embed_dim; */ + /* return output_shape; */ } TensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, diff --git a/lib/op-attrs/src/cast.cc b/lib/op-attrs/src/op-attrs/ops/cast.cc similarity index 100% rename from lib/op-attrs/src/cast.cc rename to lib/op-attrs/src/op-attrs/ops/cast.cc diff --git a/lib/op-attrs/src/combine.cc b/lib/op-attrs/src/op-attrs/ops/combine.cc similarity index 100% rename from lib/op-attrs/src/combine.cc rename to lib/op-attrs/src/op-attrs/ops/combine.cc diff --git a/lib/op-attrs/src/concat.cc b/lib/op-attrs/src/op-attrs/ops/concat.cc similarity index 100% rename from lib/op-attrs/src/concat.cc rename to lib/op-attrs/src/op-attrs/ops/concat.cc diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc new file mode 100644 index 0000000000..c45c77672c --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -0,0 +1,203 @@ +#include "op-attrs/ops/conv_2d.h" +#include "conv_2d_input_shape.dtg.h" +#include "conv_2d_parallel_input_shape.dtg.h" + +namespace FlexFlow { + +static size_t as_size_t(int x) { + assert (x >= 0); + return static_cast(x); +} + +static Conv2DInputShape parse_input_shape(TensorShape const &input) { + assert(num_dims(input) == 4); + + size_t num_samples = dim_at_idx(input, ff_dim_t{0}); + size_t in_channels = dim_at_idx(input, ff_dim_t{1}); + size_t in_height = dim_at_idx(input, ff_dim_t{2}); + size_t in_width = dim_at_idx(input, ff_dim_t{3}); + + return Conv2DInputShape{ + num_samples, + in_channels, + in_height, + in_width, + input.data_type, + }; +} + +static Conv2DParallelInputShape parse_parallel_input_shape(ParallelTensorShape const &input) { + assert(num_shard_dims(input) == 4); + + ShardParallelDim sample_dim = shard_dim_at_idx(input, ff_dim_t{0}); + ShardParallelDim channel_dim = shard_dim_at_idx(input, ff_dim_t{1}); + ShardParallelDim height_dim = shard_dim_at_idx(input, ff_dim_t{2}); + ShardParallelDim width_dim = shard_dim_at_idx(input, ff_dim_t{3}); + + return Conv2DParallelInputShape{ + sample_dim, + channel_dim, + height_dim, + width_dim, + input.dims.replica_dims.sum_degree, + input.dims.replica_dims.discard_copy_degree, + input.data_type, + }; +} + +TensorShape get_kernel_shape(Conv2DAttrs const &attrs, TensorShape const &raw_input_shape) { + assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DInputShape input = parse_input_shape(raw_input_shape); + + return TensorShape{ + TensorDims{ + FFOrdered{ + as_size_t(attrs.out_channels), + input.num_channels, + as_size_t(attrs.kernel_h), + as_size_t(attrs.kernel_w), + } + }, + input.datatype, + }; + + NOT_IMPLEMENTED(); +} + +TensorShape get_bias_shape(Conv2DAttrs const &attrs, TensorShape const &raw_input_shape) { + assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DInputShape input = parse_input_shape(raw_input_shape); + + return TensorShape{ + TensorDims{ + FFOrdered{ + as_size_t(attrs.out_channels) + }, + }, + input.datatype, + }; +} + +TensorShape get_output_shape(Conv2DAttrs const &attrs, TensorShape const &raw_input_shape) { + assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DInputShape input = parse_input_shape(raw_input_shape); + + size_t out_height = (input.height - (2 * attrs.padding_h) - (attrs.kernel_h - 1)) / attrs.stride_h; + size_t out_width = (input.width - (2 * attrs.padding_w) - (attrs.kernel_w - 1)) / attrs.stride_w; + + assert (attrs.out_channels > 0); + + return TensorShape{ + TensorDims{ + FFOrdered{ + input.num_samples, + as_size_t(attrs.out_channels), + out_height, + out_width, + } + }, + input.datatype + }; +} + +ParallelTensorShape get_kernel_shape(Conv2DAttrs const &attrs, ParallelTensorShape const &raw_input_shape) { + assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); + + ShardParallelDim output_channels_dim = {as_size_t(attrs.out_channels), input.discard_copy_reduction_degree}; + ShardParallelDim input_channels_dim = {as_size_t(input.channel_dim.size), input.channel_dim.degree}; + ShardParallelDim kernel_height_dim = {as_size_t(attrs.kernel_h), 1}; + ShardParallelDim kernel_width_dim = {as_size_t(attrs.kernel_w), 1}; + + int sum_degree = 1; + int discard_copy_degree = input.height_dim.degree * input.width_dim.degree * input.sum_reduction_degree; + + ParallelTensorShape result = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + output_channels_dim, + input_channels_dim, + kernel_height_dim, + kernel_width_dim, + }, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }, + }, + input.datatype, + }; + + assert (total_parallel_degree(result.dims) == total_parallel_degree(raw_input_shape.dims)); + + return result; +} + +ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs, ParallelTensorShape const &raw_input_shape) { + assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); + + ShardParallelDim output_channels_dim = {attrs.out_channels, input.discard_copy_reduction_degree}; + + int sum_degree = 1; + int discard_copy_degree = input.height_dim.degree * input.width_dim.degree * input.sum_reduction_degree * input.channel_dim.degree; + + ParallelTensorShape result = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + output_channels_dim, + }, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }, + }, + input.datatype, + }; + + assert (total_parallel_degree(result.dims) == total_parallel_degree(raw_input_shape.dims)); + + return result; +} + +ParallelTensorShape get_output_shape(Conv2DAttrs const &attrs, ParallelTensorShape const &raw_input_shape) { + assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); + + TensorShape unpar_output_shape = get_output_shape(attrs, get_reduced_shape(raw_input_shape)); + + size_t num_samples = dim_at_idx(unpar_output_shape, ff_dim_t{0}); + size_t num_channels = dim_at_idx(unpar_output_shape, ff_dim_t{1}); + size_t height = dim_at_idx(unpar_output_shape, ff_dim_t{2}); + size_t width = dim_at_idx(unpar_output_shape, ff_dim_t{3}); + + ShardParallelDim sample_dim = {num_samples, input.sample_dim.degree}; + ShardParallelDim channel_dim = {num_channels, input.discard_copy_reduction_degree}; + ShardParallelDim height_dim = {height, input.height_dim.degree}; + ShardParallelDim width_dim = {width, input.width_dim.degree}; + + int sum_degree = input.channel_dim.degree * input.sum_reduction_degree; + int discard_copy_degree = 1; + + ParallelTensorShape result = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + sample_dim, + channel_dim, + height_dim, + width_dim, + }, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }, + }, + input.datatype, + }; + + assert (total_parallel_degree(result.dims) == total_parallel_degree(raw_input_shape.dims)); + + return result; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.dtg.cc new file mode 100644 index 0000000000..47f86afc53 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.dtg.cc @@ -0,0 +1,157 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.struct.toml +/* proj-data +{ + "generated_from": "51911f58c134d55b2d0245444acbae53" +} +*/ + +#include "op-attrs/ops/conv_2d_input_shape.dtg.h" + +#include "op-attrs/datatype.dtg.h" +#include +#include + +namespace FlexFlow { +Conv2DInputShape::Conv2DInputShape(size_t const &num_samples, + size_t const &num_channels, + size_t const &height, + size_t const &width, + ::FlexFlow::DataType const &datatype) + : num_samples(num_samples), num_channels(num_channels), height(height), + width(width), datatype(datatype) {} +bool Conv2DInputShape::operator==(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) == std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +bool Conv2DInputShape::operator!=(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) != std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +bool Conv2DInputShape::operator<(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) < std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +bool Conv2DInputShape::operator>(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) > std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +bool Conv2DInputShape::operator<=(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) <= std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +bool Conv2DInputShape::operator>=(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) >= std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::Conv2DInputShape const &x) const { + size_t result = 0; + result ^= std::hash{}(x.num_samples) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.num_channels) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.height) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.width) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.datatype) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::Conv2DInputShape + adl_serializer::from_json(json const &j) { + return {j.at("num_samples").template get(), + j.at("num_channels").template get(), + j.at("height").template get(), + j.at("width").template get(), + j.at("datatype").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::Conv2DInputShape const &v) { + j["__type"] = "Conv2DInputShape"; + j["num_samples"] = v.num_samples; + j["num_channels"] = v.num_channels; + j["height"] = v.height; + j["width"] = v.width; + j["datatype"] = v.datatype; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DInputShape const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, Conv2DInputShape const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.dtg.h b/lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.dtg.h new file mode 100644 index 0000000000..92c6f57e73 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.dtg.h @@ -0,0 +1,72 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.struct.toml +/* proj-data +{ + "generated_from": "51911f58c134d55b2d0245444acbae53" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_SRC_OP_ATTRS_OPS_CONV_2D_INPUT_SHAPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_SRC_OP_ATTRS_OPS_CONV_2D_INPUT_SHAPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct Conv2DInputShape { + Conv2DInputShape() = delete; + Conv2DInputShape(size_t const &num_samples, + size_t const &num_channels, + size_t const &height, + size_t const &width, + ::FlexFlow::DataType const &datatype); + + bool operator==(Conv2DInputShape const &) const; + bool operator!=(Conv2DInputShape const &) const; + bool operator<(Conv2DInputShape const &) const; + bool operator>(Conv2DInputShape const &) const; + bool operator<=(Conv2DInputShape const &) const; + bool operator>=(Conv2DInputShape const &) const; + size_t num_samples; + size_t num_channels; + size_t height; + size_t width; + ::FlexFlow::DataType datatype; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::Conv2DInputShape const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::Conv2DInputShape from_json(json const &); + static void to_json(json &, FlexFlow::Conv2DInputShape const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DInputShape const &); +std::ostream &operator<<(std::ostream &, Conv2DInputShape const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_SRC_OP_ATTRS_OPS_CONV_2D_INPUT_SHAPE_DTG_H diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.struct.toml b/lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.struct.toml new file mode 100644 index 0000000000..77e8c51244 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.struct.toml @@ -0,0 +1,35 @@ +namespace = "FlexFlow" +name = "Conv2DInputShape" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [ + "", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "num_samples" +type = "size_t" + +[[fields]] +name = "num_channels" +type = "size_t" + +[[fields]] +name = "height" +type = "size_t" + +[[fields]] +name = "width" +type = "size_t" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.dtg.cc new file mode 100644 index 0000000000..46e8061e7e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.dtg.cc @@ -0,0 +1,211 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.struct.toml +/* proj-data +{ + "generated_from": "d80394bdc90f843372760310b6d17a22" +} +*/ + +#include "op-attrs/ops/conv_2d_parallel_input_shape.dtg.h" + +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include + +namespace FlexFlow { +Conv2DParallelInputShape::Conv2DParallelInputShape( + ::FlexFlow::ShardParallelDim const &sample_dim, + ::FlexFlow::ShardParallelDim const &channel_dim, + ::FlexFlow::ShardParallelDim const &height_dim, + ::FlexFlow::ShardParallelDim const &width_dim, + int const &sum_reduction_degree, + int const &discard_copy_reduction_degree, + ::FlexFlow::DataType const &datatype) + : sample_dim(sample_dim), channel_dim(channel_dim), height_dim(height_dim), + width_dim(width_dim), sum_reduction_degree(sum_reduction_degree), + discard_copy_reduction_degree(discard_copy_reduction_degree), + datatype(datatype) {} +bool Conv2DParallelInputShape::operator==( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) == + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +bool Conv2DParallelInputShape::operator!=( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) != + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +bool Conv2DParallelInputShape::operator<( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) < + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +bool Conv2DParallelInputShape::operator>( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) > + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +bool Conv2DParallelInputShape::operator<=( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) <= + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +bool Conv2DParallelInputShape::operator>=( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) >= + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::Conv2DParallelInputShape const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.sample_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.channel_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.height_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.width_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.sum_reduction_degree) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.discard_copy_reduction_degree) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.datatype) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::Conv2DParallelInputShape + adl_serializer::from_json( + json const &j) { + return {j.at("sample_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("channel_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("height_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("width_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("sum_reduction_degree").template get(), + j.at("discard_copy_reduction_degree").template get(), + j.at("datatype").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::Conv2DParallelInputShape const &v) { + j["__type"] = "Conv2DParallelInputShape"; + j["sample_dim"] = v.sample_dim; + j["channel_dim"] = v.channel_dim; + j["height_dim"] = v.height_dim; + j["width_dim"] = v.width_dim; + j["sum_reduction_degree"] = v.sum_reduction_degree; + j["discard_copy_reduction_degree"] = v.discard_copy_reduction_degree; + j["datatype"] = v.datatype; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DParallelInputShape const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, Conv2DParallelInputShape const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.dtg.h b/lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.dtg.h new file mode 100644 index 0000000000..bae26b378b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.dtg.h @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.struct.toml +/* proj-data +{ + "generated_from": "d80394bdc90f843372760310b6d17a22" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_SRC_OP_ATTRS_OPS_CONV_2D_PARALLEL_INPUT_SHAPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_SRC_OP_ATTRS_OPS_CONV_2D_PARALLEL_INPUT_SHAPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct Conv2DParallelInputShape { + Conv2DParallelInputShape() = delete; + Conv2DParallelInputShape(::FlexFlow::ShardParallelDim const &sample_dim, + ::FlexFlow::ShardParallelDim const &channel_dim, + ::FlexFlow::ShardParallelDim const &height_dim, + ::FlexFlow::ShardParallelDim const &width_dim, + int const &sum_reduction_degree, + int const &discard_copy_reduction_degree, + ::FlexFlow::DataType const &datatype); + + bool operator==(Conv2DParallelInputShape const &) const; + bool operator!=(Conv2DParallelInputShape const &) const; + bool operator<(Conv2DParallelInputShape const &) const; + bool operator>(Conv2DParallelInputShape const &) const; + bool operator<=(Conv2DParallelInputShape const &) const; + bool operator>=(Conv2DParallelInputShape const &) const; + ::FlexFlow::ShardParallelDim sample_dim; + ::FlexFlow::ShardParallelDim channel_dim; + ::FlexFlow::ShardParallelDim height_dim; + ::FlexFlow::ShardParallelDim width_dim; + int sum_reduction_degree; + int discard_copy_reduction_degree; + ::FlexFlow::DataType datatype; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::Conv2DParallelInputShape const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::Conv2DParallelInputShape from_json(json const &); + static void to_json(json &, FlexFlow::Conv2DParallelInputShape const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DParallelInputShape const &); +std::ostream &operator<<(std::ostream &, Conv2DParallelInputShape const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_SRC_OP_ATTRS_OPS_CONV_2D_PARALLEL_INPUT_SHAPE_DTG_H diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.struct.toml b/lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.struct.toml new file mode 100644 index 0000000000..68cbd878d1 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.struct.toml @@ -0,0 +1,43 @@ +namespace = "FlexFlow" +name = "Conv2DParallelInputShape" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "sample_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "channel_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "height_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "width_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "sum_reduction_degree" +type = "int" + +[[fields]] +name = "discard_copy_reduction_degree" +type = "int" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/src/element_binary.cc b/lib/op-attrs/src/op-attrs/ops/element_binary.cc similarity index 100% rename from lib/op-attrs/src/element_binary.cc rename to lib/op-attrs/src/op-attrs/ops/element_binary.cc diff --git a/lib/op-attrs/src/element_unary.cc b/lib/op-attrs/src/op-attrs/ops/element_unary.cc similarity index 100% rename from lib/op-attrs/src/element_unary.cc rename to lib/op-attrs/src/op-attrs/ops/element_unary.cc diff --git a/lib/op-attrs/src/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc similarity index 100% rename from lib/op-attrs/src/flat.cc rename to lib/op-attrs/src/op-attrs/ops/flat.cc diff --git a/lib/op-attrs/src/gather.cc b/lib/op-attrs/src/op-attrs/ops/gather.cc similarity index 100% rename from lib/op-attrs/src/gather.cc rename to lib/op-attrs/src/op-attrs/ops/gather.cc diff --git a/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc index e1837a7360..06c891d4de 100644 --- a/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml /* proj-data { - "generated_from": "722d92014b31bffcd5ad45eda476d8b3" + "generated_from": "8d1e2a2d3852bfb59d8668d14d52c958" } */ @@ -28,26 +28,6 @@ bool ParallelMultiHeadAttentionInputs::operator!=( return std::tie(this->query, this->key, this->value) != std::tie(other.query, other.key, other.value); } -bool ParallelMultiHeadAttentionInputs::operator<( - ParallelMultiHeadAttentionInputs const &other) const { - return std::tie(this->query, this->key, this->value) < - std::tie(other.query, other.key, other.value); -} -bool ParallelMultiHeadAttentionInputs::operator>( - ParallelMultiHeadAttentionInputs const &other) const { - return std::tie(this->query, this->key, this->value) > - std::tie(other.query, other.key, other.value); -} -bool ParallelMultiHeadAttentionInputs::operator<=( - ParallelMultiHeadAttentionInputs const &other) const { - return std::tie(this->query, this->key, this->value) <= - std::tie(other.query, other.key, other.value); -} -bool ParallelMultiHeadAttentionInputs::operator>=( - ParallelMultiHeadAttentionInputs const &other) const { - return std::tie(this->query, this->key, this->value) >= - std::tie(other.query, other.key, other.value); -} } // namespace FlexFlow namespace std { diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index a9ca71a060..5f54f2f2d6 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -7,3 +7,51 @@ ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape co } } // namespace FlexFlow + +/* +#include "op-attrs/ops/pool_2d.h" +#include "parallel_dim_mapping_record.h" +#include "parallel_dim_mapping_record_solver.h" + +namespace FlexFlow { + +namespace Input { +constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, + REPLICA = 4; +}; + +namespace Output { +constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, + REPLICA = 4; +}; + +bool Pool2DAttrs::is_valid(ParallelTensorShape const &input) const { + ParallelTensorShape output_shape = this->calculate_output_shape(input); + + return output_shape.is_valid() && (input.at(Input::REPLICA).degree == 1); +} + +static std::vector + construct_mappings(ParallelTensorShape const &input_shape) { + auto const outputMappings = construct_output_parallel_dims({ + {Input::REPLICA, MappingOperation::PARTITION, Output::REPLICA}, + {Input::SAMPLE, MappingOperation::PARTITION, Output::SAMPLE}, + {Input::CHANNEL, MappingOperation::PARTITION, Output::CHANNEL}, + {Input::HEIGHT, MappingOperation::PARTITION, Output::HEIGHT}, + {Input::WIDTH, MappingOperation::PARTITION, Output::WIDTH}, + }); + + return outputMappings; +} + +static ParallelDimMappingSolution + solve_mappings(ParallelTensorShape const &input) { + return solve_parallel_dim_mappings(construct_mappings(input), {input}, 0, 1); +} + +ParallelTensorShape Pool2DAttrs::calculate_output_shape(ParallelTensorShape const &input) const { + return solve_mappings(input).output_shapes.at(0); +} + +} // namespace FlexFlow +*/ diff --git a/lib/op-attrs/src/op-attrs/ops/reduction.cc b/lib/op-attrs/src/op-attrs/ops/reduction.cc index 2396772a94..37b60dbc60 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduction.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduction.cc @@ -6,4 +6,12 @@ ParallelTensorShape get_output_shape(ReductionAttrs const &attrs, ParallelTensor NOT_IMPLEMENTED(); } +/* ParallelTensorShape ReductionAttrs::output_shape(ParallelTensorShape const + * &input_shape) const { */ +/* ParallelTensorShape output = input_shape; */ +/* output.at(this->reduction_legion_dim).degree /= this->reduction_degree; */ +/* output.at(this->reduction_legion_dim).size /= this->reduction_degree; */ +/* return output; */ +/* } */ + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/repartition.cc b/lib/op-attrs/src/op-attrs/ops/repartition.cc index 45b54df80b..f35cbb35c7 100644 --- a/lib/op-attrs/src/op-attrs/ops/repartition.cc +++ b/lib/op-attrs/src/op-attrs/ops/repartition.cc @@ -6,4 +6,10 @@ ParallelTensorShape get_output_shape(RepartitionAttrs const &, ParallelTensorSha NOT_IMPLEMENTED(); } +/* bool RepartitionAttrs::is_valid(ParallelTensorShape const &input_shape) const + * { */ +/* ParallelDim dim = input_shape.at(this->repartition_legion_dim); */ +/* return (dim.size % this->repartition_degree * dim.degree == 0); */ +/* } */ + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc index df88de73ff..c2016c9f8f 100644 --- a/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc +++ b/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc @@ -1,95 +1,108 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/parallel_dim.struct.toml +// lib/op-attrs/include/op-attrs/parallel_dim.variant.toml /* proj-data { - "generated_from": "186bedde7826c7a3d00343ed63ab9971" + "generated_from": "5550fc7ad51892b3411ef274c76e7d85" } */ #include "op-attrs/parallel_dim.dtg.h" +#include "fmt/format.h" #include +#include namespace FlexFlow { -ParallelDim::ParallelDim(size_t const &size, - int const °ree, - bool const &is_replica_dim) - : size(size), degree(degree), is_replica_dim(is_replica_dim) {} +ParallelDim::ParallelDim(::FlexFlow::ShardParallelDim const &v) + : raw_variant(v) {} +ParallelDim::ParallelDim(::FlexFlow::ReplicaParallelDim const &v) + : raw_variant(v) {} bool ParallelDim::operator==(ParallelDim const &other) const { - return std::tie(this->size, this->degree, this->is_replica_dim) == - std::tie(other.size, other.degree, other.is_replica_dim); + return this->raw_variant == other.raw_variant; } bool ParallelDim::operator!=(ParallelDim const &other) const { - return std::tie(this->size, this->degree, this->is_replica_dim) != - std::tie(other.size, other.degree, other.is_replica_dim); + return this->raw_variant != other.raw_variant; } bool ParallelDim::operator<(ParallelDim const &other) const { - return std::tie(this->size, this->degree, this->is_replica_dim) < - std::tie(other.size, other.degree, other.is_replica_dim); + return this->raw_variant < other.raw_variant; } bool ParallelDim::operator>(ParallelDim const &other) const { - return std::tie(this->size, this->degree, this->is_replica_dim) > - std::tie(other.size, other.degree, other.is_replica_dim); + return this->raw_variant > other.raw_variant; } bool ParallelDim::operator<=(ParallelDim const &other) const { - return std::tie(this->size, this->degree, this->is_replica_dim) <= - std::tie(other.size, other.degree, other.is_replica_dim); + return this->raw_variant <= other.raw_variant; } bool ParallelDim::operator>=(ParallelDim const &other) const { - return std::tie(this->size, this->degree, this->is_replica_dim) >= - std::tie(other.size, other.degree, other.is_replica_dim); + return this->raw_variant >= other.raw_variant; } } // namespace FlexFlow - namespace std { -size_t hash::operator()( - FlexFlow::ParallelDim const &x) const { - size_t result = 0; - result ^= - std::hash{}(x.size) + 0x9e3779b9 + (result << 6) + (result >> 2); - result ^= - std::hash{}(x.degree) + 0x9e3779b9 + (result << 6) + (result >> 2); - result ^= std::hash{}(x.is_replica_dim) + 0x9e3779b9 + (result << 6) + - (result >> 2); - return result; +size_t hash<::FlexFlow::ParallelDim>::operator()( + ::FlexFlow::ParallelDim const &x) const { + return std::hash>{}( + x.raw_variant); } } // namespace std - namespace nlohmann { -FlexFlow::ParallelDim - adl_serializer::from_json(json const &j) { - return {j.at("size").template get(), - j.at("degree").template get(), - j.at("is_replica_dim").template get()}; +::FlexFlow::ParallelDim + adl_serializer<::FlexFlow::ParallelDim>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "shard_dim") { + return ::FlexFlow::ParallelDim{ + j.at("value").template get<::FlexFlow::ShardParallelDim>()}; + } else if (key == "replica_dim") { + return ::FlexFlow::ParallelDim{ + j.at("value").template get<::FlexFlow::ReplicaParallelDim>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } } -void adl_serializer::to_json( - json &j, FlexFlow::ParallelDim const &v) { +void adl_serializer<::FlexFlow::ParallelDim>::to_json( + json &j, ::FlexFlow::ParallelDim const &x) { j["__type"] = "ParallelDim"; - j["size"] = v.size; - j["degree"] = v.degree; - j["is_replica_dim"] = v.is_replica_dim; + switch (x.index()) { + case 0: { + j["type"] = "shard_dim"; + j["value"] = x.get<::FlexFlow::ShardParallelDim>(); + break; + } + case 1: { + j["type"] = "replica_dim"; + j["value"] = x.get<::FlexFlow::ReplicaParallelDim>(); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type ParallelDim", x.index())); + } + } } } // namespace nlohmann - -namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( - gen::arbitrary(), gen::arbitrary(), gen::arbitrary()); -} -} // namespace rc - namespace FlexFlow { -std::string format_as(ParallelDim const &x) { +std::string format_as(::FlexFlow::ParallelDim const &x) { std::ostringstream oss; - oss << ""; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type ParallelDim", x.index())); + break; + } + } return oss.str(); } -std::ostream &operator<<(std::ostream &s, ParallelDim const &x) { +std::ostream &operator<<(std::ostream &s, ::FlexFlow::ParallelDim const &x) { return s << fmt::to_string(x); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index a1fe25e0b6..531d571309 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -1,34 +1,59 @@ #include "op-attrs/parallel_tensor_dims.h" +#include "op-attrs/replica_parallel_dim_set.h" #include "utils/containers.h" +#include "op-attrs/shard_parallel_dim.h" +#include "op-attrs/replica_parallel_dim.h" namespace FlexFlow { -FFOrdered const &ff_ordered(ParallelTensorDims const &d) { - return d.ff_ordered; +FFOrdered ff_ordered_shard_dims(ParallelTensorDims const &d) { + return d.shard_dims; } -std::vector as_vector(ParallelTensorDims const &d) { - return as_vector(d.ff_ordered); +std::unordered_set replica_dims(ParallelTensorDims const &d) { + return get_replica_dims(d.replica_dims); } -int get_num_replica_dims(ParallelTensorDims const &d) { - return count(d.ff_ordered, is_replica_dim); +size_t num_shard_dims(ParallelTensorDims const &dims) { + return dims.shard_dims.size(); +} + +int total_replica_degree(ParallelTensorDims const &dims) { + return product(transform(replica_dims(dims), [](ReplicaParallelDim const &d) { return d.degree; })); +} + +int total_shard_degree(ParallelTensorDims const &dims) { + return product(transform(as_vector(dims.shard_dims), [](ShardParallelDim const &d) { return d.degree; })); +} + +int total_parallel_degree(ParallelTensorDims const &dims) { + return total_replica_degree(dims) * total_shard_degree(dims); } bool is_valid(ParallelTensorDims const &dims) { - return all_of(dims.ff_ordered, [](ParallelDim const &d) { return is_valid(d); }); + return all_of(dims.shard_dims, [](ShardParallelDim const &d) { return is_valid(d); }) + && all_of(replica_dims(dims), [](ReplicaParallelDim const &d) { return is_valid(d); }); } -size_t num_dims(ParallelTensorDims const &dims) { - return dims.ff_ordered.size(); +ShardParallelDim shard_dim_at_idx(ParallelTensorDims const &d, ff_dim_t idx) { + return d.shard_dims.at(idx); } -ParallelDim dim_at_idx(ParallelTensorDims const &d, ff_dim_t idx) { - return d.ff_ordered.at(idx); +ShardParallelDim &shard_dim_at_idx(ParallelTensorDims &d, ff_dim_t idx) { + return d.shard_dims.at(idx); } -ParallelDim &dim_at_idx(ParallelTensorDims &d, ff_dim_t idx) { - return d.ff_ordered.at(idx); +TensorDims get_piece_dims(ParallelTensorDims const &) { + NOT_IMPLEMENTED(); +} + +TensorDims get_tensor_dims_unsafe(ParallelTensorDims const &) { + NOT_IMPLEMENTED(); +} + + +TensorDims get_reduced_dims(ParallelTensorDims const &) { + NOT_IMPLEMENTED(); } } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc index e4e8f0106a..aee6e9ab14 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc @@ -3,37 +3,32 @@ // lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml /* proj-data { - "generated_from": "b46ffa08758bdcc57a75183255248ca6" + "generated_from": "141639bdce009a1594501f33c2f25c9e" } */ #include "op-attrs/parallel_tensor_dims.dtg.h" #include "op-attrs/dim_ordered.h" -#include "op-attrs/parallel_dim.h" +#include "op-attrs/replica_parallel_dim_set.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "utils/fmt/pair.h" +#include "utils/fmt/unordered_map.h" #include +#include namespace FlexFlow { ParallelTensorDims::ParallelTensorDims( - ::FlexFlow::FFOrdered<::FlexFlow::ParallelDim> const &ff_ordered) - : ff_ordered(ff_ordered) {} + ::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim> const &shard_dims, + ::FlexFlow::ReplicaParallelDimSet const &replica_dims) + : shard_dims(shard_dims), replica_dims(replica_dims) {} bool ParallelTensorDims::operator==(ParallelTensorDims const &other) const { - return std::tie(this->ff_ordered) == std::tie(other.ff_ordered); + return std::tie(this->shard_dims, this->replica_dims) == + std::tie(other.shard_dims, other.replica_dims); } bool ParallelTensorDims::operator!=(ParallelTensorDims const &other) const { - return std::tie(this->ff_ordered) != std::tie(other.ff_ordered); -} -bool ParallelTensorDims::operator<(ParallelTensorDims const &other) const { - return std::tie(this->ff_ordered) < std::tie(other.ff_ordered); -} -bool ParallelTensorDims::operator>(ParallelTensorDims const &other) const { - return std::tie(this->ff_ordered) > std::tie(other.ff_ordered); -} -bool ParallelTensorDims::operator<=(ParallelTensorDims const &other) const { - return std::tie(this->ff_ordered) <= std::tie(other.ff_ordered); -} -bool ParallelTensorDims::operator>=(ParallelTensorDims const &other) const { - return std::tie(this->ff_ordered) >= std::tie(other.ff_ordered); + return std::tie(this->shard_dims, this->replica_dims) != + std::tie(other.shard_dims, other.replica_dims); } } // namespace FlexFlow @@ -41,8 +36,10 @@ namespace std { size_t hash::operator()( FlexFlow::ParallelTensorDims const &x) const { size_t result = 0; - result ^= std::hash<::FlexFlow::FFOrdered<::FlexFlow::ParallelDim>>{}( - x.ff_ordered) + + result ^= std::hash<::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>>{}( + x.shard_dims) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ReplicaParallelDimSet>{}(x.replica_dims) + 0x9e3779b9 + (result << 6) + (result >> 2); return result; } @@ -51,13 +48,16 @@ size_t hash::operator()( namespace nlohmann { FlexFlow::ParallelTensorDims adl_serializer::from_json(json const &j) { - return {j.at("ff_ordered") - .template get<::FlexFlow::FFOrdered<::FlexFlow::ParallelDim>>()}; + return { + j.at("shard_dims") + .template get<::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>>(), + j.at("replica_dims").template get<::FlexFlow::ReplicaParallelDimSet>()}; } void adl_serializer::to_json( json &j, FlexFlow::ParallelTensorDims const &v) { j["__type"] = "ParallelTensorDims"; - j["ff_ordered"] = v.ff_ordered; + j["shard_dims"] = v.shard_dims; + j["replica_dims"] = v.replica_dims; } } // namespace nlohmann @@ -65,7 +65,8 @@ namespace FlexFlow { std::string format_as(ParallelTensorDims const &x) { std::ostringstream oss; oss << ""; return oss.str(); } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 8d6125c369..2c5e556224 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -1,33 +1,34 @@ #include "op-attrs/parallel_tensor_shape.h" #include "utils/containers.h" #include "utils/hash-utils.h" +#include "op-attrs/tensor_dims.h" namespace FlexFlow { -int num_dims(ParallelTensorShape const &s) { - return num_dims(s.dims); +int num_shard_dims(ParallelTensorShape const &s) { + return num_shard_dims(s.dims); } -int get_num_replica_dims(ParallelTensorShape const &shape) { - return get_num_replica_dims(shape.dims); +std::unordered_set replica_dims(ParallelTensorShape const &s) { + return replica_dims(s.dims); } int get_num_replicas(ParallelTensorShape const &shape) { return product( - transform(filter(as_vector(shape.dims), is_replica_dim), - [](ParallelDim const &d) -> int { return d.degree; })); + transform(replica_dims(shape), + [](ReplicaParallelDim const &d) -> int { return d.degree; })); } bool is_valid(ParallelTensorShape const &shape) { return is_valid(shape.dims); } -ParallelDim dim_at_idx(ParallelTensorShape const &s, ff_dim_t d) { - return dim_at_idx(s.dims, d); +ShardParallelDim shard_dim_at_idx(ParallelTensorShape const &s, ff_dim_t d) { + return shard_dim_at_idx(s.dims, d); } -ParallelDim &dim_at_idx(ParallelTensorShape &s, ff_dim_t d) { - return dim_at_idx(s.dims, d); +ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &s, ff_dim_t d) { + return shard_dim_at_idx(s.dims, d); } ParallelTensorShape lift_to_parallel(TensorShape const &s) { @@ -37,4 +38,12 @@ ParallelTensorShape lift_to_parallel(TensorShape const &s) { TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) { NOT_IMPLEMENTED(); } + +TensorShape get_reduced_shape(ParallelTensorShape const &s) { + return TensorShape{ + get_reduced_dims(s.dims), + s.data_type, + }; +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc index 037acbf996..f990e21fa4 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml /* proj-data { - "generated_from": "b2d36c9212916e66569af4e958c893f4" + "generated_from": "bc7e838003fe037b95d45cd5ab4aa16f" } */ @@ -26,22 +26,6 @@ bool ParallelTensorShape::operator!=(ParallelTensorShape const &other) const { return std::tie(this->dims, this->data_type) != std::tie(other.dims, other.data_type); } -bool ParallelTensorShape::operator<(ParallelTensorShape const &other) const { - return std::tie(this->dims, this->data_type) < - std::tie(other.dims, other.data_type); -} -bool ParallelTensorShape::operator>(ParallelTensorShape const &other) const { - return std::tie(this->dims, this->data_type) > - std::tie(other.dims, other.data_type); -} -bool ParallelTensorShape::operator<=(ParallelTensorShape const &other) const { - return std::tie(this->dims, this->data_type) <= - std::tie(other.dims, other.data_type); -} -bool ParallelTensorShape::operator>=(ParallelTensorShape const &other) const { - return std::tie(this->dims, this->data_type) >= - std::tie(other.dims, other.data_type); -} } // namespace FlexFlow namespace std { diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc new file mode 100644 index 0000000000..cb54736a7a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc @@ -0,0 +1,12 @@ +#include "op-attrs/pcg_operator_attrs.h" + +namespace FlexFlow { + +bool is_parallel_op(PCGOperatorAttrs const &attrs) { + return ( attrs.has() + || attrs.has() + || attrs.has() + || attrs.has()); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc index 5d915ab437..8baa07e537 100644 --- a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml /* proj-data { - "generated_from": "e1b5c307ae023ce6d504f605c7ef8491" + "generated_from": "cf0da4385b7554748a06ec25ccf17f2f" } */ @@ -20,6 +20,8 @@ PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::BatchNormAttrs const &v) : raw_variant(v) {} PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::CastAttrs const &v) : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::CombineAttrs const &v) + : raw_variant(v) {} PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ConcatAttrs const &v) : raw_variant(v) {} PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::Conv2DAttrs const &v) @@ -52,6 +54,12 @@ PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::Pool2DAttrs const &v) : raw_variant(v) {} PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReduceAttrs const &v) : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReductionAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::RepartitionAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReplicateAttrs const &v) + : raw_variant(v) {} PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReverseAttrs const &v) : raw_variant(v) {} PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReshapeAttrs const &v) @@ -89,6 +97,7 @@ size_t hash<::FlexFlow::PCGOperatorAttrs>::operator()( return std::hash::operator()( ::FlexFlow::NoopAttrs, ::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, + ::FlexFlow::ReductionAttrs, + ::FlexFlow::RepartitionAttrs, + ::FlexFlow::ReplicateAttrs, ::FlexFlow::ReverseAttrs, ::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, @@ -117,79 +129,91 @@ namespace nlohmann { ::FlexFlow::PCGOperatorAttrs adl_serializer<::FlexFlow::PCGOperatorAttrs>::from_json(json const &j) { std::string key = j.at("type").template get(); - if (key == "::FlexFlow::BatchMatmulAttrs") { + if (key == "batch_matmul") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::BatchMatmulAttrs>()}; - } else if (key == "::FlexFlow::BatchNormAttrs") { + } else if (key == "batch_norm") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::BatchNormAttrs>()}; - } else if (key == "::FlexFlow::CastAttrs") { + } else if (key == "cast") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::CastAttrs>()}; - } else if (key == "::FlexFlow::ConcatAttrs") { + } else if (key == "combine_distributed") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::CombineAttrs>()}; + } else if (key == "concat") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::ConcatAttrs>()}; - } else if (key == "::FlexFlow::Conv2DAttrs") { + } else if (key == "conv2d") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::Conv2DAttrs>()}; - } else if (key == "::FlexFlow::DropoutAttrs") { + } else if (key == "dropout") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::DropoutAttrs>()}; - } else if (key == "::FlexFlow::ElementBinaryAttrs") { + } else if (key == "element_binary") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::ElementBinaryAttrs>()}; - } else if (key == "::FlexFlow::ElementUnaryAttrs") { + } else if (key == "element_unary") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::ElementUnaryAttrs>()}; - } else if (key == "::FlexFlow::ElementScalarUnaryAttrs") { + } else if (key == "element_scalar_unary") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::ElementScalarUnaryAttrs>()}; - } else if (key == "::FlexFlow::EmbeddingAttrs") { + } else if (key == "embedding") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::EmbeddingAttrs>()}; - } else if (key == "::FlexFlow::FlatAttrs") { + } else if (key == "flat") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::FlatAttrs>()}; - } else if (key == "::FlexFlow::GatherAttrs") { + } else if (key == "gather") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::GatherAttrs>()}; - } else if (key == "::FlexFlow::InputAttrs") { + } else if (key == "input") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::InputAttrs>()}; - } else if (key == "::FlexFlow::LayerNormAttrs") { + } else if (key == "layer_norm") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::LayerNormAttrs>()}; - } else if (key == "::FlexFlow::LinearAttrs") { + } else if (key == "linear") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::LinearAttrs>()}; - } else if (key == "::FlexFlow::MultiHeadAttentionAttrs") { + } else if (key == "multi_head_attention") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::MultiHeadAttentionAttrs>()}; - } else if (key == "::FlexFlow::NoopAttrs") { + } else if (key == "noop") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::NoopAttrs>()}; - } else if (key == "::FlexFlow::Pool2DAttrs") { + } else if (key == "pool2d") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::Pool2DAttrs>()}; - } else if (key == "::FlexFlow::ReduceAttrs") { + } else if (key == "reduce") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::ReduceAttrs>()}; - } else if (key == "::FlexFlow::ReverseAttrs") { + } else if (key == "reduce_distributed") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ReductionAttrs>()}; + } else if (key == "partition_distributed") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::RepartitionAttrs>()}; + } else if (key == "replicate_distributed") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ReplicateAttrs>()}; + } else if (key == "reverse") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::ReverseAttrs>()}; - } else if (key == "::FlexFlow::ReshapeAttrs") { + } else if (key == "reshape") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::ReshapeAttrs>()}; - } else if (key == "::FlexFlow::SplitAttrs") { + } else if (key == "split") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::SplitAttrs>()}; - } else if (key == "::FlexFlow::SoftmaxAttrs") { + } else if (key == "softmax") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::SoftmaxAttrs>()}; - } else if (key == "::FlexFlow::TopKAttrs") { + } else if (key == "topk") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::TopKAttrs>()}; - } else if (key == "::FlexFlow::TransposeAttrs") { + } else if (key == "transpose") { return ::FlexFlow::PCGOperatorAttrs{ j.at("value").template get<::FlexFlow::TransposeAttrs>()}; } else { @@ -201,127 +225,147 @@ void adl_serializer<::FlexFlow::PCGOperatorAttrs>::to_json( j["__type"] = "PCGOperatorAttrs"; switch (x.index()) { case 0: { - j["type"] = "::FlexFlow::BatchMatmulAttrs"; + j["type"] = "batch_matmul"; j["value"] = x.get<::FlexFlow::BatchMatmulAttrs>(); break; } case 1: { - j["type"] = "::FlexFlow::BatchNormAttrs"; + j["type"] = "batch_norm"; j["value"] = x.get<::FlexFlow::BatchNormAttrs>(); break; } case 2: { - j["type"] = "::FlexFlow::CastAttrs"; + j["type"] = "cast"; j["value"] = x.get<::FlexFlow::CastAttrs>(); break; } case 3: { - j["type"] = "::FlexFlow::ConcatAttrs"; - j["value"] = x.get<::FlexFlow::ConcatAttrs>(); + j["type"] = "combine_distributed"; + j["value"] = x.get<::FlexFlow::CombineAttrs>(); break; } case 4: { - j["type"] = "::FlexFlow::Conv2DAttrs"; - j["value"] = x.get<::FlexFlow::Conv2DAttrs>(); + j["type"] = "concat"; + j["value"] = x.get<::FlexFlow::ConcatAttrs>(); break; } case 5: { - j["type"] = "::FlexFlow::DropoutAttrs"; - j["value"] = x.get<::FlexFlow::DropoutAttrs>(); + j["type"] = "conv2d"; + j["value"] = x.get<::FlexFlow::Conv2DAttrs>(); break; } case 6: { - j["type"] = "::FlexFlow::ElementBinaryAttrs"; - j["value"] = x.get<::FlexFlow::ElementBinaryAttrs>(); + j["type"] = "dropout"; + j["value"] = x.get<::FlexFlow::DropoutAttrs>(); break; } case 7: { - j["type"] = "::FlexFlow::ElementUnaryAttrs"; - j["value"] = x.get<::FlexFlow::ElementUnaryAttrs>(); + j["type"] = "element_binary"; + j["value"] = x.get<::FlexFlow::ElementBinaryAttrs>(); break; } case 8: { - j["type"] = "::FlexFlow::ElementScalarUnaryAttrs"; - j["value"] = x.get<::FlexFlow::ElementScalarUnaryAttrs>(); + j["type"] = "element_unary"; + j["value"] = x.get<::FlexFlow::ElementUnaryAttrs>(); break; } case 9: { - j["type"] = "::FlexFlow::EmbeddingAttrs"; - j["value"] = x.get<::FlexFlow::EmbeddingAttrs>(); + j["type"] = "element_scalar_unary"; + j["value"] = x.get<::FlexFlow::ElementScalarUnaryAttrs>(); break; } case 10: { - j["type"] = "::FlexFlow::FlatAttrs"; - j["value"] = x.get<::FlexFlow::FlatAttrs>(); + j["type"] = "embedding"; + j["value"] = x.get<::FlexFlow::EmbeddingAttrs>(); break; } case 11: { - j["type"] = "::FlexFlow::GatherAttrs"; - j["value"] = x.get<::FlexFlow::GatherAttrs>(); + j["type"] = "flat"; + j["value"] = x.get<::FlexFlow::FlatAttrs>(); break; } case 12: { - j["type"] = "::FlexFlow::InputAttrs"; - j["value"] = x.get<::FlexFlow::InputAttrs>(); + j["type"] = "gather"; + j["value"] = x.get<::FlexFlow::GatherAttrs>(); break; } case 13: { - j["type"] = "::FlexFlow::LayerNormAttrs"; - j["value"] = x.get<::FlexFlow::LayerNormAttrs>(); + j["type"] = "input"; + j["value"] = x.get<::FlexFlow::InputAttrs>(); break; } case 14: { - j["type"] = "::FlexFlow::LinearAttrs"; - j["value"] = x.get<::FlexFlow::LinearAttrs>(); + j["type"] = "layer_norm"; + j["value"] = x.get<::FlexFlow::LayerNormAttrs>(); break; } case 15: { - j["type"] = "::FlexFlow::MultiHeadAttentionAttrs"; - j["value"] = x.get<::FlexFlow::MultiHeadAttentionAttrs>(); + j["type"] = "linear"; + j["value"] = x.get<::FlexFlow::LinearAttrs>(); break; } case 16: { - j["type"] = "::FlexFlow::NoopAttrs"; - j["value"] = x.get<::FlexFlow::NoopAttrs>(); + j["type"] = "multi_head_attention"; + j["value"] = x.get<::FlexFlow::MultiHeadAttentionAttrs>(); break; } case 17: { - j["type"] = "::FlexFlow::Pool2DAttrs"; - j["value"] = x.get<::FlexFlow::Pool2DAttrs>(); + j["type"] = "noop"; + j["value"] = x.get<::FlexFlow::NoopAttrs>(); break; } case 18: { - j["type"] = "::FlexFlow::ReduceAttrs"; - j["value"] = x.get<::FlexFlow::ReduceAttrs>(); + j["type"] = "pool2d"; + j["value"] = x.get<::FlexFlow::Pool2DAttrs>(); break; } case 19: { - j["type"] = "::FlexFlow::ReverseAttrs"; - j["value"] = x.get<::FlexFlow::ReverseAttrs>(); + j["type"] = "reduce"; + j["value"] = x.get<::FlexFlow::ReduceAttrs>(); break; } case 20: { - j["type"] = "::FlexFlow::ReshapeAttrs"; - j["value"] = x.get<::FlexFlow::ReshapeAttrs>(); + j["type"] = "reduce_distributed"; + j["value"] = x.get<::FlexFlow::ReductionAttrs>(); break; } case 21: { - j["type"] = "::FlexFlow::SplitAttrs"; - j["value"] = x.get<::FlexFlow::SplitAttrs>(); + j["type"] = "partition_distributed"; + j["value"] = x.get<::FlexFlow::RepartitionAttrs>(); break; } case 22: { - j["type"] = "::FlexFlow::SoftmaxAttrs"; - j["value"] = x.get<::FlexFlow::SoftmaxAttrs>(); + j["type"] = "replicate_distributed"; + j["value"] = x.get<::FlexFlow::ReplicateAttrs>(); break; } case 23: { - j["type"] = "::FlexFlow::TopKAttrs"; - j["value"] = x.get<::FlexFlow::TopKAttrs>(); + j["type"] = "reverse"; + j["value"] = x.get<::FlexFlow::ReverseAttrs>(); break; } case 24: { - j["type"] = "::FlexFlow::TransposeAttrs"; + j["type"] = "reshape"; + j["value"] = x.get<::FlexFlow::ReshapeAttrs>(); + break; + } + case 25: { + j["type"] = "split"; + j["value"] = x.get<::FlexFlow::SplitAttrs>(); + break; + } + case 26: { + j["type"] = "softmax"; + j["value"] = x.get<::FlexFlow::SoftmaxAttrs>(); + break; + } + case 27: { + j["type"] = "topk"; + j["value"] = x.get<::FlexFlow::TopKAttrs>(); + break; + } + case 28: { + j["type"] = "transpose"; j["value"] = x.get<::FlexFlow::TransposeAttrs>(); break; } @@ -337,127 +381,143 @@ std::string format_as(::FlexFlow::PCGOperatorAttrs const &x) { std::ostringstream oss; switch (x.index()) { case 0: { - oss << ""; break; } case 1: { - oss << ""; break; } case 2: { - oss << ""; + oss << ""; break; } case 3: { - oss << ""; + oss << ""; break; } case 4: { - oss << ""; + oss << ""; break; } case 5: { - oss << ""; + oss << ""; break; } case 6: { - oss << ""; + oss << ""; break; } case 7: { - oss << ""; + oss << ""; break; } case 8: { - oss << ""; + oss << ""; break; } case 9: { - oss << ""; + oss << ""; break; } case 10: { - oss << ""; + oss << ""; break; } case 11: { - oss << ""; + oss << ""; break; } case 12: { - oss << ""; + oss << ""; break; } case 13: { - oss << ""; + oss << ""; break; } case 14: { - oss << ""; + oss << ""; break; } case 15: { - oss << ""; + oss << ""; break; } case 16: { - oss << ""; + oss << ""; break; } case 17: { - oss << ""; + oss << ""; break; } case 18: { - oss << ""; + oss << ""; break; } case 19: { - oss << ""; + oss << ""; break; } case 20: { - oss << ""; + oss << ""; break; } case 21: { - oss << ""; + oss << ""; break; } case 22: { - oss << ""; + oss << ""; break; } case 23: { - oss << ""; + oss << ""; break; } case 24: { - oss << ""; + break; + } + case 25: { + oss << ""; + break; + } + case 26: { + oss << ""; + break; + } + case 27: { + oss << ""; + break; + } + case 28: { + oss << ""; break; } diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim.cc new file mode 100644 index 0000000000..44b17c8b44 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim.cc @@ -0,0 +1,9 @@ +#include "op-attrs/replica_parallel_dim.h" + +namespace FlexFlow { + +bool is_valid(ReplicaParallelDim const &d) { + return d.degree > 0; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc new file mode 100644 index 0000000000..a1256ad79a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc @@ -0,0 +1,91 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml +/* proj-data +{ + "generated_from": "f501393070c8d55a05c43dd73a81a8d7" +} +*/ + +#include "op-attrs/replica_parallel_dim.dtg.h" + +#include "op-attrs/replica_type.dtg.h" +#include + +namespace FlexFlow { +ReplicaParallelDim::ReplicaParallelDim( + int const °ree, ::FlexFlow::ReplicaType const &replica_type) + : degree(degree), replica_type(replica_type) {} +bool ReplicaParallelDim::operator==(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) == + std::tie(other.degree, other.replica_type); +} +bool ReplicaParallelDim::operator!=(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) != + std::tie(other.degree, other.replica_type); +} +bool ReplicaParallelDim::operator<(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) < + std::tie(other.degree, other.replica_type); +} +bool ReplicaParallelDim::operator>(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) > + std::tie(other.degree, other.replica_type); +} +bool ReplicaParallelDim::operator<=(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) <= + std::tie(other.degree, other.replica_type); +} +bool ReplicaParallelDim::operator>=(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) >= + std::tie(other.degree, other.replica_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReplicaParallelDim const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.degree) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ReplicaType>{}(x.replica_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReplicaParallelDim + adl_serializer::from_json(json const &j) { + return {j.at("degree").template get(), + j.at("replica_type").template get<::FlexFlow::ReplicaType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReplicaParallelDim const &v) { + j["__type"] = "ReplicaParallelDim"; + j["degree"] = v.degree; + j["replica_type"] = v.replica_type; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), gen::arbitrary<::FlexFlow::ReplicaType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicaParallelDim const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReplicaParallelDim const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc new file mode 100644 index 0000000000..16bb1508c4 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc @@ -0,0 +1,32 @@ +#include "op-attrs/replica_parallel_dim_set.h" +#include "utils/exception.h" + +namespace FlexFlow { + +ReplicaParallelDimSet empty_replica_parallel_dim_set() { + return ReplicaParallelDimSet{1, 1}; +} + +int get_order_of_replica_type(ReplicaParallelDimSet const &s, ReplicaType replica_type) { + switch (replica_type) { + case ReplicaType::SUM: + return s.sum_degree; + case ReplicaType::DISCARD_COPY: + return s.discard_copy_degree; + default: + throw mk_runtime_error(fmt::format("Unexpected ReplicaType value: {}", static_cast(replica_type))); + } +} + +std::unordered_set get_replica_dims(ReplicaParallelDimSet const &s) { + return std::unordered_set{ + ReplicaParallelDim{s.sum_degree, ReplicaType::SUM}, + ReplicaParallelDim{s.discard_copy_degree, ReplicaType::DISCARD_COPY}, + }; +} + +bool is_valid(ReplicaParallelDimSet const &s) { + return s.sum_degree > 0 && s.discard_copy_degree > 0; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc new file mode 100644 index 0000000000..3b3ac59a9b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc @@ -0,0 +1,96 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml +/* proj-data +{ + "generated_from": "20d8004e6f1e710688fe692b92dc2816" +} +*/ + +#include "op-attrs/replica_parallel_dim_set.dtg.h" + +#include + +namespace FlexFlow { +ReplicaParallelDimSet::ReplicaParallelDimSet(int const &sum_degree, + int const &discard_copy_degree) + : sum_degree(sum_degree), discard_copy_degree(discard_copy_degree) {} +bool ReplicaParallelDimSet::operator==( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) == + std::tie(other.sum_degree, other.discard_copy_degree); +} +bool ReplicaParallelDimSet::operator!=( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) != + std::tie(other.sum_degree, other.discard_copy_degree); +} +bool ReplicaParallelDimSet::operator<( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) < + std::tie(other.sum_degree, other.discard_copy_degree); +} +bool ReplicaParallelDimSet::operator>( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) > + std::tie(other.sum_degree, other.discard_copy_degree); +} +bool ReplicaParallelDimSet::operator<=( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) <= + std::tie(other.sum_degree, other.discard_copy_degree); +} +bool ReplicaParallelDimSet::operator>=( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) >= + std::tie(other.sum_degree, other.discard_copy_degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReplicaParallelDimSet const &x) const { + size_t result = 0; + result ^= std::hash{}(x.sum_degree) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.discard_copy_degree) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReplicaParallelDimSet + adl_serializer::from_json(json const &j) { + return {j.at("sum_degree").template get(), + j.at("discard_copy_degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReplicaParallelDimSet const &v) { + j["__type"] = "ReplicaParallelDimSet"; + j["sum_degree"] = v.sum_degree; + j["discard_copy_degree"] = v.discard_copy_degree; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicaParallelDimSet const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReplicaParallelDimSet const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_type.dtg.cc b/lib/op-attrs/src/op-attrs/replica_type.dtg.cc new file mode 100644 index 0000000000..d0410c49e2 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/replica_type.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_type.enum.toml +/* proj-data +{ + "generated_from": "6ecba7a6851b8bea93705bba24661149" +} +*/ + +#include "op-attrs/replica_type.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::ReplicaType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(ReplicaType x) { + switch (x) { + case ReplicaType::SUM: + return "SUM"; + case ReplicaType::DISCARD_COPY: + return "DISCARD_COPY"; + default: + std::ostringstream oss; + oss << "Unknown ReplicaType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, ReplicaType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, ReplicaType x) { + switch (x) { + case ReplicaType::SUM: + j = "SUM"; + break; + case ReplicaType::DISCARD_COPY: + j = "DISCARD_COPY"; + break; + default: + std::ostringstream oss; + oss << "Unknown ReplicaType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, ReplicaType &x) { + std::string as_str = j.get(); + if (as_str == "SUM") { + x = ReplicaType::SUM; + } else if (as_str == "DISCARD_COPY") { + x = ReplicaType::DISCARD_COPY; + } else { + std::ostringstream oss; + oss << "Unknown ReplicaType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element( + FlexFlow::ReplicaType::SUM, FlexFlow::ReplicaType::DISCARD_COPY); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/shard_parallel_dim.cc b/lib/op-attrs/src/op-attrs/shard_parallel_dim.cc new file mode 100644 index 0000000000..d27a857723 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/shard_parallel_dim.cc @@ -0,0 +1,9 @@ +#include "op-attrs/shard_parallel_dim.h" + +namespace FlexFlow { + +bool is_valid(ShardParallelDim const &d) { + return d.degree > 0 && d.size > 0 && (d.size % d.degree) == 0; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/shard_parallel_dim.dtg.cc b/lib/op-attrs/src/op-attrs/shard_parallel_dim.dtg.cc new file mode 100644 index 0000000000..9566eb486b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/shard_parallel_dim.dtg.cc @@ -0,0 +1,89 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml +/* proj-data +{ + "generated_from": "18e074f80556d90b9b27d6515bbf9071" +} +*/ + +#include "op-attrs/shard_parallel_dim.dtg.h" + +#include + +namespace FlexFlow { +ShardParallelDim::ShardParallelDim(size_t const &size, int const °ree) + : size(size), degree(degree) {} +bool ShardParallelDim::operator==(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) == + std::tie(other.size, other.degree); +} +bool ShardParallelDim::operator!=(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) != + std::tie(other.size, other.degree); +} +bool ShardParallelDim::operator<(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) < + std::tie(other.size, other.degree); +} +bool ShardParallelDim::operator>(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) > + std::tie(other.size, other.degree); +} +bool ShardParallelDim::operator<=(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) <= + std::tie(other.size, other.degree); +} +bool ShardParallelDim::operator>=(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) >= + std::tie(other.size, other.degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ShardParallelDim const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.size) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.degree) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ShardParallelDim + adl_serializer::from_json(json const &j) { + return {j.at("size").template get(), + j.at("degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ShardParallelDim const &v) { + j["__type"] = "ShardParallelDim"; + j["size"] = v.size; + j["degree"] = v.degree; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ShardParallelDim const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ShardParallelDim const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index 9f7316e998..af6f4c9e82 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -1,5 +1,7 @@ #include "op-attrs/tensor_dims.h" #include "utils/containers.h" +#include "op-attrs/replica_parallel_dim_set.h" +#include "op-attrs/shard_parallel_dim.dtg.h" namespace FlexFlow { @@ -12,10 +14,12 @@ size_t dim_at_idx(TensorDims const &dims, ff_dim_t idx) { } ParallelTensorDims lift_to_parallel(TensorDims const &dims) { - FFOrdered lifted = { - transform(as_vector(dims.ff_ordered), [](size_t const &size) { return ParallelDim{size, 1, false}; }) + std::vector lifted = transform(as_vector(dims.ff_ordered), [](size_t size) { return ShardParallelDim{size, 1}; }); + + return ParallelTensorDims{ + FFOrdered{lifted}, + empty_replica_parallel_dim_set(), }; - return {lifted}; } } diff --git a/lib/op-attrs/src/tensor_shape.cc b/lib/op-attrs/src/op-attrs/tensor_shape.cc similarity index 60% rename from lib/op-attrs/src/tensor_shape.cc rename to lib/op-attrs/src/op-attrs/tensor_shape.cc index 6e41f9175a..01afbddf1e 100644 --- a/lib/op-attrs/src/tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.cc @@ -1,7 +1,12 @@ #include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_dims.h" namespace FlexFlow { +size_t num_dims(TensorShape const &s) { + return s.dims.ff_ordered.size(); +} + size_t dim_at_idx(TensorShape const &s, ff_dim_t idx) { return dim_at_idx(s.dims, idx); } diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc b/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc index 1538cc82c1..56856070e9 100644 --- a/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc @@ -3,14 +3,14 @@ // lib/op-attrs/include/op-attrs/tensor_shape.struct.toml /* proj-data { - "generated_from": "c02c9d2331d864a25c1443cfe70062d1" + "generated_from": "52968754cf94f415c366d228c87042db" } */ #include "op-attrs/tensor_shape.dtg.h" -#include "op-attrs/datatype.h" -#include "op-attrs/tensor_dims.h" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/tensor_dims.dtg.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/parallel_dim.cc b/lib/op-attrs/src/parallel_dim.cc deleted file mode 100644 index e103625fab..0000000000 --- a/lib/op-attrs/src/parallel_dim.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "op-attrs/parallel_dim.h" - -namespace FlexFlow { - -bool is_valid(ParallelDim const &dim) { - return dim.size > 0 && dim.degree >= 1 && dim.size % dim.degree == 0; -} - -bool is_replica_dim(ParallelDim const &dim) { - return dim.is_replica_dim; -} -} // namespace FlexFlow diff --git a/lib/op-attrs/src/pool_2d.cc b/lib/op-attrs/src/pool_2d.cc deleted file mode 100644 index 0867aeb344..0000000000 --- a/lib/op-attrs/src/pool_2d.cc +++ /dev/null @@ -1,47 +0,0 @@ -#include "op-attrs/ops/pool_2d.h" -#include "parallel_dim_mapping_record.h" -#include "parallel_dim_mapping_record_solver.h" - -namespace FlexFlow { - -namespace Input { -constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, - REPLICA = 4; -}; - -namespace Output { -constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, - REPLICA = 4; -}; - -/* bool Pool2DAttrs::is_valid(ParallelTensorShape const &input) const { */ -/* ParallelTensorShape output_shape = this->calculate_output_shape(input); */ - -/* return output_shape.is_valid() && (input.at(Input::REPLICA).degree == 1); - */ -/* } */ - -static std::vector - construct_mappings(ParallelTensorShape const &input_shape) { - auto const outputMappings = construct_output_parallel_dims({ - {Input::REPLICA, MappingOperation::PARTITION, Output::REPLICA}, - {Input::SAMPLE, MappingOperation::PARTITION, Output::SAMPLE}, - {Input::CHANNEL, MappingOperation::PARTITION, Output::CHANNEL}, - {Input::HEIGHT, MappingOperation::PARTITION, Output::HEIGHT}, - {Input::WIDTH, MappingOperation::PARTITION, Output::WIDTH}, - }); - - return outputMappings; -} - -static ParallelDimMappingSolution - solve_mappings(ParallelTensorShape const &input) { - return solve_parallel_dim_mappings(construct_mappings(input), {input}, 0, 1); -} - -/* ParallelTensorShape Pool2DAttrs::calculate_output_shape(ParallelTensorShape - * const &input) const { */ -/* return solve_mappings(input).output_shapes.at(0); */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/reduce.cc b/lib/op-attrs/src/reduce.cc deleted file mode 100644 index 9d1770d5be..0000000000 --- a/lib/op-attrs/src/reduce.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/reduce.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/reduction.cc b/lib/op-attrs/src/reduction.cc deleted file mode 100644 index 22fc9bab6a..0000000000 --- a/lib/op-attrs/src/reduction.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "op-attrs/ops/reduction.h" - -namespace FlexFlow { - -/* ParallelTensorShape ReductionAttrs::output_shape(ParallelTensorShape const - * &input_shape) const { */ -/* ParallelTensorShape output = input_shape; */ -/* output.at(this->reduction_legion_dim).degree /= this->reduction_degree; */ -/* output.at(this->reduction_legion_dim).size /= this->reduction_degree; */ -/* return output; */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/repartition.cc b/lib/op-attrs/src/repartition.cc deleted file mode 100644 index 672e68b4f6..0000000000 --- a/lib/op-attrs/src/repartition.cc +++ /dev/null @@ -1,11 +0,0 @@ -#include "op-attrs/ops/repartition.h" - -namespace FlexFlow { - -/* bool RepartitionAttrs::is_valid(ParallelTensorShape const &input_shape) const - * { */ -/* ParallelDim dim = input_shape.at(this->repartition_legion_dim); */ -/* return (dim.size % this->repartition_degree * dim.degree == 0); */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/replicate.cc b/lib/op-attrs/src/replicate.cc deleted file mode 100644 index 73ad288d8c..0000000000 --- a/lib/op-attrs/src/replicate.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/replicate.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/reshape.cc b/lib/op-attrs/src/reshape.cc deleted file mode 100644 index e8349e1f26..0000000000 --- a/lib/op-attrs/src/reshape.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/reshape.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/softmax.cc b/lib/op-attrs/src/softmax.cc deleted file mode 100644 index 9f95da4fb7..0000000000 --- a/lib/op-attrs/src/softmax.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/softmax.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/split.cc b/lib/op-attrs/src/split.cc deleted file mode 100644 index acda8f3262..0000000000 --- a/lib/op-attrs/src/split.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/split.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/topk.cc b/lib/op-attrs/src/topk.cc deleted file mode 100644 index 9d701e4868..0000000000 --- a/lib/op-attrs/src/topk.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/topk.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/transpose.cc b/lib/op-attrs/src/transpose.cc deleted file mode 100644 index ad4a84a3d5..0000000000 --- a/lib/op-attrs/src/transpose.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/transpose.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/test/src/test_conv_2d.cc b/lib/op-attrs/test/src/test_conv_2d.cc new file mode 100644 index 0000000000..11fd1633ee --- /dev/null +++ b/lib/op-attrs/test/src/test_conv_2d.cc @@ -0,0 +1,67 @@ +#include "doctest/doctest.h" +#include "op-attrs/ops/conv_2d.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(Conv2DAttrs, TensorShape)") { + int out_channels = 4; + int kernel_h = 3; + int kernel_w = 2; + int stride_h = 2; + int stride_w = 2; + int padding_h = 1; + int padding_w = 1; + int groups = 1; + std::optional activation = std::nullopt; + bool use_bias = true; + + Conv2DAttrs attrs = { + /*out_channels=*/out_channels, + /*kernel_h=*/kernel_h, + /*kernel_w=*/kernel_w, + /*stride_h=*/stride_h, + /*stride_w=*/stride_w, + /*padding_h=*/padding_h, + /*padding_w=*/padding_w, + /*groups=*/groups, + /*activation=*/activation, + /*use_bias=*/true, + }; + + size_t num_samples = 7; + size_t input_channels = 6; + size_t input_height = 10; + size_t input_width = 15; + + TensorShape input_shape = { + TensorDims{ + FFOrdered{ + num_samples, + input_channels, + input_height, + input_width, + } + }, + DataType::FLOAT, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + + + size_t correct_output_height = 3; + size_t correct_output_width = 6; + + TensorShape correct_output_shape = { + TensorDims{ + FFOrdered{ + num_samples, + static_cast(out_channels), + correct_output_height, + correct_output_width, + } + }, + DataType::FLOAT, + }; + + CHECK(result == correct_output_shape); + } +} diff --git a/lib/op-attrs/test/src/test_operator_attrs.cc b/lib/op-attrs/test/src/test_operator_attrs.cc index 188c9d1607..1973c89fe6 100644 --- a/lib/op-attrs/test/src/test_operator_attrs.cc +++ b/lib/op-attrs/test/src/test_operator_attrs.cc @@ -1,5 +1,6 @@ #include "doctest/doctest.h" -#include "op-attrs/operator_attrs.h" +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" #include "utils/json.h" #include #include @@ -13,17 +14,21 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("ComputationGraphAttrs to/from json") { - ComputationGraphAttrs correct = BatchNormAttrs{true}; + ComputationGraphOpAttrs correct = ComputationGraphOpAttrs{ + BatchNormAttrs{true} + }; json j = correct; - auto result = j.get(); + auto result = j.get(); CHECK(result == correct); } TEST_CASE("PCGOperatorAttrs to/from json") { - PCGOperatorAttrs correct = RepartitionAttrs{ - /*repartition_dim=*/ff_dim_t{1}, - /*repartition_degree=*/4, + PCGOperatorAttrs correct = PCGOperatorAttrs{ + RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{1}, + /*repartition_degree=*/4, + } }; json j = correct; auto result = j.get(); diff --git a/lib/utils/include/utils/exception.decl.h b/lib/utils/include/utils/exception.decl.h index d27174f474..a8cb150ec9 100644 --- a/lib/utils/include/utils/exception.decl.h +++ b/lib/utils/include/utils/exception.decl.h @@ -7,14 +7,14 @@ namespace FlexFlow { #ifdef FF_REQUIRE_IMPLEMENTED -#define NOT_IMPLEMENTED() static_assert(false, "Function not yet implemented"); +#define NOT_IMPLEMENTED() static_assert(false, "Function " __FUNC__ " not yet implemented " __FILE__ ":" __LINE__); #else -#define NOT_IMPLEMENTED() throw not_implemented(); +#define NOT_IMPLEMENTED() throw not_implemented(__PRETTY_FUNCTION__, __FILE__, __LINE__); #endif class not_implemented : public std::logic_error { public: - not_implemented(); + not_implemented(std::string const &function_name, std::string const &file_name, int line); }; template diff --git a/lib/utils/include/utils/fmt.decl.h b/lib/utils/include/utils/fmt.decl.h index d38b36037b..71d00e1c5a 100644 --- a/lib/utils/include/utils/fmt.decl.h +++ b/lib/utils/include/utils/fmt.decl.h @@ -6,6 +6,7 @@ #include #include #include "utils/check_fmtable.h" +#include #define DELEGATE_OSTREAM(...) \ template <> \ @@ -57,6 +58,15 @@ struct formatter<::std::variant> : formatter<::std::string> { -> decltype(ctx.out()); }; +/* template */ +/* struct formatter< */ +/* ::std::pair */ +/* Char, */ +/* std::enable_if_t>::value> */ +/* > : formatter<::std::string> { */ +/* template */ +/* auto format(::std::pair const &m, FormatContext &ctx) */ +/* -> decltype(ctx.out()); */ } // namespace fmt diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index c976174ded..8ca5b34fc2 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -47,7 +47,7 @@ auto formatter< std::string result = ::FlexFlow::join_strings( m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); - return formatter::format(result, ctx); + return formatter::format("[" + result + "]", ctx); } template @@ -60,6 +60,24 @@ auto formatter<::std::variant>::format(::std::variant const &m, return formatter::format(result, ctx); } +/* template */ +/* template */ +/* auto formatter< */ +/* ::std::pair, */ +/* Char, */ +/* std::enable_if_t>::value> */ +/* >::format(::std::pair const &m, FormatContext &ctx) */ +/* -> decltype(ctx.out()) { */ +/* /1* CHECK_FMTABLE(T); *1/ */ + +/* /1* std::string result = ::FlexFlow::join_strings( *1/ */ +/* /1* m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); *1/ */ +/* NOT_IMPLEMENTED(); */ +/* std::string result = ""; */ +/* return formatter::format(result, ctx); */ +/* } */ + + } // namespace fmt namespace FlexFlow { diff --git a/lib/utils/include/utils/fmt/pair.h b/lib/utils/include/utils/fmt/pair.h new file mode 100644 index 0000000000..6a680616c6 --- /dev/null +++ b/lib/utils/include/utils/fmt/pair.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H + +#include +#include "fmt/format.h" +#include "utils/check_fmtable.h" + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::pair const &m) { + CHECK_FMTABLE(L); + CHECK_FMTABLE(R); + + return s << fmt::to_string(m); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/fmt/unordered_map.h b/lib/utils/include/utils/fmt/unordered_map.h index 8c9125e35a..1287302c7a 100644 --- a/lib/utils/include/utils/fmt/unordered_map.h +++ b/lib/utils/include/utils/fmt/unordered_map.h @@ -17,11 +17,12 @@ struct formatter< template auto format(::std::unordered_map const &m, FormatContext &ctx) -> decltype(ctx.out()) { - CHECK_FMTABLE(K); - CHECK_FMTABLE(V); + /* CHECK_FMTABLE(K); */ + /* CHECK_FMTABLE(V); */ - std::string result = ::FlexFlow::join_strings( - m.cbegin(), m.cend(), ", ", [](std::pair const &p) { return fmt::to_string(p); }); + /* std::string result = ::FlexFlow::join_strings( */ + /* m.cbegin(), m.cend(), ", ", [](std::pair const &p) { return fmt::to_string(p); }); */ + std::string result = ""; return formatter::format(result, ctx); } }; diff --git a/lib/utils/src/exception.cc b/lib/utils/src/exception.cc index 7dccdc3074..1369bea4c9 100644 --- a/lib/utils/src/exception.cc +++ b/lib/utils/src/exception.cc @@ -2,7 +2,7 @@ namespace FlexFlow { -not_implemented::not_implemented() - : std::logic_error("Function not yet implemented"){}; +not_implemented::not_implemented(std::string const &function_name, std::string const &file_name, int line) + : std::logic_error(fmt::format("Function {} not yet implemented at {}:{}", function_name, file_name, line)){}; } From 410c9a2a85fc2db15376d30ccf09202794cd0ebe Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 25 Apr 2024 04:23:52 -0700 Subject: [PATCH 12/43] Move conv2d input parsing into public headers --- .../ops/conv_2d}/conv_2d_input_shape.dtg.h | 8 ++-- .../ops/conv_2d/conv_2d_input_shape.h | 13 ++++++ .../conv_2d}/conv_2d_input_shape.struct.toml | 0 .../conv_2d_parallel_input_shape.dtg.h | 8 ++-- .../conv_2d/conv_2d_parallel_input_shape.h | 13 ++++++ .../conv_2d_parallel_input_shape.struct.toml | 0 lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 40 +------------------ .../ops/conv_2d/conv_2d_input_shape.cc | 24 +++++++++++ .../{ => conv_2d}/conv_2d_input_shape.dtg.cc | 4 +- .../conv_2d/conv_2d_parallel_input_shape.cc | 26 ++++++++++++ .../conv_2d_parallel_input_shape.dtg.cc | 4 +- 11 files changed, 90 insertions(+), 50 deletions(-) rename lib/op-attrs/{src/op-attrs/ops => include/op-attrs/ops/conv_2d}/conv_2d_input_shape.dtg.h (84%) create mode 100644 lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.h rename lib/op-attrs/{src/op-attrs/ops => include/op-attrs/ops/conv_2d}/conv_2d_input_shape.struct.toml (100%) rename lib/op-attrs/{src/op-attrs/ops => include/op-attrs/ops/conv_2d}/conv_2d_parallel_input_shape.dtg.h (86%) create mode 100644 lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h rename lib/op-attrs/{src/op-attrs/ops => include/op-attrs/ops/conv_2d}/conv_2d_parallel_input_shape.struct.toml (100%) create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc rename lib/op-attrs/src/op-attrs/ops/{ => conv_2d}/conv_2d_input_shape.dtg.cc (97%) create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc rename lib/op-attrs/src/op-attrs/ops/{ => conv_2d}/conv_2d_parallel_input_shape.dtg.cc (98%) diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.dtg.h b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h similarity index 84% rename from lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.dtg.h rename to lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h index 92c6f57e73..2e7833064c 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h @@ -1,14 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.struct.toml +// lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml /* proj-data { "generated_from": "51911f58c134d55b2d0245444acbae53" } */ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_SRC_OP_ATTRS_OPS_CONV_2D_INPUT_SHAPE_DTG_H -#define _FLEXFLOW_LIB_OP_ATTRS_SRC_OP_ATTRS_OPS_CONV_2D_INPUT_SHAPE_DTG_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_INPUT_SHAPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_INPUT_SHAPE_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" @@ -69,4 +69,4 @@ std::string format_as(Conv2DInputShape const &); std::ostream &operator<<(std::ostream &, Conv2DInputShape const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_SRC_OP_ATTRS_OPS_CONV_2D_INPUT_SHAPE_DTG_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_INPUT_SHAPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.h b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.h new file mode 100644 index 0000000000..043f5854ae --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_INPUT_SHAPE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_INPUT_SHAPE_H + +#include "op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" + +namespace FlexFlow { + +Conv2DInputShape parse_input_shape(TensorShape const &input); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml similarity index 100% rename from lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.struct.toml rename to lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.dtg.h b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h similarity index 86% rename from lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.dtg.h rename to lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h index bae26b378b..846c9e413a 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h @@ -1,14 +1,14 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.struct.toml +// lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml /* proj-data { "generated_from": "d80394bdc90f843372760310b6d17a22" } */ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_SRC_OP_ATTRS_OPS_CONV_2D_PARALLEL_INPUT_SHAPE_DTG_H -#define _FLEXFLOW_LIB_OP_ATTRS_SRC_OP_ATTRS_OPS_CONV_2D_PARALLEL_INPUT_SHAPE_DTG_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_PARALLEL_INPUT_SHAPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_PARALLEL_INPUT_SHAPE_DTG_H #include "fmt/format.h" #include "nlohmann/json.hpp" @@ -73,4 +73,4 @@ std::string format_as(Conv2DParallelInputShape const &); std::ostream &operator<<(std::ostream &, Conv2DParallelInputShape const &); } // namespace FlexFlow -#endif // _FLEXFLOW_LIB_OP_ATTRS_SRC_OP_ATTRS_OPS_CONV_2D_PARALLEL_INPUT_SHAPE_DTG_H +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_PARALLEL_INPUT_SHAPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h new file mode 100644 index 0000000000..9edff21db8 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_PARALLEL_INPUT_SHAPE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_PARALLEL_INPUT_SHAPE_H + +#include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" + +namespace FlexFlow { + +Conv2DParallelInputShape parse_parallel_input_shape(ParallelTensorShape const &input); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml similarity index 100% rename from lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.struct.toml rename to lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index c45c77672c..be0bba0a6d 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -1,6 +1,6 @@ #include "op-attrs/ops/conv_2d.h" -#include "conv_2d_input_shape.dtg.h" -#include "conv_2d_parallel_input_shape.dtg.h" +#include "op-attrs/ops/conv_2d/conv_2d_input_shape.h" +#include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h" namespace FlexFlow { @@ -9,42 +9,6 @@ static size_t as_size_t(int x) { return static_cast(x); } -static Conv2DInputShape parse_input_shape(TensorShape const &input) { - assert(num_dims(input) == 4); - - size_t num_samples = dim_at_idx(input, ff_dim_t{0}); - size_t in_channels = dim_at_idx(input, ff_dim_t{1}); - size_t in_height = dim_at_idx(input, ff_dim_t{2}); - size_t in_width = dim_at_idx(input, ff_dim_t{3}); - - return Conv2DInputShape{ - num_samples, - in_channels, - in_height, - in_width, - input.data_type, - }; -} - -static Conv2DParallelInputShape parse_parallel_input_shape(ParallelTensorShape const &input) { - assert(num_shard_dims(input) == 4); - - ShardParallelDim sample_dim = shard_dim_at_idx(input, ff_dim_t{0}); - ShardParallelDim channel_dim = shard_dim_at_idx(input, ff_dim_t{1}); - ShardParallelDim height_dim = shard_dim_at_idx(input, ff_dim_t{2}); - ShardParallelDim width_dim = shard_dim_at_idx(input, ff_dim_t{3}); - - return Conv2DParallelInputShape{ - sample_dim, - channel_dim, - height_dim, - width_dim, - input.dims.replica_dims.sum_degree, - input.dims.replica_dims.discard_copy_degree, - input.data_type, - }; -} - TensorShape get_kernel_shape(Conv2DAttrs const &attrs, TensorShape const &raw_input_shape) { assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported Conv2DInputShape input = parse_input_shape(raw_input_shape); diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc new file mode 100644 index 0000000000..ed508de131 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc @@ -0,0 +1,24 @@ +#include "op-attrs/ops/conv_2d/conv_2d_input_shape.h" +#include "op-attrs/tensor_shape.h" + +namespace FlexFlow { + +Conv2DInputShape parse_input_shape(TensorShape const &input) { + assert(num_dims(input) == 4); + + size_t num_samples = dim_at_idx(input, ff_dim_t{0}); + size_t in_channels = dim_at_idx(input, ff_dim_t{1}); + size_t in_height = dim_at_idx(input, ff_dim_t{2}); + size_t in_width = dim_at_idx(input, ff_dim_t{3}); + + return Conv2DInputShape{ + num_samples, + in_channels, + in_height, + in_width, + input.data_type, + }; +} + + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc similarity index 97% rename from lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.dtg.cc rename to lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc index 47f86afc53..74df30e2d7 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc @@ -1,13 +1,13 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/src/op-attrs/ops/conv_2d_input_shape.struct.toml +// lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml /* proj-data { "generated_from": "51911f58c134d55b2d0245444acbae53" } */ -#include "op-attrs/ops/conv_2d_input_shape.dtg.h" +#include "op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h" #include "op-attrs/datatype.dtg.h" #include diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc new file mode 100644 index 0000000000..501f42fe0a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc @@ -0,0 +1,26 @@ +#include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +Conv2DParallelInputShape parse_parallel_input_shape(ParallelTensorShape const &input) { + assert(num_shard_dims(input) == 4); + + ShardParallelDim sample_dim = shard_dim_at_idx(input, ff_dim_t{0}); + ShardParallelDim channel_dim = shard_dim_at_idx(input, ff_dim_t{1}); + ShardParallelDim height_dim = shard_dim_at_idx(input, ff_dim_t{2}); + ShardParallelDim width_dim = shard_dim_at_idx(input, ff_dim_t{3}); + + return Conv2DParallelInputShape{ + sample_dim, + channel_dim, + height_dim, + width_dim, + input.dims.replica_dims.sum_degree, + input.dims.replica_dims.discard_copy_degree, + input.data_type, + }; +} + + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc similarity index 98% rename from lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.dtg.cc rename to lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc index 46e8061e7e..df854c2b8f 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc @@ -1,13 +1,13 @@ // THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! // If you would like to modify this datatype, instead modify -// lib/op-attrs/src/op-attrs/ops/conv_2d_parallel_input_shape.struct.toml +// lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml /* proj-data { "generated_from": "d80394bdc90f843372760310b6d17a22" } */ -#include "op-attrs/ops/conv_2d_parallel_input_shape.dtg.h" +#include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h" #include "op-attrs/datatype.dtg.h" #include "op-attrs/shard_parallel_dim.dtg.h" From d6c5f7dafdabcfd232e07c844be1fcec8a7e336d Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 25 Apr 2024 04:24:24 -0700 Subject: [PATCH 13/43] Remove incorrect not_implemented --- lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index be0bba0a6d..4271a770b6 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -24,8 +24,6 @@ TensorShape get_kernel_shape(Conv2DAttrs const &attrs, TensorShape const &raw_in }, input.datatype, }; - - NOT_IMPLEMENTED(); } TensorShape get_bias_shape(Conv2DAttrs const &attrs, TensorShape const &raw_input_shape) { From 80d305eeba59a431576c8367a31c82bd27fc1893 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 1 May 2024 18:03:37 -0700 Subject: [PATCH 14/43] Add pcg tests --- .proj.toml | 1 + flake.lock | 6 +- .../op-attrs/computation_graph_op_attrs.dtg.h | 27 ++- .../op-attrs/computation_graph_op_attrs.h | 12 ++ .../computation_graph_op_attrs.variant.toml | 5 + lib/op-attrs/include/op-attrs/get_op_type.h | 48 ++++-- .../include/op-attrs/ops/weight_attrs.dtg.h | 58 +++++++ .../op-attrs/ops/weight_attrs.struct.toml | 11 ++ .../op-attrs/parallel_tensor_dims.dtg.h | 6 +- .../op-attrs/parallel_tensor_dims.struct.toml | 2 +- .../op-attrs/parallel_tensor_shape.dtg.h | 6 +- .../parallel_tensor_shape.struct.toml | 2 +- .../include/op-attrs/pcg_operator_attrs.h | 1 + lib/op-attrs/include/op-attrs/tensor_shape.h | 2 +- lib/op-attrs/src/get_op_type.cc | 96 ----------- .../op-attrs/computation_graph_op_attrs.cc | 10 ++ .../computation_graph_op_attrs.dtg.cc | 21 ++- lib/op-attrs/src/op-attrs/get_op_type.cc | 37 ++++ lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 2 +- .../src/op-attrs/ops/weight_attrs.dtg.cc | 70 ++++++++ .../src/op-attrs/parallel_tensor_dims.dtg.cc | 18 +- .../src/op-attrs/parallel_tensor_shape.dtg.cc | 18 +- .../src/op-attrs/pcg_operator_attrs.cc | 5 + lib/pcg/include/pcg/computation_graph.dtg.h | 12 +- lib/pcg/include/pcg/computation_graph.h | 3 + .../include/pcg/computation_graph.struct.toml | 6 +- .../layer_added_result.dtg.h | 37 ++++ .../layer_added_result.struct.toml | 19 ++ .../include/pcg/computation_graph_builder.h | 39 ++--- lib/pcg/include/pcg/layer_attrs.dtg.h | 8 +- lib/pcg/include/pcg/layer_attrs.struct.toml | 4 +- lib/pcg/include/pcg/layer_guid_t.dtg.h | 46 +++++ lib/pcg/include/pcg/layer_guid_t.struct.toml | 16 ++ lib/pcg/include/pcg/operator_graph.h | 65 +++++++ lib/pcg/src/pcg/computation_graph.dtg.cc | 9 +- .../layer_added_result.dtg.cc | 43 +++++ lib/pcg/src/pcg/computation_graph_builder.cc | 162 ++++++++++++------ lib/pcg/src/pcg/layer_attrs.dtg.cc | 10 +- lib/pcg/src/pcg/layer_guid_t.dtg.cc | 59 +++++++ lib/pcg/src/pcg/operator_graph.cc | 39 +++++ lib/pcg/test/CMakeLists.txt | 13 ++ .../src/test_computation_graph_builder.cc | 8 + lib/utils/include/utils/containers.h | 2 + .../include/utils/containers/concat_vectors.h | 17 ++ .../utils/containers/enumerate_vector.h | 20 +++ .../include/utils/containers/extend_vector.h | 17 ++ lib/utils/include/utils/optional.h | 9 + lib/utils/src/exception.cc | 2 +- 48 files changed, 900 insertions(+), 229 deletions(-) create mode 100644 lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml delete mode 100644 lib/op-attrs/src/get_op_type.cc create mode 100644 lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/get_op_type.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/weight_attrs.dtg.cc create mode 100644 lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.h create mode 100644 lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml create mode 100644 lib/pcg/include/pcg/layer_guid_t.dtg.h create mode 100644 lib/pcg/include/pcg/layer_guid_t.struct.toml create mode 100644 lib/pcg/include/pcg/operator_graph.h create mode 100644 lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc create mode 100644 lib/pcg/src/pcg/layer_guid_t.dtg.cc create mode 100644 lib/pcg/src/pcg/operator_graph.cc create mode 100644 lib/pcg/test/CMakeLists.txt create mode 100644 lib/pcg/test/src/test_computation_graph_builder.cc create mode 100644 lib/utils/include/utils/containers/concat_vectors.h create mode 100644 lib/utils/include/utils/containers/enumerate_vector.h create mode 100644 lib/utils/include/utils/containers/extend_vector.h diff --git a/.proj.toml b/.proj.toml index 3f4fcddaad..f2e7e20f49 100644 --- a/.proj.toml +++ b/.proj.toml @@ -7,6 +7,7 @@ build_targets = [ "utils", "op-attrs", "kernels", + "pcg", # "substitutions", # "compiler", ] diff --git a/flake.lock b/flake.lock index 9745539839..f90110d8aa 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1713942681, - "narHash": "sha256-thpBjg7m0wCqmcLzLZdZqXIW2sfwUpiBrHriimfeoZU=", + "lastModified": 1714185778, + "narHash": "sha256-Rl33HVDHhmcgKnHPYo96XA7Zm85PUmfuCZGeSWseAdw=", "owner": "lockshaw", "repo": "proj", - "rev": "f9ee9aa7de919734228518f76f0f02d5fcdbb295", + "rev": "9b9465925365e76d0db5c11579d52b47fecc4dcd", "type": "github" }, "original": { diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h index 412bd7aea0..02e4ce4f27 100644 --- a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml /* proj-data { - "generated_from": "87653647c900faaf564d3069478569e7" + "generated_from": "dc1445fed47c2acaed22038975eec627" } */ @@ -38,6 +38,7 @@ #include "op-attrs/ops/split_attrs.dtg.h" #include "op-attrs/ops/topk_attrs.dtg.h" #include "op-attrs/ops/transpose_attrs.dtg.h" +#include "op-attrs/ops/weight_attrs.dtg.h" #include #include #include @@ -73,6 +74,7 @@ struct ComputationGraphOpAttrs { explicit ComputationGraphOpAttrs(::FlexFlow::SoftmaxAttrs const &); explicit ComputationGraphOpAttrs(::FlexFlow::TopKAttrs const &); explicit ComputationGraphOpAttrs(::FlexFlow::TransposeAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::WeightAttrs const &); template static constexpr bool IsPartOfComputationGraphOpAttrs_v = std::is_same_v || @@ -100,7 +102,8 @@ struct ComputationGraphOpAttrs { std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v; + std::is_same_v || + std::is_same_v; template ReturnType visit(Visitor &&v) const { switch (this->index()) { @@ -208,6 +211,10 @@ struct ComputationGraphOpAttrs { ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); return result; } + case 26: { + ReturnType result = v(this->get<::FlexFlow::WeightAttrs>()); + return result; + } default: { throw std::runtime_error( fmt::format("Unknown index {} for type ComputationGraphOpAttrs", @@ -322,6 +329,10 @@ struct ComputationGraphOpAttrs { ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); return result; } + case 26: { + ReturnType result = v(this->get<::FlexFlow::WeightAttrs>()); + return result; + } default: { throw std::runtime_error( fmt::format("Unknown index {} for type ComputationGraphOpAttrs", @@ -346,7 +357,8 @@ struct ComputationGraphOpAttrs { "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " "::FlexFlow::ReverseAttrs, ::FlexFlow::ReshapeAttrs, " "::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, " - "::FlexFlow::TopKAttrs, ::FlexFlow::TransposeAttrs], received T"); + "::FlexFlow::TopKAttrs, ::FlexFlow::TransposeAttrs, " + "::FlexFlow::WeightAttrs], received T"); return std::holds_alternative(this->raw_variant); } template @@ -366,7 +378,8 @@ struct ComputationGraphOpAttrs { "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " "::FlexFlow::ReverseAttrs, ::FlexFlow::ReshapeAttrs, " "::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, " - "::FlexFlow::TopKAttrs, ::FlexFlow::TransposeAttrs], received T"); + "::FlexFlow::TopKAttrs, ::FlexFlow::TransposeAttrs, " + "::FlexFlow::WeightAttrs], received T"); return std::get(this->raw_variant); } template @@ -386,7 +399,8 @@ struct ComputationGraphOpAttrs { "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " "::FlexFlow::ReverseAttrs, ::FlexFlow::ReshapeAttrs, " "::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, " - "::FlexFlow::TopKAttrs, ::FlexFlow::TransposeAttrs], received T"); + "::FlexFlow::TopKAttrs, ::FlexFlow::TransposeAttrs, " + "::FlexFlow::WeightAttrs], received T"); return std::get(this->raw_variant); } size_t index() const { @@ -423,7 +437,8 @@ struct ComputationGraphOpAttrs { ::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, - ::FlexFlow::TransposeAttrs> + ::FlexFlow::TransposeAttrs, + ::FlexFlow::WeightAttrs> raw_variant; }; } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h new file mode 100644 index 0000000000..4be17798f7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_H + +#include "op-attrs/computation_graph_op_attrs.dtg.h" + +namespace FlexFlow { + +OperatorType get_op_type(ComputationGraphOpAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml index fdf9702875..cdbf9332da 100644 --- a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml @@ -35,6 +35,7 @@ includes = [ "op-attrs/ops/split_attrs.dtg.h", "op-attrs/ops/topk_attrs.dtg.h", "op-attrs/ops/transpose_attrs.dtg.h", + "op-attrs/ops/weight_attrs.dtg.h", ] [[values]] @@ -140,3 +141,7 @@ key = "topk" [[values]] type = "::FlexFlow::TransposeAttrs" key = "transpose" + +[[values]] +type = "::FlexFlow::WeightAttrs" +key = "weight" diff --git a/lib/op-attrs/include/op-attrs/get_op_type.h b/lib/op-attrs/include/op-attrs/get_op_type.h index a2db4ab5f0..03b4f92259 100644 --- a/lib/op-attrs/include/op-attrs/get_op_type.h +++ b/lib/op-attrs/include/op-attrs/get_op_type.h @@ -1,8 +1,38 @@ #ifndef _FLEXFLOW_OP_ATTRS_GET_OP_TYPE_H #define _FLEXFLOW_OP_ATTRS_GET_OP_TYPE_H -#include "operator_attrs.h" -#include "utils/variant.h" +#include "op-attrs/ops/batch_matmul.dtg.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" +#include "op-attrs/ops/broadcast.dtg.h" +#include "op-attrs/ops/cast_attrs.dtg.h" +#include "op-attrs/ops/concat_attrs.dtg.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" +#include "op-attrs/ops/flat_attrs.dtg.h" +#include "op-attrs/ops/gather_attrs.dtg.h" +#include "op-attrs/ops/input_attrs.dtg.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" +#include "op-attrs/ops/linear_attrs.dtg.h" +#include "op-attrs/ops/attention_attrs.dtg.h" +#include "op-attrs/ops/noop_attrs.dtg.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" +#include "op-attrs/ops/split_attrs.dtg.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" +#include "op-attrs/ops/topk_attrs.dtg.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" +#include "op-attrs/ops/combine_attrs.dtg.h" +#include "op-attrs/ops/reduction_attrs.dtg.h" +#include "op-attrs/ops/repartition_attrs.dtg.h" +#include "op-attrs/ops/replicate_attrs.dtg.h" +#include "op-attrs/ops/weight_attrs.dtg.h" + namespace FlexFlow { @@ -32,23 +62,13 @@ OperatorType get_op_type(SplitAttrs const &); OperatorType get_op_type(SoftmaxAttrs const &); OperatorType get_op_type(TopKAttrs const &); OperatorType get_op_type(TransposeAttrs const &); +OperatorType get_op_type(WeightAttrs const &); + OperatorType get_op_type(CombineAttrs const &); OperatorType get_op_type(ReductionAttrs const &); OperatorType get_op_type(RepartitionAttrs const &); OperatorType get_op_type(ReplicateAttrs const &); -struct GetOpTypeFunctor { - template - OperatorType operator()(T const &t) { - return get_op_type(t); - } -}; - -template -OperatorType get_op_type(std::variant const &attrs) { - return visit(GetOpTypeFunctor{}, attrs); -} - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.h new file mode 100644 index 0000000000..4a19909c25 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml +/* proj-data +{ + "generated_from": "59f49374ffca95b2117b8940af1b6cac" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct WeightAttrs { + bool operator==(WeightAttrs const &) const; + bool operator!=(WeightAttrs const &) const; + bool operator<(WeightAttrs const &) const; + bool operator>(WeightAttrs const &) const; + bool operator<=(WeightAttrs const &) const; + bool operator>=(WeightAttrs const &) const; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::WeightAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::WeightAttrs from_json(json const &); + static void to_json(json &, FlexFlow::WeightAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(WeightAttrs const &); +std::ostream &operator<<(std::ostream &, WeightAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml new file mode 100644 index 0000000000..28810a437e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "WeightAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h index 1090244a1b..d2d1b13a49 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml /* proj-data { - "generated_from": "141639bdce009a1594501f33c2f25c9e" + "generated_from": "31a9e757f42ec3e468b299cda2cbcd4e" } */ @@ -31,6 +31,10 @@ struct ParallelTensorDims { bool operator==(ParallelTensorDims const &) const; bool operator!=(ParallelTensorDims const &) const; + bool operator<(ParallelTensorDims const &) const; + bool operator>(ParallelTensorDims const &) const; + bool operator<=(ParallelTensorDims const &) const; + bool operator>=(ParallelTensorDims const &) const; ::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim> shard_dims; ::FlexFlow::ReplicaParallelDimSet replica_dims; }; diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml index 0d07939ff0..37216b160e 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml @@ -2,7 +2,7 @@ namespace = "FlexFlow" name = "ParallelTensorDims" features = [ "eq", - # "ord", + "ord", "hash", "json", # "rapidcheck", diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h index b253880764..dfad5b1007 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml /* proj-data { - "generated_from": "bc7e838003fe037b95d45cd5ab4aa16f" + "generated_from": "b2d36c9212916e66569af4e958c893f4" } */ @@ -26,6 +26,10 @@ struct ParallelTensorShape { bool operator==(ParallelTensorShape const &) const; bool operator!=(ParallelTensorShape const &) const; + bool operator<(ParallelTensorShape const &) const; + bool operator>(ParallelTensorShape const &) const; + bool operator<=(ParallelTensorShape const &) const; + bool operator>=(ParallelTensorShape const &) const; ::FlexFlow::ParallelTensorDims dims; ::FlexFlow::DataType data_type; }; diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml index 411070848d..1199b0d816 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml @@ -2,7 +2,7 @@ namespace = "FlexFlow" name = "ParallelTensorShape" features = [ "eq", - # "ord", + "ord", "hash", "json", # "rapidcheck", diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h index 4605a4c114..0ad7a9f829 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h @@ -6,6 +6,7 @@ namespace FlexFlow { bool is_parallel_op(PCGOperatorAttrs const &); +OperatorType get_op_type(PCGOperatorAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.h b/lib/op-attrs/include/op-attrs/tensor_shape.h index 92d360a95d..de27632919 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.h @@ -5,8 +5,8 @@ namespace FlexFlow { -size_t dim_at_idx(TensorShape const &, ff_dim_t); size_t num_dims(TensorShape const &); +size_t dim_at_idx(TensorShape const &, ff_dim_t); } // namespace FlexFlow diff --git a/lib/op-attrs/src/get_op_type.cc b/lib/op-attrs/src/get_op_type.cc deleted file mode 100644 index 2fb539472f..0000000000 --- a/lib/op-attrs/src/get_op_type.cc +++ /dev/null @@ -1,96 +0,0 @@ -#include "op-attrs/get_op_type.h" - -namespace FlexFlow { - -OperatorType get_op_type(BatchMatmulAttrs const &) { - return OperatorType::BATCHMATMUL; -} -OperatorType get_op_type(BatchNormAttrs const &) { - return OperatorType::BATCHNORM; -} -OperatorType get_op_type(BroadcastAttrs const &) { - return OperatorType::BROADCAST; -} -OperatorType get_op_type(CastAttrs const &) { - return OperatorType::CAST; -} -OperatorType get_op_type(ConcatAttrs const &) { - return OperatorType::CONCAT; -} -OperatorType get_op_type(Conv2DAttrs const &) { - return OperatorType::CONV2D; -} -OperatorType get_op_type(DropoutAttrs const &) { - return OperatorType::DROPOUT; -} -OperatorType get_op_type(ElementBinaryAttrs const &attrs) { - return attrs.type; -} -OperatorType get_op_type(ElementUnaryAttrs const &attrs) { - return attrs.op_type; -} -OperatorType get_op_type(ElementScalarUnaryAttrs const &attrs) { - return attrs.op_type; -} -OperatorType get_op_type(EmbeddingAttrs const &) { - return OperatorType::EMBEDDING; -} -OperatorType get_op_type(FlatAttrs const &) { - return OperatorType::FLAT; -} -OperatorType get_op_type(GatherAttrs const &) { - return OperatorType::GATHER; -} -OperatorType get_op_type(InputAttrs const &) { - return OperatorType::INPUT; -} -OperatorType get_op_type(LayerNormAttrs const &) { - return OperatorType::LAYERNORM; -} -OperatorType get_op_type(LinearAttrs const &) { - return OperatorType::LINEAR; -} -OperatorType get_op_type(MultiHeadAttentionAttrs const &) { - return OperatorType::MULTIHEAD_ATTENTION; -} -OperatorType get_op_type(NoopAttrs const &) { - return OperatorType::NOOP; -} -OperatorType get_op_type(Pool2DAttrs const &) { - return OperatorType::POOL2D; -} -OperatorType get_op_type(ReduceAttrs const &) { - return OperatorType::REDUCE_SUM; -} -OperatorType get_op_type(ReshapeAttrs const &) { - return OperatorType::RESHAPE; -} -OperatorType get_op_type(SplitAttrs const &) { - return OperatorType::SPLIT; -} -OperatorType get_op_type(SoftmaxAttrs const &) { - return OperatorType::SOFTMAX; -} -OperatorType get_op_type(TopKAttrs const &) { - return OperatorType::TOPK; -} -OperatorType get_op_type(TransposeAttrs const &) { - return OperatorType::TRANSPOSE; -} -OperatorType get_op_type(CombineAttrs const &) { - return OperatorType::COMBINE; -} -OperatorType get_op_type(ReductionAttrs const &) { - return OperatorType::REDUCTION; -} -OperatorType get_op_type(RepartitionAttrs const &) { - return OperatorType::REPARTITION; -} -OperatorType get_op_type(ReplicateAttrs const &) { - return OperatorType::REPLICATE; -} -OperatorType get_op_type(ReverseAttrs const &attrs) { - return OperatorType::REVERSE; -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc new file mode 100644 index 0000000000..a7145ca1dd --- /dev/null +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc @@ -0,0 +1,10 @@ +#include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/get_op_type.h" + +namespace FlexFlow { + +OperatorType get_op_type(ComputationGraphOpAttrs const &attrs) { + return attrs.visit([](auto const &x) { return get_op_type(x); }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc index b92e835ee6..7b0db513bf 100644 --- a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml /* proj-data { - "generated_from": "87653647c900faaf564d3069478569e7" + "generated_from": "dc1445fed47c2acaed22038975eec627" } */ @@ -88,6 +88,9 @@ ComputationGraphOpAttrs::ComputationGraphOpAttrs(::FlexFlow::TopKAttrs const &v) ComputationGraphOpAttrs::ComputationGraphOpAttrs( ::FlexFlow::TransposeAttrs const &v) : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::WeightAttrs const &v) + : raw_variant(v) {} bool ComputationGraphOpAttrs::operator==( ComputationGraphOpAttrs const &other) const { return this->raw_variant == other.raw_variant; @@ -141,7 +144,8 @@ size_t hash<::FlexFlow::ComputationGraphOpAttrs>::operator()( ::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, - ::FlexFlow::TransposeAttrs>>{}(x.raw_variant); + ::FlexFlow::TransposeAttrs, + ::FlexFlow::WeightAttrs>>{}(x.raw_variant); } } // namespace std namespace nlohmann { @@ -227,6 +231,9 @@ ::FlexFlow::ComputationGraphOpAttrs } else if (key == "transpose") { return ::FlexFlow::ComputationGraphOpAttrs{ j.at("value").template get<::FlexFlow::TransposeAttrs>()}; + } else if (key == "weight") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::WeightAttrs>()}; } else { throw std::runtime_error(fmt::format("Unknown type key {}", key)); } @@ -365,6 +372,11 @@ void adl_serializer<::FlexFlow::ComputationGraphOpAttrs>::to_json( j["value"] = x.get<::FlexFlow::TransposeAttrs>(); break; } + case 26: { + j["type"] = "weight"; + j["value"] = x.get<::FlexFlow::WeightAttrs>(); + break; + } default: { throw std::runtime_error(fmt::format( "Unknown index {} for type ComputationGraphOpAttrs", x.index())); @@ -506,6 +518,11 @@ std::string format_as(::FlexFlow::ComputationGraphOpAttrs const &x) { << x.get<::FlexFlow::TransposeAttrs>() << ">"; break; } + case 26: { + oss << ""; + break; + } default: { throw std::runtime_error(fmt::format( "Unknown index {} for type ComputationGraphOpAttrs", x.index())); diff --git a/lib/op-attrs/src/op-attrs/get_op_type.cc b/lib/op-attrs/src/op-attrs/get_op_type.cc new file mode 100644 index 0000000000..2c658d9189 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/get_op_type.cc @@ -0,0 +1,37 @@ +#include "op-attrs/get_op_type.h" + +namespace FlexFlow { + +OperatorType get_op_type(BatchMatmulAttrs const &) { return OperatorType::BATCHMATMUL; } +OperatorType get_op_type(BatchNormAttrs const &) { return OperatorType::BATCHNORM; } +OperatorType get_op_type(BroadcastAttrs const &) { return OperatorType::BROADCAST; } +OperatorType get_op_type(CastAttrs const &) { return OperatorType::CAST; } +OperatorType get_op_type(ConcatAttrs const &) { return OperatorType::CONCAT; } +OperatorType get_op_type(Conv2DAttrs const &) { return OperatorType::CONV2D; } +OperatorType get_op_type(DropoutAttrs const &) { return OperatorType::DROPOUT; } +OperatorType get_op_type(ElementBinaryAttrs const &attrs) { return attrs.type; } +OperatorType get_op_type(ElementUnaryAttrs const &attrs) { return attrs.op_type; } +OperatorType get_op_type(ElementScalarUnaryAttrs const &attrs) { return attrs.op_type; } +OperatorType get_op_type(EmbeddingAttrs const &) { return OperatorType::EMBEDDING; } +OperatorType get_op_type(FlatAttrs const &) { return OperatorType::FLAT; } +OperatorType get_op_type(GatherAttrs const &) { return OperatorType::GATHER; } +OperatorType get_op_type(InputAttrs const &) { return OperatorType::INPUT; } +OperatorType get_op_type(LayerNormAttrs const &) { return OperatorType::LAYERNORM; } +OperatorType get_op_type(LinearAttrs const &) { return OperatorType::LINEAR; } +OperatorType get_op_type(MultiHeadAttentionAttrs const &) { return OperatorType::MULTIHEAD_ATTENTION; } +OperatorType get_op_type(NoopAttrs const &) { return OperatorType::NOOP; } +OperatorType get_op_type(Pool2DAttrs const &) { return OperatorType::POOL2D; } +OperatorType get_op_type(ReduceAttrs const &attrs) { return attrs.op_type; } +OperatorType get_op_type(ReshapeAttrs const &) { return OperatorType::RESHAPE; } +OperatorType get_op_type(ReverseAttrs const &) { return OperatorType::REVERSE; } +OperatorType get_op_type(SplitAttrs const &) { return OperatorType::SPLIT; } +OperatorType get_op_type(SoftmaxAttrs const &) { return OperatorType::SOFTMAX; } +OperatorType get_op_type(TopKAttrs const &) { return OperatorType::TOPK; } +OperatorType get_op_type(TransposeAttrs const &) { return OperatorType::TRANSPOSE; } +OperatorType get_op_type(CombineAttrs const &) { return OperatorType::COMBINE; } +OperatorType get_op_type(ReductionAttrs const &) { return OperatorType::REDUCTION; } +OperatorType get_op_type(RepartitionAttrs const &) { return OperatorType::REPARTITION; } +OperatorType get_op_type(ReplicateAttrs const &) { return OperatorType::REPLICATE; } +OperatorType get_op_type(WeightAttrs const &) { return OperatorType::WEIGHT; } + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index 4271a770b6..e07e398fa6 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -99,7 +99,7 @@ ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs, ParallelTensorShape assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); - ShardParallelDim output_channels_dim = {attrs.out_channels, input.discard_copy_reduction_degree}; + ShardParallelDim output_channels_dim = {as_size_t(attrs.out_channels), input.discard_copy_reduction_degree}; int sum_degree = 1; int discard_copy_degree = input.height_dim.degree * input.width_dim.degree * input.sum_reduction_degree * input.channel_dim.degree; diff --git a/lib/op-attrs/src/op-attrs/ops/weight_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/weight_attrs.dtg.cc new file mode 100644 index 0000000000..a288161da2 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/weight_attrs.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml +/* proj-data +{ + "generated_from": "59f49374ffca95b2117b8940af1b6cac" +} +*/ + +#include "op-attrs/ops/weight_attrs.dtg.h" + +#include + +namespace FlexFlow { +bool WeightAttrs::operator==(WeightAttrs const &other) const { + return std::tie() == std::tie(); +} +bool WeightAttrs::operator!=(WeightAttrs const &other) const { + return std::tie() != std::tie(); +} +bool WeightAttrs::operator<(WeightAttrs const &other) const { + return std::tie() < std::tie(); +} +bool WeightAttrs::operator>(WeightAttrs const &other) const { + return std::tie() > std::tie(); +} +bool WeightAttrs::operator<=(WeightAttrs const &other) const { + return std::tie() <= std::tie(); +} +bool WeightAttrs::operator>=(WeightAttrs const &other) const { + return std::tie() >= std::tie(); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::WeightAttrs const &x) const { + size_t result = 0; + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::WeightAttrs + adl_serializer::from_json(json const &j) { + return {}; +} +void adl_serializer::to_json( + json &j, FlexFlow::WeightAttrs const &v) { + j["__type"] = "WeightAttrs"; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(WeightAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, WeightAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc index aee6e9ab14..14372d42cc 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml /* proj-data { - "generated_from": "141639bdce009a1594501f33c2f25c9e" + "generated_from": "31a9e757f42ec3e468b299cda2cbcd4e" } */ @@ -30,6 +30,22 @@ bool ParallelTensorDims::operator!=(ParallelTensorDims const &other) const { return std::tie(this->shard_dims, this->replica_dims) != std::tie(other.shard_dims, other.replica_dims); } +bool ParallelTensorDims::operator<(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) < + std::tie(other.shard_dims, other.replica_dims); +} +bool ParallelTensorDims::operator>(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) > + std::tie(other.shard_dims, other.replica_dims); +} +bool ParallelTensorDims::operator<=(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) <= + std::tie(other.shard_dims, other.replica_dims); +} +bool ParallelTensorDims::operator>=(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) >= + std::tie(other.shard_dims, other.replica_dims); +} } // namespace FlexFlow namespace std { diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc index f990e21fa4..037acbf996 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml /* proj-data { - "generated_from": "bc7e838003fe037b95d45cd5ab4aa16f" + "generated_from": "b2d36c9212916e66569af4e958c893f4" } */ @@ -26,6 +26,22 @@ bool ParallelTensorShape::operator!=(ParallelTensorShape const &other) const { return std::tie(this->dims, this->data_type) != std::tie(other.dims, other.data_type); } +bool ParallelTensorShape::operator<(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) < + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator>(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) > + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator<=(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) <= + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator>=(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) >= + std::tie(other.dims, other.data_type); +} } // namespace FlexFlow namespace std { diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc index cb54736a7a..60043a82b7 100644 --- a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc @@ -1,4 +1,5 @@ #include "op-attrs/pcg_operator_attrs.h" +#include "op-attrs/get_op_type.h" namespace FlexFlow { @@ -9,4 +10,8 @@ bool is_parallel_op(PCGOperatorAttrs const &attrs) { || attrs.has()); } +OperatorType get_op_type(PCGOperatorAttrs const &attrs) { + return attrs.visit([](auto const &x) { return get_op_type(x); }); +} + } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/computation_graph.dtg.h b/lib/pcg/include/pcg/computation_graph.dtg.h index c5a74b08d5..d0cdefcb7a 100644 --- a/lib/pcg/include/pcg/computation_graph.dtg.h +++ b/lib/pcg/include/pcg/computation_graph.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/computation_graph.struct.toml /* proj-data { - "generated_from": "3639f7e8bb97a5ca2c2ef13caff3c84e" + "generated_from": "7d22a6bc44163f331bc33002714721cf" } */ @@ -11,19 +11,19 @@ #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H #include "pcg/layer_attrs.dtg.h" +#include "pcg/operator_graph.h" #include "pcg/tensor_attrs.dtg.h" -#include "utils/graph.h" namespace FlexFlow { struct ComputationGraph { ComputationGraph() = delete; ComputationGraph( - ::FlexFlow::OutputLabelledMultiDiGraph<::FlexFlow::LayerAttrs, - ::FlexFlow::TensorAttrs> const + ::FlexFlow::LabelledOperatorGraph<::FlexFlow::LayerAttrs, + ::FlexFlow::TensorAttrs> const &raw_graph); - ::FlexFlow::OutputLabelledMultiDiGraph<::FlexFlow::LayerAttrs, - ::FlexFlow::TensorAttrs> + ::FlexFlow::LabelledOperatorGraph<::FlexFlow::LayerAttrs, + ::FlexFlow::TensorAttrs> raw_graph; }; } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index a937a9d46e..d68574ad71 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -4,9 +4,12 @@ #include "pcg/computation_graph.dtg.h" #include "pcg/tensor_guid_t.dtg.h" #include "pcg/tensor_attrs.dtg.h" +#include "pcg/layer_guid_t.dtg.h" +#include "pcg/computation_graph/layer_added_result.dtg.h" namespace FlexFlow { +LayerAddedResult add_layer(ComputationGraph &computation_graph, LayerAttrs const &attrs, std::vector const &inputs, std::vector const &outputs); TensorAttrs get_tensor_attrs(ComputationGraph const &, tensor_guid_t const &); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/computation_graph.struct.toml b/lib/pcg/include/pcg/computation_graph.struct.toml index 30b3487da1..b8b3d8c372 100644 --- a/lib/pcg/include/pcg/computation_graph.struct.toml +++ b/lib/pcg/include/pcg/computation_graph.struct.toml @@ -3,11 +3,11 @@ name = "ComputationGraph" features = [ ] includes = [ - "utils/graph.h", "pcg/layer_attrs.dtg.h", - "pcg/tensor_attrs.dtg.h" + "pcg/tensor_attrs.dtg.h", + "pcg/operator_graph.h", ] [[fields]] name = "raw_graph" -type = "::FlexFlow::OutputLabelledMultiDiGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs>" +type = "::FlexFlow::LabelledOperatorGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs>" diff --git a/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.h b/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.h new file mode 100644 index 0000000000..4fd78f2d44 --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.h @@ -0,0 +1,37 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml +/* proj-data +{ + "generated_from": "15bf9d73ef934599c9b11807d86ae5d4" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_LAYER_ADDED_RESULT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_LAYER_ADDED_RESULT_DTG_H + +#include "fmt/format.h" +#include "pcg/layer_guid_t.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" +#include +#include + +namespace FlexFlow { +struct LayerAddedResult { + LayerAddedResult() = delete; + LayerAddedResult(::FlexFlow::layer_guid_t const &layer, + std::vector<::FlexFlow::tensor_guid_t> const &outputs); + + bool operator==(LayerAddedResult const &) const; + bool operator!=(LayerAddedResult const &) const; + ::FlexFlow::layer_guid_t layer; + std::vector<::FlexFlow::tensor_guid_t> outputs; +}; +} // namespace FlexFlow + +namespace FlexFlow { +std::string format_as(LayerAddedResult const &); +std::ostream &operator<<(std::ostream &, LayerAddedResult const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_LAYER_ADDED_RESULT_DTG_H diff --git a/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml b/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml new file mode 100644 index 0000000000..b02e992ba1 --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "LayerAddedResult" +features = [ + "eq", + "fmt", +] + +includes = [ + "pcg/layer_guid_t.dtg.h", + "pcg/tensor_guid_t.dtg.h", +] + +[[fields]] +name = "layer" +type = "::FlexFlow::layer_guid_t" + +[[fields]] +name = "outputs" +type = "std::vector<::FlexFlow::tensor_guid_t>" diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 2721743dcd..d2dea1cbaf 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -169,8 +169,7 @@ struct ComputationGraphBuilder { bool keepdims, char const *name); // Add a split layer - void split(tensor_guid_t const &input, - tensor_guid_t *outputs, + std::vector split(tensor_guid_t const &input, std::vector const &split, int axis, std::optional const &name = std::nullopt); @@ -195,8 +194,7 @@ struct ComputationGraphBuilder { tensor_guid_t reverse(tensor_guid_t const &input, int axis, std::optional const &name = std::nullopt); - void top_k(tensor_guid_t const &input, - tensor_guid_t *outputs, + std::vector top_k(tensor_guid_t const &input, int k, bool sorted, std::optional const &name = std::nullopt); @@ -229,28 +227,27 @@ struct ComputationGraphBuilder { TensorAttrs get_attrs(tensor_guid_t const &) const; TensorShape get_shape(tensor_guid_t const &) const; + private: tensor_guid_t broadcast(tensor_guid_t const &, TensorShape const &); - void add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs); - tensor_guid_t add_layer( - LayerAttrs const &layer, - std::vector const &inputs, - std::vector>> const - &weight_shapes, - TensorShape const &output_shape); - std::vector add_layer( - LayerAttrs const &layer, - std::vector const &inputs, - std::vector>> const - &weight_shapes, - std::vector const &output_shapes); - tensor_guid_t as_type(tensor_guid_t const &, DataType, std::string const &); + std::vector add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs); + + tensor_guid_t add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorAttrs const &output); + + tensor_guid_t add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorShape const &output); + TensorShape get_broadcast_target_shape(std::vector const &); TensorShape get_broadcast_target_shape(std::vector const &); diff --git a/lib/pcg/include/pcg/layer_attrs.dtg.h b/lib/pcg/include/pcg/layer_attrs.dtg.h index 9c9d277b67..6afa1757dc 100644 --- a/lib/pcg/include/pcg/layer_attrs.dtg.h +++ b/lib/pcg/include/pcg/layer_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/layer_attrs.struct.toml /* proj-data { - "generated_from": "12b49c15e8defff5118e5607a7823f59" + "generated_from": "b3e4f0c07a906139b599bd4696cb5e65" } */ @@ -12,7 +12,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/operator_attrs.h" +#include "op-attrs/computation_graph_op_attrs.dtg.h" #include "utils/json.h" #include "utils/stack_string.h" #include @@ -23,7 +23,7 @@ namespace FlexFlow { struct LayerAttrs { LayerAttrs() = delete; - LayerAttrs(::FlexFlow::CompGraphOperatorAttrs const &attrs, + LayerAttrs(::FlexFlow::ComputationGraphOpAttrs const &attrs, std::optional<::FlexFlow::stack_string> const &name); bool operator==(LayerAttrs const &) const; @@ -32,7 +32,7 @@ struct LayerAttrs { bool operator>(LayerAttrs const &) const; bool operator<=(LayerAttrs const &) const; bool operator>=(LayerAttrs const &) const; - ::FlexFlow::CompGraphOperatorAttrs attrs; + ::FlexFlow::ComputationGraphOpAttrs attrs; std::optional<::FlexFlow::stack_string> name; }; } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/layer_attrs.struct.toml b/lib/pcg/include/pcg/layer_attrs.struct.toml index 0dec35a1d8..9f8aaa5ba3 100644 --- a/lib/pcg/include/pcg/layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/layer_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/operator_attrs.h", + "op-attrs/computation_graph_op_attrs.dtg.h", "utils/stack_string.h", "", "utils/json.h" @@ -18,7 +18,7 @@ includes = [ [[fields]] name = "attrs" -type = "::FlexFlow::CompGraphOperatorAttrs" +type = "::FlexFlow::ComputationGraphOpAttrs" [[fields]] name = "name" diff --git a/lib/pcg/include/pcg/layer_guid_t.dtg.h b/lib/pcg/include/pcg/layer_guid_t.dtg.h new file mode 100644 index 0000000000..4bbdd36fed --- /dev/null +++ b/lib/pcg/include/pcg/layer_guid_t.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/layer_guid_t.struct.toml +/* proj-data +{ + "generated_from": "a672ffe470fd1dde8299f91f3038ca7a" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_GUID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_GUID_T_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct layer_guid_t { + layer_guid_t() = delete; + layer_guid_t(::FlexFlow::Node const &raw_node); + + bool operator==(layer_guid_t const &) const; + bool operator!=(layer_guid_t const &) const; + bool operator<(layer_guid_t const &) const; + bool operator>(layer_guid_t const &) const; + bool operator<=(layer_guid_t const &) const; + bool operator>=(layer_guid_t const &) const; + ::FlexFlow::Node raw_node; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::layer_guid_t const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(layer_guid_t const &); +std::ostream &operator<<(std::ostream &, layer_guid_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_GUID_T_DTG_H diff --git a/lib/pcg/include/pcg/layer_guid_t.struct.toml b/lib/pcg/include/pcg/layer_guid_t.struct.toml new file mode 100644 index 0000000000..c6d4073f58 --- /dev/null +++ b/lib/pcg/include/pcg/layer_guid_t.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "layer_guid_t" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_node" +type = "::FlexFlow::Node" diff --git a/lib/pcg/include/pcg/operator_graph.h b/lib/pcg/include/pcg/operator_graph.h new file mode 100644 index 0000000000..2ea8feda97 --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph.h @@ -0,0 +1,65 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_H + +#include "utils/graph.h" + +namespace FlexFlow { + +struct OperatorGraphOutput { }; +struct OperatorGraphInput { }; +struct OperatorGraphOutputQuery { }; +struct OperatorGraphEdge { }; + +Node get_node(OperatorGraphOutput const &); +int get_idx(OperatorGraphOutput const &); + +Node get_node(OperatorGraphInput const &); +int get_idx(OperatorGraphInput const &); + +Node get_src_node(OperatorGraphEdge const &); +Node get_dst_node(OperatorGraphEdge const &); +int get_src_idx(OperatorGraphEdge const &); +int get_dst_idx(OperatorGraphEdge const &); + +struct OperatorGraphEdgeQuery; + +struct OperatorGraphView : virtual MultiDiGraphView { +public: + using Edge = OperatorGraphEdge; + using EdgeQuery = OperatorGraphEdgeQuery; + + OperatorGraphView(OperatorGraphView const &) = default; + OperatorGraphView &operator=(OperatorGraphView const &) = default; + + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set query_outputs(OperatorGraphOutputQuery const &) const; + std::unordered_set query_edges(OperatorGraphEdgeQuery const &) const; +}; +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(MultiDiGraphView); + +std::vector get_outputs(OperatorGraphView const &, Node const &); +std::unordered_set get_uses(OperatorGraphView const &, OperatorGraphOutput const &); + +struct OperatorGraph : virtual OperatorGraphView { +public: + OperatorGraph() = delete; + OperatorGraph(OperatorGraph const &) = default; + OperatorGraph &operator=(OperatorGraph const &) = default; + + Node add_node(std::vector const &inputs, int num_outputs); +}; + +template +struct LabelledOperatorGraphView : virtual OperatorGraphView { + NodeLabel const &at(Node const &) const; + OutputLabel const &at(OperatorGraphOutput const &) const; +}; + +template +struct LabelledOperatorGraph : virtual LabelledOperatorGraphView { + Node add_node(NodeLabel const &, std::vector const &inputs, std::vector const &output_labels); +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/src/pcg/computation_graph.dtg.cc b/lib/pcg/src/pcg/computation_graph.dtg.cc index b9b2ae56ee..5b376cad77 100644 --- a/lib/pcg/src/pcg/computation_graph.dtg.cc +++ b/lib/pcg/src/pcg/computation_graph.dtg.cc @@ -3,20 +3,19 @@ // lib/pcg/include/pcg/computation_graph.struct.toml /* proj-data { - "generated_from": "3639f7e8bb97a5ca2c2ef13caff3c84e" + "generated_from": "7d22a6bc44163f331bc33002714721cf" } */ #include "pcg/computation_graph.dtg.h" #include "pcg/layer_attrs.dtg.h" +#include "pcg/operator_graph.h" #include "pcg/tensor_attrs.dtg.h" -#include "utils/graph.h" namespace FlexFlow { ComputationGraph::ComputationGraph( - ::FlexFlow::OutputLabelledMultiDiGraph<::FlexFlow::LayerAttrs, - ::FlexFlow::TensorAttrs> const - &raw_graph) + ::FlexFlow::LabelledOperatorGraph<::FlexFlow::LayerAttrs, + ::FlexFlow::TensorAttrs> const &raw_graph) : raw_graph(raw_graph) {} } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc b/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc new file mode 100644 index 0000000000..18b394f6d0 --- /dev/null +++ b/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc @@ -0,0 +1,43 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml +/* proj-data +{ + "generated_from": "15bf9d73ef934599c9b11807d86ae5d4" +} +*/ + +#include "pcg/computation_graph/layer_added_result.dtg.h" + +#include "pcg/layer_guid_t.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" +#include + +namespace FlexFlow { +LayerAddedResult::LayerAddedResult( + ::FlexFlow::layer_guid_t const &layer, + std::vector<::FlexFlow::tensor_guid_t> const &outputs) + : layer(layer), outputs(outputs) {} +bool LayerAddedResult::operator==(LayerAddedResult const &other) const { + return std::tie(this->layer, this->outputs) == + std::tie(other.layer, other.outputs); +} +bool LayerAddedResult::operator!=(LayerAddedResult const &other) const { + return std::tie(this->layer, this->outputs) != + std::tie(other.layer, other.outputs); +} +} // namespace FlexFlow + +namespace FlexFlow { +std::string format_as(LayerAddedResult const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, LayerAddedResult const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index f6336f9510..f6d14c64ab 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -1,38 +1,77 @@ #include "pcg/computation_graph_builder.h" #include "op-attrs/get_op_type.h" #include "op-attrs/get_output_shapes.h" +#include "op-attrs/ops/weight_attrs.dtg.h" #include "utils/expected.h" #include "utils/fmt.h" #include "op-attrs/ops/element_binary.h" #include "op-attrs/ops/embedding.h" - +#include "op-attrs/computation_graph_op_attrs.h" +#include "utils/containers.h" +#include "utils/containers/enumerate_vector.h" +#include "pcg/computation_graph.h" +#include "utils/containers/concat_vectors.h" + namespace FlexFlow { -void ComputationGraphBuilder::add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs) { - NOT_IMPLEMENTED(); +std::vector ComputationGraphBuilder::add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs) { + std::vector weight_tensors; + for (auto const &kv : enumerate_vector(weights)) { + int weight_idx = kv.first; + TensorAttrs weight_tensor_attrs = kv.second; + + std::optional weight_name = transform(layer.name, [&](std::string const &layer_name) { return fmt::format("{}.weights[{}]", layer_name, weight_idx); }); + LayerAttrs weight_layer_attrs = LayerAttrs{ + ComputationGraphOpAttrs{WeightAttrs{}}, + weight_name, + }; + std::vector weight_layer_inputs = {}; + std::vector weight_layer_outputs = {weight_tensor_attrs}; + LayerAddedResult added_weight = ::FlexFlow::add_layer(this->computation_graph, weight_layer_attrs, weight_layer_inputs, weight_layer_outputs); + weight_tensors.push_back(get_only(added_weight.outputs)); + } + + LayerAddedResult added = ::FlexFlow::add_layer(this->computation_graph, layer, concat_vectors(inputs, weight_tensors), outputs); + return added.outputs; } -tensor_guid_t ComputationGraphBuilder::add_layer( - LayerAttrs const &layer, - std::vector const &inputs, - std::vector>> const - &weight_shapes, - TensorShape const &output_shape) { +tensor_guid_t ComputationGraphBuilder::add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorAttrs const &output) { + std::vector outputs = {output}; + return get_only(this->add_layer(layer, inputs, weights, outputs)); +} + +static tensor_guid_t make_weight_tensor(ComputationGraphBuilder &cgb, TensorShape const &shape, std::optional const &initializer_attrs = std::nullopt) { NOT_IMPLEMENTED(); } -std::vector ComputationGraphBuilder::add_layer( - LayerAttrs const &layer, - std::vector const &inputs, - std::vector>> const - &weight_shapes, - std::vector const &output_shapes) { +static tensor_guid_t make_output_tensor(ComputationGraphBuilder &cgb, TensorShape const &shape, std::optional const &initializer_attrs = std::nullopt) { NOT_IMPLEMENTED(); } +static tensor_guid_t cast_to(ComputationGraphBuilder &cgb, + tensor_guid_t const &x, + DataType data_type, + std::string const &name) { + DataType x_datatype = cgb.get_shape(x).data_type; + if (x_datatype < data_type) { + return cgb.cast(x, data_type, name); + } else if (x_datatype > data_type) { + throw mk_runtime_error(fmt::format("Could not convert provided tensor data type {} to " + "desired data type {}", + x_datatype, + data_type)); + } else { + return x; + } +} + + tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &, TensorShape const &) { NOT_IMPLEMENTED(); } @@ -42,32 +81,17 @@ tensor_guid_t ComputationGraphBuilder::cast(tensor_guid_t const &input, std::optional const &name){ NOT_IMPLEMENTED()} -tensor_guid_t ComputationGraphBuilder::as_type(tensor_guid_t const &x, - DataType data_type, - std::string const &name) { - DataType x_datatype = this->get_shape(x).data_type; - if (x_datatype < data_type) { - return this->cast(x, data_type, name); - } else if (x_datatype > data_type) { - throw mk_runtime_error("Could not convert provided tensor data type {} to " - "desired data type {}", - x_datatype, - data_type); - } - return x; -} - static std::string get_default_name(OperatorType op_type) { return get_operator_type_name(op_type); } -static std::string get_default_name(ComputationGraphAttrs const &attrs) { +static std::string get_default_name(ComputationGraphOpAttrs const &attrs) { return get_default_name(get_op_type(attrs)); } -template -static std::string get_default_name(std::variant const &attrs) { - return get_default_name(widen(attrs)); +template +static std::string get_default_name(T const &t) { + return get_default_name(t); } tensor_guid_t ComputationGraphBuilder::element_unary( @@ -78,10 +102,19 @@ tensor_guid_t ComputationGraphBuilder::element_unary( tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - LayerAttrs layer = {attrs, name}; + LayerAttrs layer = LayerAttrs{ + ComputationGraphOpAttrs{attrs}, + name + }; + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer( + layer, + {input}, + {}, + output_shape + ); } tensor_guid_t ComputationGraphBuilder::element_scalar_unary( @@ -92,7 +125,11 @@ tensor_guid_t ComputationGraphBuilder::element_scalar_unary( tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - LayerAttrs layer = {attrs, name}; + LayerAttrs layer = { + ComputationGraphOpAttrs{attrs}, + name + }; + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); return this->add_layer(layer, {input}, {}, output_shape); @@ -137,7 +174,11 @@ tensor_guid_t ComputationGraphBuilder::element_binary( ElementBinaryAttrs attrs = {op_type, compute_type, false, false}; - LayerAttrs layer = {attrs, name}; + LayerAttrs layer = { + ComputationGraphOpAttrs{attrs}, + name + }; + TensorShape output_shape = get_output_shape( attrs, this->get_shape(lhs_input), @@ -266,6 +307,8 @@ tensor_guid_t ComputationGraphBuilder::elu(tensor_guid_t const &input, return this->element_unary(OperatorType::ELU, input, name); } +static TensorAttrs make_weight_attrs(TensorShape const &, std::optional const &) { NOT_IMPLEMENTED(); } + tensor_guid_t ComputationGraphBuilder::conv2d( tensor_guid_t const &x, int outChannels, @@ -292,19 +335,25 @@ tensor_guid_t ComputationGraphBuilder::conv2d( groups, activation, use_bias}; + std::string name = maybe_name.value_or(get_default_name(attrs)); tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - LayerAttrs layer = {attrs, name}; - TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + LayerAttrs layer = { + ComputationGraphOpAttrs{attrs}, + name + }; + + TensorShape input_shape = this->get_shape(input); + TensorShape output_shape = get_output_shape(attrs, input_shape); - std::vector>> weights; + std::vector weights; - weights.push_back({get_kernel_shape(attrs, this->get_shape(input)), kernel_initializer}); + weights.push_back(make_weight_attrs(get_kernel_shape(attrs, input_shape), kernel_initializer)); if (use_bias) { - weights.push_back({get_bias_shape(attrs, this->get_shape(input)), bias_initializer}); + weights.push_back(make_weight_attrs(get_bias_shape(attrs, input_shape), bias_initializer)); } return this->add_layer(layer, {input}, weights, output_shape); @@ -318,7 +367,7 @@ tensor_guid_t ComputationGraphBuilder::dropout( DropoutAttrs attrs = {rate, seed}; std::string name = maybe_name.value_or(get_default_name(attrs)); - LayerAttrs layer = {attrs, name}; + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); @@ -333,18 +382,25 @@ tensor_guid_t ComputationGraphBuilder::embedding( AggregateOp aggr, DataType dtype, std::optional const &kernel_initializer, - std::optional const &maybe_name) { + std::optional const &maybe_name) +{ EmbeddingAttrs attrs = {num_entries, outDim, aggr, dtype}; std::string name = maybe_name.value_or(get_default_name(attrs)); - LayerAttrs layer = {attrs, name}; + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + TensorShape input_shape = this->get_shape(input); + + TensorAttrs weight_attrs = make_weight_attrs( + get_weights_shape(attrs, input_shape), + kernel_initializer + ); + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - TensorShape weights_shape = get_weights_shape(attrs, this->get_shape(input)); return this->add_layer( - layer, {input}, {{weights_shape, kernel_initializer}}, output_shape); + layer, {input}, {weight_attrs}, output_shape); } std::vector ComputationGraphBuilder::gather( @@ -355,7 +411,7 @@ std::vector ComputationGraphBuilder::gather( GatherAttrs attrs = {dim}; std::string name = maybe_name.value_or(get_default_name(attrs)); - LayerAttrs layer = {attrs, name}; + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; if (this->get_shape(index).data_type != DataType::INT32 && this->get_shape(index).data_type != DataType::INT64) { throw mk_runtime_error("Invalid data type for input tensor 2 for Gather: " @@ -411,7 +467,7 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( BatchNormAttrs attrs = BatchNormAttrs{relu}; std::string name = maybe_name.value_or(get_default_name(attrs)); - LayerAttrs layer = {attrs, name}; + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); diff --git a/lib/pcg/src/pcg/layer_attrs.dtg.cc b/lib/pcg/src/pcg/layer_attrs.dtg.cc index 54fe104ce3..21c53ad4e8 100644 --- a/lib/pcg/src/pcg/layer_attrs.dtg.cc +++ b/lib/pcg/src/pcg/layer_attrs.dtg.cc @@ -3,13 +3,13 @@ // lib/pcg/include/pcg/layer_attrs.struct.toml /* proj-data { - "generated_from": "12b49c15e8defff5118e5607a7823f59" + "generated_from": "b3e4f0c07a906139b599bd4696cb5e65" } */ #include "pcg/layer_attrs.dtg.h" -#include "op-attrs/operator_attrs.h" +#include "op-attrs/computation_graph_op_attrs.dtg.h" #include "utils/json.h" #include "utils/stack_string.h" #include @@ -17,7 +17,7 @@ namespace FlexFlow { LayerAttrs::LayerAttrs( - ::FlexFlow::CompGraphOperatorAttrs const &attrs, + ::FlexFlow::ComputationGraphOpAttrs const &attrs, std::optional<::FlexFlow::stack_string> const &name) : attrs(attrs), name(name) {} bool LayerAttrs::operator==(LayerAttrs const &other) const { @@ -44,7 +44,7 @@ namespace std { size_t hash::operator()( FlexFlow::LayerAttrs const &x) const { size_t result = 0; - result ^= std::hash<::FlexFlow::CompGraphOperatorAttrs>{}(x.attrs) + + result ^= std::hash<::FlexFlow::ComputationGraphOpAttrs>{}(x.attrs) + 0x9e3779b9 + (result << 6) + (result >> 2); result ^= std::hash>>{}(x.name) + @@ -57,7 +57,7 @@ namespace nlohmann { FlexFlow::LayerAttrs adl_serializer::from_json(json const &j) { return { - j.at("attrs").template get<::FlexFlow::CompGraphOperatorAttrs>(), + j.at("attrs").template get<::FlexFlow::ComputationGraphOpAttrs>(), j.at("name") .template get>>()}; } diff --git a/lib/pcg/src/pcg/layer_guid_t.dtg.cc b/lib/pcg/src/pcg/layer_guid_t.dtg.cc new file mode 100644 index 0000000000..9d92608569 --- /dev/null +++ b/lib/pcg/src/pcg/layer_guid_t.dtg.cc @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/layer_guid_t.struct.toml +/* proj-data +{ + "generated_from": "a672ffe470fd1dde8299f91f3038ca7a" +} +*/ + +#include "pcg/layer_guid_t.dtg.h" + +#include "utils/graph.h" +#include + +namespace FlexFlow { +layer_guid_t::layer_guid_t(::FlexFlow::Node const &raw_node) + : raw_node(raw_node) {} +bool layer_guid_t::operator==(layer_guid_t const &other) const { + return std::tie(this->raw_node) == std::tie(other.raw_node); +} +bool layer_guid_t::operator!=(layer_guid_t const &other) const { + return std::tie(this->raw_node) != std::tie(other.raw_node); +} +bool layer_guid_t::operator<(layer_guid_t const &other) const { + return std::tie(this->raw_node) < std::tie(other.raw_node); +} +bool layer_guid_t::operator>(layer_guid_t const &other) const { + return std::tie(this->raw_node) > std::tie(other.raw_node); +} +bool layer_guid_t::operator<=(layer_guid_t const &other) const { + return std::tie(this->raw_node) <= std::tie(other.raw_node); +} +bool layer_guid_t::operator>=(layer_guid_t const &other) const { + return std::tie(this->raw_node) >= std::tie(other.raw_node); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::layer_guid_t const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.raw_node) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(layer_guid_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, layer_guid_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph.cc b/lib/pcg/src/pcg/operator_graph.cc new file mode 100644 index 0000000000..70066e6b5b --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph.cc @@ -0,0 +1,39 @@ +#include "pcg/operator_graph.h" + +namespace FlexFlow { + +Node get_node(OperatorGraphOutput const &) { + NOT_IMPLEMENTED(); +} + +int get_idx(OperatorGraphOutput const &) { + NOT_IMPLEMENTED(); +} + +Node get_node(OperatorGraphInput const &) { + NOT_IMPLEMENTED(); +} + +int get_idx(OperatorGraphInput const &) { + NOT_IMPLEMENTED(); +} + +Node get_src_node(OperatorGraphEdge const &) { + NOT_IMPLEMENTED(); +} + +Node get_dst_node(OperatorGraphEdge const &) { + NOT_IMPLEMENTED(); +} + +int get_src_idx(OperatorGraphEdge const &) { + NOT_IMPLEMENTED(); +} + +int get_dst_idx(OperatorGraphEdge const &) { + NOT_IMPLEMENTED(); +} + +/* OperatorGraphView::query_nodes */ + +} // namespace FlexFlow diff --git a/lib/pcg/test/CMakeLists.txt b/lib/pcg/test/CMakeLists.txt new file mode 100644 index 0000000000..685d1d8b88 --- /dev/null +++ b/lib/pcg/test/CMakeLists.txt @@ -0,0 +1,13 @@ +ff_add_test_executable( + NAME + pcg-tests + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + src/ + DEPS + utils + pcg + doctest + utils-test-common +) diff --git a/lib/pcg/test/src/test_computation_graph_builder.cc b/lib/pcg/test/src/test_computation_graph_builder.cc new file mode 100644 index 0000000000..8619ce6ad7 --- /dev/null +++ b/lib/pcg/test/src/test_computation_graph_builder.cc @@ -0,0 +1,8 @@ +#include "doctest/doctest.h" +#include "op-attrs/ops/conv_2d.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("ComputationGraphBuilder") { + CHECK_MESSAGE(false, "TODO: ComputationGraphBuilder"); + } +} diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index ed4e4c0a2b..0db08b77e8 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -18,6 +18,7 @@ #include #include #include +#include "utils/containers/extend_vector.h" namespace FlexFlow { @@ -418,6 +419,7 @@ T get_first(std::unordered_set const &s) { template void extend(std::vector &lhs, C const &rhs) { + extend_vector(lhs, rhs); lhs.reserve(lhs.size() + std::distance(rhs.begin(), rhs.end())); lhs.insert(lhs.end(), rhs.begin(), rhs.end()); } diff --git a/lib/utils/include/utils/containers/concat_vectors.h b/lib/utils/include/utils/containers/concat_vectors.h new file mode 100644 index 0000000000..8c6858c84e --- /dev/null +++ b/lib/utils/include/utils/containers/concat_vectors.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CONCAT_VECTORS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CONCAT_VECTORS_H + +#include "utils/containers/extend_vector.h" + +namespace FlexFlow { + +template +std::vector concat_vectors(std::vector const &prefix, std::vector const &postfix) { + std::vector result = prefix; + extend_vector(result, postfix); + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/enumerate_vector.h b/lib/utils/include/utils/containers/enumerate_vector.h new file mode 100644 index 0000000000..bf927d8415 --- /dev/null +++ b/lib/utils/include/utils/containers/enumerate_vector.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H + +#include +#include + +namespace FlexFlow { + +template +std::vector> enumerate_vector(std::vector const &v) { + std::vector> result; + for (int i = 0; i < v.size(); i++) { + result.push_back({i, v.at(i)}); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/extend_vector.h b/lib/utils/include/utils/containers/extend_vector.h new file mode 100644 index 0000000000..289ead16c0 --- /dev/null +++ b/lib/utils/include/utils/containers/extend_vector.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_EXTEND_VECTOR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_EXTEND_VECTOR_H + +#include + +namespace FlexFlow { + +template +void extend_vector(std::vector &lhs, C const &rhs) { + lhs.reserve(lhs.size() + std::distance(rhs.begin(), rhs.end())); + lhs.insert(lhs.end(), rhs.begin(), rhs.end()); +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 6133a27832..d8d667d0d4 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -23,6 +23,15 @@ T const &assert_unwrap(std::optional const &o) { return o.value(); } +template +std::optional> transform(std::optional const &o, F &&f) { + if (o.has_value()) { + return std::optional{f(o)}; + } else { + return std::nullopt; + } +} + } // namespace FlexFlow namespace fmt { diff --git a/lib/utils/src/exception.cc b/lib/utils/src/exception.cc index 1369bea4c9..cb7fb0397b 100644 --- a/lib/utils/src/exception.cc +++ b/lib/utils/src/exception.cc @@ -3,6 +3,6 @@ namespace FlexFlow { not_implemented::not_implemented(std::string const &function_name, std::string const &file_name, int line) - : std::logic_error(fmt::format("Function {} not yet implemented at {}:{}", function_name, file_name, line)){}; + : std::logic_error(fmt::format("Function '{}' not yet implemented at {}:{}", function_name, file_name, line)){}; } From 16ca8de43d71598c5afcc2b46697b1facc31a76d Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 14 May 2024 16:32:12 -0700 Subject: [PATCH 15/43] Add initial test for pcg --- .proj.toml | 1 + lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 26 +++--- .../src/op-attrs/ops/element_binary.cc | 12 ++- .../src/op-attrs/ops/element_unary.cc | 20 ++++- lib/pcg/CMakeLists.txt | 1 + lib/pcg/include/pcg/computation_graph.dtg.h | 12 ++- lib/pcg/include/pcg/computation_graph.h | 4 + .../include/pcg/computation_graph.struct.toml | 4 +- .../include/pcg/computation_graph_builder.h | 13 +-- lib/pcg/include/pcg/dataflow_graph.h | 70 +++++++++++++++ .../v1/graphs/v1_operator_graph.dtg.h | 45 ++++++++++ .../v1/graphs/v1_operator_graph.struct.toml | 25 ++++++ .../pcg/{ => operator_graph}/operator_graph.h | 35 +++++--- .../operator_graph/operator_graph_input.dtg.h | 47 ++++++++++ .../pcg/operator_graph/operator_graph_input.h | 13 +++ .../operator_graph_input.struct.toml | 20 +++++ .../operator_graph_output.dtg.h | 47 ++++++++++ .../operator_graph/operator_graph_output.h | 13 +++ .../operator_graph_output.struct.toml | 20 +++++ .../pcg/parallel_computation_graph.dtg.h | 15 ++-- .../parallel_computation_graph.struct.toml | 4 +- lib/pcg/src/file_format/v1/graphs.cc | 87 +++++++++++++++++-- lib/pcg/src/pcg/computation_graph.cc | 20 +++++ lib/pcg/src/pcg/computation_graph.dtg.cc | 8 +- lib/pcg/src/pcg/computation_graph_builder.cc | 86 +++++++++++------- .../v1/graphs/v1_operator_graph.dtg.cc | 52 +++++++++++ lib/pcg/src/pcg/operator_graph.cc | 39 --------- .../src/pcg/operator_graph/operator_graph.cc | 44 ++++++++++ .../operator_graph/operator_graph_input.cc | 13 +++ .../operator_graph_input.dtg.cc | 63 ++++++++++++++ .../operator_graph/operator_graph_output.cc | 13 +++ .../operator_graph_output.dtg.cc | 63 ++++++++++++++ .../src/pcg/parallel_computation_graph.dtg.cc | 9 +- .../src/test_computation_graph_builder.cc | 19 +++- lib/utils/include/utils/integer_conversions.h | 12 +++ lib/utils/include/utils/optional.h | 4 +- lib/utils/src/utils/integer_conversions.cc | 11 +++ 37 files changed, 849 insertions(+), 141 deletions(-) create mode 100644 lib/pcg/include/pcg/dataflow_graph.h create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml rename lib/pcg/include/pcg/{ => operator_graph}/operator_graph.h (72%) create mode 100644 lib/pcg/include/pcg/operator_graph/operator_graph_input.dtg.h create mode 100644 lib/pcg/include/pcg/operator_graph/operator_graph_input.h create mode 100644 lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml create mode 100644 lib/pcg/include/pcg/operator_graph/operator_graph_output.dtg.h create mode 100644 lib/pcg/include/pcg/operator_graph/operator_graph_output.h create mode 100644 lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml create mode 100644 lib/pcg/src/pcg/computation_graph.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc delete mode 100644 lib/pcg/src/pcg/operator_graph.cc create mode 100644 lib/pcg/src/pcg/operator_graph/operator_graph.cc create mode 100644 lib/pcg/src/pcg/operator_graph/operator_graph_input.cc create mode 100644 lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc create mode 100644 lib/pcg/src/pcg/operator_graph/operator_graph_output.cc create mode 100644 lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc create mode 100644 lib/utils/include/utils/integer_conversions.h create mode 100644 lib/utils/src/utils/integer_conversions.cc diff --git a/.proj.toml b/.proj.toml index f2e7e20f49..43f1522186 100644 --- a/.proj.toml +++ b/.proj.toml @@ -14,6 +14,7 @@ build_targets = [ test_targets = [ "utils-tests", "op-attrs-tests", + "pcg-tests", # "substitutions-tests", # "compiler-tests", ] diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index e07e398fa6..8cf2afe125 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -1,14 +1,10 @@ #include "op-attrs/ops/conv_2d.h" #include "op-attrs/ops/conv_2d/conv_2d_input_shape.h" #include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h" +#include "utils/integer_conversions.h" namespace FlexFlow { -static size_t as_size_t(int x) { - assert (x >= 0); - return static_cast(x); -} - TensorShape get_kernel_shape(Conv2DAttrs const &attrs, TensorShape const &raw_input_shape) { assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported Conv2DInputShape input = parse_input_shape(raw_input_shape); @@ -16,10 +12,10 @@ TensorShape get_kernel_shape(Conv2DAttrs const &attrs, TensorShape const &raw_in return TensorShape{ TensorDims{ FFOrdered{ - as_size_t(attrs.out_channels), + size_t_from_int(attrs.out_channels), input.num_channels, - as_size_t(attrs.kernel_h), - as_size_t(attrs.kernel_w), + size_t_from_int(attrs.kernel_h), + size_t_from_int(attrs.kernel_w), } }, input.datatype, @@ -33,7 +29,7 @@ TensorShape get_bias_shape(Conv2DAttrs const &attrs, TensorShape const &raw_inpu return TensorShape{ TensorDims{ FFOrdered{ - as_size_t(attrs.out_channels) + size_t_from_int(attrs.out_channels) }, }, input.datatype, @@ -53,7 +49,7 @@ TensorShape get_output_shape(Conv2DAttrs const &attrs, TensorShape const &raw_in TensorDims{ FFOrdered{ input.num_samples, - as_size_t(attrs.out_channels), + size_t_from_int(attrs.out_channels), out_height, out_width, } @@ -66,10 +62,10 @@ ParallelTensorShape get_kernel_shape(Conv2DAttrs const &attrs, ParallelTensorSha assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); - ShardParallelDim output_channels_dim = {as_size_t(attrs.out_channels), input.discard_copy_reduction_degree}; - ShardParallelDim input_channels_dim = {as_size_t(input.channel_dim.size), input.channel_dim.degree}; - ShardParallelDim kernel_height_dim = {as_size_t(attrs.kernel_h), 1}; - ShardParallelDim kernel_width_dim = {as_size_t(attrs.kernel_w), 1}; + ShardParallelDim output_channels_dim = {size_t_from_int(attrs.out_channels), input.discard_copy_reduction_degree}; + ShardParallelDim input_channels_dim = {size_t_from_int(input.channel_dim.size), input.channel_dim.degree}; + ShardParallelDim kernel_height_dim = {size_t_from_int(attrs.kernel_h), 1}; + ShardParallelDim kernel_width_dim = {size_t_from_int(attrs.kernel_w), 1}; int sum_degree = 1; int discard_copy_degree = input.height_dim.degree * input.width_dim.degree * input.sum_reduction_degree; @@ -99,7 +95,7 @@ ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs, ParallelTensorShape assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); - ShardParallelDim output_channels_dim = {as_size_t(attrs.out_channels), input.discard_copy_reduction_degree}; + ShardParallelDim output_channels_dim = {size_t_from_int(attrs.out_channels), input.discard_copy_reduction_degree}; int sum_degree = 1; int discard_copy_degree = input.height_dim.degree * input.width_dim.degree * input.sum_reduction_degree * input.channel_dim.degree; diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary.cc b/lib/op-attrs/src/op-attrs/ops/element_binary.cc index b713c6753f..0e0fce6f1d 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_binary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_binary.cc @@ -1,3 +1,13 @@ #include "op-attrs/ops/element_binary.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +TensorShape get_output_shape(ElementBinaryAttrs const &, TensorShape const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/src/op-attrs/ops/element_unary.cc index 481151fafb..ab3ee1ba58 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_unary.cc @@ -1,3 +1,21 @@ #include "op-attrs/ops/element_unary.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ElementUnaryAttrs const &, ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +TensorShape get_output_shape(ElementUnaryAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + +ParallelTensorShape get_output_shape(ElementScalarUnaryAttrs const &, ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +TensorShape get_output_shape(ElementScalarUnaryAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/pcg/CMakeLists.txt b/lib/pcg/CMakeLists.txt index 81009b0f1f..e1875ca694 100644 --- a/lib/pcg/CMakeLists.txt +++ b/lib/pcg/CMakeLists.txt @@ -13,3 +13,4 @@ ff_add_library( ) add_subdirectory(ffi) +add_subdirectory(test) diff --git a/lib/pcg/include/pcg/computation_graph.dtg.h b/lib/pcg/include/pcg/computation_graph.dtg.h index d0cdefcb7a..217b940ce6 100644 --- a/lib/pcg/include/pcg/computation_graph.dtg.h +++ b/lib/pcg/include/pcg/computation_graph.dtg.h @@ -3,27 +3,25 @@ // lib/pcg/include/pcg/computation_graph.struct.toml /* proj-data { - "generated_from": "7d22a6bc44163f331bc33002714721cf" + "generated_from": "8f1f0e13d75065944f7fe307e12fe280" } */ #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H +#include "pcg/dataflow_graph.h" #include "pcg/layer_attrs.dtg.h" -#include "pcg/operator_graph.h" #include "pcg/tensor_attrs.dtg.h" namespace FlexFlow { struct ComputationGraph { ComputationGraph() = delete; ComputationGraph( - ::FlexFlow::LabelledOperatorGraph<::FlexFlow::LayerAttrs, - ::FlexFlow::TensorAttrs> const - &raw_graph); + ::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, + ::FlexFlow::TensorAttrs> const &raw_graph); - ::FlexFlow::LabelledOperatorGraph<::FlexFlow::LayerAttrs, - ::FlexFlow::TensorAttrs> + ::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs> raw_graph; }; } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index d68574ad71..7702d1a7f2 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -9,6 +9,10 @@ namespace FlexFlow { +ComputationGraph make_empty_computation_graph(); + +std::unordered_set get_layers(ComputationGraph const &); + LayerAddedResult add_layer(ComputationGraph &computation_graph, LayerAttrs const &attrs, std::vector const &inputs, std::vector const &outputs); TensorAttrs get_tensor_attrs(ComputationGraph const &, tensor_guid_t const &); diff --git a/lib/pcg/include/pcg/computation_graph.struct.toml b/lib/pcg/include/pcg/computation_graph.struct.toml index b8b3d8c372..a270cb8fbe 100644 --- a/lib/pcg/include/pcg/computation_graph.struct.toml +++ b/lib/pcg/include/pcg/computation_graph.struct.toml @@ -5,9 +5,9 @@ features = [ ] includes = [ "pcg/layer_attrs.dtg.h", "pcg/tensor_attrs.dtg.h", - "pcg/operator_graph.h", + "pcg/dataflow_graph.h", ] [[fields]] name = "raw_graph" -type = "::FlexFlow::LabelledOperatorGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs>" +type = "::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs>" diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index d2dea1cbaf..d7cf1e7a18 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -222,13 +222,11 @@ struct ComputationGraphBuilder { std::vector get_outputs(LayerAttrs const &) const; tensor_guid_t get_output(LayerAttrs const &, int idx) const; - tensor_guid_t at(MultiDiEdge const &) const; - LayerAttrs at(Node const &) const; - - TensorAttrs get_attrs(tensor_guid_t const &) const; +/* tensor_guid_t at(MultiDiEdge const &) const; */ +/* LayerAttrs at(Node const &) const; */ +private: TensorShape get_shape(tensor_guid_t const &) const; -private: tensor_guid_t broadcast(tensor_guid_t const &, TensorShape const &); tensor_guid_t as_type(tensor_guid_t const &, DataType, std::string const &); @@ -243,6 +241,11 @@ struct ComputationGraphBuilder { std::vector const &weights, TensorAttrs const &output); + std::vector add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs); + tensor_guid_t add_layer(LayerAttrs const &layer, std::vector const &inputs, std::vector const &weights, diff --git a/lib/pcg/include/pcg/dataflow_graph.h b/lib/pcg/include/pcg/dataflow_graph.h new file mode 100644 index 0000000000..a7affaab83 --- /dev/null +++ b/lib/pcg/include/pcg/dataflow_graph.h @@ -0,0 +1,70 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_DATAFLOW_GRAPH_H + +#include "utils/containers/enumerate_vector.h" +#include "utils/graph.h" + +namespace FlexFlow { + +template +struct DataflowGraph { +public: + DataflowGraph() + : g(OutputLabelledMultiDiGraph::template create< + UnorderedOutputLabelledMultiDiGraph>()) { } + + std::vector add_operator(NodeLabel const &func, std::vector const &inputs, std::vector const &outputs) { + Node n = this->g.add_node(func); + for (auto const &[idx, input] : enumerate_vector(inputs)) { + this->g.add_edge(MultiDiEdge{input.src, input.src_idx, n, this->make_port_for_idx(idx)}); + } + + std::vector result; + for (auto const &[idx, label] : enumerate_vector(outputs)) { + MultiDiOutput output = MultiDiOutput{n, this->make_port_for_idx(idx)}; + this->g.add_output(output, label); + result.push_back(output); + } + + return result; + } + + NodePort make_port_for_idx(int idx) { + if (!this->port_mapping.contains_l(idx)) { + this->port_mapping.equate(idx, this->g.add_node_port()); + } + return this->port_mapping.at_l(idx); + } + + NodePort port_for_idx(int idx) const { + return this->port_mapping.at_l(idx); + } + + int idx_for_port(NodePort const &p) const { + return this->port_mapping.at_r(p); + } + + OutputLabelledMultiDiGraphView const &get_raw_graph() const { + return this->g; + } + + NodeLabel const &at(Node const &n) const { + return this->g.at(n); + } + + OutputLabel const &at(MultiDiOutput const &o) const { + return this->g.at(o); + } +private: + OutputLabelledMultiDiGraph g; + bidict port_mapping; +}; + +template +std::unordered_set get_nodes(DataflowGraph const &g) { + return get_nodes(g.get_raw_graph()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h new file mode 100644 index 0000000000..7e5554d44a --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml +/* proj-data +{ + "generated_from": "5bfd7d8755cfd8cd9dbf57d5c367038e" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_OPERATOR_GRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_OPERATOR_GRAPH_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" +#include "utils/fmt.h" +#include +#include +#include + +namespace FlexFlow { +struct V1OperatorGraph { + V1OperatorGraph() = delete; + V1OperatorGraph(std::vector const &nodes, + std::unordered_set<::FlexFlow::V1GraphEdge> const &edges); + + std::vector nodes; + std::unordered_set<::FlexFlow::V1GraphEdge> edges; +}; +} // namespace FlexFlow + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::V1OperatorGraph from_json(json const &); + static void to_json(json &, FlexFlow::V1OperatorGraph const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1OperatorGraph const &); +std::ostream &operator<<(std::ostream &, V1OperatorGraph const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_OPERATOR_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml new file mode 100644 index 0000000000..61dc45ae2e --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "V1OperatorGraph" +features = [ + # "eq", + # "ord", + # "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "", + "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h", + "utils/fmt.h", +] + +[[fields]] +name = "nodes" +type = "std::vector" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::V1GraphEdge>" diff --git a/lib/pcg/include/pcg/operator_graph.h b/lib/pcg/include/pcg/operator_graph/operator_graph.h similarity index 72% rename from lib/pcg/include/pcg/operator_graph.h rename to lib/pcg/include/pcg/operator_graph/operator_graph.h index 2ea8feda97..2140ff1555 100644 --- a/lib/pcg/include/pcg/operator_graph.h +++ b/lib/pcg/include/pcg/operator_graph/operator_graph.h @@ -2,20 +2,14 @@ #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_H #include "utils/graph.h" +#include "pcg/operator_graph/operator_graph_output.dtg.h" +#include "pcg/operator_graph/operator_graph_input.dtg.h" namespace FlexFlow { -struct OperatorGraphOutput { }; -struct OperatorGraphInput { }; struct OperatorGraphOutputQuery { }; struct OperatorGraphEdge { }; -Node get_node(OperatorGraphOutput const &); -int get_idx(OperatorGraphOutput const &); - -Node get_node(OperatorGraphInput const &); -int get_idx(OperatorGraphInput const &); - Node get_src_node(OperatorGraphEdge const &); Node get_dst_node(OperatorGraphEdge const &); int get_src_idx(OperatorGraphEdge const &); @@ -23,32 +17,45 @@ int get_dst_idx(OperatorGraphEdge const &); struct OperatorGraphEdgeQuery; -struct OperatorGraphView : virtual MultiDiGraphView { +struct OperatorGraphView { public: using Edge = OperatorGraphEdge; using EdgeQuery = OperatorGraphEdgeQuery; - OperatorGraphView(OperatorGraphView const &) = default; - OperatorGraphView &operator=(OperatorGraphView const &) = default; + OperatorGraphView(OperatorGraphView const &); + OperatorGraphView &operator=(OperatorGraphView const &); + + OperatorGraphView(OperatorGraphView &&); + OperatorGraphView &&operator=(OperatorGraphView &&); std::unordered_set query_nodes(NodeQuery const &) const; std::unordered_set query_outputs(OperatorGraphOutputQuery const &) const; std::unordered_set query_edges(OperatorGraphEdgeQuery const &) const; + + struct Impl; + std::unique_ptr impl; }; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(MultiDiGraphView); +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(OperatorGraphView); +std::unordered_set get_outputs(OperatorGraphView const &); std::vector get_outputs(OperatorGraphView const &, Node const &); std::unordered_set get_uses(OperatorGraphView const &, OperatorGraphOutput const &); -struct OperatorGraph : virtual OperatorGraphView { +struct OperatorGraph { public: - OperatorGraph() = delete; + OperatorGraph(); OperatorGraph(OperatorGraph const &) = default; OperatorGraph &operator=(OperatorGraph const &) = default; Node add_node(std::vector const &inputs, int num_outputs); + +private: + struct Impl; + std::unique_ptr impl; }; +struct value_t; + template struct LabelledOperatorGraphView : virtual OperatorGraphView { NodeLabel const &at(Node const &) const; diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_input.dtg.h b/lib/pcg/include/pcg/operator_graph/operator_graph_input.dtg.h new file mode 100644 index 0000000000..13904f220d --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_input.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml +/* proj-data +{ + "generated_from": "57d9c9afc86f43049c6f035c74477afd" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_INPUT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_INPUT_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorGraphInput { + OperatorGraphInput() = delete; + OperatorGraphInput(::FlexFlow::Node const &node, int const &idx); + + bool operator==(OperatorGraphInput const &) const; + bool operator!=(OperatorGraphInput const &) const; + bool operator<(OperatorGraphInput const &) const; + bool operator>(OperatorGraphInput const &) const; + bool operator<=(OperatorGraphInput const &) const; + bool operator>=(OperatorGraphInput const &) const; + ::FlexFlow::Node node; + int idx; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorGraphInput const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OperatorGraphInput const &); +std::ostream &operator<<(std::ostream &, OperatorGraphInput const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_INPUT_DTG_H diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_input.h b/lib/pcg/include/pcg/operator_graph/operator_graph_input.h new file mode 100644 index 0000000000..18e7710186 --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_input.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_INPUT_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_INPUT_H + +#include "pcg/operator_graph/operator_graph_input.dtg.h" + +namespace FlexFlow { + +Node get_node(OperatorGraphInput const &); +int get_idx(OperatorGraphInput const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml b/lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml new file mode 100644 index 0000000000..a729f75bae --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OperatorGraphInput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "idx" +type = "int" diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_output.dtg.h b/lib/pcg/include/pcg/operator_graph/operator_graph_output.dtg.h new file mode 100644 index 0000000000..40bdc245b8 --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_output.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml +/* proj-data +{ + "generated_from": "3931cb388b00e0634495cdb89cb2af54" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_OUTPUT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_OUTPUT_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorGraphOutput { + OperatorGraphOutput() = delete; + OperatorGraphOutput(::FlexFlow::Node const &node, int const &idx); + + bool operator==(OperatorGraphOutput const &) const; + bool operator!=(OperatorGraphOutput const &) const; + bool operator<(OperatorGraphOutput const &) const; + bool operator>(OperatorGraphOutput const &) const; + bool operator<=(OperatorGraphOutput const &) const; + bool operator>=(OperatorGraphOutput const &) const; + ::FlexFlow::Node node; + int idx; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorGraphOutput const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OperatorGraphOutput const &); +std::ostream &operator<<(std::ostream &, OperatorGraphOutput const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_OUTPUT_DTG_H diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_output.h b/lib/pcg/include/pcg/operator_graph/operator_graph_output.h new file mode 100644 index 0000000000..d50b74f496 --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_output.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_OUTPUT_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_OUTPUT_H + +#include "pcg/operator_graph/operator_graph_output.dtg.h" + +namespace FlexFlow { + +Node get_node(OperatorGraphOutput const &); +int get_idx(OperatorGraphOutput const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml b/lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml new file mode 100644 index 0000000000..044d4c8df3 --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OperatorGraphOutput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "idx" +type = "int" diff --git a/lib/pcg/include/pcg/parallel_computation_graph.dtg.h b/lib/pcg/include/pcg/parallel_computation_graph.dtg.h index f08e58a8b6..01fbb7d30c 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph.dtg.h +++ b/lib/pcg/include/pcg/parallel_computation_graph.dtg.h @@ -3,26 +3,27 @@ // lib/pcg/include/pcg/parallel_computation_graph.struct.toml /* proj-data { - "generated_from": "3bb0791e3481298ddea75f4bd134f9e1" + "generated_from": "e4db0f603f7b8947dda13e01f96c40fb" } */ #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_DTG_H #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_DTG_H +#include "pcg/dataflow_graph.h" #include "pcg/parallel_layer_attrs.dtg.h" #include "pcg/parallel_tensor_attrs.dtg.h" -#include "utils/graph.h" namespace FlexFlow { struct ParallelComputationGraph { ParallelComputationGraph() = delete; - ParallelComputationGraph(::FlexFlow::OutputLabelledMultiDiGraph< - ::FlexFlow::ParallelLayerAttrs, - ::FlexFlow::ParallelTensorAttrs> const &raw_graph); + ParallelComputationGraph( + ::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const + &raw_graph); - ::FlexFlow::OutputLabelledMultiDiGraph<::FlexFlow::ParallelLayerAttrs, - ::FlexFlow::ParallelTensorAttrs> + ::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> raw_graph; }; } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/parallel_computation_graph.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph.struct.toml index 5e9eaee4ab..d4e305abe5 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph.struct.toml @@ -3,11 +3,11 @@ name = "ParallelComputationGraph" features = [ ] includes = [ - "utils/graph.h", + "pcg/dataflow_graph.h", "pcg/parallel_tensor_attrs.dtg.h", "pcg/parallel_layer_attrs.dtg.h", ] [[fields]] name = "raw_graph" -type = "::FlexFlow::OutputLabelledMultiDiGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" +type = "::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/pcg/src/file_format/v1/graphs.cc b/lib/pcg/src/file_format/v1/graphs.cc index 69fbb4e88e..0a11842709 100644 --- a/lib/pcg/src/file_format/v1/graphs.cc +++ b/lib/pcg/src/file_format/v1/graphs.cc @@ -1,8 +1,30 @@ #include "pcg/file_format/v1/graphs.h" #include "utils/graph/algorithms.h" +#include "pcg/file_format/v1/graphs/v1_operator_graph.dtg.h" +#include "utils/integer_conversions.h" +#include "pcg/file_format/v1/graphs/v1_multidigraph.h" +#include "pcg/dataflow_graph.h" namespace FlexFlow { +/* static V1OperatorGraph to_v1(OperatorGraphView const &g, bidict const &nodes) { */ +/* std::unordered_set edges; */ +/* for (MultiDiEdge const &e : get_edges(g)) { */ +/* size_t src_node = nodes.at_l(get_src_node(e)); */ +/* size_t dst_node = nodes.at_l(get_dst_node(e)); */ +/* size_t src_idx = size_t_from_int(get_src_idx(e)); */ +/* size_t dst_idx = size_t_from_int(get_dst_idx(e)); */ +/* V1GraphEdge v1_e = {src_node, src_idx, dst_node, dst_idx}; */ +/* edges.insert(v1_e); */ +/* } */ + +/* return V1OperatorGraph{ */ +/* count(nodes.size()), */ +/* edges, */ +/* }; */ +/* } */ + + static V1MultiDiGraph to_v1(MultiDiGraphView const &g, bidict const &nodes, bidict const &node_ports) { @@ -21,15 +43,70 @@ static V1MultiDiGraph to_v1(MultiDiGraphView const &g, }; } -static V1MultiDiGraph to_v1(MultiDiGraphView const &g) { - return to_v1(g, - enumerate(get_nodes(g)).reversed(), - enumerate(get_present_node_ports(g)).reversed()); +/* static V1MultiDiGraph to_v1(MultiDiGraphView const &g) { */ +/* return to_v1(g, */ +/* enumerate(get_nodes(g)).reversed(), */ +/* enumerate(get_present_node_ports(g)).reversed()); */ +/* } */ + +/* template */ +/* static V1JsonableGraph */ +/* to_v1(LabelledOperatorGraphView const &g) { */ + +/* bidict nodes = enumerate(get_nodes(g)); */ + +/* V1OperatorGraph unlabelled = to_v1(g, nodes.reversed()); */ +/* std::unordered_map node_labels = */ +/* map_values(nodes, [&](Node const &n) { return g.at(n); }); */ + +/* bidict outputs_bidict = enumerate(get_outputs(g)); */ +/* std::unordered_map outputs = */ +/* map_values(outputs_bidict, [&](OperatorGraphOutput const &o) { */ +/* return V1GraphOutput{nodes.at_r(get_node(o)), size_t_from_int(get_idx(o))}; */ +/* }); */ + +/* std::unordered_map output_labels = map_values( */ +/* outputs_bidict, [&](OperatorGraphOutput const &o) { return g.at(o); }); */ + +/* return {node_labels, outputs, output_labels, unlabelled}; */ +/* } */ + +template +static bidict get_ports_by_idx(DataflowGraph const &g) { + bidict result; + for (NodePort const &p : get_present_node_ports(g.get_raw_graph())) { + size_t idx = size_t_from_int(g.idx_for_port(p)); + result.equate(idx, p); + } + return result; +} + +template +static V1JsonableGraph + to_v1(DataflowGraph const &g) { + + bidict nodes = enumerate(get_nodes(g.get_raw_graph())); + bidict node_ports = get_ports_by_idx(g); + + V1MultiDiGraph unlabelled = to_v1(g.get_raw_graph(), nodes.reversed(), node_ports.reversed()); + std::unordered_map node_labels = + map_values(nodes, [&](Node const &n) { return g.at(n); }); + + bidict outputs_bidict = enumerate(get_outputs(g.get_raw_graph())); + std::unordered_map outputs = + map_values(outputs_bidict, [&](MultiDiOutput const &o) { + return V1GraphOutput{nodes.at_r(o.src), node_ports.at_r(o.src_idx)}; + }); + + std::unordered_map output_labels = map_values( + outputs_bidict, [&](MultiDiOutput const &o) { return g.at(o); }); + + return {node_labels, outputs, output_labels, unlabelled}; } template static V1JsonableGraph - to_v1(OutputLabelledMultiDiGraph const &g) { + to_v1(OutputLabelledMultiDiGraphView const &g) { bidict nodes = enumerate(get_nodes(g)); bidict node_ports = enumerate(get_present_node_ports(g)); diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc new file mode 100644 index 0000000000..3c21f32697 --- /dev/null +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -0,0 +1,20 @@ +#include "pcg/computation_graph.h" +#include "utils/containers.h" + +namespace FlexFlow { + +ComputationGraph make_empty_computation_graph() { + return ComputationGraph{ + DataflowGraph{} + }; +} + +std::unordered_set get_layers(ComputationGraph const &cg) { + return transform(get_nodes(cg.raw_graph), [&](Node const &n) { return layer_guid_t{n}; }); +} + +TensorAttrs get_tensor_attrs(ComputationGraph const &cg, tensor_guid_t const &t) { + return cg.raw_graph.at(t.raw_graph_output); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph.dtg.cc b/lib/pcg/src/pcg/computation_graph.dtg.cc index 5b376cad77..bb6233a910 100644 --- a/lib/pcg/src/pcg/computation_graph.dtg.cc +++ b/lib/pcg/src/pcg/computation_graph.dtg.cc @@ -3,19 +3,19 @@ // lib/pcg/include/pcg/computation_graph.struct.toml /* proj-data { - "generated_from": "7d22a6bc44163f331bc33002714721cf" + "generated_from": "8f1f0e13d75065944f7fe307e12fe280" } */ #include "pcg/computation_graph.dtg.h" +#include "pcg/dataflow_graph.h" #include "pcg/layer_attrs.dtg.h" -#include "pcg/operator_graph.h" #include "pcg/tensor_attrs.dtg.h" namespace FlexFlow { ComputationGraph::ComputationGraph( - ::FlexFlow::LabelledOperatorGraph<::FlexFlow::LayerAttrs, - ::FlexFlow::TensorAttrs> const &raw_graph) + ::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, + ::FlexFlow::TensorAttrs> const &raw_graph) : raw_graph(raw_graph) {} } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index f6d14c64ab..bd381c6047 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -14,11 +14,28 @@ namespace FlexFlow { +ComputationGraphBuilder::ComputationGraphBuilder() + : computation_graph(make_empty_computation_graph()) { } + +TensorShape ComputationGraphBuilder::get_shape(tensor_guid_t const &t) const { + return get_tensor_attrs(this->computation_graph, t).shape; +} + +tensor_guid_t ComputationGraphBuilder::create_tensor(TensorShape const &shape, bool create_grad) { + TensorAttrs tensor_attrs = {shape, std::nullopt, create_grad, std::nullopt}; + LayerAttrs layer_attrs = LayerAttrs{ + ComputationGraphOpAttrs{InputAttrs{}}, + std::nullopt, + }; + + return this->add_layer(layer_attrs, {}, {}, tensor_attrs); +} + std::vector ComputationGraphBuilder::add_layer(LayerAttrs const &layer, std::vector const &inputs, std::vector const &weights, std::vector const &outputs) { - std::vector weight_tensors; + std::vector raw_weight_tensors; for (auto const &kv : enumerate_vector(weights)) { int weight_idx = kv.first; TensorAttrs weight_tensor_attrs = kv.second; @@ -28,14 +45,16 @@ std::vector ComputationGraphBuilder::add_layer(LayerAttrs const & ComputationGraphOpAttrs{WeightAttrs{}}, weight_name, }; - std::vector weight_layer_inputs = {}; - std::vector weight_layer_outputs = {weight_tensor_attrs}; - LayerAddedResult added_weight = ::FlexFlow::add_layer(this->computation_graph, weight_layer_attrs, weight_layer_inputs, weight_layer_outputs); - weight_tensors.push_back(get_only(added_weight.outputs)); + std::vector weight_layer_inputs = {}; + std::vector weight_output_attrs = {weight_tensor_attrs}; + raw_weight_tensors.push_back( + get_only(this->computation_graph.raw_graph.add_operator(weight_layer_attrs, weight_layer_inputs, weight_output_attrs)) + ); } - LayerAddedResult added = ::FlexFlow::add_layer(this->computation_graph, layer, concat_vectors(inputs, weight_tensors), outputs); - return added.outputs; + std::vector raw_inputs = transform(inputs, [](tensor_guid_t const &t) { return t.raw_graph_output; }); + std::vector raw_outputs = this->computation_graph.raw_graph.add_operator(layer, concat_vectors(raw_inputs, raw_weight_tensors), outputs); + return transform(raw_outputs, [](MultiDiOutput const &o) { return tensor_guid_t{o}; }); } tensor_guid_t ComputationGraphBuilder::add_layer(LayerAttrs const &layer, @@ -46,21 +65,33 @@ tensor_guid_t ComputationGraphBuilder::add_layer(LayerAttrs const &layer, return get_only(this->add_layer(layer, inputs, weights, outputs)); } -static tensor_guid_t make_weight_tensor(ComputationGraphBuilder &cgb, TensorShape const &shape, std::optional const &initializer_attrs = std::nullopt) { - NOT_IMPLEMENTED(); +std::vector ComputationGraphBuilder::add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs) { + return this->add_layer(layer, + inputs, + weights, + transform(outputs, [](TensorShape const &s) { return TensorAttrs{s, std::nullopt, true, std::nullopt}; })); } -static tensor_guid_t make_output_tensor(ComputationGraphBuilder &cgb, TensorShape const &shape, std::optional const &initializer_attrs = std::nullopt) { - NOT_IMPLEMENTED(); +tensor_guid_t ComputationGraphBuilder::add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorShape const &output) { + return get_only(this->add_layer(layer, + inputs, + weights, + std::vector{output})); } -static tensor_guid_t cast_to(ComputationGraphBuilder &cgb, +tensor_guid_t ComputationGraphBuilder::as_type( tensor_guid_t const &x, DataType data_type, std::string const &name) { - DataType x_datatype = cgb.get_shape(x).data_type; + DataType x_datatype = this->get_shape(x).data_type; if (x_datatype < data_type) { - return cgb.cast(x, data_type, name); + return this->cast(x, data_type, name); } else if (x_datatype > data_type) { throw mk_runtime_error(fmt::format("Could not convert provided tensor data type {} to " "desired data type {}", @@ -89,16 +120,11 @@ static std::string get_default_name(ComputationGraphOpAttrs const &attrs) { return get_default_name(get_op_type(attrs)); } -template -static std::string get_default_name(T const &t) { - return get_default_name(t); -} - tensor_guid_t ComputationGraphBuilder::element_unary( ElementUnaryAttrs const &attrs, tensor_guid_t const &x, std::optional const &maybe_name) { - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); @@ -121,7 +147,7 @@ tensor_guid_t ComputationGraphBuilder::element_scalar_unary( ElementScalarUnaryAttrs const &attrs, tensor_guid_t const &x, std::optional const &maybe_name) { - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); @@ -307,7 +333,9 @@ tensor_guid_t ComputationGraphBuilder::elu(tensor_guid_t const &input, return this->element_unary(OperatorType::ELU, input, name); } -static TensorAttrs make_weight_attrs(TensorShape const &, std::optional const &) { NOT_IMPLEMENTED(); } +static TensorAttrs make_weight_attrs(TensorShape const &shape, std::optional const &initializer_attrs) { + return TensorAttrs{shape, initializer_attrs, true, std::nullopt}; +} tensor_guid_t ComputationGraphBuilder::conv2d( tensor_guid_t const &x, @@ -336,7 +364,7 @@ tensor_guid_t ComputationGraphBuilder::conv2d( activation, use_bias}; - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); @@ -365,7 +393,7 @@ tensor_guid_t ComputationGraphBuilder::dropout( unsigned long long seed, std::optional const &maybe_name) { DropoutAttrs attrs = {rate, seed}; - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); @@ -385,7 +413,7 @@ tensor_guid_t ComputationGraphBuilder::embedding( std::optional const &maybe_name) { EmbeddingAttrs attrs = {num_entries, outDim, aggr, dtype}; - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); @@ -409,7 +437,7 @@ std::vector ComputationGraphBuilder::gather( ff_dim_t dim, std::optional const &maybe_name) { GatherAttrs attrs = {dim}; - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; if (this->get_shape(index).data_type != DataType::INT32 && @@ -426,10 +454,6 @@ std::vector ComputationGraphBuilder::gather( return this->add_layer(layer, {input}, {}, output_shapes); } -TensorShape ComputationGraphBuilder::get_shape(tensor_guid_t const &t) const { - return this->get_attrs(t).shape; -} - /* std::vector ComputationGraphBuilder::get_shapes(std::vector const &ts) const { */ /* return transform(ts, [&](tensor_guid_t const &t) { return this->get_shape(t); }); */ /* } */ @@ -465,7 +489,7 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( bool relu, std::optional const &maybe_name) { BatchNormAttrs attrs = BatchNormAttrs{relu}; - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc new file mode 100644 index 0000000000..19f1e09d07 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc @@ -0,0 +1,52 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml +/* proj-data +{ + "generated_from": "5bfd7d8755cfd8cd9dbf57d5c367038e" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_operator_graph.dtg.h" + +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" +#include "utils/fmt.h" +#include +#include +#include + +namespace FlexFlow { +V1OperatorGraph::V1OperatorGraph( + std::vector const &nodes, + std::unordered_set<::FlexFlow::V1GraphEdge> const &edges) + : nodes(nodes), edges(edges) {} +} // namespace FlexFlow + +namespace nlohmann { +FlexFlow::V1OperatorGraph + adl_serializer::from_json(json const &j) { + return {j.at("nodes").template get>(), + j.at("edges") + .template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::V1OperatorGraph const &v) { + j["__type"] = "V1OperatorGraph"; + j["nodes"] = v.nodes; + j["edges"] = v.edges; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1OperatorGraph const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, V1OperatorGraph const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph.cc b/lib/pcg/src/pcg/operator_graph.cc deleted file mode 100644 index 70066e6b5b..0000000000 --- a/lib/pcg/src/pcg/operator_graph.cc +++ /dev/null @@ -1,39 +0,0 @@ -#include "pcg/operator_graph.h" - -namespace FlexFlow { - -Node get_node(OperatorGraphOutput const &) { - NOT_IMPLEMENTED(); -} - -int get_idx(OperatorGraphOutput const &) { - NOT_IMPLEMENTED(); -} - -Node get_node(OperatorGraphInput const &) { - NOT_IMPLEMENTED(); -} - -int get_idx(OperatorGraphInput const &) { - NOT_IMPLEMENTED(); -} - -Node get_src_node(OperatorGraphEdge const &) { - NOT_IMPLEMENTED(); -} - -Node get_dst_node(OperatorGraphEdge const &) { - NOT_IMPLEMENTED(); -} - -int get_src_idx(OperatorGraphEdge const &) { - NOT_IMPLEMENTED(); -} - -int get_dst_idx(OperatorGraphEdge const &) { - NOT_IMPLEMENTED(); -} - -/* OperatorGraphView::query_nodes */ - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph.cc b/lib/pcg/src/pcg/operator_graph/operator_graph.cc new file mode 100644 index 0000000000..afb3f99e55 --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph/operator_graph.cc @@ -0,0 +1,44 @@ +#include "pcg/operator_graph/operator_graph.h" +#include "utils/graph.h" + +namespace FlexFlow { + +/* struct OperatorGraphView::Impl { */ +/* MultiDiGraphView raw_graph; */ +/* }; */ + +/* struct OperatorGraph::Impl { */ +/* MultiDiGraph raw_graph; */ +/* }; */ + +/* std::unordered_set get_outputs(OperatorGraphView const &g) { */ +/* return transform(get_outputs(g.impl->raw_graph), [](MultiDiOutput const &o) {}); */ +/* } */ + +/* std::vector get_outputs(OperatorGraphView const &, Node const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* std::unordered_set get_uses(OperatorGraphView const &, OperatorGraphOutput const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* Node get_src_node(OperatorGraphEdge const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* Node get_dst_node(OperatorGraphEdge const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* int get_src_idx(OperatorGraphEdge const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* int get_dst_idx(OperatorGraphEdge const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* OperatorGraphView::query_nodes */ + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_input.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_input.cc new file mode 100644 index 0000000000..945034dd73 --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_input.cc @@ -0,0 +1,13 @@ +#include "pcg/operator_graph/operator_graph_input.h" + +namespace FlexFlow { + +Node get_node(OperatorGraphInput const &i) { + return i.node; +} + +int get_idx(OperatorGraphInput const &i) { + return i.idx; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc new file mode 100644 index 0000000000..381c948ad0 --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml +/* proj-data +{ + "generated_from": "57d9c9afc86f43049c6f035c74477afd" +} +*/ + +#include "pcg/operator_graph/operator_graph_input.dtg.h" + +#include "utils/graph.h" +#include + +namespace FlexFlow { +OperatorGraphInput::OperatorGraphInput(::FlexFlow::Node const &node, + int const &idx) + : node(node), idx(idx) {} +bool OperatorGraphInput::operator==(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) == std::tie(other.node, other.idx); +} +bool OperatorGraphInput::operator!=(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) != std::tie(other.node, other.idx); +} +bool OperatorGraphInput::operator<(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) < std::tie(other.node, other.idx); +} +bool OperatorGraphInput::operator>(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) > std::tie(other.node, other.idx); +} +bool OperatorGraphInput::operator<=(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) <= std::tie(other.node, other.idx); +} +bool OperatorGraphInput::operator>=(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) >= std::tie(other.node, other.idx); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorGraphInput const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.idx) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OperatorGraphInput const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OperatorGraphInput const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_output.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_output.cc new file mode 100644 index 0000000000..bdfe1a9795 --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_output.cc @@ -0,0 +1,13 @@ +#include "pcg/operator_graph/operator_graph_output.h" + +namespace FlexFlow { + +Node get_node(OperatorGraphOutput const &o) { + return o.node; +} + +int get_idx(OperatorGraphOutput const &o) { + return o.idx; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc new file mode 100644 index 0000000000..88c23c0c67 --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml +/* proj-data +{ + "generated_from": "3931cb388b00e0634495cdb89cb2af54" +} +*/ + +#include "pcg/operator_graph/operator_graph_output.dtg.h" + +#include "utils/graph.h" +#include + +namespace FlexFlow { +OperatorGraphOutput::OperatorGraphOutput(::FlexFlow::Node const &node, + int const &idx) + : node(node), idx(idx) {} +bool OperatorGraphOutput::operator==(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) == std::tie(other.node, other.idx); +} +bool OperatorGraphOutput::operator!=(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) != std::tie(other.node, other.idx); +} +bool OperatorGraphOutput::operator<(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) < std::tie(other.node, other.idx); +} +bool OperatorGraphOutput::operator>(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) > std::tie(other.node, other.idx); +} +bool OperatorGraphOutput::operator<=(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) <= std::tie(other.node, other.idx); +} +bool OperatorGraphOutput::operator>=(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) >= std::tie(other.node, other.idx); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorGraphOutput const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.idx) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OperatorGraphOutput const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OperatorGraphOutput const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph.dtg.cc index 18549b43a2..e4e1555b4a 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph.dtg.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph.dtg.cc @@ -3,20 +3,19 @@ // lib/pcg/include/pcg/parallel_computation_graph.struct.toml /* proj-data { - "generated_from": "3bb0791e3481298ddea75f4bd134f9e1" + "generated_from": "e4db0f603f7b8947dda13e01f96c40fb" } */ #include "pcg/parallel_computation_graph.dtg.h" +#include "pcg/dataflow_graph.h" #include "pcg/parallel_layer_attrs.dtg.h" #include "pcg/parallel_tensor_attrs.dtg.h" -#include "utils/graph.h" namespace FlexFlow { ParallelComputationGraph::ParallelComputationGraph( - ::FlexFlow::OutputLabelledMultiDiGraph< - ::FlexFlow::ParallelLayerAttrs, - ::FlexFlow::ParallelTensorAttrs> const &raw_graph) + ::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const &raw_graph) : raw_graph(raw_graph) {} } // namespace FlexFlow diff --git a/lib/pcg/test/src/test_computation_graph_builder.cc b/lib/pcg/test/src/test_computation_graph_builder.cc index 8619ce6ad7..c241d0f2c5 100644 --- a/lib/pcg/test/src/test_computation_graph_builder.cc +++ b/lib/pcg/test/src/test_computation_graph_builder.cc @@ -1,8 +1,23 @@ #include "doctest/doctest.h" -#include "op-attrs/ops/conv_2d.h" +#include "pcg/computation_graph_builder.h" +#include "pcg/computation_graph.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ComputationGraphBuilder") { - CHECK_MESSAGE(false, "TODO: ComputationGraphBuilder"); + ComputationGraphBuilder b; + + size_t batch_size = 2; + + TensorShape input_shape = { + TensorDims{ + FFOrdered{batch_size, 3, 10, 10} + }, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_tensor(input_shape, /*create_grad=*/true); + tensor_guid_t output = b.conv2d(input, /*outChannels=*/5, /*kernelH=*/3, /*kernelW=*/3, /*strideH=*/1, /*strideW=*/1, /*paddingH=*/0, /*paddingW=*/0); + // ComputationGraph cg = b.computation_graph; + // CHECK(get_layers(cg).size() == 1); } } diff --git a/lib/utils/include/utils/integer_conversions.h b/lib/utils/include/utils/integer_conversions.h new file mode 100644 index 0000000000..6d3fb4cfdf --- /dev/null +++ b/lib/utils/include/utils/integer_conversions.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_INTEGER_CONVERSIONS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_INTEGER_CONVERSIONS_H + +#include + +namespace FlexFlow { + +size_t size_t_from_int(int); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index d8d667d0d4..2546c302c0 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -25,8 +25,10 @@ T const &assert_unwrap(std::optional const &o) { template std::optional> transform(std::optional const &o, F &&f) { + using Return = std::invoke_result_t; if (o.has_value()) { - return std::optional{f(o)}; + Return r = f(o.value()); + return std::optional{r}; } else { return std::nullopt; } diff --git a/lib/utils/src/utils/integer_conversions.cc b/lib/utils/src/utils/integer_conversions.cc new file mode 100644 index 0000000000..7156aab896 --- /dev/null +++ b/lib/utils/src/utils/integer_conversions.cc @@ -0,0 +1,11 @@ +#include "utils/integer_conversions.h" +#include + +namespace FlexFlow { + +size_t size_t_from_int(int x) { + assert (x >= 0); + return static_cast(x); +} + +} // namespace FlexFlow From f6535850ddf0e840fe9307d76fa8e6a40d86bbe5 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Tue, 14 May 2024 16:56:57 -0700 Subject: [PATCH 16/43] Partial implementation of shape inference for linear --- lib/op-attrs/include/op-attrs/ops/linear.h | 9 ++- lib/op-attrs/include/op-attrs/tensor_dims.h | 1 + lib/op-attrs/src/op-attrs/ops/linear.cc | 61 ++++++++++++++++++- .../src/op-attrs/parallel_tensor_dims.cc | 3 + lib/op-attrs/src/op-attrs/tensor_dims.cc | 7 +++ lib/op-attrs/src/op-attrs/tensor_shape.cc | 1 + 6 files changed, 80 insertions(+), 2 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index 2b0c5c7cda..d90d0712db 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -4,12 +4,19 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/linear_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(LinearAttrs); -ParallelTensorShape get_output_shape(LinearAttrs const &, ParallelTensorShape const &); +TensorShape get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input); +TensorShape get_bias_shape(LinearAttrs const &attrs, TensorShape const &input); +TensorShape get_output_shape(LinearAttrs const &attrs, TensorShape const &input); + +ParallelTensorShape get_kernel_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); +ParallelTensorShape get_bias_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); +ParallelTensorShape get_output_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h index a0c37139d0..05a2c6e263 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -8,6 +8,7 @@ namespace FlexFlow { FFOrdered const &ff_ordered(TensorDims const &); +size_t num_dims(TensorDims const &); size_t dim_at_idx(TensorDims const &, ff_dim_t); ParallelTensorDims lift_to_parallel(TensorDims const &); diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index 0ecd601e00..a5fae353de 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -1,9 +1,68 @@ #include "op-attrs/ops/linear.h" +#include "op-attrs/tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "utils/integer_conversions.h" namespace FlexFlow { -ParallelTensorShape get_output_shape(LinearAttrs const &, ParallelTensorShape const &) { +TensorShape get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { + size_t in_channels = dim_at_idx(input_shape, ff_dim_t{-1}); + + return TensorShape{ + TensorDims{ + FFOrdered{ + in_channels, + size_t_from_int(attrs.out_channels) + }, + }, + input_shape.data_type, + }; +} + +TensorShape get_bias_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { + return TensorShape{ + TensorDims{ + FFOrdered{ + size_t_from_int(attrs.out_channels) + }, + }, + input_shape.data_type, + }; +} + +TensorShape get_output_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { + TensorShape output_shape = input_shape; + output_shape.dims.ff_ordered.at(ff_dim_t{-1}) = size_t_from_int(attrs.out_channels); + + return output_shape; +} + +ParallelTensorShape get_kernel_shape(LinearAttrs const &attrs, ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); + /* ShardParallelDim input_sample_dim = shard_dim_at_idx(input_shape, ff_dim_t{-2}); */ + /* ShardParallelDim in_channels_dim = shard_dim_at_idx(input_shape, ff_dim_t{-1}); */ } +ParallelTensorShape get_output_shape(LinearAttrs const &attrs, ParallelTensorShape const &input_shape) { + ShardParallelDim input_sample_dim = shard_dim_at_idx(input_shape, ff_dim_t{-2}); + ShardParallelDim in_channels_dim = shard_dim_at_idx(input_shape, ff_dim_t{-1}); + + ShardParallelDim output_sample_dim = input_sample_dim; + ShardParallelDim output_channels_dim = { size_t_from_int(attrs.out_channels), input_shape.dims.replica_dims.discard_copy_degree }; + + int output_sum_degree = input_shape.dims.replica_dims.sum_degree * in_channels_dim.degree; + int output_discard_copy_degree = 1; + + ParallelTensorShape result = input_shape; + shard_dim_at_idx(result, ff_dim_t{-2}) = output_sample_dim; + shard_dim_at_idx(result, ff_dim_t{-1}) = output_channels_dim; + result.dims.replica_dims.sum_degree = output_sum_degree; + result.dims.replica_dims.discard_copy_degree = output_discard_copy_degree; + + assert (total_parallel_degree(result.dims) == total_parallel_degree(input_shape.dims)); + + return result; +} + + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 531d571309..1c3c42173b 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -36,6 +36,9 @@ bool is_valid(ParallelTensorDims const &dims) { } ShardParallelDim shard_dim_at_idx(ParallelTensorDims const &d, ff_dim_t idx) { + if (idx.value < 0) { + idx = ff_dim_t{d.shard_dims.size() + idx.value}; + } return d.shard_dims.at(idx); } diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index af6f4c9e82..9f226f9101 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -9,7 +9,14 @@ FFOrdered const &ff_ordered(TensorDims const &dims) { return dims.ff_ordered; } +size_t num_dims(TensorDims const &dims) { + return dims.ff_ordered.size(); +} + size_t dim_at_idx(TensorDims const &dims, ff_dim_t idx) { + if (idx.value < 0) { + idx = ff_dim_t{num_dims(dims) + idx.value}; + } return dims.ff_ordered.at(idx); } diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.cc b/lib/op-attrs/src/op-attrs/tensor_shape.cc index 01afbddf1e..f338b56b59 100644 --- a/lib/op-attrs/src/op-attrs/tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.cc @@ -8,6 +8,7 @@ size_t num_dims(TensorShape const &s) { } size_t dim_at_idx(TensorShape const &s, ff_dim_t idx) { + if (idx.value < 0) { return dim_at_idx(s.dims, idx); } From a0c6c5228a8a6b006308379d174feeb3f048a5b2 Mon Sep 17 00:00:00 2001 From: Rae Wong <33883582+yingyee0111@users.noreply.github.com> Date: Sun, 26 May 2024 10:49:54 -0700 Subject: [PATCH 17/43] Fix rapidcheck (#8) * enable rapidchecks for op-attrs * added rc::checks * fix merge * fixed variant toml * revert proj.toml * removed additional import merged * constraint for ff_dim * lock flake --------- Co-authored-by: Rae Wong --- flake.lock | 6 +- .../op-attrs/computation_graph_op_attrs.dtg.h | 9 ++- .../computation_graph_op_attrs.variant.toml | 1 + lib/op-attrs/include/op-attrs/dim_ordered.h | 17 ++++- lib/op-attrs/include/op-attrs/ff_dim.dtg.h | 2 +- lib/op-attrs/include/op-attrs/ff_dim.h | 17 +++++ .../include/op-attrs/ff_dim.struct.toml | 1 - lib/op-attrs/include/op-attrs/ops/attention.h | 10 ++- .../op-attrs/ops/attention_inputs.dtg.h | 10 ++- .../op-attrs/ops/attention_inputs.struct.toml | 2 +- .../include/op-attrs/ops/broadcast.dtg.h | 10 ++- .../op-attrs/ops/broadcast.struct.toml | 2 +- .../include/op-attrs/ops/cast_attrs.dtg.h | 10 ++- .../op-attrs/ops/cast_attrs.struct.toml | 2 +- .../include/op-attrs/ops/combine_attrs.dtg.h | 11 +++- .../op-attrs/ops/combine_attrs.struct.toml | 3 +- .../include/op-attrs/ops/concat_attrs.dtg.h | 11 +++- .../op-attrs/ops/concat_attrs.struct.toml | 3 +- .../include/op-attrs/ops/conv_2d_attrs.dtg.h | 10 ++- .../op-attrs/ops/conv_2d_attrs.struct.toml | 2 +- .../op-attrs/ops/element_binary_attrs.dtg.h | 10 ++- .../ops/element_binary_attrs.struct.toml | 2 +- .../ops/element_scalar_unary_attrs.dtg.h | 10 ++- .../element_scalar_unary_attrs.struct.toml | 2 +- .../op-attrs/ops/element_unary_attrs.dtg.h | 10 ++- .../ops/element_unary_attrs.struct.toml | 2 +- .../op-attrs/ops/embedding_attrs.dtg.h | 10 ++- .../op-attrs/ops/embedding_attrs.struct.toml | 2 +- .../include/op-attrs/ops/gather_attrs.dtg.h | 11 +++- .../op-attrs/ops/gather_attrs.struct.toml | 3 +- .../op-attrs/ops/layer_norm_attrs.dtg.h | 11 +++- .../op-attrs/ops/layer_norm_attrs.struct.toml | 3 +- .../include/op-attrs/ops/linear_attrs.dtg.h | 10 ++- .../op-attrs/ops/linear_attrs.struct.toml | 2 +- .../ops/parallel_attention_inputs.dtg.h | 10 ++- .../ops/parallel_attention_inputs.struct.toml | 2 +- .../include/op-attrs/ops/pool_2d_attrs.dtg.h | 10 ++- .../op-attrs/ops/pool_2d_attrs.struct.toml | 2 +- .../include/op-attrs/ops/reduce_attrs.dtg.h | 11 +++- .../op-attrs/ops/reduce_attrs.struct.toml | 3 +- .../op-attrs/ops/reduction_attrs.dtg.h | 11 +++- .../op-attrs/ops/reduction_attrs.struct.toml | 3 +- .../op-attrs/ops/repartition_attrs.dtg.h | 11 +++- .../ops/repartition_attrs.struct.toml | 3 +- .../op-attrs/ops/replicate_attrs.dtg.h | 11 +++- .../op-attrs/ops/replicate_attrs.struct.toml | 3 +- .../include/op-attrs/ops/reshape_attrs.dtg.h | 10 ++- .../op-attrs/ops/reshape_attrs.struct.toml | 2 +- .../include/op-attrs/ops/reverse_attrs.dtg.h | 11 +++- .../op-attrs/ops/reverse_attrs.struct.toml | 3 +- .../include/op-attrs/ops/softmax_attrs.dtg.h | 11 +++- .../op-attrs/ops/softmax_attrs.struct.toml | 3 +- .../include/op-attrs/ops/split_attrs.dtg.h | 11 +++- .../op-attrs/ops/split_attrs.struct.toml | 3 +- .../op-attrs/ops/transpose_attrs.dtg.h | 11 +++- .../op-attrs/ops/transpose_attrs.struct.toml | 3 +- .../include/op-attrs/parallel_dim.dtg.h | 9 ++- .../op-attrs/parallel_dim.variant.toml | 1 + .../op-attrs/parallel_tensor_dims.dtg.h | 10 ++- .../op-attrs/parallel_tensor_dims.struct.toml | 2 +- .../op-attrs/parallel_tensor_shape.dtg.h | 10 ++- .../parallel_tensor_shape.struct.toml | 2 +- .../include/op-attrs/pcg_operator_attrs.dtg.h | 9 ++- .../op-attrs/pcg_operator_attrs.variant.toml | 1 + .../include/op-attrs/regularizer_attrs.dtg.h | 9 ++- .../op-attrs/regularizer_attrs.variant.toml | 1 + .../include/op-attrs/tensor_dims.dtg.h | 10 ++- .../include/op-attrs/tensor_dims.struct.toml | 2 +- .../include/op-attrs/tensor_shape.dtg.h | 10 ++- .../include/op-attrs/tensor_shape.struct.toml | 2 +- .../computation_graph_op_attrs.dtg.cc | 61 ++++++++++++++++- lib/op-attrs/src/op-attrs/ff_dim.dtg.cc | 2 +- lib/op-attrs/src/op-attrs/ops/attention.cc | 7 +- .../src/op-attrs/ops/attention_inputs.dtg.cc | 12 +++- .../src/op-attrs/ops/broadcast.dtg.cc | 9 ++- .../src/op-attrs/ops/cast_attrs.dtg.cc | 8 ++- .../src/op-attrs/ops/combine_attrs.dtg.cc | 10 ++- .../src/op-attrs/ops/concat_attrs.dtg.cc | 10 ++- .../src/op-attrs/ops/conv_2d_attrs.dtg.cc | 18 ++++- .../op-attrs/ops/element_binary_attrs.dtg.cc | 13 +++- .../ops/element_scalar_unary_attrs.dtg.cc | 10 ++- .../op-attrs/ops/element_unary_attrs.dtg.cc | 10 ++- .../src/op-attrs/ops/embedding_attrs.dtg.cc | 12 +++- .../src/op-attrs/ops/gather_attrs.dtg.cc | 10 ++- .../src/op-attrs/ops/layer_norm_attrs.dtg.cc | 13 +++- .../src/op-attrs/ops/linear_attrs.dtg.cc | 13 +++- .../ops/parallel_attention_inputs.dtg.cc | 12 +++- .../src/op-attrs/ops/pool_2d_attrs.dtg.cc | 16 ++++- .../src/op-attrs/ops/reduce_attrs.dtg.cc | 13 +++- .../src/op-attrs/ops/reduction_attrs.dtg.cc | 10 ++- .../src/op-attrs/ops/repartition_attrs.dtg.cc | 11 +++- .../src/op-attrs/ops/replicate_attrs.dtg.cc | 10 ++- .../src/op-attrs/ops/reshape_attrs.dtg.cc | 9 ++- .../src/op-attrs/ops/reverse_attrs.dtg.cc | 10 ++- .../src/op-attrs/ops/softmax_attrs.dtg.cc | 10 ++- .../src/op-attrs/ops/split_attrs.dtg.cc | 11 +++- .../src/op-attrs/ops/transpose_attrs.dtg.cc | 11 +++- lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc | 10 ++- .../src/op-attrs/parallel_tensor_dims.cc | 2 +- .../src/op-attrs/parallel_tensor_dims.dtg.cc | 11 +++- .../src/op-attrs/parallel_tensor_shape.cc | 2 +- .../src/op-attrs/parallel_tensor_shape.dtg.cc | 11 +++- .../src/op-attrs/pcg_operator_attrs.dtg.cc | 65 ++++++++++++++++++- .../src/op-attrs/regularizer_attrs.dtg.cc | 11 +++- lib/op-attrs/src/op-attrs/tensor_dims.cc | 2 +- lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc | 9 ++- lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc | 10 ++- lib/op-attrs/test/src/test_dim_ordered.cc | 13 ++++ .../test/src/test_regularizer_attrs.cc | 14 ++++ .../include/pcg/strided_rectangle.struct.toml | 2 +- lib/utils/include/utils/optional.h | 15 +++++ lib/utils/include/utils/stack_vector.h | 22 ++++++- lib/utils/include/utils/variant.h | 12 ++++ lib/utils/test/src/test_optional.cc | 10 +++ lib/utils/test/src/test_stack_vector.cc | 8 +++ lib/utils/test/src/test_variant.cc | 7 ++ 116 files changed, 912 insertions(+), 118 deletions(-) create mode 100644 lib/op-attrs/include/op-attrs/ff_dim.h create mode 100644 lib/op-attrs/test/src/test_dim_ordered.cc create mode 100644 lib/op-attrs/test/src/test_regularizer_attrs.cc create mode 100644 lib/utils/test/src/test_optional.cc diff --git a/flake.lock b/flake.lock index f90110d8aa..e27afe3c2d 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1714185778, - "narHash": "sha256-Rl33HVDHhmcgKnHPYo96XA7Zm85PUmfuCZGeSWseAdw=", + "lastModified": 1716611864, + "narHash": "sha256-Nd1Hv4j5Wy70KzuAji3/50Fcr7axCw+o8TBJsh2f5UY=", "owner": "lockshaw", "repo": "proj", - "rev": "9b9465925365e76d0db5c11579d52b47fecc4dcd", + "rev": "b83e54a47c755b241ec5fa2a79aa455cba7dfc18", "type": "github" }, "original": { diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h index 02e4ce4f27..cc45628145 100644 --- a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml /* proj-data { - "generated_from": "dc1445fed47c2acaed22038975eec627" + "generated_from": "cc0ab49405423594ffa1d8f541235a48" } */ @@ -39,6 +39,7 @@ #include "op-attrs/ops/topk_attrs.dtg.h" #include "op-attrs/ops/transpose_attrs.dtg.h" #include "op-attrs/ops/weight_attrs.dtg.h" +#include "rapidcheck.h" #include #include #include @@ -455,6 +456,12 @@ struct adl_serializer<::FlexFlow::ComputationGraphOpAttrs> { static void to_json(json &, ::FlexFlow::ComputationGraphOpAttrs const &); }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::ComputationGraphOpAttrs> { + static Gen<::FlexFlow::ComputationGraphOpAttrs> arbitrary(); +}; +} // namespace rc namespace FlexFlow { std::string format_as(::FlexFlow::ComputationGraphOpAttrs const &); std::ostream &operator<<(std::ostream &, diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml index cdbf9332da..bb25514e1d 100644 --- a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml @@ -5,6 +5,7 @@ features = [ "ord", "hash", "json", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index b843b96e70..ab62c1d30c 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -171,7 +171,6 @@ FFOrdered const &outer_to_inner(FFOrdered const &ff_ordered) { } // namespace FlexFlow - /* template */ /* void to_json(json &j, DimOrdered const &x) { */ /* /1* j = std::vector{x.cbegin(), x.cend()}; *1/ */ @@ -189,11 +188,11 @@ struct adl_serializer<::FlexFlow::DimOrdered> { return {j.template get>()}; } - static void to_json(json& j, ::FlexFlow::DimOrdered const &x) { + static void to_json(json &j, ::FlexFlow::DimOrdered const &x) { j = std::vector{x.cbegin(), x.cend()}; } }; -} +} // namespace nlohmann namespace std { @@ -209,4 +208,16 @@ struct hash<::FlexFlow::DimOrdered> { } // namespace std +namespace rc { + +template +struct Arbitrary<::FlexFlow::DimOrdered> { + static Gen<::FlexFlow::DimOrdered> arbitrary() { + return gen::construct<::FlexFlow::DimOrdered>( + gen::arbitrary<::FlexFlow::stack_vector>()); + } +}; + +} // namespace rc + #endif diff --git a/lib/op-attrs/include/op-attrs/ff_dim.dtg.h b/lib/op-attrs/include/op-attrs/ff_dim.dtg.h index 8363ef0207..1697f78196 100644 --- a/lib/op-attrs/include/op-attrs/ff_dim.dtg.h +++ b/lib/op-attrs/include/op-attrs/ff_dim.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ff_dim.struct.toml /* proj-data { - "generated_from": "ffd119eb46e048b0f5a2d8fbef253de3" + "generated_from": "a5fa89a024e95c4f2d52681a74cab30f" } */ diff --git a/lib/op-attrs/include/op-attrs/ff_dim.h b/lib/op-attrs/include/op-attrs/ff_dim.h new file mode 100644 index 0000000000..b0559c2f1e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ff_dim.h @@ -0,0 +1,17 @@ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H + +#include "rapidcheck.h" +#include "op-attrs/ff_dim.dtg.h" + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(){ + return gen::construct(gen::inRange(0, MAX_TENSOR_DIM)); + } +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H diff --git a/lib/op-attrs/include/op-attrs/ff_dim.struct.toml b/lib/op-attrs/include/op-attrs/ff_dim.struct.toml index feae1e4b21..441f9826ca 100644 --- a/lib/op-attrs/include/op-attrs/ff_dim.struct.toml +++ b/lib/op-attrs/include/op-attrs/ff_dim.struct.toml @@ -6,7 +6,6 @@ features = [ "ord", "hash", "json", - # "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index ae0e791a4e..de7246dcef 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -37,15 +37,13 @@ int get_num_samples(MultiHeadAttentionInputs const &); TensorShape get_weights_shape(MultiHeadAttentionAttrs const &, MultiHeadAttentionInputs const &); -ParallelTensorShape - get_weights_shape(MultiHeadAttentionAttrs const &, - ParallelMultiHeadAttentionInputs const &); +ParallelTensorShape get_weights_shape(MultiHeadAttentionAttrs const &, + ParallelMultiHeadAttentionInputs const &); TensorShape get_output_shape(MultiHeadAttentionAttrs const &, MultiHeadAttentionInputs const &); -ParallelTensorShape - get_output_shape(MultiHeadAttentionAttrs const &, - ParallelMultiHeadAttentionInputs const &); +ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &, + ParallelMultiHeadAttentionInputs const &); CHECK_VALID_OP_ATTR(MultiHeadAttentionAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/attention_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/attention_inputs.dtg.h index bc1116eb17..809c12c835 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention_inputs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/attention_inputs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml /* proj-data { - "generated_from": "700f5fb734284b7feabbdd4cb61f3183" + "generated_from": "846dd6d3f4ca1c8135e4b3c8913fb872" } */ @@ -13,6 +13,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/tensor_shape.h" +#include "rapidcheck.h" #include #include #include @@ -51,6 +52,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(MultiHeadAttentionInputs const &); std::ostream &operator<<(std::ostream &, MultiHeadAttentionInputs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml index 224bad4dc8..1b04c1de2d 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h b/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h index b940ccc2b3..e4de3dcc75 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml /* proj-data { - "generated_from": "890d0e63a08a30d925aa170aea6992ba" + "generated_from": "12715c970e8416eacbd0750f338478e5" } */ @@ -12,6 +12,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" +#include "rapidcheck.h" #include "utils/stack_vector.h" #include #include @@ -48,6 +49,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(BroadcastAttrs const &); std::ostream &operator<<(std::ostream &, BroadcastAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml b/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml index ae5549c9b9..c87afa59b5 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h index 5956b5b14f..33391eb221 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml /* proj-data { - "generated_from": "62da4845a8aa0ae4ca3bce432a3aa9a3" + "generated_from": "c171c87db89b9ec9ea7d52a50c153054" } */ @@ -13,6 +13,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/datatype.h" +#include "rapidcheck.h" #include #include #include @@ -47,6 +48,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(CastAttrs const &); std::ostream &operator<<(std::ostream &, CastAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml index 75231ebc45..6c12680ea1 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h index e3c8b9ea2a..43db204bc5 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml /* proj-data { - "generated_from": "7caa0f9668b1894f5e446556f1a424c8" + "generated_from": "58fc5a388fd1a325ef4142094607e39a" } */ @@ -13,6 +13,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" #include #include #include @@ -49,6 +51,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(CombineAttrs const &); std::ostream &operator<<(std::ostream &, CombineAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml index 6791d3a110..585295fe1c 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml @@ -5,11 +5,12 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] includes = [ + "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h index 3d0b50c688..3c26473a4e 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml /* proj-data { - "generated_from": "b72ef29f9f79a917176c63a5c3683ab5" + "generated_from": "68e0520b143e0579140a2f2cdd390759" } */ @@ -13,6 +13,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" #include #include #include @@ -48,6 +50,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(ConcatAttrs const &); std::ostream &operator<<(std::ostream &, ConcatAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml index b75839bd9c..4faa870bc4 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml @@ -5,11 +5,12 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] includes = [ + "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h" ] diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h index 7eb9bd677c..06827656da 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml /* proj-data { - "generated_from": "85f65c1b0e0340ea8e8622c2bf9ca38d" + "generated_from": "74f98e1aacb57d847bb450e1d28d3e67" } */ @@ -13,6 +13,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/activation.dtg.h" +#include "rapidcheck.h" #include "utils/json.h" #include #include @@ -67,6 +68,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(Conv2DAttrs const &); std::ostream &operator<<(std::ostream &, Conv2DAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml index b27c2e1899..353ef93004 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h index 66a0b66304..10d93c87d3 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml /* proj-data { - "generated_from": "1aae4139632791a4b7638e59fa6b5dc8" + "generated_from": "2bb947c9cc92e3833ee88c908c539629" } */ @@ -14,6 +14,7 @@ #include "nlohmann/json.hpp" #include "op-attrs/datatype.h" #include "op-attrs/operator_type.h" +#include "rapidcheck.h" #include #include #include @@ -54,6 +55,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(ElementBinaryAttrs const &); std::ostream &operator<<(std::ostream &, ElementBinaryAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml index 9479cb2956..d167c67aed 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml @@ -6,7 +6,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h index 61041b3993..a9fe63ca71 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml /* proj-data { - "generated_from": "09554c353caed6075e362da5008c4bd2" + "generated_from": "aa6f98b992d46bdf7ad59158bc143a3f" } */ @@ -13,6 +13,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/operator_type.h" +#include "rapidcheck.h" #include #include #include @@ -49,6 +50,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(ElementScalarUnaryAttrs const &); std::ostream &operator<<(std::ostream &, ElementScalarUnaryAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml index 2f406a67d5..609805ab98 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml @@ -6,7 +6,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h index bdf63fda8d..3220234bd1 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml /* proj-data { - "generated_from": "fdb867c04cdd7de320f573f360bcab90" + "generated_from": "75272cff78d3db866122dbb1001aedbe" } */ @@ -13,6 +13,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/operator_type.h" +#include "rapidcheck.h" #include #include #include @@ -47,6 +48,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(ElementUnaryAttrs const &); std::ostream &operator<<(std::ostream &, ElementUnaryAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml index fad251d181..b0e23aa5c7 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml @@ -6,7 +6,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h index 23df0b7cd2..0b5bed5ba7 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml /* proj-data { - "generated_from": "65af6a38dfabebbc05c8ad3f75397b07" + "generated_from": "a0ac41fc0f56bc06bcb1a8d42fc6191c" } */ @@ -14,6 +14,7 @@ #include "nlohmann/json.hpp" #include "op-attrs/aggregate_op.dtg.h" #include "op-attrs/datatype.dtg.h" +#include "rapidcheck.h" #include "utils/stack_vector.h" #include #include @@ -55,6 +56,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(EmbeddingAttrs const &); std::ostream &operator<<(std::ostream &, EmbeddingAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml index 1bae4869bd..39dc71bdb3 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h index 6c74d77031..e7a35e5800 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml /* proj-data { - "generated_from": "ee735644d3c5f53f790e0a1fa8b8beaf" + "generated_from": "4ba46b6b494a7a52edda437d2a05fcf1" } */ @@ -13,6 +13,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" #include #include #include @@ -47,6 +49,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(GatherAttrs const &); std::ostream &operator<<(std::ostream &, GatherAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml index c66f1585fd..c8bb88dcc7 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml @@ -5,11 +5,12 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] includes = [ + "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h" ] diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h index af8ace620a..c945206863 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml /* proj-data { - "generated_from": "c03d823a6e889e1254b73a0730a71046" + "generated_from": "349deae8d9356d3eeacd7e7d069c3155" } */ @@ -13,6 +13,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" #include "utils/stack_vector.h" #include #include @@ -53,6 +55,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(LayerNormAttrs const &); std::ostream &operator<<(std::ostream &, LayerNormAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml index a72b903ebe..ec60d39f7f 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml @@ -5,11 +5,12 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] includes = [ + "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h", "utils/stack_vector.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h index 572520031e..dfb7579b25 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml /* proj-data { - "generated_from": "dae07c937f6c52d4dc89ec322520e29f" + "generated_from": "1369f126a4a6d6eee91642043ab481f6" } */ @@ -15,6 +15,7 @@ #include "op-attrs/activation.dtg.h" #include "op-attrs/datatype.dtg.h" #include "op-attrs/regularizer_attrs.dtg.h" +#include "rapidcheck.h" #include "utils/json.h" #include #include @@ -58,6 +59,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(LinearAttrs const &); std::ostream &operator<<(std::ostream &, LinearAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml index 8945d47c55..7fa2d9c584 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h index c3ee7782c4..d3903bd3b2 100644 --- a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml /* proj-data { - "generated_from": "8d1e2a2d3852bfb59d8668d14d52c958" + "generated_from": "b76a39763275090d8376e1c27668d2cb" } */ @@ -13,6 +13,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/parallel_tensor_shape.h" +#include "rapidcheck.h" #include #include #include @@ -49,6 +50,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(ParallelMultiHeadAttentionInputs const &); std::ostream &operator<<(std::ostream &, diff --git a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml index 22136a948b..4809ee998a 100644 --- a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml @@ -5,7 +5,7 @@ features = [ # "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h index c976ca0720..a5c6603302 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml /* proj-data { - "generated_from": "607be08f56d910bfa340fb180646c126" + "generated_from": "03aeafe335f68ff831e3e73a77f45caf" } */ @@ -14,6 +14,7 @@ #include "nlohmann/json.hpp" #include "op-attrs/activation.dtg.h" #include "op-attrs/pool_op.dtg.h" +#include "rapidcheck.h" #include #include #include @@ -62,6 +63,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(Pool2DAttrs const &); std::ostream &operator<<(std::ostream &, Pool2DAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml index 58854d457c..56bf682f50 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h index f6f78911e3..af27bf35be 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml /* proj-data { - "generated_from": "bc6279031650335f4a0b7b6cfe116c85" + "generated_from": "097463446e254f662c7bdf5df4e12d17" } */ @@ -13,7 +13,9 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" #include "op-attrs/operator_type.dtg.h" +#include "rapidcheck.h" #include "utils/stack_vector.h" #include #include @@ -54,6 +56,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(ReduceAttrs const &); std::ostream &operator<<(std::ostream &, ReduceAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml index e8a1785d19..717e7954e8 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml @@ -5,12 +5,13 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] includes = [ "op-attrs/operator_type.dtg.h", + "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h", "utils/stack_vector.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h index 942e1870e8..5ff8a12651 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml /* proj-data { - "generated_from": "57b8ccb5bc2e1a1a3bcf1bce2d8cad9e" + "generated_from": "28492e45a5c4f44987e17fe9ea876e11" } */ @@ -13,6 +13,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" #include #include #include @@ -49,6 +51,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(ReductionAttrs const &); std::ostream &operator<<(std::ostream &, ReductionAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml index 5baafdfa42..ff990ef46c 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml @@ -5,11 +5,12 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] includes = [ + "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h index fa888700d0..66c21466f4 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml /* proj-data { - "generated_from": "366cb1a14093762f75508260ac6494ca" + "generated_from": "0a4d8b435768ce3ee37013fc550c9ebb" } */ @@ -13,6 +13,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" #include #include #include @@ -49,6 +51,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(RepartitionAttrs const &); std::ostream &operator<<(std::ostream &, RepartitionAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml index 344691a781..25a33c0c15 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml @@ -5,11 +5,12 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] includes = [ + "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h index 4249a2c0e7..36b70a0b6d 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml /* proj-data { - "generated_from": "4224406d468444433d69e4abf61b7cd1" + "generated_from": "68c1bba349a54c0db219a67d4cc502b3" } */ @@ -13,6 +13,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" #include #include #include @@ -49,6 +51,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(ReplicateAttrs const &); std::ostream &operator<<(std::ostream &, ReplicateAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml index d5f9c22f28..afcb8f8fa4 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml @@ -5,11 +5,12 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] includes = [ + "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h index 860b61f2e8..612874790f 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml /* proj-data { - "generated_from": "5a6a9e646a457a6cf959c542fb631512" + "generated_from": "015d04de0ccb982e7eaa013a842880ca" } */ @@ -13,6 +13,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/tensor_shape.dtg.h" +#include "rapidcheck.h" #include #include #include @@ -47,6 +48,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(ReshapeAttrs const &); std::ostream &operator<<(std::ostream &, ReshapeAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml index dc0a96313d..69ac761859 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h index 3ed917d33e..8c8c8a7a9e 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml /* proj-data { - "generated_from": "7c21c4192854f5981018abf4fbdd9ead" + "generated_from": "c5a82c8a15ac3ce6f47dc054236ab69b" } */ @@ -13,6 +13,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" #include #include #include @@ -47,6 +49,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(ReverseAttrs const &); std::ostream &operator<<(std::ostream &, ReverseAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml index e2058cf3e5..198346e5dd 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml @@ -5,11 +5,12 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] includes = [ + "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h index a2acbf7300..1c855d90f4 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml /* proj-data { - "generated_from": "9be043678a4ce7666fc372cded600290" + "generated_from": "2ddf5a8b7daa32a43387f5fd5866bb3b" } */ @@ -13,6 +13,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" #include #include #include @@ -47,6 +49,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(SoftmaxAttrs const &); std::ostream &operator<<(std::ostream &, SoftmaxAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml index 3e4fcbc75a..8b839c122a 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml @@ -5,11 +5,12 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] includes = [ + "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h index dee08ca1c8..b602015e2e 100644 --- a/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml /* proj-data { - "generated_from": "4112baa96de544b865618e0a999e0807" + "generated_from": "cde6b5caf6739d3b02fe8fce0d8ae8c5" } */ @@ -13,6 +13,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" #include "utils/stack_vector.h" #include #include @@ -50,6 +52,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(SplitAttrs const &); std::ostream &operator<<(std::ostream &, SplitAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml index 8205cdbccb..8cdf7728af 100644 --- a/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml @@ -5,12 +5,13 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] includes = [ "utils/stack_vector.h", + "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h index 352aaf6e6a..355c28fcdc 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml /* proj-data { - "generated_from": "edff0b414040204e895666d81b49db07" + "generated_from": "87f6e4db4b66d564530994773c0ecef4" } */ @@ -13,6 +13,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" #include "utils/stack_vector.h" #include #include @@ -49,6 +51,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(TransposeAttrs const &); std::ostream &operator<<(std::ostream &, TransposeAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml index af13022262..aab525b7e6 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml @@ -5,11 +5,12 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] includes = [ + "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h", "utils/stack_vector.h", ] diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h b/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h index a5c4fc0b29..4115d4ce1f 100644 --- a/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h +++ b/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/parallel_dim.variant.toml /* proj-data { - "generated_from": "5550fc7ad51892b3411ef274c76e7d85" + "generated_from": "f382ff547aae62777e5091f00d034d84" } */ @@ -14,6 +14,7 @@ #include "nlohmann/json.hpp" #include "op-attrs/replica_parallel_dim.dtg.h" #include "op-attrs/shard_parallel_dim.dtg.h" +#include "rapidcheck.h" #include #include #include @@ -113,6 +114,12 @@ struct adl_serializer<::FlexFlow::ParallelDim> { static void to_json(json &, ::FlexFlow::ParallelDim const &); }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::ParallelDim> { + static Gen<::FlexFlow::ParallelDim> arbitrary(); +}; +} // namespace rc namespace FlexFlow { std::string format_as(::FlexFlow::ParallelDim const &); std::ostream &operator<<(std::ostream &, ::FlexFlow::ParallelDim const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml b/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml index eceffd38a3..e27e6509fe 100644 --- a/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml +++ b/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml @@ -5,6 +5,7 @@ features = [ "ord", "hash", "json", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h index d2d1b13a49..71ad517095 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml /* proj-data { - "generated_from": "31a9e757f42ec3e468b299cda2cbcd4e" + "generated_from": "aec3b6b66e34be0d5ce3055822479430" } */ @@ -15,6 +15,7 @@ #include "op-attrs/dim_ordered.h" #include "op-attrs/replica_parallel_dim_set.dtg.h" #include "op-attrs/shard_parallel_dim.dtg.h" +#include "rapidcheck.h" #include "utils/fmt/pair.h" #include "utils/fmt/unordered_map.h" #include @@ -55,6 +56,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(ParallelTensorDims const &); std::ostream &operator<<(std::ostream &, ParallelTensorDims const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml index 37216b160e..ae6eab1e58 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h index dfad5b1007..62d291fa4f 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml /* proj-data { - "generated_from": "b2d36c9212916e66569af4e958c893f4" + "generated_from": "06d657d1e95f34aebf4b721c768cbee8" } */ @@ -14,6 +14,7 @@ #include "nlohmann/json.hpp" #include "op-attrs/datatype.h" #include "op-attrs/parallel_tensor_dims.h" +#include "rapidcheck.h" #include #include #include @@ -50,6 +51,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(ParallelTensorShape const &); std::ostream &operator<<(std::ostream &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml index 1199b0d816..e6197bcd51 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h index 75b203ab3a..5370773a45 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml /* proj-data { - "generated_from": "cf0da4385b7554748a06ec25ccf17f2f" + "generated_from": "e1d10b0c7c98524c27886bdae0972321" } */ @@ -41,6 +41,7 @@ #include "op-attrs/ops/split_attrs.dtg.h" #include "op-attrs/ops/topk_attrs.dtg.h" #include "op-attrs/ops/transpose_attrs.dtg.h" +#include "rapidcheck.h" #include #include #include @@ -480,6 +481,12 @@ struct adl_serializer<::FlexFlow::PCGOperatorAttrs> { static void to_json(json &, ::FlexFlow::PCGOperatorAttrs const &); }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::PCGOperatorAttrs> { + static Gen<::FlexFlow::PCGOperatorAttrs> arbitrary(); +}; +} // namespace rc namespace FlexFlow { std::string format_as(::FlexFlow::PCGOperatorAttrs const &); std::ostream &operator<<(std::ostream &, ::FlexFlow::PCGOperatorAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml index 4062f08684..ddb8a109d8 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml @@ -5,6 +5,7 @@ features = [ "ord", "hash", "json", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.h b/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.h index 38add9b42b..2621b4b12c 100644 --- a/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml /* proj-data { - "generated_from": "b0cb2d264215faf9759925c631f3d55f" + "generated_from": "ea060a8ab344c9772102f084903883ea" } */ @@ -14,6 +14,7 @@ #include "nlohmann/json.hpp" #include "op-attrs/l1_regularizer_attrs.dtg.h" #include "op-attrs/l2_regularizer_attrs.dtg.h" +#include "rapidcheck.h" #include #include #include @@ -113,6 +114,12 @@ struct adl_serializer<::FlexFlow::RegularizerAttrs> { static void to_json(json &, ::FlexFlow::RegularizerAttrs const &); }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::RegularizerAttrs> { + static Gen<::FlexFlow::RegularizerAttrs> arbitrary(); +}; +} // namespace rc namespace FlexFlow { std::string format_as(::FlexFlow::RegularizerAttrs const &); std::ostream &operator<<(std::ostream &, ::FlexFlow::RegularizerAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml b/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml index df974fed91..d650c7f6a9 100644 --- a/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml +++ b/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml @@ -5,6 +5,7 @@ features = [ "ord", "hash", "json", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h b/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h index cb67f65c49..a8e46a4626 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/tensor_dims.struct.toml /* proj-data { - "generated_from": "f925a4c2343d2404116dc598c301beaf" + "generated_from": "5beb89eeae9eba303f90e726c794375d" } */ @@ -13,6 +13,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" #include "op-attrs/dim_ordered.h" +#include "rapidcheck.h" #include #include #include @@ -47,6 +48,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(TensorDims const &); std::ostream &operator<<(std::ostream &, TensorDims const &); diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml index e3913f60f6..cff8e08b0f 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml +++ b/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] includes = [ diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h b/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h index 8ac6655956..f36d5d1306 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/tensor_shape.struct.toml /* proj-data { - "generated_from": "52968754cf94f415c366d228c87042db" + "generated_from": "ef6fa5088b89d6da4dc8bddf0a6d3294" } */ @@ -14,6 +14,7 @@ #include "nlohmann/json.hpp" #include "op-attrs/datatype.dtg.h" #include "op-attrs/tensor_dims.dtg.h" +#include "rapidcheck.h" #include #include #include @@ -50,6 +51,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(TensorShape const &); std::ostream &operator<<(std::ostream &, TensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml b/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml index 24f9ff1b79..901c3b9e60 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml +++ b/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc index 7b0db513bf..9bcde22cd9 100644 --- a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml /* proj-data { - "generated_from": "dc1445fed47c2acaed22038975eec627" + "generated_from": "cc0ab49405423594ffa1d8f541235a48" } */ @@ -384,6 +384,65 @@ void adl_serializer<::FlexFlow::ComputationGraphOpAttrs>::to_json( } } } // namespace nlohmann +namespace rc { +Gen<::FlexFlow::ComputationGraphOpAttrs> + Arbitrary<::FlexFlow::ComputationGraphOpAttrs>::arbitrary() { + return gen::oneOf(gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::BatchMatmulAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::BatchNormAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::BroadcastAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::CastAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ConcatAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::Conv2DAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::DropoutAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ElementBinaryAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ElementUnaryAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ElementScalarUnaryAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::EmbeddingAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::FlatAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::GatherAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::InputAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::LayerNormAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::LinearAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::MultiHeadAttentionAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::NoopAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::Pool2DAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ReduceAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ReverseAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ReshapeAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::SplitAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::SoftmaxAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::TopKAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::TransposeAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::WeightAttrs>())); +} +} // namespace rc namespace FlexFlow { std::string format_as(::FlexFlow::ComputationGraphOpAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc b/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc index f6a1863fff..8b22dfd18d 100644 --- a/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ff_dim.struct.toml /* proj-data { - "generated_from": "ffd119eb46e048b0f5a2d8fbef253de3" + "generated_from": "a5fa89a024e95c4f2d52681a74cab30f" } */ diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index bc7ea3f57c..105fa7250b 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -62,9 +62,8 @@ int get_vSize(MultiHeadAttentionInputs const &) { NOT_IMPLEMENTED(); } -TensorShape - get_weights_shape(MultiHeadAttentionAttrs const &attrs, - MultiHeadAttentionInputs const &inputs) { +TensorShape get_weights_shape(MultiHeadAttentionAttrs const &attrs, + MultiHeadAttentionInputs const &inputs) { size_t qParas = get_qProjSize(attrs) * get_qSize(inputs); size_t kParas = get_kProjSize(attrs) * get_kSize(inputs); size_t vParas = get_vProjSize(attrs) * get_vSize(inputs); @@ -72,7 +71,7 @@ TensorShape size_t oParas = get_oProjSize(attrs) * get_oSize(output_shape); TensorDims dims = {{qParas + kParas + vParas + oParas, - static_cast(attrs.embed_dim)}}; + static_cast(attrs.embed_dim)}}; return {dims, DataType::FLOAT}; } diff --git a/lib/op-attrs/src/op-attrs/ops/attention_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention_inputs.dtg.cc index d12018acb5..f8ad72ca7a 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention_inputs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention_inputs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml /* proj-data { - "generated_from": "700f5fb734284b7feabbdd4cb61f3183" + "generated_from": "846dd6d3f4ca1c8135e4b3c8913fb872" } */ @@ -81,6 +81,16 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::TensorShape>(), + gen::arbitrary<::FlexFlow::TensorShape>(), + gen::arbitrary<::FlexFlow::TensorShape>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(MultiHeadAttentionInputs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc index dadb8d4cff..ec08bd6a1d 100644 --- a/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml /* proj-data { - "generated_from": "890d0e63a08a30d925aa170aea6992ba" + "generated_from": "12715c970e8416eacbd0750f338478e5" } */ @@ -60,6 +60,13 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::stack_vector>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(BroadcastAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc index 63b0cda27b..28367f3449 100644 --- a/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml /* proj-data { - "generated_from": "62da4845a8aa0ae4ca3bce432a3aa9a3" + "generated_from": "c171c87db89b9ec9ea7d52a50c153054" } */ @@ -56,6 +56,12 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(CastAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc index a652537871..516d3b0318 100644 --- a/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc @@ -3,13 +3,14 @@ // lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml /* proj-data { - "generated_from": "7caa0f9668b1894f5e446556f1a424c8" + "generated_from": "58fc5a388fd1a325ef4142094607e39a" } */ #include "op-attrs/ops/combine_attrs.dtg.h" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" #include namespace FlexFlow { @@ -68,6 +69,13 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(CombineAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc index 0494b7069a..20db25d485 100644 --- a/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc @@ -3,13 +3,14 @@ // lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml /* proj-data { - "generated_from": "b72ef29f9f79a917176c63a5c3683ab5" + "generated_from": "68e0520b143e0579140a2f2cdd390759" } */ #include "op-attrs/ops/concat_attrs.dtg.h" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" #include namespace FlexFlow { @@ -68,6 +69,13 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(ConcatAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc index e0d1e654aa..238b349cbe 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml /* proj-data { - "generated_from": "85f65c1b0e0340ea8e8622c2bf9ca38d" + "generated_from": "74f98e1aacb57d847bb450e1d28d3e67" } */ @@ -217,6 +217,22 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary>(), + gen::arbitrary()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(Conv2DAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc index bdaef6511f..a0e555cb12 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml /* proj-data { - "generated_from": "1aae4139632791a4b7638e59fa6b5dc8" + "generated_from": "2bb947c9cc92e3833ee88c908c539629" } */ @@ -117,6 +117,17 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::OperatorType>(), + gen::arbitrary<::FlexFlow::DataType>(), + gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(ElementBinaryAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc index 36c26653d4..ee85474caf 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml /* proj-data { - "generated_from": "09554c353caed6075e362da5008c4bd2" + "generated_from": "aa6f98b992d46bdf7ad59158bc143a3f" } */ @@ -75,6 +75,14 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::OperatorType>(), gen::arbitrary()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(ElementScalarUnaryAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc index b5968ed425..bf90a3db7d 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml /* proj-data { - "generated_from": "fdb867c04cdd7de320f573f360bcab90" + "generated_from": "75272cff78d3db866122dbb1001aedbe" } */ @@ -57,6 +57,14 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::OperatorType>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(ElementUnaryAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc index a9110f16bc..b5a4028e13 100644 --- a/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml /* proj-data { - "generated_from": "65af6a38dfabebbc05c8ad3f75397b07" + "generated_from": "a0ac41fc0f56bc06bcb1a8d42fc6191c" } */ @@ -111,6 +111,16 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::AggregateOp>(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(EmbeddingAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc index 886794a9b1..713c0f391e 100644 --- a/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc @@ -3,13 +3,14 @@ // lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml /* proj-data { - "generated_from": "ee735644d3c5f53f790e0a1fa8b8beaf" + "generated_from": "4ba46b6b494a7a52edda437d2a05fcf1" } */ #include "op-attrs/ops/gather_attrs.dtg.h" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" #include namespace FlexFlow { @@ -56,6 +57,13 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(GatherAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc index d3c4e0c57e..163f2e2f91 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc @@ -3,13 +3,14 @@ // lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml /* proj-data { - "generated_from": "c03d823a6e889e1254b73a0730a71046" + "generated_from": "349deae8d9356d3eeacd7e7d069c3155" } */ #include "op-attrs/ops/layer_norm_attrs.dtg.h" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" #include "utils/stack_vector.h" #include @@ -81,6 +82,16 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), + gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(LayerNormAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc index 961222843d..a465751991 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml /* proj-data { - "generated_from": "dae07c937f6c52d4dc89ec322520e29f" + "generated_from": "1369f126a4a6d6eee91642043ab481f6" } */ @@ -132,6 +132,17 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::DataType>(), + gen::arbitrary<::FlexFlow::Activation>(), + gen::arbitrary>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(LinearAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc index 06c891d4de..ac8da6d2d7 100644 --- a/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml /* proj-data { - "generated_from": "8d1e2a2d3852bfb59d8668d14d52c958" + "generated_from": "b76a39763275090d8376e1c27668d2cb" } */ @@ -61,6 +61,16 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ParallelTensorShape>(), + gen::arbitrary<::FlexFlow::ParallelTensorShape>(), + gen::arbitrary<::FlexFlow::ParallelTensorShape>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(ParallelMultiHeadAttentionInputs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc index 3316e4c136..8c445d8b84 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml /* proj-data { - "generated_from": "607be08f56d910bfa340fb180646c126" + "generated_from": "03aeafe335f68ff831e3e73a77f45caf" } */ @@ -179,6 +179,20 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::PoolOp>(), + gen::arbitrary<::FlexFlow::Activation>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(Pool2DAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc index 004beb7c64..2aa9546956 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc @@ -3,13 +3,14 @@ // lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml /* proj-data { - "generated_from": "bc6279031650335f4a0b7b6cfe116c85" + "generated_from": "097463446e254f662c7bdf5df4e12d17" } */ #include "op-attrs/ops/reduce_attrs.dtg.h" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" #include "op-attrs/operator_type.dtg.h" #include "utils/stack_vector.h" #include @@ -82,6 +83,16 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), + gen::arbitrary<::FlexFlow::OperatorType>(), + gen::arbitrary()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(ReduceAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc index a7cc019111..d4566614df 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc @@ -3,13 +3,14 @@ // lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml /* proj-data { - "generated_from": "57b8ccb5bc2e1a1a3bcf1bce2d8cad9e" + "generated_from": "28492e45a5c4f44987e17fe9ea876e11" } */ #include "op-attrs/ops/reduction_attrs.dtg.h" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" #include namespace FlexFlow { @@ -68,6 +69,13 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(ReductionAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc index 5ff0f44f44..6270298c87 100644 --- a/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc @@ -3,13 +3,14 @@ // lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml /* proj-data { - "generated_from": "366cb1a14093762f75508260ac6494ca" + "generated_from": "0a4d8b435768ce3ee37013fc550c9ebb" } */ #include "op-attrs/ops/repartition_attrs.dtg.h" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" #include namespace FlexFlow { @@ -69,6 +70,14 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(RepartitionAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc index bf92a0b656..66fe45a1db 100644 --- a/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc @@ -3,13 +3,14 @@ // lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml /* proj-data { - "generated_from": "4224406d468444433d69e4abf61b7cd1" + "generated_from": "68c1bba349a54c0db219a67d4cc502b3" } */ #include "op-attrs/ops/replicate_attrs.dtg.h" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" #include namespace FlexFlow { @@ -68,6 +69,13 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(ReplicateAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc index 2c5509a655..b1fb350b88 100644 --- a/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml /* proj-data { - "generated_from": "5a6a9e646a457a6cf959c542fb631512" + "generated_from": "015d04de0ccb982e7eaa013a842880ca" } */ @@ -57,6 +57,13 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::TensorShape>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(ReshapeAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc index 61122313b0..9ac9abeb82 100644 --- a/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc @@ -3,13 +3,14 @@ // lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml /* proj-data { - "generated_from": "7c21c4192854f5981018abf4fbdd9ead" + "generated_from": "c5a82c8a15ac3ce6f47dc054236ab69b" } */ #include "op-attrs/ops/reverse_attrs.dtg.h" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" #include namespace FlexFlow { @@ -56,6 +57,13 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(ReverseAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc index 6b685b3de2..4941b7438a 100644 --- a/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc @@ -3,13 +3,14 @@ // lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml /* proj-data { - "generated_from": "9be043678a4ce7666fc372cded600290" + "generated_from": "2ddf5a8b7daa32a43387f5fd5866bb3b" } */ #include "op-attrs/ops/softmax_attrs.dtg.h" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" #include namespace FlexFlow { @@ -56,6 +57,13 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(SoftmaxAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc index 8ca4518d17..c6f7e75dbf 100644 --- a/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc @@ -3,13 +3,14 @@ // lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml /* proj-data { - "generated_from": "4112baa96de544b865618e0a999e0807" + "generated_from": "cde6b5caf6739d3b02fe8fce0d8ae8c5" } */ #include "op-attrs/ops/split_attrs.dtg.h" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" #include "utils/stack_vector.h" #include @@ -72,6 +73,14 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::stack_vector>(), + gen::arbitrary<::FlexFlow::ff_dim_t>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(SplitAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc index 7463c6b3de..78f03d0815 100644 --- a/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc @@ -3,13 +3,14 @@ // lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml /* proj-data { - "generated_from": "edff0b414040204e895666d81b49db07" + "generated_from": "87f6e4db4b66d564530994773c0ecef4" } */ #include "op-attrs/ops/transpose_attrs.dtg.h" #include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" #include "utils/stack_vector.h" #include @@ -64,6 +65,14 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(TransposeAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc index c2016c9f8f..886893c90a 100644 --- a/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc +++ b/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/parallel_dim.variant.toml /* proj-data { - "generated_from": "5550fc7ad51892b3411ef274c76e7d85" + "generated_from": "f382ff547aae62777e5091f00d034d84" } */ @@ -80,6 +80,14 @@ void adl_serializer<::FlexFlow::ParallelDim>::to_json( } } } // namespace nlohmann +namespace rc { +Gen<::FlexFlow::ParallelDim> Arbitrary<::FlexFlow::ParallelDim>::arbitrary() { + return gen::oneOf(gen::construct<::FlexFlow::ParallelDim>( + gen::arbitrary<::FlexFlow::ShardParallelDim>()), + gen::construct<::FlexFlow::ParallelDim>( + gen::arbitrary<::FlexFlow::ReplicaParallelDim>())); +} +} // namespace rc namespace FlexFlow { std::string format_as(::FlexFlow::ParallelDim const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 1c3c42173b..16de2347d6 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -59,4 +59,4 @@ TensorDims get_reduced_dims(ParallelTensorDims const &) { NOT_IMPLEMENTED(); } -} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc index 14372d42cc..40be73cb9f 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml /* proj-data { - "generated_from": "31a9e757f42ec3e468b299cda2cbcd4e" + "generated_from": "aec3b6b66e34be0d5ce3055822479430" } */ @@ -77,6 +77,15 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>>(), + gen::arbitrary<::FlexFlow::ReplicaParallelDimSet>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(ParallelTensorDims const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 2c5e556224..9e0afca357 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -32,7 +32,7 @@ ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &s, ff_dim_t d) { } ParallelTensorShape lift_to_parallel(TensorShape const &s) { - return { lift_to_parallel(s.dims), s.data_type }; + return {lift_to_parallel(s.dims), s.data_type}; } TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) { diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc index 037acbf996..1fe82ce108 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml /* proj-data { - "generated_from": "b2d36c9212916e66569af4e958c893f4" + "generated_from": "06d657d1e95f34aebf4b721c768cbee8" } */ @@ -70,6 +70,15 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ParallelTensorDims>(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(ParallelTensorShape const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc index 8baa07e537..5334c8a7ab 100644 --- a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml /* proj-data { - "generated_from": "cf0da4385b7554748a06ec25ccf17f2f" + "generated_from": "e1d10b0c7c98524c27886bdae0972321" } */ @@ -376,6 +376,69 @@ void adl_serializer<::FlexFlow::PCGOperatorAttrs>::to_json( } } } // namespace nlohmann +namespace rc { +Gen<::FlexFlow::PCGOperatorAttrs> + Arbitrary<::FlexFlow::PCGOperatorAttrs>::arbitrary() { + return gen::oneOf(gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::BatchMatmulAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::BatchNormAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::CastAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::CombineAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ConcatAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::Conv2DAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::DropoutAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ElementBinaryAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ElementUnaryAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ElementScalarUnaryAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::EmbeddingAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::FlatAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::GatherAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::InputAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::LayerNormAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::LinearAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::MultiHeadAttentionAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::NoopAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::Pool2DAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ReduceAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ReductionAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::RepartitionAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ReplicateAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ReverseAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ReshapeAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::SplitAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::SoftmaxAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::TopKAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::TransposeAttrs>())); +} +} // namespace rc namespace FlexFlow { std::string format_as(::FlexFlow::PCGOperatorAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/regularizer_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/regularizer_attrs.dtg.cc index 31a06cb19f..d1f844ab10 100644 --- a/lib/op-attrs/src/op-attrs/regularizer_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/regularizer_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml /* proj-data { - "generated_from": "b0cb2d264215faf9759925c631f3d55f" + "generated_from": "ea060a8ab344c9772102f084903883ea" } */ @@ -80,6 +80,15 @@ void adl_serializer<::FlexFlow::RegularizerAttrs>::to_json( } } } // namespace nlohmann +namespace rc { +Gen<::FlexFlow::RegularizerAttrs> + Arbitrary<::FlexFlow::RegularizerAttrs>::arbitrary() { + return gen::oneOf(gen::construct<::FlexFlow::RegularizerAttrs>( + gen::arbitrary<::FlexFlow::L1RegularizerAttrs>()), + gen::construct<::FlexFlow::RegularizerAttrs>( + gen::arbitrary<::FlexFlow::L2RegularizerAttrs>())); +} +} // namespace rc namespace FlexFlow { std::string format_as(::FlexFlow::RegularizerAttrs const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index 9f226f9101..9bad6a3b3d 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -29,4 +29,4 @@ ParallelTensorDims lift_to_parallel(TensorDims const &dims) { }; } -} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc b/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc index f2a7367f1d..909be323ac 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/tensor_dims.struct.toml /* proj-data { - "generated_from": "f925a4c2343d2404116dc598c301beaf" + "generated_from": "5beb89eeae9eba303f90e726c794375d" } */ @@ -57,6 +57,13 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::FFOrdered>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(TensorDims const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc b/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc index 56856070e9..92b31930fa 100644 --- a/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/tensor_shape.struct.toml /* proj-data { - "generated_from": "52968754cf94f415c366d228c87042db" + "generated_from": "ef6fa5088b89d6da4dc8bddf0a6d3294" } */ @@ -69,6 +69,14 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::TensorDims>(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(TensorShape const &x) { std::ostringstream oss; diff --git a/lib/op-attrs/test/src/test_dim_ordered.cc b/lib/op-attrs/test/src/test_dim_ordered.cc new file mode 100644 index 0000000000..17f4bae05f --- /dev/null +++ b/lib/op-attrs/test/src/test_dim_ordered.cc @@ -0,0 +1,13 @@ +#include "doctest/doctest.h" +#include "op-attrs/dim_ordered.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE_TEMPLATE("RC", T, int, double, char) { + CHECK(rc::check("generate", + [](FFOrdered ff_dim, DimOrdered dim) {})); + } +} diff --git a/lib/op-attrs/test/src/test_regularizer_attrs.cc b/lib/op-attrs/test/src/test_regularizer_attrs.cc new file mode 100644 index 0000000000..7b4139ad53 --- /dev/null +++ b/lib/op-attrs/test/src/test_regularizer_attrs.cc @@ -0,0 +1,14 @@ +#include "doctest/doctest.h" +#include "op-attrs/regularizer_attrs.dtg.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("RC") { + CHECK(rc::check("valid variant", [](RegularizerAttrs reg) { + return reg.has() || reg.has(); + })); + } +} diff --git a/lib/pcg/include/pcg/strided_rectangle.struct.toml b/lib/pcg/include/pcg/strided_rectangle.struct.toml index ec9eca9ffa..3dfd90e296 100644 --- a/lib/pcg/include/pcg/strided_rectangle.struct.toml +++ b/lib/pcg/include/pcg/strided_rectangle.struct.toml @@ -5,7 +5,7 @@ features = [ "ord", "hash", "json", - # "rapidcheck", + "rapidcheck", "fmt", ] diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 2546c302c0..7abbf8ab17 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_OPTIONAL_H #include "fmt.h" +#include "rapidcheck.h" #include "utils/exception.h" #include "utils/optional.decl" @@ -59,4 +60,18 @@ struct formatter< } // namespace fmt +namespace rc { + +template +struct Arbitrary> { + static Gen> arbitrary() { + return gen::map( + gen::maybe(std::move(gen::arbitrary())), [](Maybe &&m) { + return m ? std::optional(std::move(*m)) : std::optional(); + }); + } +}; + +} // namespace rc + #endif diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index 08248003f3..d47886b055 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -3,6 +3,7 @@ #include "containers.h" #include "hash-utils.h" +#include "rapidcheck.h" #include "utils/fmt.h" #include "utils/json.h" #include "utils/test_types.h" @@ -39,9 +40,10 @@ struct stack_vector { template stack_vector(Iterator start, Iterator end) { - assert(end - start >= 0); - assert(end - start <= MAXSIZE); - for (; start < end; start++) { + size_t elements_added = 0; + for (; start != end; start++) { + elements_added++; + assert(elements_added <= MAXSIZE); this->push_back(static_cast(*start)); } } @@ -346,4 +348,18 @@ struct hash<::FlexFlow::stack_vector> { } // namespace std +namespace rc { + +template +struct Arbitrary<::FlexFlow::stack_vector> { + static Gen<::FlexFlow::stack_vector> arbitrary() { + return gen::mapcat(gen::inRange(0, MAXSIZE), [](size_t size) { + return gen::container<::FlexFlow::stack_vector>( + size, gen::arbitrary()); + }); + } +}; + +} // namespace rc + #endif diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index 272caaffde..bb2286a9cd 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_UTILS_VARIANT_H #define _FLEXFLOW_UTILS_VARIANT_H +#include "rapidcheck.h" #include "utils/type_traits.h" #include #include @@ -212,4 +213,15 @@ std::optional cast(VariantIn const &v) { } // namespace FlexFlow +namespace rc { + +template +struct Arbitrary> { + static Gen> arbitrary() { + return gen::oneOf(gen::cast>(gen::arbitrary())...); + } +}; + +} // namespace rc + #endif diff --git a/lib/utils/test/src/test_optional.cc b/lib/utils/test/src/test_optional.cc new file mode 100644 index 0000000000..8ef9e18f18 --- /dev/null +++ b/lib/utils/test/src/test_optional.cc @@ -0,0 +1,10 @@ +#include "test/utils/doctest.h" +#include "utils/optional.h" +#include + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE_TEMPLATE("RC arbitrary", T, int, double, char) { + CHECK(rc::check("generate", [](std::optional o) {})); + } +} diff --git a/lib/utils/test/src/test_stack_vector.cc b/lib/utils/test/src/test_stack_vector.cc index 6c0ecf36f3..141cd30e95 100644 --- a/lib/utils/test/src/test_stack_vector.cc +++ b/lib/utils/test/src/test_stack_vector.cc @@ -1,6 +1,7 @@ #include "test/utils/doctest.h" #include "utils/stack_vector.h" #include +#include using namespace FlexFlow; @@ -76,4 +77,11 @@ TEST_SUITE(FF_TEST_SUITE) { vector.push_back(20); CHECK(vector.back() == 20); } + + TEST_CASE_TEMPLATE("RC arbitrary", T, int, double, char) { + constexpr std::size_t MAXSIZE = 10; + CHECK(rc::check("within bound", [](stack_vector v) { + return v.size() <= MAXSIZE; + })); + } } diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/test_variant.cc index 0fef782c0e..7cffe9fbe4 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/test_variant.cc @@ -1,5 +1,6 @@ #include "test/utils/doctest.h" #include "utils/variant.h" +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("widen and narrow functions") { @@ -69,4 +70,10 @@ TEST_SUITE(FF_TEST_SUITE) { // Check the result CHECK(get(wider_variant) == 42); } + + TEST_CASE("RC arbitrary") { + CHECK(rc::check("valid type", [](std::variant v) { + return std::holds_alternative(v) || std::holds_alternative(v); + })); + } } From 394660e6b616541e6ca0f64f3e12eb3c25264e13 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 26 May 2024 10:52:12 -0700 Subject: [PATCH 18/43] Attempt to hide dtgen-generated files from github diff --- .gitattributes | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .gitattributes diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000..28a3d45f12 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.dtg.cc linguist-generated=true +*.dtg.hh linguist-generated=true From 5d60878f9bf2ae25102aba84c955ca9e2b6e4d4c Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 26 May 2024 10:54:54 -0700 Subject: [PATCH 19/43] Fix header file name for dtgen in gitattributes --- .gitattributes | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitattributes b/.gitattributes index 28a3d45f12..efec9cf353 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,2 @@ *.dtg.cc linguist-generated=true -*.dtg.hh linguist-generated=true +*.dtg.h linguist-generated=true From 238da2dbef5431c934d3fa3539deb2c250742743 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 26 May 2024 11:18:31 -0700 Subject: [PATCH 20/43] Update proj and format code --- flake.lock | 6 +- lib/compiler/test/src/test_optimal_cost.cc | 9 +- .../src/cuda/element_binary_kernels.cu | 3 +- lib/op-attrs/include/op-attrs/as_dot.h | 2 +- lib/op-attrs/include/op-attrs/datatype.h | 2 +- lib/op-attrs/include/op-attrs/dim_ordered.h | 3 +- lib/op-attrs/include/op-attrs/ff_dim.h | 7 +- lib/op-attrs/include/op-attrs/get_op_type.h | 17 +- .../include/op-attrs/operator_attrs.h | 10 +- lib/op-attrs/include/op-attrs/ops/attention.h | 2 +- .../include/op-attrs/ops/batch_matmul.h | 4 +- .../include/op-attrs/ops/batch_norm.h | 3 +- lib/op-attrs/include/op-attrs/ops/broadcast.h | 3 +- lib/op-attrs/include/op-attrs/ops/conv_2d.h | 15 +- .../conv_2d/conv_2d_parallel_input_shape.h | 3 +- lib/op-attrs/include/op-attrs/ops/dropout.h | 3 +- .../include/op-attrs/ops/element_binary.h | 8 +- .../include/op-attrs/ops/element_unary.h | 9 +- lib/op-attrs/include/op-attrs/ops/embedding.h | 3 +- .../include/op-attrs/ops/layer_norm.h | 3 +- lib/op-attrs/include/op-attrs/ops/linear.h | 15 +- lib/op-attrs/include/op-attrs/ops/noop.h | 3 +- lib/op-attrs/include/op-attrs/ops/pool_2d.h | 3 +- lib/op-attrs/include/op-attrs/ops/reduce.h | 3 +- lib/op-attrs/include/op-attrs/ops/reduction.h | 3 +- .../include/op-attrs/ops/repartition.h | 3 +- lib/op-attrs/include/op-attrs/ops/replicate.h | 3 +- lib/op-attrs/include/op-attrs/ops/reshape.h | 3 +- lib/op-attrs/include/op-attrs/ops/reverse.h | 3 +- lib/op-attrs/include/op-attrs/ops/softmax.h | 3 +- lib/op-attrs/include/op-attrs/ops/split.h | 4 +- lib/op-attrs/include/op-attrs/ops/topk.h | 3 +- lib/op-attrs/include/op-attrs/ops/transpose.h | 3 +- .../include/op-attrs/parallel_tensor_shape.h | 6 +- lib/op-attrs/include/op-attrs/param_sync.h | 4 +- .../op-attrs/replica_parallel_dim_set.h | 3 +- lib/op-attrs/include/op-attrs/tensor_dims.h | 2 +- .../op-attrs/computation_graph_op_attrs.cc | 3 +- lib/op-attrs/src/op-attrs/get_op_type.cc | 124 +++++-- lib/op-attrs/src/op-attrs/ops/attention.cc | 3 +- lib/op-attrs/src/op-attrs/ops/batch_matmul.cc | 5 +- lib/op-attrs/src/op-attrs/ops/batch_norm.cc | 3 +- lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 187 +++++----- .../ops/conv_2d/conv_2d_input_shape.cc | 11 +- .../conv_2d/conv_2d_parallel_input_shape.cc | 20 +- lib/op-attrs/src/op-attrs/ops/dropout.cc | 3 +- .../src/op-attrs/ops/element_binary.cc | 8 +- .../src/op-attrs/ops/element_unary.cc | 9 +- lib/op-attrs/src/op-attrs/ops/embedding.cc | 3 +- lib/op-attrs/src/op-attrs/ops/layer_norm.cc | 3 +- lib/op-attrs/src/op-attrs/ops/linear.cc | 60 ++-- lib/op-attrs/src/op-attrs/ops/noop.cc | 5 +- lib/op-attrs/src/op-attrs/ops/pool_2d.cc | 7 +- lib/op-attrs/src/op-attrs/ops/reduce.cc | 3 +- lib/op-attrs/src/op-attrs/ops/reduction.cc | 3 +- lib/op-attrs/src/op-attrs/ops/repartition.cc | 3 +- lib/op-attrs/src/op-attrs/ops/replicate.cc | 3 +- lib/op-attrs/src/op-attrs/ops/reshape.cc | 3 +- lib/op-attrs/src/op-attrs/ops/reverse.cc | 3 +- lib/op-attrs/src/op-attrs/ops/softmax.cc | 3 +- lib/op-attrs/src/op-attrs/ops/split.cc | 4 +- lib/op-attrs/src/op-attrs/ops/topk.cc | 3 +- lib/op-attrs/src/op-attrs/ops/transpose.cc | 3 +- .../src/op-attrs/parallel_tensor_dims.cc | 21 +- .../src/op-attrs/parallel_tensor_shape.cc | 9 +- .../src/op-attrs/pcg_operator_attrs.cc | 9 +- .../src/op-attrs/replica_parallel_dim_set.cc | 13 +- lib/op-attrs/src/op-attrs/tensor_dims.cc | 11 +- lib/op-attrs/src/op-attrs/tensor_shape.cc | 4 +- lib/op-attrs/test/src/test_conv_2d.cc | 53 ++- lib/op-attrs/test/src/test_operator_attrs.cc | 13 +- .../test/src/test_regularizer_attrs.cc | 2 +- lib/pcg/include/pcg/computation_graph.h | 11 +- .../include/pcg/computation_graph_builder.h | 267 +++++++------- lib/pcg/include/pcg/create_grad.h | 4 +- lib/pcg/include/pcg/dataflow_graph.h | 27 +- lib/pcg/include/pcg/device_id.h | 4 +- lib/pcg/include/pcg/file_format/v1/graphs.h | 8 +- lib/pcg/include/pcg/machine_specification.h | 4 +- lib/pcg/include/pcg/machine_view.h | 6 +- .../pcg/operator_graph/operator_graph.h | 34 +- lib/pcg/include/pcg/optimizer_attrs.h | 2 +- .../include/pcg/parallel_computation_graph.h | 2 +- lib/pcg/include/pcg/parallel_tensor.h | 4 +- lib/pcg/include/pcg/strided_rectangle.dtg.h | 10 +- lib/pcg/include/pcg/strided_rectangle.h | 3 +- lib/pcg/include/pcg/strided_rectangle_side.h | 2 +- lib/pcg/src/file_format/v1/graphs.cc | 34 +- lib/pcg/src/pcg/computation_graph.cc | 10 +- lib/pcg/src/pcg/computation_graph_builder.cc | 332 ++++++++++-------- lib/pcg/src/pcg/machine_view.cc | 8 +- .../src/pcg/operator_graph/operator_graph.cc | 14 +- lib/pcg/src/pcg/strided_rectangle.dtg.cc | 11 +- lib/pcg/src/pcg/strided_rectangle_side.cc | 5 +- .../src/test_computation_graph_builder.cc | 19 +- .../include/substitution-generator/json.h | 160 ++++----- .../include/substitutions/graph_pattern.h | 10 +- .../operator_pattern/eval_list_access.h | 6 +- .../operator_pattern/eval_list_size.h | 6 +- .../operator_attribute_expr.h | 2 +- .../operator_pattern/satisfies_constraint.h | 6 +- .../operator_pattern/satisfies_pattern.h | 3 +- .../sub_parallel_computation_graph.h | 10 +- .../tensor_pattern/eval_list_access.h | 5 +- .../tensor_pattern/eval_list_size.h | 5 +- .../tensor_pattern/get_attribute.h | 7 +- .../tensor_pattern/satisfies_constraint.h | 6 +- .../tensor_pattern/satisfies_pattern.h | 5 +- .../tensor_pattern/tensor_attribute_expr.h | 9 +- .../substitutions/unlabelled/edge_splits.h | 11 +- .../unlabelled/find_pattern_matches.h | 4 +- .../substitutions/unlabelled/match_split.h | 2 +- .../unlabelled/multidigraph_pattern_match.h | 10 +- .../substitutions/unlabelled/pattern_edge.h | 6 +- .../unlabelled/pattern_matching.h | 15 +- .../substitutions/unlabelled/pattern_split.h | 8 +- .../unlabelled/unlabelled_graph_pattern.h | 18 +- lib/substitutions/src/substitution.cc | 138 +++++--- .../src/substitutions/graph_pattern.cc | 32 +- .../operator_pattern/eval_list_access.cc | 50 +-- .../operator_pattern/eval_list_size.cc | 31 +- .../operator_pattern/get_attribute.cc | 16 +- .../operator_attribute_expr.cc | 16 +- .../operator_pattern/satisfies_constraint.cc | 11 +- .../operator_pattern/satisfies_pattern.cc | 9 +- .../sub_parallel_computation_graph.cc | 11 +- .../src/substitutions/substitution.cc | 11 +- .../tensor_pattern/eval_list_access.cc | 19 +- .../tensor_pattern/eval_list_size.cc | 15 +- .../tensor_pattern/get_attribute.cc | 18 +- .../tensor_pattern/satisfies_constraint.cc | 11 +- .../tensor_pattern/satisfies_pattern.cc | 7 +- .../tensor_pattern/tensor_attribute_expr.cc | 28 +- .../substitutions/unlabelled/edge_splits.cc | 26 +- .../unlabelled/find_pattern_matches.cc | 56 ++- .../substitutions/unlabelled/match_split.cc | 17 +- .../unlabelled/multidigraph_pattern_match.cc | 20 +- .../substitutions/unlabelled/pattern_edge.cc | 18 +- .../unlabelled/pattern_matching.cc | 36 +- .../substitutions/unlabelled/pattern_split.cc | 25 +- .../unlabelled/unlabelled_graph_pattern.cc | 27 +- .../test/src/test_substitution.cc | 18 +- lib/utils/include/utils/bidict.h | 8 +- lib/utils/include/utils/containers.decl.h | 3 +- lib/utils/include/utils/containers.h | 2 +- .../include/utils/containers/concat_vectors.h | 3 +- .../utils/containers/enumerate_vector.h | 2 +- .../include/utils/containers/extend_vector.h | 1 - lib/utils/include/utils/exception.decl.h | 12 +- lib/utils/include/utils/fmt.decl.h | 22 +- lib/utils/include/utils/fmt.h | 38 +- lib/utils/include/utils/fmt/pair.h | 2 +- lib/utils/include/utils/fmt/unordered_map.h | 27 +- lib/utils/include/utils/join_strings.h | 1 - lib/utils/include/utils/json.h | 4 +- lib/utils/include/utils/optional.h | 11 +- lib/utils/include/utils/stack_string.h | 3 +- lib/utils/src/exception.cc | 9 +- lib/utils/src/utils/integer_conversions.cc | 2 +- 159 files changed, 1566 insertions(+), 1187 deletions(-) diff --git a/flake.lock b/flake.lock index e27afe3c2d..afcfa9aa21 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1716611864, - "narHash": "sha256-Nd1Hv4j5Wy70KzuAji3/50Fcr7axCw+o8TBJsh2f5UY=", + "lastModified": 1716747446, + "narHash": "sha256-mn3br/KFBtv4c4ZLHR1ZIqFeM1p93rcfHivselz+Nr4=", "owner": "lockshaw", "repo": "proj", - "rev": "b83e54a47c755b241ec5fa2a79aa455cba7dfc18", + "rev": "62839e7ac51dc16ddd05a5e174e0590ea85afc65", "type": "github" }, "original": { diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index 959fa07f25..8c176eb4d2 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -42,13 +42,8 @@ TEST_SUITE(FF_TEST_SUITE) { MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; pcg.add_edge(e); ParallelDim dim = {2, 1, false}; - ParallelTensorDims dims = { - FFOrdered{dim} - }; - pcg.add_output(e, - ParallelTensor(dims, - DataType::FLOAT, - CreateGrad::YES)); + ParallelTensorDims dims = {FFOrdered{dim}}; + pcg.add_output(e, ParallelTensor(dims, DataType::FLOAT, CreateGrad::YES)); auto test_allowed_machine_views = [](Operator const &, MachineSpecification const &) { diff --git a/lib/kernels/src/cuda/element_binary_kernels.cu b/lib/kernels/src/cuda/element_binary_kernels.cu index be06504197..39a67ee06e 100644 --- a/lib/kernels/src/cuda/element_binary_kernels.cu +++ b/lib/kernels/src/cuda/element_binary_kernels.cu @@ -394,7 +394,8 @@ void backward_kernel(cudaStream_t stream, rhs_grad_ptr)); } } - } else if (op_type == OperatorType::EW_MIN || op_type == OperatorType::EW_MAX) { + } else if (op_type == OperatorType::EW_MIN || + op_type == OperatorType::EW_MAX) { float alpha = 1.0f, beta = 1.0f; cudnnDataType_t dataType; int n; diff --git a/lib/op-attrs/include/op-attrs/as_dot.h b/lib/op-attrs/include/op-attrs/as_dot.h index 4a5d9eaeb1..d92557c2f4 100644 --- a/lib/op-attrs/include/op-attrs/as_dot.h +++ b/lib/op-attrs/include/op-attrs/as_dot.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AS_DOT_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AS_DOT_H -#include "op-attrs/pcg_operator_attrs.dtg.h" #include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" #include "utils/record_formatter.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/datatype.h b/lib/op-attrs/include/op-attrs/datatype.h index f3f3c4a08e..01360fdc95 100644 --- a/lib/op-attrs/include/op-attrs/datatype.h +++ b/lib/op-attrs/include/op-attrs/datatype.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_DATATYPE_H #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_DATATYPE_H +#include "op-attrs/datatype.dtg.h" #include "utils/fmt.h" #include "utils/fp16.h" #include -#include "op-attrs/datatype.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index ab62c1d30c..e7c1891a4b 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H #include "op-attrs/ff_dim.dtg.h" -#include "utils/stack_vector.h" #include "utils/json.h" +#include "utils/stack_vector.h" namespace FlexFlow { @@ -126,6 +126,7 @@ struct DimOrdered { } friend struct ::std::hash; + private: stack_vector contents; }; diff --git a/lib/op-attrs/include/op-attrs/ff_dim.h b/lib/op-attrs/include/op-attrs/ff_dim.h index b0559c2f1e..e78ce4b51e 100644 --- a/lib/op-attrs/include/op-attrs/ff_dim.h +++ b/lib/op-attrs/include/op-attrs/ff_dim.h @@ -2,14 +2,15 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H -#include "rapidcheck.h" #include "op-attrs/ff_dim.dtg.h" +#include "rapidcheck.h" namespace rc { template <> struct Arbitrary { - static Gen arbitrary(){ - return gen::construct(gen::inRange(0, MAX_TENSOR_DIM)); + static Gen arbitrary() { + return gen::construct( + gen::inRange(0, MAX_TENSOR_DIM)); } }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/get_op_type.h b/lib/op-attrs/include/op-attrs/get_op_type.h index 03b4f92259..39541aa8d6 100644 --- a/lib/op-attrs/include/op-attrs/get_op_type.h +++ b/lib/op-attrs/include/op-attrs/get_op_type.h @@ -1,39 +1,38 @@ #ifndef _FLEXFLOW_OP_ATTRS_GET_OP_TYPE_H #define _FLEXFLOW_OP_ATTRS_GET_OP_TYPE_H +#include "op-attrs/ops/attention_attrs.dtg.h" #include "op-attrs/ops/batch_matmul.dtg.h" #include "op-attrs/ops/batch_norm_attrs.dtg.h" #include "op-attrs/ops/broadcast.dtg.h" #include "op-attrs/ops/cast_attrs.dtg.h" +#include "op-attrs/ops/combine_attrs.dtg.h" #include "op-attrs/ops/concat_attrs.dtg.h" #include "op-attrs/ops/conv_2d_attrs.dtg.h" #include "op-attrs/ops/dropout_attrs.dtg.h" #include "op-attrs/ops/element_binary_attrs.dtg.h" -#include "op-attrs/ops/element_unary_attrs.dtg.h" -#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" #include "op-attrs/ops/embedding_attrs.dtg.h" #include "op-attrs/ops/flat_attrs.dtg.h" #include "op-attrs/ops/gather_attrs.dtg.h" #include "op-attrs/ops/input_attrs.dtg.h" #include "op-attrs/ops/layer_norm_attrs.dtg.h" #include "op-attrs/ops/linear_attrs.dtg.h" -#include "op-attrs/ops/attention_attrs.dtg.h" #include "op-attrs/ops/noop_attrs.dtg.h" #include "op-attrs/ops/pool_2d_attrs.dtg.h" #include "op-attrs/ops/reduce_attrs.dtg.h" +#include "op-attrs/ops/reduction_attrs.dtg.h" +#include "op-attrs/ops/repartition_attrs.dtg.h" +#include "op-attrs/ops/replicate_attrs.dtg.h" #include "op-attrs/ops/reshape_attrs.dtg.h" #include "op-attrs/ops/reverse_attrs.dtg.h" -#include "op-attrs/ops/split_attrs.dtg.h" #include "op-attrs/ops/softmax_attrs.dtg.h" +#include "op-attrs/ops/split_attrs.dtg.h" #include "op-attrs/ops/topk_attrs.dtg.h" #include "op-attrs/ops/transpose_attrs.dtg.h" -#include "op-attrs/ops/combine_attrs.dtg.h" -#include "op-attrs/ops/reduction_attrs.dtg.h" -#include "op-attrs/ops/repartition_attrs.dtg.h" -#include "op-attrs/ops/replicate_attrs.dtg.h" #include "op-attrs/ops/weight_attrs.dtg.h" - namespace FlexFlow { OperatorType get_op_type(BatchMatmulAttrs const &); diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index b96541d34f..7acd322928 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -2,6 +2,7 @@ #define _OPERATOR_PARAMS_H #include "op-attrs/ops/core.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" #include "ops/attention.h" #include "ops/batch_matmul.h" #include "ops/batch_norm.h" @@ -31,10 +32,9 @@ #include "ops/split.h" #include "ops/topk.h" #include "ops/transpose.h" +#include "utils/record_formatter.h" #include "utils/variant.h" #include -#include "utils/record_formatter.h" -#include "op-attrs/pcg_operator_attrs.dtg.h" namespace FlexFlow { @@ -89,7 +89,8 @@ namespace FlexFlow { /* static_assert(is_valid_opattr::value, ""); */ /* using ParallelOperatorAttrs = std:: */ -/* variant; */ +/* variant; + */ /* using ComputationGraphAttrs = */ /* variant_join>; */ @@ -100,7 +101,8 @@ namespace FlexFlow { /* static_assert(is_equal_comparable::value, */ /* "ComputationGraphAttrs must support =="); */ -/* static_assert(elements_satisfy::value, */ +/* static_assert(elements_satisfy::value, */ /* ""); */ /* static_assert(is_neq_comparable::value, */ /* "ComputationGraphAttrs must support !="); */ diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index de7246dcef..2ec6585bbe 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -3,9 +3,9 @@ #include "core.h" #include "op-attrs/ops/attention_attrs.dtg.h" -#include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/ops/attention_inputs.dtg.h" #include "op-attrs/ops/parallel_attention_inputs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 7860f891e3..412d694f69 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -6,7 +6,9 @@ namespace FlexFlow { -bool is_valid(BatchMatmulAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &); +bool is_valid(BatchMatmulAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index 3230ab4239..b9a1d87a75 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -7,7 +7,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(BatchNormAttrs const &, ParallelTensorShape const &); +ParallelTensorShape get_output_shape(BatchNormAttrs const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchNormAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index 9ee96458b9..ad44060400 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -6,7 +6,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(BroadcastAttrs const &, ParallelTensorShape const &); +ParallelTensorShape get_output_shape(BroadcastAttrs const &, + ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index 5edb680cc8..7759380088 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -10,13 +10,18 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(Conv2DAttrs); -TensorShape get_kernel_shape(Conv2DAttrs const &attrs, TensorShape const &input); +TensorShape get_kernel_shape(Conv2DAttrs const &attrs, + TensorShape const &input); TensorShape get_bias_shape(Conv2DAttrs const &attrs, TensorShape const &input); -TensorShape get_output_shape(Conv2DAttrs const &attrs, TensorShape const &input); +TensorShape get_output_shape(Conv2DAttrs const &attrs, + TensorShape const &input); -ParallelTensorShape get_kernel_shape(Conv2DAttrs const &attrs, ParallelTensorShape const &input_shape); -ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs, ParallelTensorShape const &input_shape); -ParallelTensorShape get_output_shape(Conv2DAttrs const &attrs, ParallelTensorShape const &input_shape); +ParallelTensorShape get_kernel_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &input_shape); +ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &input_shape); +ParallelTensorShape get_output_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h index 9edff21db8..accc64e751 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h @@ -6,7 +6,8 @@ namespace FlexFlow { -Conv2DParallelInputShape parse_parallel_input_shape(ParallelTensorShape const &input); +Conv2DParallelInputShape + parse_parallel_input_shape(ParallelTensorShape const &input); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/dropout.h b/lib/op-attrs/include/op-attrs/ops/dropout.h index 54e6fbf279..a0493301c4 100644 --- a/lib/op-attrs/include/op-attrs/ops/dropout.h +++ b/lib/op-attrs/include/op-attrs/ops/dropout.h @@ -7,7 +7,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(DropoutAttrs const &, ParallelTensorShape const &); +ParallelTensorShape get_output_shape(DropoutAttrs const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(DropoutAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary.h b/lib/op-attrs/include/op-attrs/ops/element_binary.h index 18c4a1eea5..39ae70ecfe 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary.h @@ -7,8 +7,12 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &); -TensorShape get_output_shape(ElementBinaryAttrs const &, TensorShape const &, TensorShape const &); +ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); +TensorShape get_output_shape(ElementBinaryAttrs const &, + TensorShape const &, + TensorShape const &); CHECK_VALID_OP_ATTR(ElementBinaryAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 808c453d2c..cfec033a16 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -9,11 +9,14 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(ElementUnaryAttrs const &, ParallelTensorShape const &); +ParallelTensorShape get_output_shape(ElementUnaryAttrs const &, + ParallelTensorShape const &); TensorShape get_output_shape(ElementUnaryAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(ElementScalarUnaryAttrs const &, ParallelTensorShape const &); -TensorShape get_output_shape(ElementScalarUnaryAttrs const &, TensorShape const &); +ParallelTensorShape get_output_shape(ElementScalarUnaryAttrs const &, + ParallelTensorShape const &); +TensorShape get_output_shape(ElementScalarUnaryAttrs const &, + TensorShape const &); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/embedding.h b/lib/op-attrs/include/op-attrs/ops/embedding.h index 2a7d8cd7bf..f7a2226643 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding.h @@ -12,7 +12,8 @@ CHECK_VALID_OP_ATTR(EmbeddingAttrs); TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(EmbeddingAttrs const &, ParallelTensorShape const &); +ParallelTensorShape get_output_shape(EmbeddingAttrs const &, + ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index 3186bbba11..01130139f1 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -7,7 +7,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(LayerNormAttrs const &, ParallelTensorShape const &); +ParallelTensorShape get_output_shape(LayerNormAttrs const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(LayerNormAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index d90d0712db..566fb3dcf1 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -10,13 +10,18 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(LinearAttrs); -TensorShape get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input); +TensorShape get_kernel_shape(LinearAttrs const &attrs, + TensorShape const &input); TensorShape get_bias_shape(LinearAttrs const &attrs, TensorShape const &input); -TensorShape get_output_shape(LinearAttrs const &attrs, TensorShape const &input); +TensorShape get_output_shape(LinearAttrs const &attrs, + TensorShape const &input); -ParallelTensorShape get_kernel_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); -ParallelTensorShape get_bias_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); -ParallelTensorShape get_output_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); +ParallelTensorShape get_kernel_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input); +ParallelTensorShape get_bias_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input); +ParallelTensorShape get_output_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/noop.h b/lib/op-attrs/include/op-attrs/ops/noop.h index 635fa3d490..eb01009259 100644 --- a/lib/op-attrs/include/op-attrs/ops/noop.h +++ b/lib/op-attrs/include/op-attrs/ops/noop.h @@ -9,7 +9,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(NoopAttrs); -ParallelTensorShape get_output_shape(NoopAttrs const &, ParallelTensorShape const &); +ParallelTensorShape get_output_shape(NoopAttrs const &, + ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index 9a9193fd63..162f9aef05 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -9,7 +9,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(Pool2DAttrs); -ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); +ParallelTensorShape get_output_shape(Pool2DAttrs const &, + ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reduce.h b/lib/op-attrs/include/op-attrs/ops/reduce.h index ce5ae7d3fd..800610fb2b 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce.h @@ -9,7 +9,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ReduceAttrs); -ParallelTensorShape get_output_shape(ReduceAttrs const &, ParallelTensorShape const &input_shape); +ParallelTensorShape get_output_shape(ReduceAttrs const &, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reduction.h b/lib/op-attrs/include/op-attrs/ops/reduction.h index a4ce679330..0ab9861b67 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction.h @@ -9,7 +9,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ReductionAttrs); -ParallelTensorShape get_output_shape(ReductionAttrs const &attrs, ParallelTensorShape const &input_shape); +ParallelTensorShape get_output_shape(ReductionAttrs const &attrs, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/repartition.h b/lib/op-attrs/include/op-attrs/ops/repartition.h index 5dff92e966..09ab21615a 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition.h @@ -9,7 +9,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(RepartitionAttrs); -ParallelTensorShape get_output_shape(RepartitionAttrs const &, ParallelTensorShape const &input_shape); +ParallelTensorShape get_output_shape(RepartitionAttrs const &, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index c6430ddbc5..4c46bf88a9 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -9,7 +9,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ReplicateAttrs); -ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, ParallelTensorShape const &input_shape); +ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reshape.h b/lib/op-attrs/include/op-attrs/ops/reshape.h index 2cd0287d45..cd2ca80c3a 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape.h +++ b/lib/op-attrs/include/op-attrs/ops/reshape.h @@ -9,7 +9,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ReshapeAttrs); -ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, ParallelTensorShape const &input_shape); +ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reverse.h b/lib/op-attrs/include/op-attrs/ops/reverse.h index 45b05e62ab..adc62dc9ae 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse.h +++ b/lib/op-attrs/include/op-attrs/ops/reverse.h @@ -9,7 +9,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ReverseAttrs); -ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, ParallelTensorShape const &input_shape); +ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/softmax.h b/lib/op-attrs/include/op-attrs/ops/softmax.h index 7ae5eb7438..d855716cfb 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax.h +++ b/lib/op-attrs/include/op-attrs/ops/softmax.h @@ -9,7 +9,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(SoftmaxAttrs); -ParallelTensorShape get_output_shape(SoftmaxAttrs const &attrs, ParallelTensorShape const &input_shape); +ParallelTensorShape get_output_shape(SoftmaxAttrs const &attrs, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/split.h b/lib/op-attrs/include/op-attrs/ops/split.h index 08ce826945..8fc2257760 100644 --- a/lib/op-attrs/include/op-attrs/ops/split.h +++ b/lib/op-attrs/include/op-attrs/ops/split.h @@ -10,7 +10,9 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(SplitAttrs); -std::vector get_output_shapes(SplitAttrs const &attrs, ParallelTensorShape const &input_shape); +std::vector + get_output_shapes(SplitAttrs const &attrs, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index 4fab6584b4..c6af40dd48 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -9,7 +9,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(TopKAttrs); -ParallelTensorShape get_output_shape(TopKAttrs const &attrs, ParallelTensorShape const &input_shape); +ParallelTensorShape get_output_shape(TopKAttrs const &attrs, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/transpose.h b/lib/op-attrs/include/op-attrs/ops/transpose.h index 4156885610..6e23d91d78 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose.h @@ -9,7 +9,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(TransposeAttrs); -ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs, ParallelTensorShape const &input_shape); +ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 6b9bde8283..0482989e0e 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -1,10 +1,9 @@ #ifndef _OP_META_PARALLEL_TENSOR_SHAPE_H #define _OP_META_PARALLEL_TENSOR_SHAPE_H -#include "op-attrs/tensor_shape.h" -#include #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.h" +#include namespace FlexFlow { @@ -14,7 +13,8 @@ ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &, ff_dim_t); ParallelTensorShape lift_to_parallel(TensorShape const &); -std::unordered_set replica_dims(ParallelTensorShape const &); +std::unordered_set + replica_dims(ParallelTensorShape const &); TensorShape get_piece_shape(ParallelTensorShape const &); int get_num_replica_dims(ParallelTensorShape const &); int get_num_replicas(ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/param_sync.h b/lib/op-attrs/include/op-attrs/param_sync.h index 55845a931b..dd7048ff36 100644 --- a/lib/op-attrs/include/op-attrs/param_sync.h +++ b/lib/op-attrs/include/op-attrs/param_sync.h @@ -3,8 +3,6 @@ #include "param_sync_t.h" -namespace FlexFlow { - -} +namespace FlexFlow {} #endif diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.h index e8b6a92114..74a8df339b 100644 --- a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.h +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.h @@ -9,7 +9,8 @@ namespace FlexFlow { ReplicaParallelDimSet empty_replica_parallel_dim_set(); int get_degree_of_replica_type(ReplicaParallelDimSet const &, ReplicaType); -std::unordered_set get_replica_dims(ReplicaParallelDimSet const &); +std::unordered_set + get_replica_dims(ReplicaParallelDimSet const &); bool is_valid(ReplicaParallelDimSet const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h index 05a2c6e263..6302ab1418 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_H -#include "op-attrs/tensor_dims.dtg.h" #include "op-attrs/parallel_tensor_dims.dtg.h" +#include "op-attrs/tensor_dims.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc index a7145ca1dd..166416cbad 100644 --- a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc @@ -4,7 +4,8 @@ namespace FlexFlow { OperatorType get_op_type(ComputationGraphOpAttrs const &attrs) { - return attrs.visit([](auto const &x) { return get_op_type(x); }); + return attrs.visit( + [](auto const &x) { return get_op_type(x); }); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/get_op_type.cc b/lib/op-attrs/src/op-attrs/get_op_type.cc index 2c658d9189..aced8d873c 100644 --- a/lib/op-attrs/src/op-attrs/get_op_type.cc +++ b/lib/op-attrs/src/op-attrs/get_op_type.cc @@ -2,36 +2,98 @@ namespace FlexFlow { -OperatorType get_op_type(BatchMatmulAttrs const &) { return OperatorType::BATCHMATMUL; } -OperatorType get_op_type(BatchNormAttrs const &) { return OperatorType::BATCHNORM; } -OperatorType get_op_type(BroadcastAttrs const &) { return OperatorType::BROADCAST; } -OperatorType get_op_type(CastAttrs const &) { return OperatorType::CAST; } -OperatorType get_op_type(ConcatAttrs const &) { return OperatorType::CONCAT; } -OperatorType get_op_type(Conv2DAttrs const &) { return OperatorType::CONV2D; } -OperatorType get_op_type(DropoutAttrs const &) { return OperatorType::DROPOUT; } -OperatorType get_op_type(ElementBinaryAttrs const &attrs) { return attrs.type; } -OperatorType get_op_type(ElementUnaryAttrs const &attrs) { return attrs.op_type; } -OperatorType get_op_type(ElementScalarUnaryAttrs const &attrs) { return attrs.op_type; } -OperatorType get_op_type(EmbeddingAttrs const &) { return OperatorType::EMBEDDING; } -OperatorType get_op_type(FlatAttrs const &) { return OperatorType::FLAT; } -OperatorType get_op_type(GatherAttrs const &) { return OperatorType::GATHER; } -OperatorType get_op_type(InputAttrs const &) { return OperatorType::INPUT; } -OperatorType get_op_type(LayerNormAttrs const &) { return OperatorType::LAYERNORM; } -OperatorType get_op_type(LinearAttrs const &) { return OperatorType::LINEAR; } -OperatorType get_op_type(MultiHeadAttentionAttrs const &) { return OperatorType::MULTIHEAD_ATTENTION; } -OperatorType get_op_type(NoopAttrs const &) { return OperatorType::NOOP; } -OperatorType get_op_type(Pool2DAttrs const &) { return OperatorType::POOL2D; } -OperatorType get_op_type(ReduceAttrs const &attrs) { return attrs.op_type; } -OperatorType get_op_type(ReshapeAttrs const &) { return OperatorType::RESHAPE; } -OperatorType get_op_type(ReverseAttrs const &) { return OperatorType::REVERSE; } -OperatorType get_op_type(SplitAttrs const &) { return OperatorType::SPLIT; } -OperatorType get_op_type(SoftmaxAttrs const &) { return OperatorType::SOFTMAX; } -OperatorType get_op_type(TopKAttrs const &) { return OperatorType::TOPK; } -OperatorType get_op_type(TransposeAttrs const &) { return OperatorType::TRANSPOSE; } -OperatorType get_op_type(CombineAttrs const &) { return OperatorType::COMBINE; } -OperatorType get_op_type(ReductionAttrs const &) { return OperatorType::REDUCTION; } -OperatorType get_op_type(RepartitionAttrs const &) { return OperatorType::REPARTITION; } -OperatorType get_op_type(ReplicateAttrs const &) { return OperatorType::REPLICATE; } -OperatorType get_op_type(WeightAttrs const &) { return OperatorType::WEIGHT; } +OperatorType get_op_type(BatchMatmulAttrs const &) { + return OperatorType::BATCHMATMUL; +} +OperatorType get_op_type(BatchNormAttrs const &) { + return OperatorType::BATCHNORM; +} +OperatorType get_op_type(BroadcastAttrs const &) { + return OperatorType::BROADCAST; +} +OperatorType get_op_type(CastAttrs const &) { + return OperatorType::CAST; +} +OperatorType get_op_type(ConcatAttrs const &) { + return OperatorType::CONCAT; +} +OperatorType get_op_type(Conv2DAttrs const &) { + return OperatorType::CONV2D; +} +OperatorType get_op_type(DropoutAttrs const &) { + return OperatorType::DROPOUT; +} +OperatorType get_op_type(ElementBinaryAttrs const &attrs) { + return attrs.type; +} +OperatorType get_op_type(ElementUnaryAttrs const &attrs) { + return attrs.op_type; +} +OperatorType get_op_type(ElementScalarUnaryAttrs const &attrs) { + return attrs.op_type; +} +OperatorType get_op_type(EmbeddingAttrs const &) { + return OperatorType::EMBEDDING; +} +OperatorType get_op_type(FlatAttrs const &) { + return OperatorType::FLAT; +} +OperatorType get_op_type(GatherAttrs const &) { + return OperatorType::GATHER; +} +OperatorType get_op_type(InputAttrs const &) { + return OperatorType::INPUT; +} +OperatorType get_op_type(LayerNormAttrs const &) { + return OperatorType::LAYERNORM; +} +OperatorType get_op_type(LinearAttrs const &) { + return OperatorType::LINEAR; +} +OperatorType get_op_type(MultiHeadAttentionAttrs const &) { + return OperatorType::MULTIHEAD_ATTENTION; +} +OperatorType get_op_type(NoopAttrs const &) { + return OperatorType::NOOP; +} +OperatorType get_op_type(Pool2DAttrs const &) { + return OperatorType::POOL2D; +} +OperatorType get_op_type(ReduceAttrs const &attrs) { + return attrs.op_type; +} +OperatorType get_op_type(ReshapeAttrs const &) { + return OperatorType::RESHAPE; +} +OperatorType get_op_type(ReverseAttrs const &) { + return OperatorType::REVERSE; +} +OperatorType get_op_type(SplitAttrs const &) { + return OperatorType::SPLIT; +} +OperatorType get_op_type(SoftmaxAttrs const &) { + return OperatorType::SOFTMAX; +} +OperatorType get_op_type(TopKAttrs const &) { + return OperatorType::TOPK; +} +OperatorType get_op_type(TransposeAttrs const &) { + return OperatorType::TRANSPOSE; +} +OperatorType get_op_type(CombineAttrs const &) { + return OperatorType::COMBINE; +} +OperatorType get_op_type(ReductionAttrs const &) { + return OperatorType::REDUCTION; +} +OperatorType get_op_type(RepartitionAttrs const &) { + return OperatorType::REPARTITION; +} +OperatorType get_op_type(ReplicateAttrs const &) { + return OperatorType::REPLICATE; +} +OperatorType get_op_type(WeightAttrs const &) { + return OperatorType::WEIGHT; +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index 105fa7250b..c62e1af48b 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -82,7 +82,8 @@ ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, ParallelTensorShape const &value_shape) { NOT_IMPLEMENTED(); /* ParallelTensorShape output_shape = query_shape; */ - /* dim_at_idx(output_shape, ff_dim_t(num_dims(output_shape) - 1)).size = attrs.embed_dim; */ + /* dim_at_idx(output_shape, ff_dim_t(num_dims(output_shape) - 1)).size = + * attrs.embed_dim; */ /* return output_shape; */ } diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc index 28e1f0af0a..cd3b198955 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc @@ -23,8 +23,9 @@ namespace FlexFlow { /* return true; */ /* } */ - -bool is_valid(BatchMatmulAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &) { +bool is_valid(BatchMatmulAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc index 9152a1306c..7be51efa22 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(BatchNormAttrs const &, ParallelTensorShape const &) { +ParallelTensorShape get_output_shape(BatchNormAttrs const &, + ParallelTensorShape const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index 8cf2afe125..c9ec467af4 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -5,132 +5,146 @@ namespace FlexFlow { -TensorShape get_kernel_shape(Conv2DAttrs const &attrs, TensorShape const &raw_input_shape) { - assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported +TensorShape get_kernel_shape(Conv2DAttrs const &attrs, + TensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported Conv2DInputShape input = parse_input_shape(raw_input_shape); return TensorShape{ - TensorDims{ - FFOrdered{ - size_t_from_int(attrs.out_channels), - input.num_channels, - size_t_from_int(attrs.kernel_h), - size_t_from_int(attrs.kernel_w), - } - }, - input.datatype, + TensorDims{FFOrdered{ + size_t_from_int(attrs.out_channels), + input.num_channels, + size_t_from_int(attrs.kernel_h), + size_t_from_int(attrs.kernel_w), + }}, + input.datatype, }; } -TensorShape get_bias_shape(Conv2DAttrs const &attrs, TensorShape const &raw_input_shape) { - assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported +TensorShape get_bias_shape(Conv2DAttrs const &attrs, + TensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported Conv2DInputShape input = parse_input_shape(raw_input_shape); return TensorShape{ - TensorDims{ - FFOrdered{ - size_t_from_int(attrs.out_channels) + TensorDims{ + FFOrdered{size_t_from_int(attrs.out_channels)}, }, - }, - input.datatype, + input.datatype, }; } -TensorShape get_output_shape(Conv2DAttrs const &attrs, TensorShape const &raw_input_shape) { - assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported +TensorShape get_output_shape(Conv2DAttrs const &attrs, + TensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported Conv2DInputShape input = parse_input_shape(raw_input_shape); - size_t out_height = (input.height - (2 * attrs.padding_h) - (attrs.kernel_h - 1)) / attrs.stride_h; - size_t out_width = (input.width - (2 * attrs.padding_w) - (attrs.kernel_w - 1)) / attrs.stride_w; - - assert (attrs.out_channels > 0); - - return TensorShape{ - TensorDims{ - FFOrdered{ - input.num_samples, - size_t_from_int(attrs.out_channels), - out_height, - out_width, - } - }, - input.datatype - }; + size_t out_height = + (input.height - (2 * attrs.padding_h) - (attrs.kernel_h - 1)) / + attrs.stride_h; + size_t out_width = + (input.width - (2 * attrs.padding_w) - (attrs.kernel_w - 1)) / + attrs.stride_w; + + assert(attrs.out_channels > 0); + + return TensorShape{TensorDims{FFOrdered{ + input.num_samples, + size_t_from_int(attrs.out_channels), + out_height, + out_width, + }}, + input.datatype}; } -ParallelTensorShape get_kernel_shape(Conv2DAttrs const &attrs, ParallelTensorShape const &raw_input_shape) { - assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported +ParallelTensorShape + get_kernel_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); - ShardParallelDim output_channels_dim = {size_t_from_int(attrs.out_channels), input.discard_copy_reduction_degree}; - ShardParallelDim input_channels_dim = {size_t_from_int(input.channel_dim.size), input.channel_dim.degree}; + ShardParallelDim output_channels_dim = {size_t_from_int(attrs.out_channels), + input.discard_copy_reduction_degree}; + ShardParallelDim input_channels_dim = { + size_t_from_int(input.channel_dim.size), input.channel_dim.degree}; ShardParallelDim kernel_height_dim = {size_t_from_int(attrs.kernel_h), 1}; ShardParallelDim kernel_width_dim = {size_t_from_int(attrs.kernel_w), 1}; int sum_degree = 1; - int discard_copy_degree = input.height_dim.degree * input.width_dim.degree * input.sum_reduction_degree; + int discard_copy_degree = input.height_dim.degree * input.width_dim.degree * + input.sum_reduction_degree; ParallelTensorShape result = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - output_channels_dim, - input_channels_dim, - kernel_height_dim, - kernel_width_dim, + ParallelTensorDims{ + FFOrdered{ + output_channels_dim, + input_channels_dim, + kernel_height_dim, + kernel_width_dim, + }, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }, }, - ReplicaParallelDimSet{ - sum_degree, - discard_copy_degree, - }, - }, - input.datatype, + input.datatype, }; - assert (total_parallel_degree(result.dims) == total_parallel_degree(raw_input_shape.dims)); + assert(total_parallel_degree(result.dims) == + total_parallel_degree(raw_input_shape.dims)); return result; } -ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs, ParallelTensorShape const &raw_input_shape) { - assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported +ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); - ShardParallelDim output_channels_dim = {size_t_from_int(attrs.out_channels), input.discard_copy_reduction_degree}; + ShardParallelDim output_channels_dim = {size_t_from_int(attrs.out_channels), + input.discard_copy_reduction_degree}; int sum_degree = 1; - int discard_copy_degree = input.height_dim.degree * input.width_dim.degree * input.sum_reduction_degree * input.channel_dim.degree; + int discard_copy_degree = input.height_dim.degree * input.width_dim.degree * + input.sum_reduction_degree * + input.channel_dim.degree; ParallelTensorShape result = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - output_channels_dim, - }, - ReplicaParallelDimSet{ - sum_degree, - discard_copy_degree, + ParallelTensorDims{ + FFOrdered{ + output_channels_dim, + }, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }, }, - }, - input.datatype, + input.datatype, }; - assert (total_parallel_degree(result.dims) == total_parallel_degree(raw_input_shape.dims)); + assert(total_parallel_degree(result.dims) == + total_parallel_degree(raw_input_shape.dims)); return result; } -ParallelTensorShape get_output_shape(Conv2DAttrs const &attrs, ParallelTensorShape const &raw_input_shape) { - assert (attrs.groups == 1); // TODO(@lockshaw): currently not supported +ParallelTensorShape + get_output_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); - TensorShape unpar_output_shape = get_output_shape(attrs, get_reduced_shape(raw_input_shape)); + TensorShape unpar_output_shape = + get_output_shape(attrs, get_reduced_shape(raw_input_shape)); size_t num_samples = dim_at_idx(unpar_output_shape, ff_dim_t{0}); size_t num_channels = dim_at_idx(unpar_output_shape, ff_dim_t{1}); size_t height = dim_at_idx(unpar_output_shape, ff_dim_t{2}); size_t width = dim_at_idx(unpar_output_shape, ff_dim_t{3}); - ShardParallelDim sample_dim = {num_samples, input.sample_dim.degree}; - ShardParallelDim channel_dim = {num_channels, input.discard_copy_reduction_degree}; + ShardParallelDim sample_dim = {num_samples, input.sample_dim.degree}; + ShardParallelDim channel_dim = {num_channels, + input.discard_copy_reduction_degree}; ShardParallelDim height_dim = {height, input.height_dim.degree}; ShardParallelDim width_dim = {width, input.width_dim.degree}; @@ -138,22 +152,23 @@ ParallelTensorShape get_output_shape(Conv2DAttrs const &attrs, ParallelTensorSha int discard_copy_degree = 1; ParallelTensorShape result = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - sample_dim, - channel_dim, - height_dim, - width_dim, - }, - ReplicaParallelDimSet{ - sum_degree, - discard_copy_degree, + ParallelTensorDims{ + FFOrdered{ + sample_dim, + channel_dim, + height_dim, + width_dim, + }, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }, }, - }, - input.datatype, + input.datatype, }; - assert (total_parallel_degree(result.dims) == total_parallel_degree(raw_input_shape.dims)); + assert(total_parallel_degree(result.dims) == + total_parallel_degree(raw_input_shape.dims)); return result; } diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc index ed508de131..a8a3b10bdf 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc @@ -12,13 +12,12 @@ Conv2DInputShape parse_input_shape(TensorShape const &input) { size_t in_width = dim_at_idx(input, ff_dim_t{3}); return Conv2DInputShape{ - num_samples, - in_channels, - in_height, - in_width, - input.data_type, + num_samples, + in_channels, + in_height, + in_width, + input.data_type, }; } - } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc index 501f42fe0a..8074a03b4d 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc @@ -3,24 +3,24 @@ namespace FlexFlow { -Conv2DParallelInputShape parse_parallel_input_shape(ParallelTensorShape const &input) { +Conv2DParallelInputShape + parse_parallel_input_shape(ParallelTensorShape const &input) { assert(num_shard_dims(input) == 4); - + ShardParallelDim sample_dim = shard_dim_at_idx(input, ff_dim_t{0}); ShardParallelDim channel_dim = shard_dim_at_idx(input, ff_dim_t{1}); ShardParallelDim height_dim = shard_dim_at_idx(input, ff_dim_t{2}); ShardParallelDim width_dim = shard_dim_at_idx(input, ff_dim_t{3}); return Conv2DParallelInputShape{ - sample_dim, - channel_dim, - height_dim, - width_dim, - input.dims.replica_dims.sum_degree, - input.dims.replica_dims.discard_copy_degree, - input.data_type, + sample_dim, + channel_dim, + height_dim, + width_dim, + input.dims.replica_dims.sum_degree, + input.dims.replica_dims.discard_copy_degree, + input.data_type, }; } - } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/dropout.cc b/lib/op-attrs/src/op-attrs/ops/dropout.cc index 8dd381f65a..adbd144f38 100644 --- a/lib/op-attrs/src/op-attrs/ops/dropout.cc +++ b/lib/op-attrs/src/op-attrs/ops/dropout.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(DropoutAttrs const &, ParallelTensorShape const &) { +ParallelTensorShape get_output_shape(DropoutAttrs const &, + ParallelTensorShape const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary.cc b/lib/op-attrs/src/op-attrs/ops/element_binary.cc index 0e0fce6f1d..2bbd3a1e2e 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_binary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_binary.cc @@ -2,11 +2,15 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &) { +ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &) { NOT_IMPLEMENTED(); } -TensorShape get_output_shape(ElementBinaryAttrs const &, TensorShape const &, TensorShape const &) { +TensorShape get_output_shape(ElementBinaryAttrs const &, + TensorShape const &, + TensorShape const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/src/op-attrs/ops/element_unary.cc index ab3ee1ba58..800199a51a 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_unary.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(ElementUnaryAttrs const &, ParallelTensorShape const &) { +ParallelTensorShape get_output_shape(ElementUnaryAttrs const &, + ParallelTensorShape const &) { NOT_IMPLEMENTED(); } @@ -10,11 +11,13 @@ TensorShape get_output_shape(ElementUnaryAttrs const &, TensorShape const &) { NOT_IMPLEMENTED(); } -ParallelTensorShape get_output_shape(ElementScalarUnaryAttrs const &, ParallelTensorShape const &) { +ParallelTensorShape get_output_shape(ElementScalarUnaryAttrs const &, + ParallelTensorShape const &) { NOT_IMPLEMENTED(); } -TensorShape get_output_shape(ElementScalarUnaryAttrs const &, TensorShape const &) { +TensorShape get_output_shape(ElementScalarUnaryAttrs const &, + TensorShape const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/embedding.cc b/lib/op-attrs/src/op-attrs/ops/embedding.cc index 2a55266a7f..c8c63c70ea 100644 --- a/lib/op-attrs/src/op-attrs/ops/embedding.cc +++ b/lib/op-attrs/src/op-attrs/ops/embedding.cc @@ -6,7 +6,8 @@ TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &) { NOT_IMPLEMENTED(); } -ParallelTensorShape get_output_shape(EmbeddingAttrs const &, ParallelTensorShape const &) { +ParallelTensorShape get_output_shape(EmbeddingAttrs const &, + ParallelTensorShape const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc index d072fb8b17..437ba3638a 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(LayerNormAttrs const &, ParallelTensorShape const &) { +ParallelTensorShape get_output_shape(LayerNormAttrs const &, + ParallelTensorShape const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index a5fae353de..fbe336c090 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -1,56 +1,64 @@ #include "op-attrs/ops/linear.h" -#include "op-attrs/tensor_shape.h" #include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" #include "utils/integer_conversions.h" namespace FlexFlow { -TensorShape get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { +TensorShape get_kernel_shape(LinearAttrs const &attrs, + TensorShape const &input_shape) { size_t in_channels = dim_at_idx(input_shape, ff_dim_t{-1}); return TensorShape{ - TensorDims{ - FFOrdered{ - in_channels, - size_t_from_int(attrs.out_channels) + TensorDims{ + FFOrdered{in_channels, size_t_from_int(attrs.out_channels)}, }, - }, - input_shape.data_type, + input_shape.data_type, }; } -TensorShape get_bias_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { +TensorShape get_bias_shape(LinearAttrs const &attrs, + TensorShape const &input_shape) { return TensorShape{ - TensorDims{ - FFOrdered{ - size_t_from_int(attrs.out_channels) + TensorDims{ + FFOrdered{size_t_from_int(attrs.out_channels)}, }, - }, - input_shape.data_type, + input_shape.data_type, }; } -TensorShape get_output_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { +TensorShape get_output_shape(LinearAttrs const &attrs, + TensorShape const &input_shape) { TensorShape output_shape = input_shape; - output_shape.dims.ff_ordered.at(ff_dim_t{-1}) = size_t_from_int(attrs.out_channels); + output_shape.dims.ff_ordered.at(ff_dim_t{-1}) = + size_t_from_int(attrs.out_channels); return output_shape; } -ParallelTensorShape get_kernel_shape(LinearAttrs const &attrs, ParallelTensorShape const &input_shape) { +ParallelTensorShape get_kernel_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); - /* ShardParallelDim input_sample_dim = shard_dim_at_idx(input_shape, ff_dim_t{-2}); */ - /* ShardParallelDim in_channels_dim = shard_dim_at_idx(input_shape, ff_dim_t{-1}); */ + /* ShardParallelDim input_sample_dim = shard_dim_at_idx(input_shape, + * ff_dim_t{-2}); */ + /* ShardParallelDim in_channels_dim = shard_dim_at_idx(input_shape, + * ff_dim_t{-1}); */ } -ParallelTensorShape get_output_shape(LinearAttrs const &attrs, ParallelTensorShape const &input_shape) { - ShardParallelDim input_sample_dim = shard_dim_at_idx(input_shape, ff_dim_t{-2}); - ShardParallelDim in_channels_dim = shard_dim_at_idx(input_shape, ff_dim_t{-1}); +ParallelTensorShape get_output_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input_shape) { + ShardParallelDim input_sample_dim = + shard_dim_at_idx(input_shape, ff_dim_t{-2}); + ShardParallelDim in_channels_dim = + shard_dim_at_idx(input_shape, ff_dim_t{-1}); ShardParallelDim output_sample_dim = input_sample_dim; - ShardParallelDim output_channels_dim = { size_t_from_int(attrs.out_channels), input_shape.dims.replica_dims.discard_copy_degree }; + ShardParallelDim output_channels_dim = { + size_t_from_int(attrs.out_channels), + input_shape.dims.replica_dims.discard_copy_degree}; - int output_sum_degree = input_shape.dims.replica_dims.sum_degree * in_channels_dim.degree; + int output_sum_degree = + input_shape.dims.replica_dims.sum_degree * in_channels_dim.degree; int output_discard_copy_degree = 1; ParallelTensorShape result = input_shape; @@ -59,10 +67,10 @@ ParallelTensorShape get_output_shape(LinearAttrs const &attrs, ParallelTensorSha result.dims.replica_dims.sum_degree = output_sum_degree; result.dims.replica_dims.discard_copy_degree = output_discard_copy_degree; - assert (total_parallel_degree(result.dims) == total_parallel_degree(input_shape.dims)); + assert(total_parallel_degree(result.dims) == + total_parallel_degree(input_shape.dims)); return result; } - } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/noop.cc b/lib/op-attrs/src/op-attrs/ops/noop.cc index 1b243a388a..b2b15d820c 100644 --- a/lib/op-attrs/src/op-attrs/ops/noop.cc +++ b/lib/op-attrs/src/op-attrs/ops/noop.cc @@ -2,8 +2,9 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(NoopAttrs const &, ParallelTensorShape const &input_shape) { - return input_shape; +ParallelTensorShape get_output_shape(NoopAttrs const &, + ParallelTensorShape const &input_shape) { + return input_shape; } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index 5f54f2f2d6..cf6ed177d3 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &) { +ParallelTensorShape get_output_shape(Pool2DAttrs const &, + ParallelTensorShape const &) { NOT_IMPLEMENTED(); } @@ -49,8 +50,8 @@ static ParallelDimMappingSolution return solve_parallel_dim_mappings(construct_mappings(input), {input}, 0, 1); } -ParallelTensorShape Pool2DAttrs::calculate_output_shape(ParallelTensorShape const &input) const { - return solve_mappings(input).output_shapes.at(0); +ParallelTensorShape Pool2DAttrs::calculate_output_shape(ParallelTensorShape +const &input) const { return solve_mappings(input).output_shapes.at(0); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reduce.cc b/lib/op-attrs/src/op-attrs/ops/reduce.cc index a08fb4128e..2a8bf06ecf 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduce.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduce.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(ReduceAttrs const &, ParallelTensorShape const &input_shape) { +ParallelTensorShape get_output_shape(ReduceAttrs const &, + ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/reduction.cc b/lib/op-attrs/src/op-attrs/ops/reduction.cc index 37b60dbc60..1a61277076 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduction.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduction.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(ReductionAttrs const &attrs, ParallelTensorShape const &input_shape) { +ParallelTensorShape get_output_shape(ReductionAttrs const &attrs, + ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/repartition.cc b/lib/op-attrs/src/op-attrs/ops/repartition.cc index f35cbb35c7..309247c6e1 100644 --- a/lib/op-attrs/src/op-attrs/ops/repartition.cc +++ b/lib/op-attrs/src/op-attrs/ops/repartition.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(RepartitionAttrs const &, ParallelTensorShape const &input_shape) { +ParallelTensorShape get_output_shape(RepartitionAttrs const &, + ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/replicate.cc b/lib/op-attrs/src/op-attrs/ops/replicate.cc index a639a51f15..261a82464f 100644 --- a/lib/op-attrs/src/op-attrs/ops/replicate.cc +++ b/lib/op-attrs/src/op-attrs/ops/replicate.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, ParallelTensorShape const &input_shape) { +ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, + ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/reshape.cc b/lib/op-attrs/src/op-attrs/ops/reshape.cc index 49ec940525..7d0600550a 100644 --- a/lib/op-attrs/src/op-attrs/ops/reshape.cc +++ b/lib/op-attrs/src/op-attrs/ops/reshape.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, ParallelTensorShape const &input_shape) { +ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, + ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/reverse.cc b/lib/op-attrs/src/op-attrs/ops/reverse.cc index 5afd3e726e..79b5bd50fb 100644 --- a/lib/op-attrs/src/op-attrs/ops/reverse.cc +++ b/lib/op-attrs/src/op-attrs/ops/reverse.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, ParallelTensorShape const &input_shape) { +ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, + ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/softmax.cc b/lib/op-attrs/src/op-attrs/ops/softmax.cc index 05d6645637..2d870af50e 100644 --- a/lib/op-attrs/src/op-attrs/ops/softmax.cc +++ b/lib/op-attrs/src/op-attrs/ops/softmax.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(SoftmaxAttrs const &attrs, ParallelTensorShape const &input_shape) { +ParallelTensorShape get_output_shape(SoftmaxAttrs const &attrs, + ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/split.cc b/lib/op-attrs/src/op-attrs/ops/split.cc index bb3c35c645..cfb4071833 100644 --- a/lib/op-attrs/src/op-attrs/ops/split.cc +++ b/lib/op-attrs/src/op-attrs/ops/split.cc @@ -2,7 +2,9 @@ namespace FlexFlow { -std::vector get_output_shapes(SplitAttrs const &attrs, ParallelTensorShape const &input_shape) { +std::vector + get_output_shapes(SplitAttrs const &attrs, + ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/topk.cc b/lib/op-attrs/src/op-attrs/ops/topk.cc index 5e3607286d..9d2fd35a94 100644 --- a/lib/op-attrs/src/op-attrs/ops/topk.cc +++ b/lib/op-attrs/src/op-attrs/ops/topk.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(TopKAttrs const &attrs, ParallelTensorShape const &input_shape) { +ParallelTensorShape get_output_shape(TopKAttrs const &attrs, + ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/transpose.cc b/lib/op-attrs/src/op-attrs/ops/transpose.cc index a8ce715f99..75f7eb3c18 100644 --- a/lib/op-attrs/src/op-attrs/ops/transpose.cc +++ b/lib/op-attrs/src/op-attrs/ops/transpose.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs, ParallelTensorShape const &input_shape) { +ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs, + ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 16de2347d6..c3cad6de19 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -1,8 +1,8 @@ #include "op-attrs/parallel_tensor_dims.h" +#include "op-attrs/replica_parallel_dim.h" #include "op-attrs/replica_parallel_dim_set.h" -#include "utils/containers.h" #include "op-attrs/shard_parallel_dim.h" -#include "op-attrs/replica_parallel_dim.h" +#include "utils/containers.h" namespace FlexFlow { @@ -10,7 +10,8 @@ FFOrdered ff_ordered_shard_dims(ParallelTensorDims const &d) { return d.shard_dims; } -std::unordered_set replica_dims(ParallelTensorDims const &d) { +std::unordered_set + replica_dims(ParallelTensorDims const &d) { return get_replica_dims(d.replica_dims); } @@ -19,11 +20,14 @@ size_t num_shard_dims(ParallelTensorDims const &dims) { } int total_replica_degree(ParallelTensorDims const &dims) { - return product(transform(replica_dims(dims), [](ReplicaParallelDim const &d) { return d.degree; })); + return product(transform(replica_dims(dims), [](ReplicaParallelDim const &d) { + return d.degree; + })); } int total_shard_degree(ParallelTensorDims const &dims) { - return product(transform(as_vector(dims.shard_dims), [](ShardParallelDim const &d) { return d.degree; })); + return product(transform(as_vector(dims.shard_dims), + [](ShardParallelDim const &d) { return d.degree; })); } int total_parallel_degree(ParallelTensorDims const &dims) { @@ -31,8 +35,10 @@ int total_parallel_degree(ParallelTensorDims const &dims) { } bool is_valid(ParallelTensorDims const &dims) { - return all_of(dims.shard_dims, [](ShardParallelDim const &d) { return is_valid(d); }) - && all_of(replica_dims(dims), [](ReplicaParallelDim const &d) { return is_valid(d); }); + return all_of(dims.shard_dims, + [](ShardParallelDim const &d) { return is_valid(d); }) && + all_of(replica_dims(dims), + [](ReplicaParallelDim const &d) { return is_valid(d); }); } ShardParallelDim shard_dim_at_idx(ParallelTensorDims const &d, ff_dim_t idx) { @@ -54,7 +60,6 @@ TensorDims get_tensor_dims_unsafe(ParallelTensorDims const &) { NOT_IMPLEMENTED(); } - TensorDims get_reduced_dims(ParallelTensorDims const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 9e0afca357..66e99b1e86 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -1,7 +1,7 @@ #include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" #include "utils/containers.h" #include "utils/hash-utils.h" -#include "op-attrs/tensor_dims.h" namespace FlexFlow { @@ -9,7 +9,8 @@ int num_shard_dims(ParallelTensorShape const &s) { return num_shard_dims(s.dims); } -std::unordered_set replica_dims(ParallelTensorShape const &s) { +std::unordered_set + replica_dims(ParallelTensorShape const &s) { return replica_dims(s.dims); } @@ -41,8 +42,8 @@ TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) { TensorShape get_reduced_shape(ParallelTensorShape const &s) { return TensorShape{ - get_reduced_dims(s.dims), - s.data_type, + get_reduced_dims(s.dims), + s.data_type, }; } diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc index 60043a82b7..76ad48d471 100644 --- a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc @@ -4,14 +4,13 @@ namespace FlexFlow { bool is_parallel_op(PCGOperatorAttrs const &attrs) { - return ( attrs.has() - || attrs.has() - || attrs.has() - || attrs.has()); + return (attrs.has() || attrs.has() || + attrs.has() || attrs.has()); } OperatorType get_op_type(PCGOperatorAttrs const &attrs) { - return attrs.visit([](auto const &x) { return get_op_type(x); }); + return attrs.visit( + [](auto const &x) { return get_op_type(x); }); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc index 16bb1508c4..0ce6a40a3b 100644 --- a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc @@ -7,21 +7,24 @@ ReplicaParallelDimSet empty_replica_parallel_dim_set() { return ReplicaParallelDimSet{1, 1}; } -int get_order_of_replica_type(ReplicaParallelDimSet const &s, ReplicaType replica_type) { +int get_order_of_replica_type(ReplicaParallelDimSet const &s, + ReplicaType replica_type) { switch (replica_type) { case ReplicaType::SUM: return s.sum_degree; case ReplicaType::DISCARD_COPY: return s.discard_copy_degree; default: - throw mk_runtime_error(fmt::format("Unexpected ReplicaType value: {}", static_cast(replica_type))); + throw mk_runtime_error(fmt::format("Unexpected ReplicaType value: {}", + static_cast(replica_type))); } } -std::unordered_set get_replica_dims(ReplicaParallelDimSet const &s) { +std::unordered_set + get_replica_dims(ReplicaParallelDimSet const &s) { return std::unordered_set{ - ReplicaParallelDim{s.sum_degree, ReplicaType::SUM}, - ReplicaParallelDim{s.discard_copy_degree, ReplicaType::DISCARD_COPY}, + ReplicaParallelDim{s.sum_degree, ReplicaType::SUM}, + ReplicaParallelDim{s.discard_copy_degree, ReplicaType::DISCARD_COPY}, }; } diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index 9bad6a3b3d..ac657dd620 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -1,7 +1,7 @@ #include "op-attrs/tensor_dims.h" -#include "utils/containers.h" #include "op-attrs/replica_parallel_dim_set.h" #include "op-attrs/shard_parallel_dim.dtg.h" +#include "utils/containers.h" namespace FlexFlow { @@ -21,11 +21,14 @@ size_t dim_at_idx(TensorDims const &dims, ff_dim_t idx) { } ParallelTensorDims lift_to_parallel(TensorDims const &dims) { - std::vector lifted = transform(as_vector(dims.ff_ordered), [](size_t size) { return ShardParallelDim{size, 1}; }); + std::vector lifted = + transform(as_vector(dims.ff_ordered), [](size_t size) { + return ShardParallelDim{size, 1}; + }); return ParallelTensorDims{ - FFOrdered{lifted}, - empty_replica_parallel_dim_set(), + FFOrdered{lifted}, + empty_replica_parallel_dim_set(), }; } diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.cc b/lib/op-attrs/src/op-attrs/tensor_shape.cc index f338b56b59..5b8d6572d9 100644 --- a/lib/op-attrs/src/op-attrs/tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.cc @@ -9,7 +9,7 @@ size_t num_dims(TensorShape const &s) { size_t dim_at_idx(TensorShape const &s, ff_dim_t idx) { if (idx.value < 0) { - return dim_at_idx(s.dims, idx); -} + return dim_at_idx(s.dims, idx); + } } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/test_conv_2d.cc b/lib/op-attrs/test/src/test_conv_2d.cc index 11fd1633ee..b16a26a7b1 100644 --- a/lib/op-attrs/test/src/test_conv_2d.cc +++ b/lib/op-attrs/test/src/test_conv_2d.cc @@ -15,16 +15,16 @@ TEST_SUITE(FF_TEST_SUITE) { bool use_bias = true; Conv2DAttrs attrs = { - /*out_channels=*/out_channels, - /*kernel_h=*/kernel_h, - /*kernel_w=*/kernel_w, - /*stride_h=*/stride_h, - /*stride_w=*/stride_w, - /*padding_h=*/padding_h, - /*padding_w=*/padding_w, - /*groups=*/groups, - /*activation=*/activation, - /*use_bias=*/true, + /*out_channels=*/out_channels, + /*kernel_h=*/kernel_h, + /*kernel_w=*/kernel_w, + /*stride_h=*/stride_h, + /*stride_w=*/stride_w, + /*padding_h=*/padding_h, + /*padding_w=*/padding_w, + /*groups=*/groups, + /*activation=*/activation, + /*use_bias=*/true, }; size_t num_samples = 7; @@ -33,33 +33,28 @@ TEST_SUITE(FF_TEST_SUITE) { size_t input_width = 15; TensorShape input_shape = { - TensorDims{ - FFOrdered{ - num_samples, - input_channels, - input_height, - input_width, - } - }, - DataType::FLOAT, + TensorDims{FFOrdered{ + num_samples, + input_channels, + input_height, + input_width, + }}, + DataType::FLOAT, }; TensorShape result = get_output_shape(attrs, input_shape); - size_t correct_output_height = 3; size_t correct_output_width = 6; TensorShape correct_output_shape = { - TensorDims{ - FFOrdered{ - num_samples, - static_cast(out_channels), - correct_output_height, - correct_output_width, - } - }, - DataType::FLOAT, + TensorDims{FFOrdered{ + num_samples, + static_cast(out_channels), + correct_output_height, + correct_output_width, + }}, + DataType::FLOAT, }; CHECK(result == correct_output_shape); diff --git a/lib/op-attrs/test/src/test_operator_attrs.cc b/lib/op-attrs/test/src/test_operator_attrs.cc index 1973c89fe6..a7724dba69 100644 --- a/lib/op-attrs/test/src/test_operator_attrs.cc +++ b/lib/op-attrs/test/src/test_operator_attrs.cc @@ -2,8 +2,8 @@ #include "op-attrs/computation_graph_op_attrs.dtg.h" #include "op-attrs/pcg_operator_attrs.dtg.h" #include "utils/json.h" -#include #include +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("BatchNormAttrs to/from json") { @@ -14,9 +14,8 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("ComputationGraphAttrs to/from json") { - ComputationGraphOpAttrs correct = ComputationGraphOpAttrs{ - BatchNormAttrs{true} - }; + ComputationGraphOpAttrs correct = + ComputationGraphOpAttrs{BatchNormAttrs{true}}; json j = correct; auto result = j.get(); @@ -24,12 +23,10 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("PCGOperatorAttrs to/from json") { - PCGOperatorAttrs correct = PCGOperatorAttrs{ - RepartitionAttrs{ + PCGOperatorAttrs correct = PCGOperatorAttrs{RepartitionAttrs{ /*repartition_dim=*/ff_dim_t{1}, /*repartition_degree=*/4, - } - }; + }}; json j = correct; auto result = j.get(); diff --git a/lib/op-attrs/test/src/test_regularizer_attrs.cc b/lib/op-attrs/test/src/test_regularizer_attrs.cc index 7b4139ad53..198c3add38 100644 --- a/lib/op-attrs/test/src/test_regularizer_attrs.cc +++ b/lib/op-attrs/test/src/test_regularizer_attrs.cc @@ -8,7 +8,7 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("RC") { CHECK(rc::check("valid variant", [](RegularizerAttrs reg) { - return reg.has() || reg.has(); + return reg.has() || reg.has(); })); } } diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index 7702d1a7f2..23003641cf 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -2,10 +2,10 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H #include "pcg/computation_graph.dtg.h" -#include "pcg/tensor_guid_t.dtg.h" -#include "pcg/tensor_attrs.dtg.h" -#include "pcg/layer_guid_t.dtg.h" #include "pcg/computation_graph/layer_added_result.dtg.h" +#include "pcg/layer_guid_t.dtg.h" +#include "pcg/tensor_attrs.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" namespace FlexFlow { @@ -13,7 +13,10 @@ ComputationGraph make_empty_computation_graph(); std::unordered_set get_layers(ComputationGraph const &); -LayerAddedResult add_layer(ComputationGraph &computation_graph, LayerAttrs const &attrs, std::vector const &inputs, std::vector const &outputs); +LayerAddedResult add_layer(ComputationGraph &computation_graph, + LayerAttrs const &attrs, + std::vector const &inputs, + std::vector const &outputs); TensorAttrs get_tensor_attrs(ComputationGraph const &, tensor_guid_t const &); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index d7cf1e7a18..36869d63eb 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -2,8 +2,8 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_BUILDER_H #include "pcg/computation_graph.dtg.h" -#include "pcg/tensor_guid_t.dtg.h" #include "pcg/initializer_attrs.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" namespace FlexFlow { @@ -14,70 +14,74 @@ struct ComputationGraphBuilder { // C++ APIs for constructing models // Add an exp layer tensor_guid_t exp(tensor_guid_t const &, - std::optional const &name = std::nullopt); + std::optional const &name = std::nullopt); // Add an add layer tensor_guid_t add(tensor_guid_t const &x, - tensor_guid_t const &y, - std::optional const &name = std::nullopt); + tensor_guid_t const &y, + std::optional const &name = std::nullopt); // Add a subtract layer tensor_guid_t subtract(tensor_guid_t const &x, - tensor_guid_t const &y, - std::optional const &name = std::nullopt); + tensor_guid_t const &y, + std::optional const &name = std::nullopt); // Add a multiply layer tensor_guid_t multiply(tensor_guid_t const &x, - tensor_guid_t const &y, - std::optional const &name = std::nullopt); + tensor_guid_t const &y, + std::optional const &name = std::nullopt); // Add a divide layer tensor_guid_t divide(tensor_guid_t const &x, - tensor_guid_t const &y, - std::optional const &name = std::nullopt); + tensor_guid_t const &y, + std::optional const &name = std::nullopt); // Add a max layer tensor_guid_t max(tensor_guid_t const &x, - tensor_guid_t const &y, - std::optional const &name = std::nullopt); + tensor_guid_t const &y, + std::optional const &name = std::nullopt); // Add a min layer tensor_guid_t min(tensor_guid_t const &x, - tensor_guid_t const &y, - std::optional const &name = std::nullopt); + tensor_guid_t const &y, + std::optional const &name = std::nullopt); // Add a rsqrt layer tensor_guid_t rsqrt(tensor_guid_t const &x, - std::optional const &name = std::nullopt); + std::optional const &name = std::nullopt); // Add a pow layer tensor_guid_t pow(tensor_guid_t const &x, - float exponent, - std::optional const &name = std::nullopt); - // Add a scalar multiply layer - tensor_guid_t scalar_multiply(tensor_guid_t const &x, - float scalar, - std::optional const &name = std::nullopt); - tensor_guid_t scalar_add(tensor_guid_t const &x, - float scalar, + float exponent, std::optional const &name = std::nullopt); - tensor_guid_t scalar_sub(tensor_guid_t const &lhs, - float rhs, - std::optional const &name = std::nullopt); - tensor_guid_t scalar_truediv(tensor_guid_t const &numerator, - float denominator, - std::optional const &name = std::nullopt); + // Add a scalar multiply layer + tensor_guid_t + scalar_multiply(tensor_guid_t const &x, + float scalar, + std::optional const &name = std::nullopt); + tensor_guid_t + scalar_add(tensor_guid_t const &x, + float scalar, + std::optional const &name = std::nullopt); + tensor_guid_t + scalar_sub(tensor_guid_t const &lhs, + float rhs, + std::optional const &name = std::nullopt); + tensor_guid_t + scalar_truediv(tensor_guid_t const &numerator, + float denominator, + std::optional const &name = std::nullopt); // Add a sin layer tensor_guid_t sin(tensor_guid_t const &x, - std::optional const &name = std::nullopt); + std::optional const &name = std::nullopt); // Add a cos layer tensor_guid_t cos(tensor_guid_t const &x, - std::optional const &name = std::nullopt); + std::optional const &name = std::nullopt); // Add an activation layer tensor_guid_t relu(tensor_guid_t const &x, - std::optional const &name = std::nullopt); + std::optional const &name = std::nullopt); tensor_guid_t identity(tensor_guid_t const &x, - std::optional const &name = std::nullopt); + std::optional const &name = std::nullopt); tensor_guid_t gelu(tensor_guid_t const &x, - std::optional const &name = std::nullopt); + std::optional const &name = std::nullopt); tensor_guid_t sigmoid(tensor_guid_t const &x, - std::optional const &name = std::nullopt); + std::optional const &name = std::nullopt); tensor_guid_t tanh(tensor_guid_t const &x, - std::optional const &name = std::nullopt); + std::optional const &name = std::nullopt); tensor_guid_t elu(tensor_guid_t const &x, - std::optional const &name = std::nullopt); + std::optional const &name = std::nullopt); // Add a 2D convolutional layer tensor_guid_t conv2d( tensor_guid_t const &input, @@ -97,9 +101,9 @@ struct ComputationGraphBuilder { std::optional const &name = std::nullopt); // Add a dropout layer tensor_guid_t dropout(tensor_guid_t const &input, - float rate, - unsigned long long seed = 0, - std::optional const &name = std::nullopt); + float rate, + unsigned long long seed = 0, + std::optional const &name = std::nullopt); // Add an embedding layer tensor_guid_t embedding( tensor_guid_t const &input, @@ -116,88 +120,97 @@ struct ComputationGraphBuilder { ff_dim_t dim, std::optional const &name = std::nullopt); // Add a cache layer - tensor_guid_t cache(tensor_guid_t const &input, - int num_batches, - std::function - score_f = {}, - std::optional const &name = std::nullopt); - // Add a 2D pooling layer - tensor_guid_t pool2d(tensor_guid_t const &input, - int kernelH, - int kernelW, - int strideH, - int strideW, - int paddingH, - int paddingW, - PoolOp type = PoolOp::MAX, - std::optional const &activation = std::nullopt, - std::optional const &name = std::nullopt); - tensor_guid_t layer_norm(tensor_guid_t const &input, - std::vector const &axes, - bool elementwise_affine, - float eps, - std::optional const &name = std::nullopt); - tensor_guid_t batch_norm(tensor_guid_t const &input, - bool relu = true, - std::optional const &name = std::nullopt); - tensor_guid_t batch_matmul(tensor_guid_t const &A, - tensor_guid_t const &B, - int a_seq_length_dim = -1, - int b_seq_length_dim = -1, - std::optional const &name = std::nullopt); tensor_guid_t - dense(tensor_guid_t const &input, - int outDim, - std::optional activation = std::nullopt, - bool use_bias = true, - DataType data_type = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, - std::optional const &bias_initializer = std::nullopt, + cache(tensor_guid_t const &input, + int num_batches, + std::function + score_f = {}, std::optional const &name = std::nullopt); + // Add a 2D pooling layer + tensor_guid_t + pool2d(tensor_guid_t const &input, + int kernelH, + int kernelW, + int strideH, + int strideW, + int paddingH, + int paddingW, + PoolOp type = PoolOp::MAX, + std::optional const &activation = std::nullopt, + std::optional const &name = std::nullopt); + tensor_guid_t + layer_norm(tensor_guid_t const &input, + std::vector const &axes, + bool elementwise_affine, + float eps, + std::optional const &name = std::nullopt); + tensor_guid_t + batch_norm(tensor_guid_t const &input, + bool relu = true, + std::optional const &name = std::nullopt); + tensor_guid_t + batch_matmul(tensor_guid_t const &A, + tensor_guid_t const &B, + int a_seq_length_dim = -1, + int b_seq_length_dim = -1, + std::optional const &name = std::nullopt); + tensor_guid_t dense( + tensor_guid_t const &input, + int outDim, + std::optional activation = std::nullopt, + bool use_bias = true, + DataType data_type = DataType::FLOAT, + std::optional const &kernel_initializer = std::nullopt, + std::optional const &bias_initializer = std::nullopt, + std::optional const &name = std::nullopt); // Add a cast layer tensor_guid_t cast(tensor_guid_t const &input, - DataType dtype, - std::optional const &name = std::nullopt); + DataType dtype, + std::optional const &name = std::nullopt); // Add a concat layer tensor_guid_t concat(int n, - std::vector const &tensors, - int axis, - std::optional const &name = std::nullopt); + std::vector const &tensors, + int axis, + std::optional const &name = std::nullopt); // Add a mean layer tensor_guid_t mean(tensor_guid_t const &input, - std::vector const &dims, - bool keepdims, - char const *name); + std::vector const &dims, + bool keepdims, + char const *name); // Add a split layer - std::vector split(tensor_guid_t const &input, - std::vector const &split, - int axis, - std::optional const &name = std::nullopt); + std::vector + split(tensor_guid_t const &input, + std::vector const &split, + int axis, + std::optional const &name = std::nullopt); // Add a flat layer tensor_guid_t flat(tensor_guid_t const &input, - std::optional const &name = std::nullopt); + std::optional const &name = std::nullopt); // Add a softmax layer tensor_guid_t softmax(tensor_guid_t const &input, - int dim = -1, - std::optional const &name = std::nullopt); + int dim = -1, + std::optional const &name = std::nullopt); // Create input tensors and constants - tensor_guid_t transpose(tensor_guid_t const &input, - std::vector const &perm, - std::optional const &name = std::nullopt); - tensor_guid_t reduce_sum(tensor_guid_t const &input, - std::vector const &axes, - bool keepdims = false, - std::optional const &name = std::nullopt); - tensor_guid_t reshape(tensor_guid_t const &input, - std::vector const &shape, + tensor_guid_t + transpose(tensor_guid_t const &input, + std::vector const &perm, + std::optional const &name = std::nullopt); + tensor_guid_t + reduce_sum(tensor_guid_t const &input, + std::vector const &axes, + bool keepdims = false, std::optional const &name = std::nullopt); + tensor_guid_t reshape(tensor_guid_t const &input, + std::vector const &shape, + std::optional const &name = std::nullopt); tensor_guid_t reverse(tensor_guid_t const &input, - int axis, - std::optional const &name = std::nullopt); - std::vector top_k(tensor_guid_t const &input, - int k, - bool sorted, - std::optional const &name = std::nullopt); + int axis, + std::optional const &name = std::nullopt); + std::vector + top_k(tensor_guid_t const &input, + int k, + bool sorted, + std::optional const &name = std::nullopt); tensor_guid_t multihead_attention( tensor_guid_t const &query, tensor_guid_t const &key, @@ -222,8 +235,8 @@ struct ComputationGraphBuilder { std::vector get_outputs(LayerAttrs const &) const; tensor_guid_t get_output(LayerAttrs const &, int idx) const; -/* tensor_guid_t at(MultiDiEdge const &) const; */ -/* LayerAttrs at(Node const &) const; */ + /* tensor_guid_t at(MultiDiEdge const &) const; */ + /* LayerAttrs at(Node const &) const; */ private: TensorShape get_shape(tensor_guid_t const &) const; @@ -232,7 +245,7 @@ struct ComputationGraphBuilder { tensor_guid_t as_type(tensor_guid_t const &, DataType, std::string const &); std::vector add_layer(LayerAttrs const &layer, - std::vector const &inputs, + std::vector const &inputs, std::vector const &weights, std::vector const &outputs); @@ -242,9 +255,9 @@ struct ComputationGraphBuilder { TensorAttrs const &output); std::vector add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs); + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs); tensor_guid_t add_layer(LayerAttrs const &layer, std::vector const &inputs, @@ -254,25 +267,29 @@ struct ComputationGraphBuilder { TensorShape get_broadcast_target_shape(std::vector const &); TensorShape get_broadcast_target_shape(std::vector const &); - tensor_guid_t element_binary(OperatorType, - tensor_guid_t const &lhs, - tensor_guid_t const &rhs, - std::optional const &name = std::nullopt); + tensor_guid_t + element_binary(OperatorType, + tensor_guid_t const &lhs, + tensor_guid_t const &rhs, + std::optional const &name = std::nullopt); - tensor_guid_t element_unary(OperatorType, - tensor_guid_t const &input, - std::optional const &name = std::nullopt); + tensor_guid_t + element_unary(OperatorType, + tensor_guid_t const &input, + std::optional const &name = std::nullopt); tensor_guid_t element_scalar_unary( OperatorType, tensor_guid_t const &input, float scalar, std::optional const &name = std::nullopt); - tensor_guid_t element_unary(ElementUnaryAttrs const &, - tensor_guid_t const &input, - std::optional const &name = std::nullopt); - tensor_guid_t element_scalar_unary(ElementScalarUnaryAttrs const &attrs, - tensor_guid_t const &x, - std::optional const &maybe_name); + tensor_guid_t + element_unary(ElementUnaryAttrs const &, + tensor_guid_t const &input, + std::optional const &name = std::nullopt); + tensor_guid_t + element_scalar_unary(ElementScalarUnaryAttrs const &attrs, + tensor_guid_t const &x, + std::optional const &maybe_name); public: ComputationGraph computation_graph; diff --git a/lib/pcg/include/pcg/create_grad.h b/lib/pcg/include/pcg/create_grad.h index 26ba88f1b2..5a12d310c2 100644 --- a/lib/pcg/include/pcg/create_grad.h +++ b/lib/pcg/include/pcg/create_grad.h @@ -3,8 +3,6 @@ #include "pcg/create_grad_t.h" -namespace FlexFlow { - -} +namespace FlexFlow {} #endif diff --git a/lib/pcg/include/pcg/dataflow_graph.h b/lib/pcg/include/pcg/dataflow_graph.h index a7affaab83..f649c0444c 100644 --- a/lib/pcg/include/pcg/dataflow_graph.h +++ b/lib/pcg/include/pcg/dataflow_graph.h @@ -8,15 +8,19 @@ namespace FlexFlow { template struct DataflowGraph { -public: - DataflowGraph() - : g(OutputLabelledMultiDiGraph::template create< - UnorderedOutputLabelledMultiDiGraph>()) { } +public: + DataflowGraph() + : g(OutputLabelledMultiDiGraph::template create< + UnorderedOutputLabelledMultiDiGraph>()) {} - std::vector add_operator(NodeLabel const &func, std::vector const &inputs, std::vector const &outputs) { + std::vector + add_operator(NodeLabel const &func, + std::vector const &inputs, + std::vector const &outputs) { Node n = this->g.add_node(func); for (auto const &[idx, input] : enumerate_vector(inputs)) { - this->g.add_edge(MultiDiEdge{input.src, input.src_idx, n, this->make_port_for_idx(idx)}); + this->g.add_edge(MultiDiEdge{ + input.src, input.src_idx, n, this->make_port_for_idx(idx)}); } std::vector result; @@ -32,7 +36,7 @@ struct DataflowGraph { NodePort make_port_for_idx(int idx) { if (!this->port_mapping.contains_l(idx)) { this->port_mapping.equate(idx, this->g.add_node_port()); - } + } return this->port_mapping.at_l(idx); } @@ -41,10 +45,11 @@ struct DataflowGraph { } int idx_for_port(NodePort const &p) const { - return this->port_mapping.at_r(p); + return this->port_mapping.at_r(p); } - OutputLabelledMultiDiGraphView const &get_raw_graph() const { + OutputLabelledMultiDiGraphView const & + get_raw_graph() const { return this->g; } @@ -55,13 +60,15 @@ struct DataflowGraph { OutputLabel const &at(MultiDiOutput const &o) const { return this->g.at(o); } + private: OutputLabelledMultiDiGraph g; bidict port_mapping; }; template -std::unordered_set get_nodes(DataflowGraph const &g) { +std::unordered_set + get_nodes(DataflowGraph const &g) { return get_nodes(g.get_raw_graph()); } diff --git a/lib/pcg/include/pcg/device_id.h b/lib/pcg/include/pcg/device_id.h index 9c38674e82..be92be7081 100644 --- a/lib/pcg/include/pcg/device_id.h +++ b/lib/pcg/include/pcg/device_id.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_DEVICE_ID_H #define _FLEXFLOW_PCG_INCLUDE_PCG_DEVICE_ID_H -#include "pcg/device_type.dtg.h" #include "pcg/cpu_id_t.dtg.h" -#include "pcg/gpu_id_t.dtg.h" #include "pcg/device_id_t.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/gpu_id_t.dtg.h" namespace FlexFlow { diff --git a/lib/pcg/include/pcg/file_format/v1/graphs.h b/lib/pcg/include/pcg/file_format/v1/graphs.h index 6417a549cb..dad73ce142 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs.h @@ -1,14 +1,14 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H -#include "pcg/parallel_tensor_attrs.dtg.h" -#include "pcg/tensor_attrs.dtg.h" #include "pcg/computation_graph.dtg.h" -#include "pcg/parallel_computation_graph.dtg.h" #include "pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h" -#include "utils/json.h" #include "pcg/layer_attrs.dtg.h" +#include "pcg/parallel_computation_graph.dtg.h" #include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "pcg/tensor_attrs.dtg.h" +#include "utils/json.h" namespace FlexFlow { diff --git a/lib/pcg/include/pcg/machine_specification.h b/lib/pcg/include/pcg/machine_specification.h index 3886dcfe2e..cf84bf5048 100644 --- a/lib/pcg/include/pcg/machine_specification.h +++ b/lib/pcg/include/pcg/machine_specification.h @@ -3,8 +3,6 @@ #include "machine_specification_t.h" -namespace FlexFlow { - -} // namespace FlexFlow +namespace FlexFlow {} // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index 60837a2abf..625b128d35 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -1,12 +1,12 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H #define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H -#include "pcg/machine_view.dtg.h" +#include "pcg/cpu_id_t.dtg.h" #include "pcg/device_id_t.dtg.h" #include "pcg/device_type.dtg.h" -#include "pcg/num_points_t.dtg.h" -#include "pcg/cpu_id_t.dtg.h" #include "pcg/gpu_id_t.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/num_points_t.dtg.h" #include "pcg/side_size_t.dtg.h" #include #include diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph.h b/lib/pcg/include/pcg/operator_graph/operator_graph.h index 2140ff1555..5fca50d4c7 100644 --- a/lib/pcg/include/pcg/operator_graph/operator_graph.h +++ b/lib/pcg/include/pcg/operator_graph/operator_graph.h @@ -1,14 +1,14 @@ #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_H #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_H -#include "utils/graph.h" -#include "pcg/operator_graph/operator_graph_output.dtg.h" #include "pcg/operator_graph/operator_graph_input.dtg.h" +#include "pcg/operator_graph/operator_graph_output.dtg.h" +#include "utils/graph.h" namespace FlexFlow { -struct OperatorGraphOutputQuery { }; -struct OperatorGraphEdge { }; +struct OperatorGraphOutputQuery {}; +struct OperatorGraphEdge {}; Node get_src_node(OperatorGraphEdge const &); Node get_dst_node(OperatorGraphEdge const &); @@ -29,8 +29,10 @@ struct OperatorGraphView { OperatorGraphView &&operator=(OperatorGraphView &&); std::unordered_set query_nodes(NodeQuery const &) const; - std::unordered_set query_outputs(OperatorGraphOutputQuery const &) const; - std::unordered_set query_edges(OperatorGraphEdgeQuery const &) const; + std::unordered_set + query_outputs(OperatorGraphOutputQuery const &) const; + std::unordered_set + query_edges(OperatorGraphEdgeQuery const &) const; struct Impl; std::unique_ptr impl; @@ -38,8 +40,10 @@ struct OperatorGraphView { CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(OperatorGraphView); std::unordered_set get_outputs(OperatorGraphView const &); -std::vector get_outputs(OperatorGraphView const &, Node const &); -std::unordered_set get_uses(OperatorGraphView const &, OperatorGraphOutput const &); +std::vector get_outputs(OperatorGraphView const &, + Node const &); +std::unordered_set get_uses(OperatorGraphView const &, + OperatorGraphOutput const &); struct OperatorGraph { public: @@ -47,7 +51,8 @@ struct OperatorGraph { OperatorGraph(OperatorGraph const &) = default; OperatorGraph &operator=(OperatorGraph const &) = default; - Node add_node(std::vector const &inputs, int num_outputs); + Node add_node(std::vector const &inputs, + int num_outputs); private: struct Impl; @@ -58,13 +63,16 @@ struct value_t; template struct LabelledOperatorGraphView : virtual OperatorGraphView { - NodeLabel const &at(Node const &) const; - OutputLabel const &at(OperatorGraphOutput const &) const; + NodeLabel const &at(Node const &) const; + OutputLabel const &at(OperatorGraphOutput const &) const; }; template -struct LabelledOperatorGraph : virtual LabelledOperatorGraphView { - Node add_node(NodeLabel const &, std::vector const &inputs, std::vector const &output_labels); +struct LabelledOperatorGraph + : virtual LabelledOperatorGraphView { + Node add_node(NodeLabel const &, + std::vector const &inputs, + std::vector const &output_labels); }; } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/optimizer_attrs.h b/lib/pcg/include/pcg/optimizer_attrs.h index 03ba461c45..4bac74b999 100644 --- a/lib/pcg/include/pcg/optimizer_attrs.h +++ b/lib/pcg/include/pcg/optimizer_attrs.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H #define _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H -#include "utils/variant.h" #include "pcg/optimizers/adam_optimizer_attrs.h" #include "pcg/optimizers/sgd_optimizer_attrs.h" +#include "utils/variant.h" namespace FlexFlow { diff --git a/lib/pcg/include/pcg/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph.h index 4dc2db5de4..9d7103f4fd 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph.h @@ -3,6 +3,6 @@ #include "pcg/parallel_computation_graph_t.h" -namespace FlexFlow { } +namespace FlexFlow {} #endif diff --git a/lib/pcg/include/pcg/parallel_tensor.h b/lib/pcg/include/pcg/parallel_tensor.h index 8fd2fc0e17..de41e0fb21 100644 --- a/lib/pcg/include/pcg/parallel_tensor.h +++ b/lib/pcg/include/pcg/parallel_tensor.h @@ -23,9 +23,7 @@ #include "pcg/parallel_tensor_attrs.h" -namespace FlexFlow { - -} // namespace FlexFlow +namespace FlexFlow {} // namespace FlexFlow namespace FlexFlow { static_assert(is_well_behaved_value_type::value, ""); diff --git a/lib/pcg/include/pcg/strided_rectangle.dtg.h b/lib/pcg/include/pcg/strided_rectangle.dtg.h index cacc11093d..df6a16a0ad 100644 --- a/lib/pcg/include/pcg/strided_rectangle.dtg.h +++ b/lib/pcg/include/pcg/strided_rectangle.dtg.h @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/strided_rectangle.struct.toml /* proj-data { - "generated_from": "87af84e6a16d5363049cb9a9a75e4f5f" + "generated_from": "817bbe017d179aa469822a4032d08836" } */ @@ -14,6 +14,7 @@ #include "nlohmann/json.hpp" #include "op-attrs/dim_ordered.h" #include "pcg/strided_rectangle_side.dtg.h" +#include "rapidcheck.h" #include #include #include @@ -49,6 +50,13 @@ struct adl_serializer { }; } // namespace nlohmann +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + namespace FlexFlow { std::string format_as(StridedRectangle const &); std::ostream &operator<<(std::ostream &, StridedRectangle const &); diff --git a/lib/pcg/include/pcg/strided_rectangle.h b/lib/pcg/include/pcg/strided_rectangle.h index 48bd4e8146..24ae51ac41 100644 --- a/lib/pcg/include/pcg/strided_rectangle.h +++ b/lib/pcg/include/pcg/strided_rectangle.h @@ -8,7 +8,8 @@ namespace FlexFlow { size_t get_num_dims(StridedRectangle const &); -StridedRectangleSide get_side_at_idx(StridedRectangle const &, ff_dim_t const &); +StridedRectangleSide get_side_at_idx(StridedRectangle const &, + ff_dim_t const &); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/strided_rectangle_side.h b/lib/pcg/include/pcg/strided_rectangle_side.h index 540bb76bc8..1486b73143 100644 --- a/lib/pcg/include/pcg/strided_rectangle_side.h +++ b/lib/pcg/include/pcg/strided_rectangle_side.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_H #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_H -#include "pcg/strided_rectangle_side.dtg.h" #include "pcg/side_size_t.dtg.h" +#include "pcg/strided_rectangle_side.dtg.h" namespace FlexFlow { diff --git a/lib/pcg/src/file_format/v1/graphs.cc b/lib/pcg/src/file_format/v1/graphs.cc index 0a11842709..eabd266e25 100644 --- a/lib/pcg/src/file_format/v1/graphs.cc +++ b/lib/pcg/src/file_format/v1/graphs.cc @@ -1,13 +1,14 @@ #include "pcg/file_format/v1/graphs.h" -#include "utils/graph/algorithms.h" +#include "pcg/dataflow_graph.h" +#include "pcg/file_format/v1/graphs/v1_multidigraph.h" #include "pcg/file_format/v1/graphs/v1_operator_graph.dtg.h" +#include "utils/graph/algorithms.h" #include "utils/integer_conversions.h" -#include "pcg/file_format/v1/graphs/v1_multidigraph.h" -#include "pcg/dataflow_graph.h" namespace FlexFlow { -/* static V1OperatorGraph to_v1(OperatorGraphView const &g, bidict const &nodes) { */ +/* static V1OperatorGraph to_v1(OperatorGraphView const &g, bidict + * const &nodes) { */ /* std::unordered_set edges; */ /* for (MultiDiEdge const &e : get_edges(g)) { */ /* size_t src_node = nodes.at_l(get_src_node(e)); */ @@ -24,10 +25,9 @@ namespace FlexFlow { /* }; */ /* } */ - static V1MultiDiGraph to_v1(MultiDiGraphView const &g, - bidict const &nodes, - bidict const &node_ports) { + bidict const &nodes, + bidict const &node_ports) { std::unordered_set edges; for (MultiDiEdge const &e : get_edges(g)) { edges.insert({nodes.at_l(e.src), @@ -50,7 +50,7 @@ static V1MultiDiGraph to_v1(MultiDiGraphView const &g, /* } */ /* template */ -/* static V1JsonableGraph */ +/* static V1JsonableGraph */ /* to_v1(LabelledOperatorGraphView const &g) { */ /* bidict nodes = enumerate(get_nodes(g)); */ @@ -59,20 +59,24 @@ static V1MultiDiGraph to_v1(MultiDiGraphView const &g, /* std::unordered_map node_labels = */ /* map_values(nodes, [&](Node const &n) { return g.at(n); }); */ -/* bidict outputs_bidict = enumerate(get_outputs(g)); */ +/* bidict outputs_bidict = + * enumerate(get_outputs(g)); */ /* std::unordered_map outputs = */ /* map_values(outputs_bidict, [&](OperatorGraphOutput const &o) { */ -/* return V1GraphOutput{nodes.at_r(get_node(o)), size_t_from_int(get_idx(o))}; */ +/* return V1GraphOutput{nodes.at_r(get_node(o)), + * size_t_from_int(get_idx(o))}; */ /* }); */ /* std::unordered_map output_labels = map_values( */ -/* outputs_bidict, [&](OperatorGraphOutput const &o) { return g.at(o); }); */ +/* outputs_bidict, [&](OperatorGraphOutput const &o) { return g.at(o); }); + */ /* return {node_labels, outputs, output_labels, unlabelled}; */ /* } */ template -static bidict get_ports_by_idx(DataflowGraph const &g) { +static bidict + get_ports_by_idx(DataflowGraph const &g) { bidict result; for (NodePort const &p : get_present_node_ports(g.get_raw_graph())) { size_t idx = size_t_from_int(g.idx_for_port(p)); @@ -88,11 +92,13 @@ static V1JsonableGraph bidict nodes = enumerate(get_nodes(g.get_raw_graph())); bidict node_ports = get_ports_by_idx(g); - V1MultiDiGraph unlabelled = to_v1(g.get_raw_graph(), nodes.reversed(), node_ports.reversed()); + V1MultiDiGraph unlabelled = + to_v1(g.get_raw_graph(), nodes.reversed(), node_ports.reversed()); std::unordered_map node_labels = map_values(nodes, [&](Node const &n) { return g.at(n); }); - bidict outputs_bidict = enumerate(get_outputs(g.get_raw_graph())); + bidict outputs_bidict = + enumerate(get_outputs(g.get_raw_graph())); std::unordered_map outputs = map_values(outputs_bidict, [&](MultiDiOutput const &o) { return V1GraphOutput{nodes.at_r(o.src), node_ports.at_r(o.src_idx)}; diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index 3c21f32697..12a72ca837 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -4,16 +4,16 @@ namespace FlexFlow { ComputationGraph make_empty_computation_graph() { - return ComputationGraph{ - DataflowGraph{} - }; + return ComputationGraph{DataflowGraph{}}; } std::unordered_set get_layers(ComputationGraph const &cg) { - return transform(get_nodes(cg.raw_graph), [&](Node const &n) { return layer_guid_t{n}; }); + return transform(get_nodes(cg.raw_graph), + [&](Node const &n) { return layer_guid_t{n}; }); } -TensorAttrs get_tensor_attrs(ComputationGraph const &cg, tensor_guid_t const &t) { +TensorAttrs get_tensor_attrs(ComputationGraph const &cg, + tensor_guid_t const &t) { return cg.raw_graph.at(t.raw_graph_output); } diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index bd381c6047..7969b40ac7 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -1,116 +1,128 @@ #include "pcg/computation_graph_builder.h" +#include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/get_op_type.h" #include "op-attrs/get_output_shapes.h" -#include "op-attrs/ops/weight_attrs.dtg.h" -#include "utils/expected.h" -#include "utils/fmt.h" #include "op-attrs/ops/element_binary.h" #include "op-attrs/ops/embedding.h" -#include "op-attrs/computation_graph_op_attrs.h" -#include "utils/containers.h" -#include "utils/containers/enumerate_vector.h" +#include "op-attrs/ops/weight_attrs.dtg.h" #include "pcg/computation_graph.h" +#include "utils/containers.h" #include "utils/containers/concat_vectors.h" - +#include "utils/containers/enumerate_vector.h" +#include "utils/expected.h" +#include "utils/fmt.h" + namespace FlexFlow { -ComputationGraphBuilder::ComputationGraphBuilder() - : computation_graph(make_empty_computation_graph()) { } +ComputationGraphBuilder::ComputationGraphBuilder() + : computation_graph(make_empty_computation_graph()) {} TensorShape ComputationGraphBuilder::get_shape(tensor_guid_t const &t) const { return get_tensor_attrs(this->computation_graph, t).shape; } -tensor_guid_t ComputationGraphBuilder::create_tensor(TensorShape const &shape, bool create_grad) { +tensor_guid_t ComputationGraphBuilder::create_tensor(TensorShape const &shape, + bool create_grad) { TensorAttrs tensor_attrs = {shape, std::nullopt, create_grad, std::nullopt}; LayerAttrs layer_attrs = LayerAttrs{ - ComputationGraphOpAttrs{InputAttrs{}}, - std::nullopt, + ComputationGraphOpAttrs{InputAttrs{}}, + std::nullopt, }; return this->add_layer(layer_attrs, {}, {}, tensor_attrs); } -std::vector ComputationGraphBuilder::add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs) { +std::vector ComputationGraphBuilder::add_layer( + LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs) { std::vector raw_weight_tensors; for (auto const &kv : enumerate_vector(weights)) { int weight_idx = kv.first; TensorAttrs weight_tensor_attrs = kv.second; - std::optional weight_name = transform(layer.name, [&](std::string const &layer_name) { return fmt::format("{}.weights[{}]", layer_name, weight_idx); }); + std::optional weight_name = + transform(layer.name, [&](std::string const &layer_name) { + return fmt::format("{}.weights[{}]", layer_name, weight_idx); + }); LayerAttrs weight_layer_attrs = LayerAttrs{ - ComputationGraphOpAttrs{WeightAttrs{}}, - weight_name, + ComputationGraphOpAttrs{WeightAttrs{}}, + weight_name, }; std::vector weight_layer_inputs = {}; std::vector weight_output_attrs = {weight_tensor_attrs}; raw_weight_tensors.push_back( - get_only(this->computation_graph.raw_graph.add_operator(weight_layer_attrs, weight_layer_inputs, weight_output_attrs)) - ); + get_only(this->computation_graph.raw_graph.add_operator( + weight_layer_attrs, weight_layer_inputs, weight_output_attrs))); } - std::vector raw_inputs = transform(inputs, [](tensor_guid_t const &t) { return t.raw_graph_output; }); - std::vector raw_outputs = this->computation_graph.raw_graph.add_operator(layer, concat_vectors(raw_inputs, raw_weight_tensors), outputs); - return transform(raw_outputs, [](MultiDiOutput const &o) { return tensor_guid_t{o}; }); + std::vector raw_inputs = transform( + inputs, [](tensor_guid_t const &t) { return t.raw_graph_output; }); + std::vector raw_outputs = + this->computation_graph.raw_graph.add_operator( + layer, concat_vectors(raw_inputs, raw_weight_tensors), outputs); + return transform(raw_outputs, + [](MultiDiOutput const &o) { return tensor_guid_t{o}; }); } -tensor_guid_t ComputationGraphBuilder::add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorAttrs const &output) { +tensor_guid_t + ComputationGraphBuilder::add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorAttrs const &output) { std::vector outputs = {output}; return get_only(this->add_layer(layer, inputs, weights, outputs)); } -std::vector ComputationGraphBuilder::add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs) { - return this->add_layer(layer, - inputs, - weights, - transform(outputs, [](TensorShape const &s) { return TensorAttrs{s, std::nullopt, true, std::nullopt}; })); +std::vector ComputationGraphBuilder::add_layer( + LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs) { + return this->add_layer( + layer, inputs, weights, transform(outputs, [](TensorShape const &s) { + return TensorAttrs{s, std::nullopt, true, std::nullopt}; + })); } -tensor_guid_t ComputationGraphBuilder::add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorShape const &output) { - return get_only(this->add_layer(layer, - inputs, - weights, - std::vector{output})); +tensor_guid_t + ComputationGraphBuilder::add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorShape const &output) { + return get_only(this->add_layer( + layer, inputs, weights, std::vector{output})); } -tensor_guid_t ComputationGraphBuilder::as_type( - tensor_guid_t const &x, - DataType data_type, - std::string const &name) { +tensor_guid_t ComputationGraphBuilder::as_type(tensor_guid_t const &x, + DataType data_type, + std::string const &name) { DataType x_datatype = this->get_shape(x).data_type; if (x_datatype < data_type) { return this->cast(x, data_type, name); } else if (x_datatype > data_type) { - throw mk_runtime_error(fmt::format("Could not convert provided tensor data type {} to " - "desired data type {}", - x_datatype, - data_type)); + throw mk_runtime_error( + fmt::format("Could not convert provided tensor data type {} to " + "desired data type {}", + x_datatype, + data_type)); } else { return x; } } - -tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &, TensorShape const &) { +tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &, + TensorShape const &) { NOT_IMPLEMENTED(); } -tensor_guid_t ComputationGraphBuilder::cast(tensor_guid_t const &input, - DataType dtype, - std::optional const &name){ - NOT_IMPLEMENTED()} +tensor_guid_t + ComputationGraphBuilder::cast(tensor_guid_t const &input, + DataType dtype, + std::optional const &name) { + NOT_IMPLEMENTED() +} static std::string get_default_name(OperatorType op_type) { return get_operator_type_name(op_type); @@ -124,37 +136,30 @@ tensor_guid_t ComputationGraphBuilder::element_unary( ElementUnaryAttrs const &attrs, tensor_guid_t const &x, std::optional const &maybe_name) { - std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + tensor_guid_t input = + this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - LayerAttrs layer = LayerAttrs{ - ComputationGraphOpAttrs{attrs}, - name - }; + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - return this->add_layer( - layer, - {input}, - {}, - output_shape - ); + return this->add_layer(layer, {input}, {}, output_shape); } tensor_guid_t ComputationGraphBuilder::element_scalar_unary( ElementScalarUnaryAttrs const &attrs, tensor_guid_t const &x, std::optional const &maybe_name) { - std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + tensor_guid_t input = + this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - LayerAttrs layer = { - ComputationGraphOpAttrs{attrs}, - name - }; + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); @@ -186,42 +191,36 @@ tensor_guid_t ComputationGraphBuilder::element_binary( std::string name = maybe_name.value_or(get_default_name(op_type)); TensorShape compute_shape = this->get_broadcast_target_shape({lhs, rhs}); - DataType compute_type = std::max( - this->get_shape(lhs).data_type, - this->get_shape(rhs).data_type - ); + DataType compute_type = + std::max(this->get_shape(lhs).data_type, this->get_shape(rhs).data_type); tensor_guid_t lhs_input = this->as_type(this->broadcast(lhs, compute_shape), - compute_type, - name + "_inputl_pre_cast"); + compute_type, + name + "_inputl_pre_cast"); tensor_guid_t rhs_input = this->as_type(this->broadcast(rhs, compute_shape), - compute_type, - name + "_inputr_pre_cast"); + compute_type, + name + "_inputr_pre_cast"); ElementBinaryAttrs attrs = {op_type, compute_type, false, false}; - LayerAttrs layer = { - ComputationGraphOpAttrs{attrs}, - name - }; + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; TensorShape output_shape = get_output_shape( - attrs, - this->get_shape(lhs_input), - this->get_shape(rhs_input) - ); + attrs, this->get_shape(lhs_input), this->get_shape(rhs_input)); return this->add_layer(layer, {lhs_input, rhs_input}, {}, output_shape); } -tensor_guid_t ComputationGraphBuilder::exp(tensor_guid_t const &input, - std::optional const &name) { +tensor_guid_t + ComputationGraphBuilder::exp(tensor_guid_t const &input, + std::optional const &name) { return this->element_unary(OperatorType::EXP, input, name); } -tensor_guid_t ComputationGraphBuilder::add(tensor_guid_t const &lhs, - tensor_guid_t const &rhs, - std::optional const &name) { +tensor_guid_t + ComputationGraphBuilder::add(tensor_guid_t const &lhs, + tensor_guid_t const &rhs, + std::optional const &name) { return this->element_binary(OperatorType::EW_ADD, lhs, rhs, name); } @@ -239,47 +238,60 @@ tensor_guid_t return this->element_binary(OperatorType::EW_MUL, lhs, rhs, name); } -tensor_guid_t ComputationGraphBuilder::divide(tensor_guid_t const &lhs, - tensor_guid_t const &rhs, - std::optional const &name) { +tensor_guid_t + ComputationGraphBuilder::divide(tensor_guid_t const &lhs, + tensor_guid_t const &rhs, + std::optional const &name) { return this->element_binary(OperatorType::EW_DIV, lhs, rhs, name); } -tensor_guid_t ComputationGraphBuilder::max(tensor_guid_t const &lhs, - tensor_guid_t const &rhs, - std::optional const &name) { +tensor_guid_t + ComputationGraphBuilder::max(tensor_guid_t const &lhs, + tensor_guid_t const &rhs, + std::optional const &name) { return this->element_binary(OperatorType::EW_MAX, lhs, rhs, name); } -tensor_guid_t ComputationGraphBuilder::min(tensor_guid_t const &lhs, - tensor_guid_t const &rhs, - std::optional const &name) { +tensor_guid_t + ComputationGraphBuilder::min(tensor_guid_t const &lhs, + tensor_guid_t const &rhs, + std::optional const &name) { return this->element_binary(OperatorType::EW_MIN, lhs, rhs, name); } -tensor_guid_t ComputationGraphBuilder::rsqrt(tensor_guid_t const &input, - std::optional const &name) { +tensor_guid_t + ComputationGraphBuilder::rsqrt(tensor_guid_t const &input, + std::optional const &name) { return this->element_unary(OperatorType::RSQRT, input, name); } -tensor_guid_t ComputationGraphBuilder::pow(tensor_guid_t const &input, - float exponent, - std::optional const &name) { +tensor_guid_t + ComputationGraphBuilder::pow(tensor_guid_t const &input, + float exponent, + std::optional const &name) { return this->element_scalar_unary(OperatorType::POW, input, exponent, name); } tensor_guid_t ComputationGraphBuilder::scalar_multiply( - tensor_guid_t const &input, float scalar, std::optional const &name) { - return this->element_scalar_unary(OperatorType::SCALAR_MULTIPLY, input, scalar, name); + tensor_guid_t const &input, + float scalar, + std::optional const &name) { + return this->element_scalar_unary( + OperatorType::SCALAR_MULTIPLY, input, scalar, name); } tensor_guid_t ComputationGraphBuilder::scalar_add( - tensor_guid_t const &input, float scalar, std::optional const &name) { - return this->element_scalar_unary(OperatorType::SCALAR_ADD, input, scalar, name); + tensor_guid_t const &input, + float scalar, + std::optional const &name) { + return this->element_scalar_unary( + OperatorType::SCALAR_ADD, input, scalar, name); } tensor_guid_t ComputationGraphBuilder::scalar_sub( - tensor_guid_t const &lhs, float rhs, std::optional const &name) { + tensor_guid_t const &lhs, + float rhs, + std::optional const &name) { return this->element_scalar_unary(OperatorType::SCALAR_SUB, lhs, rhs, name); } @@ -291,18 +303,21 @@ tensor_guid_t ComputationGraphBuilder::scalar_truediv( OperatorType::SCALAR_TRUE_DIV, numerator, denominator, name); } -tensor_guid_t ComputationGraphBuilder::sin(tensor_guid_t const &input, - std::optional const &name) { +tensor_guid_t + ComputationGraphBuilder::sin(tensor_guid_t const &input, + std::optional const &name) { return this->element_unary(OperatorType::SIN, input, name); } -tensor_guid_t ComputationGraphBuilder::cos(tensor_guid_t const &input, - std::optional const &name) { +tensor_guid_t + ComputationGraphBuilder::cos(tensor_guid_t const &input, + std::optional const &name) { return this->element_unary(OperatorType::COS, input, name); } -tensor_guid_t ComputationGraphBuilder::relu(tensor_guid_t const &input, - std::optional const &name) { +tensor_guid_t + ComputationGraphBuilder::relu(tensor_guid_t const &input, + std::optional const &name) { return this->element_unary(OperatorType::RELU, input, name); } @@ -312,8 +327,9 @@ tensor_guid_t return this->element_unary(OperatorType::IDENTITY, input, name); } -tensor_guid_t ComputationGraphBuilder::gelu(tensor_guid_t const &input, - std::optional const &name) { +tensor_guid_t + ComputationGraphBuilder::gelu(tensor_guid_t const &input, + std::optional const &name) { return this->element_unary(OperatorType::GELU, input, name); } @@ -323,17 +339,21 @@ tensor_guid_t return this->element_unary(OperatorType::SIGMOID, input, name); } -tensor_guid_t ComputationGraphBuilder::tanh(tensor_guid_t const &input, - std::optional const &name) { +tensor_guid_t + ComputationGraphBuilder::tanh(tensor_guid_t const &input, + std::optional const &name) { return this->element_unary(OperatorType::TANH, input, name); } -tensor_guid_t ComputationGraphBuilder::elu(tensor_guid_t const &input, - std::optional const &name) { +tensor_guid_t + ComputationGraphBuilder::elu(tensor_guid_t const &input, + std::optional const &name) { return this->element_unary(OperatorType::ELU, input, name); } -static TensorAttrs make_weight_attrs(TensorShape const &shape, std::optional const &initializer_attrs) { +static TensorAttrs make_weight_attrs( + TensorShape const &shape, + std::optional const &initializer_attrs) { return TensorAttrs{shape, initializer_attrs, true, std::nullopt}; } @@ -364,24 +384,25 @@ tensor_guid_t ComputationGraphBuilder::conv2d( activation, use_bias}; - std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + tensor_guid_t input = + this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; - LayerAttrs layer = { - ComputationGraphOpAttrs{attrs}, - name - }; - TensorShape input_shape = this->get_shape(input); TensorShape output_shape = get_output_shape(attrs, input_shape); std::vector weights; - weights.push_back(make_weight_attrs(get_kernel_shape(attrs, input_shape), kernel_initializer)); + weights.push_back(make_weight_attrs(get_kernel_shape(attrs, input_shape), + kernel_initializer)); if (use_bias) { - weights.push_back(make_weight_attrs(get_bias_shape(attrs, input_shape), bias_initializer)); + weights.push_back(make_weight_attrs(get_bias_shape(attrs, input_shape), + bias_initializer)); } return this->add_layer(layer, {input}, weights, output_shape); @@ -393,10 +414,12 @@ tensor_guid_t ComputationGraphBuilder::dropout( unsigned long long seed, std::optional const &maybe_name) { DropoutAttrs attrs = {rate, seed}; - std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; - tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + tensor_guid_t input = + this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); @@ -410,25 +433,23 @@ tensor_guid_t ComputationGraphBuilder::embedding( AggregateOp aggr, DataType dtype, std::optional const &kernel_initializer, - std::optional const &maybe_name) -{ + std::optional const &maybe_name) { EmbeddingAttrs attrs = {num_entries, outDim, aggr, dtype}; - std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; - tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + tensor_guid_t input = + this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); TensorShape input_shape = this->get_shape(input); TensorAttrs weight_attrs = make_weight_attrs( - get_weights_shape(attrs, input_shape), - kernel_initializer - ); + get_weights_shape(attrs, input_shape), kernel_initializer); TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - return this->add_layer( - layer, {input}, {weight_attrs}, output_shape); + return this->add_layer(layer, {input}, {weight_attrs}, output_shape); } std::vector ComputationGraphBuilder::gather( @@ -437,7 +458,8 @@ std::vector ComputationGraphBuilder::gather( ff_dim_t dim, std::optional const &maybe_name) { GatherAttrs attrs = {dim}; - std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; if (this->get_shape(index).data_type != DataType::INT32 && @@ -454,8 +476,11 @@ std::vector ComputationGraphBuilder::gather( return this->add_layer(layer, {input}, {}, output_shapes); } -/* std::vector ComputationGraphBuilder::get_shapes(std::vector const &ts) const { */ -/* return transform(ts, [&](tensor_guid_t const &t) { return this->get_shape(t); }); */ +/* std::vector + * ComputationGraphBuilder::get_shapes(std::vector const &ts) + * const { */ +/* return transform(ts, [&](tensor_guid_t const &t) { return + * this->get_shape(t); }); */ /* } */ // tensor_guid_t ComputationGraphBuilder::aggregate( @@ -489,7 +514,8 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( bool relu, std::optional const &maybe_name) { BatchNormAttrs attrs = BatchNormAttrs{relu}; - std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc index c181a1ebbc..ff1d34852b 100644 --- a/lib/pcg/src/pcg/machine_view.cc +++ b/lib/pcg/src/pcg/machine_view.cc @@ -1,6 +1,6 @@ #include "pcg/machine_view.h" -#include "pcg/strided_rectangle_side.h" #include "pcg/strided_rectangle.dtg.h" +#include "pcg/strided_rectangle_side.h" namespace FlexFlow { @@ -12,7 +12,7 @@ std::size_t num_dims(MachineView const &) { NOT_IMPLEMENTED(); } -std::size_t num_devices(MachineView const &) { +std::size_t num_devices(MachineView const &) { NOT_IMPLEMENTED(); } @@ -23,7 +23,8 @@ DeviceType get_device_type(MachineView const &) { static StridedRectangle make_1d_rect(int start, int stop, int stride) { assert(stop > start); assert(stride > 0); - StridedRectangleSide side = strided_side_from_size_and_stride(side_size_t{stop - start}, stride); + StridedRectangleSide side = + strided_side_from_size_and_stride(side_size_t{stop - start}, stride); StridedRectangle rect = {{side}}; return rect; } @@ -59,5 +60,4 @@ MachineView make_1d_machine_view(device_id_t start, size_t interval_size) { /* return this->start + offset; */ /* } */ - } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph.cc b/lib/pcg/src/pcg/operator_graph/operator_graph.cc index afb3f99e55..461fc8027c 100644 --- a/lib/pcg/src/pcg/operator_graph/operator_graph.cc +++ b/lib/pcg/src/pcg/operator_graph/operator_graph.cc @@ -4,22 +4,26 @@ namespace FlexFlow { /* struct OperatorGraphView::Impl { */ -/* MultiDiGraphView raw_graph; */ +/* MultiDiGraphView raw_graph; */ /* }; */ /* struct OperatorGraph::Impl { */ /* MultiDiGraph raw_graph; */ /* }; */ -/* std::unordered_set get_outputs(OperatorGraphView const &g) { */ -/* return transform(get_outputs(g.impl->raw_graph), [](MultiDiOutput const &o) {}); */ +/* std::unordered_set get_outputs(OperatorGraphView const + * &g) { */ +/* return transform(get_outputs(g.impl->raw_graph), [](MultiDiOutput const &o) + * {}); */ /* } */ -/* std::vector get_outputs(OperatorGraphView const &, Node const &) { */ +/* std::vector get_outputs(OperatorGraphView const &, Node + * const &) { */ /* NOT_IMPLEMENTED(); */ /* } */ -/* std::unordered_set get_uses(OperatorGraphView const &, OperatorGraphOutput const &) { */ +/* std::unordered_set get_uses(OperatorGraphView const &, + * OperatorGraphOutput const &) { */ /* NOT_IMPLEMENTED(); */ /* } */ diff --git a/lib/pcg/src/pcg/strided_rectangle.dtg.cc b/lib/pcg/src/pcg/strided_rectangle.dtg.cc index d9cb72a882..e743a2722a 100644 --- a/lib/pcg/src/pcg/strided_rectangle.dtg.cc +++ b/lib/pcg/src/pcg/strided_rectangle.dtg.cc @@ -3,7 +3,7 @@ // lib/pcg/include/pcg/strided_rectangle.struct.toml /* proj-data { - "generated_from": "87af84e6a16d5363049cb9a9a75e4f5f" + "generated_from": "817bbe017d179aa469822a4032d08836" } */ @@ -63,6 +63,15 @@ void adl_serializer::to_json( } } // namespace nlohmann +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary< + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>>()); +} +} // namespace rc + namespace FlexFlow { std::string format_as(StridedRectangle const &x) { std::ostringstream oss; diff --git a/lib/pcg/src/pcg/strided_rectangle_side.cc b/lib/pcg/src/pcg/strided_rectangle_side.cc index fad022a65b..80258886d7 100644 --- a/lib/pcg/src/pcg/strided_rectangle_side.cc +++ b/lib/pcg/src/pcg/strided_rectangle_side.cc @@ -3,7 +3,8 @@ namespace FlexFlow { -StridedRectangleSide strided_side_from_size_and_stride(side_size_t, int stride) { +StridedRectangleSide strided_side_from_size_and_stride(side_size_t, + int stride) { NOT_IMPLEMENTED(); } @@ -11,4 +12,4 @@ side_size_t get_side_size(StridedRectangleSide const &s) { return s.num_points.unwrapped * s.stride; } -} +} // namespace FlexFlow diff --git a/lib/pcg/test/src/test_computation_graph_builder.cc b/lib/pcg/test/src/test_computation_graph_builder.cc index c241d0f2c5..e88e231bd0 100644 --- a/lib/pcg/test/src/test_computation_graph_builder.cc +++ b/lib/pcg/test/src/test_computation_graph_builder.cc @@ -1,22 +1,27 @@ #include "doctest/doctest.h" -#include "pcg/computation_graph_builder.h" #include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ComputationGraphBuilder") { ComputationGraphBuilder b; - + size_t batch_size = 2; TensorShape input_shape = { - TensorDims{ - FFOrdered{batch_size, 3, 10, 10} - }, - DataType::FLOAT, + TensorDims{FFOrdered{batch_size, 3, 10, 10}}, + DataType::FLOAT, }; tensor_guid_t input = b.create_tensor(input_shape, /*create_grad=*/true); - tensor_guid_t output = b.conv2d(input, /*outChannels=*/5, /*kernelH=*/3, /*kernelW=*/3, /*strideH=*/1, /*strideW=*/1, /*paddingH=*/0, /*paddingW=*/0); + tensor_guid_t output = b.conv2d(input, + /*outChannels=*/5, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/0, + /*paddingW=*/0); // ComputationGraph cg = b.computation_graph; // CHECK(get_layers(cg).size() == 1); } diff --git a/lib/substitution-generator/include/substitution-generator/json.h b/lib/substitution-generator/include/substitution-generator/json.h index ebffc93a76..54d923a378 100644 --- a/lib/substitution-generator/include/substitution-generator/json.h +++ b/lib/substitution-generator/include/substitution-generator/json.h @@ -78,86 +78,86 @@ NLOHMANN_JSON_SERIALIZE_ENUM(PMParameter, {PM_PARALLEL_DEGREE, "PM_PARALLEL_DEGREE"}, {PM_PAD, "PM_PAD"}}) -NLOHMANN_JSON_SERIALIZE_ENUM(Op, - {{OperatorType::NOOP, "OP_NOOP"}, - {OperatorType::CONV2D, "OP_CONV2D"}, - {OperatorType::DROPOUT, "OP_DROPOUT"}, - {OperatorType::LINEAR, "OP_LINEAR"}, - {OperatorType::BATCHMATMUL, "OP_BATCHMATMUL"}, - {OperatorType::POOL2D, "OP_POOL2D_MAX"}, - {OperatorType::SCALAR_MULTIPLY, "OP_SCALAR_MULTIPLY"}, - {OperatorType::SCALAR_ADD, "OP_SCALAR_ADD"}, - {OperatorType::SCALAR_FLOOR_DIV, "OP_SCALAR_FLOOR_DIV"}, - {OperatorType::SCALAR_TRUE_DIV, "OP_SCALAR_TRUE_DIV"}, - {OperatorType::SCALAR_SUB, "OP_SCALAR_SUB"}, - {OperatorType::RELU, "OP_RELU"}, - {OperatorType::IDENTITY, "OP_IDENTITY"}, - {OperatorType::SIGMOID, "OP_SIGMOID"}, - {OperatorType::TANH, "OP_TANH"}, - {OperatorType::ELU, "OP_ELU"}, - {OperatorType::FLAT, "OP_FLAT"}, - {OperatorType::SOFTMAX, "OP_SOFTMAX"}, - {OperatorType::BATCHNORM, "OP_BATCHNORM"}, - {OperatorType::CONCAT, "OP_CONCAT"}, - {OperatorType::SPLIT, "OP_SPLIT"}, - {OperatorType::EMBEDDING, "OP_EMBEDDING"}, - {OperatorType::CACHE, "OP_CACHE"}, - {OperatorType::RESHAPE, "OP_RESHAPE"}, - {OperatorType::REVERSE, "OP_REVERSE"}, - {OperatorType::TRANSPOSE, "OP_TRANSPOSE"}, - {OperatorType::EW_ADD, "OP_EW_ADD"}, - {OperatorType::EW_MUL, "OP_EW_MUL"}, - {OperatorType::MATMUL, "OP_MATMUL"}, - {OperatorType::MUL, "OP_MUL"}, - {OperatorType::ENLARGE, "OP_ENLARGE"}, - {OperatorType::SQUEEZE, "OP_SQUEEZE"}, - {OperatorType::UNSQUEEZE, "OP_UNSQUEEZE"}, - {OperatorType::EW_SUB, "OP_EW_SUB"}, - {OperatorType::EW_DIV, "OP_EW_DIV"}, - {OperatorType::EW_EQUAL, "OP_EW_EQUAL"}, - {OperatorType::EW_GREATER, "OP_EW_GREATER"}, - {OperatorType::EW_LESS, "OP_EW_LESS"}, - {OperatorType::EW_MAX, "OP_EW_MAX"}, - {OperatorType::EW_MIN, "OP_EW_MIN"}, - {OperatorType::REDUCE_ARGMAX, "OP_REDUCE_ARGMAX"}, - {OperatorType::REDUCE_ARGMIN, "OP_REDUCE_ARGMIN"}, - {OperatorType::REDUCE_MAX, "OP_REDUCE_MAX"}, - {OperatorType::REDUCE_MEAN, "OP_REDUCE_MEAN"}, - {OperatorType::REDUCE_MIN, "OP_REDUCE_MIN"}, - {OperatorType::REDUCE_PROD, "OP_REDUCE_PROD"}, - {OperatorType::REDUCE_SUM, "OP_REDUCE_SUM"}, - {OperatorType::PAD, "OP_PAD"}, - {OperatorType::SHAPE, "OP_SHAPE"}, - {OperatorType::SIZE, "OP_SIZE"}, - {OperatorType::TOPK, "OP_TOPK"}, - {OperatorType::WHERE, "OP_WHERE"}, - {OperatorType::CEIL, "OP_CEIL"}, - {OperatorType::CAST, "OP_CAST"}, - {OperatorType::EXP, "OP_EXP"}, - {OperatorType::ROUND, "OP_ROUND"}, - {OperatorType::LOG, "OP_LOG"}, - {OperatorType::LOGICAL_NOT, "OP_LOGICAL_NOT"}, - {OperatorType::SQRT, "OP_SQRT"}, - {OperatorType::SIN, "OP_SIN"}, - {OperatorType::COS, "OP_COS"}, - {OperatorType::LEAKYRELU, "OP_LEAKYRELU"}, - {OperatorType::SLICE, "OP_SLICE"}, - {OperatorType::RESIZE, "OP_RESIZE"}, - {OperatorType::PRELU, "OP_PRELU"}, - {OperatorType::GELU, "OP_GELU"}, - {OperatorType::MULTIHEAD_ATTENTION, - "OP_MULTIHEAD_ATTENTION"}, - {OperatorType::FUSED, "OP_FUSED"}, - {OperatorType::RSQRT, "OP_RSQRT"}, - {OperatorType::POW, "OP_POW"}, - {OperatorType::MEAN, "OP_MEAN"}, - {OperatorType::LAYERNORM, "OP_LAYERNORM"}, - {OperatorType::REPARTITION, "OP_PARTITION"}, - {OperatorType::COMBINE, "OP_COMBINE"}, - {OperatorType::REPLICATE, "OP_REPLICATE"}, - {OperatorType::REDUCTION, "OP_REDUCE"}, - {OperatorType::PIPELINE, "OP_PIPELINE"}, - {OperatorType::FUSED_PARALLEL, "OP_FUSED_PARALLEL"}}) +NLOHMANN_JSON_SERIALIZE_ENUM( + Op, + {{OperatorType::NOOP, "OP_NOOP"}, + {OperatorType::CONV2D, "OP_CONV2D"}, + {OperatorType::DROPOUT, "OP_DROPOUT"}, + {OperatorType::LINEAR, "OP_LINEAR"}, + {OperatorType::BATCHMATMUL, "OP_BATCHMATMUL"}, + {OperatorType::POOL2D, "OP_POOL2D_MAX"}, + {OperatorType::SCALAR_MULTIPLY, "OP_SCALAR_MULTIPLY"}, + {OperatorType::SCALAR_ADD, "OP_SCALAR_ADD"}, + {OperatorType::SCALAR_FLOOR_DIV, "OP_SCALAR_FLOOR_DIV"}, + {OperatorType::SCALAR_TRUE_DIV, "OP_SCALAR_TRUE_DIV"}, + {OperatorType::SCALAR_SUB, "OP_SCALAR_SUB"}, + {OperatorType::RELU, "OP_RELU"}, + {OperatorType::IDENTITY, "OP_IDENTITY"}, + {OperatorType::SIGMOID, "OP_SIGMOID"}, + {OperatorType::TANH, "OP_TANH"}, + {OperatorType::ELU, "OP_ELU"}, + {OperatorType::FLAT, "OP_FLAT"}, + {OperatorType::SOFTMAX, "OP_SOFTMAX"}, + {OperatorType::BATCHNORM, "OP_BATCHNORM"}, + {OperatorType::CONCAT, "OP_CONCAT"}, + {OperatorType::SPLIT, "OP_SPLIT"}, + {OperatorType::EMBEDDING, "OP_EMBEDDING"}, + {OperatorType::CACHE, "OP_CACHE"}, + {OperatorType::RESHAPE, "OP_RESHAPE"}, + {OperatorType::REVERSE, "OP_REVERSE"}, + {OperatorType::TRANSPOSE, "OP_TRANSPOSE"}, + {OperatorType::EW_ADD, "OP_EW_ADD"}, + {OperatorType::EW_MUL, "OP_EW_MUL"}, + {OperatorType::MATMUL, "OP_MATMUL"}, + {OperatorType::MUL, "OP_MUL"}, + {OperatorType::ENLARGE, "OP_ENLARGE"}, + {OperatorType::SQUEEZE, "OP_SQUEEZE"}, + {OperatorType::UNSQUEEZE, "OP_UNSQUEEZE"}, + {OperatorType::EW_SUB, "OP_EW_SUB"}, + {OperatorType::EW_DIV, "OP_EW_DIV"}, + {OperatorType::EW_EQUAL, "OP_EW_EQUAL"}, + {OperatorType::EW_GREATER, "OP_EW_GREATER"}, + {OperatorType::EW_LESS, "OP_EW_LESS"}, + {OperatorType::EW_MAX, "OP_EW_MAX"}, + {OperatorType::EW_MIN, "OP_EW_MIN"}, + {OperatorType::REDUCE_ARGMAX, "OP_REDUCE_ARGMAX"}, + {OperatorType::REDUCE_ARGMIN, "OP_REDUCE_ARGMIN"}, + {OperatorType::REDUCE_MAX, "OP_REDUCE_MAX"}, + {OperatorType::REDUCE_MEAN, "OP_REDUCE_MEAN"}, + {OperatorType::REDUCE_MIN, "OP_REDUCE_MIN"}, + {OperatorType::REDUCE_PROD, "OP_REDUCE_PROD"}, + {OperatorType::REDUCE_SUM, "OP_REDUCE_SUM"}, + {OperatorType::PAD, "OP_PAD"}, + {OperatorType::SHAPE, "OP_SHAPE"}, + {OperatorType::SIZE, "OP_SIZE"}, + {OperatorType::TOPK, "OP_TOPK"}, + {OperatorType::WHERE, "OP_WHERE"}, + {OperatorType::CEIL, "OP_CEIL"}, + {OperatorType::CAST, "OP_CAST"}, + {OperatorType::EXP, "OP_EXP"}, + {OperatorType::ROUND, "OP_ROUND"}, + {OperatorType::LOG, "OP_LOG"}, + {OperatorType::LOGICAL_NOT, "OP_LOGICAL_NOT"}, + {OperatorType::SQRT, "OP_SQRT"}, + {OperatorType::SIN, "OP_SIN"}, + {OperatorType::COS, "OP_COS"}, + {OperatorType::LEAKYRELU, "OP_LEAKYRELU"}, + {OperatorType::SLICE, "OP_SLICE"}, + {OperatorType::RESIZE, "OP_RESIZE"}, + {OperatorType::PRELU, "OP_PRELU"}, + {OperatorType::GELU, "OP_GELU"}, + {OperatorType::MULTIHEAD_ATTENTION, "OP_MULTIHEAD_ATTENTION"}, + {OperatorType::FUSED, "OP_FUSED"}, + {OperatorType::RSQRT, "OP_RSQRT"}, + {OperatorType::POW, "OP_POW"}, + {OperatorType::MEAN, "OP_MEAN"}, + {OperatorType::LAYERNORM, "OP_LAYERNORM"}, + {OperatorType::REPARTITION, "OP_PARTITION"}, + {OperatorType::COMBINE, "OP_COMBINE"}, + {OperatorType::REPLICATE, "OP_REPLICATE"}, + {OperatorType::REDUCTION, "OP_REDUCE"}, + {OperatorType::PIPELINE, "OP_PIPELINE"}, + {OperatorType::FUSED_PARALLEL, "OP_FUSED_PARALLEL"}}) struct Parameter { PMParameter key; diff --git a/lib/substitutions/include/substitutions/graph_pattern.h b/lib/substitutions/include/substitutions/graph_pattern.h index 533109387b..5f03a6e92e 100644 --- a/lib/substitutions/include/substitutions/graph_pattern.h +++ b/lib/substitutions/include/substitutions/graph_pattern.h @@ -1,19 +1,21 @@ #ifndef _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_H #define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_H -#include "substitutions/sub_parallel_computation_graph.dtg.h" #include "substitutions/pcg_pattern.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" #include "substitutions/unlabelled/pattern_edge.dtg.h" -#include "substitutions/unlabelled/pattern_node.dtg.h" #include "substitutions/unlabelled/pattern_matching.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" namespace FlexFlow { UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &); -TensorAttributePattern get_tensor_pattern(PCGPattern const &, PatternEdge const &); -OperatorAttributePattern get_operator_pattern(PCGPattern const &, PatternNode const &); +TensorAttributePattern get_tensor_pattern(PCGPattern const &, + PatternEdge const &); +OperatorAttributePattern get_operator_pattern(PCGPattern const &, + PatternNode const &); bool assignment_satisfies(SubParallelComputationGraph const &, PCGPattern const &, diff --git a/lib/substitutions/include/substitutions/operator_pattern/eval_list_access.h b/lib/substitutions/include/substitutions/operator_pattern/eval_list_access.h index 777f38edea..93d2d56384 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/eval_list_access.h +++ b/lib/substitutions/include/substitutions/operator_pattern/eval_list_access.h @@ -1,14 +1,16 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_H +#include "op-attrs/pcg_operator_attrs.dtg.h" #include "substitutions/operator_pattern/operator_attribute_list_access.dtg.h" #include "substitutions/operator_pattern/operator_attribute_value.dtg.h" -#include "op-attrs/pcg_operator_attrs.dtg.h" #include namespace FlexFlow { -std::optional eval_list_access(PCGOperatorAttrs const &attrs, OperatorAttributeListIndexAccess const &); +std::optional + eval_list_access(PCGOperatorAttrs const &attrs, + OperatorAttributeListIndexAccess const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/operator_pattern/eval_list_size.h b/lib/substitutions/include/substitutions/operator_pattern/eval_list_size.h index 337799955b..236a248945 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/eval_list_size.h +++ b/lib/substitutions/include/substitutions/operator_pattern/eval_list_size.h @@ -1,13 +1,15 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_H +#include "op-attrs/pcg_operator_attrs.dtg.h" #include "substitutions/operator_pattern/operator_attribute_list_size.dtg.h" #include "substitutions/operator_pattern/operator_attribute_value.dtg.h" -#include "op-attrs/pcg_operator_attrs.dtg.h" namespace FlexFlow { -std::optional eval_list_size(PCGOperatorAttrs const &attrs, OperatorAttributeListSize const &acc); +std::optional + eval_list_size(PCGOperatorAttrs const &attrs, + OperatorAttributeListSize const &acc); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h index f37ad64df0..4528847771 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_H +#include "pcg/parallel_layer_attrs.dtg.h" #include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" #include "substitutions/operator_pattern/operator_attribute_value.dtg.h" -#include "pcg/parallel_layer_attrs.dtg.h" #include namespace FlexFlow { diff --git a/lib/substitutions/include/substitutions/operator_pattern/satisfies_constraint.h b/lib/substitutions/include/substitutions/operator_pattern/satisfies_constraint.h index 2ac45af0be..7ddda2219c 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/satisfies_constraint.h +++ b/lib/substitutions/include/substitutions/operator_pattern/satisfies_constraint.h @@ -1,12 +1,14 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_SATISFIES_CONSTRAINT_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_SATISFIES_CONSTRAINT_H -#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" #include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" namespace FlexFlow { -bool operator_satisfies_constraint(PCGOperatorAttrs const ¶ms, OperatorAttributeConstraint const &constraint); +bool operator_satisfies_constraint( + PCGOperatorAttrs const ¶ms, + OperatorAttributeConstraint const &constraint); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/operator_pattern/satisfies_pattern.h b/lib/substitutions/include/substitutions/operator_pattern/satisfies_pattern.h index f33e027777..ca4d5c13fa 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/satisfies_pattern.h +++ b/lib/substitutions/include/substitutions/operator_pattern/satisfies_pattern.h @@ -6,7 +6,8 @@ namespace FlexFlow { -bool operator_satisfies_pattern(PCGOperatorAttrs const &attrs, OperatorAttributePattern const &pattern); +bool operator_satisfies_pattern(PCGOperatorAttrs const &attrs, + OperatorAttributePattern const &pattern); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index e58502a745..5d40f3f975 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -5,9 +5,13 @@ namespace FlexFlow { -ParallelLayerAttrs get_parallel_layer_attrs(SubParallelComputationGraph const &, Node const &); -PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &, Node const &); -ParallelTensorAttrs get_parallel_tensor_attrs(SubParallelComputationGraph const &, OpenMultiDiEdge const &); +ParallelLayerAttrs get_parallel_layer_attrs(SubParallelComputationGraph const &, + Node const &); +PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &, + Node const &); +ParallelTensorAttrs + get_parallel_tensor_attrs(SubParallelComputationGraph const &, + OpenMultiDiEdge const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h index 78af6c7405..e245e800b2 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h @@ -1,13 +1,14 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_ACCESS_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_ACCESS_H +#include "pcg/parallel_tensor_attrs.dtg.h" #include "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h" #include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" -#include "pcg/parallel_tensor_attrs.dtg.h" namespace FlexFlow { -TensorAttributeValue eval_list_access(ParallelTensorAttrs const &attrs, TensorAttributeListIndexAccess const &); +TensorAttributeValue eval_list_access(ParallelTensorAttrs const &attrs, + TensorAttributeListIndexAccess const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h index 863cb81239..de0d58e14f 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h @@ -1,13 +1,14 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_SIZE_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_SIZE_H +#include "pcg/parallel_tensor_attrs.dtg.h" #include "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h" #include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" -#include "pcg/parallel_tensor_attrs.dtg.h" namespace FlexFlow { -TensorAttributeValue eval_list_size(ParallelTensorAttrs const &attrs, TensorAttributeListSize const &); +TensorAttributeValue eval_list_size(ParallelTensorAttrs const &attrs, + TensorAttributeListSize const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h b/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h index f276fcbd3a..eedca2da82 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h @@ -1,13 +1,14 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_GET_ATTRIBUTE_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_GET_ATTRIBUTE_H -#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" -#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" #include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" namespace FlexFlow { -TensorAttributeValue get_attribute(ParallelTensorAttrs const &, TensorAttributeKey); +TensorAttributeValue get_attribute(ParallelTensorAttrs const &, + TensorAttributeKey); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h index 2e15e604f8..6c11b421a8 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h @@ -1,12 +1,14 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_CONSTRAINT_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_CONSTRAINT_H -#include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" #include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" namespace FlexFlow { -bool parallel_tensor_satisfies_constraint(ParallelTensorAttrs const ¶ms, TensorAttributeConstraint const &constraint); +bool parallel_tensor_satisfies_constraint( + ParallelTensorAttrs const ¶ms, + TensorAttributeConstraint const &constraint); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h index 8defca7e50..b8b46669c6 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h @@ -1,12 +1,13 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_PATTERN_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_PATTERN_H -#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" #include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" namespace FlexFlow { -bool parallel_tensor_satisfies_pattern(ParallelTensorAttrs const &attrs, TensorAttributePattern const &pattern); +bool parallel_tensor_satisfies_pattern(ParallelTensorAttrs const &attrs, + TensorAttributePattern const &pattern); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h index 12515d2716..98d4394530 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h @@ -1,15 +1,14 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_H -#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" -#include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" #include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" namespace FlexFlow { -TensorAttributeValue - evaluate_attribute_expr(ParallelTensorAttrs const &attrs, - TensorAttributeExpr const &expr); +TensorAttributeValue evaluate_attribute_expr(ParallelTensorAttrs const &attrs, + TensorAttributeExpr const &expr); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/edge_splits.h b/lib/substitutions/include/substitutions/unlabelled/edge_splits.h index 76135f34db..58704500ac 100644 --- a/lib/substitutions/include/substitutions/unlabelled/edge_splits.h +++ b/lib/substitutions/include/substitutions/unlabelled/edge_splits.h @@ -1,17 +1,20 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_H -#include -#include "substitutions/unlabelled/edge_splits.dtg.h" #include "substitutions/unlabelled/closed_pattern_edge.dtg.h" +#include "substitutions/unlabelled/edge_splits.dtg.h" #include "substitutions/unlabelled/input_pattern_edge.dtg.h" #include "substitutions/unlabelled/output_pattern_edge.dtg.h" +#include namespace FlexFlow { -std::pair get_split_edges(UnlabelledPatternEdgeSplits const &, ClosedPatternEdge const &); +std::pair + get_split_edges(UnlabelledPatternEdgeSplits const &, + ClosedPatternEdge const &); -std::vector> as_closed_output_input_tuples(UnlabelledPatternEdgeSplits const &); +std::vector> + as_closed_output_input_tuples(UnlabelledPatternEdgeSplits const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h b/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h index 95f393109c..29c5740c0e 100644 --- a/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h +++ b/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_FIND_PATTERN_MATCHES_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_FIND_PATTERN_MATCHES_H -#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" -#include "utils/graph.h" #include "substitutions/unlabelled/match_additional_criterion.dtg.h" #include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" +#include "utils/graph.h" namespace FlexFlow { diff --git a/lib/substitutions/include/substitutions/unlabelled/match_split.h b/lib/substitutions/include/substitutions/unlabelled/match_split.h index 221805daa9..a23bc3f89a 100644 --- a/lib/substitutions/include/substitutions/unlabelled/match_split.h +++ b/lib/substitutions/include/substitutions/unlabelled/match_split.h @@ -3,8 +3,8 @@ #include "substitutions/unlabelled/match_split.dtg.h" #include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" -#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" #include "substitutions/unlabelled/pattern_split.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" namespace FlexFlow { diff --git a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h index 550d4249f4..aacae6d42a 100644 --- a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h +++ b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h @@ -1,16 +1,16 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_H -#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" #include "substitutions/unlabelled/edge_splits.dtg.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" namespace FlexFlow { MultiDiGraphPatternMatch empty_multidigraph_pattern_match(); -std::optional unsplit_matches( - MultiDiGraphPatternMatch const &prefix, - MultiDiGraphPatternMatch const &postfix, - UnlabelledPatternEdgeSplits const &edge_splits); +std::optional + unsplit_matches(MultiDiGraphPatternMatch const &prefix, + MultiDiGraphPatternMatch const &postfix, + UnlabelledPatternEdgeSplits const &edge_splits); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h index 689f46012b..79db533d4e 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h @@ -1,11 +1,11 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PATTERN_EDGE_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PATTERN_EDGE_H -#include "substitutions/unlabelled/pattern_node.dtg.h" -#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/closed_pattern_edge.dtg.h" #include "substitutions/unlabelled/input_pattern_edge.dtg.h" #include "substitutions/unlabelled/output_pattern_edge.dtg.h" -#include "substitutions/unlabelled/closed_pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" namespace FlexFlow { diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h index aee90413d5..223886b411 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h @@ -1,18 +1,19 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_MATCHING_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_MATCHING_H -#include "utils/graph.h" -#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" -#include "substitutions/unlabelled/match_split.dtg.h" #include "substitutions/unlabelled/match_additional_criterion.dtg.h" +#include "substitutions/unlabelled/match_split.dtg.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" +#include "utils/graph.h" namespace FlexFlow { -bool unlabelled_pattern_does_match(UnlabelledGraphPattern const &pattern, - OpenMultiDiGraphView const &graph, - MultiDiGraphPatternMatch const &match, - MatchAdditionalCriterion const &additional_criterion); +bool unlabelled_pattern_does_match( + UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MultiDiGraphPatternMatch const &match, + MatchAdditionalCriterion const &additional_criterion); std::vector find_pattern_matches(UnlabelledGraphPattern const &pattern, diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.h b/lib/substitutions/include/substitutions/unlabelled/pattern_split.h index 50d3d37eb8..3fcc5cb12f 100644 --- a/lib/substitutions/include/substitutions/unlabelled/pattern_split.h +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_H -#include "substitutions/unlabelled/pattern_split.dtg.h" #include "substitutions/unlabelled/edge_splits.dtg.h" +#include "substitutions/unlabelled/pattern_split.dtg.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" namespace FlexFlow { @@ -11,10 +11,12 @@ PatternSplit find_even_split(UnlabelledGraphPattern const &); GraphSplit get_raw_split(PatternSplit const &); -UnlabelledPatternEdgeSplits get_edge_splits(UnlabelledGraphPattern const &pattern, PatternSplit const &split); +UnlabelledPatternEdgeSplits + get_edge_splits(UnlabelledGraphPattern const &pattern, + PatternSplit const &split); std::pair - apply_split(UnlabelledGraphPattern const &, PatternSplit const &); + apply_split(UnlabelledGraphPattern const &, PatternSplit const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h index 822e51588a..9bb63037be 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h @@ -1,11 +1,11 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_H -#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" -#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "substitutions/unlabelled/downward_open_pattern_edge.dtg.h" #include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" #include "substitutions/unlabelled/upward_open_pattern_edge.dtg.h" -#include "substitutions/unlabelled/downward_open_pattern_edge.dtg.h" namespace FlexFlow { @@ -13,12 +13,16 @@ size_t num_nodes(UnlabelledGraphPattern const &); bool is_singleton_pattern(UnlabelledGraphPattern const &); std::unordered_set get_nodes(UnlabelledGraphPattern const &); std::unordered_set get_edges(UnlabelledGraphPattern const &); -std::vector get_topological_ordering(UnlabelledGraphPattern const &); +std::vector + get_topological_ordering(UnlabelledGraphPattern const &); -std::unordered_set get_incoming_edges(UnlabelledGraphPattern const &, PatternNode const &); -std::unordered_set get_outgoing_edges(UnlabelledGraphPattern const &, PatternNode const &); +std::unordered_set + get_incoming_edges(UnlabelledGraphPattern const &, PatternNode const &); +std::unordered_set + get_outgoing_edges(UnlabelledGraphPattern const &, PatternNode const &); -UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &, std::unordered_set const &); +UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &, + std::unordered_set const &); } // namespace FlexFlow diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 20e14c2256..94993f3c90 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -10,7 +10,8 @@ namespace FlexFlow { /* } */ /* std::unordered_set> */ -/* derive_valid_operator_attribute_expr(OperatorAttributeKey const &key) { */ +/* derive_valid_operator_attribute_expr(OperatorAttributeKey const &key) { + */ /* return {key}; */ /* } */ @@ -30,8 +31,10 @@ namespace FlexFlow { /* std::unordered_set> */ /* get_valid_operator_attribute_exprs(OperatorPattern const &pattern) { */ /* return set_union(transform( */ -/* pattern.attribute_constraints, [](OperatorAttributeConstraint const &t) { */ -/* return visit(DeriveValidOperatorAttributeExpr{}, t.attribute_expr); */ +/* pattern.attribute_constraints, [](OperatorAttributeConstraint const &t) + * { */ +/* return visit(DeriveValidOperatorAttributeExpr{}, t.attribute_expr); + */ /* })); */ /* } */ @@ -50,7 +53,8 @@ namespace FlexFlow { /* } */ /* bool is_valid(OperatorAttrAccess const &t) const { */ -/* return is_valid_operator_attribute_expr(graph_pattern.value().at(t.node), */ +/* return is_valid_operator_attribute_expr(graph_pattern.value().at(t.node), + */ /* t.attr_expr); */ /* } */ @@ -88,7 +92,8 @@ namespace FlexFlow { /* OperatorAttributeValue evaluate(OperatorAttrAccess const &t) { */ /* Node node_in_pattern = t.node; */ /* Node node_in_pcg = match.node_assignment.at_l(node_in_pattern); */ -/* return evaluate_attribute_expr(graph.at(node_in_pcg), t.attr_expr).value(); */ +/* return evaluate_attribute_expr(graph.at(node_in_pcg), + * t.attr_expr).value(); */ /* } */ /* OperatorAttributeValue evaluate(AttrConstant const &t) { */ @@ -106,7 +111,8 @@ namespace FlexFlow { /* Operator get_operator_attrs(SubParallelComputationGraph const &graph, */ /* MultiDiGraphPatternMatch const &match, */ /* OperatorAttrAssignment const &assignment) { */ -/* std::unordered_map assignments; */ +/* std::unordered_map + * assignments; */ /* for (auto const &[key, expr] : assignment.assignments) { */ /* OperatorAttributeValue value = */ /* evaluate_graph_attribute_expr(graph, match, expr); */ @@ -116,7 +122,8 @@ namespace FlexFlow { /* assert(std::holds_alternative( */ /* assignments.at(OperatorAttributeKey::OP_TYPE))); */ /* OperatorType op_type = */ -/* std::get(assignments.at(OperatorAttributeKey::OP_TYPE)); */ +/* std::get(assignments.at(OperatorAttributeKey::OP_TYPE)); + */ /* switch (op_type) { */ /* case OperatorType::BATCHMATMUL: */ /* return Operator{ */ @@ -131,33 +138,45 @@ namespace FlexFlow { /* std::nullopt}; */ /* case OperatorType::CAST: */ /* return Operator{CastAttrs{std::get( */ -/* assignments.at(OperatorAttributeKey::DATA_TYPE))}, */ +/* assignments.at(OperatorAttributeKey::DATA_TYPE))}, + */ /* std::nullopt}; */ /* case OperatorType::CONCAT: */ /* return Operator{ */ /* ConcatAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::AXIS)), */ -/* std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, */ +/* std::get(assignments.at(OperatorAttributeKey::AXIS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, + */ /* std::nullopt}; */ /* case OperatorType::CONV2D: */ /* return Operator{ */ /* Conv2DAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), */ -/* std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), */ -/* std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), */ -/* std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), */ -/* std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), */ -/* std::get(assignments.at(OperatorAttributeKey::PADDING_H)), */ -/* std::get(assignments.at(OperatorAttributeKey::PADDING_W)), */ +/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_W)), + */ /* std::get(assignments.at(OperatorAttributeKey::GROUPS)), */ /* std::get( */ /* assignments.at(OperatorAttributeKey::ACTIVATION)), */ -/* std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, */ +/* std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, + */ /* std::nullopt}; */ /* case OperatorType::DROPOUT: */ /* return Operator{DropoutAttrs{std::get(assignments.at( */ /* OperatorAttributeKey::RATE)), */ -/* std::get(assignments.at( */ +/* std::get(assignments.at( */ /* OperatorAttributeKey::SEED))}, */ /* std::nullopt}; */ /* case OperatorType::EW_ADD: */ @@ -174,9 +193,11 @@ namespace FlexFlow { /* std::get(assignments.at( */ /* OperatorAttributeKey::DATA_TYPE)), */ /* std::get(assignments.at( */ -/* OperatorAttributeKey::SHOULD_BROADCAST_LHS)), */ +/* OperatorAttributeKey::SHOULD_BROADCAST_LHS)), + */ /* std::get(assignments.at( */ -/* OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, */ +/* OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, + */ /* std::nullopt}; */ /* case OperatorType::SCALAR_ADD: */ /* case OperatorType::SCALAR_FLOOR_DIV: */ @@ -186,7 +207,8 @@ namespace FlexFlow { /* return Operator{ */ /* ElementScalarUnaryAttrs{ */ /* op_type, */ -/* std::get(assignments.at(OperatorAttributeKey::SCALAR))}, */ +/* std::get(assignments.at(OperatorAttributeKey::SCALAR))}, + */ /* std::nullopt}; */ /* case OperatorType::EXP: */ /* case OperatorType::IDENTITY: */ @@ -199,9 +221,12 @@ namespace FlexFlow { /* case OperatorType::EMBEDDING: */ /* return Operator{ */ /* EmbeddingAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), */ -/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), */ -/* std::get(assignments.at(OperatorAttributeKey::AGGR)), */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), + */ +/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::AGGR)), + */ /* std::get( */ /* assignments.at(OperatorAttributeKey::OP_TYPE))}, */ /* std::nullopt}; */ @@ -219,14 +244,18 @@ namespace FlexFlow { /* std::get>( */ /* assignments.at(OperatorAttributeKey::AXES)), */ /* std::get( */ -/* assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), */ -/* std::get(assignments.at(OperatorAttributeKey::EPSILON))}, */ +/* assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), + */ +/* std::get(assignments.at(OperatorAttributeKey::EPSILON))}, + */ /* std::nullopt}; */ /* case OperatorType::LINEAR: */ /* return Operator{ */ /* LinearAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), */ -/* std::get(assignments.at(OperatorAttributeKey::USE_BIAS)), */ +/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::USE_BIAS)), + */ /* std::get( */ /* assignments.at(OperatorAttributeKey::DATA_TYPE)), */ /* std::get( */ @@ -237,13 +266,18 @@ namespace FlexFlow { /* case OperatorType::MULTIHEAD_ATTENTION: */ /* return Operator{ */ /* MultiHeadAttentionAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), */ -/* std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), */ -/* std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), */ +/* std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), + */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), + */ /* std::get(assignments.at(OperatorAttributeKey::VDIM)), */ -/* std::get(assignments.at(OperatorAttributeKey::DROPOUT)), */ +/* std::get(assignments.at(OperatorAttributeKey::DROPOUT)), + */ /* std::get(assignments.at(OperatorAttributeKey::BIAS)), */ -/* std::get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), */ +/* std::get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), + */ /* std::get( */ /* assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, */ /* std::nullopt}; */ @@ -252,13 +286,20 @@ namespace FlexFlow { /* case OperatorType::POOL2D: */ /* return Operator{ */ /* Pool2DAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), */ -/* std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), */ -/* std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), */ -/* std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), */ -/* std::get(assignments.at(OperatorAttributeKey::PADDING_H)), */ -/* std::get(assignments.at(OperatorAttributeKey::PADDING_W)), */ -/* std::get(assignments.at(OperatorAttributeKey::POOL_TYPE)), */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::POOL_TYPE)), + */ /* std::get( */ /* assignments.at(OperatorAttributeKey::ACTIVATION))}, */ /* std::nullopt}; */ @@ -274,7 +315,8 @@ namespace FlexFlow { /* std::get>( */ /* assignments.at(OperatorAttributeKey::AXES)), */ /* op_type, */ -/* std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, */ +/* std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, + */ /* std::nullopt}; */ /* case OperatorType::REVERSE: */ /* return Operator{ReverseAttrs{std::get( */ @@ -289,7 +331,8 @@ namespace FlexFlow { /* SplitAttrs{ */ /* std::get>( */ /* assignments.at(OperatorAttributeKey::SPLITS)), */ -/* std::get(assignments.at(OperatorAttributeKey::AXIS))}, */ +/* std::get(assignments.at(OperatorAttributeKey::AXIS))}, + */ /* std::nullopt}; */ /* case OperatorType::SOFTMAX: */ /* return Operator{SoftmaxAttrs{std::get( */ @@ -299,7 +342,8 @@ namespace FlexFlow { /* return Operator{ */ /* TopKAttrs{ */ /* std::get(assignments.at(OperatorAttributeKey::K)), */ -/* std::get(assignments.at(OperatorAttributeKey::SORTED))}, */ +/* std::get(assignments.at(OperatorAttributeKey::SORTED))}, + */ /* std::nullopt}; */ /* case OperatorType::TRANSPOSE: */ /* return Operator{ */ @@ -308,9 +352,11 @@ namespace FlexFlow { /* std::nullopt}; */ /* case OperatorType::COMBINE: */ /* return Operator{CombineAttrs{std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DIM)), */ +/* OperatorAttributeKey::PARALLEL_DIM)), + */ /* std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ +/* OperatorAttributeKey::PARALLEL_DEGREE))}, + */ /* std::nullopt}; */ /* case OperatorType::REDUCTION: */ /* return Operator{ */ diff --git a/lib/substitutions/src/substitutions/graph_pattern.cc b/lib/substitutions/src/substitutions/graph_pattern.cc index ac032d37ce..22cf12b4cf 100644 --- a/lib/substitutions/src/substitutions/graph_pattern.cc +++ b/lib/substitutions/src/substitutions/graph_pattern.cc @@ -1,6 +1,6 @@ #include "substitutions/graph_pattern.h" -#include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/operator_pattern/satisfies_pattern.h" +#include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/tensor_pattern/satisfies_pattern.h" namespace FlexFlow { @@ -9,11 +9,13 @@ UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &p) { return UnlabelledGraphPattern{p.raw_graph}; } -TensorAttributePattern get_tensor_pattern(PCGPattern const &p, PatternEdge const &e) { +TensorAttributePattern get_tensor_pattern(PCGPattern const &p, + PatternEdge const &e) { return p.raw_graph.at(e.raw_edge); } -OperatorAttributePattern get_operator_pattern(PCGPattern const &p, PatternNode const &n) { +OperatorAttributePattern get_operator_pattern(PCGPattern const &p, + PatternNode const &n) { return p.raw_graph.at(n.raw_node); } @@ -25,20 +27,16 @@ bool assignment_satisfies(SubParallelComputationGraph const &pcg, pcg.raw_graph, patternMatch, MatchAdditionalCriterion{ - [&](PatternNode const &patternNode, Node const &pcgNode) { - return operator_satisfies_pattern( - get_operator_attrs(pcg, pcgNode), - get_operator_pattern(pattern, patternNode) - ); - }, - [&](PatternEdge const &patternEdge, OpenMultiDiEdge const &pcgEdge) { - return parallel_tensor_satisfies_pattern( - get_parallel_tensor_attrs(pcg, pcgEdge), - get_tensor_pattern(pattern, patternEdge) - ); - } - } - ); + [&](PatternNode const &patternNode, Node const &pcgNode) { + return operator_satisfies_pattern( + get_operator_attrs(pcg, pcgNode), + get_operator_pattern(pattern, patternNode)); + }, + [&](PatternEdge const &patternEdge, OpenMultiDiEdge const &pcgEdge) { + return parallel_tensor_satisfies_pattern( + get_parallel_tensor_attrs(pcg, pcgEdge), + get_tensor_pattern(pattern, patternEdge)); + }}); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc b/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc index f90ddc2e20..53973dc1cb 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc @@ -4,36 +4,38 @@ namespace FlexFlow { -std::optional eval_list_access(PCGOperatorAttrs const &attrs, OperatorAttributeListIndexAccess const &acc) { - std::optional from_attr = get_attribute(attrs, acc.attribute_key); +std::optional + eval_list_access(PCGOperatorAttrs const &attrs, + OperatorAttributeListIndexAccess const &acc) { + std::optional from_attr = + get_attribute(attrs, acc.attribute_key); if (!from_attr.has_value()) { return std::nullopt; } - return from_attr.value().visit< - std::optional - >([&](auto const &v) -> std::optional { - using T = std::decay_t; + return from_attr.value().visit>( + [&](auto const &v) -> std::optional { + using T = std::decay_t; - if constexpr (std::is_same_v>) { - if (acc.index >= v.size()) { - return std::nullopt; - } else { - int value = v.at(acc.index); - return OperatorAttributeValue{value}; - } - } else if constexpr (std::is_same_v>) { - if (acc.index >= v.size()) { - return std::nullopt; - } else { - ff_dim_t value = v.at(acc.index); - return OperatorAttributeValue{value}; - } - } else { - throw mk_runtime_error("Invalid operand"); - } - }); + if constexpr (std::is_same_v>) { + if (acc.index >= v.size()) { + return std::nullopt; + } else { + int value = v.at(acc.index); + return OperatorAttributeValue{value}; + } + } else if constexpr (std::is_same_v>) { + if (acc.index >= v.size()) { + return std::nullopt; + } else { + ff_dim_t value = v.at(acc.index); + return OperatorAttributeValue{value}; + } + } else { + throw mk_runtime_error("Invalid operand"); + } + }); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/eval_list_size.cc b/lib/substitutions/src/substitutions/operator_pattern/eval_list_size.cc index 2c2fd4850d..a3ae9c84d1 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/eval_list_size.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/eval_list_size.cc @@ -4,25 +4,28 @@ namespace FlexFlow { -std::optional eval_list_size(PCGOperatorAttrs const &attrs, OperatorAttributeListSize const &acc) { - std::optional from_attr = get_attribute(attrs, acc.attribute_key); - +std::optional + eval_list_size(PCGOperatorAttrs const &attrs, + OperatorAttributeListSize const &acc) { + std::optional from_attr = + get_attribute(attrs, acc.attribute_key); + if (!from_attr.has_value()) { return std::nullopt; } - return from_attr.value().visit< - std::optional - >([&](auto const &v) -> std::optional { - using T = std::decay_t; + return from_attr.value().visit>( + [&](auto const &v) -> std::optional { + using T = std::decay_t; - if constexpr (std::is_same_v> || std::is_same_v>) { - size_t size = v.size(); - return OperatorAttributeValue{size}; - } else { - throw mk_runtime_error("Invalid operand"); - } - }); + if constexpr (std::is_same_v> || + std::is_same_v>) { + size_t size = v.size(); + return OperatorAttributeValue{size}; + } else { + throw mk_runtime_error("Invalid operand"); + } + }); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index 7932a2d26e..28b0d2e37f 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -14,7 +14,8 @@ std::optional get_attribute(BatchMatmulAttrs const &p, } } -std::optional get_attribute(BatchNormAttrs const &p, OperatorAttributeKey key) { +std::optional get_attribute(BatchNormAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::OP_TYPE: return get_op_type(p); @@ -227,8 +228,8 @@ std::optional } } -std::optional - get_attribute(NoopAttrs const &p, OperatorAttributeKey key) { +std::optional get_attribute(NoopAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::OP_TYPE: return get_op_type(p); @@ -237,7 +238,6 @@ std::optional } } - std::optional get_attribute(Pool2DAttrs const &p, OperatorAttributeKey key) { switch (key) { @@ -338,7 +338,6 @@ std::optional get_attribute(ReverseAttrs const &p, } } - std::optional get_attribute(SplitAttrs const &p, OperatorAttributeKey key) { switch (key) { @@ -387,11 +386,8 @@ std::optional get_attribute(TransposeAttrs const &p, std::optional get_attribute(PCGOperatorAttrs const &p, OperatorAttributeKey key) { - return p.visit< - std::optional - >([&](auto const &attrs) { - return get_attribute(attrs, key); - }); + return p.visit>( + [&](auto const &attrs) { return get_attribute(attrs, key); }); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc index 71f03bd364..4a55fa3de3 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc @@ -1,7 +1,7 @@ #include "substitutions/operator_pattern/operator_attribute_expr.h" -#include "substitutions/operator_pattern/get_attribute.h" #include "substitutions/operator_pattern/eval_list_access.h" #include "substitutions/operator_pattern/eval_list_size.h" +#include "substitutions/operator_pattern/get_attribute.h" #include "utils/overload.h" namespace FlexFlow { @@ -9,12 +9,14 @@ namespace FlexFlow { std::optional evaluate_attribute_expr(PCGOperatorAttrs const &attrs, OperatorAttributeExpr const &expr) { - return expr.visit< - std::optional - >(overload { - [&](OperatorAttributeKey const &k) { return get_attribute(attrs, k); }, - [&](OperatorAttributeListSize const &k) { return eval_list_size(attrs, k); }, - [&](OperatorAttributeListIndexAccess const &k) { return eval_list_access(attrs, k); }, + return expr.visit>(overload{ + [&](OperatorAttributeKey const &k) { return get_attribute(attrs, k); }, + [&](OperatorAttributeListSize const &k) { + return eval_list_size(attrs, k); + }, + [&](OperatorAttributeListIndexAccess const &k) { + return eval_list_access(attrs, k); + }, }); } diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc index 5455cdced5..ae42515cc8 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc @@ -3,8 +3,11 @@ namespace FlexFlow { -bool operator_satisfies_constraint(PCGOperatorAttrs const &attrs, OperatorAttributeConstraint const &constraint) { - std::optional expr_val = evaluate_attribute_expr(attrs, constraint.attribute_expr); +bool operator_satisfies_constraint( + PCGOperatorAttrs const &attrs, + OperatorAttributeConstraint const &constraint) { + std::optional expr_val = + evaluate_attribute_expr(attrs, constraint.attribute_expr); if (!expr_val.has_value()) { return false; @@ -14,7 +17,9 @@ bool operator_satisfies_constraint(PCGOperatorAttrs const &attrs, OperatorAttrib case ConstraintType::EQUAL: return expr_val.value() == constraint.attribute_value; default: - throw mk_runtime_error(fmt::format("Unknown constraint type {}", static_cast(constraint.constraint_type))); + throw mk_runtime_error( + fmt::format("Unknown constraint type {}", + static_cast(constraint.constraint_type))); } } diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_pattern.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_pattern.cc index 28d7803a6b..60ab363cc6 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/satisfies_pattern.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_pattern.cc @@ -3,9 +3,12 @@ namespace FlexFlow { -bool operator_satisfies_pattern(PCGOperatorAttrs const &attrs, OperatorAttributePattern const &pattern) { - return all_of(pattern.attribute_constraints, - [&](OperatorAttributeConstraint const &c) { return operator_satisfies_constraint(attrs, c); }); +bool operator_satisfies_pattern(PCGOperatorAttrs const &attrs, + OperatorAttributePattern const &pattern) { + return all_of(pattern.attribute_constraints, + [&](OperatorAttributeConstraint const &c) { + return operator_satisfies_constraint(attrs, c); + }); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index d4dd87543d..7736113819 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -2,15 +2,20 @@ namespace FlexFlow { -ParallelLayerAttrs get_parallel_layer_attrs(SubParallelComputationGraph const &spcg, Node const &n) { +ParallelLayerAttrs + get_parallel_layer_attrs(SubParallelComputationGraph const &spcg, + Node const &n) { return spcg.raw_graph.at(n); } -PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &spcg, Node const &n) { +PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &spcg, + Node const &n) { return get_parallel_layer_attrs(spcg, n).attrs; } -ParallelTensorAttrs get_parallel_tensor_attrs(SubParallelComputationGraph const &spcg, OpenMultiDiEdge const &e) { +ParallelTensorAttrs + get_parallel_tensor_attrs(SubParallelComputationGraph const &spcg, + OpenMultiDiEdge const &e) { return spcg.raw_graph.at(e); } diff --git a/lib/substitutions/src/substitutions/substitution.cc b/lib/substitutions/src/substitutions/substitution.cc index 9d51c7018d..e900175bc6 100644 --- a/lib/substitutions/src/substitutions/substitution.cc +++ b/lib/substitutions/src/substitutions/substitution.cc @@ -29,8 +29,6 @@ namespace FlexFlow { /* } */ /* }; */ - - /* struct AddNewEdgeFunctor { */ /* SubParallelComputationGraph const &old_pcg; */ /* SubParallelComputationGraph &new_pcg; */ @@ -90,8 +88,10 @@ namespace FlexFlow { /* Substitution const &substitution, */ /* MultiDiGraphPatternMatch const &match) { */ /* SubParallelComputationGraph new_pcg = */ -/* OutputLabelledOpenMultiDiGraph::template create< */ -/* UnorderedOutputLabelledOpenMultiDiGraph>(); */ +/* OutputLabelledOpenMultiDiGraph::template + * create< */ +/* UnorderedOutputLabelledOpenMultiDiGraph>(); */ /* bidict node_mapping; // Refactor it with global nodes */ /* for (Node const &node : get_nodes(pcg)) { */ /* if (!contains_r(match.node_assignment, node)) { */ @@ -106,7 +106,8 @@ namespace FlexFlow { /* for (Node const &output_node : */ /* get_nodes(substitution.output_graph_expr.value())) { */ /* Operator new_op = get_operator_attrs( */ -/* pcg, match, substitution.output_graph_expr.value().at(output_node)); */ +/* pcg, match, substitution.output_graph_expr.value().at(output_node)); + */ /* Node new_node = new_pcg.add_node(new_op); */ /* node_mapping.equate(output_node, new_node); */ /* } */ diff --git a/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc index 0b07725471..ea4833d36a 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc @@ -5,19 +5,20 @@ namespace FlexFlow { -TensorAttributeValue eval_list_access(ParallelTensorAttrs const &attrs, TensorAttributeListIndexAccess const &acc) { +TensorAttributeValue + eval_list_access(ParallelTensorAttrs const &attrs, + TensorAttributeListIndexAccess const &acc) { TensorAttributeValue from_attr = get_attribute(attrs, acc.attribute_key); - return from_attr.visit( - overload { - [&](std::vector const &v) -> TensorAttributeValue { + return from_attr.visit(overload{ + [&](std::vector const &v) -> TensorAttributeValue { return TensorAttributeValue{ - static_cast(at_idx(v, acc.index).value()) - }; + static_cast(at_idx(v, acc.index).value())}; }, - [](auto &&) -> TensorAttributeValue { throw mk_runtime_error("Invalid operand"); }, - } - ); + [](auto &&) -> TensorAttributeValue { + throw mk_runtime_error("Invalid operand"); + }, + }); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/eval_list_size.cc b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_size.cc index b63a35380f..d1e97adc37 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/eval_list_size.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_size.cc @@ -4,14 +4,17 @@ namespace FlexFlow { -TensorAttributeValue eval_list_size(ParallelTensorAttrs const &attrs, TensorAttributeListSize const &acc) { +TensorAttributeValue eval_list_size(ParallelTensorAttrs const &attrs, + TensorAttributeListSize const &acc) { TensorAttributeValue from_attr = get_attribute(attrs, acc.attribute_key); - return from_attr.visit(overload { - [](std::vector const &v) -> TensorAttributeValue { - return TensorAttributeValue{v.size()}; - }, - [](auto &&) -> TensorAttributeValue { throw mk_runtime_error("Invalid operand"); }, + return from_attr.visit(overload{ + [](std::vector const &v) -> TensorAttributeValue { + return TensorAttributeValue{v.size()}; + }, + [](auto &&) -> TensorAttributeValue { + throw mk_runtime_error("Invalid operand"); + }, }); } diff --git a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc index ed7ad120a4..4a5330d4af 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc @@ -3,22 +3,24 @@ namespace FlexFlow { -TensorAttributeValue get_attribute(ParallelTensorAttrs const &attrs, TensorAttributeKey key) { +TensorAttributeValue get_attribute(ParallelTensorAttrs const &attrs, + TensorAttributeKey key) { switch (key) { case TensorAttributeKey::DIM_SIZES: { - std::vector sizes = transform(as_vector(ff_ordered(attrs.shape.dims)), - [](ParallelDim const &d) { return d.size; }); + std::vector sizes = + transform(as_vector(ff_ordered(attrs.shape.dims)), + [](ParallelDim const &d) { return d.size; }); return TensorAttributeValue{sizes}; } case TensorAttributeKey::DIM_DEGREES: { - std::vector degrees = transform(as_vector(ff_ordered(attrs.shape.dims)), - [](ParallelDim const &d) { - return static_cast(d.degree); - }); + std::vector degrees = transform( + as_vector(ff_ordered(attrs.shape.dims)), + [](ParallelDim const &d) { return static_cast(d.degree); }); return TensorAttributeValue{degrees}; } default: - throw std::runtime_error(fmt::format("Unknown TensorAttributeKey {}", static_cast(key))); + throw std::runtime_error( + fmt::format("Unknown TensorAttributeKey {}", static_cast(key))); } } diff --git a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc index e464523eef..974bfcabc0 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc @@ -3,14 +3,19 @@ namespace FlexFlow { -bool parallel_tensor_satisfies_constraint(ParallelTensorAttrs const &attrs, TensorAttributeConstraint const &constraint) { - TensorAttributeValue expr_val = evaluate_attribute_expr(attrs, constraint.attribute_expr); +bool parallel_tensor_satisfies_constraint( + ParallelTensorAttrs const &attrs, + TensorAttributeConstraint const &constraint) { + TensorAttributeValue expr_val = + evaluate_attribute_expr(attrs, constraint.attribute_expr); switch (constraint.constraint_type) { case ConstraintType::EQUAL: return expr_val == constraint.attribute_value; default: - throw mk_runtime_error(fmt::format("Unknown constraint type {}", static_cast(constraint.constraint_type))); + throw mk_runtime_error( + fmt::format("Unknown constraint type {}", + static_cast(constraint.constraint_type))); } } diff --git a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_pattern.cc b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_pattern.cc index 1383c46cf8..35fec2dfea 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_pattern.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_pattern.cc @@ -3,8 +3,11 @@ namespace FlexFlow { -bool parallel_tensor_satisfies_pattern(ParallelTensorAttrs const &attrs, TensorAttributePattern const &pattern) { +bool parallel_tensor_satisfies_pattern(ParallelTensorAttrs const &attrs, + TensorAttributePattern const &pattern) { return all_of(pattern.attribute_constraints, - [&](TensorAttributeConstraint const &c) { return parallel_tensor_satisfies_constraint(attrs, c); }); + [&](TensorAttributeConstraint const &c) { + return parallel_tensor_satisfies_constraint(attrs, c); + }); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.cc index 068d5d7a69..33bcc1a082 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.cc @@ -1,26 +1,22 @@ #include "substitutions/tensor_pattern/tensor_attribute_expr.h" -#include "substitutions/tensor_pattern/get_attribute.h" -#include "substitutions/tensor_pattern/eval_list_size.h" #include "substitutions/tensor_pattern/eval_list_access.h" +#include "substitutions/tensor_pattern/eval_list_size.h" +#include "substitutions/tensor_pattern/get_attribute.h" #include "utils/overload.h" namespace FlexFlow { -TensorAttributeValue - evaluate_attribute_expr(ParallelTensorAttrs const &attrs, - TensorAttributeExpr const &expr) { +TensorAttributeValue evaluate_attribute_expr(ParallelTensorAttrs const &attrs, + TensorAttributeExpr const &expr) { - return expr.visit(overload { - [&](TensorAttributeKey const &key) { - return get_attribute(attrs, key); - }, - [&](TensorAttributeListSize const &s) { - return eval_list_size(attrs, s); - }, - [&](TensorAttributeListIndexAccess const &s) { - return eval_list_access(attrs, s); - } - }); + return expr.visit(overload{ + [&](TensorAttributeKey const &key) { return get_attribute(attrs, key); }, + [&](TensorAttributeListSize const &s) { + return eval_list_size(attrs, s); + }, + [&](TensorAttributeListIndexAccess const &s) { + return eval_list_access(attrs, s); + }}); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc b/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc index e1a9fc1fe7..33ea7dc9f6 100644 --- a/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc +++ b/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc @@ -2,27 +2,31 @@ namespace FlexFlow { -std::pair get_split_edges(UnlabelledPatternEdgeSplits const &splits, ClosedPatternEdge const &e) { - std::pair raw_result = splits.unwrapped.at_l(e.raw_edge); +std::pair + get_split_edges(UnlabelledPatternEdgeSplits const &splits, + ClosedPatternEdge const &e) { + std::pair raw_result = + splits.unwrapped.at_l(e.raw_edge); return { - OutputPatternEdge{raw_result.first}, - InputPatternEdge{raw_result.second}, + OutputPatternEdge{raw_result.first}, + InputPatternEdge{raw_result.second}, }; } -std::vector> as_closed_output_input_tuples(UnlabelledPatternEdgeSplits const &s) { - std::vector> result; +std::vector> + as_closed_output_input_tuples(UnlabelledPatternEdgeSplits const &s) { + std::vector< + std::tuple> + result; for (auto const &kv : s.unwrapped) { MultiDiEdge standard_edge = kv.first; OutputMultiDiEdge output_edge = kv.second.first; InputMultiDiEdge input_edge = kv.second.second; - result.push_back({ - ClosedPatternEdge{standard_edge}, - OutputPatternEdge{output_edge}, - InputPatternEdge{input_edge} - }); + result.push_back({ClosedPatternEdge{standard_edge}, + OutputPatternEdge{output_edge}, + InputPatternEdge{input_edge}}); } return result; diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc index 2f2c2d756c..8c787ca255 100644 --- a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -1,26 +1,42 @@ #include "substitutions/unlabelled/find_pattern_matches.h" -#include "substitutions/unlabelled/unlabelled_graph_pattern.h" -#include "utils/containers.h" -#include "substitutions/unlabelled/upward_open_pattern_edge.h" #include "substitutions/unlabelled/downward_open_pattern_edge.h" #include "substitutions/unlabelled/multidigraph_pattern_match.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "substitutions/unlabelled/upward_open_pattern_edge.h" +#include "utils/containers.h" namespace FlexFlow { -static std::vector sorted_by_dst_idx(std::unordered_set const &in) { - return sorted_by(in, compare_by([](UpwardOpenPatternEdge const &e) { return get_dst_idx(e); })); +static std::vector + sorted_by_dst_idx(std::unordered_set const &in) { + return sorted_by( + in, compare_by([](UpwardOpenPatternEdge const &e) { + return get_dst_idx(e); + })); } -static std::vector sorted_by_src_idx(std::unordered_set const &in) { - return sorted_by(in, compare_by([](DownwardOpenPatternEdge const &e) { return get_src_idx(e); })); +static std::vector + sorted_by_src_idx(std::unordered_set const &in) { + return sorted_by( + in, + compare_by( + [](DownwardOpenPatternEdge const &e) { return get_src_idx(e); })); } -static std::vector sorted_by_dst_idx(std::unordered_set const &in) { - return sorted_by(in, compare_by([](UpwardOpenPatternEdge const &e) { return get_dst_idx(e); })); +static std::vector + sorted_by_dst_idx(std::unordered_set const &in) { + return sorted_by( + in, compare_by([](UpwardOpenPatternEdge const &e) { + return get_dst_idx(e); + })); } -static std::vector sorted_by_src_idx(std::unordered_set const &in) { - return sorted_by(in, compare_by([](DownwardOpenMultiDiEdge const &e) { return get_src_idx(e); })); +static std::vector + sorted_by_src_idx(std::unordered_set const &in) { + return sorted_by( + in, + compare_by( + [](DownwardOpenMultiDiEdge const &e) { return get_src_idx(e); })); } static std::optional @@ -52,11 +68,15 @@ static std::optional return std::nullopt; } - std::vector incoming_ordered = sorted_by_dst_idx(incoming); - std::vector outgoing_ordered = sorted_by_src_idx(outgoing); + std::vector incoming_ordered = + sorted_by_dst_idx(incoming); + std::vector outgoing_ordered = + sorted_by_src_idx(outgoing); - std::vector pattern_incoming_ordered = sorted_by_dst_idx(pattern_incoming); - std::vector pattern_outgoing_ordered = sorted_by_src_idx(pattern_outgoing); + std::vector pattern_incoming_ordered = + sorted_by_dst_idx(pattern_incoming); + std::vector pattern_outgoing_ordered = + sorted_by_src_idx(pattern_outgoing); if (pattern_incoming.size() > 0) { std::unordered_map node_port_mapping; @@ -73,7 +93,7 @@ static std::optional } } match.edge_assignment.equate(widen(pattern_edge), - widen(graph_edge)); + widen(graph_edge)); } } @@ -81,7 +101,8 @@ static std::optional std::unordered_map node_port_mapping; for (int i = 0; i < outgoing_ordered.size(); ++i) { DownwardOpenMultiDiEdge graph_edge = outgoing_ordered[i], - DownwardOpenPatternEdge pattern_edge = pattern_outgoing_ordered[i]; + DownwardOpenPatternEdge pattern_edge = + pattern_outgoing_ordered[i]; NodePort graph_port = get_src_idx(graph_edge), pattern_port = get_src_idx(pattern_edge); @@ -100,7 +121,6 @@ static std::optional return match; } - std::vector find_pattern_matches(UnlabelledGraphPattern const &pattern, OpenMultiDiGraphView const &graph, diff --git a/lib/substitutions/src/substitutions/unlabelled/match_split.cc b/lib/substitutions/src/substitutions/unlabelled/match_split.cc index 10e7a9e975..ef0397d6a8 100644 --- a/lib/substitutions/src/substitutions/unlabelled/match_split.cc +++ b/lib/substitutions/src/substitutions/unlabelled/match_split.cc @@ -1,16 +1,14 @@ #include "substitutions/unlabelled/match_split.h" +#include "substitutions/unlabelled/edge_splits.h" #include "substitutions/unlabelled/multidigraph_pattern_match.h" -#include "substitutions/unlabelled/pattern_split.h" #include "substitutions/unlabelled/pattern_edge.h" -#include "substitutions/unlabelled/edge_splits.h" +#include "substitutions/unlabelled/pattern_split.h" namespace FlexFlow { MatchSplit empty_match_split() { - return MatchSplit{ - empty_multidigraph_pattern_match(), - empty_multidigraph_pattern_match() - }; + return MatchSplit{empty_multidigraph_pattern_match(), + empty_multidigraph_pattern_match()}; } MatchSplit apply_split(UnlabelledGraphPattern const &pattern, @@ -54,8 +52,10 @@ MatchSplit apply_split(UnlabelledGraphPattern const &pattern, OutputMultiDiEdge output_graph_edge = split_graph_edge.first; InputMultiDiEdge input_graph_edge = split_graph_edge.second; - handle_edge(pattern_edge_from_input_edge(input_edge), OpenMultiDiEdge{input_graph_edge}); - handle_edge(pattern_edge_from_output_edge(output_edge), OpenMultiDiEdge{output_graph_edge}); + handle_edge(pattern_edge_from_input_edge(input_edge), + OpenMultiDiEdge{input_graph_edge}); + handle_edge(pattern_edge_from_output_edge(output_edge), + OpenMultiDiEdge{output_graph_edge}); } }; @@ -66,5 +66,4 @@ MatchSplit apply_split(UnlabelledGraphPattern const &pattern, return result; } - } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc index a2d18abc07..8f4fd7f535 100644 --- a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc +++ b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc @@ -1,21 +1,21 @@ #include "substitutions/unlabelled/multidigraph_pattern_match.h" -#include "utils/containers.h" #include "substitutions/unlabelled/edge_splits.h" #include "substitutions/unlabelled/pattern_edge.h" +#include "utils/containers.h" namespace FlexFlow { MultiDiGraphPatternMatch empty_multidigraph_pattern_match() { return MultiDiGraphPatternMatch{ - bidict{}, - bidict{}, + bidict{}, + bidict{}, }; } -std::optional unsplit_matches( - MultiDiGraphPatternMatch const &prefix, - MultiDiGraphPatternMatch const &postfix, - UnlabelledPatternEdgeSplits const &edge_splits) { +std::optional + unsplit_matches(MultiDiGraphPatternMatch const &prefix, + MultiDiGraphPatternMatch const &postfix, + UnlabelledPatternEdgeSplits const &edge_splits) { MultiDiGraphPatternMatch result = empty_multidigraph_pattern_match(); @@ -30,9 +30,11 @@ std::optional unsplit_matches( OpenMultiDiEdge output_graph_edge = prefix.edge_assignment.at_l(pattern_edge_from_output_edge(output_edge)); - OpenMultiDiEdge input_graph_edge = postfix.edge_assignment.at_l(pattern_edge_from_input_edge(input_edge)); + OpenMultiDiEdge input_graph_edge = + postfix.edge_assignment.at_l(pattern_edge_from_input_edge(input_edge)); if (output_graph_edge == input_graph_edge) { - result.edge_assignment.equate(pattern_edge_from_closed_edge(closed_edge), output_graph_edge); + result.edge_assignment.equate(pattern_edge_from_closed_edge(closed_edge), + output_graph_edge); } else { return std::nullopt; } diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc index fb662e0887..3dd4987705 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc @@ -21,24 +21,18 @@ bool is_output_edge(PatternEdge const &e) { } ClosedPatternEdge require_closed_edge(PatternEdge const &e) { - assert (is_closed_edge(e)); - return ClosedPatternEdge{ - std::get(e.raw_edge) - }; + assert(is_closed_edge(e)); + return ClosedPatternEdge{std::get(e.raw_edge)}; } InputPatternEdge require_input_edge(PatternEdge const &e) { - assert (is_input_edge(e)); - return InputPatternEdge{ - std::get(e.raw_edge) - }; + assert(is_input_edge(e)); + return InputPatternEdge{std::get(e.raw_edge)}; } OutputPatternEdge require_output_edge(PatternEdge const &e) { - assert (is_output_edge(e)); - return OutputPatternEdge{ - std::get(e.raw_edge) - }; + assert(is_output_edge(e)); + return OutputPatternEdge{std::get(e.raw_edge)}; } PatternEdge pattern_edge_from_input_edge(InputPatternEdge const &e) { diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc index f6678c0e3c..335b9664ea 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -1,23 +1,23 @@ #include "substitutions/unlabelled/pattern_matching.h" -#include "substitutions/unlabelled/match_split.h" -#include "substitutions/unlabelled/unlabelled_graph_pattern.h" -#include "substitutions/unlabelled/pattern_edge.h" -#include #include "substitutions/unlabelled/input_pattern_edge.h" +#include "substitutions/unlabelled/match_split.h" #include "substitutions/unlabelled/output_pattern_edge.h" +#include "substitutions/unlabelled/pattern_edge.h" #include "substitutions/unlabelled/pattern_split.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include namespace FlexFlow { -bool unlabelled_pattern_does_match(UnlabelledGraphPattern const &pattern, - OpenMultiDiGraphView const &graph, - MultiDiGraphPatternMatch const &match, - MatchAdditionalCriterion const &additional_criterion) { +bool unlabelled_pattern_does_match( + UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MultiDiGraphPatternMatch const &match, + MatchAdditionalCriterion const &additional_criterion) { if (is_singleton_pattern(pattern)) { PatternNode pattern_node = get_only(get_nodes(pattern)); Node matched_node = match.node_assignment.at_l(pattern_node); - if (!additional_criterion.node_criterion(pattern_node, - matched_node)) { + if (!additional_criterion.node_criterion(pattern_node, matched_node)) { return false; } for (PatternEdge const &e : get_edges(pattern)) { @@ -57,18 +57,18 @@ bool unlabelled_pattern_does_match(UnlabelledGraphPattern const &pattern, } PatternSplit split = find_even_split(pattern); - std::pair subpatterns = apply_split(pattern, split); + std::pair subpatterns = + apply_split(pattern, split); auto submatches = apply_split(pattern, match, split); return unlabelled_pattern_does_match(subpatterns.first, - graph, - submatches.prefix_submatch, - additional_criterion) && + graph, + submatches.prefix_submatch, + additional_criterion) && unlabelled_pattern_does_match(subpatterns.second, - graph, - submatches.postfix_submatch, - additional_criterion); + graph, + submatches.postfix_submatch, + additional_criterion); } - } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc index 573b562395..e116c062df 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc @@ -3,33 +3,36 @@ namespace FlexFlow { PatternSplit find_even_split(UnlabelledGraphPattern const &p) { - std::vector topological_ordering = get_topological_ordering(pattern.raw_graph); + std::vector topological_ordering = + get_topological_ordering(pattern.raw_graph); assert(topological_ordering.size() >= 2); int split_point = topological_ordering.size() / 2; auto split = vector_split(topological_ordering, split_point); - std::unordered_set prefix(split.first.begin(), split.first.end()); - std::unordered_set postfix(split.second.begin(), split.second.end()); + std::unordered_set prefix(split.first.begin(), + split.first.end()); + std::unordered_set postfix(split.second.begin(), + split.second.end()); return {prefix, postfix}; } GraphSplit get_raw_split(PatternSplit const &s) { return std::pair{ - transform(s.first, [](PatternNode const &n) { return n.raw_node; }), - transform(s.second, [](PatternNode const &n) { return n.raw_node; }), + transform(s.first, [](PatternNode const &n) { return n.raw_node; }), + transform(s.second, [](PatternNode const &n) { return n.raw_node; }), }; } -UnlabelledPatternEdgeSplits get_edge_splits(UnlabelledGraphPattern const &pattern, PatternSplit const &split) { - bidict> raw_result = get_edge_splits( - pattern.raw_graph, - get_raw_split(split), - ); +UnlabelledPatternEdgeSplits + get_edge_splits(UnlabelledGraphPattern const &pattern, + PatternSplit const &split) { + bidict> + raw_result = get_edge_splits(pattern.raw_graph, get_raw_split(split), ); return UnlabelledPatternEdgeSplits{raw_result}; } std::pair - apply_split(UnlabelledGraphPattern const &p, PatternSplit const &s) { + apply_split(UnlabelledGraphPattern const &p, PatternSplit const &s) { return { get_subgraph(p, s.left); get_subgraph(p, s.right); diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc index 858b3197a8..df10507a04 100644 --- a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -13,31 +13,40 @@ bool is_singleton_pattern(UnlabelledGraphPattern const &pattern) { std::unordered_set get_nodes(UnlabelledGraphPattern const &p) { return transform(get_nodes(p.raw_graph), - [](Node const &n) { return PatternNode{n}; }}); + [](Node const &n) { + return PatternNode{n}; }}); } std::unordered_set get_edges(UnlabelledGraphPattern const &p) { return transform(get_nodes(p.raw_graph), - [](OpenMultiDiEdge const &e) { return PatternEdge{e}; }}); + [](OpenMultiDiEdge const &e) { + return PatternEdge{e}; }}); } std::vector get_topological_ordering(UnlabelledGraphPattern const &p) { return transform(get_topological_ordering(p), - [](Node const &n) { return PatternNode{n}; }}); + [](Node const &n) { + return PatternNode{n}; }}); } -UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &p, std::unordered_set const &n) { +UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &p, + std::unordered_set const &n) { return { - get_subgraph(p.raw_graph, transform(n, [](PatternNode const &n) { return n.raw_node; })); + get_subgraph(p.raw_graph, + transform(n, [](PatternNode const &n) { return n.raw_node; })); }; } -std::unordered_set get_incoming_edges(UnlabelledGraphPattern const &p, PatternNode const &n) { - return transform(get_incoming_edges(p.raw_graph, n.raw_node), [](Node const &n) { return PatternNode{n}; }); +std::unordered_set + get_incoming_edges(UnlabelledGraphPattern const &p, PatternNode const &n) { + return transform(get_incoming_edges(p.raw_graph, n.raw_node), + [](Node const &n) { return PatternNode{n}; }); } -std::unordered_set get_outgoing_edges(UnlabelledGraphPattern const &p, PatternNode const &n) { - return transform(get_outgoing_edges(p.raw_graph, n.raw_node), [](Node const &n) { return PatternNode{n}; }); +std::unordered_set + get_outgoing_edges(UnlabelledGraphPattern const &p, PatternNode const &n) { + return transform(get_outgoing_edges(p.raw_graph, n.raw_node), + [](Node const &n) { return PatternNode{n}; }); } } // namespace FlexFlow diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index 32a596e940..2d9320275d 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -8,8 +8,10 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("apply_substitution") { OperatorPattern operator_pattern_n0{ - std::vector{OperatorAttributeConstraint{ - ConstraintType::EQUAL, OperatorAttributeKey::OP_TYPE, OperatorType::LINEAR}}}; + std::vector{ + OperatorAttributeConstraint{ConstraintType::EQUAL, + OperatorAttributeKey::OP_TYPE, + OperatorType::LINEAR}}}; ParallelTensorPattern tensor_pattern_e0{ std::vector{ @@ -38,7 +40,8 @@ TEST_SUITE(FF_TEST_SUITE) { GraphPattern input_graph{ig}; OperatorAttrAssignment op_ass_n1{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{OperatorType::REPARTITION}}, + {{OperatorAttributeKey::OP_TYPE, + AttrConstant{OperatorType::REPARTITION}}, {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; @@ -101,13 +104,8 @@ TEST_SUITE(FF_TEST_SUITE) { MultiDiEdge e4{n5, p5, n4, p4}; pcg.add_edge(e4); ParallelDim dim = {2, 1, false}; - ParallelTensorDims dims = { - FFOrdered{dim} - }; - pcg.add_label(e4, - ParallelTensor(dims, - DataType::FLOAT, - CreateGrad::YES)); + ParallelTensorDims dims = {FFOrdered{dim}}; + pcg.add_label(e4, ParallelTensor(dims, DataType::FLOAT, CreateGrad::YES)); MatchAdditionalCriterion criterion{ [&](Node const &pattern_node, Node const &graph_node) { diff --git a/lib/utils/include/utils/bidict.h b/lib/utils/include/utils/bidict.h index 08c286d842..6af18c2a4a 100644 --- a/lib/utils/include/utils/bidict.h +++ b/lib/utils/include/utils/bidict.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_UTILS_BIDICT_H #define _FLEXFLOW_UTILS_BIDICT_H +#include "utils/fmt/unordered_map.h" #include #include -#include "utils/fmt/unordered_map.h" namespace FlexFlow { @@ -58,13 +58,13 @@ struct bidict { bool operator==(bidict const &other) const { bool result = this->fwd_map == other.fwd_map; - assert (result == (this->bwd_map == other.bwd_map)); + assert(result == (this->bwd_map == other.bwd_map)); return result; } bool operator!=(bidict const &other) const { bool result = this->fwd_map != other.fwd_map; - assert (result == (this->bwd_map != other.bwd_map)); + assert(result == (this->bwd_map != other.bwd_map)); return result; } @@ -193,6 +193,6 @@ struct hash<::FlexFlow::bidict> { } }; -} +} // namespace std #endif diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 37b022106f..7b6b9e4697 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -199,8 +199,7 @@ template bool all_of(C const &c, F const &f); template -std::optional optional_all_of(Container const &, - Function const &); +std::optional optional_all_of(Container const &, Function const &); template int count(C const &c, F const &f); diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 0db08b77e8..bc34b00a47 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -5,6 +5,7 @@ #include "containers.decl.h" #include "required_core.h" #include "type_traits_core.h" +#include "utils/containers/extend_vector.h" #include "utils/exception.h" #include "utils/type_traits.h" #include @@ -18,7 +19,6 @@ #include #include #include -#include "utils/containers/extend_vector.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/concat_vectors.h b/lib/utils/include/utils/containers/concat_vectors.h index 8c6858c84e..7940a37510 100644 --- a/lib/utils/include/utils/containers/concat_vectors.h +++ b/lib/utils/include/utils/containers/concat_vectors.h @@ -6,7 +6,8 @@ namespace FlexFlow { template -std::vector concat_vectors(std::vector const &prefix, std::vector const &postfix) { +std::vector concat_vectors(std::vector const &prefix, + std::vector const &postfix) { std::vector result = prefix; extend_vector(result, postfix); return result; diff --git a/lib/utils/include/utils/containers/enumerate_vector.h b/lib/utils/include/utils/containers/enumerate_vector.h index bf927d8415..8d36a5fe3b 100644 --- a/lib/utils/include/utils/containers/enumerate_vector.h +++ b/lib/utils/include/utils/containers/enumerate_vector.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H -#include #include +#include namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/extend_vector.h b/lib/utils/include/utils/containers/extend_vector.h index 289ead16c0..62ce94e49c 100644 --- a/lib/utils/include/utils/containers/extend_vector.h +++ b/lib/utils/include/utils/containers/extend_vector.h @@ -11,7 +11,6 @@ void extend_vector(std::vector &lhs, C const &rhs) { lhs.insert(lhs.end(), rhs.begin(), rhs.end()); } - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/exception.decl.h b/lib/utils/include/utils/exception.decl.h index a8cb150ec9..e41dff9b5a 100644 --- a/lib/utils/include/utils/exception.decl.h +++ b/lib/utils/include/utils/exception.decl.h @@ -7,14 +7,20 @@ namespace FlexFlow { #ifdef FF_REQUIRE_IMPLEMENTED -#define NOT_IMPLEMENTED() static_assert(false, "Function " __FUNC__ " not yet implemented " __FILE__ ":" __LINE__); +#define NOT_IMPLEMENTED() \ + static_assert(false, \ + "Function " __FUNC__ " not yet implemented " __FILE__ \ + ":" __LINE__); #else -#define NOT_IMPLEMENTED() throw not_implemented(__PRETTY_FUNCTION__, __FILE__, __LINE__); +#define NOT_IMPLEMENTED() \ + throw not_implemented(__PRETTY_FUNCTION__, __FILE__, __LINE__); #endif class not_implemented : public std::logic_error { public: - not_implemented(std::string const &function_name, std::string const &file_name, int line); + not_implemented(std::string const &function_name, + std::string const &file_name, + int line); }; template diff --git a/lib/utils/include/utils/fmt.decl.h b/lib/utils/include/utils/fmt.decl.h index 71d00e1c5a..93fe6ee8b6 100644 --- a/lib/utils/include/utils/fmt.decl.h +++ b/lib/utils/include/utils/fmt.decl.h @@ -2,11 +2,11 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_FMT_DECL_H #include "fmt/format.h" -#include -#include -#include #include "utils/check_fmtable.h" +#include #include +#include +#include #define DELEGATE_OSTREAM(...) \ template <> \ @@ -28,10 +28,10 @@ namespace fmt { template struct formatter< - ::std::unordered_set, - Char, - std::enable_if_t>::value> -> : formatter<::std::string, Char> { + ::std::unordered_set, + Char, + std::enable_if_t>::value>> + : formatter<::std::string, Char> { template auto format(::std::unordered_set const &m, FormatContext &ctx) -> decltype(ctx.out()); @@ -42,10 +42,10 @@ struct formatter< template struct formatter< - ::std::vector, - Char, - std::enable_if_t>::value> -> : formatter<::std::string> { + ::std::vector, + Char, + std::enable_if_t>::value>> + : formatter<::std::string> { template auto format(::std::vector const &m, FormatContext &ctx) -> decltype(ctx.out()); diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index 8ca5b34fc2..9f4d1de500 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -5,26 +5,26 @@ #include "utils/fmt.decl.h" #include "utils/test_types.h" #include "utils/type_traits_core.h" -#include #include -#include #include +#include +#include namespace fmt { template template auto formatter< - ::std::unordered_set, - Char, - std::enable_if_t>::value> ->::format( - ::std::unordered_set const &m, FormatContext &ctx) - -> decltype(ctx.out()) { + ::std::unordered_set, + Char, + std::enable_if_t>::value>>:: + format(::std::unordered_set const &m, FormatContext &ctx) + -> decltype(ctx.out()) { /* CHECK_FMTABLE(T); */ /* std::string result = ::FlexFlow::join_strings( */ - /* m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); */ + /* m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); + * }); */ std::string result = ""; return formatter::format(result, ctx); } @@ -32,17 +32,18 @@ auto formatter< /* template */ /* std::string format_as(::std::unordered_set const &m) { */ /* return::string result = ::FlexFlow::join_strings( */ -/* m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); */ +/* m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); + * }); */ /* } */ template template auto formatter< - ::std::vector, - Char, - std::enable_if_t>::value> ->::format(::std::vector const &m, FormatContext &ctx) - -> decltype(ctx.out()) { + ::std::vector, + Char, + std::enable_if_t>::value>>:: + format(::std::vector const &m, FormatContext &ctx) + -> decltype(ctx.out()) { CHECK_FMTABLE(T); std::string result = ::FlexFlow::join_strings( @@ -56,7 +57,8 @@ auto formatter<::std::variant>::format(::std::variant const &m, FormatContext &ctx) -> decltype(ctx.out()) { - std::string result = std::visit([](auto &&x) { return fmt::to_string(x); }, m); + std::string result = + std::visit([](auto &&x) { return fmt::to_string(x); }, m); return formatter::format(result, ctx); } @@ -71,13 +73,13 @@ auto formatter<::std::variant>::format(::std::variant const &m, /* /1* CHECK_FMTABLE(T); *1/ */ /* /1* std::string result = ::FlexFlow::join_strings( *1/ */ -/* /1* m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); *1/ */ +/* /1* m.cbegin(), m.cend(), ", ", [](T const &t) { return + * fmt::to_string(t); }); *1/ */ /* NOT_IMPLEMENTED(); */ /* std::string result = ""; */ /* return formatter::format(result, ctx); */ /* } */ - } // namespace fmt namespace FlexFlow { diff --git a/lib/utils/include/utils/fmt/pair.h b/lib/utils/include/utils/fmt/pair.h index 6a680616c6..eb1147ae3c 100644 --- a/lib/utils/include/utils/fmt/pair.h +++ b/lib/utils/include/utils/fmt/pair.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H -#include #include "fmt/format.h" #include "utils/check_fmtable.h" +#include namespace FlexFlow { diff --git a/lib/utils/include/utils/fmt/unordered_map.h b/lib/utils/include/utils/fmt/unordered_map.h index 1287302c7a..19701bfb0c 100644 --- a/lib/utils/include/utils/fmt/unordered_map.h +++ b/lib/utils/include/utils/fmt/unordered_map.h @@ -2,28 +2,29 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_MAP_H #include "fmt/format.h" -#include -#include "utils/join_strings.h" #include "utils/check_fmtable.h" +#include "utils/join_strings.h" +#include namespace fmt { template struct formatter< - ::std::unordered_map, - Char, - std::enable_if_t>::value> -> : formatter<::std::string> { + ::std::unordered_map, + Char, + std::enable_if_t>::value>> + : formatter<::std::string> { template auto format(::std::unordered_map const &m, FormatContext &ctx) -> decltype(ctx.out()) { - /* CHECK_FMTABLE(K); */ - /* CHECK_FMTABLE(V); */ - - /* std::string result = ::FlexFlow::join_strings( */ - /* m.cbegin(), m.cend(), ", ", [](std::pair const &p) { return fmt::to_string(p); }); */ - std::string result = ""; - return formatter::format(result, ctx); + /* CHECK_FMTABLE(K); */ + /* CHECK_FMTABLE(V); */ + + /* std::string result = ::FlexFlow::join_strings( */ + /* m.cbegin(), m.cend(), ", ", [](std::pair const &p) { return + * fmt::to_string(p); }); */ + std::string result = ""; + return formatter::format(result, ctx); } }; diff --git a/lib/utils/include/utils/join_strings.h b/lib/utils/include/utils/join_strings.h index 9c761fc9ac..db82004317 100644 --- a/lib/utils/include/utils/join_strings.h +++ b/lib/utils/include/utils/join_strings.h @@ -38,7 +38,6 @@ std::string join_strings(Container const &c, std::string const &delimiter) { return join_strings(c.cbegin(), c.cend(), delimiter); } - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/json.h b/lib/utils/include/utils/json.h index 46176f366e..f56917e329 100644 --- a/lib/utils/include/utils/json.h +++ b/lib/utils/include/utils/json.h @@ -149,7 +149,7 @@ struct VariantToJsonFunctor { template void variant_to_json(json &j, std::variant const &v) { - json jval; + json jval; visit(::FlexFlow::VariantToJsonFunctor{jval}, v); j["value"] = jval; j["index"] = v.index(); @@ -167,7 +167,7 @@ std::optional variant_from_json_impl(json const &j) { template std::optional variant_from_json_impl(json const &j, - std::index_sequence) { + std::index_sequence) { // If there were no errors when parsing, all but one element of the array // will be nullopt. This is because each call to variant_from_json_impl will // have a unique index and exactly one of them will match the index in the diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 7abbf8ab17..2594a96c8e 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -25,7 +25,8 @@ T const &assert_unwrap(std::optional const &o) { } template -std::optional> transform(std::optional const &o, F &&f) { +std::optional> transform(std::optional const &o, + F &&f) { using Return = std::invoke_result_t; if (o.has_value()) { Return r = f(o.value()); @@ -41,10 +42,10 @@ namespace fmt { template struct formatter< - ::std::optional, - Char, - std::enable_if_t>::value> -> : formatter { + ::std::optional, + Char, + std::enable_if_t>::value>> + : formatter { template auto format(::std::optional const &q, FormatContext &ctx) -> decltype(ctx.out()) { diff --git a/lib/utils/include/utils/stack_string.h b/lib/utils/include/utils/stack_string.h index fccbbb3810..0074877768 100644 --- a/lib/utils/include/utils/stack_string.h +++ b/lib/utils/include/utils/stack_string.h @@ -4,10 +4,10 @@ #include "fmt/core.h" #include "stack_vector.h" #include "utils/fmt.h" +#include "utils/json.h" #include "utils/type_traits.h" #include #include -#include "utils/json.h" namespace FlexFlow { @@ -78,7 +78,6 @@ void from_json(json const &j, stack_string &v) { v = stack_string{as_string}; } - } // namespace FlexFlow namespace std { diff --git a/lib/utils/src/exception.cc b/lib/utils/src/exception.cc index cb7fb0397b..5f78491ef2 100644 --- a/lib/utils/src/exception.cc +++ b/lib/utils/src/exception.cc @@ -2,7 +2,12 @@ namespace FlexFlow { -not_implemented::not_implemented(std::string const &function_name, std::string const &file_name, int line) - : std::logic_error(fmt::format("Function '{}' not yet implemented at {}:{}", function_name, file_name, line)){}; +not_implemented::not_implemented(std::string const &function_name, + std::string const &file_name, + int line) + : std::logic_error(fmt::format("Function '{}' not yet implemented at {}:{}", + function_name, + file_name, + line)){}; } diff --git a/lib/utils/src/utils/integer_conversions.cc b/lib/utils/src/utils/integer_conversions.cc index 7156aab896..07ff6106a3 100644 --- a/lib/utils/src/utils/integer_conversions.cc +++ b/lib/utils/src/utils/integer_conversions.cc @@ -4,7 +4,7 @@ namespace FlexFlow { size_t size_t_from_int(int x) { - assert (x >= 0); + assert(x >= 0); return static_cast(x); } From 88709f056892bdd90ce3a6a6c5060d0b91966df6 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 26 May 2024 22:44:30 -0700 Subject: [PATCH 21/43] Add initial shape inference for BMM --- .proj.toml | 2 +- .../include/op-attrs/get_output_shapes.h | 3 - .../include/op-attrs/ops/batch_matmul.h | 10 + .../include/op-attrs/parallel_tensor_shape.h | 5 + lib/op-attrs/src/op-attrs/ops/batch_matmul.cc | 149 +++++++++-- .../src/op-attrs/parallel_tensor_dims.cc | 4 +- .../src/op-attrs/parallel_tensor_shape.cc | 12 + lib/op-attrs/src/op-attrs/tensor_shape.cc | 5 +- lib/op-attrs/test/src/test_batch_matmul.cc | 234 ++++++++++++++++++ lib/pcg/include/pcg/dataflow_input.dtg.h | 101 ++++++++ .../include/pcg/dataflow_input.variant.toml | 21 ++ lib/pcg/include/pcg/open_dataflow_graph.h | 75 ++++++ lib/pcg/src/pcg/dataflow_input.dtg.cc | 41 +++ .../tensor_pattern/get_attribute.cc | 19 +- lib/utils/include/utils/fmt/expected.h | 34 +++ lib/utils/include/utils/graph/multidiedge.h | 4 + lib/utils/src/utils/graph/multidiedge.cc | 17 ++ .../test/common/include/test/utils/doctest.h | 16 +- 18 files changed, 706 insertions(+), 46 deletions(-) create mode 100644 lib/op-attrs/test/src/test_batch_matmul.cc create mode 100644 lib/pcg/include/pcg/dataflow_input.dtg.h create mode 100644 lib/pcg/include/pcg/dataflow_input.variant.toml create mode 100644 lib/pcg/include/pcg/open_dataflow_graph.h create mode 100644 lib/pcg/src/pcg/dataflow_input.dtg.cc create mode 100644 lib/utils/include/utils/fmt/expected.h create mode 100644 lib/utils/src/utils/graph/multidiedge.cc diff --git a/.proj.toml b/.proj.toml index 43f1522186..6b97ff4a37 100644 --- a/.proj.toml +++ b/.proj.toml @@ -8,7 +8,7 @@ build_targets = [ "op-attrs", "kernels", "pcg", - # "substitutions", + "substitutions", # "compiler", ] test_targets = [ diff --git a/lib/op-attrs/include/op-attrs/get_output_shapes.h b/lib/op-attrs/include/op-attrs/get_output_shapes.h index 5f8732a9d7..c99ab6d901 100644 --- a/lib/op-attrs/include/op-attrs/get_output_shapes.h +++ b/lib/op-attrs/include/op-attrs/get_output_shapes.h @@ -112,9 +112,6 @@ std::vector get_output_shapes(Attrs const &attrs, ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &, std::vector const &); -ParallelTensorShape get_output_shape(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(CastAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(CombineAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 412d694f69..e08cdef70b 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -3,6 +3,8 @@ #include "op-attrs/ops/batch_matmul.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { @@ -10,6 +12,14 @@ bool is_valid(BatchMatmulAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &); + +tl::expected get_output_shape(BatchMatmulAttrs const &attrs, + TensorShape const &input_lhs, + TensorShape const &input_rhs); + +tl::expected get_output_shape(BatchMatmulAttrs const &attrs, + ParallelTensorShape const &input_lhs, + ParallelTensorShape const &input_rhs); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 0482989e0e..5880921c55 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -19,6 +19,11 @@ TensorShape get_piece_shape(ParallelTensorShape const &); int get_num_replica_dims(ParallelTensorShape const &); int get_num_replicas(ParallelTensorShape const &); +int get_sum_degree(ParallelTensorShape const &); +int get_discard_copy_degree(ParallelTensorShape const &); + +int get_total_parallel_degree(ParallelTensorShape const &); + bool is_valid(ParallelTensorShape const &); TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &); diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc index cd3b198955..69e644441c 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc @@ -1,27 +1,27 @@ #include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { -/* bool BatchMatmulAttrs::is_valid( */ -/* ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const { - */ -/* if (!lhs.is_valid() || !rhs.is_valid()) { */ -/* return false; */ -/* } */ -/* if (lhs.num_dims() != rhs.num_dims()) { */ -/* return false; */ -/* } */ -/* for (int i = lhs.num_dims() - 1; i >= 2; i--) { */ -/* if (lhs.at(i) != rhs.at(i)) { */ -/* return false; */ -/* } */ -/* } */ -/* if (lhs.at(0) != rhs.at(1)) { */ -/* return false; */ -/* } */ - -/* return true; */ -/* } */ +// bool BatchMatmulAttrs::is_valid( +// ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const { +// if (!lhs.is_valid() || !rhs.is_valid()) { +// return false; +// } +// if (lhs.num_dims() != rhs.num_dims()) { +// return false; +// } +// for (int i = lhs.num_dims() - 1; i >= 2; i--) { +// if (lhs.at(i) != rhs.at(i)) { +// return false; +// } +// } +// if (lhs.at(0) != rhs.at(1)) { +// return false; +// } +// +// return true; +// } bool is_valid(BatchMatmulAttrs const &, ParallelTensorShape const &, @@ -29,4 +29,113 @@ bool is_valid(BatchMatmulAttrs const &, NOT_IMPLEMENTED(); } +tl::expected get_output_shape(BatchMatmulAttrs const &attrs, + TensorShape const &input_lhs, + TensorShape const &input_rhs) { + // If input_lhs is a (b×n×m) tensor, + // input_rhs is a (b×m×p) tensor, + // out will be a (b×n×p) tensor. + // https://pytorch.org/docs/stable/generated/torch.bmm.html + + if (num_dims(input_lhs) != 3) { + return tl::unexpected(fmt::format("LHS input has incorrect number of shard dims: {} != {}", num_dims(input_lhs), 3)); + } + if (num_dims(input_rhs) != 3) { + return tl::unexpected(fmt::format("RHS input has incorrect number of shard dims: {} != {}", num_dims(input_rhs), 3)); + } + if (input_lhs.data_type != input_rhs.data_type) { + return tl::unexpected(fmt::format("Input datatypes do not match: {} != {}", input_lhs.data_type, input_rhs.data_type)); + } + + size_t lhs_b = dim_at_idx(input_lhs, ff_dim_t{0}); + size_t n = dim_at_idx(input_lhs, ff_dim_t{1}); + size_t lhs_m = dim_at_idx(input_lhs, ff_dim_t{2}); + + size_t rhs_b = dim_at_idx(input_rhs, ff_dim_t{0}); + size_t rhs_m = dim_at_idx(input_rhs, ff_dim_t{1}); + size_t p = dim_at_idx(input_rhs, ff_dim_t{2}); + + if (lhs_b != rhs_b) { + return tl::unexpected(fmt::format("LHS b dim ({}) != RHS b dim ({})", lhs_b, rhs_b)); + } + if (lhs_m != rhs_m) { + return tl::unexpected(fmt::format("RHS m dim ({}) != RHS m dim ({})", lhs_m, rhs_m)); + } + + return TensorShape{ + TensorDims{ + FFOrdered{ + lhs_b, + n, + p, + }, + }, + input_lhs.data_type, + }; +} + +tl::expected get_output_shape(BatchMatmulAttrs const &attrs, + ParallelTensorShape const &input_lhs, + ParallelTensorShape const &input_rhs) { + if (num_shard_dims(input_lhs) != 3) { + return tl::unexpected(fmt::format("LHS input has incorrect number of shard dims: {} != {}", num_shard_dims(input_lhs), 3)); + } + if (num_shard_dims(input_rhs) != 3) { + return tl::unexpected(fmt::format("RHS input has incorrect number of shard dims: {} != {}", num_shard_dims(input_rhs), 3)); + } + if (input_lhs.data_type != input_rhs.data_type) { + return tl::unexpected(fmt::format("Input datatypes do not match: {} != {}", input_lhs.data_type, input_rhs.data_type)); + } + + assert (get_total_parallel_degree(input_lhs) == get_total_parallel_degree(input_rhs)); + + ShardParallelDim lhs_b = shard_dim_at_idx(input_lhs, ff_dim_t{0}); + ShardParallelDim n = shard_dim_at_idx(input_lhs, ff_dim_t{1}); + ShardParallelDim lhs_m = shard_dim_at_idx(input_lhs, ff_dim_t{2}); + + ShardParallelDim rhs_b = shard_dim_at_idx(input_rhs, ff_dim_t{0}); + ShardParallelDim rhs_m = shard_dim_at_idx(input_rhs, ff_dim_t{1}); + ShardParallelDim p = shard_dim_at_idx(input_rhs, ff_dim_t{2}); + + if (lhs_b != rhs_b) { + return tl::unexpected(fmt::format("LHS b dim ({}) != RHS b dim ({})", lhs_b, rhs_b)); + } + + if (lhs_m != rhs_m) { + return tl::unexpected(fmt::format("LHS m dim ({}) != RHS m dim ({})", lhs_m, rhs_m)); + } + + if (get_discard_copy_degree(input_lhs) != get_sum_degree(input_rhs) * p.degree) { + return tl::unexpected(fmt::format("Unexpected number of replicas in LHS: lhs.= ({}) != rhs.+ ({}) * rhs.p ({})", get_discard_copy_degree(input_lhs), get_sum_degree(input_rhs), p.degree)); + } + + if (get_discard_copy_degree(input_rhs) != get_sum_degree(input_lhs) * n.degree) { + return tl::unexpected(fmt::format("Unexpected number of replicas in RHS: rhs.= ({}) != lhs.+ ({}) * lhs.n ({})", get_discard_copy_degree(input_rhs), get_sum_degree(input_lhs), n.degree)); + } + + ShardParallelDim output_b = lhs_b; + ShardParallelDim output_n = n; + ShardParallelDim output_p = p; + + int output_discard_copy_degree = 1; + int output_sum_degree = get_total_parallel_degree(input_lhs) / (output_b.degree * output_n.degree * output_p.degree); + + ParallelTensorShape result = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + output_b, + output_n, + output_p, + }, + ReplicaParallelDimSet{ + output_sum_degree, + output_discard_copy_degree, + }, + }, + input_lhs.data_type, + }; + + return result; +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index c3cad6de19..ff4b864378 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -20,9 +20,7 @@ size_t num_shard_dims(ParallelTensorDims const &dims) { } int total_replica_degree(ParallelTensorDims const &dims) { - return product(transform(replica_dims(dims), [](ReplicaParallelDim const &d) { - return d.degree; - })); + return dims.replica_dims.discard_copy_degree * dims.replica_dims.sum_degree; } int total_shard_degree(ParallelTensorDims const &dims) { diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 66e99b1e86..64d48db678 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -20,6 +20,18 @@ int get_num_replicas(ParallelTensorShape const &shape) { [](ReplicaParallelDim const &d) -> int { return d.degree; })); } +int get_sum_degree(ParallelTensorShape const &shape) { + return shape.dims.replica_dims.sum_degree; +} + +int get_discard_copy_degree(ParallelTensorShape const &shape) { + return shape.dims.replica_dims.discard_copy_degree; +} + +int get_total_parallel_degree(ParallelTensorShape const &s) { + return total_parallel_degree(s.dims); +} + bool is_valid(ParallelTensorShape const &shape) { return is_valid(shape.dims); } diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.cc b/lib/op-attrs/src/op-attrs/tensor_shape.cc index 5b8d6572d9..01afbddf1e 100644 --- a/lib/op-attrs/src/op-attrs/tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.cc @@ -8,8 +8,7 @@ size_t num_dims(TensorShape const &s) { } size_t dim_at_idx(TensorShape const &s, ff_dim_t idx) { - if (idx.value < 0) { - return dim_at_idx(s.dims, idx); - } + return dim_at_idx(s.dims, idx); +} } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/test_batch_matmul.cc b/lib/op-attrs/test/src/test_batch_matmul.cc new file mode 100644 index 0000000000..66bd0b9edc --- /dev/null +++ b/lib/op-attrs/test/src/test_batch_matmul.cc @@ -0,0 +1,234 @@ +#include "test/utils/doctest.h" +#include "op-attrs/ops/batch_matmul.h" + + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(BatchMatmulAttrs, TensorShape)") { + size_t b = 4; + size_t m = 6; + size_t n = 8; + size_t p = 10; + + BatchMatmulAttrs attrs = { + /*a_seq_length_dim=*/0, // TODO figure out if these arguments are still relevant + /*b_seq_length_dim=*/0, + }; + + TensorShape input_lhs_shape = { + TensorDims{ + FFOrdered{ + b, + n, + m, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("valid") { + TensorShape input_rhs_shape = { + TensorDims{ + FFOrdered{ + b, + m, + p, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = get_output_shape(attrs, input_lhs_shape, input_rhs_shape); + + tl::expected correct_output_shape = TensorShape{ + TensorDims{ + FFOrdered{ + b, + n, + p, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct_output_shape); + } + + SUBCASE("mismatched b") { + TensorShape input_rhs_shape = { + TensorDims{ + FFOrdered{ + b + 1, + m, + p, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = get_output_shape(attrs, input_lhs_shape, input_rhs_shape); + + CHECK(!result.has_value()); + } + + SUBCASE("mismatched m") { + TensorShape input_rhs_shape = { + TensorDims{ + FFOrdered{ + b, + m + 1, + p, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = get_output_shape(attrs, input_lhs_shape, input_rhs_shape); + + CHECK(!result.has_value()); + } + } + + TEST_CASE("get_output_shape(BatchMatmulAttrs, ParallelTensorShape)") { + size_t b = 2 * 2; + int o_b = 2; + size_t m = 3 * 3; + int o_m = 3; + size_t n = 5 * 5; + int o_n = 5; + size_t p = 7 * 7; + int o_p = 7; + int o_sum = 11; + + BatchMatmulAttrs attrs = { + /*a_seq_length_dim=*/0, // TODO figure out if these arguments are still relevant + /*b_seq_length_dim=*/0, + }; + + auto make_lhs = [&](int o_sum, int o_eq, int o_b, int o_n, int o_m) { + return ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{b, o_b}, + ShardParallelDim{n, o_n}, + ShardParallelDim{m, o_m}, + }, + ReplicaParallelDimSet{ + o_sum, + o_eq, + }, + }, + DataType::FLOAT, + }; + }; + + auto make_rhs = [&](int o_sum, int o_eq, int o_b, int o_m, int o_p) { + return ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{b, o_b}, + ShardParallelDim{m, o_m}, + ShardParallelDim{p, o_p}, + }, + ReplicaParallelDimSet{ + o_sum, + o_eq, + }, + }, + DataType::FLOAT, + }; + }; + + auto make_output = [&](int o_sum, int o_eq, int o_b, int o_n, int o_p) { + return ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{b, o_b}, + ShardParallelDim{n, o_n}, + ShardParallelDim{p, o_p}, + }, + ReplicaParallelDimSet{ + o_sum, + o_eq, + }, + }, + DataType::FLOAT, + }; + }; + + SUBCASE("data parallel") { + tl::expected result = get_output_shape(attrs, make_lhs(1, 1, o_b, 1, 1), make_rhs(1, 1, o_b, 1, 1)); + tl::expected correct = make_output(1, 1, o_b, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("n parallel") { + tl::expected result = get_output_shape(attrs, make_lhs(1, 1, 1, o_n, 1), make_rhs(1, o_n, 1, 1, 1)); + tl::expected correct = make_output(1, 1, 1, o_n, 1); + + CHECK(result == correct); + } + + SUBCASE("p parallel") { + tl::expected result = get_output_shape(attrs, make_lhs(1, o_p, 1, 1, 1), make_rhs(1, 1, 1, 1, o_p)); + tl::expected correct = make_output(1, 1, 1, 1, o_p); + + CHECK(result == correct); + } + + SUBCASE("reduction parallel") { + tl::expected result = get_output_shape(attrs, make_lhs(1, 1, 1, 1, o_m), make_rhs(1, 1, 1, o_m, 1)); + tl::expected correct = make_output(o_m, 1, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("propagate reduction lhs") { + tl::expected result = get_output_shape(attrs, make_lhs(o_sum, 1, 1, 1, 1), make_rhs(1, o_sum, 1, 1, 1)); + tl::expected correct = make_output(o_sum, 1, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("propagate reduction rhs") { + tl::expected result = get_output_shape(attrs, make_lhs(1, o_sum, 1, 1, 1), make_rhs(o_sum, 1, 1, 1, 1)); + tl::expected correct = make_output(o_sum, 1, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("reduction lhs & reduction rhs") { + tl::expected result = get_output_shape(attrs, make_lhs(o_sum, o_sum, 1, 1, 1), make_rhs(o_sum, o_sum, 1, 1, 1)); + tl::expected correct = make_output(o_sum * o_sum, 1, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("reduction lhs & rhs (invalid)") { + tl::expected result = get_output_shape(attrs, make_lhs(o_sum, 1, 1, 1, 1), make_rhs(o_sum, 1, 1, 1, 1)); + + CHECK_MESSAGE(!result.has_value(), "Unexpected successful value: ", result); + } + + SUBCASE("reduction lhs & n") { + tl::expected result = get_output_shape(attrs, make_lhs(o_sum, 1, 1, o_n, 1), make_rhs(1, o_sum * o_n, 1, 1, 1)); + tl::expected correct = make_output(o_sum, 1, 1, o_n, 1); + + CHECK(result == correct); + } + + SUBCASE("reduction lhs & reduction rhs & n") { + tl::expected result = get_output_shape(attrs, make_lhs(o_sum, o_sum, 1, o_n, 1), make_rhs(o_sum, o_sum * o_n, 1, 1, 1)); + tl::expected correct = make_output(o_sum * o_sum, 1, 1, o_n, 1); + + CHECK(result == correct); + } + + SUBCASE("reduction lhs & reduction rhs & n & m") { + tl::expected result = get_output_shape(attrs, make_lhs(o_sum, o_sum, 1, o_n, o_m), make_rhs(o_sum, o_sum * o_n, 1, o_m, 1)); + tl::expected correct = make_output(o_sum * o_sum * o_m, 1, 1, o_n, 1); + + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/include/pcg/dataflow_input.dtg.h b/lib/pcg/include/pcg/dataflow_input.dtg.h new file mode 100644 index 0000000000..c698c75c25 --- /dev/null +++ b/lib/pcg/include/pcg/dataflow_input.dtg.h @@ -0,0 +1,101 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/dataflow_input.variant.toml +/* proj-data +{ + "generated_from": "d6a7f4570e36e257383529e9bf9390ec" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_INPUT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_INPUT_DTG_H + +#include "utils/graph/multidiedge.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct DataflowInput { + DataflowInput() = delete; + explicit DataflowInput(::FlexFlow::MultiDiOutput const &); + explicit DataflowInput(int const &); + template + static constexpr bool IsPartOfDataflowInput_v = + std::is_same_v || std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::MultiDiOutput>()); + return result; + } + case 1: { + ReturnType result = v(this->get()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type DataflowInput", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::MultiDiOutput>()); + return result; + } + case 1: { + ReturnType result = v(this->get()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type DataflowInput", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfDataflowInput_v, + "DataflowInput::has() expected one of " + "[::FlexFlow::MultiDiOutput, int], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfDataflowInput_v, + "DataflowInput::get() expected one of " + "[::FlexFlow::MultiDiOutput, int], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfDataflowInput_v, + "DataflowInput::get() expected one of " + "[::FlexFlow::MultiDiOutput, int], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(DataflowInput const &) const; + bool operator!=(DataflowInput const &) const; + bool operator<(DataflowInput const &) const; + bool operator>(DataflowInput const &) const; + bool operator<=(DataflowInput const &) const; + bool operator>=(DataflowInput const &) const; + std::variant<::FlexFlow::MultiDiOutput, int> raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::DataflowInput> { + size_t operator()(::FlexFlow::DataflowInput const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_INPUT_DTG_H diff --git a/lib/pcg/include/pcg/dataflow_input.variant.toml b/lib/pcg/include/pcg/dataflow_input.variant.toml new file mode 100644 index 0000000000..ac7c3ae5d7 --- /dev/null +++ b/lib/pcg/include/pcg/dataflow_input.variant.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "DataflowInput" +features = [ + "eq", + "ord", + "hash", + # "json", + # "fmt", +] + +includes = [ + "utils/graph/multidiedge.h" , +] + +[[values]] +type = "::FlexFlow::MultiDiOutput" +key = "internal" + +[[values]] +type = "int" +key = "external" diff --git a/lib/pcg/include/pcg/open_dataflow_graph.h b/lib/pcg/include/pcg/open_dataflow_graph.h new file mode 100644 index 0000000000..29454c4414 --- /dev/null +++ b/lib/pcg/include/pcg/open_dataflow_graph.h @@ -0,0 +1,75 @@ +// #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPEN_DATAFLOW_GRAPH_H +// #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPEN_DATAFLOW_GRAPH_H +// +// #include "utils/containers/enumerate_vector.h" +// #include "utils/graph.h" +// #include "pcg/dataflow_input.dtg.h" +// +// namespace FlexFlow { +// +// template +// struct OpenDataflowGraph { +// public: +// OpenDataflowGraph() +// : g(OutputLabelledOpenMultiDiGraph::template create< +// UnorderedOutputLabelledOpenMultiDiGraph>()) { } +// +// DataflowInput add_external_input(OutputLabel const &label) { +// /* size_t src_node_idx = edge_uid_ctr; */ +// /* edge_uid_ctr++; */ +// /* size_t src_port_idx = 0; */ +// /* edge_uid_t edge_uid = { src_node_idx, src_port_idx }; */ +// /* return MultiDiOutput{edge_uid}; */ +// } +// +// std::vector add_operator(NodeLabel const &func, std::vector const &inputs, std::vector const &outputs) { +// Node n = this->g.add_node(func); +// for (auto const &[idx, input] : enumerate_vector(inputs)) { +// this->g.add_edge(MultiDiEdge{input.src, input.src_idx, n, this->make_port_for_idx(idx)}); +// } +// +// std::vector result; +// for (auto const &[idx, label] : enumerate_vector(outputs)) { +// MultiDiOutput output = MultiDiOutput{n, this->make_port_for_idx(idx)}; +// this->g.add_output(output, label); +// result.push_back(output); +// } +// +// return result; +// } +// +// NodePort make_port_for_idx(int idx) { +// if (!this->port_mapping.contains_l(idx)) { +// this->port_mapping.equate(idx, this->g.add_node_port()); +// } +// return this->port_mapping.at_l(idx); +// } +// +// NodePort port_for_idx(int idx) const { +// return this->port_mapping.at_l(idx); +// } +// +// int idx_for_port(NodePort const &p) const { +// return this->port_mapping.at_r(p); +// } +// +// OutputLabelledMultiDiGraphView const &get_raw_graph() const { +// return this->g; +// } +// +// NodeLabel const &at(Node const &n) const { +// return this->g.at(n); +// } +// +// OutputLabel const &at(MultiDiOutput const &o) const { +// return this->g.at(o); +// } +// private: +// OutputLabelledOpenMultiDiGraph g; +// bidict port_mapping; +// size_t edge_uid_ctr = 0; +// }; +// +// } // namespace FlexFlow +// +// #endif diff --git a/lib/pcg/src/pcg/dataflow_input.dtg.cc b/lib/pcg/src/pcg/dataflow_input.dtg.cc new file mode 100644 index 0000000000..bd5a43dfa9 --- /dev/null +++ b/lib/pcg/src/pcg/dataflow_input.dtg.cc @@ -0,0 +1,41 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/dataflow_input.variant.toml +/* proj-data +{ + "generated_from": "d6a7f4570e36e257383529e9bf9390ec" +} +*/ + +#include "pcg/dataflow_input.dtg.h" + +namespace FlexFlow { +DataflowInput::DataflowInput(::FlexFlow::MultiDiOutput const &v) + : raw_variant(v) {} +DataflowInput::DataflowInput(int const &v) : raw_variant(v) {} +bool DataflowInput::operator==(DataflowInput const &other) const { + return this->raw_variant == other.raw_variant; +} +bool DataflowInput::operator!=(DataflowInput const &other) const { + return this->raw_variant != other.raw_variant; +} +bool DataflowInput::operator<(DataflowInput const &other) const { + return this->raw_variant < other.raw_variant; +} +bool DataflowInput::operator>(DataflowInput const &other) const { + return this->raw_variant > other.raw_variant; +} +bool DataflowInput::operator<=(DataflowInput const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool DataflowInput::operator>=(DataflowInput const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::DataflowInput>::operator()( + ::FlexFlow::DataflowInput const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc index 4a5330d4af..4fdbc6a2ff 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc @@ -1,26 +1,25 @@ #include "substitutions/tensor_pattern/get_attribute.h" #include "utils/containers.h" +#include "utils/integer_conversions.h" namespace FlexFlow { -TensorAttributeValue get_attribute(ParallelTensorAttrs const &attrs, - TensorAttributeKey key) { +TensorAttributeValue get_attribute(ParallelTensorAttrs const &attrs, TensorAttributeKey key) { switch (key) { case TensorAttributeKey::DIM_SIZES: { - std::vector sizes = - transform(as_vector(ff_ordered(attrs.shape.dims)), - [](ParallelDim const &d) { return d.size; }); + std::vector sizes = transform(as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + [](ShardParallelDim const &d) { return d.size; }); return TensorAttributeValue{sizes}; } case TensorAttributeKey::DIM_DEGREES: { - std::vector degrees = transform( - as_vector(ff_ordered(attrs.shape.dims)), - [](ParallelDim const &d) { return static_cast(d.degree); }); + std::vector degrees = transform(as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + [](ShardParallelDim const &d) { + return size_t_from_int(d.degree); + }); return TensorAttributeValue{degrees}; } default: - throw std::runtime_error( - fmt::format("Unknown TensorAttributeKey {}", static_cast(key))); + throw std::runtime_error(fmt::format("Unknown TensorAttributeKey {}", static_cast(key))); } } diff --git a/lib/utils/include/utils/fmt/expected.h b/lib/utils/include/utils/fmt/expected.h new file mode 100644 index 0000000000..e8d7f5b22d --- /dev/null +++ b/lib/utils/include/utils/fmt/expected.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_EXPECTED_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_EXPECTED_H + +#include "fmt/format.h" +#include "utils/check_fmtable.h" +#include +#include + +namespace fmt { + +template +struct formatter< + ::tl::expected, + Char> + /* std::enable_if_t>::value>> */ + : formatter<::std::string> { + template + auto format(::tl::expected const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + + std::string result; + if (m.has_value()) { + result = fmt::format("expected({})", m.value()); + } else { + result = fmt::format("unexpected({})", m.error()); + } + + return formatter::format(result, ctx); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/multidiedge.h b/lib/utils/include/utils/graph/multidiedge.h index 808981afa1..d7c2c1590b 100644 --- a/lib/utils/include/utils/graph/multidiedge.h +++ b/lib/utils/include/utils/graph/multidiedge.h @@ -17,6 +17,10 @@ FF_VISIT_FMTABLE(MultiDiInput); struct MultiDiOutput : DiOutput { NodePort src_idx; + + bool operator>(MultiDiOutput const &) const; + bool operator>=(MultiDiOutput const &) const; + bool operator<=(MultiDiOutput const &) const; }; FF_VISITABLE_STRUCT(MultiDiOutput, src, src_idx); FF_VISIT_FMTABLE(MultiDiOutput); diff --git a/lib/utils/src/utils/graph/multidiedge.cc b/lib/utils/src/utils/graph/multidiedge.cc new file mode 100644 index 0000000000..cd3655c8e6 --- /dev/null +++ b/lib/utils/src/utils/graph/multidiedge.cc @@ -0,0 +1,17 @@ +#include "utils/graph/multidiedge.h" + +namespace FlexFlow { + +bool MultiDiOutput::operator>(MultiDiOutput const &other) const { + return !(*this < other) && !(*this == other); +} + +bool MultiDiOutput::operator>=(MultiDiOutput const &other) const { + return !(*this < other); +} + +bool MultiDiOutput::operator<=(MultiDiOutput const &other) const { + return (*this < other) || (*this == other); +} + +} // namespace FlexFlow diff --git a/lib/utils/test/common/include/test/utils/doctest.h b/lib/utils/test/common/include/test/utils/doctest.h index 47c7ebde6d..e643503de8 100644 --- a/lib/utils/test/common/include/test/utils/doctest.h +++ b/lib/utils/test/common/include/test/utils/doctest.h @@ -5,6 +5,9 @@ #include #include #include +#include +#include "utils/fmt/expected.h" +#include using namespace FlexFlow; @@ -64,10 +67,11 @@ namespace doctest { // } // }; -// template -// struct StringMaker> { -// static String convert(std::vector const &vec) { -// return doctest_print_container(vec, "[ ", ", ", " ]").c_str(); -// } -// }; +template +struct StringMaker> { + static String convert(tl::expected const &m) { + return toString(fmt::to_string(m)); + } +}; + } // namespace doctest From d549d220255e3831a31b45c3f5aff7f722828e80 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 27 May 2024 13:50:59 -0700 Subject: [PATCH 22/43] Add half of shape inference for Attention --- lib/op-attrs/include/op-attrs/ops/attention.h | 29 +- .../multihead_attention_inputs.dtg.h | 73 ++++++ .../attention/multihead_attention_inputs.h | 16 ++ .../multihead_attention_inputs.struct.toml | 39 +++ .../op-attrs/ops/attention_inputs.struct.toml | 26 -- .../include/op-attrs/parallel_tensor_shape.h | 1 + .../discard_copy_degree.dtg.h | 62 +++++ .../discard_copy_degree.struct.toml | 14 + .../parallel_tensor_shape/sum_degree.dtg.h | 62 +++++ .../sum_degree.struct.toml | 14 + .../op-attrs/replica_parallel_dim_set.dtg.h | 12 +- .../replica_parallel_dim_set.struct.toml | 9 +- lib/op-attrs/include/op-attrs/tensor_dims.h | 1 + lib/op-attrs/src/op-attrs/ops/attention.cc | 64 ++++- .../attention/multihead_attention_inputs.cc | 58 ++++ .../multihead_attention_inputs.dtg.cc | 184 +++++++++++++ .../conv_2d/conv_2d_parallel_input_shape.cc | 4 +- lib/op-attrs/src/op-attrs/ops/linear.cc | 5 +- .../src/op-attrs/parallel_tensor_dims.cc | 5 +- .../src/op-attrs/parallel_tensor_shape.cc | 11 +- .../discard_copy_degree.dtg.cc | 76 ++++++ .../parallel_tensor_shape/sum_degree.dtg.cc | 75 ++++++ .../src/op-attrs/replica_parallel_dim_set.cc | 10 +- .../op-attrs/replica_parallel_dim_set.dtg.cc | 25 +- lib/op-attrs/src/op-attrs/tensor_dims.cc | 25 +- lib/op-attrs/test/src/test_attention.cc | 247 ++++++++++++++++++ .../include/utils/containers/zip_vectors.h | 20 ++ lib/utils/include/utils/integer_conversions.h | 1 + lib/utils/src/utils/integer_conversions.cc | 6 + .../test/common/include/test/utils/doctest.h | 1 - lib/utils/test/common/src/main.cc | 2 + 31 files changed, 1093 insertions(+), 84 deletions(-) create mode 100644 lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml delete mode 100644 lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml create mode 100644 lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_tensor_shape/sum_degree.dtg.cc create mode 100644 lib/op-attrs/test/src/test_attention.cc create mode 100644 lib/utils/include/utils/containers/zip_vectors.h create mode 100644 lib/utils/test/common/src/main.cc diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index 2ec6585bbe..177fa6ba88 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -3,9 +3,8 @@ #include "core.h" #include "op-attrs/ops/attention_attrs.dtg.h" -#include "op-attrs/ops/attention_inputs.dtg.h" -#include "op-attrs/ops/parallel_attention_inputs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include namespace FlexFlow { @@ -35,15 +34,23 @@ int get_kvSeqLength(MultiHeadAttentionInputs const &); int get_num_samples(ParallelMultiHeadAttentionInputs const &); int get_num_samples(MultiHeadAttentionInputs const &); -TensorShape get_weights_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &); -ParallelTensorShape get_weights_shape(MultiHeadAttentionAttrs const &, - ParallelMultiHeadAttentionInputs const &); - -TensorShape get_output_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &); -ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &, - ParallelMultiHeadAttentionInputs const &); +tl::expected get_weights_shape(MultiHeadAttentionAttrs const &, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v); +tl::expected get_weights_shape(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); + +tl::expected get_output_shape(MultiHeadAttentionAttrs const &, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v); +tl::expected get_output_shape(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); CHECK_VALID_OP_ATTR(MultiHeadAttentionAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h new file mode 100644 index 0000000000..4519d0bd41 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h @@ -0,0 +1,73 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml +/* proj-data +{ + "generated_from": "655a3e56cf8a50fba6c1c9daf423720f" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_INPUTS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_INPUTS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct MultiHeadAttentionInputs { + MultiHeadAttentionInputs() = delete; + MultiHeadAttentionInputs(size_t const &batch_size, + size_t const &sequence_length, + size_t const &query_size, + size_t const &key_size, + size_t const &value_size, + ::FlexFlow::DataType const &datatype); + + bool operator==(MultiHeadAttentionInputs const &) const; + bool operator!=(MultiHeadAttentionInputs const &) const; + bool operator<(MultiHeadAttentionInputs const &) const; + bool operator>(MultiHeadAttentionInputs const &) const; + bool operator<=(MultiHeadAttentionInputs const &) const; + bool operator>=(MultiHeadAttentionInputs const &) const; + size_t batch_size; + size_t sequence_length; + size_t query_size; + size_t key_size; + size_t value_size; + ::FlexFlow::DataType datatype; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MultiHeadAttentionInputs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MultiHeadAttentionInputs from_json(json const &); + static void to_json(json &, FlexFlow::MultiHeadAttentionInputs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionInputs const &); +std::ostream &operator<<(std::ostream &, MultiHeadAttentionInputs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_INPUTS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.h new file mode 100644 index 0000000000..044d4cfca9 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_INPUTS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_INPUTS_H + +#include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include + +namespace FlexFlow { + +tl::expected parse_attention_input_shape(TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml new file mode 100644 index 0000000000..b82b285451 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml @@ -0,0 +1,39 @@ +namespace = "FlexFlow" +name = "MultiHeadAttentionInputs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "batch_size" +type = "size_t" + +[[fields]] +name = "sequence_length" +type = "size_t" + +[[fields]] +name = "query_size" +type = "size_t" + +[[fields]] +name = "key_size" +type = "size_t" + +[[fields]] +name = "value_size" +type = "size_t" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml deleted file mode 100644 index 1b04c1de2d..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml +++ /dev/null @@ -1,26 +0,0 @@ -namespace = "FlexFlow" -name = "MultiHeadAttentionInputs" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "op-attrs/tensor_shape.h" -] - -[[fields]] -name = "query" -type = "::FlexFlow::TensorShape" - -[[fields]] -name = "key" -type = "::FlexFlow::TensorShape" - -[[fields]] -name = "value" -type = "::FlexFlow::TensorShape" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 5880921c55..bcce38eded 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -12,6 +12,7 @@ ShardParallelDim shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &, ff_dim_t); ParallelTensorShape lift_to_parallel(TensorShape const &); +ParallelTensorShape lift_to_parallel_with_degrees(TensorShape const &, SumDegree sum_degree, DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees); std::unordered_set replica_dims(ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h new file mode 100644 index 0000000000..a820bfe81c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml +/* proj-data +{ + "generated_from": "e4677d1fb25d3833570ee567f5659914" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DISCARD_COPY_DEGREE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DISCARD_COPY_DEGREE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct DiscardCopyDegree { + DiscardCopyDegree() = delete; + DiscardCopyDegree(int const &value); + + bool operator==(DiscardCopyDegree const &) const; + bool operator!=(DiscardCopyDegree const &) const; + bool operator<(DiscardCopyDegree const &) const; + bool operator>(DiscardCopyDegree const &) const; + bool operator<=(DiscardCopyDegree const &) const; + bool operator>=(DiscardCopyDegree const &) const; + int value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::DiscardCopyDegree const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::DiscardCopyDegree from_json(json const &); + static void to_json(json &, FlexFlow::DiscardCopyDegree const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(DiscardCopyDegree const &); +std::ostream &operator<<(std::ostream &, DiscardCopyDegree const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DISCARD_COPY_DEGREE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml new file mode 100644 index 0000000000..b4905fb0ce --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "DiscardCopyDegree" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.h new file mode 100644 index 0000000000..17388f8d05 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml +/* proj-data +{ + "generated_from": "e94a05618f2ad92dd7b3328a1d9c6786" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_SUM_DEGREE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_SUM_DEGREE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct SumDegree { + SumDegree() = delete; + SumDegree(int const &value); + + bool operator==(SumDegree const &) const; + bool operator!=(SumDegree const &) const; + bool operator<(SumDegree const &) const; + bool operator>(SumDegree const &) const; + bool operator<=(SumDegree const &) const; + bool operator>=(SumDegree const &) const; + int value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SumDegree const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::SumDegree from_json(json const &); + static void to_json(json &, FlexFlow::SumDegree const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(SumDegree const &); +std::ostream &operator<<(std::ostream &, SumDegree const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_SUM_DEGREE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml new file mode 100644 index 0000000000..d86917211e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "SumDegree" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h index c6c025f31c..321029347f 100644 --- a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml /* proj-data { - "generated_from": "20d8004e6f1e710688fe692b92dc2816" + "generated_from": "74230e2d18db5c059d3e7be0f25e746e" } */ @@ -12,6 +12,8 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" +#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" +#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" #include "rapidcheck.h" #include #include @@ -20,7 +22,9 @@ namespace FlexFlow { struct ReplicaParallelDimSet { ReplicaParallelDimSet() = delete; - ReplicaParallelDimSet(int const &sum_degree, int const &discard_copy_degree); + ReplicaParallelDimSet( + ::FlexFlow::SumDegree const &sum_degree, + ::FlexFlow::DiscardCopyDegree const &discard_copy_degree); bool operator==(ReplicaParallelDimSet const &) const; bool operator!=(ReplicaParallelDimSet const &) const; @@ -28,8 +32,8 @@ struct ReplicaParallelDimSet { bool operator>(ReplicaParallelDimSet const &) const; bool operator<=(ReplicaParallelDimSet const &) const; bool operator>=(ReplicaParallelDimSet const &) const; - int sum_degree; - int discard_copy_degree; + ::FlexFlow::SumDegree sum_degree; + ::FlexFlow::DiscardCopyDegree discard_copy_degree; }; } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml index 7c05a7809f..66f50bee9f 100644 --- a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml @@ -9,10 +9,15 @@ features = [ "fmt", ] +includes = [ + "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", + "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", +] + [[fields]] name = "sum_degree" -type = "int" +type = "::FlexFlow::SumDegree" [[fields]] name = "discard_copy_degree" -type = "int" +type = "::FlexFlow::DiscardCopyDegree" diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h index 6302ab1418..caee5c72ab 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -12,6 +12,7 @@ size_t num_dims(TensorDims const &); size_t dim_at_idx(TensorDims const &, ff_dim_t); ParallelTensorDims lift_to_parallel(TensorDims const &); +ParallelTensorDims lift_to_parallel_with_degrees(TensorDims const &, SumDegree sum_degree, DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees); } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index c62e1af48b..07717a2b81 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -1,4 +1,6 @@ #include "op-attrs/ops/attention.h" +#include "op-attrs/ops/attention/multihead_attention_inputs.h" +#include "op-attrs/tensor_shape.h" namespace FlexFlow { @@ -62,12 +64,52 @@ int get_vSize(MultiHeadAttentionInputs const &) { NOT_IMPLEMENTED(); } -TensorShape get_weights_shape(MultiHeadAttentionAttrs const &attrs, - MultiHeadAttentionInputs const &inputs) { +tl::expected +get_weights_shape(MultiHeadAttentionAttrs const &attrs, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { + tl::expected parse_result = parse_attention_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); + } + + MultiHeadAttentionInputs parsed = parse_result.value(); + + return TensorShape{ + TensorDims{ + ParallelTensorDims{ + parsed.batch_size, + parsed.sequence_length, + attrs.embed_dim, + } + }, + parsed.datatype, + } +} + +tl::expected +get_weights_shape(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v) { + NOT_IMPLEMENTED(); +} + +tl::expected get_weights_shape(MultiHeadAttentionAttrs const &attrs, + TensorShape const &query_shape, + TensorShape const &key_shape, + TensorShape const &value_shape) { + MultiHeadAttentionInputs inputs = { + query_shape, + key_shape, + value_shape, + }; + size_t qParas = get_qProjSize(attrs) * get_qSize(inputs); size_t kParas = get_kProjSize(attrs) * get_kSize(inputs); size_t vParas = get_vProjSize(attrs) * get_vSize(inputs); - TensorShape output_shape = get_output_shape(attrs, inputs); + TensorShape output_shape = get_output_shape(attrs, query_shape, key_shape, value_shape); size_t oParas = get_oProjSize(attrs) * get_oSize(output_shape); TensorDims dims = {{qParas + kParas + vParas + oParas, @@ -76,7 +118,10 @@ TensorShape get_weights_shape(MultiHeadAttentionAttrs const &attrs, return {dims, DataType::FLOAT}; } -ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, + + +tl::expected +get_output_shape(MultiHeadAttentionAttrs const &attrs, ParallelTensorShape const &query_shape, ParallelTensorShape const &key_shape, ParallelTensorShape const &value_shape) { @@ -87,16 +132,13 @@ ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, /* return output_shape; */ } -TensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, +tl::expected get_output_shape(MultiHeadAttentionAttrs const &attrs, TensorShape const &query_shape, TensorShape const &key_shape, TensorShape const &value_shape) { - ParallelTensorShape parallel_shape = - get_output_shape(attrs, - lift_to_parallel(query_shape), - lift_to_parallel(key_shape), - lift_to_parallel(value_shape)); - return get_tensor_shape_unsafe(parallel_shape); + + + size_t q_batchsize = dim_at_idx(query_shape } TensorShape get_output_shape(MultiHeadAttentionAttrs const &, MultiHeadAttentionInputs const &) { diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc new file mode 100644 index 0000000000..09d48df497 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc @@ -0,0 +1,58 @@ +#include "op-attrs/ops/attention/multihead_attention_inputs.h" +#include "op-attrs/tensor_shape.h" + +namespace FlexFlow { + +template +static bool all_same(T const &x, T const &y, T const &z) { + return x == y && y == z; +} + +tl::expected parse_attention_input_shape(TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { + if (num_dims(input_q) != 3) { + return tl::unexpected(fmt::format("Query input has incorrect number of dims: {} != {}", num_dims(input_q), 3)); + } + if (num_dims(input_k) != 3) { + return tl::unexpected(fmt::format("Key input has incorrect number of dims: {} != {}", num_dims(input_k), 3)); + } + if (num_dims(input_v) != 3) { + return tl::unexpected(fmt::format("Value input has incorrect number of dims: {} != {}", num_dims(input_v), 3)); + } + + size_t seq_len_q = dim_at_idx(input_q, ff_dim_t{-2}); + size_t seq_len_k = dim_at_idx(input_k, ff_dim_t{-2}); + size_t seq_len_v = dim_at_idx(input_v, ff_dim_t{-2}); + + if (!all_same(seq_len_q, seq_len_k, seq_len_v)) { + return tl::unexpected(fmt::format("Q, K, V disagree on the sequence length: {} (Q) vs {} (K) vs {} (V)", seq_len_q, seq_len_k, seq_len_v)); + } + + size_t batch_size_q = dim_at_idx(input_q, ff_dim_t{-3}); + size_t batch_size_k = dim_at_idx(input_k, ff_dim_t{-3}); + size_t batch_size_v = dim_at_idx(input_v, ff_dim_t{-3}); + + if (!all_same(batch_size_q, batch_size_k, batch_size_v)) { + return tl::unexpected(fmt::format("Q, K, V disagree on the batch size: {} (Q) vs {} (K) vs {} (V)", batch_size_q, batch_size_k, batch_size_v)); + } + + if (!all_same(input_q.data_type, input_k.data_type, input_v.data_type)) { + return tl::unexpected(fmt::format("Q, K, V disagree on the datatype: {} (Q) vs {} (K) vs {} (V)", input_q.data_type, input_k.data_type, input_v.data_type)); + } + + size_t q_size = dim_at_idx(input_q, ff_dim_t{-1}); + size_t k_size = dim_at_idx(input_k, ff_dim_t{-1}); + size_t v_size = dim_at_idx(input_v, ff_dim_t{-1}); + + return MultiHeadAttentionInputs{ + batch_size_q, + seq_len_q, + q_size, + k_size, + v_size, + input_q.data_type, + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc new file mode 100644 index 0000000000..849e4ff4d9 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc @@ -0,0 +1,184 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml +/* proj-data +{ + "generated_from": "655a3e56cf8a50fba6c1c9daf423720f" +} +*/ + +#include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" + +#include +#include + +namespace FlexFlow { +MultiHeadAttentionInputs::MultiHeadAttentionInputs( + size_t const &batch_size, + size_t const &sequence_length, + size_t const &query_size, + size_t const &key_size, + size_t const &value_size, + ::FlexFlow::DataType const &datatype) + : batch_size(batch_size), sequence_length(sequence_length), + query_size(query_size), key_size(key_size), value_size(value_size), + datatype(datatype) {} +bool MultiHeadAttentionInputs::operator==( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) == std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +bool MultiHeadAttentionInputs::operator!=( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) != std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +bool MultiHeadAttentionInputs::operator<( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) < std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +bool MultiHeadAttentionInputs::operator>( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) > std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +bool MultiHeadAttentionInputs::operator<=( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) <= std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +bool MultiHeadAttentionInputs::operator>=( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) >= std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MultiHeadAttentionInputs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.batch_size) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.sequence_length) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.query_size) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.key_size) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.value_size) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.datatype) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MultiHeadAttentionInputs + adl_serializer::from_json( + json const &j) { + return {j.at("batch_size").template get(), + j.at("sequence_length").template get(), + j.at("query_size").template get(), + j.at("key_size").template get(), + j.at("value_size").template get(), + j.at("datatype").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MultiHeadAttentionInputs const &v) { + j["__type"] = "MultiHeadAttentionInputs"; + j["batch_size"] = v.batch_size; + j["sequence_length"] = v.sequence_length; + j["query_size"] = v.query_size; + j["key_size"] = v.key_size; + j["value_size"] = v.value_size; + j["datatype"] = v.datatype; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionInputs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MultiHeadAttentionInputs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc index 8074a03b4d..32ac4547f1 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc @@ -17,8 +17,8 @@ Conv2DParallelInputShape channel_dim, height_dim, width_dim, - input.dims.replica_dims.sum_degree, - input.dims.replica_dims.discard_copy_degree, + get_sum_degree(input), + get_discard_copy_degree(input), input.data_type, }; } diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index fbe336c090..e6c9dd751b 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -55,10 +55,11 @@ ParallelTensorShape get_output_shape(LinearAttrs const &attrs, ShardParallelDim output_sample_dim = input_sample_dim; ShardParallelDim output_channels_dim = { size_t_from_int(attrs.out_channels), - input_shape.dims.replica_dims.discard_copy_degree}; + get_discard_copy_degree(input_shape), + }; int output_sum_degree = - input_shape.dims.replica_dims.sum_degree * in_channels_dim.degree; + get_sum_degree(input_shape) * in_channels_dim.degree; int output_discard_copy_degree = 1; ParallelTensorShape result = input_shape; diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index ff4b864378..3942b2c49f 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -3,6 +3,7 @@ #include "op-attrs/replica_parallel_dim_set.h" #include "op-attrs/shard_parallel_dim.h" #include "utils/containers.h" +#include "utils/integer_conversions.h" namespace FlexFlow { @@ -20,7 +21,7 @@ size_t num_shard_dims(ParallelTensorDims const &dims) { } int total_replica_degree(ParallelTensorDims const &dims) { - return dims.replica_dims.discard_copy_degree * dims.replica_dims.sum_degree; + return dims.replica_dims.discard_copy_degree.value * dims.replica_dims.sum_degree.value; } int total_shard_degree(ParallelTensorDims const &dims) { @@ -41,7 +42,7 @@ bool is_valid(ParallelTensorDims const &dims) { ShardParallelDim shard_dim_at_idx(ParallelTensorDims const &d, ff_dim_t idx) { if (idx.value < 0) { - idx = ff_dim_t{d.shard_dims.size() + idx.value}; + idx = ff_dim_t{int_from_size_t(d.shard_dims.size()) + idx.value}; } return d.shard_dims.at(idx); } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 64d48db678..55d1d4af2b 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -21,11 +21,11 @@ int get_num_replicas(ParallelTensorShape const &shape) { } int get_sum_degree(ParallelTensorShape const &shape) { - return shape.dims.replica_dims.sum_degree; + return shape.dims.replica_dims.sum_degree.value; } int get_discard_copy_degree(ParallelTensorShape const &shape) { - return shape.dims.replica_dims.discard_copy_degree; + return shape.dims.replica_dims.discard_copy_degree.value; } int get_total_parallel_degree(ParallelTensorShape const &s) { @@ -48,6 +48,13 @@ ParallelTensorShape lift_to_parallel(TensorShape const &s) { return {lift_to_parallel(s.dims), s.data_type}; } +ParallelTensorShape lift_to_parallel_with_degrees(TensorShape const &s, SumDegree sum_degree, DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees) { + return ParallelTensorShape{ + lift_to_parallel_with_degrees(s.dims, sum_degree, discard_copy_degree, shard_degrees), + s.data_type, + }; +} + TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.cc new file mode 100644 index 0000000000..4547a5df9b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml +/* proj-data +{ + "generated_from": "e4677d1fb25d3833570ee567f5659914" +} +*/ + +#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" + +#include + +namespace FlexFlow { +DiscardCopyDegree::DiscardCopyDegree(int const &value) : value(value) {} +bool DiscardCopyDegree::operator==(DiscardCopyDegree const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool DiscardCopyDegree::operator!=(DiscardCopyDegree const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool DiscardCopyDegree::operator<(DiscardCopyDegree const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool DiscardCopyDegree::operator>(DiscardCopyDegree const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool DiscardCopyDegree::operator<=(DiscardCopyDegree const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool DiscardCopyDegree::operator>=(DiscardCopyDegree const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::DiscardCopyDegree const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::DiscardCopyDegree + adl_serializer::from_json(json const &j) { + return {j.at("value").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::DiscardCopyDegree const &v) { + j["__type"] = "DiscardCopyDegree"; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(DiscardCopyDegree const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DiscardCopyDegree const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape/sum_degree.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/sum_degree.dtg.cc new file mode 100644 index 0000000000..cf159a1ea7 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/sum_degree.dtg.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml +/* proj-data +{ + "generated_from": "e94a05618f2ad92dd7b3328a1d9c6786" +} +*/ + +#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" + +#include + +namespace FlexFlow { +SumDegree::SumDegree(int const &value) : value(value) {} +bool SumDegree::operator==(SumDegree const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool SumDegree::operator!=(SumDegree const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool SumDegree::operator<(SumDegree const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool SumDegree::operator>(SumDegree const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool SumDegree::operator<=(SumDegree const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool SumDegree::operator>=(SumDegree const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(FlexFlow::SumDegree const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::SumDegree + adl_serializer::from_json(json const &j) { + return {j.at("value").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::SumDegree const &v) { + j["__type"] = "SumDegree"; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(SumDegree const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, SumDegree const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc index 0ce6a40a3b..7ab2e2d2ef 100644 --- a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc @@ -11,9 +11,9 @@ int get_order_of_replica_type(ReplicaParallelDimSet const &s, ReplicaType replica_type) { switch (replica_type) { case ReplicaType::SUM: - return s.sum_degree; + return s.sum_degree.value; case ReplicaType::DISCARD_COPY: - return s.discard_copy_degree; + return s.discard_copy_degree.value; default: throw mk_runtime_error(fmt::format("Unexpected ReplicaType value: {}", static_cast(replica_type))); @@ -23,13 +23,13 @@ int get_order_of_replica_type(ReplicaParallelDimSet const &s, std::unordered_set get_replica_dims(ReplicaParallelDimSet const &s) { return std::unordered_set{ - ReplicaParallelDim{s.sum_degree, ReplicaType::SUM}, - ReplicaParallelDim{s.discard_copy_degree, ReplicaType::DISCARD_COPY}, + ReplicaParallelDim{s.sum_degree.value, ReplicaType::SUM}, + ReplicaParallelDim{s.discard_copy_degree.value, ReplicaType::DISCARD_COPY}, }; } bool is_valid(ReplicaParallelDimSet const &s) { - return s.sum_degree > 0 && s.discard_copy_degree > 0; + return s.sum_degree.value > 0 && s.discard_copy_degree.value > 0; } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc index 3b3ac59a9b..f8782be01b 100644 --- a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc @@ -3,17 +3,20 @@ // lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml /* proj-data { - "generated_from": "20d8004e6f1e710688fe692b92dc2816" + "generated_from": "74230e2d18db5c059d3e7be0f25e746e" } */ #include "op-attrs/replica_parallel_dim_set.dtg.h" +#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" +#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" #include namespace FlexFlow { -ReplicaParallelDimSet::ReplicaParallelDimSet(int const &sum_degree, - int const &discard_copy_degree) +ReplicaParallelDimSet::ReplicaParallelDimSet( + ::FlexFlow::SumDegree const &sum_degree, + ::FlexFlow::DiscardCopyDegree const &discard_copy_degree) : sum_degree(sum_degree), discard_copy_degree(discard_copy_degree) {} bool ReplicaParallelDimSet::operator==( ReplicaParallelDimSet const &other) const { @@ -51,10 +54,10 @@ namespace std { size_t hash::operator()( FlexFlow::ReplicaParallelDimSet const &x) const { size_t result = 0; - result ^= std::hash{}(x.sum_degree) + 0x9e3779b9 + (result << 6) + - (result >> 2); - result ^= std::hash{}(x.discard_copy_degree) + 0x9e3779b9 + + result ^= std::hash<::FlexFlow::SumDegree>{}(x.sum_degree) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DiscardCopyDegree>{}(x.discard_copy_degree) + + 0x9e3779b9 + (result << 6) + (result >> 2); return result; } } // namespace std @@ -62,8 +65,9 @@ size_t hash::operator()( namespace nlohmann { FlexFlow::ReplicaParallelDimSet adl_serializer::from_json(json const &j) { - return {j.at("sum_degree").template get(), - j.at("discard_copy_degree").template get()}; + return {j.at("sum_degree").template get<::FlexFlow::SumDegree>(), + j.at("discard_copy_degree") + .template get<::FlexFlow::DiscardCopyDegree>()}; } void adl_serializer::to_json( json &j, FlexFlow::ReplicaParallelDimSet const &v) { @@ -76,8 +80,9 @@ void adl_serializer::to_json( namespace rc { Gen Arbitrary::arbitrary() { - return gen::construct(gen::arbitrary(), - gen::arbitrary()); + return gen::construct( + gen::arbitrary<::FlexFlow::SumDegree>(), + gen::arbitrary<::FlexFlow::DiscardCopyDegree>()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index ac657dd620..74ebedd816 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -2,6 +2,8 @@ #include "op-attrs/replica_parallel_dim_set.h" #include "op-attrs/shard_parallel_dim.dtg.h" #include "utils/containers.h" +#include "utils/containers/zip_vectors.h" +#include "utils/integer_conversions.h" namespace FlexFlow { @@ -15,20 +17,31 @@ size_t num_dims(TensorDims const &dims) { size_t dim_at_idx(TensorDims const &dims, ff_dim_t idx) { if (idx.value < 0) { - idx = ff_dim_t{num_dims(dims) + idx.value}; + idx = ff_dim_t{int_from_size_t(num_dims(dims)) + idx.value}; } return dims.ff_ordered.at(idx); } ParallelTensorDims lift_to_parallel(TensorDims const &dims) { + std::vector shard_degrees(num_dims(dims), 1); // 1 repeated num_dims(dims) times + return lift_to_parallel_with_degrees(dims, 1, 1, shard_degrees); +} + +ParallelTensorDims lift_to_parallel_with_degrees(TensorDims const &dims, SumDegree sum_degree, DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees) { std::vector lifted = - transform(as_vector(dims.ff_ordered), [](size_t size) { - return ShardParallelDim{size, 1}; - }); + transform(zip(as_vector(dims.ff_ordered), as_vector(shard_degrees)), + [](std::pair const &p) { + size_t size = p.first; + int degree = p.second; + return ShardParallelDim(size, degree); + }); return ParallelTensorDims{ - FFOrdered{lifted}, - empty_replica_parallel_dim_set(), + FFOrdered{lifted}, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + } }; } diff --git a/lib/op-attrs/test/src/test_attention.cc b/lib/op-attrs/test/src/test_attention.cc new file mode 100644 index 0000000000..e5310bca7c --- /dev/null +++ b/lib/op-attrs/test/src/test_attention.cc @@ -0,0 +1,247 @@ +#include "test/utils/doctest.h" +#include "op-attrs/ops/attention.h" +#include "utils/integer_conversions.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(MultiHeadAttentionAttrs, TensorShape, TensorShape, TensorShape)") { + int embed_dim = 32; + + /* Parameter meanings can be found at + * https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html + */ + MultiHeadAttentionAttrs attrs = { + /*embed_dim=*/embed_dim, + /*num_heads=*/10, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + + size_t n = 40; + size_t l = 48; + size_t s = 56; + + TensorShape input_q = { + TensorDims{ + FFOrdered{ + n, + l, + size_t_from_int(attrs.embed_dim), + } + }, + DataType::FLOAT, + }; + + TensorShape input_k = { + TensorDims{ + FFOrdered{ + n, + s, + size_t_from_int(attrs.kdim), + }, + }, + DataType::FLOAT, + }; + + TensorShape input_v = { + TensorDims{ + FFOrdered{ + n, + s, + size_t_from_int(attrs.vdim), + }, + }, + DataType::FLOAT, + }; + + tl::expected result = get_output_shape(attrs, input_q, input_k, input_v); + + tl::expected correct = TensorShape{ + TensorDims{ + FFOrdered{ + n, + l, + size_t_from_int(attrs.embed_dim), + } + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + TEST_CASE("get_weights_shape(MultiHeadAttentionAttrs, TensorShape, TensorShape, TensorShape)") { + int embed_dim = 32; + + /* Parameter meanings can be found at + * https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html + */ + MultiHeadAttentionAttrs attrs = { + /*embed_dim=*/embed_dim, + /*num_heads=*/10, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + + size_t n = 40; + size_t l = 48; + size_t s = 56; + + TensorShape input_q = { + TensorDims{ + FFOrdered{ + n, + l, + size_t_from_int(attrs.embed_dim), + } + }, + DataType::FLOAT, + }; + + TensorShape input_k = { + TensorDims{ + FFOrdered{ + n, + s, + size_t_from_int(attrs.kdim), + }, + }, + DataType::FLOAT, + }; + + TensorShape input_v = { + TensorDims{ + FFOrdered{ + n, + s, + size_t_from_int(attrs.vdim), + }, + }, + DataType::FLOAT, + }; + + tl::expected result = get_weights_shape(attrs, input_q, input_k, input_v); + + int qProjPerHeadWeightSize = attrs.kdim * dim_at_idx(input_q, ff_dim_t{-1}); + int kProjPerHeadWeightSize = attrs.kdim * dim_at_idx(input_k, ff_dim_t{-1}); + int vProjPerHeadWeightSize = attrs.vdim * dim_at_idx(input_v, ff_dim_t{-1}); + int oProjPerHeadWeightSize = attrs.embed_dim * attrs.vdim; + int perHeadWeightSize = qProjPerHeadWeightSize + kProjPerHeadWeightSize + vProjPerHeadWeightSize + oProjPerHeadWeightSize; + + tl::expected correct = TensorShape{ + TensorDims{ + FFOrdered{ + size_t_from_int(perHeadWeightSize), + size_t_from_int(attrs.num_heads), + } + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + TEST_CASE("parallel shape inference for MultiHeadAttentionAttrs") { + int embed_dim = 32; + + /* Parameter meanings can be found at + * https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html + */ + MultiHeadAttentionAttrs attrs = { + /*embed_dim=*/embed_dim, + /*num_heads=*/10, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + + size_t batchsize = 40; + size_t seq_len = 48; + size_t q_size = 56; + size_t k_size = 64; + size_t v_size = 72; + + TensorShape unpar_q_shape = TensorShape{ + TensorDims{ + FFOrdered{ + batchsize, + seq_len, + q_size, + }, + }, + DataType::FLOAT, + }; + + TensorShape unpar_k_shape = TensorShape{ + TensorDims{ + FFOrdered{ + batchsize, + seq_len, + k_size, + }, + }, + DataType::FLOAT, + }; + + TensorShape unpar_v_shape = TensorShape{ + TensorDims{ + FFOrdered{ + batchsize, + seq_len, + v_size, + }, + }, + DataType::FLOAT, + }; + + TensorShape unpar_o_shape = get_output_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); + TensorShape unpar_w_shape = get_output_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); + + auto make_q = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_batch, int o_seq_len, int o_q) { + return lift_to_parallel_with_degrees(unpar_q_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_q}); + }; + + auto make_k = [&](int o_sum, int o_eq, int o_batch, int o_seq_len, int o_k) { + return lift_to_parallel_with_degrees(unpar_k_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_k}); + }; + + auto make_v = [&](int o_sum, int o_eq, int o_batch, int o_seq_len, int o_v) { + return lift_to_parallel_with_degrees(unpar_v_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_v}); + }; + + auto make_o = [&](int o_sum, int o_eq, int o_batch, int o_seq_len, int o_o) { + return lift_to_parallel_with_degrees(unpar_o_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_o}); + }; + + auto make_w = [&](int o_sum, int o_eq, int o_e, int o_h) { + return lift_to_parallel_with_degrees(unpar_w_shape, o_sum, o_eq, FFOrdered{o_e, o_h}); + }; + + SUBCASE("data parallelism") { + int o_b = 4; + ParallelTensorShape q = make_q(1, 1, o_b, 1, 1); + ParallelTensorShape k = make_k(1, 1, o_b, 1, 1); + ParallelTensorShape v = make_v(1, 1, o_b, 1, 1); + + tl::expected result_o = get_output_shape(attrs, q, k, v); + tl::expected correct_o = make_o(1, 1, o_b, 1, 1); + + CHECK(result_o == correct_o); + + tl::expected result_w = get_weights_shape(attrs, q, k, v); + tl::expected correct_w = make_w(1, o_b, 1, 1); + + CHECK(result_w == correct_w); + } + } +} diff --git a/lib/utils/include/utils/containers/zip_vectors.h b/lib/utils/include/utils/containers/zip_vectors.h new file mode 100644 index 0000000000..84664da48f --- /dev/null +++ b/lib/utils/include/utils/containers/zip_vectors.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VECTORS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VECTORS_H + +#include +#include + +namespace FlexFlow { + +template +std::vector> zip(std::vector const &l, std::vector const &r) { + std::vector> result; + for (int i = 0; i < std::min(l.size(), r.size()); i++) { + result.push_back(std::make_pair(l.at(i), r.at(i))); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/integer_conversions.h b/lib/utils/include/utils/integer_conversions.h index 6d3fb4cfdf..154aaa2a67 100644 --- a/lib/utils/include/utils/integer_conversions.h +++ b/lib/utils/include/utils/integer_conversions.h @@ -6,6 +6,7 @@ namespace FlexFlow { size_t size_t_from_int(int); +int int_from_size_t(size_t); } // namespace FlexFlow diff --git a/lib/utils/src/utils/integer_conversions.cc b/lib/utils/src/utils/integer_conversions.cc index 07ff6106a3..e3e9983c45 100644 --- a/lib/utils/src/utils/integer_conversions.cc +++ b/lib/utils/src/utils/integer_conversions.cc @@ -1,5 +1,6 @@ #include "utils/integer_conversions.h" #include +#include namespace FlexFlow { @@ -8,4 +9,9 @@ size_t size_t_from_int(int x) { return static_cast(x); } +int int_from_size_t(size_t x) { + assert (x < std::numeric_limits::max()); + return static_cast(x); +} + } // namespace FlexFlow diff --git a/lib/utils/test/common/include/test/utils/doctest.h b/lib/utils/test/common/include/test/utils/doctest.h index e643503de8..db6d2a3f3b 100644 --- a/lib/utils/test/common/include/test/utils/doctest.h +++ b/lib/utils/test/common/include/test/utils/doctest.h @@ -1,4 +1,3 @@ -#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN #include "doctest/doctest.h" #include "utils/containers.decl.h" #include diff --git a/lib/utils/test/common/src/main.cc b/lib/utils/test/common/src/main.cc new file mode 100644 index 0000000000..9522fa7fdb --- /dev/null +++ b/lib/utils/test/common/src/main.cc @@ -0,0 +1,2 @@ +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#include "doctest/doctest.h" From 8405ebff9d70ac67b330061b32e4726366953251 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 27 May 2024 22:07:52 -0700 Subject: [PATCH 23/43] Finish initial shape inference for Attention --- .../include/op-attrs/dim_ordered/transform.h | 19 ++ lib/op-attrs/include/op-attrs/ops/attention.h | 15 +- .../multihead_attention_inputs.dtg.h | 3 +- .../multihead_attention_parallel_inputs.dtg.h | 82 +++++++ .../multihead_attention_parallel_inputs.h | 16 ++ ...head_attention_parallel_inputs.struct.toml | 46 ++++ .../op-attrs/ops/attention_inputs.dtg.h | 67 ------ lib/op-attrs/src/op-attrs/ops/attention.cc | 134 ++++++----- .../multihead_attention_inputs.dtg.cc | 3 +- .../multihead_attention_parallel_inputs.cc | 90 ++++++++ ...multihead_attention_parallel_inputs.dtg.cc | 209 ++++++++++++++++++ .../src/op-attrs/ops/attention_inputs.dtg.cc | 107 --------- .../src/op-attrs/parallel_tensor_dims.cc | 6 +- lib/op-attrs/test/src/test_attention.cc | 181 +++++++-------- lib/utils/include/utils/containers.decl.h | 5 - lib/utils/include/utils/containers.h | 6 +- .../utils/containers/vector_transform.h | 20 ++ 17 files changed, 667 insertions(+), 342 deletions(-) create mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/transform.h create mode 100644 lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml delete mode 100644 lib/op-attrs/include/op-attrs/ops/attention_inputs.dtg.h create mode 100644 lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc delete mode 100644 lib/op-attrs/src/op-attrs/ops/attention_inputs.dtg.cc create mode 100644 lib/utils/include/utils/containers/vector_transform.h diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h new file mode 100644 index 0000000000..08ffff43f1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H + +#include "op-attrs/dim_ordered.h" +#include "utils/containers.h" +#include "utils/containers/vector_transform.h" + +namespace FlexFlow { + +template +DimOrdered> transform(DimOrdered const &d, F f) { + using Out = std::invoke_result_t; + + return DimOrdered{vector_transform(as_vector(d), f)}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index 177fa6ba88..c552f034d0 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -4,6 +4,9 @@ #include "core.h" #include "op-attrs/ops/attention_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" #include namespace FlexFlow { @@ -13,25 +16,25 @@ int get_vProjSize(MultiHeadAttentionAttrs const &); int get_kProjSize(MultiHeadAttentionAttrs const &); int get_oProjSize(MultiHeadAttentionAttrs const &); -int get_qSize(ParallelMultiHeadAttentionInputs const &); +int get_qSize(MultiHeadAttentionParallelInputs const &); int get_qSize(MultiHeadAttentionInputs const &); -int get_kSize(ParallelMultiHeadAttentionInputs const &); +int get_kSize(MultiHeadAttentionParallelInputs const &); int get_kSize(MultiHeadAttentionInputs const &); -int get_vSize(ParallelMultiHeadAttentionInputs const &); +int get_vSize(MultiHeadAttentionParallelInputs const &); int get_vSize(MultiHeadAttentionInputs const &); int get_oSize(ParallelTensorShape const &); int get_oSize(TensorShape const &); -int get_qoSeqLength(ParallelMultiHeadAttentionInputs const &); +int get_qoSeqLength(MultiHeadAttentionParallelInputs const &); int get_qoSeqLength(MultiHeadAttentionInputs const &); -int get_kvSeqLength(ParallelMultiHeadAttentionInputs const &); +int get_kvSeqLength(MultiHeadAttentionParallelInputs const &); int get_kvSeqLength(MultiHeadAttentionInputs const &); -int get_num_samples(ParallelMultiHeadAttentionInputs const &); +int get_num_samples(MultiHeadAttentionParallelInputs const &); int get_num_samples(MultiHeadAttentionInputs const &); tl::expected get_weights_shape(MultiHeadAttentionAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h index 4519d0bd41..7b61305a1a 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml /* proj-data { - "generated_from": "655a3e56cf8a50fba6c1c9daf423720f" + "generated_from": "c57a9d1d2822a726ee9d9369d22e8e72" } */ @@ -12,6 +12,7 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" +#include "op-attrs/datatype.dtg.h" #include "rapidcheck.h" #include #include diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h new file mode 100644 index 0000000000..297b1f8f1c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h @@ -0,0 +1,82 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml +/* proj-data +{ + "generated_from": "7c434445707968123a361c038a337da2" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_PARALLEL_INPUTS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_PARALLEL_INPUTS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" +#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct MultiHeadAttentionParallelInputs { + MultiHeadAttentionParallelInputs() = delete; + MultiHeadAttentionParallelInputs( + ::FlexFlow::ShardParallelDim const &batch_dim, + ::FlexFlow::ShardParallelDim const &sequence_dim, + ::FlexFlow::ShardParallelDim const &query_dim, + ::FlexFlow::ShardParallelDim const &key_dim, + ::FlexFlow::ShardParallelDim const &value_dim, + ::FlexFlow::DiscardCopyDegree const &discard_copy_degree, + ::FlexFlow::DataType const &datatype); + + bool operator==(MultiHeadAttentionParallelInputs const &) const; + bool operator!=(MultiHeadAttentionParallelInputs const &) const; + bool operator<(MultiHeadAttentionParallelInputs const &) const; + bool operator>(MultiHeadAttentionParallelInputs const &) const; + bool operator<=(MultiHeadAttentionParallelInputs const &) const; + bool operator>=(MultiHeadAttentionParallelInputs const &) const; + ::FlexFlow::ShardParallelDim batch_dim; + ::FlexFlow::ShardParallelDim sequence_dim; + ::FlexFlow::ShardParallelDim query_dim; + ::FlexFlow::ShardParallelDim key_dim; + ::FlexFlow::ShardParallelDim value_dim; + ::FlexFlow::DiscardCopyDegree discard_copy_degree; + ::FlexFlow::DataType datatype; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MultiHeadAttentionParallelInputs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MultiHeadAttentionParallelInputs from_json(json const &); + static void to_json(json &, + FlexFlow::MultiHeadAttentionParallelInputs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionParallelInputs const &); +std::ostream &operator<<(std::ostream &, + MultiHeadAttentionParallelInputs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_PARALLEL_INPUTS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.h new file mode 100644 index 0000000000..a09a66e531 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_PARALLEL_INPUTS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_PARALLEL_INPUTS_H + +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include + +namespace FlexFlow { + +tl::expected parse_attention_parallel_input_shape(ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml new file mode 100644 index 0000000000..b0636db353 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml @@ -0,0 +1,46 @@ +namespace = "FlexFlow" +name = "MultiHeadAttentionParallelInputs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "op-attrs/datatype.dtg.h", + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", + "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", +] + +[[fields]] +name = "batch_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "sequence_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "query_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "key_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "value_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "discard_copy_degree" +type = "::FlexFlow::DiscardCopyDegree" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/attention_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/attention_inputs.dtg.h deleted file mode 100644 index 809c12c835..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/attention_inputs.dtg.h +++ /dev/null @@ -1,67 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml -/* proj-data -{ - "generated_from": "846dd6d3f4ca1c8135e4b3c8913fb872" -} -*/ - -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_INPUTS_DTG_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_INPUTS_DTG_H - -#include "fmt/format.h" -#include "nlohmann/json.hpp" -#include "op-attrs/tensor_shape.h" -#include "rapidcheck.h" -#include -#include -#include - -namespace FlexFlow { -struct MultiHeadAttentionInputs { - MultiHeadAttentionInputs() = delete; - MultiHeadAttentionInputs(::FlexFlow::TensorShape const &query, - ::FlexFlow::TensorShape const &key, - ::FlexFlow::TensorShape const &value); - - bool operator==(MultiHeadAttentionInputs const &) const; - bool operator!=(MultiHeadAttentionInputs const &) const; - bool operator<(MultiHeadAttentionInputs const &) const; - bool operator>(MultiHeadAttentionInputs const &) const; - bool operator<=(MultiHeadAttentionInputs const &) const; - bool operator>=(MultiHeadAttentionInputs const &) const; - ::FlexFlow::TensorShape query; - ::FlexFlow::TensorShape key; - ::FlexFlow::TensorShape value; -}; -} // namespace FlexFlow - -namespace std { -template <> -struct hash { - size_t operator()(FlexFlow::MultiHeadAttentionInputs const &) const; -}; -} // namespace std - -namespace nlohmann { -template <> -struct adl_serializer { - static FlexFlow::MultiHeadAttentionInputs from_json(json const &); - static void to_json(json &, FlexFlow::MultiHeadAttentionInputs const &); -}; -} // namespace nlohmann - -namespace rc { -template <> -struct Arbitrary { - static Gen arbitrary(); -}; -} // namespace rc - -namespace FlexFlow { -std::string format_as(MultiHeadAttentionInputs const &); -std::ostream &operator<<(std::ostream &, MultiHeadAttentionInputs const &); -} // namespace FlexFlow - -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_INPUTS_DTG_H diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index 07717a2b81..1cbebcbdc7 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -1,6 +1,9 @@ #include "op-attrs/ops/attention.h" #include "op-attrs/ops/attention/multihead_attention_inputs.h" +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h" +#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" +#include "utils/integer_conversions.h" namespace FlexFlow { @@ -40,7 +43,7 @@ int get_vSize(TensorShape const &value_shape) { return dim_at_idx(value_shape, ff_dim_t(0)); } -int get_qSize(ParallelMultiHeadAttentionInputs const &) { +int get_qSize(MultiHeadAttentionParallelInputs const &) { NOT_IMPLEMENTED(); } @@ -48,7 +51,7 @@ int get_qSize(MultiHeadAttentionInputs const &) { NOT_IMPLEMENTED(); } -int get_kSize(ParallelMultiHeadAttentionInputs const &) { +int get_kSize(MultiHeadAttentionParallelInputs const &) { NOT_IMPLEMENTED(); } @@ -56,7 +59,7 @@ int get_kSize(MultiHeadAttentionInputs const &) { NOT_IMPLEMENTED(); } -int get_vSize(ParallelMultiHeadAttentionInputs const &) { +int get_vSize(MultiHeadAttentionParallelInputs const &) { NOT_IMPLEMENTED(); } @@ -64,11 +67,10 @@ int get_vSize(MultiHeadAttentionInputs const &) { NOT_IMPLEMENTED(); } -tl::expected -get_weights_shape(MultiHeadAttentionAttrs const &attrs, - TensorShape const &input_q, - TensorShape const &input_k, - TensorShape const &input_v) { +tl::expected get_output_shape(MultiHeadAttentionAttrs const &attrs, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { tl::expected parse_result = parse_attention_input_shape(input_q, input_k, input_v); if (!parse_result.has_value()) { return tl::unexpected(parse_result.error()); @@ -78,73 +80,101 @@ get_weights_shape(MultiHeadAttentionAttrs const &attrs, return TensorShape{ TensorDims{ - ParallelTensorDims{ + FFOrdered{ parsed.batch_size, parsed.sequence_length, - attrs.embed_dim, + size_t_from_int(attrs.embed_dim), } }, parsed.datatype, + }; +} + +tl::expected +get_weights_shape(MultiHeadAttentionAttrs const &attrs, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { + tl::expected parse_result = parse_attention_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); } + + MultiHeadAttentionInputs parsed = parse_result.value(); + + // W^Q_i in "Attention Is All You Need" top of page 5 + size_t qProjectWeightSize = parsed.query_size * attrs.kdim; + + // W^K_i in "Attention Is All You Need" top of page 5 (all i's put together) + size_t kProjectWeightSize = parsed.key_size * attrs.kdim; + + // W^V_i in "Attention Is All You Need" top of page 5 (all i's put together) + size_t vProjectWeightSize = parsed.value_size * attrs.vdim; + + // W^O in "Attention Is All You Need" top of page 5, with num_heads factored out + size_t outWeightSize = parsed.value_size * attrs.embed_dim; + + return TensorShape{ + TensorDims{ + FFOrdered{ + (qProjectWeightSize + kProjectWeightSize + vProjectWeightSize + outWeightSize), + size_t_from_int(attrs.num_heads), + } + }, + parsed.datatype, + }; } tl::expected -get_weights_shape(MultiHeadAttentionAttrs const &, +get_weights_shape(MultiHeadAttentionAttrs const &attrs, ParallelTensorShape const &input_q, ParallelTensorShape const &input_k, ParallelTensorShape const &input_v) { - NOT_IMPLEMENTED(); -} - -tl::expected get_weights_shape(MultiHeadAttentionAttrs const &attrs, - TensorShape const &query_shape, - TensorShape const &key_shape, - TensorShape const &value_shape) { - MultiHeadAttentionInputs inputs = { - query_shape, - key_shape, - value_shape, - }; + tl::expected parse_result = parse_attention_parallel_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); + } + MultiHeadAttentionParallelInputs parsed = parse_result.value(); - size_t qParas = get_qProjSize(attrs) * get_qSize(inputs); - size_t kParas = get_kProjSize(attrs) * get_kSize(inputs); - size_t vParas = get_vProjSize(attrs) * get_vSize(inputs); - TensorShape output_shape = get_output_shape(attrs, query_shape, key_shape, value_shape); - size_t oParas = get_oProjSize(attrs) * get_oSize(output_shape); + tl::expected result_unpar_get_shape = get_weights_shape(attrs, get_reduced_shape(input_q), get_reduced_shape(input_k), get_reduced_shape(input_v)); + if (!result_unpar_get_shape.has_value()) { + return tl::unexpected(result_unpar_get_shape.error()); + } + TensorShape unpar_shape = result_unpar_get_shape.value(); - TensorDims dims = {{qParas + kParas + vParas + oParas, - static_cast(attrs.embed_dim)}}; + int joined_dim_degree = 1; + int head_dim_degree = parsed.discard_copy_degree.value; - return {dims, DataType::FLOAT}; + return lift_to_parallel_with_degrees(unpar_shape, SumDegree{1}, DiscardCopyDegree{parsed.batch_dim.degree}, FFOrdered{joined_dim_degree, head_dim_degree}); } - - tl::expected get_output_shape(MultiHeadAttentionAttrs const &attrs, - ParallelTensorShape const &query_shape, - ParallelTensorShape const &key_shape, - ParallelTensorShape const &value_shape) { - NOT_IMPLEMENTED(); - /* ParallelTensorShape output_shape = query_shape; */ - /* dim_at_idx(output_shape, ff_dim_t(num_dims(output_shape) - 1)).size = - * attrs.embed_dim; */ - /* return output_shape; */ -} + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v) { + tl::expected parse_result = parse_attention_parallel_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); + } + MultiHeadAttentionParallelInputs parsed = parse_result.value(); -tl::expected get_output_shape(MultiHeadAttentionAttrs const &attrs, - TensorShape const &query_shape, - TensorShape const &key_shape, - TensorShape const &value_shape) { - + tl::expected result_unpar_get_shape = get_output_shape(attrs, get_reduced_shape(input_q), get_reduced_shape(input_k), get_reduced_shape(input_v)); + if (!result_unpar_get_shape.has_value()) { + return tl::unexpected(result_unpar_get_shape.error()); + } + TensorShape unpar_shape = result_unpar_get_shape.value(); - size_t q_batchsize = dim_at_idx(query_shape -} -TensorShape get_output_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &) { - NOT_IMPLEMENTED(); + int sum_degree = parsed.discard_copy_degree.value; + int discard_copy_degree = 1; + int batch_degree = parsed.batch_dim.degree; + int seq_len_degree = 1; + int out_dim_degree = 1; + + return lift_to_parallel_with_degrees(unpar_shape, SumDegree{sum_degree}, DiscardCopyDegree{discard_copy_degree}, FFOrdered{batch_degree, seq_len_degree, out_dim_degree}); } + int get_oSize(ParallelTensorShape const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc index 849e4ff4d9..26d3138eb4 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc @@ -3,12 +3,13 @@ // lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml /* proj-data { - "generated_from": "655a3e56cf8a50fba6c1c9daf423720f" + "generated_from": "c57a9d1d2822a726ee9d9369d22e8e72" } */ #include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" +#include "op-attrs/datatype.dtg.h" #include #include diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc new file mode 100644 index 0000000000..4038d0c7a3 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc @@ -0,0 +1,90 @@ +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/ops/attention/multihead_attention_inputs.h" + +namespace FlexFlow { + +template +static bool all_same(T const &x, T const &y, T const &z) { + return x == y && y == z; +} + +tl::expected parse_attention_parallel_input_shape(ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v) { + tl::expected unpar_parse_result = parse_attention_input_shape( + get_reduced_shape(input_q), get_reduced_shape(input_k), get_reduced_shape(input_v)); + if (!unpar_parse_result.has_value()) { + return tl::unexpected(fmt::format("MHA unparallel input parsing failed with message: \"{}\"", unpar_parse_result.error())); + } + + if (num_shard_dims(input_q) != 3) { + return tl::unexpected(fmt::format("Query input has incorrect number of dims: {} != {}", num_shard_dims(input_q), 3)); + } + if (num_shard_dims(input_k) != 3) { + return tl::unexpected(fmt::format("Key input has incorrect number of dims: {} != {}", num_shard_dims(input_k), 3)); + } + if (num_shard_dims(input_v) != 3) { + return tl::unexpected(fmt::format("Value input has incorrect number of dims: {} != {}", num_shard_dims(input_v), 3)); + } + + ShardParallelDim seq_len_q = shard_dim_at_idx(input_q, ff_dim_t{-2}); + if (seq_len_q.degree != 1) { + return tl::unexpected(fmt::format("Query sequence length parallel degree expected to be 1, but received degree {}", seq_len_q.degree)); + } + + ShardParallelDim seq_len_k = shard_dim_at_idx(input_k, ff_dim_t{-2}); + if (seq_len_k.degree != 1) { + return tl::unexpected(fmt::format("Key sequence length parallel degree expected to be 1, but received degree {}", seq_len_k.degree)); + } + + ShardParallelDim seq_len_v = shard_dim_at_idx(input_v, ff_dim_t{-2}); + if (seq_len_v.degree != 1) { + return tl::unexpected(fmt::format("Value sequence length parallel degree expected to be 1, but received degree {}", seq_len_v.degree)); + } + + ShardParallelDim batch_size_q = shard_dim_at_idx(input_q, ff_dim_t{-3}); + ShardParallelDim batch_size_k = shard_dim_at_idx(input_k, ff_dim_t{-3}); + ShardParallelDim batch_size_v = shard_dim_at_idx(input_v, ff_dim_t{-3}); + + if (!all_same(batch_size_q.degree, batch_size_k.degree, batch_size_v.degree)) { + return tl::unexpected(fmt::format("Q, K, V disagree on the parallel degree of the batch dimension: {} (Q) vs {} (K) vs {} (V)", batch_size_q.degree, batch_size_k.degree, batch_size_v.degree)); + } + + ShardParallelDim query_dim = shard_dim_at_idx(input_q, ff_dim_t{-1}); + if (query_dim.degree > 1) { + return tl::unexpected(fmt::format("Expected query tensor to have query dim parallel degree 1, but received degree {}", query_dim.degree)); + } + + ShardParallelDim key_dim = shard_dim_at_idx(input_k, ff_dim_t{-1}); + if (key_dim.degree > 1) { + return tl::unexpected(fmt::format("Expected key tensor to have key dim parallel degree 1, but received degree {}", key_dim.degree)); + } + + ShardParallelDim value_dim = shard_dim_at_idx(input_v, ff_dim_t{-1}); + if (value_dim.degree > 1) { + return tl::unexpected(fmt::format("Expected value tensor to have value dim parallel degree 1, but received degree {}", value_dim.degree)); + } + + int discard_copy_q = get_discard_copy_degree(input_q); + int discard_copy_k = get_discard_copy_degree(input_k); + int discard_copy_v = get_discard_copy_degree(input_v); + + if (!all_same(discard_copy_q, discard_copy_k, discard_copy_v)) { + return tl::unexpected(fmt::format("Q, K, V disagree on the discard-copy degree: {} (Q) vs {} (K) vs {} (V)", discard_copy_q, discard_copy_k, discard_copy_v)); + } + + return MultiHeadAttentionParallelInputs{ + batch_size_q, + seq_len_q, + query_dim, + key_dim, + value_dim, + discard_copy_q, + input_q.data_type, + }; + + // return; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc new file mode 100644 index 0000000000..94784d83cc --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc @@ -0,0 +1,209 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml +/* proj-data +{ + "generated_from": "7c434445707968123a361c038a337da2" +} +*/ + +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" + +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" +#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include +#include + +namespace FlexFlow { +MultiHeadAttentionParallelInputs::MultiHeadAttentionParallelInputs( + ::FlexFlow::ShardParallelDim const &batch_dim, + ::FlexFlow::ShardParallelDim const &sequence_dim, + ::FlexFlow::ShardParallelDim const &query_dim, + ::FlexFlow::ShardParallelDim const &key_dim, + ::FlexFlow::ShardParallelDim const &value_dim, + ::FlexFlow::DiscardCopyDegree const &discard_copy_degree, + ::FlexFlow::DataType const &datatype) + : batch_dim(batch_dim), sequence_dim(sequence_dim), query_dim(query_dim), + key_dim(key_dim), value_dim(value_dim), + discard_copy_degree(discard_copy_degree), datatype(datatype) {} +bool MultiHeadAttentionParallelInputs::operator==( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) == std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +bool MultiHeadAttentionParallelInputs::operator!=( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) != std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +bool MultiHeadAttentionParallelInputs::operator<( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) < std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +bool MultiHeadAttentionParallelInputs::operator>( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) > std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +bool MultiHeadAttentionParallelInputs::operator<=( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) <= std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +bool MultiHeadAttentionParallelInputs::operator>=( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) >= std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MultiHeadAttentionParallelInputs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.batch_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.sequence_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.query_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.key_dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.value_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DiscardCopyDegree>{}(x.discard_copy_degree) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.datatype) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MultiHeadAttentionParallelInputs + adl_serializer::from_json( + json const &j) { + return { + j.at("batch_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("sequence_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("query_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("key_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("value_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("discard_copy_degree").template get<::FlexFlow::DiscardCopyDegree>(), + j.at("datatype").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MultiHeadAttentionParallelInputs const &v) { + j["__type"] = "MultiHeadAttentionParallelInputs"; + j["batch_dim"] = v.batch_dim; + j["sequence_dim"] = v.sequence_dim; + j["query_dim"] = v.query_dim; + j["key_dim"] = v.key_dim; + j["value_dim"] = v.value_dim; + j["discard_copy_degree"] = v.discard_copy_degree; + j["datatype"] = v.datatype; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::DiscardCopyDegree>(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionParallelInputs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + MultiHeadAttentionParallelInputs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention_inputs.dtg.cc deleted file mode 100644 index f8ad72ca7a..0000000000 --- a/lib/op-attrs/src/op-attrs/ops/attention_inputs.dtg.cc +++ /dev/null @@ -1,107 +0,0 @@ -// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! -// If you would like to modify this datatype, instead modify -// lib/op-attrs/include/op-attrs/ops/attention_inputs.struct.toml -/* proj-data -{ - "generated_from": "846dd6d3f4ca1c8135e4b3c8913fb872" -} -*/ - -#include "op-attrs/ops/attention_inputs.dtg.h" - -#include "op-attrs/tensor_shape.h" -#include - -namespace FlexFlow { -MultiHeadAttentionInputs::MultiHeadAttentionInputs( - ::FlexFlow::TensorShape const &query, - ::FlexFlow::TensorShape const &key, - ::FlexFlow::TensorShape const &value) - : query(query), key(key), value(value) {} -bool MultiHeadAttentionInputs::operator==( - MultiHeadAttentionInputs const &other) const { - return std::tie(this->query, this->key, this->value) == - std::tie(other.query, other.key, other.value); -} -bool MultiHeadAttentionInputs::operator!=( - MultiHeadAttentionInputs const &other) const { - return std::tie(this->query, this->key, this->value) != - std::tie(other.query, other.key, other.value); -} -bool MultiHeadAttentionInputs::operator<( - MultiHeadAttentionInputs const &other) const { - return std::tie(this->query, this->key, this->value) < - std::tie(other.query, other.key, other.value); -} -bool MultiHeadAttentionInputs::operator>( - MultiHeadAttentionInputs const &other) const { - return std::tie(this->query, this->key, this->value) > - std::tie(other.query, other.key, other.value); -} -bool MultiHeadAttentionInputs::operator<=( - MultiHeadAttentionInputs const &other) const { - return std::tie(this->query, this->key, this->value) <= - std::tie(other.query, other.key, other.value); -} -bool MultiHeadAttentionInputs::operator>=( - MultiHeadAttentionInputs const &other) const { - return std::tie(this->query, this->key, this->value) >= - std::tie(other.query, other.key, other.value); -} -} // namespace FlexFlow - -namespace std { -size_t hash::operator()( - FlexFlow::MultiHeadAttentionInputs const &x) const { - size_t result = 0; - result ^= std::hash<::FlexFlow::TensorShape>{}(x.query) + 0x9e3779b9 + - (result << 6) + (result >> 2); - result ^= std::hash<::FlexFlow::TensorShape>{}(x.key) + 0x9e3779b9 + - (result << 6) + (result >> 2); - result ^= std::hash<::FlexFlow::TensorShape>{}(x.value) + 0x9e3779b9 + - (result << 6) + (result >> 2); - return result; -} -} // namespace std - -namespace nlohmann { -FlexFlow::MultiHeadAttentionInputs - adl_serializer::from_json( - json const &j) { - return {j.at("query").template get<::FlexFlow::TensorShape>(), - j.at("key").template get<::FlexFlow::TensorShape>(), - j.at("value").template get<::FlexFlow::TensorShape>()}; -} -void adl_serializer::to_json( - json &j, FlexFlow::MultiHeadAttentionInputs const &v) { - j["__type"] = "MultiHeadAttentionInputs"; - j["query"] = v.query; - j["key"] = v.key; - j["value"] = v.value; -} -} // namespace nlohmann - -namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( - gen::arbitrary<::FlexFlow::TensorShape>(), - gen::arbitrary<::FlexFlow::TensorShape>(), - gen::arbitrary<::FlexFlow::TensorShape>()); -} -} // namespace rc - -namespace FlexFlow { -std::string format_as(MultiHeadAttentionInputs const &x) { - std::ostringstream oss; - oss << ""; - return oss.str(); -} -std::ostream &operator<<(std::ostream &s, MultiHeadAttentionInputs const &x) { - return s << fmt::to_string(x); -} -} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 3942b2c49f..019a9f9223 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -4,6 +4,7 @@ #include "op-attrs/shard_parallel_dim.h" #include "utils/containers.h" #include "utils/integer_conversions.h" +#include "op-attrs/dim_ordered/transform.h" namespace FlexFlow { @@ -59,8 +60,9 @@ TensorDims get_tensor_dims_unsafe(ParallelTensorDims const &) { NOT_IMPLEMENTED(); } -TensorDims get_reduced_dims(ParallelTensorDims const &) { - NOT_IMPLEMENTED(); +TensorDims get_reduced_dims(ParallelTensorDims const &dims) { + FFOrdered dim_sizes = transform(dims.shard_dims, [](ShardParallelDim const &d) { return d.size; }); + return TensorDims{dim_sizes}; } } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/test_attention.cc b/lib/op-attrs/test/src/test_attention.cc index e5310bca7c..a28068780a 100644 --- a/lib/op-attrs/test/src/test_attention.cc +++ b/lib/op-attrs/test/src/test_attention.cc @@ -1,12 +1,13 @@ #include "test/utils/doctest.h" #include "op-attrs/ops/attention.h" #include "utils/integer_conversions.h" +#include "op-attrs/parallel_tensor_shape.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(MultiHeadAttentionAttrs, TensorShape, TensorShape, TensorShape)") { int embed_dim = 32; - /* Parameter meanings can be found at + /* Parameter meanings match those at * https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html */ MultiHeadAttentionAttrs attrs = { @@ -20,15 +21,14 @@ TEST_SUITE(FF_TEST_SUITE) { /*add_zero_attn=*/false, }; - size_t n = 40; - size_t l = 48; - size_t s = 56; + size_t batch_size = 40; + size_t seq_len = 48; TensorShape input_q = { TensorDims{ FFOrdered{ - n, - l, + batch_size, + seq_len, size_t_from_int(attrs.embed_dim), } }, @@ -38,8 +38,8 @@ TEST_SUITE(FF_TEST_SUITE) { TensorShape input_k = { TensorDims{ FFOrdered{ - n, - s, + batch_size, + seq_len, size_t_from_int(attrs.kdim), }, }, @@ -49,103 +49,52 @@ TEST_SUITE(FF_TEST_SUITE) { TensorShape input_v = { TensorDims{ FFOrdered{ - n, - s, + batch_size, + seq_len, size_t_from_int(attrs.vdim), }, }, DataType::FLOAT, }; - tl::expected result = get_output_shape(attrs, input_q, input_k, input_v); + SUBCASE("get_output_shape") { + tl::expected result = get_output_shape(attrs, input_q, input_k, input_v); - tl::expected correct = TensorShape{ - TensorDims{ - FFOrdered{ - n, - l, - size_t_from_int(attrs.embed_dim), - } - }, - DataType::FLOAT, - }; - - CHECK(result == correct); - } - - TEST_CASE("get_weights_shape(MultiHeadAttentionAttrs, TensorShape, TensorShape, TensorShape)") { - int embed_dim = 32; - - /* Parameter meanings can be found at - * https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html - */ - MultiHeadAttentionAttrs attrs = { - /*embed_dim=*/embed_dim, - /*num_heads=*/10, - /*kdim=*/embed_dim, - /*vdim=*/embed_dim, - /*dropout=*/0.0, - /*bias=*/true, - /*add_bias_kv=*/false, - /*add_zero_attn=*/false, - }; - - size_t n = 40; - size_t l = 48; - size_t s = 56; - - TensorShape input_q = { - TensorDims{ - FFOrdered{ - n, - l, - size_t_from_int(attrs.embed_dim), - } - }, - DataType::FLOAT, - }; - - TensorShape input_k = { - TensorDims{ - FFOrdered{ - n, - s, - size_t_from_int(attrs.kdim), + tl::expected correct = TensorShape{ + TensorDims{ + FFOrdered{ + batch_size, + seq_len, + size_t_from_int(attrs.embed_dim), + } }, - }, - DataType::FLOAT, - }; - - TensorShape input_v = { - TensorDims{ - FFOrdered{ - n, - s, - size_t_from_int(attrs.vdim), - }, - }, - DataType::FLOAT, - }; - - tl::expected result = get_weights_shape(attrs, input_q, input_k, input_v); + DataType::FLOAT, + }; - int qProjPerHeadWeightSize = attrs.kdim * dim_at_idx(input_q, ff_dim_t{-1}); - int kProjPerHeadWeightSize = attrs.kdim * dim_at_idx(input_k, ff_dim_t{-1}); - int vProjPerHeadWeightSize = attrs.vdim * dim_at_idx(input_v, ff_dim_t{-1}); - int oProjPerHeadWeightSize = attrs.embed_dim * attrs.vdim; - int perHeadWeightSize = qProjPerHeadWeightSize + kProjPerHeadWeightSize + vProjPerHeadWeightSize + oProjPerHeadWeightSize; + CHECK(result == correct); + } - tl::expected correct = TensorShape{ - TensorDims{ - FFOrdered{ - size_t_from_int(perHeadWeightSize), - size_t_from_int(attrs.num_heads), - } - }, - DataType::FLOAT, - }; + SUBCASE("get_weights_shape") { + tl::expected result = get_weights_shape(attrs, input_q, input_k, input_v); + + int qProjPerHeadWeightSize = attrs.kdim * dim_at_idx(input_q, ff_dim_t{-1}); + int kProjPerHeadWeightSize = attrs.kdim * dim_at_idx(input_k, ff_dim_t{-1}); + int vProjPerHeadWeightSize = attrs.vdim * dim_at_idx(input_v, ff_dim_t{-1}); + int oProjPerHeadWeightSize = attrs.embed_dim * attrs.vdim; + int perHeadWeightSize = qProjPerHeadWeightSize + kProjPerHeadWeightSize + vProjPerHeadWeightSize + oProjPerHeadWeightSize; + + tl::expected correct = TensorShape{ + TensorDims{ + FFOrdered{ + size_t_from_int(perHeadWeightSize), + size_t_from_int(attrs.num_heads), + } + }, + DataType::FLOAT, + }; - CHECK(result == correct); + CHECK(result == correct); + } } TEST_CASE("parallel shape inference for MultiHeadAttentionAttrs") { @@ -204,8 +153,13 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - TensorShape unpar_o_shape = get_output_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); - TensorShape unpar_w_shape = get_output_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); + tl::expected result_unpar_o_shape = get_output_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); + REQUIRE(result_unpar_o_shape.has_value()); + TensorShape unpar_o_shape = result_unpar_o_shape.value(); + + tl::expected result_unpar_w_shape = get_weights_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); + REQUIRE(result_unpar_o_shape.has_value()); + TensorShape unpar_w_shape = result_unpar_w_shape.value(); auto make_q = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_batch, int o_seq_len, int o_q) { return lift_to_parallel_with_degrees(unpar_q_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_q}); @@ -243,5 +197,40 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result_w == correct_w); } + + SUBCASE("attention head parallelism") { + int o_h = 2; + ParallelTensorShape q = make_q(1, o_h, 1, 1, 1); + ParallelTensorShape k = make_k(1, o_h, 1, 1, 1); + ParallelTensorShape v = make_v(1, o_h, 1, 1, 1); + + tl::expected result_o = get_output_shape(attrs, q, k, v); + tl::expected correct_o = make_o(o_h, 1, 1, 1, 1); + + CHECK(result_o == correct_o); + + tl::expected result_w = get_weights_shape(attrs, q, k, v); + tl::expected correct_w = make_w(1, 1, 1, o_h); + + CHECK(result_w == correct_w); + } + + SUBCASE("combined data & attention head parallelism") { + int o_b = 4; + int o_h = 2; + ParallelTensorShape q = make_q(1, o_h, o_b, 1, 1); + ParallelTensorShape k = make_k(1, o_h, o_b, 1, 1); + ParallelTensorShape v = make_v(1, o_h, o_b, 1, 1); + + tl::expected result_o = get_output_shape(attrs, q, k, v); + tl::expected correct_o = make_o(o_h, 1, o_b, 1, 1); + + CHECK(result_o == correct_o); + + tl::expected result_w = get_weights_shape(attrs, q, k, v); + tl::expected correct_w = make_w(1, o_b, 1, o_h); + + CHECK(result_w == correct_w); + } } } diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 7b6b9e4697..b02c95bf77 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -219,11 +219,6 @@ template auto transform(req const &c, F const &f) -> decltype(transform(std::declval(), std::declval())); -template ()(std::declval()))> -std::vector vector_transform(F const &f, std::vector const &v); - template ()(std::declval()))> diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index bc34b00a47..3efd56726c 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -19,6 +19,7 @@ #include #include #include +#include "utils/containers/vector_transform.h" namespace FlexFlow { @@ -504,11 +505,6 @@ auto transform(req const &c, F const &f) return transform(static_cast(c), f); } -template -std::vector vector_transform(F const &f, std::vector const &v) { - return transform(v, f); -} - template std::unordered_set transform(std::unordered_set const &v, F const &f) { std::unordered_set result; diff --git a/lib/utils/include/utils/containers/vector_transform.h b/lib/utils/include/utils/containers/vector_transform.h new file mode 100644 index 0000000000..13865732aa --- /dev/null +++ b/lib/utils/include/utils/containers/vector_transform.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_TRANSFORM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_TRANSFORM_H + +#include +#include + +namespace FlexFlow { + +template +std::vector> vector_transform(std::vector const &v, F const &f) { + using Out = std::invoke_result_t; + + std::vector result; + std::transform(v.cbegin(), v.cend(), std::back_inserter(result), f); + return result; +} + +} // namespace FlexFlow + +#endif From 946acca269b6c629c97e3bd9e566efde07ea9c2d Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 27 May 2024 22:09:04 -0700 Subject: [PATCH 24/43] Enable op-attrs and pcg tests in CI --- .github/workflows/per-lib-check.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index a53a6afc11..eedcd2582e 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -88,6 +88,14 @@ jobs: run: | test_libs.sh utils + - name: Test op-attrs + run: | + test_libs.sh op-attrs + + - name: Test pcg + run: | + test_libs.sh pcg + - name: Test substitutions run: | test_libs.sh substitutions From 1a6d1f8d9de10f1d9159686c61b13e07b71671b6 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 29 May 2024 18:01:02 -0700 Subject: [PATCH 25/43] Add parallel shape inference for add and relu --- .../include/op-attrs/ops/element_binary.h | 12 +- .../include/op-attrs/ops/element_unary.h | 22 +-- lib/op-attrs/include/op-attrs/tensor_dims.h | 1 + lib/op-attrs/include/op-attrs/tensor_shape.h | 1 + .../src/op-attrs/ops/element_binary.cc | 65 +++++++-- .../src/op-attrs/ops/element_unary.cc | 43 ++++-- lib/op-attrs/src/op-attrs/tensor_dims.cc | 7 + lib/op-attrs/src/op-attrs/tensor_shape.cc | 4 + lib/op-attrs/test/src/test_element_binary.cc | 130 ++++++++++++++++++ lib/op-attrs/test/src/test_element_unary.cc | 60 ++++++++ lib/pcg/src/pcg/computation_graph_builder.cc | 9 +- lib/utils/include/utils/exception.decl.h | 4 + lib/utils/include/utils/exception.h | 10 ++ 13 files changed, 332 insertions(+), 36 deletions(-) create mode 100644 lib/op-attrs/test/src/test_element_binary.cc create mode 100644 lib/op-attrs/test/src/test_element_unary.cc diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary.h b/lib/op-attrs/include/op-attrs/ops/element_binary.h index 39ae70ecfe..a33e48e524 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary.h @@ -1,18 +1,20 @@ #ifndef _FLEXFLOW_ELEMENT_BINARY_ATTRS_H #define _FLEXFLOW_ELEMENT_BINARY_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/element_binary_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" +#include namespace FlexFlow { -ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); -TensorShape get_output_shape(ElementBinaryAttrs const &, +tl::expected + get_output_shape(ElementBinaryAttrs const &, TensorShape const &, TensorShape const &); +tl::expected get_output_shape(ElementBinaryAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(ElementBinaryAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index cfec033a16..3dc16d4ebb 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -1,22 +1,28 @@ #ifndef _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H #define _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" #include "op-attrs/ops/element_unary_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { -ParallelTensorShape get_output_shape(ElementUnaryAttrs const &, - ParallelTensorShape const &); -TensorShape get_output_shape(ElementUnaryAttrs const &, TensorShape const &); +tl::expected + get_output_shape(ElementUnaryAttrs const &, + TensorShape const &); +tl::expected + get_output_shape(ElementUnaryAttrs const &, + ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ElementScalarUnaryAttrs const &, - ParallelTensorShape const &); -TensorShape get_output_shape(ElementScalarUnaryAttrs const &, - TensorShape const &); +tl::expected + get_output_shape(ElementScalarUnaryAttrs const &, + TensorShape const &); +tl::expected + get_output_shape(ElementScalarUnaryAttrs const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h index caee5c72ab..0f4a793430 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -10,6 +10,7 @@ FFOrdered const &ff_ordered(TensorDims const &); size_t num_dims(TensorDims const &); size_t dim_at_idx(TensorDims const &, ff_dim_t); +size_t &dim_at_idx(TensorDims &, ff_dim_t); ParallelTensorDims lift_to_parallel(TensorDims const &); ParallelTensorDims lift_to_parallel_with_degrees(TensorDims const &, SumDegree sum_degree, DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees); diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.h b/lib/op-attrs/include/op-attrs/tensor_shape.h index de27632919..ad751461e8 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.h @@ -7,6 +7,7 @@ namespace FlexFlow { size_t num_dims(TensorShape const &); size_t dim_at_idx(TensorShape const &, ff_dim_t); +size_t &dim_at_idx(TensorShape &, ff_dim_t); } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary.cc b/lib/op-attrs/src/op-attrs/ops/element_binary.cc index 2bbd3a1e2e..d998834f0a 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_binary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_binary.cc @@ -2,16 +2,65 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected +get_output_shape(ElementBinaryAttrs const &attrs, + TensorShape const &input_lhs, + TensorShape const &input_rhs) { + assert (!(attrs.should_broadcast_lhs && attrs.should_broadcast_rhs)); + + if (attrs.should_broadcast_lhs) { + NOT_IMPLEMENTED(); + } else if (attrs.should_broadcast_rhs) { + NOT_IMPLEMENTED(); + } else { + if (input_lhs != input_rhs) { + return tl::unexpected(fmt::format("Expected input shapes to match, but receieved LHS ({}) != RHS ({})", input_lhs, input_rhs)); + } + + return input_lhs; + } } -TensorShape get_output_shape(ElementBinaryAttrs const &, - TensorShape const &, - TensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(ElementBinaryAttrs const &attrs, + ParallelTensorShape const &input_lhs, + ParallelTensorShape const &input_rhs) { + assert (!(attrs.should_broadcast_lhs && attrs.should_broadcast_rhs)); + + if (attrs.should_broadcast_lhs) { + NOT_IMPLEMENTED(); + } else if (attrs.should_broadcast_rhs) { + NOT_IMPLEMENTED(); + } else { + if (input_lhs != input_rhs) { + return tl::unexpected(fmt::format("Expected input shapes to match, but receieved LHS ({}) != RHS ({})", input_lhs, input_rhs)); + } + + switch (attrs.type) { + case OperatorType::EW_ADD: + { + if (get_discard_copy_degree(input_lhs) != 1) { + return tl::unexpected(fmt::format("Elementwise Add expected discard copy degree of inputs to be 1, but receieved {}", get_discard_copy_degree(input_lhs))); + } + + break; + } + case OperatorType::EW_SUB: + NOT_IMPLEMENTED(); + case OperatorType::EW_MUL: + NOT_IMPLEMENTED(); + case OperatorType::EW_DIV: + NOT_IMPLEMENTED(); + case OperatorType::EW_MAX: + NOT_IMPLEMENTED(); + case OperatorType::EW_MIN: + NOT_IMPLEMENTED(); + default: + return tl::unexpected(fmt::format("Unexpected element-wise binary operator {}", attrs.type)); + } + + return input_lhs; + } } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/src/op-attrs/ops/element_unary.cc index 800199a51a..66feac42e4 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_unary.cc @@ -1,24 +1,45 @@ #include "op-attrs/ops/element_unary.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { -ParallelTensorShape get_output_shape(ElementUnaryAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(ElementUnaryAttrs const &attrs, TensorShape const &input_shape) { + return input_shape; } -TensorShape get_output_shape(ElementUnaryAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected get_output_shape(ElementUnaryAttrs const &attrs, + ParallelTensorShape const &input_shape) { + if (get_sum_degree(input_shape) != 1) { + return tl::unexpected(fmt::format("Expected sum degree 1, but receieved sum degree {}", get_sum_degree(input_shape))); + } + + if (get_discard_copy_degree(input_shape) != 1) { + return tl::unexpected(fmt::format("Expected discard copy degree 1, but received discartd copy degree {}", get_discard_copy_degree(input_shape))); + } + + return input_shape; } -ParallelTensorShape get_output_shape(ElementScalarUnaryAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(ElementScalarUnaryAttrs const &attrs, + TensorShape const &input_shape) { + return input_shape; } -TensorShape get_output_shape(ElementScalarUnaryAttrs const &, - TensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(ElementScalarUnaryAttrs const &attrs, + ParallelTensorShape const &input_shape) { + if (get_sum_degree(input_shape) != 1) { + return tl::unexpected(fmt::format("Expected sum degree 1, but receieved sum degree {}", get_sum_degree(input_shape))); + } + + if (get_discard_copy_degree(input_shape) != 1) { + return tl::unexpected(fmt::format("Expected discard copy degree 1, but received discartd copy degree {}", get_discard_copy_degree(input_shape))); + } + + return input_shape; } + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index 74ebedd816..265241bacc 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -22,6 +22,13 @@ size_t dim_at_idx(TensorDims const &dims, ff_dim_t idx) { return dims.ff_ordered.at(idx); } +size_t &dim_at_idx(TensorDims &dims, ff_dim_t idx) { + if (idx.value < 0) { + idx = ff_dim_t{int_from_size_t(num_dims(dims)) + idx.value}; + } + return dims.ff_ordered.at(idx); +} + ParallelTensorDims lift_to_parallel(TensorDims const &dims) { std::vector shard_degrees(num_dims(dims), 1); // 1 repeated num_dims(dims) times return lift_to_parallel_with_degrees(dims, 1, 1, shard_degrees); diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.cc b/lib/op-attrs/src/op-attrs/tensor_shape.cc index 01afbddf1e..850bea6d00 100644 --- a/lib/op-attrs/src/op-attrs/tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.cc @@ -11,4 +11,8 @@ size_t dim_at_idx(TensorShape const &s, ff_dim_t idx) { return dim_at_idx(s.dims, idx); } +size_t &dim_at_idx(TensorShape &s, ff_dim_t idx) { + return dim_at_idx(s.dims, idx); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/test_element_binary.cc b/lib/op-attrs/test/src/test_element_binary.cc new file mode 100644 index 0000000000..fa0841b732 --- /dev/null +++ b/lib/op-attrs/test/src/test_element_binary.cc @@ -0,0 +1,130 @@ +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest.h" +#include "op-attrs/ops/element_binary.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("EWAdd shape inference") { + size_t d1 = 16; + size_t d2 = 32; + size_t d3 = 24; + + ElementBinaryAttrs attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }; + + TensorShape input_lhs = TensorShape{ + TensorDims{ + FFOrdered{ + d1, + d2, + d3, + }, + }, + DataType::FLOAT, + }; + + TensorShape input_rhs = input_lhs; + + SUBCASE("correct") { + tl::expected result = get_output_shape(attrs, input_lhs, input_rhs); + tl::expected correct = input_lhs; + + CHECK(result == correct); + } + + SUBCASE("mismatched dim size") { + TensorShape incorrect_rhs = input_lhs; + dim_at_idx(incorrect_rhs, ff_dim_t{0}) += 1; + + tl::expected result = get_output_shape(attrs, input_lhs, incorrect_rhs); + + CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", result.error()); + } + } + + TEST_CASE("EWAdd parallel shape inference") { + size_t d1 = 16; + size_t d2 = 32; + size_t d3 = 24; + + ElementBinaryAttrs attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }; + + TensorShape unpar_lhs = TensorShape{ + TensorDims{ + FFOrdered{ + d1, + d2, + d3, + }, + }, + DataType::FLOAT, + }; + + TensorShape unpar_rhs = unpar_lhs; + tl::expected result_unpar_output = get_output_shape(attrs, unpar_lhs, unpar_rhs); + REQUIRE(result_unpar_output.has_value()); + TensorShape unpar_output = result_unpar_output.value(); + + auto make_lhs = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_1, int o_2, int o_3) { + return lift_to_parallel_with_degrees(unpar_lhs, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + }; + + auto make_rhs = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_1, int o_2, int o_3) { + return lift_to_parallel_with_degrees(unpar_rhs, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + }; + + auto make_output = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_1, int o_2, int o_3) { + return lift_to_parallel_with_degrees(unpar_output, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + }; + + SUBCASE("data parallelism") { + int degree = 4; + + ParallelTensorShape input_lhs = make_lhs(1, 1, degree, 1, 1); + ParallelTensorShape input_rhs = make_rhs(1, 1, degree, 1, 1); + tl::expected result = get_output_shape(attrs, input_lhs, input_rhs); + tl::expected correct = make_output(1, 1, degree, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("reduction parallelism") { + int degree = 4; + + ParallelTensorShape input_lhs = make_lhs(SumDegree{degree}, 1, 1, 1, 1); + ParallelTensorShape input_rhs = make_rhs(SumDegree{degree}, 1, 1, 1, 1); + tl::expected result = get_output_shape(attrs, input_lhs, input_rhs); + tl::expected correct = make_output(SumDegree{degree}, 1, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("invalid discard copy parallelism") { + int degree = 4; + + ParallelTensorShape input_lhs = make_lhs(1, DiscardCopyDegree{degree}, 1, 1, 1); + ParallelTensorShape input_rhs = make_rhs(1, DiscardCopyDegree{degree}, 1, 1, 1); + tl::expected result = get_output_shape(attrs, input_lhs, input_rhs); + + CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", result.error()); + } + + SUBCASE("invalid mismatched parallelism degrees") { + int degree = 4; + + ParallelTensorShape input_lhs = make_lhs(1, 1, 1, degree, 1); + ParallelTensorShape input_rhs = make_rhs(1, 1, 1, 1, degree); + tl::expected result = get_output_shape(attrs, input_lhs, input_rhs); + + CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", result.error()); + } + } +} diff --git a/lib/op-attrs/test/src/test_element_unary.cc b/lib/op-attrs/test/src/test_element_unary.cc new file mode 100644 index 0000000000..2c7506dc8f --- /dev/null +++ b/lib/op-attrs/test/src/test_element_unary.cc @@ -0,0 +1,60 @@ +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest.h" +#include "op-attrs/ops/element_unary.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("ReLU shape inference") { + size_t d1 = 16; + size_t d2 = 32; + size_t d3 = 24; + + ElementUnaryAttrs attrs = ElementUnaryAttrs{OperatorType::RELU}; + + TensorShape input = TensorShape{ + TensorDims{ + FFOrdered{ + d1, + d2, + d3, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = get_output_shape(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + + auto make_i = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_1, int o_2, int o_3) { + return lift_to_parallel_with_degrees(input, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + }; + + SUBCASE("partition i.e., sharding parallelism") { + int degree1 = 4; + int degree2 = 8; + ParallelTensorShape par_input = make_i(1, 1, degree1, 1, degree2); + + tl::expected result = get_output_shape(attrs, par_input); + tl::expected correct = par_input; + + CHECK(result == correct); + } + + SUBCASE("sum degree > 1") { + int degree = 2; + + tl::expected result = get_output_shape(attrs, make_i(SumDegree{degree}, 1, 1, 1, 1)); + + CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", result.error()); + } + + SUBCASE("discard copy degree > 1") { + int degree = 2; + + tl::expected result = get_output_shape(attrs, make_i(1, DiscardCopyDegree{degree}, 1, 1, 1)); + + CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", result.error()); + } + } +} diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 7969b40ac7..5e15f50966 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -144,7 +144,7 @@ tensor_guid_t ComputationGraphBuilder::element_unary( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {}, output_shape); } @@ -161,7 +161,7 @@ tensor_guid_t ComputationGraphBuilder::element_scalar_unary( LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {}, output_shape); } @@ -205,8 +205,9 @@ tensor_guid_t ComputationGraphBuilder::element_binary( LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = get_output_shape( - attrs, this->get_shape(lhs_input), this->get_shape(rhs_input)); + TensorShape output_shape = throw_if_unexpected(get_output_shape( + attrs, this->get_shape(lhs_input), this->get_shape(rhs_input)) + ); return this->add_layer(layer, {lhs_input, rhs_input}, {}, output_shape); } diff --git a/lib/utils/include/utils/exception.decl.h b/lib/utils/include/utils/exception.decl.h index e41dff9b5a..93c450294b 100644 --- a/lib/utils/include/utils/exception.decl.h +++ b/lib/utils/include/utils/exception.decl.h @@ -3,6 +3,7 @@ #include "utils/fmt.decl.h" #include +#include namespace FlexFlow { @@ -23,6 +24,9 @@ class not_implemented : public std::logic_error { int line); }; +template +T throw_if_unexpected(tl::expected const &r); + template std::runtime_error mk_runtime_error(fmt::format_string fmt_str, T &&...args); diff --git a/lib/utils/include/utils/exception.h b/lib/utils/include/utils/exception.h index fd3a0b7ee0..a00d2dba2b 100644 --- a/lib/utils/include/utils/exception.h +++ b/lib/utils/include/utils/exception.h @@ -4,9 +4,19 @@ #include "utils/exception.decl.h" #include "utils/fmt.h" #include +#include namespace FlexFlow { +template +T throw_if_unexpected(tl::expected const &r) { + if (r.has_value()) { + return r.value(); + } else { + throw std::runtime_error(fmt::to_string(r.error())); + } +} + template std::runtime_error mk_runtime_error(fmt::format_string fmt_str, T &&...args) { From 3e7ceac69eaa6a326ae7e2b097171e0c0cecc9c4 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 31 May 2024 13:27:52 -0700 Subject: [PATCH 26/43] Add parallel shape inference for embedding --- lib/op-attrs/include/op-attrs/dim_ordered.h | 12 +- .../include/op-attrs/dim_ordered/slice.h | 41 ++++++ .../include/op-attrs/get_output_shapes.h | 2 - lib/op-attrs/include/op-attrs/ops/embedding.h | 10 +- .../op-attrs/ops/embedding_attrs.dtg.h | 6 +- .../op-attrs/ops/embedding_attrs.struct.toml | 2 +- lib/op-attrs/src/op-attrs/ops/embedding.cc | 96 ++++++++++++- .../src/op-attrs/ops/embedding_attrs.dtg.cc | 19 +-- .../src/op-attrs/parallel_tensor_dims.cc | 3 - lib/op-attrs/src/op-attrs/tensor_dims.cc | 6 - lib/op-attrs/test/src/dim_ordered/slice.cc | 17 +++ lib/op-attrs/test/src/test_embedding.cc | 132 ++++++++++++++++++ lib/pcg/src/pcg/computation_graph_builder.cc | 4 +- lib/utils/include/utils/containers.h | 2 +- 14 files changed, 315 insertions(+), 37 deletions(-) create mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/slice.h create mode 100644 lib/op-attrs/test/src/dim_ordered/slice.cc create mode 100644 lib/op-attrs/test/src/test_embedding.cc diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index e7c1891a4b..685d60c370 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -29,11 +29,19 @@ struct DimOrdered { : contents(contents.begin(), contents.end()) {} T const &at(Idx idx) const { - return this->contents.at(idx.value); + int raw = idx.value; + if (raw < 0) { + raw = this->contents.size() + raw; + } + return this->contents.at(raw); } T &at(Idx idx) { - return this->contents.at(idx.value); + int raw = idx.value; + if (raw < 0) { + raw = this->contents.size() + raw; + } + return this->contents.at(raw); } T const &operator[](Idx idx) const { diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h new file mode 100644 index 0000000000..0bc4c7513e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H + +#include "op-attrs/dim_ordered.h" +#include "utils/containers.h" +#include "utils/optional.h" + +namespace FlexFlow { + +template +DimOrdered nonoverloaded_slice(DimOrdered const &d, std::optional const &start, std::optional const &end) { + auto to_raw_idx = [](std::optional const &idx) -> std::optional { + return transform(idx, [](Idx const &i) { return i.value; }); + }; + + return DimOrdered{subvec(as_vector(d), to_raw_idx(start), to_raw_idx(end))}; +} + +template +DimOrdered slice(DimOrdered const &d, std::optional const &start, std::optional const &end) { + return nonoverloaded_slice(d, start, end); +} + +template +DimOrdered slice(DimOrdered const &d, std::nullopt_t const &start, Idx const &end) { + return nonoverloaded_slice(d, std::optional{start}, std::optional{end}); +} + +template +DimOrdered slice(DimOrdered const &d, Idx const &start, std::nullopt_t const &end) { + return nonoverloaded_slice(d, std::optional{start}, std::optional{end}); +} + +template +DimOrdered slice(DimOrdered const &d, Idx const &start, Idx const &end) { + return nonoverloaded_slice(d, std::optional{start}, std::optional{end}); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/get_output_shapes.h b/lib/op-attrs/include/op-attrs/get_output_shapes.h index c99ab6d901..60d66babc4 100644 --- a/lib/op-attrs/include/op-attrs/get_output_shapes.h +++ b/lib/op-attrs/include/op-attrs/get_output_shapes.h @@ -122,8 +122,6 @@ ParallelTensorShape get_output_shape(Conv2DAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(DropoutAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(EmbeddingAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(FlatAttrs const &, ParallelTensorShape const &); std::vector get_output_shapes(GatherAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/embedding.h b/lib/op-attrs/include/op-attrs/ops/embedding.h index f7a2226643..52c868edee 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding.h @@ -1,19 +1,23 @@ #ifndef _FLEXFLOW_EMBEDDING_ATTRS_H #define _FLEXFLOW_EMBEDDING_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/embedding_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" +#include namespace FlexFlow { CHECK_VALID_OP_ATTR(EmbeddingAttrs); -TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &); +tl::expected get_output_shape(EmbeddingAttrs const &, TensorShape const &); +tl::expected get_weights_shape(EmbeddingAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(EmbeddingAttrs const &, +tl::expected get_output_shape(EmbeddingAttrs const &, ParallelTensorShape const &); +tl::expected get_weights_shape(EmbeddingAttrs const &, + ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h index 0b5bed5ba7..f1cae86460 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml /* proj-data { - "generated_from": "a0ac41fc0f56bc06bcb1a8d42fc6191c" + "generated_from": "f2bdea52e23dee6f674f598f8691d994" } */ @@ -25,7 +25,7 @@ struct EmbeddingAttrs { EmbeddingAttrs() = delete; EmbeddingAttrs(int const &num_entries, int const &out_channels, - ::FlexFlow::AggregateOp const &aggr, + std::optional<::FlexFlow::AggregateOp> const &aggr, ::FlexFlow::DataType const &data_type); bool operator==(EmbeddingAttrs const &) const; @@ -36,7 +36,7 @@ struct EmbeddingAttrs { bool operator>=(EmbeddingAttrs const &) const; int num_entries; int out_channels; - ::FlexFlow::AggregateOp aggr; + std::optional<::FlexFlow::AggregateOp> aggr; ::FlexFlow::DataType data_type; }; } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml index 39dc71bdb3..f0772c351e 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml @@ -25,7 +25,7 @@ type = "int" [[fields]] name = "aggr" -type = "::FlexFlow::AggregateOp" +type = "std::optional<::FlexFlow::AggregateOp>" [[fields]] name = "data_type" diff --git a/lib/op-attrs/src/op-attrs/ops/embedding.cc b/lib/op-attrs/src/op-attrs/ops/embedding.cc index c8c63c70ea..b62656ff02 100644 --- a/lib/op-attrs/src/op-attrs/ops/embedding.cc +++ b/lib/op-attrs/src/op-attrs/ops/embedding.cc @@ -1,14 +1,100 @@ #include "op-attrs/ops/embedding.h" +#include "utils/integer_conversions.h" +#include "op-attrs/dim_ordered/transform.h" +#include "op-attrs/dim_ordered/slice.h" +#include "utils/containers.h" namespace FlexFlow { -TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); +static std::optional basic_check(EmbeddingAttrs const &attrs, TensorShape const &input) { + if (input.data_type != DataType::INT32 && input.data_type != DataType::INT64) { + return fmt::format("Embedding expected input tensor to have integer datatype, but receieved tensor of datatype {}", input.data_type); + } + + if (attrs.aggr != AggregateOp::SUM) { + return fmt::format(fmt::format("Currently unsupported aggregation op for embedding: {}", attrs.aggr)); + } + + return std::nullopt; +} + +tl::expected +get_output_shape(EmbeddingAttrs const &attrs, TensorShape const &input) { + { + std::optional err_msg = basic_check(attrs, input); + if (err_msg.has_value()) { + return tl::unexpected(err_msg.value()); + } + } + + TensorShape output = input; + dim_at_idx(output, ff_dim_t{-1}) = attrs.out_channels; + output.data_type = attrs.data_type; + return output; +} + +tl::expected +get_weights_shape(EmbeddingAttrs const &attrs, TensorShape const &input) { + { + std::optional err_msg = basic_check(attrs, input); + if (err_msg.has_value()) { + return tl::unexpected(err_msg.value()); + } + } + + return TensorShape{ + TensorDims{ + FFOrdered{ + size_t_from_int(attrs.num_entries), + size_t_from_int(attrs.out_channels), + }, + }, + attrs.data_type, + }; } -ParallelTensorShape get_output_shape(EmbeddingAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected +get_output_shape(EmbeddingAttrs const &attrs, ParallelTensorShape const &input) { + + TensorShape unpar = ({ + tl::expected result_unpar = get_output_shape(attrs, get_reduced_shape(input)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + SumDegree sum_degree = shard_dim_at_idx(input, ff_dim_t{-1}).degree; + DiscardCopyDegree discard_copy_degree = 1; + FFOrdered shard_degrees = transform(input.dims.shard_dims, [](ShardParallelDim const &d) { return d.degree; }); + shard_degrees.at(ff_dim_t{-1}) = get_discard_copy_degree(input); + + return lift_to_parallel_with_degrees(unpar, sum_degree, discard_copy_degree, shard_degrees); +} + +tl::expected +get_weights_shape(EmbeddingAttrs const &attrs, ParallelTensorShape const &input) { + TensorShape unpar = ({ + tl::expected result_unpar = get_weights_shape(attrs, get_reduced_shape(input)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + SumDegree sum_degree = 1; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ + product(transform(ff_ordered_shard_dims(input.dims), + [](ShardParallelDim const &d) -> int { return d.degree; })) + }; + int entry_dim_degree = 1; + int out_channel_degree = get_discard_copy_degree(input); + FFOrdered shard_degrees = { + entry_dim_degree, + out_channel_degree, + }; + + return lift_to_parallel_with_degrees(unpar, sum_degree, discard_copy_degree, shard_degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc index b5a4028e13..b4d4657e08 100644 --- a/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml /* proj-data { - "generated_from": "a0ac41fc0f56bc06bcb1a8d42fc6191c" + "generated_from": "f2bdea52e23dee6f674f598f8691d994" } */ @@ -15,10 +15,11 @@ #include namespace FlexFlow { -EmbeddingAttrs::EmbeddingAttrs(int const &num_entries, - int const &out_channels, - ::FlexFlow::AggregateOp const &aggr, - ::FlexFlow::DataType const &data_type) +EmbeddingAttrs::EmbeddingAttrs( + int const &num_entries, + int const &out_channels, + std::optional<::FlexFlow::AggregateOp> const &aggr, + ::FlexFlow::DataType const &data_type) : num_entries(num_entries), out_channels(out_channels), aggr(aggr), data_type(data_type) {} bool EmbeddingAttrs::operator==(EmbeddingAttrs const &other) const { @@ -85,8 +86,8 @@ size_t hash::operator()( (result >> 2); result ^= std::hash{}(x.out_channels) + 0x9e3779b9 + (result << 6) + (result >> 2); - result ^= std::hash<::FlexFlow::AggregateOp>{}(x.aggr) + 0x9e3779b9 + - (result << 6) + (result >> 2); + result ^= std::hash>{}(x.aggr) + + 0x9e3779b9 + (result << 6) + (result >> 2); result ^= std::hash<::FlexFlow::DataType>{}(x.data_type) + 0x9e3779b9 + (result << 6) + (result >> 2); return result; @@ -98,7 +99,7 @@ FlexFlow::EmbeddingAttrs adl_serializer::from_json(json const &j) { return {j.at("num_entries").template get(), j.at("out_channels").template get(), - j.at("aggr").template get<::FlexFlow::AggregateOp>(), + j.at("aggr").template get>(), j.at("data_type").template get<::FlexFlow::DataType>()}; } void adl_serializer::to_json( @@ -116,7 +117,7 @@ Gen Arbitrary::arbitrary() { return gen::construct( gen::arbitrary(), gen::arbitrary(), - gen::arbitrary<::FlexFlow::AggregateOp>(), + gen::arbitrary>(), gen::arbitrary<::FlexFlow::DataType>()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 019a9f9223..89a2934704 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -42,9 +42,6 @@ bool is_valid(ParallelTensorDims const &dims) { } ShardParallelDim shard_dim_at_idx(ParallelTensorDims const &d, ff_dim_t idx) { - if (idx.value < 0) { - idx = ff_dim_t{int_from_size_t(d.shard_dims.size()) + idx.value}; - } return d.shard_dims.at(idx); } diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index 265241bacc..512e9f0804 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -16,16 +16,10 @@ size_t num_dims(TensorDims const &dims) { } size_t dim_at_idx(TensorDims const &dims, ff_dim_t idx) { - if (idx.value < 0) { - idx = ff_dim_t{int_from_size_t(num_dims(dims)) + idx.value}; - } return dims.ff_ordered.at(idx); } size_t &dim_at_idx(TensorDims &dims, ff_dim_t idx) { - if (idx.value < 0) { - idx = ff_dim_t{int_from_size_t(num_dims(dims)) + idx.value}; - } return dims.ff_ordered.at(idx); } diff --git a/lib/op-attrs/test/src/dim_ordered/slice.cc b/lib/op-attrs/test/src/dim_ordered/slice.cc new file mode 100644 index 0000000000..26e49e630e --- /dev/null +++ b/lib/op-attrs/test/src/dim_ordered/slice.cc @@ -0,0 +1,17 @@ +#include "op-attrs/dim_ordered/slice.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("slice(DimOrdered, std::optional, std::optional)") { + FFOrdered d = FFOrdered{ + 1, 2, 3, 4, + }; + + FFOrdered result = slice(d, std::nullopt, ff_dim_t{-1}); + FFOrdered correct = FFOrdered{ + 1, 2, 3, + }; + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/test_embedding.cc b/lib/op-attrs/test/src/test_embedding.cc new file mode 100644 index 0000000000..f03ffdd27f --- /dev/null +++ b/lib/op-attrs/test/src/test_embedding.cc @@ -0,0 +1,132 @@ +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest.h" +#include "op-attrs/ops/embedding.h" +#include "utils/integer_conversions.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Sum embedding shape inference") { + int out_channels = 128; + int num_entries = 1024; + EmbeddingAttrs attrs = EmbeddingAttrs{ + /*num_entries=*/num_entries, + /*out_channels=*/out_channels, + /*aggr=*/AggregateOp::SUM, + /*data_type=*/DataType::FLOAT, + }; + + size_t batch_size = 48; + size_t features_dim = 56; + + TensorShape input = { + TensorDims{ + FFOrdered{ + batch_size, + features_dim, + } + }, + DataType::INT32, + }; + + TensorShape output = TensorShape{ + TensorDims{ + FFOrdered{ + batch_size, + size_t_from_int(out_channels), + }, + }, + DataType::FLOAT, + }; + + TensorShape weights = TensorShape{ + TensorDims{ + FFOrdered{ + size_t_from_int(num_entries), + size_t_from_int(out_channels), + }, + }, + DataType::FLOAT, + }; + + // get_output_shape + { + tl::expected output_result = get_output_shape(attrs, input); + tl::expected output_correct = output; + CHECK(output_result == output_correct); + } + + // get_weights_shape + { + tl::expected weight_result = get_weights_shape(attrs, input); + tl::expected weight_correct = weights; + CHECK(weight_result == weight_correct); + } + + auto make_input = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_batch, int o_features) { + return lift_to_parallel_with_degrees(input, o_sum, o_eq, FFOrdered{o_batch, o_features}); + }; + + auto make_output = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_batch, int o_outchannels) { + return lift_to_parallel_with_degrees(output, o_sum, o_eq, FFOrdered{o_batch, o_outchannels}); + }; + + auto make_weights = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_entries, int o_outchannels) { + return lift_to_parallel_with_degrees(weights, o_sum, o_eq, FFOrdered{o_entries, o_outchannels}); + }; + + SUBCASE("data parallelism") { + int degree = 4; + ParallelTensorShape par_input = make_input(SumDegree{1}, DiscardCopyDegree{1}, degree, 1); + + { + tl::expected result = get_output_shape(attrs, par_input); + tl::expected correct = make_output(SumDegree{1}, DiscardCopyDegree{1}, degree, 1); + CHECK(result == correct); + } + + { + tl::expected result = get_weights_shape(attrs, par_input); + tl::expected correct = make_weights(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1); + CHECK(result == correct); + } + } + + SUBCASE("input features parallelism") { + int degree = 4; + ParallelTensorShape input = make_input(SumDegree{1}, DiscardCopyDegree{1}, 1, degree); + + { + tl::expected result = get_output_shape(attrs, input); + tl::expected correct = make_output(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1); + CHECK(result == correct); + } + + { + tl::expected result = get_weights_shape(attrs, input); + tl::expected correct = make_weights(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1); + CHECK(result == correct); + } + } + + + SUBCASE("output channel shard parallelism") { + // NOTE (@lockshaw): in the current (parallel shape inference from just input tensor) representation we have to choose between + // either parallelism in the weight channel dimension or in the weight entry dimension. For now we choose to represent + // parallelism in the channel dimension, but partitioning in the entry dimension is also potentially useful as it produces + // sum parallelism in the output + int degree = 4; + ParallelTensorShape input = make_input(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1); + + { + tl::expected result = get_output_shape(attrs, input); + tl::expected correct = make_output(SumDegree{1}, DiscardCopyDegree{1}, 1, degree); + CHECK(result == correct); + } + + { + tl::expected result = get_weights_shape(attrs, input); + tl::expected correct = make_weights(SumDegree{1}, DiscardCopyDegree{1}, 1, degree); + CHECK(result == correct); + } + } + } +} diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 5e15f50966..d5ed622a82 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -446,9 +446,9 @@ tensor_guid_t ComputationGraphBuilder::embedding( TensorShape input_shape = this->get_shape(input); TensorAttrs weight_attrs = make_weight_attrs( - get_weights_shape(attrs, input_shape), kernel_initializer); + throw_if_unexpected(get_weights_shape(attrs, input_shape)), kernel_initializer); - TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {weight_attrs}, output_shape); } diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 3efd56726c..2da1878b07 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -684,7 +684,7 @@ std::vector subvec(std::vector const &v, auto resolve_loc = [&](int idx) -> typename std::vector::iterator::difference_type { if (idx < 0) { - return v.size() - idx; + return v.size() + idx; } else { return idx; } From dd008d0016da1537637a7d3996b6430f0fa7f80f Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 31 May 2024 16:50:53 -0700 Subject: [PATCH 27/43] Add shape inference for repartition, combine, replicate, and reduction --- lib/op-attrs/include/op-attrs/dim_ordered.h | 8 + .../include/op-attrs/get_output_shapes.h | 6 - lib/op-attrs/include/op-attrs/ops/combine.h | 11 +- lib/op-attrs/include/op-attrs/ops/reduce.h | 2 +- lib/op-attrs/include/op-attrs/ops/reduction.h | 7 +- .../op-attrs/ops/reduction_attrs.dtg.h | 8 +- .../op-attrs/ops/reduction_attrs.struct.toml | 9 - .../include/op-attrs/ops/repartition.h | 8 +- lib/op-attrs/include/op-attrs/ops/replicate.h | 2 +- .../op-attrs/ops/replicate_attrs.dtg.h | 8 +- .../op-attrs/ops/replicate_attrs.struct.toml | 9 +- .../include/op-attrs/parallel_tensor_shape.h | 2 + lib/op-attrs/src/op-attrs/ops/combine.cc | 30 +- lib/op-attrs/src/op-attrs/ops/flat.cc | 2 - lib/op-attrs/src/op-attrs/ops/reduction.cc | 20 +- .../src/op-attrs/ops/reduction_attrs.dtg.cc | 37 +- lib/op-attrs/src/op-attrs/ops/repartition.cc | 15 +- lib/op-attrs/src/op-attrs/ops/replicate.cc | 4 +- .../src/op-attrs/ops/replicate_attrs.dtg.cc | 37 +- .../src/op-attrs/parallel_tensor_shape.cc | 8 + .../src/parallel_dim_mapping_record_solver.cc | 362 ------------------ .../src/parallel_dim_mapping_record_solver.h | 106 ----- lib/op-attrs/test/src/ops/combine.cc | 55 +++ lib/op-attrs/test/src/ops/repartition.cc | 40 ++ lib/op-attrs/test/src/ops/replicate.cc | 33 ++ 25 files changed, 230 insertions(+), 599 deletions(-) delete mode 100644 lib/op-attrs/src/parallel_dim_mapping_record_solver.cc delete mode 100644 lib/op-attrs/src/parallel_dim_mapping_record_solver.h create mode 100644 lib/op-attrs/test/src/ops/combine.cc create mode 100644 lib/op-attrs/test/src/ops/repartition.cc create mode 100644 lib/op-attrs/test/src/ops/replicate.cc diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index 685d60c370..dbc237a03d 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -52,6 +52,14 @@ struct DimOrdered { return this->at(idx); } + bool idx_is_valid(Idx const &idx) const { + int raw = idx.value; + if (raw < 0) { + raw = this->contents.size() + raw; + } + return (raw >= 0 && raw < this->contents.size()); + } + bool operator==(DimOrdered const &other) const { return this->contents == other.contents; } diff --git a/lib/op-attrs/include/op-attrs/get_output_shapes.h b/lib/op-attrs/include/op-attrs/get_output_shapes.h index 60d66babc4..9796204250 100644 --- a/lib/op-attrs/include/op-attrs/get_output_shapes.h +++ b/lib/op-attrs/include/op-attrs/get_output_shapes.h @@ -114,8 +114,6 @@ ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &, std::vector const &); ParallelTensorShape get_output_shape(CastAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(CombineAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(ConcatAttrs const &, std::vector const &); ParallelTensorShape get_output_shape(Conv2DAttrs const &, @@ -135,10 +133,6 @@ ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(ReduceAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ReductionAttrs const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(RepartitionAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(ReplicateAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(ReverseAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/combine.h b/lib/op-attrs/include/op-attrs/ops/combine.h index d2d86e2fea..bbd5c12f33 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine.h +++ b/lib/op-attrs/include/op-attrs/ops/combine.h @@ -1,13 +1,18 @@ -#ifndef _FLEXFLOW_COMBINE_ATTRS_H -#define _FLEXFLOW_COMBINE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/combine_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include namespace FlexFlow { CHECK_VALID_OP_ATTR(CombineAttrs); +tl::expected + get_output_shape(CombineAttrs const &, ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reduce.h b/lib/op-attrs/include/op-attrs/ops/reduce.h index 800610fb2b..04e44b4161 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_OP_META_OPS_REDUCE_ATTRS_H #define _FLEXFLOW_OP_META_OPS_REDUCE_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/reduce_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" diff --git a/lib/op-attrs/include/op-attrs/ops/reduction.h b/lib/op-attrs/include/op-attrs/ops/reduction.h index 0ab9861b67..49f99d81fd 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction.h @@ -1,16 +1,17 @@ #ifndef _FLEXFLOW_REDUCTION_ATTRS_H #define _FLEXFLOW_REDUCTION_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/reduction_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include namespace FlexFlow { CHECK_VALID_OP_ATTR(ReductionAttrs); -ParallelTensorShape get_output_shape(ReductionAttrs const &attrs, - ParallelTensorShape const &input_shape); +tl::expected get_output_shape(ReductionAttrs const &attrs, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h index 5ff8a12651..9de5eb2252 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml /* proj-data { - "generated_from": "28492e45a5c4f44987e17fe9ea876e11" + "generated_from": "1d2b5b7cf11ed04a27a6fd8215e4e2a5" } */ @@ -12,8 +12,6 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/ff_dim.h" #include "rapidcheck.h" #include #include @@ -22,8 +20,7 @@ namespace FlexFlow { struct ReductionAttrs { ReductionAttrs() = delete; - ReductionAttrs(::FlexFlow::ff_dim_t const &reduction_dim, - int const &reduction_degree); + ReductionAttrs(int const &reduction_degree); bool operator==(ReductionAttrs const &) const; bool operator!=(ReductionAttrs const &) const; @@ -31,7 +28,6 @@ struct ReductionAttrs { bool operator>(ReductionAttrs const &) const; bool operator<=(ReductionAttrs const &) const; bool operator>=(ReductionAttrs const &) const; - ::FlexFlow::ff_dim_t reduction_dim; int reduction_degree; }; } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml index ff990ef46c..ee0ae54132 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml @@ -9,15 +9,6 @@ features = [ "fmt", ] -includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", -] - -[[fields]] -name = "reduction_dim" -type = "::FlexFlow::ff_dim_t" - [[fields]] name = "reduction_degree" type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/repartition.h b/lib/op-attrs/include/op-attrs/ops/repartition.h index 09ab21615a..15b7ed71de 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition.h @@ -1,16 +1,18 @@ #ifndef _FLEXFLOW_PARTITION_ATTRS_H #define _FLEXFLOW_PARTITION_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/repartition_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include namespace FlexFlow { CHECK_VALID_OP_ATTR(RepartitionAttrs); -ParallelTensorShape get_output_shape(RepartitionAttrs const &, - ParallelTensorShape const &input_shape); +tl::expected + get_output_shape(RepartitionAttrs const &, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index 4c46bf88a9..f10938cb68 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -10,7 +10,7 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ReplicateAttrs); ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, - ParallelTensorShape const &input_shape); + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h index 36b70a0b6d..ea3f0d46c7 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml /* proj-data { - "generated_from": "68c1bba349a54c0db219a67d4cc502b3" + "generated_from": "6d3ad4d10c24dae819ffee4592a72499" } */ @@ -12,8 +12,6 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" -#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/ff_dim.h" #include "rapidcheck.h" #include #include @@ -22,8 +20,7 @@ namespace FlexFlow { struct ReplicateAttrs { ReplicateAttrs() = delete; - ReplicateAttrs(::FlexFlow::ff_dim_t const &replicate_dim, - int const &replicate_degree); + ReplicateAttrs(int const &replicate_degree); bool operator==(ReplicateAttrs const &) const; bool operator!=(ReplicateAttrs const &) const; @@ -31,7 +28,6 @@ struct ReplicateAttrs { bool operator>(ReplicateAttrs const &) const; bool operator<=(ReplicateAttrs const &) const; bool operator>=(ReplicateAttrs const &) const; - ::FlexFlow::ff_dim_t replicate_dim; int replicate_degree; }; } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml index afcb8f8fa4..4e43ea747a 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml @@ -9,14 +9,7 @@ features = [ "fmt", ] -includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", -] - -[[fields]] -name = "replicate_dim" -type = "::FlexFlow::ff_dim_t" +includes = [ ] [[fields]] name = "replicate_degree" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index bcce38eded..969392ebf9 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -11,6 +11,8 @@ int num_shard_dims(ParallelTensorShape const &); ShardParallelDim shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &, ff_dim_t); +std::optional try_get_shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); + ParallelTensorShape lift_to_parallel(TensorShape const &); ParallelTensorShape lift_to_parallel_with_degrees(TensorShape const &, SumDegree sum_degree, DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees); diff --git a/lib/op-attrs/src/op-attrs/ops/combine.cc b/lib/op-attrs/src/op-attrs/ops/combine.cc index cdca524538..a91fe43452 100644 --- a/lib/op-attrs/src/op-attrs/ops/combine.cc +++ b/lib/op-attrs/src/op-attrs/ops/combine.cc @@ -1,18 +1,26 @@ #include "op-attrs/ops/combine.h" -#include "utils/hash-utils.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { -/* bool CombineAttrs::is_valid(ParallelTensorShape const &input) const { */ -/* return input.at(this->combine_legion_dim).degree % this->combine_degree == - * 0; */ -/* } */ +tl::expected get_output_shape(CombineAttrs const &attrs, ParallelTensorShape const &input) { + ShardParallelDim input_dim = ({ + std::optional result = try_get_shard_dim_at_idx(input, attrs.combine_dim); + if (!result.has_value()) { + return tl::unexpected(fmt::format("Failed to get shard dim at index {} in parallel tensor shape {}", attrs.combine_dim, input)); + } -/* ParallelTensorShape CombineAttrs::output_shape(ParallelTensorShape const - * &input_shape) const { */ -/* ParallelTensorShape output = input_shape; */ -/* output.at(this->combine_legion_dim).degree /= this->combine_degree; */ -/* return output; */ -/* } */ + result.value(); + }); + + if (input_dim.degree % attrs.combine_degree != 0) { + return tl::unexpected(fmt::format("Combine received tensor containing parallel dim {} with degree {}, which is not divisible by combine degree {}", attrs.combine_dim, input_dim.degree, attrs.combine_degree)); + } + + ParallelTensorShape output = input; + shard_dim_at_idx(output, attrs.combine_dim).degree /= attrs.combine_degree; + + return output; +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc index 75d31beae4..b0683c5f08 100644 --- a/lib/op-attrs/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat.cc @@ -1,6 +1,4 @@ #include "op-attrs/ops/flat.h" -#include "parallel_dim_mapping_record.h" -#include "parallel_dim_mapping_record_solver.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/reduction.cc b/lib/op-attrs/src/op-attrs/ops/reduction.cc index 1a61277076..1b0a393058 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduction.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduction.cc @@ -1,18 +1,18 @@ #include "op-attrs/ops/reduction.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { -ParallelTensorShape get_output_shape(ReductionAttrs const &attrs, +tl::expected + get_output_shape(ReductionAttrs const &attrs, ParallelTensorShape const &input_shape) { - NOT_IMPLEMENTED(); -} + if (get_sum_degree(input_shape) % attrs.reduction_degree != 0) { + return tl::unexpected(fmt::format("Reduction received tensor with sum degree {}, which is not divisible by reduction degree {}", get_sum_degree(input_shape), attrs.reduction_degree)); + } -/* ParallelTensorShape ReductionAttrs::output_shape(ParallelTensorShape const - * &input_shape) const { */ -/* ParallelTensorShape output = input_shape; */ -/* output.at(this->reduction_legion_dim).degree /= this->reduction_degree; */ -/* output.at(this->reduction_legion_dim).size /= this->reduction_degree; */ -/* return output; */ -/* } */ + ParallelTensorShape output_shape = input_shape; + output_shape.dims.replica_dims.sum_degree.value /= attrs.reduction_degree; + return output_shape; +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc index d4566614df..2f1550bb66 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc @@ -3,43 +3,34 @@ // lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml /* proj-data { - "generated_from": "28492e45a5c4f44987e17fe9ea876e11" + "generated_from": "1d2b5b7cf11ed04a27a6fd8215e4e2a5" } */ #include "op-attrs/ops/reduction_attrs.dtg.h" -#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/ff_dim.h" #include namespace FlexFlow { -ReductionAttrs::ReductionAttrs(::FlexFlow::ff_dim_t const &reduction_dim, - int const &reduction_degree) - : reduction_dim(reduction_dim), reduction_degree(reduction_degree) {} +ReductionAttrs::ReductionAttrs(int const &reduction_degree) + : reduction_degree(reduction_degree) {} bool ReductionAttrs::operator==(ReductionAttrs const &other) const { - return std::tie(this->reduction_dim, this->reduction_degree) == - std::tie(other.reduction_dim, other.reduction_degree); + return std::tie(this->reduction_degree) == std::tie(other.reduction_degree); } bool ReductionAttrs::operator!=(ReductionAttrs const &other) const { - return std::tie(this->reduction_dim, this->reduction_degree) != - std::tie(other.reduction_dim, other.reduction_degree); + return std::tie(this->reduction_degree) != std::tie(other.reduction_degree); } bool ReductionAttrs::operator<(ReductionAttrs const &other) const { - return std::tie(this->reduction_dim, this->reduction_degree) < - std::tie(other.reduction_dim, other.reduction_degree); + return std::tie(this->reduction_degree) < std::tie(other.reduction_degree); } bool ReductionAttrs::operator>(ReductionAttrs const &other) const { - return std::tie(this->reduction_dim, this->reduction_degree) > - std::tie(other.reduction_dim, other.reduction_degree); + return std::tie(this->reduction_degree) > std::tie(other.reduction_degree); } bool ReductionAttrs::operator<=(ReductionAttrs const &other) const { - return std::tie(this->reduction_dim, this->reduction_degree) <= - std::tie(other.reduction_dim, other.reduction_degree); + return std::tie(this->reduction_degree) <= std::tie(other.reduction_degree); } bool ReductionAttrs::operator>=(ReductionAttrs const &other) const { - return std::tie(this->reduction_dim, this->reduction_degree) >= - std::tie(other.reduction_dim, other.reduction_degree); + return std::tie(this->reduction_degree) >= std::tie(other.reduction_degree); } } // namespace FlexFlow @@ -47,8 +38,6 @@ namespace std { size_t hash::operator()( FlexFlow::ReductionAttrs const &x) const { size_t result = 0; - result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.reduction_dim) + 0x9e3779b9 + - (result << 6) + (result >> 2); result ^= std::hash{}(x.reduction_degree) + 0x9e3779b9 + (result << 6) + (result >> 2); return result; @@ -58,21 +47,18 @@ size_t hash::operator()( namespace nlohmann { FlexFlow::ReductionAttrs adl_serializer::from_json(json const &j) { - return {j.at("reduction_dim").template get<::FlexFlow::ff_dim_t>(), - j.at("reduction_degree").template get()}; + return {j.at("reduction_degree").template get()}; } void adl_serializer::to_json( json &j, FlexFlow::ReductionAttrs const &v) { j["__type"] = "ReductionAttrs"; - j["reduction_dim"] = v.reduction_dim; j["reduction_degree"] = v.reduction_degree; } } // namespace nlohmann namespace rc { Gen Arbitrary::arbitrary() { - return gen::construct( - gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); + return gen::construct(gen::arbitrary()); } } // namespace rc @@ -80,7 +66,6 @@ namespace FlexFlow { std::string format_as(ReductionAttrs const &x) { std::ostringstream oss; oss << ""; return oss.str(); diff --git a/lib/op-attrs/src/op-attrs/ops/repartition.cc b/lib/op-attrs/src/op-attrs/ops/repartition.cc index 309247c6e1..e668ccce6c 100644 --- a/lib/op-attrs/src/op-attrs/ops/repartition.cc +++ b/lib/op-attrs/src/op-attrs/ops/repartition.cc @@ -2,15 +2,12 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(RepartitionAttrs const &, - ParallelTensorShape const &input_shape) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(RepartitionAttrs const &attrs, + ParallelTensorShape const &input_shape) { + ParallelTensorShape output_shape = input_shape; + output_shape.dims.shard_dims.at(attrs.repartition_dim).degree *= attrs.repartition_degree; + return output_shape; } -/* bool RepartitionAttrs::is_valid(ParallelTensorShape const &input_shape) const - * { */ -/* ParallelDim dim = input_shape.at(this->repartition_legion_dim); */ -/* return (dim.size % this->repartition_degree * dim.degree == 0); */ -/* } */ - } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/replicate.cc b/lib/op-attrs/src/op-attrs/ops/replicate.cc index 261a82464f..046558b452 100644 --- a/lib/op-attrs/src/op-attrs/ops/replicate.cc +++ b/lib/op-attrs/src/op-attrs/ops/replicate.cc @@ -4,7 +4,9 @@ namespace FlexFlow { ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, ParallelTensorShape const &input_shape) { - NOT_IMPLEMENTED(); + ParallelTensorShape output_shape = input_shape; + output_shape.dims.replica_dims.discard_copy_degree.value *= attrs.replicate_degree; + return output_shape; } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc index 66fe45a1db..930c5beaf4 100644 --- a/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc @@ -3,43 +3,34 @@ // lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml /* proj-data { - "generated_from": "68c1bba349a54c0db219a67d4cc502b3" + "generated_from": "6d3ad4d10c24dae819ffee4592a72499" } */ #include "op-attrs/ops/replicate_attrs.dtg.h" -#include "op-attrs/ff_dim.dtg.h" -#include "op-attrs/ff_dim.h" #include namespace FlexFlow { -ReplicateAttrs::ReplicateAttrs(::FlexFlow::ff_dim_t const &replicate_dim, - int const &replicate_degree) - : replicate_dim(replicate_dim), replicate_degree(replicate_degree) {} +ReplicateAttrs::ReplicateAttrs(int const &replicate_degree) + : replicate_degree(replicate_degree) {} bool ReplicateAttrs::operator==(ReplicateAttrs const &other) const { - return std::tie(this->replicate_dim, this->replicate_degree) == - std::tie(other.replicate_dim, other.replicate_degree); + return std::tie(this->replicate_degree) == std::tie(other.replicate_degree); } bool ReplicateAttrs::operator!=(ReplicateAttrs const &other) const { - return std::tie(this->replicate_dim, this->replicate_degree) != - std::tie(other.replicate_dim, other.replicate_degree); + return std::tie(this->replicate_degree) != std::tie(other.replicate_degree); } bool ReplicateAttrs::operator<(ReplicateAttrs const &other) const { - return std::tie(this->replicate_dim, this->replicate_degree) < - std::tie(other.replicate_dim, other.replicate_degree); + return std::tie(this->replicate_degree) < std::tie(other.replicate_degree); } bool ReplicateAttrs::operator>(ReplicateAttrs const &other) const { - return std::tie(this->replicate_dim, this->replicate_degree) > - std::tie(other.replicate_dim, other.replicate_degree); + return std::tie(this->replicate_degree) > std::tie(other.replicate_degree); } bool ReplicateAttrs::operator<=(ReplicateAttrs const &other) const { - return std::tie(this->replicate_dim, this->replicate_degree) <= - std::tie(other.replicate_dim, other.replicate_degree); + return std::tie(this->replicate_degree) <= std::tie(other.replicate_degree); } bool ReplicateAttrs::operator>=(ReplicateAttrs const &other) const { - return std::tie(this->replicate_dim, this->replicate_degree) >= - std::tie(other.replicate_dim, other.replicate_degree); + return std::tie(this->replicate_degree) >= std::tie(other.replicate_degree); } } // namespace FlexFlow @@ -47,8 +38,6 @@ namespace std { size_t hash::operator()( FlexFlow::ReplicateAttrs const &x) const { size_t result = 0; - result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.replicate_dim) + 0x9e3779b9 + - (result << 6) + (result >> 2); result ^= std::hash{}(x.replicate_degree) + 0x9e3779b9 + (result << 6) + (result >> 2); return result; @@ -58,21 +47,18 @@ size_t hash::operator()( namespace nlohmann { FlexFlow::ReplicateAttrs adl_serializer::from_json(json const &j) { - return {j.at("replicate_dim").template get<::FlexFlow::ff_dim_t>(), - j.at("replicate_degree").template get()}; + return {j.at("replicate_degree").template get()}; } void adl_serializer::to_json( json &j, FlexFlow::ReplicateAttrs const &v) { j["__type"] = "ReplicateAttrs"; - j["replicate_dim"] = v.replicate_dim; j["replicate_degree"] = v.replicate_degree; } } // namespace nlohmann namespace rc { Gen Arbitrary::arbitrary() { - return gen::construct( - gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); + return gen::construct(gen::arbitrary()); } } // namespace rc @@ -80,7 +66,6 @@ namespace FlexFlow { std::string format_as(ReplicateAttrs const &x) { std::ostringstream oss; oss << ""; return oss.str(); diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 55d1d4af2b..e8ffc49269 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -44,6 +44,14 @@ ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &s, ff_dim_t d) { return shard_dim_at_idx(s.dims, d); } +std::optional try_get_shard_dim_at_idx(ParallelTensorShape const &s, ff_dim_t d) { + if (s.dims.shard_dims.idx_is_valid(d)) { + return s.dims.shard_dims.at(d); + } else { + return std::nullopt; + } +} + ParallelTensorShape lift_to_parallel(TensorShape const &s) { return {lift_to_parallel(s.dims), s.data_type}; } diff --git a/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc b/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc deleted file mode 100644 index c7e70bb906..0000000000 --- a/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc +++ /dev/null @@ -1,362 +0,0 @@ -#include "parallel_dim_mapping_record_solver.h" -#include "op-attrs/parallel_tensor_shape.h" -#include -#include - -namespace FlexFlow { - -std::vector construct_weight_parallel_dims( - std::vector &records, - std::vector> mappings, - int input_idx, - int weight_idx) { - - std::vector output; - std::transform(mappings.cbegin(), - mappings.cend(), - output.begin(), - [&](std::tuple const &mapping) { - return construct_weight_parallel_dims(std::get<0>(mapping), - std::get<2>(mapping), - input_idx, - weight_idx, - std::get<1>(mapping)); - }); - return output; -} - -std::vector construct_output_parallel_dims( - std::vector> mappings, - int input_idx, - int output_idx) { - NOT_IMPLEMENTED(); -} - -std::vector construct_weight_parallel_dims( - std::vector> mappings, - int input_idx, - int weight_idx) { - NOT_IMPLEMENTED(); -} - -ParallelDimMappingRecord - construct_output_parallel_dims(int input_dim, - int output_dim, - int input_idx, - int output_idx, - std::optional operation) { - NOT_IMPLEMENTED(); -} - -ParallelDimMappingRecord - construct_weight_parallel_dims(int input_dim, - int weight_dim, - int input_idx, - int weight_idx, - std::optional operation) { - NOT_IMPLEMENTED(); -} -/* int get_output_to_input_dim_mapping(ParallelTensorShape const &output, */ -/* int output_dim, */ -/* ParallelTensorShape const &input) { */ -/* int output_idx = -1, input_idx = -1; */ -/* for (int i = 0; i < numOutputs; i++) { */ -/* if (output == outputs[i]) { */ -/* output_idx = i; */ -/* } */ -/* } */ -/* for (int i = 0; i < numInputs; i++) { */ -/* if (input == inputs[i]) { */ -/* input_idx = i; */ -/* } */ -/* } */ -/* assert(output_idx != -1); */ -/* assert(input_idx != -1); */ -/* for (size_t i = 0; i < parallel_dims_mapping->size(); i++) { */ -/* if ((*parallel_dims_mapping)[i].output_idx != output_idx) { */ -/* continue; */ -/* } */ -/* if ((*parallel_dims_mapping)[i].output_dim != output_dim) { */ -/* continue; */ -/* } */ -/* if ((*parallel_dims_mapping)[i].input_idx != input_idx) { */ -/* continue; */ -/* } */ -/* // Check validness */ -/* assert((*parallel_dims_mapping)[i].weight_idx = -1); */ -/* assert((*parallel_dims_mapping)[i].weight_dim = -1); */ -/* return (*parallel_dims_mapping)[i].input_dim; */ -/* } */ -/* assert(false); */ -/* return -1; */ -/* } */ - -/* int get_output_to_weight_dim_mapping(const ParallelTensor output, */ -/* int output_dim, */ -/* const ParallelTensor weight) { */ -/* int output_idx = -1, weight_idx = -1; */ -/* for (int i = 0; i < numOutputs; i++) { */ -/* if (output == outputs[i]) { */ -/* output_idx = i; */ -/* } */ -/* } */ -/* for (int i = 0; i < numInputs; i++) { */ -/* if (weight == weights[i]) { */ -/* weight_idx = i; */ -/* } */ -/* } */ -/* assert(output_idx != -1); */ -/* assert(weight_idx != -1); */ -/* for (size_t i = 0; i < parallel_dims_mapping->size(); i++) { */ -/* if ((*parallel_dims_mapping)[i].output_idx != output_idx) { */ -/* continue; */ -/* } */ -/* if ((*parallel_dims_mapping)[i].output_dim != output_dim) { */ -/* continue; */ -/* } */ -/* if ((*parallel_dims_mapping)[i].weight_idx != weight_idx) { */ -/* continue; */ -/* } */ -/* // Check validness */ -/* assert((*parallel_dims_mapping)[i].input_idx = -1); */ -/* assert((*parallel_dims_mapping)[i].input_dim = -1); */ -/* return (*parallel_dims_mapping)[i].weight_dim; */ -/* } */ -/* assert(false); */ -/* return -1; */ -/* } */ - -/* bool check_output_input_weight_parallel_dims(bool allocate_weights) const { - */ -/* // if (!allocate_weights) { */ -/* // assert(this->numWeights == 0); */ -/* // } */ - -/* for (ParallelDimMappingRecord const &record : *parallel_dims_mapping) { */ -/* assert(record.input_idx < this->numInputs); */ -/* assert(record.input_dim < this->inputs[record.input_idx]->num_dims); */ -/* ParallelDim const &input_dim = */ -/* inputs[record.input_idx]->dims[record.input_dim]; */ -/* /1* assert (input_dim.degree != ParallelDim::UNKNOWN_DEGREE); *1/ */ -/* /1* assert (input_dim.parallel_idx != ParallelDim::UNKNOWN_INDEX); *1/ */ - -/* ParallelDim other_dim; */ -/* switch (record.get_type()) { */ -/* case MappingRecordType::INPUT_OUTPUT: */ -/* assert(record.output_idx < this->numOutputs); */ -/* assert(record.output_dim < - * this->outputs[record.output_idx]->num_dims); */ -/* other_dim = outputs[record.output_idx]->dims[record.output_dim]; */ -/* break; */ -/* case MappingRecordType::INPUT_WEIGHT: */ -/* if (!allocate_weights) { */ -/* continue; */ -/* } */ -/* if (record.weight_idx >= this->numWeights) { */ -/* // The case where some weights are not used (e.g., no bias for - * linear) */ -/* continue; */ -/* } */ -/* assert(record.weight_dim < - * this->weights[record.weight_idx]->num_dims); */ -/* other_dim = weights[record.weight_idx]->dims[record.weight_dim]; */ -/* break; */ -/* } */ - -/* assert(other_dim.degree == input_dim.degree); */ -/* assert(other_dim.parallel_idx == input_dim.parallel_idx); */ -/* } */ -/* return true; */ -/* } */ - -/* bool check_output_input_weight_same_machine_view() const { */ -/* assert(numOutputs > 0); */ -/* MachineView machine_view = outputs[0]->machine_view; */ -/* for (int i = 0; i < numOutputs; i++) { */ -/* if (outputs[i]->machine_view != machine_view) { */ -/* return false; */ -/* } */ -/* } */ -/* for (int i = 0; i < numInputs; i++) { */ -/* if (inputs[i]->machine_view != machine_view) { */ -/* return false; */ -/* } */ -/* } */ -/* for (int i = 0; i < numWeights; i++) { */ -/* if (weights[i]->machine_view != machine_view) { */ -/* return false; */ -/* } */ -/* } */ -/* return true; */ -/* } */ - -std::vector construct_weight_parallel_dims( - std::vector> mappings, int input_idx, int weight_idx) { - std::vector output; - std::transform(mappings.cbegin(), - mappings.cend(), - output.begin(), - [&](std::pair const &mapping) { - return construct_weight_parallel_dims( - mapping.first, mapping.second, input_idx, weight_idx); - }); - return output; -} - -void construct_weight_parallel_dims( - std::vector &records, - int input_dim, - int weight_dim, - int input_idx, - int weight_idx, - std::optional operation) { - records.push_back(ParallelDimMappingRecord::input_weight_record( - input_idx, input_dim, weight_idx, weight_dim, operation)); -} - -/* void ParallelDimMappingRecordSolver::register_weight_parallel_dims( */ -/* std::vector> mappings, int input_idx, int weight_idx) - * { */ -/* construct_weight_parallel_dims( */ -/* *this->parallel_dims_mapping, mappings, input_idx, weight_idx); */ -/* } */ - -/* void register_weight_parallel_dims( */ -/* std::vector> mappings, */ -/* int input_idx, */ -/* int weight_idx) { */ -/* construct_weight_parallel_dims( */ -/* *this->parallel_dims_mapping, mappings, input_idx, weight_idx); */ -/* } */ - -/* void register_weight_parallel_dims( */ -/* int input_dim, */ -/* int weight_dim, */ -/* int input_idx, */ -/* int weight_idx, */ -/* tl::optional operation) { */ -/* construct_weight_parallel_dims(*this->parallel_dims_mapping, */ -/* input_dim, */ -/* weight_dim, */ -/* input_idx, */ -/* weight_idx, */ -/* operation); */ -/* } */ - -void construct_output_parallel_dims( - std::vector &records, - std::vector> mappings, - int input_idx, - int output_idx) { - for (std::tuple const &mapping : mappings) { - construct_output_parallel_dims(std::get<0>(mapping), - std::get<2>(mapping), - input_idx, - output_idx, - std::get<1>(mapping)); - } -} - -void construct_output_parallel_dims( - std::vector &records, - std::vector> mappings, - int input_idx, - int output_idx) { - for (std::pair const &mapping : mappings) { - construct_output_parallel_dims( - mapping.first, mapping.second, input_idx, output_idx); - } -} - -void construct_output_parallel_dims( - std::vector &records, - int input_dim, - int output_dim, - int input_idx, - int output_idx, - std::optional operation) { - records.push_back(ParallelDimMappingRecord::input_output_record( - input_idx, input_dim, output_idx, output_dim, operation)); -} - -/* void register_output_parallel_dims( */ -/* std::vector> mappings, int input_idx, int output_idx) - * { */ -/* construct_output_parallel_dims( */ -/* *this->parallel_dims_mapping, mappings, input_idx, output_idx); */ -/* } */ - -/* void register_output_parallel_dims( */ -/* std::vector> mappings, */ -/* int input_idx, */ -/* int output_idx) { */ -/* construct_output_parallel_dims( */ -/* *this->parallel_dims_mapping, mappings, input_idx, output_idx); */ -/* } */ - -/* void register_output_parallel_dims( */ -/* int input_dim, */ -/* int output_dim, */ -/* int input_idx, */ -/* int output_idx, */ -/* tl::optional operation) { */ -/* construct_output_parallel_dims(*this->parallel_dims_mapping, */ -/* input_dim, */ -/* output_dim, */ -/* input_idx, */ -/* output_idx, */ -/* operation); */ -/* } */ - -/* ParallelDimMappingSolution solve_parallel_dim_mappings( */ -/* std::vector const &mappings, */ -/* std::vector const &inputs, */ -/* int numWeights, int numOutputs) { */ - -/* ParallelDimMappingSolution solution = [&]() -> ParallelDimMappingSolution { - */ -/* std::vector weight_shapes(numWeights); */ -/* std::vector output_shapes(numOutputs); */ -/* return { weight_shapes, output_shapes }; */ -/* }(); */ - -/* for (ParallelDimMappingRecord const &record : mappings) { */ -/* ParallelDim const &input_dim = - * inputs.at(record.input_idx).at(record.input_dim); */ - -/* switch (record.get_type()) { */ -/* case MappingRecordType::INPUT_OUTPUT: { */ -/* ParallelDim &output_dim = - * solution.output_shapes.at(record.output_idx).at(record.output_dim); */ -/* output_dim.degree = input_dim.degree; */ -/* output_dim.parallel_idx = input_dim.parallel_idx; */ - -/* if (output_dim.is_replica_dim) { */ -/* output_dim.size = input_dim.degree; */ -/* } */ -/* } break; */ -/* case MappingRecordType::INPUT_WEIGHT: { */ -/* ParallelDim &weight_dim = - * solution.weight_shapes.at(record.weight_idx).at(record.weight_dim); */ -/* weight_dim.degree = input_dim.degree; */ -/* weight_dim.parallel_idx = input_dim.parallel_idx; */ - -/* if (weight_dim.is_replica_dim) { */ -/* weight_dim.size = input_dim.degree; */ -/* } */ -/* } break; */ -/* } */ -/* } */ - -/* return solution; */ -/* } */ - -ParallelDimMappingSolution solve_parallel_dim_mappings( - std::vector const &mappings, - std::vector const &input, - int numWeights, - int numOutputs) { - NOT_IMPLEMENTED(); -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_dim_mapping_record_solver.h b/lib/op-attrs/src/parallel_dim_mapping_record_solver.h deleted file mode 100644 index a46192edeb..0000000000 --- a/lib/op-attrs/src/parallel_dim_mapping_record_solver.h +++ /dev/null @@ -1,106 +0,0 @@ -/** - * @file - * @warning This is legacy code the should be removed - * (partially tracked in - * https://github.com/flexflow/FlexFlow/issues/519). - * @brief Helper functions for computing data dependencies of parallel - * operators. Functions based on an incorrect abstraction that should eventually - * be removed in favor of something like https://doi.org/10.1145/3302424.3303953 - */ - -#ifndef _FLEXFLOW_OP_META_SRC_PARELLEL_DIM_MAPPING_RECORD_SOLVER_H -#define _FLEXFLOW_OP_META_SRC_PARELLEL_DIM_MAPPING_RECORD_SOLVER_H - -#include "op-attrs/parallel_tensor_shape.h" -#include "parallel_dim_mapping_record.h" - -namespace FlexFlow { - -std::vector - construct_weight_parallel_dims(std::vector> mappings, - int input_idx = 0, - int weight_idx = 0); -std::vector construct_weight_parallel_dims( - std::vector> mappings, - int input_idx = 0, - int weight_idx = 0); -ParallelDimMappingRecord construct_weight_parallel_dims( - int input_dim, - int weight_dim, - int input_idx = 0, - int weight_idx = 0, - std::optional operation = std::nullopt); - -std::vector - construct_output_parallel_dims(std::vector> mappings, - int input_idx = 0, - int output_idx = 0); -std::vector construct_output_parallel_dims( - std::vector> mappings, - int input_idx = 0, - int output_idx = 0); -ParallelDimMappingRecord construct_output_parallel_dims( - int input_dim, - int output_dim, - int input_idx = 0, - int output_idx = 0, - std::optional operation = std::nullopt); - -struct ParallelDimMappingSolution { - std::vector weight_shapes; - std::vector output_shapes; -}; - -ParallelDimMappingSolution solve_parallel_dim_mappings( - std::vector const &mappings, - std::vector const &input, - int numWeights, - int numOutputs); - -/* class ParallelDimMappingRecordSolver { */ -/* /1* void register_weight_parallel_dims(std::vector> - * mappings, *1/ */ -/* /1* int input_idx = 0, *1/ */ -/* /1* int weight_idx = 0); *1/ */ - -/* /1* void register_output_parallel_dims(std::vector> - * mappings, *1/ */ -/* /1* int input_idx = 0, *1/ */ -/* /1* int output_idx = 0); *1/ */ - -/* /1* int get_output_to_input_dim_mapping(const ParallelTensor output, *1/ */ -/* /1* int output_dim, *1/ */ -/* /1* const ParallelTensor input); *1/ */ -/* /1* int get_output_to_weight_dim_mapping(const ParallelTensor output, *1/ - */ -/* /1* int output_dim, *1/ */ -/* /1* const ParallelTensor weight); *1/ - */ -/* void register_weight_parallel_dims( */ -/* std::vector> mappings, */ -/* int input_idx = 0, */ -/* int weight_idx = 0); */ -/* void register_weight_parallel_dims( */ -/* int input_dim, */ -/* int weight_dim, */ -/* int input_idx = 0, */ -/* int weight_idx = 0, */ -/* std::optional operation = std::nullopt); */ -/* void register_output_parallel_dims( */ -/* std::vector> mappings, */ -/* int input_idx = 0, */ -/* int output_idx = 0); */ -/* void register_output_parallel_dims( */ -/* int input_dim, */ -/* int output_dim, */ -/* int input_idx = 0, */ -/* int output_idx = 0, */ -/* std::optional operation = std::nullopt); */ - -/* private: */ -/* std::vector *parallel_dims_mapping; */ -/* }; */ - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/test/src/ops/combine.cc b/lib/op-attrs/test/src/ops/combine.cc new file mode 100644 index 0000000000..61f9b1d138 --- /dev/null +++ b/lib/op-attrs/test/src/ops/combine.cc @@ -0,0 +1,55 @@ +#include "test/utils/doctest.h" +#include "op-attrs/ops/combine.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Combine shape inference") { + + ParallelTensorShape input = { + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{14, 1}, + ShardParallelDim{16, 3}, + ShardParallelDim{18, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("valid") { + ff_dim_t dim = 2; + int degree = 3; + CombineAttrs attrs = CombineAttrs{ + /*repartition_dim=*/dim, + /*repartition_degree=*/degree, + }; + + tl::expected result = get_output_shape(attrs, input); + + tl::expected correct = [&] { + ParallelTensorShape output = input; + output.dims.shard_dims.at(dim).degree /= degree; + return output; + }(); + + CHECK(result == correct); + } + + SUBCASE("invalid") { + ff_dim_t dim = 2; + int degree = 4; + CombineAttrs attrs = CombineAttrs{ + /*repartition_dim=*/dim, + /*repartition_degree=*/degree, + }; + + tl::expected result = get_output_shape(attrs, input); + + CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", result.error()); + } + } +} diff --git a/lib/op-attrs/test/src/ops/repartition.cc b/lib/op-attrs/test/src/ops/repartition.cc new file mode 100644 index 0000000000..62c6a66799 --- /dev/null +++ b/lib/op-attrs/test/src/ops/repartition.cc @@ -0,0 +1,40 @@ +#include "test/utils/doctest.h" +#include "op-attrs/ops/repartition.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Repartition shape inference") { + ff_dim_t dim = 2; + int degree = 4; + RepartitionAttrs attrs = RepartitionAttrs{ + /*repartition_dim=*/dim, + /*repartition_degree=*/degree, + }; + + ParallelTensorShape input = { + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{14, 1}, + ShardParallelDim{16, 3}, + ShardParallelDim{18, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = get_output_shape(attrs, input); + + + tl::expected correct = [&] { + ParallelTensorShape output = input; + output.dims.shard_dims.at(dim).degree *= degree; + return output; + }(); + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/ops/replicate.cc b/lib/op-attrs/test/src/ops/replicate.cc new file mode 100644 index 0000000000..412ec5a2b4 --- /dev/null +++ b/lib/op-attrs/test/src/ops/replicate.cc @@ -0,0 +1,33 @@ +#include "test/utils/doctest.h" +#include "op-attrs/ops/replicate.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Replicate shape inference") { + ReplicateAttrs attrs = ReplicateAttrs{ + /*replicate_degree=*/4, + }; + + ParallelTensorShape input = { + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + ShardParallelDim{14, 2}, + ShardParallelDim{16, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + ParallelTensorShape result = get_output_shape(attrs, input); + + ParallelTensorShape correct_output = input; + correct_output.dims.replica_dims.discard_copy_degree = 8; + + CHECK(result == correct_output); + } +} From cb2c862b77b9707335868a12c90ca6c18fd67863 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 31 May 2024 16:56:34 -0700 Subject: [PATCH 28/43] Include tests for reduction --- lib/op-attrs/test/src/ops/reduction.cc | 51 ++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 lib/op-attrs/test/src/ops/reduction.cc diff --git a/lib/op-attrs/test/src/ops/reduction.cc b/lib/op-attrs/test/src/ops/reduction.cc new file mode 100644 index 0000000000..dc51479d37 --- /dev/null +++ b/lib/op-attrs/test/src/ops/reduction.cc @@ -0,0 +1,51 @@ +#include "test/utils/doctest.h" +#include "op-attrs/ops/reduction.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Reduction shape inference") { + + ParallelTensorShape input = { + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{14, 1}, + ShardParallelDim{16, 3}, + ShardParallelDim{18, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("valid") { + int degree = 3; + ReductionAttrs attrs = ReductionAttrs{ + /*repartition_degree=*/degree, + }; + + tl::expected result = get_output_shape(attrs, input); + + tl::expected correct = [&] { + ParallelTensorShape output = input; + output.dims.replica_dims.sum_degree.value /= degree; + return output; + }(); + + CHECK(result == correct); + } + + SUBCASE("invalid") { + int degree = 4; + ReductionAttrs attrs = ReductionAttrs{ + /*repartition_degree=*/degree, + }; + + tl::expected result = get_output_shape(attrs, input); + + CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", result.error()); + } + } +} From bf7a96543a2d218f3656bba11793c9c8f25c8165 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 1 Jun 2024 21:59:53 -0700 Subject: [PATCH 29/43] Address wmdi comments --- .../include/op-attrs/operator_attrs.h | 85 ------------------- 1 file changed, 85 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index 7acd322928..268554b5be 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -38,91 +38,6 @@ namespace FlexFlow { -/* using SharedOperatorAttrs = std::variant; */ - -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ -/* static_assert(is_valid_opattr::value, ""); */ - -/* using ParallelOperatorAttrs = std:: */ -/* variant; - */ - -/* using ComputationGraphAttrs = */ -/* variant_join>; */ -/* using CompGraphOperatorAttrs = ComputationGraphAttrs; */ - -/* using PCGOperatorAttrs = */ -/* variant_join; */ - -/* static_assert(is_equal_comparable::value, */ -/* "ComputationGraphAttrs must support =="); */ -/* static_assert(elements_satisfy::value, */ -/* ""); */ -/* static_assert(is_neq_comparable::value, */ -/* "ComputationGraphAttrs must support !="); */ -/* static_assert(is_lt_comparable::value, */ -/* "ComputationGraphAttrs must support <"); */ -/* static_assert(is_hashable::value, */ -/* "ComputationGraphAttrs must be hashable"); */ - -/* static_assert(is_equal_comparable::value, */ -/* "PCGOperatorAttrs must support =="); */ -/* static_assert(is_neq_comparable::value, */ -/* "PCGOperatorAttrs must support !="); */ -/* static_assert(is_lt_comparable::value, */ -/* "PCGOperatorAttrs must support <"); */ -/* static_assert(is_hashable::value, */ -/* "PCGOperatorAttrs must be hashable"); */ - -/* OperatorType get_op_type(CompGraphOperatorAttrs const &); */ -/* OperatorType get_op_type(PCGOperatorAttrs const &); */ - std::vector get_output_shapes( PCGOperatorAttrs const &op_params, std::vector const &input_tensor_shapes); From 3a05ef2477eec7581a58d78b7d05d1a2304f0542 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 2 Jun 2024 21:46:27 -0700 Subject: [PATCH 30/43] Fixup linear shape inference, add tests for linear --- .../include/op-attrs/get_output_shapes.h | 2 - lib/op-attrs/include/op-attrs/ops/linear.h | 19 +- .../include/op-attrs/ops/linear_attrs.dtg.h | 6 +- .../op-attrs/ops/linear_attrs.struct.toml | 2 +- .../include/op-attrs/parallel_tensor_dims.h | 1 + .../include/op-attrs/parallel_tensor_shape.h | 2 + lib/op-attrs/src/op-attrs/ops/linear.cc | 94 ++++++---- .../src/op-attrs/ops/linear_attrs.dtg.cc | 23 +-- .../src/op-attrs/parallel_tensor_dims.cc | 4 + .../src/op-attrs/parallel_tensor_shape.cc | 4 + lib/op-attrs/test/src/ops/linear.cc | 175 ++++++++++++++++++ 11 files changed, 274 insertions(+), 58 deletions(-) create mode 100644 lib/op-attrs/test/src/ops/linear.cc diff --git a/lib/op-attrs/include/op-attrs/get_output_shapes.h b/lib/op-attrs/include/op-attrs/get_output_shapes.h index 9796204250..a826e1cb54 100644 --- a/lib/op-attrs/include/op-attrs/get_output_shapes.h +++ b/lib/op-attrs/include/op-attrs/get_output_shapes.h @@ -127,8 +127,6 @@ std::vector get_output_shapes(GatherAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(LayerNormAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(LinearAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(ReduceAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index 566fb3dcf1..6da4f0c9f3 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -5,22 +5,29 @@ #include "op-attrs/ops/linear_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { CHECK_VALID_OP_ATTR(LinearAttrs); -TensorShape get_kernel_shape(LinearAttrs const &attrs, +tl::expected + get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input); -TensorShape get_bias_shape(LinearAttrs const &attrs, TensorShape const &input); -TensorShape get_output_shape(LinearAttrs const &attrs, +tl::expected + get_bias_shape(LinearAttrs const &attrs, TensorShape const &input); +tl::expected + get_output_shape(LinearAttrs const &attrs, TensorShape const &input); -ParallelTensorShape get_kernel_shape(LinearAttrs const &attrs, +tl::expected + get_kernel_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); -ParallelTensorShape get_bias_shape(LinearAttrs const &attrs, +tl::expected + get_bias_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); -ParallelTensorShape get_output_shape(LinearAttrs const &attrs, +tl::expected + get_output_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h index dfb7579b25..28cd2a8b33 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml /* proj-data { - "generated_from": "1369f126a4a6d6eee91642043ab481f6" + "generated_from": "7e82d282f90e08f1e0db7d5c4ce528b7" } */ @@ -27,7 +27,7 @@ struct LinearAttrs { LinearAttrs(int const &out_channels, bool const &use_bias, ::FlexFlow::DataType const &data_type, - ::FlexFlow::Activation const &activation, + std::optional<::FlexFlow::Activation> const &activation, std::optional<::FlexFlow::RegularizerAttrs> const ®ularizer); bool operator==(LinearAttrs const &) const; @@ -39,7 +39,7 @@ struct LinearAttrs { int out_channels; bool use_bias; ::FlexFlow::DataType data_type; - ::FlexFlow::Activation activation; + std::optional<::FlexFlow::Activation> activation; std::optional<::FlexFlow::RegularizerAttrs> regularizer; }; } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml index 7fa2d9c584..4ac8f83ec9 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml @@ -30,7 +30,7 @@ type = "::FlexFlow::DataType" [[fields]] name = "activation" -type = "::FlexFlow::Activation" +type = "std::optional<::FlexFlow::Activation>" [[fields]] name = "regularizer" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index 2bb44c919f..8e02e3607b 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -8,6 +8,7 @@ namespace FlexFlow { FFOrdered ff_ordered_shard_dims(ParallelTensorDims const &); +FFOrdered ff_ordered_shard_degrees(ParallelTensorDims const &); std::unordered_set replica_dims(ParallelTensorDims const &); /* size_t get_volume(ParallelTensorDims const &); */ diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 969392ebf9..c8b0ad236c 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -11,6 +11,8 @@ int num_shard_dims(ParallelTensorShape const &); ShardParallelDim shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &, ff_dim_t); +FFOrdered ff_ordered_shard_degrees(ParallelTensorShape const &); + std::optional try_get_shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); ParallelTensorShape lift_to_parallel(TensorShape const &); diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index e6c9dd751b..2bba41fa2a 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -2,11 +2,14 @@ #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" #include "utils/integer_conversions.h" +#include "op-attrs/dim_ordered/transform.h" +#include "op-attrs/dim_ordered/slice.h" namespace FlexFlow { -TensorShape get_kernel_shape(LinearAttrs const &attrs, - TensorShape const &input_shape) { +tl::expected + get_kernel_shape(LinearAttrs const &attrs, + TensorShape const &input_shape) { size_t in_channels = dim_at_idx(input_shape, ff_dim_t{-1}); return TensorShape{ @@ -17,8 +20,9 @@ TensorShape get_kernel_shape(LinearAttrs const &attrs, }; } -TensorShape get_bias_shape(LinearAttrs const &attrs, - TensorShape const &input_shape) { +tl::expected +get_bias_shape(LinearAttrs const &attrs, + TensorShape const &input_shape) { return TensorShape{ TensorDims{ FFOrdered{size_t_from_int(attrs.out_channels)}, @@ -27,7 +31,7 @@ TensorShape get_bias_shape(LinearAttrs const &attrs, }; } -TensorShape get_output_shape(LinearAttrs const &attrs, +tl::expected get_output_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { TensorShape output_shape = input_shape; output_shape.dims.ff_ordered.at(ff_dim_t{-1}) = @@ -36,42 +40,62 @@ TensorShape get_output_shape(LinearAttrs const &attrs, return output_shape; } -ParallelTensorShape get_kernel_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input_shape) { - NOT_IMPLEMENTED(); - /* ShardParallelDim input_sample_dim = shard_dim_at_idx(input_shape, - * ff_dim_t{-2}); */ - /* ShardParallelDim in_channels_dim = shard_dim_at_idx(input_shape, - * ff_dim_t{-1}); */ -} +tl::expected get_kernel_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input) { + TensorShape unpar = ({ + tl::expected result_unpar = get_kernel_shape(attrs, get_reduced_shape(input)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); -ParallelTensorShape get_output_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input_shape) { - ShardParallelDim input_sample_dim = - shard_dim_at_idx(input_shape, ff_dim_t{-2}); - ShardParallelDim in_channels_dim = - shard_dim_at_idx(input_shape, ff_dim_t{-1}); - - ShardParallelDim output_sample_dim = input_sample_dim; - ShardParallelDim output_channels_dim = { - size_t_from_int(attrs.out_channels), - get_discard_copy_degree(input_shape), + SumDegree sum_degree = 1; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ + get_sum_degree(input) * product(slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1})) }; + FFOrdered shard_degrees = FFOrdered{ + shard_dim_at_idx(input, ff_dim_t{-1}).degree, + get_discard_copy_degree(input), + }; + + return lift_to_parallel_with_degrees(unpar, sum_degree, discard_copy_degree, shard_degrees); +} + +tl::expected + get_bias_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input) { + TensorShape unpar = ({ + tl::expected result_unpar = get_bias_shape(attrs, get_reduced_shape(input)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); - int output_sum_degree = - get_sum_degree(input_shape) * in_channels_dim.degree; - int output_discard_copy_degree = 1; + SumDegree sum_degree = get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree; + DiscardCopyDegree discard_copy_degree = product(slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1})); + FFOrdered shard_degrees = FFOrdered{get_discard_copy_degree(input)}; + + return lift_to_parallel_with_degrees(unpar, sum_degree, discard_copy_degree, shard_degrees); +} - ParallelTensorShape result = input_shape; - shard_dim_at_idx(result, ff_dim_t{-2}) = output_sample_dim; - shard_dim_at_idx(result, ff_dim_t{-1}) = output_channels_dim; - result.dims.replica_dims.sum_degree = output_sum_degree; - result.dims.replica_dims.discard_copy_degree = output_discard_copy_degree; +tl::expected get_output_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input) { + TensorShape unpar = ({ + tl::expected result_unpar = get_output_shape(attrs, get_reduced_shape(input)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); - assert(total_parallel_degree(result.dims) == - total_parallel_degree(input_shape.dims)); + SumDegree sum_degree = get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree; + DiscardCopyDegree discard_copy_degree = 1; + FFOrdered shard_degrees = ff_ordered_shard_degrees(input); + shard_degrees.at(ff_dim_t{-1}) = get_discard_copy_degree(input); - return result; + return lift_to_parallel_with_degrees(unpar, sum_degree, discard_copy_degree, shard_degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc index a465751991..f3359da219 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml /* proj-data { - "generated_from": "1369f126a4a6d6eee91642043ab481f6" + "generated_from": "7e82d282f90e08f1e0db7d5c4ce528b7" } */ @@ -20,7 +20,7 @@ LinearAttrs::LinearAttrs( int const &out_channels, bool const &use_bias, ::FlexFlow::DataType const &data_type, - ::FlexFlow::Activation const &activation, + std::optional<::FlexFlow::Activation> const &activation, std::optional<::FlexFlow::RegularizerAttrs> const ®ularizer) : out_channels(out_channels), use_bias(use_bias), data_type(data_type), activation(activation), regularizer(regularizer) {} @@ -102,8 +102,8 @@ size_t hash::operator()( (result >> 2); result ^= std::hash<::FlexFlow::DataType>{}(x.data_type) + 0x9e3779b9 + (result << 6) + (result >> 2); - result ^= std::hash<::FlexFlow::Activation>{}(x.activation) + 0x9e3779b9 + - (result << 6) + (result >> 2); + result ^= std::hash>{}(x.activation) + + 0x9e3779b9 + (result << 6) + (result >> 2); result ^= std::hash>{}(x.regularizer) + 0x9e3779b9 + (result << 6) + (result >> 2); @@ -114,12 +114,13 @@ size_t hash::operator()( namespace nlohmann { FlexFlow::LinearAttrs adl_serializer::from_json(json const &j) { - return {j.at("out_channels").template get(), - j.at("use_bias").template get(), - j.at("data_type").template get<::FlexFlow::DataType>(), - j.at("activation").template get<::FlexFlow::Activation>(), - j.at("regularizer") - .template get>()}; + return { + j.at("out_channels").template get(), + j.at("use_bias").template get(), + j.at("data_type").template get<::FlexFlow::DataType>(), + j.at("activation").template get>(), + j.at("regularizer") + .template get>()}; } void adl_serializer::to_json( json &j, FlexFlow::LinearAttrs const &v) { @@ -138,7 +139,7 @@ Gen Arbitrary::arbitrary() { gen::arbitrary(), gen::arbitrary(), gen::arbitrary<::FlexFlow::DataType>(), - gen::arbitrary<::FlexFlow::Activation>(), + gen::arbitrary>(), gen::arbitrary>()); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 89a2934704..4ade47840b 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -12,6 +12,10 @@ FFOrdered ff_ordered_shard_dims(ParallelTensorDims const &d) { return d.shard_dims; } +FFOrdered ff_ordered_shard_degrees(ParallelTensorDims const &d) { + return transform(d.shard_dims, [](ShardParallelDim const &d) { return d.degree; }); +} + std::unordered_set replica_dims(ParallelTensorDims const &d) { return get_replica_dims(d.replica_dims); diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index e8ffc49269..a0ebb6f9c2 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -44,6 +44,10 @@ ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &s, ff_dim_t d) { return shard_dim_at_idx(s.dims, d); } +FFOrdered ff_ordered_shard_degrees(ParallelTensorShape const &s) { + return ff_ordered_shard_degrees(s.dims); +} + std::optional try_get_shard_dim_at_idx(ParallelTensorShape const &s, ff_dim_t d) { if (s.dims.shard_dims.idx_is_valid(d)) { return s.dims.shard_dims.at(d); diff --git a/lib/op-attrs/test/src/ops/linear.cc b/lib/op-attrs/test/src/ops/linear.cc new file mode 100644 index 0000000000..2de04a8a03 --- /dev/null +++ b/lib/op-attrs/test/src/ops/linear.cc @@ -0,0 +1,175 @@ +#include "test/utils/doctest.h" +#include "op-attrs/ops/linear.h" +#include "utils/integer_conversions.h" +#include "op-attrs/parallel_tensor_shape.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Linear shape inference") { + int out_channels = 16; + LinearAttrs attrs = LinearAttrs{ + /*out_channels=*/out_channels, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/Activation::RELU, + /*regularizer=*/std::nullopt, + }; + + size_t batch_size = 12; + size_t extra_dim = 16; + size_t in_channels = 8; + + TensorShape input = TensorShape{ + TensorDims{ + FFOrdered{ + batch_size, + extra_dim, + in_channels, + }, + }, + DataType::FLOAT, + }; + + TensorShape output = TensorShape{ + TensorDims{ + FFOrdered{ + batch_size, + extra_dim, + size_t_from_int(out_channels), + }, + }, + DataType::FLOAT, + }; + + TensorShape kernel = TensorShape{ + TensorDims{ + FFOrdered{ + in_channels, + size_t_from_int(out_channels), + }, + }, + DataType::FLOAT, + }; + + TensorShape bias = TensorShape{ + TensorDims{ + FFOrdered{ + size_t_from_int(out_channels), + }, + }, + DataType::FLOAT, + }; + + // get_output_shape + { + tl::expected output_result = get_output_shape(attrs, input); + tl::expected output_correct = output; + CHECK(output_result == output_correct); + } + + // get_weight_shape + { + tl::expected kernel_result = get_kernel_shape(attrs, input); + tl::expected kernel_correct = kernel; + CHECK(kernel_result == kernel_correct); + } + + // get_bias_shape + { + tl::expected bias_result = get_bias_shape(attrs, input); + tl::expected bias_correct = bias; + CHECK(bias_result == bias_correct); + } + + auto make_input = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_batch, int o_extra_dim, int o_channel) { + return lift_to_parallel_with_degrees(input, o_sum, o_eq, FFOrdered{o_batch, o_extra_dim, o_channel}); + }; + + auto make_output = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_batch, int o_extra_dim, int o_channel) { + return lift_to_parallel_with_degrees(output, o_sum, o_eq, FFOrdered{o_batch, o_extra_dim, o_channel}); + }; + + auto make_kernel = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_inchannel, int o_outchannel) { + return lift_to_parallel_with_degrees(kernel, o_sum, o_eq, FFOrdered{o_inchannel, o_outchannel}); + }; + + auto make_bias = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_outchannel) { + return lift_to_parallel_with_degrees(bias, o_sum, o_eq, FFOrdered{o_outchannel}); + }; + + SUBCASE("data parallelism") { + int input_sum_degree = 2; + int extra_dim_degree = 8; + int degree = 4; + + ParallelTensorShape par_input = make_input(SumDegree{input_sum_degree}, DiscardCopyDegree{1}, degree, extra_dim_degree, 1); + + { + tl::expected result = get_output_shape(attrs, par_input); + tl::expected correct = make_output(SumDegree{input_sum_degree}, DiscardCopyDegree{1}, degree, extra_dim_degree, 1); + CHECK(result == correct); + } + + { + tl::expected result = get_kernel_shape(attrs, par_input); + tl::expected correct = make_kernel(SumDegree{1}, DiscardCopyDegree{input_sum_degree * degree * extra_dim_degree}, 1, 1); + CHECK(result == correct); + } + + { + tl::expected result = get_bias_shape(attrs, par_input); + tl::expected correct = make_bias(SumDegree{input_sum_degree}, DiscardCopyDegree{degree * extra_dim_degree}, 1); + CHECK(result == correct); + } + } + + SUBCASE("reduction parallelism") { + int input_sum_degree = 2; + int degree = 4; + + ParallelTensorShape par_input = make_input(SumDegree{input_sum_degree}, DiscardCopyDegree{1}, 1, 1, degree); + + { + tl::expected result = get_output_shape(attrs, par_input); + tl::expected correct = make_output(SumDegree{input_sum_degree * degree}, DiscardCopyDegree{1}, 1, 1, 1); + CHECK(result == correct); + } + + { + tl::expected result = get_kernel_shape(attrs, par_input); + tl::expected correct = make_kernel(SumDegree{1}, DiscardCopyDegree{input_sum_degree}, degree, 1); + CHECK(result == correct); + } + + { + tl::expected result = get_bias_shape(attrs, par_input); + tl::expected correct = make_bias(SumDegree{input_sum_degree * degree}, DiscardCopyDegree{1}, 1); + CHECK(result == correct); + } + } + + SUBCASE("output channel parallelism") { + int input_sum_degree = 2; + int degree = 4; + + ParallelTensorShape par_input = make_input(SumDegree{input_sum_degree}, DiscardCopyDegree{degree}, 1, 1, 1); + + { + tl::expected result = get_output_shape(attrs, par_input); + tl::expected correct = make_output(SumDegree{input_sum_degree}, DiscardCopyDegree{1}, 1, 1, degree); + CHECK(result == correct); + } + + { + tl::expected result = get_kernel_shape(attrs, par_input); + tl::expected correct = make_kernel(SumDegree{1}, DiscardCopyDegree{input_sum_degree}, 1, degree); + CHECK(result == correct); + } + + { + tl::expected result = get_bias_shape(attrs, par_input); + tl::expected correct = make_bias(SumDegree{input_sum_degree}, DiscardCopyDegree{1}, degree); + CHECK(result == correct); + } + } + } +} From 13cb842cd5b6a0a3cb265c29b65895962fe567d4 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 2 Jun 2024 22:37:58 -0700 Subject: [PATCH 31/43] Fix tests --- lib/pcg/src/computation_graph.cc | 76 -------------------------------- 1 file changed, 76 deletions(-) delete mode 100644 lib/pcg/src/computation_graph.cc diff --git a/lib/pcg/src/computation_graph.cc b/lib/pcg/src/computation_graph.cc deleted file mode 100644 index 18fded6d3e..0000000000 --- a/lib/pcg/src/computation_graph.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include "pcg/computation_graph.h" - -namespace FlexFlow { - -std::vector - traverse_comp_graph_forward(ComputationGraph const &comp_graph) { - std::vector layers = get_topological_ordering(comp_graph.value()); - return transform(layers, [&](Node const &e) -> operator_guid_t { - return operator_guid_t{e}; - }); -} - -std::vector - traverse_comp_graph_backward(ComputationGraph const &comp_graph) { - std::vector layers = - reversed>(get_topological_ordering(comp_graph.value())); - return transform(layers, [&](Node const &e) -> operator_guid_t { - return operator_guid_t{e}; - }); -} - -std::vector - sort_edge_set(std::unordered_set edges) { - return transform( - sorted_by(edges, compare_by([](MultiDiEdge const &e) { - return e.src_idx; - })), - [&](MultiDiEdge const &e) -> tensor_guid_t { return tensor_guid_t{e}; }); -} - -std::vector - get_outgoing_tensors(ComputationGraph const &comp_graph, - operator_guid_t n) { - return sort_edge_set(get_outgoing_edges(comp_graph.value(), n.value())); -} - -std::vector - get_incoming_tensors(ComputationGraph const &comp_graph, - operator_guid_t n) { - return sort_edge_set(get_incoming_edges(comp_graph.value(), n.value())); -} - -operator_guid_t create_node(ComputationGraph &comp_graph, Layer const &layer) { - Node added_node = comp_graph.value().add_node(layer); - return operator_guid_t{added_node}; -} - -tensor_guid_t create_outgoing_edge(ComputationGraph &comp_graph, - operator_guid_t node, - int idx, - Tensor tensor) { - MultiDiOutput edge = {node.value(), NodePort{idx}}; - comp_graph.value().add_output(edge, tensor); - return tensor_guid_t{edge}; -} - -void connect_incoming_edges(ComputationGraph &comp_graph, - std::vector const &incoming_edges, - operator_guid_t node) { - size_t incoming_edge_dst_port = 0; - for (tensor_guid_t input : incoming_edges) { - MultiDiOutput input_view = input.value(); - MultiDiEdge edge = {node.value(), - NodePort{incoming_edge_dst_port++}, - input_view.src, - input_view.src_idx}; - comp_graph.value().add_edge(edge); - } -} - -CompGraphOperatorAttrs get_layer_attrs(ComputationGraph const &comp_graph, - operator_guid_t const &n) { - return comp_graph.at(n).attrs; -} - -} // namespace FlexFlow From 967cb22acc2920901d75bcca6bf3f0654f1f828c Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 2 Jun 2024 22:38:30 -0700 Subject: [PATCH 32/43] Format --- .../include/op-attrs/dim_ordered/slice.h | 31 ++- .../include/op-attrs/dim_ordered/transform.h | 5 +- lib/op-attrs/include/op-attrs/ops/attention.h | 42 +-- .../attention/multihead_attention_inputs.h | 7 +- .../multihead_attention_parallel_inputs.h | 7 +- .../include/op-attrs/ops/batch_matmul.h | 15 +- lib/op-attrs/include/op-attrs/ops/combine.h | 6 +- .../include/op-attrs/ops/element_binary.h | 15 +- .../include/op-attrs/ops/element_unary.h | 19 +- lib/op-attrs/include/op-attrs/ops/embedding.h | 14 +- lib/op-attrs/include/op-attrs/ops/linear.h | 21 +- lib/op-attrs/include/op-attrs/ops/reduction.h | 5 +- .../include/op-attrs/ops/repartition.h | 4 +- lib/op-attrs/include/op-attrs/ops/replicate.h | 2 +- .../include/op-attrs/parallel_tensor_shape.h | 9 +- lib/op-attrs/include/op-attrs/tensor_dims.h | 6 +- lib/op-attrs/src/op-attrs/ops/attention.cc | 102 ++++--- .../attention/multihead_attention_inputs.cc | 52 ++-- .../multihead_attention_parallel_inputs.cc | 96 +++++-- lib/op-attrs/src/op-attrs/ops/batch_matmul.cc | 161 ++++++----- lib/op-attrs/src/op-attrs/ops/combine.cc | 19 +- .../src/op-attrs/ops/element_binary.cc | 41 +-- .../src/op-attrs/ops/element_unary.cc | 33 ++- lib/op-attrs/src/op-attrs/ops/embedding.cc | 72 ++--- lib/op-attrs/src/op-attrs/ops/linear.cc | 63 +++-- lib/op-attrs/src/op-attrs/ops/reduction.cc | 10 +- lib/op-attrs/src/op-attrs/ops/repartition.cc | 7 +- lib/op-attrs/src/op-attrs/ops/replicate.cc | 3 +- .../src/op-attrs/parallel_tensor_dims.cc | 11 +- .../src/op-attrs/parallel_tensor_shape.cc | 16 +- .../src/op-attrs/replica_parallel_dim_set.cc | 3 +- lib/op-attrs/src/op-attrs/tensor_dims.cc | 23 +- lib/op-attrs/test/src/dim_ordered/slice.cc | 12 +- lib/op-attrs/test/src/ops/combine.cc | 44 +-- lib/op-attrs/test/src/ops/linear.cc | 190 ++++++++----- lib/op-attrs/test/src/ops/reduction.cc | 40 +-- lib/op-attrs/test/src/ops/repartition.cc | 34 +-- lib/op-attrs/test/src/ops/replicate.cc | 30 +-- lib/op-attrs/test/src/test_attention.cc | 254 ++++++++++-------- lib/op-attrs/test/src/test_batch_matmul.cc | 234 +++++++++------- lib/op-attrs/test/src/test_element_binary.cc | 118 +++++--- lib/op-attrs/test/src/test_element_unary.cc | 47 ++-- lib/op-attrs/test/src/test_embedding.cc | 132 +++++---- lib/pcg/include/pcg/open_dataflow_graph.h | 50 ++-- lib/pcg/src/pcg/computation_graph_builder.cc | 15 +- .../tensor_pattern/get_attribute.cc | 18 +- lib/utils/include/utils/containers.h | 2 +- .../utils/containers/vector_transform.h | 5 +- .../include/utils/containers/zip_vectors.h | 5 +- lib/utils/include/utils/fmt.decl.h | 1 - lib/utils/include/utils/fmt/expected.h | 10 +- lib/utils/src/utils/integer_conversions.cc | 2 +- .../test/common/include/test/utils/doctest.h | 6 +- 53 files changed, 1302 insertions(+), 867 deletions(-) diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h index 0bc4c7513e..4d6e82b71b 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -8,32 +8,45 @@ namespace FlexFlow { template -DimOrdered nonoverloaded_slice(DimOrdered const &d, std::optional const &start, std::optional const &end) { +DimOrdered nonoverloaded_slice(DimOrdered const &d, + std::optional const &start, + std::optional const &end) { auto to_raw_idx = [](std::optional const &idx) -> std::optional { return transform(idx, [](Idx const &i) { return i.value; }); }; - return DimOrdered{subvec(as_vector(d), to_raw_idx(start), to_raw_idx(end))}; + return DimOrdered{ + subvec(as_vector(d), to_raw_idx(start), to_raw_idx(end))}; } template -DimOrdered slice(DimOrdered const &d, std::optional const &start, std::optional const &end) { +DimOrdered slice(DimOrdered const &d, + std::optional const &start, + std::optional const &end) { return nonoverloaded_slice(d, start, end); } template -DimOrdered slice(DimOrdered const &d, std::nullopt_t const &start, Idx const &end) { - return nonoverloaded_slice(d, std::optional{start}, std::optional{end}); +DimOrdered slice(DimOrdered const &d, + std::nullopt_t const &start, + Idx const &end) { + return nonoverloaded_slice( + d, std::optional{start}, std::optional{end}); } template -DimOrdered slice(DimOrdered const &d, Idx const &start, std::nullopt_t const &end) { - return nonoverloaded_slice(d, std::optional{start}, std::optional{end}); +DimOrdered slice(DimOrdered const &d, + Idx const &start, + std::nullopt_t const &end) { + return nonoverloaded_slice( + d, std::optional{start}, std::optional{end}); } template -DimOrdered slice(DimOrdered const &d, Idx const &start, Idx const &end) { - return nonoverloaded_slice(d, std::optional{start}, std::optional{end}); +DimOrdered + slice(DimOrdered const &d, Idx const &start, Idx const &end) { + return nonoverloaded_slice( + d, std::optional{start}, std::optional{end}); } } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h index 08ffff43f1..880f13b4d4 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h @@ -8,12 +8,13 @@ namespace FlexFlow { template -DimOrdered> transform(DimOrdered const &d, F f) { +DimOrdered> + transform(DimOrdered const &d, F f) { using Out = std::invoke_result_t; return DimOrdered{vector_transform(as_vector(d), f)}; } - + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index c552f034d0..8233775e63 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -2,10 +2,10 @@ #define _FLEXFLOW_ATTENTION_ATTRS_H #include "core.h" -#include "op-attrs/ops/attention_attrs.dtg.h" -#include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" #include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" +#include "op-attrs/ops/attention_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include @@ -37,23 +37,27 @@ int get_kvSeqLength(MultiHeadAttentionInputs const &); int get_num_samples(MultiHeadAttentionParallelInputs const &); int get_num_samples(MultiHeadAttentionInputs const &); -tl::expected get_weights_shape(MultiHeadAttentionAttrs const &, - TensorShape const &input_q, - TensorShape const &input_k, - TensorShape const &input_v); -tl::expected get_weights_shape(MultiHeadAttentionAttrs const &, - ParallelTensorShape const &input_q, - ParallelTensorShape const &input_k, - ParallelTensorShape const &input_v); - -tl::expected get_output_shape(MultiHeadAttentionAttrs const &, - TensorShape const &input_q, - TensorShape const &input_k, - TensorShape const &input_v); -tl::expected get_output_shape(MultiHeadAttentionAttrs const &, - ParallelTensorShape const &input_q, - ParallelTensorShape const &input_k, - ParallelTensorShape const &input_v); +tl::expected + get_weights_shape(MultiHeadAttentionAttrs const &, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v); +tl::expected + get_weights_shape(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); + +tl::expected + get_output_shape(MultiHeadAttentionAttrs const &, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v); +tl::expected + get_output_shape(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); CHECK_VALID_OP_ATTR(MultiHeadAttentionAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.h index 044d4cfca9..aed9f577ff 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.h +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.h @@ -7,9 +7,10 @@ namespace FlexFlow { -tl::expected parse_attention_input_shape(TensorShape const &input_q, - TensorShape const &input_k, - TensorShape const &input_v); +tl::expected + parse_attention_input_shape(TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.h index a09a66e531..53cc3167f2 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.h +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.h @@ -7,9 +7,10 @@ namespace FlexFlow { -tl::expected parse_attention_parallel_input_shape(ParallelTensorShape const &input_q, - ParallelTensorShape const &input_k, - ParallelTensorShape const &input_v); +tl::expected + parse_attention_parallel_input_shape(ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index e08cdef70b..57760d1110 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -12,14 +12,15 @@ bool is_valid(BatchMatmulAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &); +tl::expected + get_output_shape(BatchMatmulAttrs const &attrs, + TensorShape const &input_lhs, + TensorShape const &input_rhs); -tl::expected get_output_shape(BatchMatmulAttrs const &attrs, - TensorShape const &input_lhs, - TensorShape const &input_rhs); - -tl::expected get_output_shape(BatchMatmulAttrs const &attrs, - ParallelTensorShape const &input_lhs, - ParallelTensorShape const &input_rhs); +tl::expected + get_output_shape(BatchMatmulAttrs const &attrs, + ParallelTensorShape const &input_lhs, + ParallelTensorShape const &input_rhs); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/combine.h b/lib/op-attrs/include/op-attrs/ops/combine.h index bbd5c12f33..d9b20fc2c5 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine.h +++ b/lib/op-attrs/include/op-attrs/ops/combine.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_H -#include "op-attrs/ops/core.h" #include "op-attrs/ops/combine_attrs.dtg.h" +#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include @@ -10,8 +10,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(CombineAttrs); -tl::expected - get_output_shape(CombineAttrs const &, ParallelTensorShape const &); +tl::expected + get_output_shape(CombineAttrs const &, ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary.h b/lib/op-attrs/include/op-attrs/ops/element_binary.h index a33e48e524..d51c3a3afa 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary.h @@ -4,17 +4,16 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/element_binary_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" -#include +#include namespace FlexFlow { -tl::expected - get_output_shape(ElementBinaryAttrs const &, - TensorShape const &, - TensorShape const &); -tl::expected get_output_shape(ElementBinaryAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); +tl::expected get_output_shape( + ElementBinaryAttrs const &, TensorShape const &, TensorShape const &); +tl::expected + get_output_shape(ElementBinaryAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(ElementBinaryAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 6fd37fc80c..471a2a30f5 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -11,18 +11,15 @@ namespace FlexFlow { tl::expected - get_output_shape(ElementUnaryAttrs const &, - TensorShape const &); + get_output_shape(ElementUnaryAttrs const &, TensorShape const &); tl::expected - get_output_shape(ElementUnaryAttrs const &, - ParallelTensorShape const &); - -tl::expected - get_output_shape(ElementScalarUnaryAttrs const &, - TensorShape const &); -tl::expected - get_output_shape(ElementScalarUnaryAttrs const &, - ParallelTensorShape const &); + get_output_shape(ElementUnaryAttrs const &, ParallelTensorShape const &); + +tl::expected + get_output_shape(ElementScalarUnaryAttrs const &, TensorShape const &); +tl::expected + get_output_shape(ElementScalarUnaryAttrs const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(ElementUnaryAttrs); CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/embedding.h b/lib/op-attrs/include/op-attrs/ops/embedding.h index 52c868edee..aa67c6cb04 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding.h @@ -11,13 +11,15 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(EmbeddingAttrs); -tl::expected get_output_shape(EmbeddingAttrs const &, TensorShape const &); -tl::expected get_weights_shape(EmbeddingAttrs const &, TensorShape const &); +tl::expected get_output_shape(EmbeddingAttrs const &, + TensorShape const &); +tl::expected get_weights_shape(EmbeddingAttrs const &, + TensorShape const &); -tl::expected get_output_shape(EmbeddingAttrs const &, - ParallelTensorShape const &); -tl::expected get_weights_shape(EmbeddingAttrs const &, - ParallelTensorShape const &); +tl::expected + get_output_shape(EmbeddingAttrs const &, ParallelTensorShape const &); +tl::expected + get_weights_shape(EmbeddingAttrs const &, ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index 6da4f0c9f3..dd6948165e 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -12,23 +12,20 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(LinearAttrs); tl::expected - get_kernel_shape(LinearAttrs const &attrs, - TensorShape const &input); + get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input); +tl::expected get_bias_shape(LinearAttrs const &attrs, + TensorShape const &input); tl::expected - get_bias_shape(LinearAttrs const &attrs, TensorShape const &input); -tl::expected - get_output_shape(LinearAttrs const &attrs, - TensorShape const &input); + get_output_shape(LinearAttrs const &attrs, TensorShape const &input); tl::expected - get_kernel_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input); + get_kernel_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input); tl::expected - get_bias_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input); + get_bias_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); tl::expected - get_output_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input); + get_output_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/reduction.h b/lib/op-attrs/include/op-attrs/ops/reduction.h index 49f99d81fd..a6047b38f9 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction.h @@ -10,8 +10,9 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ReductionAttrs); -tl::expected get_output_shape(ReductionAttrs const &attrs, - ParallelTensorShape const &input_shape); +tl::expected + get_output_shape(ReductionAttrs const &attrs, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/repartition.h b/lib/op-attrs/include/op-attrs/ops/repartition.h index 15b7ed71de..559e7278f5 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition.h @@ -11,8 +11,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(RepartitionAttrs); tl::expected - get_output_shape(RepartitionAttrs const &, - ParallelTensorShape const &input_shape); + get_output_shape(RepartitionAttrs const &, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index f10938cb68..4c46bf88a9 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -10,7 +10,7 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ReplicateAttrs); ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, - ParallelTensorShape const &input_shape); + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index c8b0ad236c..99be635ffc 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -13,10 +13,15 @@ ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &, ff_dim_t); FFOrdered ff_ordered_shard_degrees(ParallelTensorShape const &); -std::optional try_get_shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); +std::optional + try_get_shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); ParallelTensorShape lift_to_parallel(TensorShape const &); -ParallelTensorShape lift_to_parallel_with_degrees(TensorShape const &, SumDegree sum_degree, DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees); +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &, + SumDegree sum_degree, + DiscardCopyDegree discard_copy_degree, + FFOrdered const &shard_degrees); std::unordered_set replica_dims(ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h index 0f4a793430..2391197471 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -13,7 +13,11 @@ size_t dim_at_idx(TensorDims const &, ff_dim_t); size_t &dim_at_idx(TensorDims &, ff_dim_t); ParallelTensorDims lift_to_parallel(TensorDims const &); -ParallelTensorDims lift_to_parallel_with_degrees(TensorDims const &, SumDegree sum_degree, DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees); +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &, + SumDegree sum_degree, + DiscardCopyDegree discard_copy_degree, + FFOrdered const &shard_degrees); } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index 1cbebcbdc7..14ab2b9b00 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -67,11 +67,13 @@ int get_vSize(MultiHeadAttentionInputs const &) { NOT_IMPLEMENTED(); } -tl::expected get_output_shape(MultiHeadAttentionAttrs const &attrs, - TensorShape const &input_q, - TensorShape const &input_k, - TensorShape const &input_v) { - tl::expected parse_result = parse_attention_input_shape(input_q, input_k, input_v); +tl::expected + get_output_shape(MultiHeadAttentionAttrs const &attrs, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { + tl::expected parse_result = + parse_attention_input_shape(input_q, input_k, input_v); if (!parse_result.has_value()) { return tl::unexpected(parse_result.error()); } @@ -79,30 +81,29 @@ tl::expected get_output_shape(MultiHeadAttentionAttrs MultiHeadAttentionInputs parsed = parse_result.value(); return TensorShape{ - TensorDims{ - FFOrdered{ - parsed.batch_size, - parsed.sequence_length, - size_t_from_int(attrs.embed_dim), - } - }, - parsed.datatype, + TensorDims{FFOrdered{ + parsed.batch_size, + parsed.sequence_length, + size_t_from_int(attrs.embed_dim), + }}, + parsed.datatype, }; } tl::expected -get_weights_shape(MultiHeadAttentionAttrs const &attrs, - TensorShape const &input_q, - TensorShape const &input_k, - TensorShape const &input_v) { - tl::expected parse_result = parse_attention_input_shape(input_q, input_k, input_v); + get_weights_shape(MultiHeadAttentionAttrs const &attrs, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { + tl::expected parse_result = + parse_attention_input_shape(input_q, input_k, input_v); if (!parse_result.has_value()) { return tl::unexpected(parse_result.error()); } MultiHeadAttentionInputs parsed = parse_result.value(); - // W^Q_i in "Attention Is All You Need" top of page 5 + // W^Q_i in "Attention Is All You Need" top of page 5 size_t qProjectWeightSize = parsed.query_size * attrs.kdim; // W^K_i in "Attention Is All You Need" top of page 5 (all i's put together) @@ -111,32 +112,37 @@ get_weights_shape(MultiHeadAttentionAttrs const &attrs, // W^V_i in "Attention Is All You Need" top of page 5 (all i's put together) size_t vProjectWeightSize = parsed.value_size * attrs.vdim; - // W^O in "Attention Is All You Need" top of page 5, with num_heads factored out + // W^O in "Attention Is All You Need" top of page 5, with num_heads factored + // out size_t outWeightSize = parsed.value_size * attrs.embed_dim; return TensorShape{ - TensorDims{ - FFOrdered{ - (qProjectWeightSize + kProjectWeightSize + vProjectWeightSize + outWeightSize), - size_t_from_int(attrs.num_heads), - } - }, - parsed.datatype, + TensorDims{FFOrdered{ + (qProjectWeightSize + kProjectWeightSize + vProjectWeightSize + + outWeightSize), + size_t_from_int(attrs.num_heads), + }}, + parsed.datatype, }; } tl::expected -get_weights_shape(MultiHeadAttentionAttrs const &attrs, - ParallelTensorShape const &input_q, - ParallelTensorShape const &input_k, - ParallelTensorShape const &input_v) { - tl::expected parse_result = parse_attention_parallel_input_shape(input_q, input_k, input_v); + get_weights_shape(MultiHeadAttentionAttrs const &attrs, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v) { + tl::expected parse_result = + parse_attention_parallel_input_shape(input_q, input_k, input_v); if (!parse_result.has_value()) { return tl::unexpected(parse_result.error()); } MultiHeadAttentionParallelInputs parsed = parse_result.value(); - tl::expected result_unpar_get_shape = get_weights_shape(attrs, get_reduced_shape(input_q), get_reduced_shape(input_k), get_reduced_shape(input_v)); + tl::expected result_unpar_get_shape = + get_weights_shape(attrs, + get_reduced_shape(input_q), + get_reduced_shape(input_k), + get_reduced_shape(input_v)); if (!result_unpar_get_shape.has_value()) { return tl::unexpected(result_unpar_get_shape.error()); } @@ -145,21 +151,30 @@ get_weights_shape(MultiHeadAttentionAttrs const &attrs, int joined_dim_degree = 1; int head_dim_degree = parsed.discard_copy_degree.value; - return lift_to_parallel_with_degrees(unpar_shape, SumDegree{1}, DiscardCopyDegree{parsed.batch_dim.degree}, FFOrdered{joined_dim_degree, head_dim_degree}); + return lift_to_parallel_with_degrees( + unpar_shape, + SumDegree{1}, + DiscardCopyDegree{parsed.batch_dim.degree}, + FFOrdered{joined_dim_degree, head_dim_degree}); } tl::expected -get_output_shape(MultiHeadAttentionAttrs const &attrs, - ParallelTensorShape const &input_q, - ParallelTensorShape const &input_k, - ParallelTensorShape const &input_v) { - tl::expected parse_result = parse_attention_parallel_input_shape(input_q, input_k, input_v); + get_output_shape(MultiHeadAttentionAttrs const &attrs, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v) { + tl::expected parse_result = + parse_attention_parallel_input_shape(input_q, input_k, input_v); if (!parse_result.has_value()) { return tl::unexpected(parse_result.error()); } MultiHeadAttentionParallelInputs parsed = parse_result.value(); - tl::expected result_unpar_get_shape = get_output_shape(attrs, get_reduced_shape(input_q), get_reduced_shape(input_k), get_reduced_shape(input_v)); + tl::expected result_unpar_get_shape = + get_output_shape(attrs, + get_reduced_shape(input_q), + get_reduced_shape(input_k), + get_reduced_shape(input_v)); if (!result_unpar_get_shape.has_value()) { return tl::unexpected(result_unpar_get_shape.error()); } @@ -171,10 +186,13 @@ get_output_shape(MultiHeadAttentionAttrs const &attrs, int seq_len_degree = 1; int out_dim_degree = 1; - return lift_to_parallel_with_degrees(unpar_shape, SumDegree{sum_degree}, DiscardCopyDegree{discard_copy_degree}, FFOrdered{batch_degree, seq_len_degree, out_dim_degree}); + return lift_to_parallel_with_degrees( + unpar_shape, + SumDegree{sum_degree}, + DiscardCopyDegree{discard_copy_degree}, + FFOrdered{batch_degree, seq_len_degree, out_dim_degree}); } - int get_oSize(ParallelTensorShape const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc index 09d48df497..65feb642e1 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc @@ -8,17 +8,27 @@ static bool all_same(T const &x, T const &y, T const &z) { return x == y && y == z; } -tl::expected parse_attention_input_shape(TensorShape const &input_q, - TensorShape const &input_k, - TensorShape const &input_v) { +tl::expected + parse_attention_input_shape(TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { if (num_dims(input_q) != 3) { - return tl::unexpected(fmt::format("Query input has incorrect number of dims: {} != {}", num_dims(input_q), 3)); + return tl::unexpected( + fmt::format("Query input has incorrect number of dims: {} != {}", + num_dims(input_q), + 3)); } if (num_dims(input_k) != 3) { - return tl::unexpected(fmt::format("Key input has incorrect number of dims: {} != {}", num_dims(input_k), 3)); + return tl::unexpected( + fmt::format("Key input has incorrect number of dims: {} != {}", + num_dims(input_k), + 3)); } if (num_dims(input_v) != 3) { - return tl::unexpected(fmt::format("Value input has incorrect number of dims: {} != {}", num_dims(input_v), 3)); + return tl::unexpected( + fmt::format("Value input has incorrect number of dims: {} != {}", + num_dims(input_v), + 3)); } size_t seq_len_q = dim_at_idx(input_q, ff_dim_t{-2}); @@ -26,7 +36,11 @@ tl::expected parse_attention_input_shape( size_t seq_len_v = dim_at_idx(input_v, ff_dim_t{-2}); if (!all_same(seq_len_q, seq_len_k, seq_len_v)) { - return tl::unexpected(fmt::format("Q, K, V disagree on the sequence length: {} (Q) vs {} (K) vs {} (V)", seq_len_q, seq_len_k, seq_len_v)); + return tl::unexpected(fmt::format( + "Q, K, V disagree on the sequence length: {} (Q) vs {} (K) vs {} (V)", + seq_len_q, + seq_len_k, + seq_len_v)); } size_t batch_size_q = dim_at_idx(input_q, ff_dim_t{-3}); @@ -34,11 +48,19 @@ tl::expected parse_attention_input_shape( size_t batch_size_v = dim_at_idx(input_v, ff_dim_t{-3}); if (!all_same(batch_size_q, batch_size_k, batch_size_v)) { - return tl::unexpected(fmt::format("Q, K, V disagree on the batch size: {} (Q) vs {} (K) vs {} (V)", batch_size_q, batch_size_k, batch_size_v)); + return tl::unexpected(fmt::format( + "Q, K, V disagree on the batch size: {} (Q) vs {} (K) vs {} (V)", + batch_size_q, + batch_size_k, + batch_size_v)); } if (!all_same(input_q.data_type, input_k.data_type, input_v.data_type)) { - return tl::unexpected(fmt::format("Q, K, V disagree on the datatype: {} (Q) vs {} (K) vs {} (V)", input_q.data_type, input_k.data_type, input_v.data_type)); + return tl::unexpected(fmt::format( + "Q, K, V disagree on the datatype: {} (Q) vs {} (K) vs {} (V)", + input_q.data_type, + input_k.data_type, + input_v.data_type)); } size_t q_size = dim_at_idx(input_q, ff_dim_t{-1}); @@ -46,12 +68,12 @@ tl::expected parse_attention_input_shape( size_t v_size = dim_at_idx(input_v, ff_dim_t{-1}); return MultiHeadAttentionInputs{ - batch_size_q, - seq_len_q, - q_size, - k_size, - v_size, - input_q.data_type, + batch_size_q, + seq_len_q, + q_size, + k_size, + v_size, + input_q.data_type, }; } diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc index 4038d0c7a3..2cd5b7ec00 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc @@ -1,6 +1,6 @@ #include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h" -#include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/ops/attention/multihead_attention_inputs.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { @@ -9,61 +9,99 @@ static bool all_same(T const &x, T const &y, T const &z) { return x == y && y == z; } -tl::expected parse_attention_parallel_input_shape(ParallelTensorShape const &input_q, - ParallelTensorShape const &input_k, - ParallelTensorShape const &input_v) { - tl::expected unpar_parse_result = parse_attention_input_shape( - get_reduced_shape(input_q), get_reduced_shape(input_k), get_reduced_shape(input_v)); +tl::expected + parse_attention_parallel_input_shape(ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v) { + tl::expected unpar_parse_result = + parse_attention_input_shape(get_reduced_shape(input_q), + get_reduced_shape(input_k), + get_reduced_shape(input_v)); if (!unpar_parse_result.has_value()) { - return tl::unexpected(fmt::format("MHA unparallel input parsing failed with message: \"{}\"", unpar_parse_result.error())); + return tl::unexpected( + fmt::format("MHA unparallel input parsing failed with message: \"{}\"", + unpar_parse_result.error())); } if (num_shard_dims(input_q) != 3) { - return tl::unexpected(fmt::format("Query input has incorrect number of dims: {} != {}", num_shard_dims(input_q), 3)); + return tl::unexpected( + fmt::format("Query input has incorrect number of dims: {} != {}", + num_shard_dims(input_q), + 3)); } if (num_shard_dims(input_k) != 3) { - return tl::unexpected(fmt::format("Key input has incorrect number of dims: {} != {}", num_shard_dims(input_k), 3)); + return tl::unexpected( + fmt::format("Key input has incorrect number of dims: {} != {}", + num_shard_dims(input_k), + 3)); } if (num_shard_dims(input_v) != 3) { - return tl::unexpected(fmt::format("Value input has incorrect number of dims: {} != {}", num_shard_dims(input_v), 3)); + return tl::unexpected( + fmt::format("Value input has incorrect number of dims: {} != {}", + num_shard_dims(input_v), + 3)); } ShardParallelDim seq_len_q = shard_dim_at_idx(input_q, ff_dim_t{-2}); if (seq_len_q.degree != 1) { - return tl::unexpected(fmt::format("Query sequence length parallel degree expected to be 1, but received degree {}", seq_len_q.degree)); + return tl::unexpected( + fmt::format("Query sequence length parallel degree expected to be 1, " + "but received degree {}", + seq_len_q.degree)); } ShardParallelDim seq_len_k = shard_dim_at_idx(input_k, ff_dim_t{-2}); if (seq_len_k.degree != 1) { - return tl::unexpected(fmt::format("Key sequence length parallel degree expected to be 1, but received degree {}", seq_len_k.degree)); + return tl::unexpected( + fmt::format("Key sequence length parallel degree expected to be 1, but " + "received degree {}", + seq_len_k.degree)); } ShardParallelDim seq_len_v = shard_dim_at_idx(input_v, ff_dim_t{-2}); if (seq_len_v.degree != 1) { - return tl::unexpected(fmt::format("Value sequence length parallel degree expected to be 1, but received degree {}", seq_len_v.degree)); + return tl::unexpected( + fmt::format("Value sequence length parallel degree expected to be 1, " + "but received degree {}", + seq_len_v.degree)); } ShardParallelDim batch_size_q = shard_dim_at_idx(input_q, ff_dim_t{-3}); ShardParallelDim batch_size_k = shard_dim_at_idx(input_k, ff_dim_t{-3}); ShardParallelDim batch_size_v = shard_dim_at_idx(input_v, ff_dim_t{-3}); - if (!all_same(batch_size_q.degree, batch_size_k.degree, batch_size_v.degree)) { - return tl::unexpected(fmt::format("Q, K, V disagree on the parallel degree of the batch dimension: {} (Q) vs {} (K) vs {} (V)", batch_size_q.degree, batch_size_k.degree, batch_size_v.degree)); + if (!all_same( + batch_size_q.degree, batch_size_k.degree, batch_size_v.degree)) { + return tl::unexpected( + fmt::format("Q, K, V disagree on the parallel degree of the batch " + "dimension: {} (Q) vs {} (K) vs {} (V)", + batch_size_q.degree, + batch_size_k.degree, + batch_size_v.degree)); } ShardParallelDim query_dim = shard_dim_at_idx(input_q, ff_dim_t{-1}); if (query_dim.degree > 1) { - return tl::unexpected(fmt::format("Expected query tensor to have query dim parallel degree 1, but received degree {}", query_dim.degree)); + return tl::unexpected( + fmt::format("Expected query tensor to have query dim parallel degree " + "1, but received degree {}", + query_dim.degree)); } ShardParallelDim key_dim = shard_dim_at_idx(input_k, ff_dim_t{-1}); if (key_dim.degree > 1) { - return tl::unexpected(fmt::format("Expected key tensor to have key dim parallel degree 1, but received degree {}", key_dim.degree)); + return tl::unexpected( + fmt::format("Expected key tensor to have key dim parallel degree 1, " + "but received degree {}", + key_dim.degree)); } ShardParallelDim value_dim = shard_dim_at_idx(input_v, ff_dim_t{-1}); if (value_dim.degree > 1) { - return tl::unexpected(fmt::format("Expected value tensor to have value dim parallel degree 1, but received degree {}", value_dim.degree)); + return tl::unexpected( + fmt::format("Expected value tensor to have value dim parallel degree " + "1, but received degree {}", + value_dim.degree)); } int discard_copy_q = get_discard_copy_degree(input_q); @@ -71,19 +109,23 @@ tl::expected parse_attention_para int discard_copy_v = get_discard_copy_degree(input_v); if (!all_same(discard_copy_q, discard_copy_k, discard_copy_v)) { - return tl::unexpected(fmt::format("Q, K, V disagree on the discard-copy degree: {} (Q) vs {} (K) vs {} (V)", discard_copy_q, discard_copy_k, discard_copy_v)); + return tl::unexpected(fmt::format("Q, K, V disagree on the discard-copy " + "degree: {} (Q) vs {} (K) vs {} (V)", + discard_copy_q, + discard_copy_k, + discard_copy_v)); } return MultiHeadAttentionParallelInputs{ - batch_size_q, - seq_len_q, - query_dim, - key_dim, - value_dim, - discard_copy_q, - input_q.data_type, + batch_size_q, + seq_len_q, + query_dim, + key_dim, + value_dim, + discard_copy_q, + input_q.data_type, }; - + // return; } diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc index 69e644441c..cbda4ea533 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc @@ -3,25 +3,26 @@ namespace FlexFlow { -// bool BatchMatmulAttrs::is_valid( -// ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const { -// if (!lhs.is_valid() || !rhs.is_valid()) { -// return false; -// } -// if (lhs.num_dims() != rhs.num_dims()) { -// return false; -// } -// for (int i = lhs.num_dims() - 1; i >= 2; i--) { -// if (lhs.at(i) != rhs.at(i)) { -// return false; -// } -// } -// if (lhs.at(0) != rhs.at(1)) { -// return false; -// } -// -// return true; -// } +// bool BatchMatmulAttrs::is_valid( +// ParallelTensorShape const &lhs, +// ParallelTensorShape const &rhs) const { +// if (!lhs.is_valid() || !rhs.is_valid()) { +// return false; +// } +// if (lhs.num_dims() != rhs.num_dims()) { +// return false; +// } +// for (int i = lhs.num_dims() - 1; i >= 2; i--) { +// if (lhs.at(i) != rhs.at(i)) { +// return false; +// } +// } +// if (lhs.at(0) != rhs.at(1)) { +// return false; +// } +// +// return true; +// } bool is_valid(BatchMatmulAttrs const &, ParallelTensorShape const &, @@ -29,24 +30,33 @@ bool is_valid(BatchMatmulAttrs const &, NOT_IMPLEMENTED(); } -tl::expected get_output_shape(BatchMatmulAttrs const &attrs, - TensorShape const &input_lhs, - TensorShape const &input_rhs) { - // If input_lhs is a (b×n×m) tensor, - // input_rhs is a (b×m×p) tensor, +tl::expected + get_output_shape(BatchMatmulAttrs const &attrs, + TensorShape const &input_lhs, + TensorShape const &input_rhs) { + // If input_lhs is a (b×n×m) tensor, + // input_rhs is a (b×m×p) tensor, // out will be a (b×n×p) tensor. // https://pytorch.org/docs/stable/generated/torch.bmm.html if (num_dims(input_lhs) != 3) { - return tl::unexpected(fmt::format("LHS input has incorrect number of shard dims: {} != {}", num_dims(input_lhs), 3)); + return tl::unexpected( + fmt::format("LHS input has incorrect number of shard dims: {} != {}", + num_dims(input_lhs), + 3)); } if (num_dims(input_rhs) != 3) { - return tl::unexpected(fmt::format("RHS input has incorrect number of shard dims: {} != {}", num_dims(input_rhs), 3)); + return tl::unexpected( + fmt::format("RHS input has incorrect number of shard dims: {} != {}", + num_dims(input_rhs), + 3)); } if (input_lhs.data_type != input_rhs.data_type) { - return tl::unexpected(fmt::format("Input datatypes do not match: {} != {}", input_lhs.data_type, input_rhs.data_type)); + return tl::unexpected(fmt::format("Input datatypes do not match: {} != {}", + input_lhs.data_type, + input_rhs.data_type)); } - + size_t lhs_b = dim_at_idx(input_lhs, ff_dim_t{0}); size_t n = dim_at_idx(input_lhs, ff_dim_t{1}); size_t lhs_m = dim_at_idx(input_lhs, ff_dim_t{2}); @@ -56,38 +66,50 @@ tl::expected get_output_shape(BatchMatmulAttrs const & size_t p = dim_at_idx(input_rhs, ff_dim_t{2}); if (lhs_b != rhs_b) { - return tl::unexpected(fmt::format("LHS b dim ({}) != RHS b dim ({})", lhs_b, rhs_b)); + return tl::unexpected( + fmt::format("LHS b dim ({}) != RHS b dim ({})", lhs_b, rhs_b)); } if (lhs_m != rhs_m) { - return tl::unexpected(fmt::format("RHS m dim ({}) != RHS m dim ({})", lhs_m, rhs_m)); + return tl::unexpected( + fmt::format("RHS m dim ({}) != RHS m dim ({})", lhs_m, rhs_m)); } return TensorShape{ - TensorDims{ - FFOrdered{ - lhs_b, - n, - p, + TensorDims{ + FFOrdered{ + lhs_b, + n, + p, + }, }, - }, - input_lhs.data_type, + input_lhs.data_type, }; } -tl::expected get_output_shape(BatchMatmulAttrs const &attrs, - ParallelTensorShape const &input_lhs, - ParallelTensorShape const &input_rhs) { +tl::expected + get_output_shape(BatchMatmulAttrs const &attrs, + ParallelTensorShape const &input_lhs, + ParallelTensorShape const &input_rhs) { if (num_shard_dims(input_lhs) != 3) { - return tl::unexpected(fmt::format("LHS input has incorrect number of shard dims: {} != {}", num_shard_dims(input_lhs), 3)); + return tl::unexpected( + fmt::format("LHS input has incorrect number of shard dims: {} != {}", + num_shard_dims(input_lhs), + 3)); } if (num_shard_dims(input_rhs) != 3) { - return tl::unexpected(fmt::format("RHS input has incorrect number of shard dims: {} != {}", num_shard_dims(input_rhs), 3)); + return tl::unexpected( + fmt::format("RHS input has incorrect number of shard dims: {} != {}", + num_shard_dims(input_rhs), + 3)); } if (input_lhs.data_type != input_rhs.data_type) { - return tl::unexpected(fmt::format("Input datatypes do not match: {} != {}", input_lhs.data_type, input_rhs.data_type)); + return tl::unexpected(fmt::format("Input datatypes do not match: {} != {}", + input_lhs.data_type, + input_rhs.data_type)); } - - assert (get_total_parallel_degree(input_lhs) == get_total_parallel_degree(input_rhs)); + + assert(get_total_parallel_degree(input_lhs) == + get_total_parallel_degree(input_rhs)); ShardParallelDim lhs_b = shard_dim_at_idx(input_lhs, ff_dim_t{0}); ShardParallelDim n = shard_dim_at_idx(input_lhs, ff_dim_t{1}); @@ -98,19 +120,31 @@ tl::expected get_output_shape(BatchMatmulAttrs ShardParallelDim p = shard_dim_at_idx(input_rhs, ff_dim_t{2}); if (lhs_b != rhs_b) { - return tl::unexpected(fmt::format("LHS b dim ({}) != RHS b dim ({})", lhs_b, rhs_b)); + return tl::unexpected( + fmt::format("LHS b dim ({}) != RHS b dim ({})", lhs_b, rhs_b)); } if (lhs_m != rhs_m) { - return tl::unexpected(fmt::format("LHS m dim ({}) != RHS m dim ({})", lhs_m, rhs_m)); + return tl::unexpected( + fmt::format("LHS m dim ({}) != RHS m dim ({})", lhs_m, rhs_m)); } - if (get_discard_copy_degree(input_lhs) != get_sum_degree(input_rhs) * p.degree) { - return tl::unexpected(fmt::format("Unexpected number of replicas in LHS: lhs.= ({}) != rhs.+ ({}) * rhs.p ({})", get_discard_copy_degree(input_lhs), get_sum_degree(input_rhs), p.degree)); + if (get_discard_copy_degree(input_lhs) != + get_sum_degree(input_rhs) * p.degree) { + return tl::unexpected(fmt::format("Unexpected number of replicas in LHS: " + "lhs.= ({}) != rhs.+ ({}) * rhs.p ({})", + get_discard_copy_degree(input_lhs), + get_sum_degree(input_rhs), + p.degree)); } - if (get_discard_copy_degree(input_rhs) != get_sum_degree(input_lhs) * n.degree) { - return tl::unexpected(fmt::format("Unexpected number of replicas in RHS: rhs.= ({}) != lhs.+ ({}) * lhs.n ({})", get_discard_copy_degree(input_rhs), get_sum_degree(input_lhs), n.degree)); + if (get_discard_copy_degree(input_rhs) != + get_sum_degree(input_lhs) * n.degree) { + return tl::unexpected(fmt::format("Unexpected number of replicas in RHS: " + "rhs.= ({}) != lhs.+ ({}) * lhs.n ({})", + get_discard_copy_degree(input_rhs), + get_sum_degree(input_lhs), + n.degree)); } ShardParallelDim output_b = lhs_b; @@ -118,21 +152,22 @@ tl::expected get_output_shape(BatchMatmulAttrs ShardParallelDim output_p = p; int output_discard_copy_degree = 1; - int output_sum_degree = get_total_parallel_degree(input_lhs) / (output_b.degree * output_n.degree * output_p.degree); + int output_sum_degree = get_total_parallel_degree(input_lhs) / + (output_b.degree * output_n.degree * output_p.degree); ParallelTensorShape result = ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - output_b, - output_n, - output_p, - }, - ReplicaParallelDimSet{ - output_sum_degree, - output_discard_copy_degree, + ParallelTensorDims{ + FFOrdered{ + output_b, + output_n, + output_p, + }, + ReplicaParallelDimSet{ + output_sum_degree, + output_discard_copy_degree, + }, }, - }, - input_lhs.data_type, + input_lhs.data_type, }; return result; diff --git a/lib/op-attrs/src/op-attrs/ops/combine.cc b/lib/op-attrs/src/op-attrs/ops/combine.cc index a91fe43452..e41b78c5af 100644 --- a/lib/op-attrs/src/op-attrs/ops/combine.cc +++ b/lib/op-attrs/src/op-attrs/ops/combine.cc @@ -3,18 +3,29 @@ namespace FlexFlow { -tl::expected get_output_shape(CombineAttrs const &attrs, ParallelTensorShape const &input) { +tl::expected + get_output_shape(CombineAttrs const &attrs, + ParallelTensorShape const &input) { ShardParallelDim input_dim = ({ - std::optional result = try_get_shard_dim_at_idx(input, attrs.combine_dim); + std::optional result = + try_get_shard_dim_at_idx(input, attrs.combine_dim); if (!result.has_value()) { - return tl::unexpected(fmt::format("Failed to get shard dim at index {} in parallel tensor shape {}", attrs.combine_dim, input)); + return tl::unexpected(fmt::format( + "Failed to get shard dim at index {} in parallel tensor shape {}", + attrs.combine_dim, + input)); } result.value(); }); if (input_dim.degree % attrs.combine_degree != 0) { - return tl::unexpected(fmt::format("Combine received tensor containing parallel dim {} with degree {}, which is not divisible by combine degree {}", attrs.combine_dim, input_dim.degree, attrs.combine_degree)); + return tl::unexpected( + fmt::format("Combine received tensor containing parallel dim {} with " + "degree {}, which is not divisible by combine degree {}", + attrs.combine_dim, + input_dim.degree, + attrs.combine_degree)); } ParallelTensorShape output = input; diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary.cc b/lib/op-attrs/src/op-attrs/ops/element_binary.cc index d998834f0a..16957a036c 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_binary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_binary.cc @@ -3,18 +3,21 @@ namespace FlexFlow { tl::expected -get_output_shape(ElementBinaryAttrs const &attrs, - TensorShape const &input_lhs, - TensorShape const &input_rhs) { - assert (!(attrs.should_broadcast_lhs && attrs.should_broadcast_rhs)); + get_output_shape(ElementBinaryAttrs const &attrs, + TensorShape const &input_lhs, + TensorShape const &input_rhs) { + assert(!(attrs.should_broadcast_lhs && attrs.should_broadcast_rhs)); if (attrs.should_broadcast_lhs) { - NOT_IMPLEMENTED(); + NOT_IMPLEMENTED(); } else if (attrs.should_broadcast_rhs) { NOT_IMPLEMENTED(); } else { if (input_lhs != input_rhs) { - return tl::unexpected(fmt::format("Expected input shapes to match, but receieved LHS ({}) != RHS ({})", input_lhs, input_rhs)); + return tl::unexpected(fmt::format( + "Expected input shapes to match, but receieved LHS ({}) != RHS ({})", + input_lhs, + input_rhs)); } return input_lhs; @@ -22,25 +25,30 @@ get_output_shape(ElementBinaryAttrs const &attrs, } tl::expected - get_output_shape(ElementBinaryAttrs const &attrs, - ParallelTensorShape const &input_lhs, - ParallelTensorShape const &input_rhs) { - assert (!(attrs.should_broadcast_lhs && attrs.should_broadcast_rhs)); + get_output_shape(ElementBinaryAttrs const &attrs, + ParallelTensorShape const &input_lhs, + ParallelTensorShape const &input_rhs) { + assert(!(attrs.should_broadcast_lhs && attrs.should_broadcast_rhs)); if (attrs.should_broadcast_lhs) { - NOT_IMPLEMENTED(); + NOT_IMPLEMENTED(); } else if (attrs.should_broadcast_rhs) { NOT_IMPLEMENTED(); } else { if (input_lhs != input_rhs) { - return tl::unexpected(fmt::format("Expected input shapes to match, but receieved LHS ({}) != RHS ({})", input_lhs, input_rhs)); + return tl::unexpected(fmt::format( + "Expected input shapes to match, but receieved LHS ({}) != RHS ({})", + input_lhs, + input_rhs)); } switch (attrs.type) { - case OperatorType::EW_ADD: - { + case OperatorType::EW_ADD: { if (get_discard_copy_degree(input_lhs) != 1) { - return tl::unexpected(fmt::format("Elementwise Add expected discard copy degree of inputs to be 1, but receieved {}", get_discard_copy_degree(input_lhs))); + return tl::unexpected( + fmt::format("Elementwise Add expected discard copy degree of " + "inputs to be 1, but receieved {}", + get_discard_copy_degree(input_lhs))); } break; @@ -56,7 +64,8 @@ tl::expected case OperatorType::EW_MIN: NOT_IMPLEMENTED(); default: - return tl::unexpected(fmt::format("Unexpected element-wise binary operator {}", attrs.type)); + return tl::unexpected(fmt::format( + "Unexpected element-wise binary operator {}", attrs.type)); } return input_lhs; diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/src/op-attrs/ops/element_unary.cc index 66feac42e4..f703799ef3 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_unary.cc @@ -4,42 +4,51 @@ namespace FlexFlow { tl::expected - get_output_shape(ElementUnaryAttrs const &attrs, TensorShape const &input_shape) { + get_output_shape(ElementUnaryAttrs const &attrs, + TensorShape const &input_shape) { return input_shape; } -tl::expected get_output_shape(ElementUnaryAttrs const &attrs, - ParallelTensorShape const &input_shape) { +tl::expected + get_output_shape(ElementUnaryAttrs const &attrs, + ParallelTensorShape const &input_shape) { if (get_sum_degree(input_shape) != 1) { - return tl::unexpected(fmt::format("Expected sum degree 1, but receieved sum degree {}", get_sum_degree(input_shape))); + return tl::unexpected( + fmt::format("Expected sum degree 1, but receieved sum degree {}", + get_sum_degree(input_shape))); } if (get_discard_copy_degree(input_shape) != 1) { - return tl::unexpected(fmt::format("Expected discard copy degree 1, but received discartd copy degree {}", get_discard_copy_degree(input_shape))); + return tl::unexpected(fmt::format( + "Expected discard copy degree 1, but received discartd copy degree {}", + get_discard_copy_degree(input_shape))); } return input_shape; } tl::expected - get_output_shape(ElementScalarUnaryAttrs const &attrs, - TensorShape const &input_shape) { + get_output_shape(ElementScalarUnaryAttrs const &attrs, + TensorShape const &input_shape) { return input_shape; } tl::expected - get_output_shape(ElementScalarUnaryAttrs const &attrs, - ParallelTensorShape const &input_shape) { + get_output_shape(ElementScalarUnaryAttrs const &attrs, + ParallelTensorShape const &input_shape) { if (get_sum_degree(input_shape) != 1) { - return tl::unexpected(fmt::format("Expected sum degree 1, but receieved sum degree {}", get_sum_degree(input_shape))); + return tl::unexpected( + fmt::format("Expected sum degree 1, but receieved sum degree {}", + get_sum_degree(input_shape))); } if (get_discard_copy_degree(input_shape) != 1) { - return tl::unexpected(fmt::format("Expected discard copy degree 1, but received discartd copy degree {}", get_discard_copy_degree(input_shape))); + return tl::unexpected(fmt::format( + "Expected discard copy degree 1, but received discartd copy degree {}", + get_discard_copy_degree(input_shape))); } return input_shape; } - } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/embedding.cc b/lib/op-attrs/src/op-attrs/ops/embedding.cc index b62656ff02..9e9ad3a194 100644 --- a/lib/op-attrs/src/op-attrs/ops/embedding.cc +++ b/lib/op-attrs/src/op-attrs/ops/embedding.cc @@ -1,25 +1,30 @@ #include "op-attrs/ops/embedding.h" -#include "utils/integer_conversions.h" -#include "op-attrs/dim_ordered/transform.h" #include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/dim_ordered/transform.h" #include "utils/containers.h" +#include "utils/integer_conversions.h" namespace FlexFlow { -static std::optional basic_check(EmbeddingAttrs const &attrs, TensorShape const &input) { - if (input.data_type != DataType::INT32 && input.data_type != DataType::INT64) { - return fmt::format("Embedding expected input tensor to have integer datatype, but receieved tensor of datatype {}", input.data_type); +static std::optional basic_check(EmbeddingAttrs const &attrs, + TensorShape const &input) { + if (input.data_type != DataType::INT32 && + input.data_type != DataType::INT64) { + return fmt::format("Embedding expected input tensor to have integer " + "datatype, but receieved tensor of datatype {}", + input.data_type); } if (attrs.aggr != AggregateOp::SUM) { - return fmt::format(fmt::format("Currently unsupported aggregation op for embedding: {}", attrs.aggr)); + return fmt::format(fmt::format( + "Currently unsupported aggregation op for embedding: {}", attrs.aggr)); } return std::nullopt; } -tl::expected -get_output_shape(EmbeddingAttrs const &attrs, TensorShape const &input) { +tl::expected + get_output_shape(EmbeddingAttrs const &attrs, TensorShape const &input) { { std::optional err_msg = basic_check(attrs, input); if (err_msg.has_value()) { @@ -33,8 +38,8 @@ get_output_shape(EmbeddingAttrs const &attrs, TensorShape const &input) { return output; } -tl::expected -get_weights_shape(EmbeddingAttrs const &attrs, TensorShape const &input) { +tl::expected + get_weights_shape(EmbeddingAttrs const &attrs, TensorShape const &input) { { std::optional err_msg = basic_check(attrs, input); if (err_msg.has_value()) { @@ -43,21 +48,23 @@ get_weights_shape(EmbeddingAttrs const &attrs, TensorShape const &input) { } return TensorShape{ - TensorDims{ - FFOrdered{ - size_t_from_int(attrs.num_entries), - size_t_from_int(attrs.out_channels), + TensorDims{ + FFOrdered{ + size_t_from_int(attrs.num_entries), + size_t_from_int(attrs.out_channels), + }, }, - }, - attrs.data_type, + attrs.data_type, }; } tl::expected -get_output_shape(EmbeddingAttrs const &attrs, ParallelTensorShape const &input) { + get_output_shape(EmbeddingAttrs const &attrs, + ParallelTensorShape const &input) { TensorShape unpar = ({ - tl::expected result_unpar = get_output_shape(attrs, get_reduced_shape(input)); + tl::expected result_unpar = + get_output_shape(attrs, get_reduced_shape(input)); if (!result_unpar.has_value()) { return tl::unexpected(result_unpar.error()); } @@ -66,16 +73,21 @@ get_output_shape(EmbeddingAttrs const &attrs, ParallelTensorShape const &input) SumDegree sum_degree = shard_dim_at_idx(input, ff_dim_t{-1}).degree; DiscardCopyDegree discard_copy_degree = 1; - FFOrdered shard_degrees = transform(input.dims.shard_dims, [](ShardParallelDim const &d) { return d.degree; }); + FFOrdered shard_degrees = + transform(input.dims.shard_dims, + [](ShardParallelDim const &d) { return d.degree; }); shard_degrees.at(ff_dim_t{-1}) = get_discard_copy_degree(input); - return lift_to_parallel_with_degrees(unpar, sum_degree, discard_copy_degree, shard_degrees); + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); } -tl::expected -get_weights_shape(EmbeddingAttrs const &attrs, ParallelTensorShape const &input) { +tl::expected + get_weights_shape(EmbeddingAttrs const &attrs, + ParallelTensorShape const &input) { TensorShape unpar = ({ - tl::expected result_unpar = get_weights_shape(attrs, get_reduced_shape(input)); + tl::expected result_unpar = + get_weights_shape(attrs, get_reduced_shape(input)); if (!result_unpar.has_value()) { return tl::unexpected(result_unpar.error()); } @@ -83,18 +95,18 @@ get_weights_shape(EmbeddingAttrs const &attrs, ParallelTensorShape const &input) }); SumDegree sum_degree = 1; - DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ - product(transform(ff_ordered_shard_dims(input.dims), - [](ShardParallelDim const &d) -> int { return d.degree; })) - }; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{product( + transform(ff_ordered_shard_dims(input.dims), + [](ShardParallelDim const &d) -> int { return d.degree; }))}; int entry_dim_degree = 1; int out_channel_degree = get_discard_copy_degree(input); FFOrdered shard_degrees = { - entry_dim_degree, - out_channel_degree, + entry_dim_degree, + out_channel_degree, }; - return lift_to_parallel_with_degrees(unpar, sum_degree, discard_copy_degree, shard_degrees); + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index 2bba41fa2a..8283673378 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -1,15 +1,14 @@ #include "op-attrs/ops/linear.h" +#include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/dim_ordered/transform.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" #include "utils/integer_conversions.h" -#include "op-attrs/dim_ordered/transform.h" -#include "op-attrs/dim_ordered/slice.h" namespace FlexFlow { tl::expected - get_kernel_shape(LinearAttrs const &attrs, - TensorShape const &input_shape) { + get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { size_t in_channels = dim_at_idx(input_shape, ff_dim_t{-1}); return TensorShape{ @@ -21,8 +20,7 @@ tl::expected } tl::expected -get_bias_shape(LinearAttrs const &attrs, - TensorShape const &input_shape) { + get_bias_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { return TensorShape{ TensorDims{ FFOrdered{size_t_from_int(attrs.out_channels)}, @@ -31,8 +29,8 @@ get_bias_shape(LinearAttrs const &attrs, }; } -tl::expected get_output_shape(LinearAttrs const &attrs, - TensorShape const &input_shape) { +tl::expected + get_output_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { TensorShape output_shape = input_shape; output_shape.dims.ff_ordered.at(ff_dim_t{-1}) = size_t_from_int(attrs.out_channels); @@ -40,10 +38,12 @@ tl::expected get_output_shape(LinearAttrs const &attrs return output_shape; } -tl::expected get_kernel_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input) { +tl::expected + get_kernel_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input) { TensorShape unpar = ({ - tl::expected result_unpar = get_kernel_shape(attrs, get_reduced_shape(input)); + tl::expected result_unpar = + get_kernel_shape(attrs, get_reduced_shape(input)); if (!result_unpar.has_value()) { return tl::unexpected(result_unpar.error()); } @@ -52,50 +52,59 @@ tl::expected get_kernel_shape(LinearAttrs cons SumDegree sum_degree = 1; DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ - get_sum_degree(input) * product(slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1})) - }; + get_sum_degree(input) * + product( + slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1}))}; FFOrdered shard_degrees = FFOrdered{ - shard_dim_at_idx(input, ff_dim_t{-1}).degree, - get_discard_copy_degree(input), + shard_dim_at_idx(input, ff_dim_t{-1}).degree, + get_discard_copy_degree(input), }; - return lift_to_parallel_with_degrees(unpar, sum_degree, discard_copy_degree, shard_degrees); + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); } tl::expected - get_bias_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input) { + get_bias_shape(LinearAttrs const &attrs, ParallelTensorShape const &input) { TensorShape unpar = ({ - tl::expected result_unpar = get_bias_shape(attrs, get_reduced_shape(input)); + tl::expected result_unpar = + get_bias_shape(attrs, get_reduced_shape(input)); if (!result_unpar.has_value()) { return tl::unexpected(result_unpar.error()); } result_unpar.value(); }); - SumDegree sum_degree = get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree; - DiscardCopyDegree discard_copy_degree = product(slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1})); + SumDegree sum_degree = + get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree; + DiscardCopyDegree discard_copy_degree = product( + slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1})); FFOrdered shard_degrees = FFOrdered{get_discard_copy_degree(input)}; - return lift_to_parallel_with_degrees(unpar, sum_degree, discard_copy_degree, shard_degrees); + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); } -tl::expected get_output_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input) { +tl::expected + get_output_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input) { TensorShape unpar = ({ - tl::expected result_unpar = get_output_shape(attrs, get_reduced_shape(input)); + tl::expected result_unpar = + get_output_shape(attrs, get_reduced_shape(input)); if (!result_unpar.has_value()) { return tl::unexpected(result_unpar.error()); } result_unpar.value(); }); - SumDegree sum_degree = get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree; + SumDegree sum_degree = + get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree; DiscardCopyDegree discard_copy_degree = 1; FFOrdered shard_degrees = ff_ordered_shard_degrees(input); shard_degrees.at(ff_dim_t{-1}) = get_discard_copy_degree(input); - return lift_to_parallel_with_degrees(unpar, sum_degree, discard_copy_degree, shard_degrees); + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reduction.cc b/lib/op-attrs/src/op-attrs/ops/reduction.cc index 1b0a393058..0fef6f37d6 100644 --- a/lib/op-attrs/src/op-attrs/ops/reduction.cc +++ b/lib/op-attrs/src/op-attrs/ops/reduction.cc @@ -4,10 +4,14 @@ namespace FlexFlow { tl::expected - get_output_shape(ReductionAttrs const &attrs, - ParallelTensorShape const &input_shape) { + get_output_shape(ReductionAttrs const &attrs, + ParallelTensorShape const &input_shape) { if (get_sum_degree(input_shape) % attrs.reduction_degree != 0) { - return tl::unexpected(fmt::format("Reduction received tensor with sum degree {}, which is not divisible by reduction degree {}", get_sum_degree(input_shape), attrs.reduction_degree)); + return tl::unexpected( + fmt::format("Reduction received tensor with sum degree {}, which is " + "not divisible by reduction degree {}", + get_sum_degree(input_shape), + attrs.reduction_degree)); } ParallelTensorShape output_shape = input_shape; diff --git a/lib/op-attrs/src/op-attrs/ops/repartition.cc b/lib/op-attrs/src/op-attrs/ops/repartition.cc index e668ccce6c..37a0b8a168 100644 --- a/lib/op-attrs/src/op-attrs/ops/repartition.cc +++ b/lib/op-attrs/src/op-attrs/ops/repartition.cc @@ -3,10 +3,11 @@ namespace FlexFlow { tl::expected - get_output_shape(RepartitionAttrs const &attrs, - ParallelTensorShape const &input_shape) { + get_output_shape(RepartitionAttrs const &attrs, + ParallelTensorShape const &input_shape) { ParallelTensorShape output_shape = input_shape; - output_shape.dims.shard_dims.at(attrs.repartition_dim).degree *= attrs.repartition_degree; + output_shape.dims.shard_dims.at(attrs.repartition_dim).degree *= + attrs.repartition_degree; return output_shape; } diff --git a/lib/op-attrs/src/op-attrs/ops/replicate.cc b/lib/op-attrs/src/op-attrs/ops/replicate.cc index 046558b452..9e163cb55a 100644 --- a/lib/op-attrs/src/op-attrs/ops/replicate.cc +++ b/lib/op-attrs/src/op-attrs/ops/replicate.cc @@ -5,7 +5,8 @@ namespace FlexFlow { ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, ParallelTensorShape const &input_shape) { ParallelTensorShape output_shape = input_shape; - output_shape.dims.replica_dims.discard_copy_degree.value *= attrs.replicate_degree; + output_shape.dims.replica_dims.discard_copy_degree.value *= + attrs.replicate_degree; return output_shape; } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 4ade47840b..ff5a8224df 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -1,10 +1,10 @@ #include "op-attrs/parallel_tensor_dims.h" +#include "op-attrs/dim_ordered/transform.h" #include "op-attrs/replica_parallel_dim.h" #include "op-attrs/replica_parallel_dim_set.h" #include "op-attrs/shard_parallel_dim.h" #include "utils/containers.h" #include "utils/integer_conversions.h" -#include "op-attrs/dim_ordered/transform.h" namespace FlexFlow { @@ -13,7 +13,8 @@ FFOrdered ff_ordered_shard_dims(ParallelTensorDims const &d) { } FFOrdered ff_ordered_shard_degrees(ParallelTensorDims const &d) { - return transform(d.shard_dims, [](ShardParallelDim const &d) { return d.degree; }); + return transform(d.shard_dims, + [](ShardParallelDim const &d) { return d.degree; }); } std::unordered_set @@ -26,7 +27,8 @@ size_t num_shard_dims(ParallelTensorDims const &dims) { } int total_replica_degree(ParallelTensorDims const &dims) { - return dims.replica_dims.discard_copy_degree.value * dims.replica_dims.sum_degree.value; + return dims.replica_dims.discard_copy_degree.value * + dims.replica_dims.sum_degree.value; } int total_shard_degree(ParallelTensorDims const &dims) { @@ -62,7 +64,8 @@ TensorDims get_tensor_dims_unsafe(ParallelTensorDims const &) { } TensorDims get_reduced_dims(ParallelTensorDims const &dims) { - FFOrdered dim_sizes = transform(dims.shard_dims, [](ShardParallelDim const &d) { return d.size; }); + FFOrdered dim_sizes = transform( + dims.shard_dims, [](ShardParallelDim const &d) { return d.size; }); return TensorDims{dim_sizes}; } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index a0ebb6f9c2..516cbe191f 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -29,7 +29,7 @@ int get_discard_copy_degree(ParallelTensorShape const &shape) { } int get_total_parallel_degree(ParallelTensorShape const &s) { - return total_parallel_degree(s.dims); + return total_parallel_degree(s.dims); } bool is_valid(ParallelTensorShape const &shape) { @@ -48,7 +48,8 @@ FFOrdered ff_ordered_shard_degrees(ParallelTensorShape const &s) { return ff_ordered_shard_degrees(s.dims); } -std::optional try_get_shard_dim_at_idx(ParallelTensorShape const &s, ff_dim_t d) { +std::optional + try_get_shard_dim_at_idx(ParallelTensorShape const &s, ff_dim_t d) { if (s.dims.shard_dims.idx_is_valid(d)) { return s.dims.shard_dims.at(d); } else { @@ -60,10 +61,15 @@ ParallelTensorShape lift_to_parallel(TensorShape const &s) { return {lift_to_parallel(s.dims), s.data_type}; } -ParallelTensorShape lift_to_parallel_with_degrees(TensorShape const &s, SumDegree sum_degree, DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees) { +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &s, + SumDegree sum_degree, + DiscardCopyDegree discard_copy_degree, + FFOrdered const &shard_degrees) { return ParallelTensorShape{ - lift_to_parallel_with_degrees(s.dims, sum_degree, discard_copy_degree, shard_degrees), - s.data_type, + lift_to_parallel_with_degrees( + s.dims, sum_degree, discard_copy_degree, shard_degrees), + s.data_type, }; } diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc index 7ab2e2d2ef..7ef228e97e 100644 --- a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc @@ -24,7 +24,8 @@ std::unordered_set get_replica_dims(ReplicaParallelDimSet const &s) { return std::unordered_set{ ReplicaParallelDim{s.sum_degree.value, ReplicaType::SUM}, - ReplicaParallelDim{s.discard_copy_degree.value, ReplicaType::DISCARD_COPY}, + ReplicaParallelDim{s.discard_copy_degree.value, + ReplicaType::DISCARD_COPY}, }; } diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index 512e9f0804..ed40f509d9 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -24,26 +24,29 @@ size_t &dim_at_idx(TensorDims &dims, ff_dim_t idx) { } ParallelTensorDims lift_to_parallel(TensorDims const &dims) { - std::vector shard_degrees(num_dims(dims), 1); // 1 repeated num_dims(dims) times + std::vector shard_degrees(num_dims(dims), + 1); // 1 repeated num_dims(dims) times return lift_to_parallel_with_degrees(dims, 1, 1, shard_degrees); } -ParallelTensorDims lift_to_parallel_with_degrees(TensorDims const &dims, SumDegree sum_degree, DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees) { +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &dims, + SumDegree sum_degree, + DiscardCopyDegree discard_copy_degree, + FFOrdered const &shard_degrees) { std::vector lifted = - transform(zip(as_vector(dims.ff_ordered), as_vector(shard_degrees)), + transform(zip(as_vector(dims.ff_ordered), as_vector(shard_degrees)), [](std::pair const &p) { size_t size = p.first; int degree = p.second; return ShardParallelDim(size, degree); }); - return ParallelTensorDims{ - FFOrdered{lifted}, - ReplicaParallelDimSet{ - sum_degree, - discard_copy_degree, - } - }; + return ParallelTensorDims{FFOrdered{lifted}, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }}; } } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/dim_ordered/slice.cc b/lib/op-attrs/test/src/dim_ordered/slice.cc index 26e49e630e..8640b077dc 100644 --- a/lib/op-attrs/test/src/dim_ordered/slice.cc +++ b/lib/op-attrs/test/src/dim_ordered/slice.cc @@ -2,14 +2,20 @@ #include "test/utils/doctest.h" TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("slice(DimOrdered, std::optional, std::optional)") { + TEST_CASE( + "slice(DimOrdered, std::optional, std::optional)") { FFOrdered d = FFOrdered{ - 1, 2, 3, 4, + 1, + 2, + 3, + 4, }; FFOrdered result = slice(d, std::nullopt, ff_dim_t{-1}); FFOrdered correct = FFOrdered{ - 1, 2, 3, + 1, + 2, + 3, }; CHECK(result == correct); diff --git a/lib/op-attrs/test/src/ops/combine.cc b/lib/op-attrs/test/src/ops/combine.cc index 61f9b1d138..a50b3b01de 100644 --- a/lib/op-attrs/test/src/ops/combine.cc +++ b/lib/op-attrs/test/src/ops/combine.cc @@ -1,34 +1,35 @@ -#include "test/utils/doctest.h" #include "op-attrs/ops/combine.h" +#include "test/utils/doctest.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Combine shape inference") { ParallelTensorShape input = { - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{12, 2}, - ShardParallelDim{14, 1}, - ShardParallelDim{16, 3}, - ShardParallelDim{18, 2}, - }, - ReplicaParallelDimSet{ - SumDegree{3}, - DiscardCopyDegree{2}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{14, 1}, + ShardParallelDim{16, 3}, + ShardParallelDim{18, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; SUBCASE("valid") { ff_dim_t dim = 2; int degree = 3; CombineAttrs attrs = CombineAttrs{ - /*repartition_dim=*/dim, - /*repartition_degree=*/degree, + /*repartition_dim=*/dim, + /*repartition_degree=*/degree, }; - tl::expected result = get_output_shape(attrs, input); + tl::expected result = + get_output_shape(attrs, input); tl::expected correct = [&] { ParallelTensorShape output = input; @@ -43,13 +44,16 @@ TEST_SUITE(FF_TEST_SUITE) { ff_dim_t dim = 2; int degree = 4; CombineAttrs attrs = CombineAttrs{ - /*repartition_dim=*/dim, - /*repartition_degree=*/degree, + /*repartition_dim=*/dim, + /*repartition_degree=*/degree, }; - tl::expected result = get_output_shape(attrs, input); + tl::expected result = + get_output_shape(attrs, input); - CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", result.error()); + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); } } } diff --git a/lib/op-attrs/test/src/ops/linear.cc b/lib/op-attrs/test/src/ops/linear.cc index 2de04a8a03..0d23dc35df 100644 --- a/lib/op-attrs/test/src/ops/linear.cc +++ b/lib/op-attrs/test/src/ops/linear.cc @@ -1,17 +1,17 @@ -#include "test/utils/doctest.h" #include "op-attrs/ops/linear.h" -#include "utils/integer_conversions.h" #include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest.h" +#include "utils/integer_conversions.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Linear shape inference") { int out_channels = 16; LinearAttrs attrs = LinearAttrs{ - /*out_channels=*/out_channels, - /*use_bias=*/true, - /*data_type=*/DataType::FLOAT, - /*activation=*/Activation::RELU, - /*regularizer=*/std::nullopt, + /*out_channels=*/out_channels, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/Activation::RELU, + /*regularizer=*/std::nullopt, }; size_t batch_size = 12; @@ -19,105 +19,143 @@ TEST_SUITE(FF_TEST_SUITE) { size_t in_channels = 8; TensorShape input = TensorShape{ - TensorDims{ - FFOrdered{ - batch_size, - extra_dim, - in_channels, + TensorDims{ + FFOrdered{ + batch_size, + extra_dim, + in_channels, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; TensorShape output = TensorShape{ - TensorDims{ - FFOrdered{ - batch_size, - extra_dim, - size_t_from_int(out_channels), + TensorDims{ + FFOrdered{ + batch_size, + extra_dim, + size_t_from_int(out_channels), + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; TensorShape kernel = TensorShape{ - TensorDims{ - FFOrdered{ - in_channels, - size_t_from_int(out_channels), + TensorDims{ + FFOrdered{ + in_channels, + size_t_from_int(out_channels), + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; TensorShape bias = TensorShape{ - TensorDims{ - FFOrdered{ - size_t_from_int(out_channels), + TensorDims{ + FFOrdered{ + size_t_from_int(out_channels), + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; // get_output_shape { - tl::expected output_result = get_output_shape(attrs, input); + tl::expected output_result = + get_output_shape(attrs, input); tl::expected output_correct = output; CHECK(output_result == output_correct); } // get_weight_shape { - tl::expected kernel_result = get_kernel_shape(attrs, input); + tl::expected kernel_result = + get_kernel_shape(attrs, input); tl::expected kernel_correct = kernel; CHECK(kernel_result == kernel_correct); } // get_bias_shape { - tl::expected bias_result = get_bias_shape(attrs, input); + tl::expected bias_result = + get_bias_shape(attrs, input); tl::expected bias_correct = bias; CHECK(bias_result == bias_correct); } - auto make_input = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_batch, int o_extra_dim, int o_channel) { - return lift_to_parallel_with_degrees(input, o_sum, o_eq, FFOrdered{o_batch, o_extra_dim, o_channel}); + auto make_input = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_extra_dim, + int o_channel) { + return lift_to_parallel_with_degrees( + input, o_sum, o_eq, FFOrdered{o_batch, o_extra_dim, o_channel}); }; - auto make_output = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_batch, int o_extra_dim, int o_channel) { - return lift_to_parallel_with_degrees(output, o_sum, o_eq, FFOrdered{o_batch, o_extra_dim, o_channel}); + auto make_output = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_extra_dim, + int o_channel) { + return lift_to_parallel_with_degrees( + output, o_sum, o_eq, FFOrdered{o_batch, o_extra_dim, o_channel}); }; - auto make_kernel = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_inchannel, int o_outchannel) { - return lift_to_parallel_with_degrees(kernel, o_sum, o_eq, FFOrdered{o_inchannel, o_outchannel}); + auto make_kernel = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_inchannel, + int o_outchannel) { + return lift_to_parallel_with_degrees( + kernel, o_sum, o_eq, FFOrdered{o_inchannel, o_outchannel}); }; - auto make_bias = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_outchannel) { - return lift_to_parallel_with_degrees(bias, o_sum, o_eq, FFOrdered{o_outchannel}); - }; + auto make_bias = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_outchannel) { + return lift_to_parallel_with_degrees( + bias, o_sum, o_eq, FFOrdered{o_outchannel}); + }; SUBCASE("data parallelism") { int input_sum_degree = 2; int extra_dim_degree = 8; - int degree = 4; + int degree = 4; + + ParallelTensorShape par_input = make_input(SumDegree{input_sum_degree}, + DiscardCopyDegree{1}, + degree, + extra_dim_degree, + 1); - ParallelTensorShape par_input = make_input(SumDegree{input_sum_degree}, DiscardCopyDegree{1}, degree, extra_dim_degree, 1); - { - tl::expected result = get_output_shape(attrs, par_input); - tl::expected correct = make_output(SumDegree{input_sum_degree}, DiscardCopyDegree{1}, degree, extra_dim_degree, 1); + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = + make_output(SumDegree{input_sum_degree}, + DiscardCopyDegree{1}, + degree, + extra_dim_degree, + 1); CHECK(result == correct); } { - tl::expected result = get_kernel_shape(attrs, par_input); - tl::expected correct = make_kernel(SumDegree{1}, DiscardCopyDegree{input_sum_degree * degree * extra_dim_degree}, 1, 1); + tl::expected result = + get_kernel_shape(attrs, par_input); + tl::expected correct = make_kernel( + SumDegree{1}, + DiscardCopyDegree{input_sum_degree * degree * extra_dim_degree}, + 1, + 1); CHECK(result == correct); } { - tl::expected result = get_bias_shape(attrs, par_input); - tl::expected correct = make_bias(SumDegree{input_sum_degree}, DiscardCopyDegree{degree * extra_dim_degree}, 1); + tl::expected result = + get_bias_shape(attrs, par_input); + tl::expected correct = + make_bias(SumDegree{input_sum_degree}, + DiscardCopyDegree{degree * extra_dim_degree}, + 1); CHECK(result == correct); } } @@ -126,23 +164,34 @@ TEST_SUITE(FF_TEST_SUITE) { int input_sum_degree = 2; int degree = 4; - ParallelTensorShape par_input = make_input(SumDegree{input_sum_degree}, DiscardCopyDegree{1}, 1, 1, degree); + ParallelTensorShape par_input = make_input( + SumDegree{input_sum_degree}, DiscardCopyDegree{1}, 1, 1, degree); { - tl::expected result = get_output_shape(attrs, par_input); - tl::expected correct = make_output(SumDegree{input_sum_degree * degree}, DiscardCopyDegree{1}, 1, 1, 1); + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = + make_output(SumDegree{input_sum_degree * degree}, + DiscardCopyDegree{1}, + 1, + 1, + 1); CHECK(result == correct); } { - tl::expected result = get_kernel_shape(attrs, par_input); - tl::expected correct = make_kernel(SumDegree{1}, DiscardCopyDegree{input_sum_degree}, degree, 1); + tl::expected result = + get_kernel_shape(attrs, par_input); + tl::expected correct = make_kernel( + SumDegree{1}, DiscardCopyDegree{input_sum_degree}, degree, 1); CHECK(result == correct); } { - tl::expected result = get_bias_shape(attrs, par_input); - tl::expected correct = make_bias(SumDegree{input_sum_degree * degree}, DiscardCopyDegree{1}, 1); + tl::expected result = + get_bias_shape(attrs, par_input); + tl::expected correct = make_bias( + SumDegree{input_sum_degree * degree}, DiscardCopyDegree{1}, 1); CHECK(result == correct); } } @@ -151,23 +200,30 @@ TEST_SUITE(FF_TEST_SUITE) { int input_sum_degree = 2; int degree = 4; - ParallelTensorShape par_input = make_input(SumDegree{input_sum_degree}, DiscardCopyDegree{degree}, 1, 1, 1); + ParallelTensorShape par_input = make_input( + SumDegree{input_sum_degree}, DiscardCopyDegree{degree}, 1, 1, 1); { - tl::expected result = get_output_shape(attrs, par_input); - tl::expected correct = make_output(SumDegree{input_sum_degree}, DiscardCopyDegree{1}, 1, 1, degree); + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = make_output( + SumDegree{input_sum_degree}, DiscardCopyDegree{1}, 1, 1, degree); CHECK(result == correct); } { - tl::expected result = get_kernel_shape(attrs, par_input); - tl::expected correct = make_kernel(SumDegree{1}, DiscardCopyDegree{input_sum_degree}, 1, degree); + tl::expected result = + get_kernel_shape(attrs, par_input); + tl::expected correct = make_kernel( + SumDegree{1}, DiscardCopyDegree{input_sum_degree}, 1, degree); CHECK(result == correct); } { - tl::expected result = get_bias_shape(attrs, par_input); - tl::expected correct = make_bias(SumDegree{input_sum_degree}, DiscardCopyDegree{1}, degree); + tl::expected result = + get_bias_shape(attrs, par_input); + tl::expected correct = make_bias( + SumDegree{input_sum_degree}, DiscardCopyDegree{1}, degree); CHECK(result == correct); } } diff --git a/lib/op-attrs/test/src/ops/reduction.cc b/lib/op-attrs/test/src/ops/reduction.cc index dc51479d37..6f73951e00 100644 --- a/lib/op-attrs/test/src/ops/reduction.cc +++ b/lib/op-attrs/test/src/ops/reduction.cc @@ -1,32 +1,33 @@ -#include "test/utils/doctest.h" #include "op-attrs/ops/reduction.h" +#include "test/utils/doctest.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Reduction shape inference") { ParallelTensorShape input = { - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{12, 2}, - ShardParallelDim{14, 1}, - ShardParallelDim{16, 3}, - ShardParallelDim{18, 2}, - }, - ReplicaParallelDimSet{ - SumDegree{3}, - DiscardCopyDegree{2}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{14, 1}, + ShardParallelDim{16, 3}, + ShardParallelDim{18, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; SUBCASE("valid") { int degree = 3; ReductionAttrs attrs = ReductionAttrs{ - /*repartition_degree=*/degree, + /*repartition_degree=*/degree, }; - tl::expected result = get_output_shape(attrs, input); + tl::expected result = + get_output_shape(attrs, input); tl::expected correct = [&] { ParallelTensorShape output = input; @@ -40,12 +41,15 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("invalid") { int degree = 4; ReductionAttrs attrs = ReductionAttrs{ - /*repartition_degree=*/degree, + /*repartition_degree=*/degree, }; - tl::expected result = get_output_shape(attrs, input); + tl::expected result = + get_output_shape(attrs, input); - CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", result.error()); + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); } } } diff --git a/lib/op-attrs/test/src/ops/repartition.cc b/lib/op-attrs/test/src/ops/repartition.cc index 62c6a66799..3b3ae92b4c 100644 --- a/lib/op-attrs/test/src/ops/repartition.cc +++ b/lib/op-attrs/test/src/ops/repartition.cc @@ -1,33 +1,33 @@ -#include "test/utils/doctest.h" #include "op-attrs/ops/repartition.h" +#include "test/utils/doctest.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Repartition shape inference") { ff_dim_t dim = 2; int degree = 4; RepartitionAttrs attrs = RepartitionAttrs{ - /*repartition_dim=*/dim, - /*repartition_degree=*/degree, + /*repartition_dim=*/dim, + /*repartition_degree=*/degree, }; ParallelTensorShape input = { - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{12, 2}, - ShardParallelDim{14, 1}, - ShardParallelDim{16, 3}, - ShardParallelDim{18, 2}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{14, 1}, + ShardParallelDim{16, 3}, + ShardParallelDim{18, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, }, - ReplicaParallelDimSet{ - SumDegree{3}, - DiscardCopyDegree{2}, - }, - }, - DataType::FLOAT, + DataType::FLOAT, }; - tl::expected result = get_output_shape(attrs, input); - + tl::expected result = + get_output_shape(attrs, input); tl::expected correct = [&] { ParallelTensorShape output = input; diff --git a/lib/op-attrs/test/src/ops/replicate.cc b/lib/op-attrs/test/src/ops/replicate.cc index 412ec5a2b4..b326038388 100644 --- a/lib/op-attrs/test/src/ops/replicate.cc +++ b/lib/op-attrs/test/src/ops/replicate.cc @@ -1,29 +1,29 @@ -#include "test/utils/doctest.h" #include "op-attrs/ops/replicate.h" +#include "test/utils/doctest.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Replicate shape inference") { ReplicateAttrs attrs = ReplicateAttrs{ - /*replicate_degree=*/4, + /*replicate_degree=*/4, }; ParallelTensorShape input = { - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{10, 2}, - ShardParallelDim{12, 1}, - ShardParallelDim{14, 2}, - ShardParallelDim{16, 2}, - }, - ReplicaParallelDimSet{ - SumDegree{3}, - DiscardCopyDegree{2}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + ShardParallelDim{14, 2}, + ShardParallelDim{16, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; - ParallelTensorShape result = get_output_shape(attrs, input); + ParallelTensorShape result = get_output_shape(attrs, input); ParallelTensorShape correct_output = input; correct_output.dims.replica_dims.discard_copy_degree = 8; diff --git a/lib/op-attrs/test/src/test_attention.cc b/lib/op-attrs/test/src/test_attention.cc index a28068780a..74ae4565ca 100644 --- a/lib/op-attrs/test/src/test_attention.cc +++ b/lib/op-attrs/test/src/test_attention.cc @@ -1,96 +1,97 @@ -#include "test/utils/doctest.h" #include "op-attrs/ops/attention.h" -#include "utils/integer_conversions.h" #include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest.h" +#include "utils/integer_conversions.h" TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_output_shape(MultiHeadAttentionAttrs, TensorShape, TensorShape, TensorShape)") { + TEST_CASE("get_output_shape(MultiHeadAttentionAttrs, TensorShape, " + "TensorShape, TensorShape)") { int embed_dim = 32; - /* Parameter meanings match those at + /* Parameter meanings match those at * https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html */ MultiHeadAttentionAttrs attrs = { - /*embed_dim=*/embed_dim, - /*num_heads=*/10, - /*kdim=*/embed_dim, - /*vdim=*/embed_dim, - /*dropout=*/0.0, - /*bias=*/true, - /*add_bias_kv=*/false, - /*add_zero_attn=*/false, + /*embed_dim=*/embed_dim, + /*num_heads=*/10, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, }; size_t batch_size = 40; size_t seq_len = 48; TensorShape input_q = { - TensorDims{ - FFOrdered{ - batch_size, - seq_len, - size_t_from_int(attrs.embed_dim), - } - }, - DataType::FLOAT, + TensorDims{FFOrdered{ + batch_size, + seq_len, + size_t_from_int(attrs.embed_dim), + }}, + DataType::FLOAT, }; TensorShape input_k = { - TensorDims{ - FFOrdered{ - batch_size, - seq_len, - size_t_from_int(attrs.kdim), + TensorDims{ + FFOrdered{ + batch_size, + seq_len, + size_t_from_int(attrs.kdim), + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; TensorShape input_v = { - TensorDims{ - FFOrdered{ - batch_size, - seq_len, - size_t_from_int(attrs.vdim), + TensorDims{ + FFOrdered{ + batch_size, + seq_len, + size_t_from_int(attrs.vdim), + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; SUBCASE("get_output_shape") { - tl::expected result = get_output_shape(attrs, input_q, input_k, input_v); + tl::expected result = + get_output_shape(attrs, input_q, input_k, input_v); tl::expected correct = TensorShape{ - TensorDims{ - FFOrdered{ - batch_size, - seq_len, - size_t_from_int(attrs.embed_dim), - } - }, - DataType::FLOAT, + TensorDims{FFOrdered{ + batch_size, + seq_len, + size_t_from_int(attrs.embed_dim), + }}, + DataType::FLOAT, }; CHECK(result == correct); } SUBCASE("get_weights_shape") { - tl::expected result = get_weights_shape(attrs, input_q, input_k, input_v); - - int qProjPerHeadWeightSize = attrs.kdim * dim_at_idx(input_q, ff_dim_t{-1}); - int kProjPerHeadWeightSize = attrs.kdim * dim_at_idx(input_k, ff_dim_t{-1}); - int vProjPerHeadWeightSize = attrs.vdim * dim_at_idx(input_v, ff_dim_t{-1}); + tl::expected result = + get_weights_shape(attrs, input_q, input_k, input_v); + + int qProjPerHeadWeightSize = + attrs.kdim * dim_at_idx(input_q, ff_dim_t{-1}); + int kProjPerHeadWeightSize = + attrs.kdim * dim_at_idx(input_k, ff_dim_t{-1}); + int vProjPerHeadWeightSize = + attrs.vdim * dim_at_idx(input_v, ff_dim_t{-1}); int oProjPerHeadWeightSize = attrs.embed_dim * attrs.vdim; - int perHeadWeightSize = qProjPerHeadWeightSize + kProjPerHeadWeightSize + vProjPerHeadWeightSize + oProjPerHeadWeightSize; + int perHeadWeightSize = qProjPerHeadWeightSize + kProjPerHeadWeightSize + + vProjPerHeadWeightSize + oProjPerHeadWeightSize; tl::expected correct = TensorShape{ - TensorDims{ - FFOrdered{ - size_t_from_int(perHeadWeightSize), - size_t_from_int(attrs.num_heads), - } - }, - DataType::FLOAT, + TensorDims{FFOrdered{ + size_t_from_int(perHeadWeightSize), + size_t_from_int(attrs.num_heads), + }}, + DataType::FLOAT, }; CHECK(result == correct); @@ -100,18 +101,18 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("parallel shape inference for MultiHeadAttentionAttrs") { int embed_dim = 32; - /* Parameter meanings can be found at + /* Parameter meanings can be found at * https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html */ MultiHeadAttentionAttrs attrs = { - /*embed_dim=*/embed_dim, - /*num_heads=*/10, - /*kdim=*/embed_dim, - /*vdim=*/embed_dim, - /*dropout=*/0.0, - /*bias=*/true, - /*add_bias_kv=*/false, - /*add_zero_attn=*/false, + /*embed_dim=*/embed_dim, + /*num_heads=*/10, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, }; size_t batchsize = 40; @@ -121,64 +122,87 @@ TEST_SUITE(FF_TEST_SUITE) { size_t v_size = 72; TensorShape unpar_q_shape = TensorShape{ - TensorDims{ - FFOrdered{ - batchsize, - seq_len, - q_size, + TensorDims{ + FFOrdered{ + batchsize, + seq_len, + q_size, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; TensorShape unpar_k_shape = TensorShape{ - TensorDims{ - FFOrdered{ - batchsize, - seq_len, - k_size, + TensorDims{ + FFOrdered{ + batchsize, + seq_len, + k_size, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; TensorShape unpar_v_shape = TensorShape{ - TensorDims{ - FFOrdered{ - batchsize, - seq_len, - v_size, + TensorDims{ + FFOrdered{ + batchsize, + seq_len, + v_size, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; - tl::expected result_unpar_o_shape = get_output_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); + tl::expected result_unpar_o_shape = + get_output_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); REQUIRE(result_unpar_o_shape.has_value()); TensorShape unpar_o_shape = result_unpar_o_shape.value(); - tl::expected result_unpar_w_shape = get_weights_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); + tl::expected result_unpar_w_shape = + get_weights_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); REQUIRE(result_unpar_o_shape.has_value()); TensorShape unpar_w_shape = result_unpar_w_shape.value(); - auto make_q = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_batch, int o_seq_len, int o_q) { - return lift_to_parallel_with_degrees(unpar_q_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_q}); + auto make_q = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_seq_len, + int o_q) { + return lift_to_parallel_with_degrees( + unpar_q_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_q}); }; - auto make_k = [&](int o_sum, int o_eq, int o_batch, int o_seq_len, int o_k) { - return lift_to_parallel_with_degrees(unpar_k_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_k}); + auto make_k = [&](int o_sum, + int o_eq, + int o_batch, + int o_seq_len, + int o_k) { + return lift_to_parallel_with_degrees( + unpar_k_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_k}); }; - auto make_v = [&](int o_sum, int o_eq, int o_batch, int o_seq_len, int o_v) { - return lift_to_parallel_with_degrees(unpar_v_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_v}); + auto make_v = [&](int o_sum, + int o_eq, + int o_batch, + int o_seq_len, + int o_v) { + return lift_to_parallel_with_degrees( + unpar_v_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_v}); }; - auto make_o = [&](int o_sum, int o_eq, int o_batch, int o_seq_len, int o_o) { - return lift_to_parallel_with_degrees(unpar_o_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_o}); + auto make_o = [&](int o_sum, + int o_eq, + int o_batch, + int o_seq_len, + int o_o) { + return lift_to_parallel_with_degrees( + unpar_o_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_o}); }; auto make_w = [&](int o_sum, int o_eq, int o_e, int o_h) { - return lift_to_parallel_with_degrees(unpar_w_shape, o_sum, o_eq, FFOrdered{o_e, o_h}); + return lift_to_parallel_with_degrees( + unpar_w_shape, o_sum, o_eq, FFOrdered{o_e, o_h}); }; SUBCASE("data parallelism") { @@ -187,13 +211,17 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelTensorShape k = make_k(1, 1, o_b, 1, 1); ParallelTensorShape v = make_v(1, 1, o_b, 1, 1); - tl::expected result_o = get_output_shape(attrs, q, k, v); - tl::expected correct_o = make_o(1, 1, o_b, 1, 1); + tl::expected result_o = + get_output_shape(attrs, q, k, v); + tl::expected correct_o = + make_o(1, 1, o_b, 1, 1); CHECK(result_o == correct_o); - tl::expected result_w = get_weights_shape(attrs, q, k, v); - tl::expected correct_w = make_w(1, o_b, 1, 1); + tl::expected result_w = + get_weights_shape(attrs, q, k, v); + tl::expected correct_w = + make_w(1, o_b, 1, 1); CHECK(result_w == correct_w); } @@ -204,13 +232,17 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelTensorShape k = make_k(1, o_h, 1, 1, 1); ParallelTensorShape v = make_v(1, o_h, 1, 1, 1); - tl::expected result_o = get_output_shape(attrs, q, k, v); - tl::expected correct_o = make_o(o_h, 1, 1, 1, 1); + tl::expected result_o = + get_output_shape(attrs, q, k, v); + tl::expected correct_o = + make_o(o_h, 1, 1, 1, 1); CHECK(result_o == correct_o); - tl::expected result_w = get_weights_shape(attrs, q, k, v); - tl::expected correct_w = make_w(1, 1, 1, o_h); + tl::expected result_w = + get_weights_shape(attrs, q, k, v); + tl::expected correct_w = + make_w(1, 1, 1, o_h); CHECK(result_w == correct_w); } @@ -222,13 +254,17 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelTensorShape k = make_k(1, o_h, o_b, 1, 1); ParallelTensorShape v = make_v(1, o_h, o_b, 1, 1); - tl::expected result_o = get_output_shape(attrs, q, k, v); - tl::expected correct_o = make_o(o_h, 1, o_b, 1, 1); + tl::expected result_o = + get_output_shape(attrs, q, k, v); + tl::expected correct_o = + make_o(o_h, 1, o_b, 1, 1); CHECK(result_o == correct_o); - tl::expected result_w = get_weights_shape(attrs, q, k, v); - tl::expected correct_w = make_w(1, o_b, 1, o_h); + tl::expected result_w = + get_weights_shape(attrs, q, k, v); + tl::expected correct_w = + make_w(1, o_b, 1, o_h); CHECK(result_w == correct_w); } diff --git a/lib/op-attrs/test/src/test_batch_matmul.cc b/lib/op-attrs/test/src/test_batch_matmul.cc index 66bd0b9edc..f48478be10 100644 --- a/lib/op-attrs/test/src/test_batch_matmul.cc +++ b/lib/op-attrs/test/src/test_batch_matmul.cc @@ -1,6 +1,5 @@ -#include "test/utils/doctest.h" #include "op-attrs/ops/batch_matmul.h" - +#include "test/utils/doctest.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(BatchMatmulAttrs, TensorShape)") { @@ -10,44 +9,46 @@ TEST_SUITE(FF_TEST_SUITE) { size_t p = 10; BatchMatmulAttrs attrs = { - /*a_seq_length_dim=*/0, // TODO figure out if these arguments are still relevant - /*b_seq_length_dim=*/0, + /*a_seq_length_dim=*/0, // TODO figure out if these arguments are still + // relevant + /*b_seq_length_dim=*/0, }; TensorShape input_lhs_shape = { - TensorDims{ - FFOrdered{ - b, - n, - m, + TensorDims{ + FFOrdered{ + b, + n, + m, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; SUBCASE("valid") { TensorShape input_rhs_shape = { - TensorDims{ - FFOrdered{ - b, - m, - p, + TensorDims{ + FFOrdered{ + b, + m, + p, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; - tl::expected result = get_output_shape(attrs, input_lhs_shape, input_rhs_shape); + tl::expected result = + get_output_shape(attrs, input_lhs_shape, input_rhs_shape); tl::expected correct_output_shape = TensorShape{ - TensorDims{ - FFOrdered{ - b, - n, - p, + TensorDims{ + FFOrdered{ + b, + n, + p, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; CHECK(result == correct_output_shape); @@ -55,34 +56,36 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("mismatched b") { TensorShape input_rhs_shape = { - TensorDims{ - FFOrdered{ - b + 1, - m, - p, + TensorDims{ + FFOrdered{ + b + 1, + m, + p, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; - tl::expected result = get_output_shape(attrs, input_lhs_shape, input_rhs_shape); + tl::expected result = + get_output_shape(attrs, input_lhs_shape, input_rhs_shape); CHECK(!result.has_value()); } SUBCASE("mismatched m") { TensorShape input_rhs_shape = { - TensorDims{ - FFOrdered{ - b, - m + 1, - p, + TensorDims{ + FFOrdered{ + b, + m + 1, + p, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; - tl::expected result = get_output_shape(attrs, input_lhs_shape, input_rhs_shape); + tl::expected result = + get_output_shape(attrs, input_lhs_shape, input_rhs_shape); CHECK(!result.has_value()); } @@ -98,135 +101,166 @@ TEST_SUITE(FF_TEST_SUITE) { size_t p = 7 * 7; int o_p = 7; int o_sum = 11; - + BatchMatmulAttrs attrs = { - /*a_seq_length_dim=*/0, // TODO figure out if these arguments are still relevant - /*b_seq_length_dim=*/0, + /*a_seq_length_dim=*/0, // TODO figure out if these arguments are still + // relevant + /*b_seq_length_dim=*/0, }; auto make_lhs = [&](int o_sum, int o_eq, int o_b, int o_n, int o_m) { return ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{b, o_b}, - ShardParallelDim{n, o_n}, - ShardParallelDim{m, o_m}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{b, o_b}, + ShardParallelDim{n, o_n}, + ShardParallelDim{m, o_m}, + }, + ReplicaParallelDimSet{ + o_sum, + o_eq, + }, }, - ReplicaParallelDimSet{ - o_sum, - o_eq, - }, - }, - DataType::FLOAT, + DataType::FLOAT, }; }; auto make_rhs = [&](int o_sum, int o_eq, int o_b, int o_m, int o_p) { return ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{b, o_b}, - ShardParallelDim{m, o_m}, - ShardParallelDim{p, o_p}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{b, o_b}, + ShardParallelDim{m, o_m}, + ShardParallelDim{p, o_p}, + }, + ReplicaParallelDimSet{ + o_sum, + o_eq, + }, }, - ReplicaParallelDimSet{ - o_sum, - o_eq, - }, - }, - DataType::FLOAT, + DataType::FLOAT, }; }; auto make_output = [&](int o_sum, int o_eq, int o_b, int o_n, int o_p) { return ParallelTensorShape{ - ParallelTensorDims{ - FFOrdered{ - ShardParallelDim{b, o_b}, - ShardParallelDim{n, o_n}, - ShardParallelDim{p, o_p}, + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{b, o_b}, + ShardParallelDim{n, o_n}, + ShardParallelDim{p, o_p}, + }, + ReplicaParallelDimSet{ + o_sum, + o_eq, + }, }, - ReplicaParallelDimSet{ - o_sum, - o_eq, - }, - }, - DataType::FLOAT, + DataType::FLOAT, }; }; SUBCASE("data parallel") { - tl::expected result = get_output_shape(attrs, make_lhs(1, 1, o_b, 1, 1), make_rhs(1, 1, o_b, 1, 1)); - tl::expected correct = make_output(1, 1, o_b, 1, 1); + tl::expected result = get_output_shape( + attrs, make_lhs(1, 1, o_b, 1, 1), make_rhs(1, 1, o_b, 1, 1)); + tl::expected correct = + make_output(1, 1, o_b, 1, 1); CHECK(result == correct); } SUBCASE("n parallel") { - tl::expected result = get_output_shape(attrs, make_lhs(1, 1, 1, o_n, 1), make_rhs(1, o_n, 1, 1, 1)); - tl::expected correct = make_output(1, 1, 1, o_n, 1); + tl::expected result = get_output_shape( + attrs, make_lhs(1, 1, 1, o_n, 1), make_rhs(1, o_n, 1, 1, 1)); + tl::expected correct = + make_output(1, 1, 1, o_n, 1); CHECK(result == correct); } SUBCASE("p parallel") { - tl::expected result = get_output_shape(attrs, make_lhs(1, o_p, 1, 1, 1), make_rhs(1, 1, 1, 1, o_p)); - tl::expected correct = make_output(1, 1, 1, 1, o_p); + tl::expected result = get_output_shape( + attrs, make_lhs(1, o_p, 1, 1, 1), make_rhs(1, 1, 1, 1, o_p)); + tl::expected correct = + make_output(1, 1, 1, 1, o_p); CHECK(result == correct); } SUBCASE("reduction parallel") { - tl::expected result = get_output_shape(attrs, make_lhs(1, 1, 1, 1, o_m), make_rhs(1, 1, 1, o_m, 1)); - tl::expected correct = make_output(o_m, 1, 1, 1, 1); + tl::expected result = get_output_shape( + attrs, make_lhs(1, 1, 1, 1, o_m), make_rhs(1, 1, 1, o_m, 1)); + tl::expected correct = + make_output(o_m, 1, 1, 1, 1); CHECK(result == correct); } SUBCASE("propagate reduction lhs") { - tl::expected result = get_output_shape(attrs, make_lhs(o_sum, 1, 1, 1, 1), make_rhs(1, o_sum, 1, 1, 1)); - tl::expected correct = make_output(o_sum, 1, 1, 1, 1); + tl::expected result = get_output_shape( + attrs, make_lhs(o_sum, 1, 1, 1, 1), make_rhs(1, o_sum, 1, 1, 1)); + tl::expected correct = + make_output(o_sum, 1, 1, 1, 1); CHECK(result == correct); } SUBCASE("propagate reduction rhs") { - tl::expected result = get_output_shape(attrs, make_lhs(1, o_sum, 1, 1, 1), make_rhs(o_sum, 1, 1, 1, 1)); - tl::expected correct = make_output(o_sum, 1, 1, 1, 1); + tl::expected result = get_output_shape( + attrs, make_lhs(1, o_sum, 1, 1, 1), make_rhs(o_sum, 1, 1, 1, 1)); + tl::expected correct = + make_output(o_sum, 1, 1, 1, 1); CHECK(result == correct); } SUBCASE("reduction lhs & reduction rhs") { - tl::expected result = get_output_shape(attrs, make_lhs(o_sum, o_sum, 1, 1, 1), make_rhs(o_sum, o_sum, 1, 1, 1)); - tl::expected correct = make_output(o_sum * o_sum, 1, 1, 1, 1); + tl::expected result = + get_output_shape(attrs, + make_lhs(o_sum, o_sum, 1, 1, 1), + make_rhs(o_sum, o_sum, 1, 1, 1)); + tl::expected correct = + make_output(o_sum * o_sum, 1, 1, 1, 1); CHECK(result == correct); } SUBCASE("reduction lhs & rhs (invalid)") { - tl::expected result = get_output_shape(attrs, make_lhs(o_sum, 1, 1, 1, 1), make_rhs(o_sum, 1, 1, 1, 1)); + tl::expected result = get_output_shape( + attrs, make_lhs(o_sum, 1, 1, 1, 1), make_rhs(o_sum, 1, 1, 1, 1)); - CHECK_MESSAGE(!result.has_value(), "Unexpected successful value: ", result); + CHECK_MESSAGE( + !result.has_value(), "Unexpected successful value: ", result); } SUBCASE("reduction lhs & n") { - tl::expected result = get_output_shape(attrs, make_lhs(o_sum, 1, 1, o_n, 1), make_rhs(1, o_sum * o_n, 1, 1, 1)); - tl::expected correct = make_output(o_sum, 1, 1, o_n, 1); + tl::expected result = + get_output_shape(attrs, + make_lhs(o_sum, 1, 1, o_n, 1), + make_rhs(1, o_sum * o_n, 1, 1, 1)); + tl::expected correct = + make_output(o_sum, 1, 1, o_n, 1); CHECK(result == correct); } SUBCASE("reduction lhs & reduction rhs & n") { - tl::expected result = get_output_shape(attrs, make_lhs(o_sum, o_sum, 1, o_n, 1), make_rhs(o_sum, o_sum * o_n, 1, 1, 1)); - tl::expected correct = make_output(o_sum * o_sum, 1, 1, o_n, 1); + tl::expected result = + get_output_shape(attrs, + make_lhs(o_sum, o_sum, 1, o_n, 1), + make_rhs(o_sum, o_sum * o_n, 1, 1, 1)); + tl::expected correct = + make_output(o_sum * o_sum, 1, 1, o_n, 1); CHECK(result == correct); } SUBCASE("reduction lhs & reduction rhs & n & m") { - tl::expected result = get_output_shape(attrs, make_lhs(o_sum, o_sum, 1, o_n, o_m), make_rhs(o_sum, o_sum * o_n, 1, o_m, 1)); - tl::expected correct = make_output(o_sum * o_sum * o_m, 1, 1, o_n, 1); + tl::expected result = + get_output_shape(attrs, + make_lhs(o_sum, o_sum, 1, o_n, o_m), + make_rhs(o_sum, o_sum * o_n, 1, o_m, 1)); + tl::expected correct = + make_output(o_sum * o_sum * o_m, 1, 1, o_n, 1); CHECK(result == correct); } diff --git a/lib/op-attrs/test/src/test_element_binary.cc b/lib/op-attrs/test/src/test_element_binary.cc index fa0841b732..b1aedbf6b5 100644 --- a/lib/op-attrs/test/src/test_element_binary.cc +++ b/lib/op-attrs/test/src/test_element_binary.cc @@ -1,6 +1,6 @@ +#include "op-attrs/ops/element_binary.h" #include "op-attrs/parallel_tensor_shape.h" #include "test/utils/doctest.h" -#include "op-attrs/ops/element_binary.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("EWAdd shape inference") { @@ -9,27 +9,28 @@ TEST_SUITE(FF_TEST_SUITE) { size_t d3 = 24; ElementBinaryAttrs attrs = ElementBinaryAttrs{ - OperatorType::EW_ADD, - DataType::FLOAT, - /*should_broadcast_lhs=*/false, - /*should_broadcast_rhs=*/false, + OperatorType::EW_ADD, + DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, }; TensorShape input_lhs = TensorShape{ - TensorDims{ - FFOrdered{ - d1, - d2, - d3, + TensorDims{ + FFOrdered{ + d1, + d2, + d3, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; TensorShape input_rhs = input_lhs; SUBCASE("correct") { - tl::expected result = get_output_shape(attrs, input_lhs, input_rhs); + tl::expected result = + get_output_shape(attrs, input_lhs, input_rhs); tl::expected correct = input_lhs; CHECK(result == correct); @@ -39,9 +40,12 @@ TEST_SUITE(FF_TEST_SUITE) { TensorShape incorrect_rhs = input_lhs; dim_at_idx(incorrect_rhs, ff_dim_t{0}) += 1; - tl::expected result = get_output_shape(attrs, input_lhs, incorrect_rhs); + tl::expected result = + get_output_shape(attrs, input_lhs, incorrect_rhs); - CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", result.error()); + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); } } @@ -51,38 +55,54 @@ TEST_SUITE(FF_TEST_SUITE) { size_t d3 = 24; ElementBinaryAttrs attrs = ElementBinaryAttrs{ - OperatorType::EW_ADD, - DataType::FLOAT, - /*should_broadcast_lhs=*/false, - /*should_broadcast_rhs=*/false, + OperatorType::EW_ADD, + DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, }; TensorShape unpar_lhs = TensorShape{ - TensorDims{ - FFOrdered{ - d1, - d2, - d3, + TensorDims{ + FFOrdered{ + d1, + d2, + d3, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; TensorShape unpar_rhs = unpar_lhs; - tl::expected result_unpar_output = get_output_shape(attrs, unpar_lhs, unpar_rhs); + tl::expected result_unpar_output = + get_output_shape(attrs, unpar_lhs, unpar_rhs); REQUIRE(result_unpar_output.has_value()); TensorShape unpar_output = result_unpar_output.value(); - auto make_lhs = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_1, int o_2, int o_3) { - return lift_to_parallel_with_degrees(unpar_lhs, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + auto make_lhs = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_1, + int o_2, + int o_3) { + return lift_to_parallel_with_degrees( + unpar_lhs, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); }; - auto make_rhs = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_1, int o_2, int o_3) { - return lift_to_parallel_with_degrees(unpar_rhs, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + auto make_rhs = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_1, + int o_2, + int o_3) { + return lift_to_parallel_with_degrees( + unpar_rhs, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); }; - auto make_output = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_1, int o_2, int o_3) { - return lift_to_parallel_with_degrees(unpar_output, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + auto make_output = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_1, + int o_2, + int o_3) { + return lift_to_parallel_with_degrees( + unpar_output, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); }; SUBCASE("data parallelism") { @@ -90,8 +110,10 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelTensorShape input_lhs = make_lhs(1, 1, degree, 1, 1); ParallelTensorShape input_rhs = make_rhs(1, 1, degree, 1, 1); - tl::expected result = get_output_shape(attrs, input_lhs, input_rhs); - tl::expected correct = make_output(1, 1, degree, 1, 1); + tl::expected result = + get_output_shape(attrs, input_lhs, input_rhs); + tl::expected correct = + make_output(1, 1, degree, 1, 1); CHECK(result == correct); } @@ -101,8 +123,10 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelTensorShape input_lhs = make_lhs(SumDegree{degree}, 1, 1, 1, 1); ParallelTensorShape input_rhs = make_rhs(SumDegree{degree}, 1, 1, 1, 1); - tl::expected result = get_output_shape(attrs, input_lhs, input_rhs); - tl::expected correct = make_output(SumDegree{degree}, 1, 1, 1, 1); + tl::expected result = + get_output_shape(attrs, input_lhs, input_rhs); + tl::expected correct = + make_output(SumDegree{degree}, 1, 1, 1, 1); CHECK(result == correct); } @@ -110,11 +134,16 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("invalid discard copy parallelism") { int degree = 4; - ParallelTensorShape input_lhs = make_lhs(1, DiscardCopyDegree{degree}, 1, 1, 1); - ParallelTensorShape input_rhs = make_rhs(1, DiscardCopyDegree{degree}, 1, 1, 1); - tl::expected result = get_output_shape(attrs, input_lhs, input_rhs); + ParallelTensorShape input_lhs = + make_lhs(1, DiscardCopyDegree{degree}, 1, 1, 1); + ParallelTensorShape input_rhs = + make_rhs(1, DiscardCopyDegree{degree}, 1, 1, 1); + tl::expected result = + get_output_shape(attrs, input_lhs, input_rhs); - CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", result.error()); + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); } SUBCASE("invalid mismatched parallelism degrees") { @@ -122,9 +151,12 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelTensorShape input_lhs = make_lhs(1, 1, 1, degree, 1); ParallelTensorShape input_rhs = make_rhs(1, 1, 1, 1, degree); - tl::expected result = get_output_shape(attrs, input_lhs, input_rhs); + tl::expected result = + get_output_shape(attrs, input_lhs, input_rhs); - CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", result.error()); + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); } } } diff --git a/lib/op-attrs/test/src/test_element_unary.cc b/lib/op-attrs/test/src/test_element_unary.cc index 2c7506dc8f..384dbc1a53 100644 --- a/lib/op-attrs/test/src/test_element_unary.cc +++ b/lib/op-attrs/test/src/test_element_unary.cc @@ -1,6 +1,6 @@ +#include "op-attrs/ops/element_unary.h" #include "op-attrs/parallel_tensor_shape.h" #include "test/utils/doctest.h" -#include "op-attrs/ops/element_unary.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ReLU shape inference") { @@ -11,23 +11,29 @@ TEST_SUITE(FF_TEST_SUITE) { ElementUnaryAttrs attrs = ElementUnaryAttrs{OperatorType::RELU}; TensorShape input = TensorShape{ - TensorDims{ - FFOrdered{ - d1, - d2, - d3, + TensorDims{ + FFOrdered{ + d1, + d2, + d3, + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; - tl::expected result = get_output_shape(attrs, input); + tl::expected result = + get_output_shape(attrs, input); tl::expected correct = input; CHECK(result == correct); - auto make_i = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_1, int o_2, int o_3) { - return lift_to_parallel_with_degrees(input, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + auto make_i = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_1, + int o_2, + int o_3) { + return lift_to_parallel_with_degrees( + input, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); }; SUBCASE("partition i.e., sharding parallelism") { @@ -35,8 +41,9 @@ TEST_SUITE(FF_TEST_SUITE) { int degree2 = 8; ParallelTensorShape par_input = make_i(1, 1, degree1, 1, degree2); - tl::expected result = get_output_shape(attrs, par_input); - tl::expected correct = par_input; + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = par_input; CHECK(result == correct); } @@ -44,17 +51,23 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("sum degree > 1") { int degree = 2; - tl::expected result = get_output_shape(attrs, make_i(SumDegree{degree}, 1, 1, 1, 1)); + tl::expected result = + get_output_shape(attrs, make_i(SumDegree{degree}, 1, 1, 1, 1)); - CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", result.error()); + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); } SUBCASE("discard copy degree > 1") { int degree = 2; - tl::expected result = get_output_shape(attrs, make_i(1, DiscardCopyDegree{degree}, 1, 1, 1)); + tl::expected result = get_output_shape( + attrs, make_i(1, DiscardCopyDegree{degree}, 1, 1, 1)); - CHECK_MESSAGE(!result.has_value(), "Unexpected successful result: ", result.error()); + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); } } } diff --git a/lib/op-attrs/test/src/test_embedding.cc b/lib/op-attrs/test/src/test_embedding.cc index f03ffdd27f..7bce6bd4d9 100644 --- a/lib/op-attrs/test/src/test_embedding.cc +++ b/lib/op-attrs/test/src/test_embedding.cc @@ -1,6 +1,6 @@ +#include "op-attrs/ops/embedding.h" #include "op-attrs/parallel_tensor_shape.h" #include "test/utils/doctest.h" -#include "op-attrs/ops/embedding.h" #include "utils/integer_conversions.h" TEST_SUITE(FF_TEST_SUITE) { @@ -8,123 +8,151 @@ TEST_SUITE(FF_TEST_SUITE) { int out_channels = 128; int num_entries = 1024; EmbeddingAttrs attrs = EmbeddingAttrs{ - /*num_entries=*/num_entries, - /*out_channels=*/out_channels, - /*aggr=*/AggregateOp::SUM, - /*data_type=*/DataType::FLOAT, + /*num_entries=*/num_entries, + /*out_channels=*/out_channels, + /*aggr=*/AggregateOp::SUM, + /*data_type=*/DataType::FLOAT, }; size_t batch_size = 48; size_t features_dim = 56; TensorShape input = { - TensorDims{ - FFOrdered{ - batch_size, - features_dim, - } - }, - DataType::INT32, + TensorDims{FFOrdered{ + batch_size, + features_dim, + }}, + DataType::INT32, }; TensorShape output = TensorShape{ - TensorDims{ - FFOrdered{ - batch_size, - size_t_from_int(out_channels), + TensorDims{ + FFOrdered{ + batch_size, + size_t_from_int(out_channels), + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; TensorShape weights = TensorShape{ - TensorDims{ - FFOrdered{ - size_t_from_int(num_entries), - size_t_from_int(out_channels), + TensorDims{ + FFOrdered{ + size_t_from_int(num_entries), + size_t_from_int(out_channels), + }, }, - }, - DataType::FLOAT, + DataType::FLOAT, }; // get_output_shape { - tl::expected output_result = get_output_shape(attrs, input); + tl::expected output_result = + get_output_shape(attrs, input); tl::expected output_correct = output; CHECK(output_result == output_correct); } // get_weights_shape { - tl::expected weight_result = get_weights_shape(attrs, input); + tl::expected weight_result = + get_weights_shape(attrs, input); tl::expected weight_correct = weights; CHECK(weight_result == weight_correct); } - auto make_input = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_batch, int o_features) { - return lift_to_parallel_with_degrees(input, o_sum, o_eq, FFOrdered{o_batch, o_features}); + auto make_input = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_features) { + return lift_to_parallel_with_degrees( + input, o_sum, o_eq, FFOrdered{o_batch, o_features}); }; - auto make_output = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_batch, int o_outchannels) { - return lift_to_parallel_with_degrees(output, o_sum, o_eq, FFOrdered{o_batch, o_outchannels}); + auto make_output = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_outchannels) { + return lift_to_parallel_with_degrees( + output, o_sum, o_eq, FFOrdered{o_batch, o_outchannels}); }; - auto make_weights = [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_entries, int o_outchannels) { - return lift_to_parallel_with_degrees(weights, o_sum, o_eq, FFOrdered{o_entries, o_outchannels}); + auto make_weights = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_entries, + int o_outchannels) { + return lift_to_parallel_with_degrees( + weights, o_sum, o_eq, FFOrdered{o_entries, o_outchannels}); }; SUBCASE("data parallelism") { int degree = 4; - ParallelTensorShape par_input = make_input(SumDegree{1}, DiscardCopyDegree{1}, degree, 1); + ParallelTensorShape par_input = + make_input(SumDegree{1}, DiscardCopyDegree{1}, degree, 1); { - tl::expected result = get_output_shape(attrs, par_input); - tl::expected correct = make_output(SumDegree{1}, DiscardCopyDegree{1}, degree, 1); + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = + make_output(SumDegree{1}, DiscardCopyDegree{1}, degree, 1); CHECK(result == correct); } { - tl::expected result = get_weights_shape(attrs, par_input); - tl::expected correct = make_weights(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1); + tl::expected result = + get_weights_shape(attrs, par_input); + tl::expected correct = + make_weights(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1); CHECK(result == correct); } } SUBCASE("input features parallelism") { int degree = 4; - ParallelTensorShape input = make_input(SumDegree{1}, DiscardCopyDegree{1}, 1, degree); + ParallelTensorShape input = + make_input(SumDegree{1}, DiscardCopyDegree{1}, 1, degree); { - tl::expected result = get_output_shape(attrs, input); - tl::expected correct = make_output(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1); + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = + make_output(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1); CHECK(result == correct); } { - tl::expected result = get_weights_shape(attrs, input); - tl::expected correct = make_weights(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1); + tl::expected result = + get_weights_shape(attrs, input); + tl::expected correct = + make_weights(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1); CHECK(result == correct); } } - SUBCASE("output channel shard parallelism") { - // NOTE (@lockshaw): in the current (parallel shape inference from just input tensor) representation we have to choose between - // either parallelism in the weight channel dimension or in the weight entry dimension. For now we choose to represent - // parallelism in the channel dimension, but partitioning in the entry dimension is also potentially useful as it produces - // sum parallelism in the output + // NOTE (@lockshaw): in the current (parallel shape inference from just + // input tensor) representation we have to choose between either + // parallelism in the weight channel dimension or in the weight entry + // dimension. For now we choose to represent parallelism in the channel + // dimension, but partitioning in the entry dimension is also potentially + // useful as it produces sum parallelism in the output int degree = 4; - ParallelTensorShape input = make_input(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1); + ParallelTensorShape input = + make_input(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1); { - tl::expected result = get_output_shape(attrs, input); - tl::expected correct = make_output(SumDegree{1}, DiscardCopyDegree{1}, 1, degree); + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = + make_output(SumDegree{1}, DiscardCopyDegree{1}, 1, degree); CHECK(result == correct); } { - tl::expected result = get_weights_shape(attrs, input); - tl::expected correct = make_weights(SumDegree{1}, DiscardCopyDegree{1}, 1, degree); + tl::expected result = + get_weights_shape(attrs, input); + tl::expected correct = + make_weights(SumDegree{1}, DiscardCopyDegree{1}, 1, degree); CHECK(result == correct); } } diff --git a/lib/pcg/include/pcg/open_dataflow_graph.h b/lib/pcg/include/pcg/open_dataflow_graph.h index 29454c4414..b3367686b3 100644 --- a/lib/pcg/include/pcg/open_dataflow_graph.h +++ b/lib/pcg/include/pcg/open_dataflow_graph.h @@ -1,19 +1,21 @@ // #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPEN_DATAFLOW_GRAPH_H // #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPEN_DATAFLOW_GRAPH_H -// +// // #include "utils/containers/enumerate_vector.h" // #include "utils/graph.h" // #include "pcg/dataflow_input.dtg.h" -// +// // namespace FlexFlow { -// +// // template // struct OpenDataflowGraph { // public: // OpenDataflowGraph() -// : g(OutputLabelledOpenMultiDiGraph::template create< -// UnorderedOutputLabelledOpenMultiDiGraph>()) { } -// +// : g(OutputLabelledOpenMultiDiGraph::template +// create< +// UnorderedOutputLabelledOpenMultiDiGraph>()) +// { } +// // DataflowInput add_external_input(OutputLabel const &label) { // /* size_t src_node_idx = edge_uid_ctr; */ // /* edge_uid_ctr++; */ @@ -21,46 +23,50 @@ // /* edge_uid_t edge_uid = { src_node_idx, src_port_idx }; */ // /* return MultiDiOutput{edge_uid}; */ // } -// -// std::vector add_operator(NodeLabel const &func, std::vector const &inputs, std::vector const &outputs) { +// +// std::vector add_operator(NodeLabel const &func, +// std::vector const &inputs, std::vector const +// &outputs) { // Node n = this->g.add_node(func); // for (auto const &[idx, input] : enumerate_vector(inputs)) { -// this->g.add_edge(MultiDiEdge{input.src, input.src_idx, n, this->make_port_for_idx(idx)}); +// this->g.add_edge(MultiDiEdge{input.src, input.src_idx, n, +// this->make_port_for_idx(idx)}); // } -// +// // std::vector result; // for (auto const &[idx, label] : enumerate_vector(outputs)) { // MultiDiOutput output = MultiDiOutput{n, this->make_port_for_idx(idx)}; // this->g.add_output(output, label); // result.push_back(output); // } -// +// // return result; // } -// +// // NodePort make_port_for_idx(int idx) { // if (!this->port_mapping.contains_l(idx)) { // this->port_mapping.equate(idx, this->g.add_node_port()); -// } +// } // return this->port_mapping.at_l(idx); // } -// +// // NodePort port_for_idx(int idx) const { // return this->port_mapping.at_l(idx); // } -// +// // int idx_for_port(NodePort const &p) const { -// return this->port_mapping.at_r(p); +// return this->port_mapping.at_r(p); // } -// -// OutputLabelledMultiDiGraphView const &get_raw_graph() const { +// +// OutputLabelledMultiDiGraphView const +// &get_raw_graph() const { // return this->g; // } -// +// // NodeLabel const &at(Node const &n) const { // return this->g.at(n); // } -// +// // OutputLabel const &at(MultiDiOutput const &o) const { // return this->g.at(o); // } @@ -69,7 +75,7 @@ // bidict port_mapping; // size_t edge_uid_ctr = 0; // }; -// +// // } // namespace FlexFlow -// +// // #endif diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index d5ed622a82..8c69b3a724 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -144,7 +144,8 @@ tensor_guid_t ComputationGraphBuilder::element_unary( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {}, output_shape); } @@ -161,7 +162,8 @@ tensor_guid_t ComputationGraphBuilder::element_scalar_unary( LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {}, output_shape); } @@ -206,8 +208,7 @@ tensor_guid_t ComputationGraphBuilder::element_binary( LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; TensorShape output_shape = throw_if_unexpected(get_output_shape( - attrs, this->get_shape(lhs_input), this->get_shape(rhs_input)) - ); + attrs, this->get_shape(lhs_input), this->get_shape(rhs_input))); return this->add_layer(layer, {lhs_input, rhs_input}, {}, output_shape); } @@ -446,9 +447,11 @@ tensor_guid_t ComputationGraphBuilder::embedding( TensorShape input_shape = this->get_shape(input); TensorAttrs weight_attrs = make_weight_attrs( - throw_if_unexpected(get_weights_shape(attrs, input_shape)), kernel_initializer); + throw_if_unexpected(get_weights_shape(attrs, input_shape)), + kernel_initializer); - TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {weight_attrs}, output_shape); } diff --git a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc index 4fdbc6a2ff..7c42bdd904 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc @@ -4,22 +4,24 @@ namespace FlexFlow { -TensorAttributeValue get_attribute(ParallelTensorAttrs const &attrs, TensorAttributeKey key) { +TensorAttributeValue get_attribute(ParallelTensorAttrs const &attrs, + TensorAttributeKey key) { switch (key) { case TensorAttributeKey::DIM_SIZES: { - std::vector sizes = transform(as_vector(ff_ordered_shard_dims(attrs.shape.dims)), - [](ShardParallelDim const &d) { return d.size; }); + std::vector sizes = + transform(as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + [](ShardParallelDim const &d) { return d.size; }); return TensorAttributeValue{sizes}; } case TensorAttributeKey::DIM_DEGREES: { - std::vector degrees = transform(as_vector(ff_ordered_shard_dims(attrs.shape.dims)), - [](ShardParallelDim const &d) { - return size_t_from_int(d.degree); - }); + std::vector degrees = transform( + as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + [](ShardParallelDim const &d) { return size_t_from_int(d.degree); }); return TensorAttributeValue{degrees}; } default: - throw std::runtime_error(fmt::format("Unknown TensorAttributeKey {}", static_cast(key))); + throw std::runtime_error( + fmt::format("Unknown TensorAttributeKey {}", static_cast(key))); } } diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 2da1878b07..fbaf572df1 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -6,6 +6,7 @@ #include "required_core.h" #include "type_traits_core.h" #include "utils/containers/extend_vector.h" +#include "utils/containers/vector_transform.h" #include "utils/exception.h" #include "utils/type_traits.h" #include @@ -19,7 +20,6 @@ #include #include #include -#include "utils/containers/vector_transform.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/vector_transform.h b/lib/utils/include/utils/containers/vector_transform.h index 13865732aa..6d13584775 100644 --- a/lib/utils/include/utils/containers/vector_transform.h +++ b/lib/utils/include/utils/containers/vector_transform.h @@ -1,13 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_TRANSFORM_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_TRANSFORM_H -#include #include +#include namespace FlexFlow { template -std::vector> vector_transform(std::vector const &v, F const &f) { +std::vector> + vector_transform(std::vector const &v, F const &f) { using Out = std::invoke_result_t; std::vector result; diff --git a/lib/utils/include/utils/containers/zip_vectors.h b/lib/utils/include/utils/containers/zip_vectors.h index 84664da48f..d32e539bef 100644 --- a/lib/utils/include/utils/containers/zip_vectors.h +++ b/lib/utils/include/utils/containers/zip_vectors.h @@ -1,13 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VECTORS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VECTORS_H -#include #include +#include namespace FlexFlow { template -std::vector> zip(std::vector const &l, std::vector const &r) { +std::vector> zip(std::vector const &l, + std::vector const &r) { std::vector> result; for (int i = 0; i < std::min(l.size(), r.size()); i++) { result.push_back(std::make_pair(l.at(i), r.at(i))); diff --git a/lib/utils/include/utils/fmt.decl.h b/lib/utils/include/utils/fmt.decl.h index 11008ee9c6..04902c8240 100644 --- a/lib/utils/include/utils/fmt.decl.h +++ b/lib/utils/include/utils/fmt.decl.h @@ -58,7 +58,6 @@ struct formatter<::std::variant> : formatter<::std::string> { -> decltype(ctx.out()); }; - } // namespace fmt #endif diff --git a/lib/utils/include/utils/fmt/expected.h b/lib/utils/include/utils/fmt/expected.h index e8d7f5b22d..5edd054ebe 100644 --- a/lib/utils/include/utils/fmt/expected.h +++ b/lib/utils/include/utils/fmt/expected.h @@ -3,15 +3,13 @@ #include "fmt/format.h" #include "utils/check_fmtable.h" -#include #include +#include namespace fmt { template -struct formatter< - ::tl::expected, - Char> +struct formatter<::tl::expected, Char> /* std::enable_if_t>::value>> */ : formatter<::std::string> { template @@ -22,13 +20,13 @@ struct formatter< if (m.has_value()) { result = fmt::format("expected({})", m.value()); } else { - result = fmt::format("unexpected({})", m.error()); + result = fmt::format("unexpected({})", m.error()); } return formatter::format(result, ctx); } }; -} // namespace FlexFlow +} // namespace fmt #endif diff --git a/lib/utils/src/utils/integer_conversions.cc b/lib/utils/src/utils/integer_conversions.cc index e3e9983c45..34ee3109bf 100644 --- a/lib/utils/src/utils/integer_conversions.cc +++ b/lib/utils/src/utils/integer_conversions.cc @@ -10,7 +10,7 @@ size_t size_t_from_int(int x) { } int int_from_size_t(size_t x) { - assert (x < std::numeric_limits::max()); + assert(x < std::numeric_limits::max()); return static_cast(x); } diff --git a/lib/utils/test/common/include/test/utils/doctest.h b/lib/utils/test/common/include/test/utils/doctest.h index db6d2a3f3b..ff7683dbcd 100644 --- a/lib/utils/test/common/include/test/utils/doctest.h +++ b/lib/utils/test/common/include/test/utils/doctest.h @@ -1,12 +1,12 @@ #include "doctest/doctest.h" #include "utils/containers.decl.h" +#include "utils/fmt/expected.h" +#include #include +#include #include #include #include -#include -#include "utils/fmt/expected.h" -#include using namespace FlexFlow; From 7e288ec62d28a6d174da1254e9f7ba3224b4c699 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 3 Jun 2024 13:43:10 -0700 Subject: [PATCH 33/43] Fix build errors --- .proj.toml | 2 +- flake.lock | 7 +- flake.nix | 2 +- lib/kernels/include/kernels/array_shape.h | 5 +- lib/kernels/include/kernels/cast_kernels.h | 2 +- lib/kernels/include/kernels/conv_2d_kernels.h | 2 +- .../include/kernels/element_unary_kernels.h | 22 ++- lib/kernels/include/kernels/legion_dim.h | 8 +- .../include/kernels/legion_dim_t.dtg.h | 54 +++++++ .../include/kernels/legion_dim_t.struct.toml | 14 ++ lib/kernels/include/kernels/pool_2d_kernels.h | 2 +- lib/kernels/include/kernels/reduce_kernels.h | 2 +- .../include/kernels/transpose_kernels.h | 2 +- lib/kernels/src/cuda/ops/concat_kernels.cu | 11 +- .../src/cuda/ops/element_binary_kernels.cu | 2 - .../src/cuda/ops/element_unary_kernels.cu | 144 +++++++++++------- lib/kernels/src/cuda/ops/gather_kernels.cu | 6 +- lib/kernels/src/cuda/ops/linear_kernels.cu | 5 +- lib/kernels/src/cuda/ops/reduce_kernels.cu | 4 +- lib/kernels/src/cuda/ops/transpose_kernels.cu | 8 +- lib/kernels/src/kernels/legion_dim_t.dtg.cc | 69 +++++++++ .../op-attrs/ops/transpose_attrs.dtg.h | 9 +- .../op-attrs/ops/transpose_attrs.struct.toml | 4 +- .../src/op-attrs/ops/transpose_attrs.dtg.cc | 21 +-- .../test/substitution-generator/json.cc | 2 +- .../operator_pattern/get_attribute.cc | 4 - 26 files changed, 289 insertions(+), 124 deletions(-) create mode 100644 lib/kernels/include/kernels/legion_dim_t.dtg.h create mode 100644 lib/kernels/include/kernels/legion_dim_t.struct.toml create mode 100644 lib/kernels/src/kernels/legion_dim_t.dtg.cc diff --git a/.proj.toml b/.proj.toml index 6b97ff4a37..43f1522186 100644 --- a/.proj.toml +++ b/.proj.toml @@ -8,7 +8,7 @@ build_targets = [ "op-attrs", "kernels", "pcg", - "substitutions", + # "substitutions", # "compiler", ] test_targets = [ diff --git a/flake.lock b/flake.lock index afcfa9aa21..83750e29e7 100644 --- a/flake.lock +++ b/flake.lock @@ -43,16 +43,15 @@ ] }, "locked": { - "lastModified": 1716747446, - "narHash": "sha256-mn3br/KFBtv4c4ZLHR1ZIqFeM1p93rcfHivselz+Nr4=", + "lastModified": 1717446123, + "narHash": "sha256-KHdIUG5LWJn5jLLbaCLifzanffn9GfDrf7HU3VYu0Iw=", "owner": "lockshaw", "repo": "proj", - "rev": "62839e7ac51dc16ddd05a5e174e0590ea85afc65", + "rev": "1dd210d3ed69392f221721664dff23979872eb6f", "type": "github" }, "original": { "owner": "lockshaw", - "ref": "dtgen", "repo": "proj", "type": "github" } diff --git a/flake.nix b/flake.nix index e0c1b11b7a..2dc005b113 100644 --- a/flake.nix +++ b/flake.nix @@ -18,7 +18,7 @@ flake-utils.url = "github:numtide/flake-utils"; proj-repo = { - url = "github:lockshaw/proj/dtgen"; + url = "github:lockshaw/proj"; inputs.nixpkgs.follows = "nixpkgs"; inputs.flake-utils.follows = "flake-utils"; }; diff --git a/lib/kernels/include/kernels/array_shape.h b/lib/kernels/include/kernels/array_shape.h index 15f14f8757..7a37299b24 100644 --- a/lib/kernels/include/kernels/array_shape.h +++ b/lib/kernels/include/kernels/array_shape.h @@ -7,6 +7,7 @@ #include "utils/visitable.h" #include #include +#include namespace FlexFlow { @@ -41,9 +42,9 @@ struct ArrayShape { std::optional at_maybe(std::size_t) const; ArrayShape reversed_dim_order() const; - ArrayShape sub_shape(std::optional start, - std::optional end) const; + ArrayShape sub_shape(std::optional> start, + std::optional> end) const; public: LegionTensorDims dims; }; diff --git a/lib/kernels/include/kernels/cast_kernels.h b/lib/kernels/include/kernels/cast_kernels.h index 4e6878e318..96f9aadd52 100644 --- a/lib/kernels/include/kernels/cast_kernels.h +++ b/lib/kernels/include/kernels/cast_kernels.h @@ -4,7 +4,7 @@ #include "device.h" #include "kernels/accessor.h" #include "kernels/ff_handle.h" -#include "op-attrs/activation.h" +#include "op-attrs/activation.dtg.h" namespace FlexFlow { namespace Kernels { diff --git a/lib/kernels/include/kernels/conv_2d_kernels.h b/lib/kernels/include/kernels/conv_2d_kernels.h index b646c4b7cb..0a93125367 100644 --- a/lib/kernels/include/kernels/conv_2d_kernels.h +++ b/lib/kernels/include/kernels/conv_2d_kernels.h @@ -4,7 +4,7 @@ #include "device.h" #include "kernels/accessor.h" #include "kernels/ff_handle.h" -#include "op-attrs/activation.h" +#include "op-attrs/activation.dtg.h" #include "utils/visitable.h" namespace FlexFlow { diff --git a/lib/kernels/include/kernels/element_unary_kernels.h b/lib/kernels/include/kernels/element_unary_kernels.h index dedfbb01ef..5044b0cdb2 100644 --- a/lib/kernels/include/kernels/element_unary_kernels.h +++ b/lib/kernels/include/kernels/element_unary_kernels.h @@ -24,18 +24,34 @@ namespace ElementUnary { ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, ArrayShape const &output_shape, - ElementUnaryUnifiedAttrs const &attrs); + ElementUnaryAttrs const &attrs); void forward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryUnifiedAttrs const &attrs, + ElementUnaryAttrs const &attrs, PerDeviceFFHandle &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output); +void forward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementScalarUnaryAttrs const &attrs, + PerDeviceFFHandle &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output); + +void backward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad); + void backward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryUnifiedAttrs const &attrs, + ElementScalarUnaryAttrs const &attrs, PerDeviceFFHandle &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, diff --git a/lib/kernels/include/kernels/legion_dim.h b/lib/kernels/include/kernels/legion_dim.h index f5c1d7ccc9..b283a4c4e4 100644 --- a/lib/kernels/include/kernels/legion_dim.h +++ b/lib/kernels/include/kernels/legion_dim.h @@ -2,13 +2,13 @@ #define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_LEGION_DIM_H #include "op-attrs/dim_ordered.h" -#include "utils/strong_typedef.h" +#include "kernels/legion_dim_t.dtg.h" namespace FlexFlow { -struct legion_dim_t : strong_typedef { - using strong_typedef::strong_typedef; -}; +legion_dim_t add_to_legion_dim(legion_dim_t, int); + +legion_dim_t legion_dim_from_ff_dim(ff_dim_t, int num_dimensions); template using LegionOrdered = DimOrdered; diff --git a/lib/kernels/include/kernels/legion_dim_t.dtg.h b/lib/kernels/include/kernels/legion_dim_t.dtg.h new file mode 100644 index 0000000000..622f9c240a --- /dev/null +++ b/lib/kernels/include/kernels/legion_dim_t.dtg.h @@ -0,0 +1,54 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/kernels/include/kernels/legion_dim_t.struct.toml +/* proj-data +{ + "generated_from": "f67d6e50c53539a21d69e7162cf965f4" +} +*/ + +#ifndef _FLEXFLOW_LIB_KERNELS_INCLUDE_KERNELS_LEGION_DIM_T_DTG_H +#define _FLEXFLOW_LIB_KERNELS_INCLUDE_KERNELS_LEGION_DIM_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct legion_dim_t { + legion_dim_t() = delete; + legion_dim_t(int const &value); + + bool operator==(legion_dim_t const &) const; + bool operator!=(legion_dim_t const &) const; + bool operator<(legion_dim_t const &) const; + bool operator>(legion_dim_t const &) const; + bool operator<=(legion_dim_t const &) const; + bool operator>=(legion_dim_t const &) const; + int value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::legion_dim_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::legion_dim_t from_json(json const &); + static void to_json(json &, FlexFlow::legion_dim_t const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(legion_dim_t const &); +std::ostream &operator<<(std::ostream &, legion_dim_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_KERNELS_INCLUDE_KERNELS_LEGION_DIM_T_DTG_H diff --git a/lib/kernels/include/kernels/legion_dim_t.struct.toml b/lib/kernels/include/kernels/legion_dim_t.struct.toml new file mode 100644 index 0000000000..d2afb0d73f --- /dev/null +++ b/lib/kernels/include/kernels/legion_dim_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "legion_dim_t" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/kernels/include/kernels/pool_2d_kernels.h b/lib/kernels/include/kernels/pool_2d_kernels.h index 96bb6eccf9..798c0507f8 100644 --- a/lib/kernels/include/kernels/pool_2d_kernels.h +++ b/lib/kernels/include/kernels/pool_2d_kernels.h @@ -3,7 +3,7 @@ #include "device.h" #include "kernels/ff_handle.h" -#include "op-attrs/activation.h" +#include "op-attrs/activation.dtg.h" #include "op-attrs/ops/pool_2d.h" #include "utils/visitable.h" diff --git a/lib/kernels/include/kernels/reduce_kernels.h b/lib/kernels/include/kernels/reduce_kernels.h index 51730fb0cd..56241b73ce 100644 --- a/lib/kernels/include/kernels/reduce_kernels.h +++ b/lib/kernels/include/kernels/reduce_kernels.h @@ -4,7 +4,7 @@ #include "array_shape.h" #include "device.h" #include "ff_handle.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.dtg.h" namespace FlexFlow { diff --git a/lib/kernels/include/kernels/transpose_kernels.h b/lib/kernels/include/kernels/transpose_kernels.h index cb34ff6736..fa087fada3 100644 --- a/lib/kernels/include/kernels/transpose_kernels.h +++ b/lib/kernels/include/kernels/transpose_kernels.h @@ -8,7 +8,7 @@ namespace FlexFlow { struct TransposePerDeviceState { int num_dim; - req> perm; + req> perm; }; FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(TransposePerDeviceState, diff --git a/lib/kernels/src/cuda/ops/concat_kernels.cu b/lib/kernels/src/cuda/ops/concat_kernels.cu index dcf7a41a2f..68004738d2 100644 --- a/lib/kernels/src/cuda/ops/concat_kernels.cu +++ b/lib/kernels/src/cuda/ops/concat_kernels.cu @@ -25,15 +25,8 @@ void calc_blk_size(size_t &num_blocks, size_t &blk_size, ArrayShape const &shape, ff_dim_t axis) { - num_blocks = 1; - blk_size = 1; - for (int d = 0; d < shape.num_dims(); d++) { - if (d <= axis) { - blk_size *= shape[legion_dim_t(d)]; - } else { - num_blocks *= shape[legion_dim_t(d)]; - } - } + blk_size = shape.sub_shape(legion_dim_t{0}, axis).num_elements(); + num_blocks = shape.sub_shape(axis, std::nullopt).num_elements(); } void forward_kernel(cudaStream_t stream, diff --git a/lib/kernels/src/cuda/ops/element_binary_kernels.cu b/lib/kernels/src/cuda/ops/element_binary_kernels.cu index 3d548c43f9..45b4d43006 100644 --- a/lib/kernels/src/cuda/ops/element_binary_kernels.cu +++ b/lib/kernels/src/cuda/ops/element_binary_kernels.cu @@ -23,8 +23,6 @@ namespace FlexFlow { namespace Kernels { namespace ElementBinary { -using OperatorType = Op; - __global__ void elewise_binary_backward_kernel(size_t volume, float const alpha, float const beta, diff --git a/lib/kernels/src/cuda/ops/element_unary_kernels.cu b/lib/kernels/src/cuda/ops/element_unary_kernels.cu index 305e778726..06bde6a9ef 100644 --- a/lib/kernels/src/cuda/ops/element_unary_kernels.cu +++ b/lib/kernels/src/cuda/ops/element_unary_kernels.cu @@ -25,29 +25,19 @@ namespace ElementUnary { static bool use_cudnn(OperatorType op_type) { switch (op_type) { - case Op::RELU: - case Op::SIGMOID: - case Op::TANH: - case Op::ELU: + case OperatorType::RELU: + case OperatorType::SIGMOID: + case OperatorType::TANH: + case OperatorType::ELU: return true; default: return false; } } -template -T get_scalar(ElementUnaryUnifiedAttrs const &attrs) { - if (std::holds_alternative(attrs)) { - return (T)std::get(attrs).scalar; - } else { - T dummy_scalar; - return dummy_scalar; - } -} - -ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, +static ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, ArrayShape const &output_shape, - ElementUnaryUnifiedAttrs const &attrs) { + OperatorType op_type) { ffTensorDescriptor_t inputTensor; ffTensorDescriptor_t outputTensor; @@ -57,21 +47,19 @@ ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); - Op op_type = get_op_type(attrs); - if (use_cudnn(op_type)) { cudnnActivationMode_t mode; switch (op_type) { - case Op::SIGMOID: + case OperatorType::SIGMOID: mode = CUDNN_ACTIVATION_SIGMOID; break; - case Op::RELU: + case OperatorType::RELU: mode = CUDNN_ACTIVATION_RELU; break; - case Op::TANH: + case OperatorType::TANH: mode = CUDNN_ACTIVATION_TANH; break; - case Op::ELU: + case OperatorType::ELU: mode = CUDNN_ACTIVATION_ELU; break; default: @@ -88,52 +76,65 @@ ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, return {inputTensor, outputTensor, actiDesc}; } +ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, + ArrayShape const &output_shape, + ElementUnaryAttrs const &attrs) { + return init_kernel(input_shape, output_shape, get_op_type(attrs)); +} + +ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, + ArrayShape const &output_shape, + ElementScalarUnaryAttrs const &attrs) { + return init_kernel(input_shape, output_shape, get_op_type(attrs)); +} + + template __global__ void elewise_unary_forward_kernel( - coord_t volume, T const scalar, OperatorType type, T const *in, T *out) { + coord_t volume, T scalar, OperatorType type, T const *in, T *out) { CUDA_KERNEL_LOOP(i, volume) { switch (type) { - case Op::EXP: { + case OperatorType::EXP: { out[i] = (T)exp((float)in[i]); break; } - case Op::IDENTITY: { + case OperatorType::IDENTITY: { out[i] = in[i]; break; } - case Op::SCALAR_MULTIPLY: { + case OperatorType::SCALAR_MULTIPLY: { out[i] = in[i] * scalar; break; } - case Op::SCALAR_ADD: { + case OperatorType::SCALAR_ADD: { out[i] = in[i] + scalar; break; } - case Op::SCALAR_SUB: { + case OperatorType::SCALAR_SUB: { out[i] = in[i] - scalar; break; } - case Op::SCALAR_TRUE_DIV: { + case OperatorType::SCALAR_TRUE_DIV: { out[i] = in[i] / scalar; break; } - case Op::GELU: { + case OperatorType::GELU: { out[i] = (T)(in[i] * 0.5 * erfc(-in[i] * M_SQRT1_2)); break; } - case Op::RSQRT: { + case OperatorType::RSQRT: { out[i] = (T)(1.0f / sqrt((float)in[i])); break; } - case Op::POW: { + case OperatorType::POW: { out[i] = (T)(powf(in[i], scalar)); break; } - case Op::SIN: { + case OperatorType::SIN: { out[i] = (T)sin((float)in[i]); break; } - case Op::COS: { + case OperatorType::COS: { out[i] = (T)cos((float)in[i]); break; } @@ -145,7 +146,7 @@ __global__ void elewise_unary_forward_kernel( template __global__ void elewise_unary_backward_kernel(coord_t volume, - T const scalar, + T scalar, OperatorType type, T const *output, T const *output_grad, @@ -153,53 +154,53 @@ __global__ void elewise_unary_backward_kernel(coord_t volume, T *input_grad) { CUDA_KERNEL_LOOP(i, volume) { switch (type) { - case Op::EXP: { + case OperatorType::EXP: { // TODO: change to use output instead of recomputing input_grad[i] += (T)(output_grad[i] * exp((float)input[i])); break; } - case Op::IDENTITY: { + case OperatorType::IDENTITY: { input_grad[i] += output_grad[i]; break; } - case Op::SCALAR_MULTIPLY: { + case OperatorType::SCALAR_MULTIPLY: { input_grad[i] += output_grad[i] * scalar; break; } - case Op::SCALAR_ADD: { + case OperatorType::SCALAR_ADD: { input_grad[i] += output_grad[i]; break; } - case Op::SCALAR_SUB: { + case OperatorType::SCALAR_SUB: { input_grad[i] += output_grad[i]; break; } - case Op::SCALAR_TRUE_DIV: { + case OperatorType::SCALAR_TRUE_DIV: { input_grad[i] += output_grad[i] / scalar; break; } - case Op::GELU: { + case OperatorType::GELU: { input_grad[i] = (T)(output_grad[i] * (0.5 * erfc(-input[i] * M_SQRT1_2) - 0.5 * M_SQRT1_2 * input[i] * exp(-input[i] * input[i] * 0.5))); break; } - case Op::RSQRT: { + case OperatorType::RSQRT: { input_grad[i] = (T)(-0.5f * output_grad[i] * output[i] * output[i] * output[i]); break; } - case Op::POW: { + case OperatorType::POW: { input_grad[i] = (T)(output_grad[i] * scalar * powf(input[i], scalar - 1)); break; } - case Op::SIN: { + case OperatorType::SIN: { input_grad[i] += (T)(output_grad[i] * cos((float)input[i])); break; } - case Op::COS: { + case OperatorType::COS: { input_grad[i] += (T)(output_grad[i] * -sin((float)input[i])); break; } @@ -213,12 +214,12 @@ template struct ForwardKernel { void operator()(ffStream_t stream, ElementUnaryPerDeviceState const &m, - ElementUnaryUnifiedAttrs const &attrs, + OperatorType op_type, + std::optional scalar, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) const { checkCUDNN(cudnnSetStream(handle.dnn, stream)); - Op op_type = get_op_type(attrs); if (use_cudnn(op_type)) { float alpha = 1.0f, beta = 0.0f; checkCUDNN(cudnnActivationForward(handle.dnn, @@ -234,7 +235,7 @@ struct ForwardKernel { elewise_unary_forward_kernel> <<>>( num_elements, - get_scalar>(attrs), + static_cast>(scalar.value()), op_type, input.get(), output.get()); @@ -246,7 +247,8 @@ template struct BackwardKernel { void operator()(ffStream_t stream, ElementUnaryPerDeviceState const &m, - ElementUnaryUnifiedAttrs const &attrs, + OperatorType op_type, + std::optional scalar, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, @@ -254,7 +256,6 @@ struct BackwardKernel { GenericTensorAccessorR const &output_grad) { checkCUDNN(cudnnSetStream(handle.dnn, stream)); - Op op_type = get_op_type(attrs); if (use_cudnn(op_type)) { float alpha = 1.0f; checkCUDNN(cudnnActivationBackward(handle.dnn, @@ -274,7 +275,7 @@ struct BackwardKernel { elewise_unary_backward_kernel> <<>>( num_elements, - get_scalar>(attrs), + static_cast>(scalar.value()), op_type, output.get(), output_grad.get(), @@ -286,17 +287,47 @@ struct BackwardKernel { void forward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryUnifiedAttrs const &attrs, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { + DataTypeDispatch1{}( + input.data_type, stream, device_state, get_op_type(attrs), std::nullopt, handle, input, output); +} + +void forward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementScalarUnaryAttrs const &attrs, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { DataTypeDispatch1{}( - input.data_type, stream, device_state, attrs, handle, input, output); + input.data_type, stream, device_state, get_op_type(attrs), attrs.scalar, handle, input, output); +} + +void backward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad) { + DataTypeDispatch1{}(input.data_type, + stream, + device_state, + get_op_type(attrs), + std::nullopt, + handle, + input, + input_grad, + output, + output_grad); } void backward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryUnifiedAttrs const &attrs, + ElementScalarUnaryAttrs const &attrs, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, @@ -305,7 +336,8 @@ void backward_kernel(ffStream_t stream, DataTypeDispatch1{}(input.data_type, stream, device_state, - attrs, + get_op_type(attrs), + attrs.scalar, handle, input, input_grad, diff --git a/lib/kernels/src/cuda/ops/gather_kernels.cu b/lib/kernels/src/cuda/ops/gather_kernels.cu index 286acf7376..4696b5022e 100644 --- a/lib/kernels/src/cuda/ops/gather_kernels.cu +++ b/lib/kernels/src/cuda/ops/gather_kernels.cu @@ -127,8 +127,8 @@ void forward_kernel(ffStream_t stream, coord_t stride = output.shape - .sub_shape(std::nullopt, legion_dim_t{m.legion_dim.value() + 1}) - .get_volume(); + .sub_shape(std::nullopt, add_to_legion_dim(m.legion_dim, 1)) + .num_elements(); coord_t output_dim_size = output.shape[m.legion_dim]; coord_t input_dim_size = input.shape[m.legion_dim]; @@ -155,7 +155,7 @@ void backward_kernel(ffStream_t stream, coord_t stride = output_grad.shape - .sub_shape(std::nullopt, legion_dim_t{m.legion_dim.value() + 1}) + .sub_shape(std::nullopt, add_to_legion_dim(m.legion_dim, 1)) .get_volume(); coord_t output_dim_size = output_grad.shape[m.legion_dim]; coord_t input_dim_size = input_grad.shape[m.legion_dim]; diff --git a/lib/kernels/src/cuda/ops/linear_kernels.cu b/lib/kernels/src/cuda/ops/linear_kernels.cu index 81ab34380e..9a36534a1b 100644 --- a/lib/kernels/src/cuda/ops/linear_kernels.cu +++ b/lib/kernels/src/cuda/ops/linear_kernels.cu @@ -253,9 +253,8 @@ void backward_kernel(cudaStream_t stream, // do nothing } else { RegularizerAttrs regularizer_attrs = m.regularizer.value(); - if (std::holds_alternative(regularizer_attrs)) { - L2RegularizerAttrs l2_attrs = - std::get(regularizer_attrs); + if (regularizer_attrs.has()) { + L2RegularizerAttrs l2_attrs = regularizer_attrs.get(); float lambda = l2_attrs.lambda; checkCUDA(cublasSgeam(m.handle.blas, CUBLAS_OP_N, diff --git a/lib/kernels/src/cuda/ops/reduce_kernels.cu b/lib/kernels/src/cuda/ops/reduce_kernels.cu index 8571219648..02a89da807 100644 --- a/lib/kernels/src/cuda/ops/reduce_kernels.cu +++ b/lib/kernels/src/cuda/ops/reduce_kernels.cu @@ -71,10 +71,10 @@ void backward_kernel(cudaStream_t stream, checkCUDNN(cudnnSetStream(m.handle.dnn, stream)); float alpha = 1.0, beta = 1.0f; switch (m.op_type) { - case Op::REDUCE_SUM: + case OperatorType::REDUCE_SUM: alpha = 1.0f; break; - case Op::REDUCE_MEAN: + case OperatorType::REDUCE_MEAN: // When the output is the average of multiple input elements // we need to scale the gradients by 1.0 / reduction_size alpha = 1.0f / m.reduction_size; diff --git a/lib/kernels/src/cuda/ops/transpose_kernels.cu b/lib/kernels/src/cuda/ops/transpose_kernels.cu index 7dac25d0c9..3b3f80944d 100644 --- a/lib/kernels/src/cuda/ops/transpose_kernels.cu +++ b/lib/kernels/src/cuda/ops/transpose_kernels.cu @@ -33,10 +33,10 @@ TransposePerDeviceState init_kernel(int num_dim, std::vector const &perm) { int const length = perm.size(); - std::vector perm_vector; + std::vector perm_vector; assert(length <= MAX_TENSOR_DIM); for (int i = 0; i < length; ++i) { - perm_vector.push_back(perm[i].value()); + perm_vector.push_back(legion_dim_from_ff_dim(perm[i], num_dim)); } return {num_dim, perm_vector}; @@ -77,7 +77,7 @@ void forward_kernel(cudaStream_t stream, info.in_strides[i] = info.in_strides[i - 1] * in_dim_size; info.out_strides[i] = info.out_strides[i - 1] * out_dim_size; } - info.perm[i] = m.perm[i]; + info.perm[i] = m.perm[i].value; } transpose_simple_kernel<< + +namespace FlexFlow { +legion_dim_t::legion_dim_t(int const &value) : value(value) {} +bool legion_dim_t::operator==(legion_dim_t const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool legion_dim_t::operator!=(legion_dim_t const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool legion_dim_t::operator<(legion_dim_t const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool legion_dim_t::operator>(legion_dim_t const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool legion_dim_t::operator<=(legion_dim_t const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool legion_dim_t::operator>=(legion_dim_t const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::legion_dim_t const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::legion_dim_t + adl_serializer::from_json(json const &j) { + return {j.at("value").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::legion_dim_t const &v) { + j["__type"] = "legion_dim_t"; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(legion_dim_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, legion_dim_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h index 355c28fcdc..f4d932845f 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h @@ -3,7 +3,7 @@ // lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml /* proj-data { - "generated_from": "87f6e4db4b66d564530994773c0ecef4" + "generated_from": "de62a505821a59c4b77197c100e204f7" } */ @@ -12,10 +12,10 @@ #include "fmt/format.h" #include "nlohmann/json.hpp" +#include "op-attrs/dim_ordered.h" #include "op-attrs/ff_dim.dtg.h" #include "op-attrs/ff_dim.h" #include "rapidcheck.h" -#include "utils/stack_vector.h" #include #include #include @@ -23,8 +23,7 @@ namespace FlexFlow { struct TransposeAttrs { TransposeAttrs() = delete; - TransposeAttrs(::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, - MAX_TENSOR_DIM> const &perm); + TransposeAttrs(::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t> const &perm); bool operator==(TransposeAttrs const &) const; bool operator!=(TransposeAttrs const &) const; @@ -32,7 +31,7 @@ struct TransposeAttrs { bool operator>(TransposeAttrs const &) const; bool operator<=(TransposeAttrs const &) const; bool operator>=(TransposeAttrs const &) const; - ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> perm; + ::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t> perm; }; } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml index aab525b7e6..756091f653 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml @@ -12,9 +12,9 @@ features = [ includes = [ "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h", - "utils/stack_vector.h", + "op-attrs/dim_ordered.h", ] [[fields]] name = "perm" -type = "::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>" +type = "::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>" diff --git a/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc index 78f03d0815..0a774b992e 100644 --- a/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc +++ b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc @@ -3,20 +3,20 @@ // lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml /* proj-data { - "generated_from": "87f6e4db4b66d564530994773c0ecef4" + "generated_from": "de62a505821a59c4b77197c100e204f7" } */ #include "op-attrs/ops/transpose_attrs.dtg.h" +#include "op-attrs/dim_ordered.h" #include "op-attrs/ff_dim.dtg.h" #include "op-attrs/ff_dim.h" -#include "utils/stack_vector.h" #include namespace FlexFlow { TransposeAttrs::TransposeAttrs( - ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> const &perm) + ::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t> const &perm) : perm(perm) {} bool TransposeAttrs::operator==(TransposeAttrs const &other) const { return std::tie(this->perm) == std::tie(other.perm); @@ -42,11 +42,8 @@ namespace std { size_t hash::operator()( FlexFlow::TransposeAttrs const &x) const { size_t result = 0; - result ^= - std::hash< - ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>{}( - x.perm) + - 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>>{}(x.perm) + + 0x9e3779b9 + (result << 6) + (result >> 2); return result; } } // namespace std @@ -54,9 +51,8 @@ size_t hash::operator()( namespace nlohmann { FlexFlow::TransposeAttrs adl_serializer::from_json(json const &j) { - return {j.at("perm") - .template get<::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, - MAX_TENSOR_DIM>>()}; + return { + j.at("perm").template get<::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>>()}; } void adl_serializer::to_json( json &j, FlexFlow::TransposeAttrs const &v) { @@ -68,8 +64,7 @@ void adl_serializer::to_json( namespace rc { Gen Arbitrary::arbitrary() { return gen::construct( - gen::arbitrary< - ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>()); + gen::arbitrary<::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>>()); } } // namespace rc diff --git a/lib/substitution-generator/test/substitution-generator/json.cc b/lib/substitution-generator/test/substitution-generator/json.cc index d12b294a2e..3b177f2bfe 100644 --- a/lib/substitution-generator/test/substitution-generator/json.cc +++ b/lib/substitution-generator/test/substitution-generator/json.cc @@ -18,7 +18,7 @@ TEST_SUITE(FF_TEST_SUITE) { Operator o; from_json(j, o); - CHECK(o.op_type == Op::EW_ADD); + CHECK(o.op_type == OperatorType::EW_ADD); CHECK(o.input.size() == 2); CHECK(o.input[0].opId == -2); CHECK(o.input[0].tsId == 0); diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index 28b0d2e37f..e168760c3b 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -279,8 +279,6 @@ std::optional get_attribute(ReductionAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return get_op_type(p); - case OperatorAttributeKey::PARALLEL_OP_DIM: - return p.reduction_dim; case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.reduction_degree; default: @@ -307,8 +305,6 @@ std::optional get_attribute(ReplicateAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return get_op_type(p); - case OperatorAttributeKey::PARALLEL_OP_DIM: - return p.replicate_dim; case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.replicate_degree; default: From 115dd491d189e5771fd187ef6093b38569dddfea Mon Sep 17 00:00:00 2001 From: Qinghan Chen Date: Mon, 3 Jun 2024 17:13:50 -0400 Subject: [PATCH 34/43] change lcov in ci to rm dtgen coverage --- .github/workflows/per-lib-check.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 4141b471fc..70308f990b 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -111,7 +111,8 @@ jobs: - name: Generate code coverage run: | lcov --capture --directory . --output-file main_coverage.info - lcov --remove main_coverage.info '/nix/store/' --output-file main_coverage.info + lcov --extract main_coverage.info 'lib/*' --output-file main_coverage.info + lcov --remove main_coverage.info 'lib/*.dtg.h' 'lib/*.dtg.cc' --output-file main_coverage.info lcov --list main_coverage.info - name: Upload code coverage From 4462af2469dc24cdd38df10b9c1f4e753bad9545 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 3 Jun 2024 14:23:53 -0700 Subject: [PATCH 35/43] Remove dtgen from coverage --- flake.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flake.lock b/flake.lock index 83750e29e7..f0fc292a5e 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1717446123, - "narHash": "sha256-KHdIUG5LWJn5jLLbaCLifzanffn9GfDrf7HU3VYu0Iw=", + "lastModified": 1717449667, + "narHash": "sha256-xFGnB44WadxlCa2LnlH82g1c89+7UAomVgytIewSwO0=", "owner": "lockshaw", "repo": "proj", - "rev": "1dd210d3ed69392f221721664dff23979872eb6f", + "rev": "28b37a9bd993d3de3d80695eb3834a0436c805a4", "type": "github" }, "original": { From 112421834ae708e534ce57e08cc4ed5fb06e4e87 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 3 Jun 2024 15:29:04 -0700 Subject: [PATCH 36/43] Temporarily disable substitutions build in CI --- .github/workflows/per-lib-check.yml | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 70308f990b..687515b4b7 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -72,13 +72,13 @@ jobs: run: | build_libs.sh kernels - - name: Build substitutions - run: | - build_libs.sh substitutions + # - name: Build substitutions + # run: | + # build_libs.sh substitutions - - name: Build compiler - run: | - build_libs.sh compiler + # - name: Build compiler + # run: | + # build_libs.sh compiler - name: Build substitution-generator run: | @@ -96,13 +96,13 @@ jobs: run: | test_libs.sh pcg - - name: Test substitutions - run: | - test_libs.sh substitutions + # - name: Test substitutions + # run: | + # test_libs.sh substitutions - - name: Test compiler - run: | - test_libs.sh compiler + # - name: Test compiler + # run: | + # test_libs.sh compiler - name: Test substitution-generator run: | From 5109ca9ee82dea67c331964e4d33f99bcfad867e Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 3 Jun 2024 15:30:13 -0700 Subject: [PATCH 37/43] Format --- lib/kernels/include/kernels/array_shape.h | 8 +++--- lib/kernels/include/kernels/legion_dim.h | 2 +- .../src/cuda/ops/element_unary_kernels.cu | 25 +++++++++++++------ lib/kernels/src/cuda/ops/gather_kernels.cu | 3 +-- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/lib/kernels/include/kernels/array_shape.h b/lib/kernels/include/kernels/array_shape.h index 7a37299b24..1cb10e8ce7 100644 --- a/lib/kernels/include/kernels/array_shape.h +++ b/lib/kernels/include/kernels/array_shape.h @@ -6,8 +6,8 @@ #include "utils/stack_vector.h" #include "utils/visitable.h" #include -#include #include +#include namespace FlexFlow { @@ -43,8 +43,10 @@ struct ArrayShape { ArrayShape reversed_dim_order() const; - ArrayShape sub_shape(std::optional> start, - std::optional> end) const; + ArrayShape + sub_shape(std::optional> start, + std::optional> end) const; + public: LegionTensorDims dims; }; diff --git a/lib/kernels/include/kernels/legion_dim.h b/lib/kernels/include/kernels/legion_dim.h index b283a4c4e4..cf6ebfc2d4 100644 --- a/lib/kernels/include/kernels/legion_dim.h +++ b/lib/kernels/include/kernels/legion_dim.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_KERNELS_INCLUDE_KERNELS_LEGION_DIM_H #define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_LEGION_DIM_H -#include "op-attrs/dim_ordered.h" #include "kernels/legion_dim_t.dtg.h" +#include "op-attrs/dim_ordered.h" namespace FlexFlow { diff --git a/lib/kernels/src/cuda/ops/element_unary_kernels.cu b/lib/kernels/src/cuda/ops/element_unary_kernels.cu index 06bde6a9ef..e37d32c325 100644 --- a/lib/kernels/src/cuda/ops/element_unary_kernels.cu +++ b/lib/kernels/src/cuda/ops/element_unary_kernels.cu @@ -36,8 +36,8 @@ static bool use_cudnn(OperatorType op_type) { } static ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, - ArrayShape const &output_shape, - OperatorType op_type) { + ArrayShape const &output_shape, + OperatorType op_type) { ffTensorDescriptor_t inputTensor; ffTensorDescriptor_t outputTensor; @@ -88,7 +88,6 @@ ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, return init_kernel(input_shape, output_shape, get_op_type(attrs)); } - template __global__ void elewise_unary_forward_kernel( coord_t volume, T scalar, OperatorType type, T const *in, T *out) { @@ -291,8 +290,14 @@ void forward_kernel(ffStream_t stream, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - DataTypeDispatch1{}( - input.data_type, stream, device_state, get_op_type(attrs), std::nullopt, handle, input, output); + DataTypeDispatch1{}(input.data_type, + stream, + device_state, + get_op_type(attrs), + std::nullopt, + handle, + input, + output); } void forward_kernel(ffStream_t stream, @@ -301,8 +306,14 @@ void forward_kernel(ffStream_t stream, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - DataTypeDispatch1{}( - input.data_type, stream, device_state, get_op_type(attrs), attrs.scalar, handle, input, output); + DataTypeDispatch1{}(input.data_type, + stream, + device_state, + get_op_type(attrs), + attrs.scalar, + handle, + input, + output); } void backward_kernel(ffStream_t stream, diff --git a/lib/kernels/src/cuda/ops/gather_kernels.cu b/lib/kernels/src/cuda/ops/gather_kernels.cu index 4696b5022e..e002cf7e71 100644 --- a/lib/kernels/src/cuda/ops/gather_kernels.cu +++ b/lib/kernels/src/cuda/ops/gather_kernels.cu @@ -126,8 +126,7 @@ void forward_kernel(ffStream_t stream, checkCUDA(get_legion_stream(&stream)); coord_t stride = - output.shape - .sub_shape(std::nullopt, add_to_legion_dim(m.legion_dim, 1)) + output.shape.sub_shape(std::nullopt, add_to_legion_dim(m.legion_dim, 1)) .num_elements(); coord_t output_dim_size = output.shape[m.legion_dim]; coord_t input_dim_size = input.shape[m.legion_dim]; From 0618cfe070f778dcf1772ff25da809480377b482 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 3 Jun 2024 15:58:41 -0700 Subject: [PATCH 38/43] Fix substitution-generator build and tests --- .proj.toml | 2 + .../include/substitution-generator/json.h | 162 +--- .../legacy_operator_type.dtg.h | 124 +++ .../legacy_operator_type.enum.toml | 95 +++ .../legacy_pm_parameter.dtg.h | 73 ++ .../legacy_pm_parameter.enum.toml | 44 ++ .../src/substitution-generator/json.cc | 25 +- .../legacy_operator_type.dtg.cc | 721 ++++++++++++++++++ .../legacy_pm_parameter.dtg.cc | 313 ++++++++ .../test/substitution-generator/json.cc | 2 +- 10 files changed, 1389 insertions(+), 172 deletions(-) create mode 100644 lib/substitution-generator/include/substitution-generator/legacy_operator_type.dtg.h create mode 100644 lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml create mode 100644 lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.dtg.h create mode 100644 lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml create mode 100644 lib/substitution-generator/src/substitution-generator/legacy_operator_type.dtg.cc create mode 100644 lib/substitution-generator/src/substitution-generator/legacy_pm_parameter.dtg.cc diff --git a/.proj.toml b/.proj.toml index 43f1522186..b076671498 100644 --- a/.proj.toml +++ b/.proj.toml @@ -10,6 +10,7 @@ build_targets = [ "pcg", # "substitutions", # "compiler", + "substitution-generator", ] test_targets = [ "utils-tests", @@ -17,6 +18,7 @@ test_targets = [ "pcg-tests", # "substitutions-tests", # "compiler-tests", + "substitution-generator-tests", ] [cmake_flags_extra] diff --git a/lib/substitution-generator/include/substitution-generator/json.h b/lib/substitution-generator/include/substitution-generator/json.h index 54d923a378..930d9c3f3f 100644 --- a/lib/substitution-generator/include/substitution-generator/json.h +++ b/lib/substitution-generator/include/substitution-generator/json.h @@ -1,166 +1,16 @@ #ifndef _FLEXFLOW_SUBSTITUTION_LOADER_H #define _FLEXFLOW_SUBSTITUTION_LOADER_H -#include "op-attrs/operator_type.h" #include #include +#include "substitution-generator/legacy_operator_type.dtg.h" +#include "substitution-generator/legacy_pm_parameter.dtg.h" +#include namespace FlexFlow { -enum PMParameter { - PM_OP_TYPE, // AnyOp - PM_NUM_INPUTS, // AnyOp - PM_NUM_OUTPUTS, // AnyOp - PM_GROUP, // Conv2D - PM_KERNEL_H, // Conv2D, Pool2D - PM_KERNEL_W, // Conv2D, Pool2D - PM_STRIDE_H, // Conv2D, Pool2D - PM_STRIDE_W, // Conv2D, Pool2D - PM_PADDING_H, // Conv2D, Pool2D - PM_PADDING_W, // Conv2D, Pool2D - PM_ACTI, // Conv2D, Pool2D - PM_NUMDIM, // Concat, Transpose - PM_AXIS, // Concat, Split - PM_PERM, // Transpose - PM_OUTSHUFFLE, // Transpose - PM_MERGE_GCONV_COUNT, // MergeGConv - PM_AXES, // Squeeze, Unsqueeze, Reduce* - PM_KEEP_DIMS, // Reduce* - PM_EPSILON, // BatchNorm - PM_REPARTITION_DIM, // Repartition - PM_REPARTITION_DEGREE, // Repartition - PM_REPLICATE_DIM, // Replicate - PM_REPLICATE_DEGREE, // Replicate - PM_COMBINE_DIM, // Combine - PM_COMBINE_DEGREE, // Combine - PM_REDUCTION_DIM, // Reduction - PM_REDUCTION_DEGREE, // Reduction - PM_SOFTMAX_DIM, // Softmax - PM_NUM_HEADS, // MultiHeadAttention - PM_INVALID, - PM_PARALLEL_DIM, - PM_PARALLEL_DEGREE, - PM_PAD, -}; - -NLOHMANN_JSON_SERIALIZE_ENUM(PMParameter, - {{PM_INVALID, nullptr}, - {PM_OP_TYPE, "PM_OP_TYPE"}, - {PM_NUM_INPUTS, "PM_NUM_INPUTS"}, - {PM_NUM_OUTPUTS, "PM_NUM_OUTPUTS"}, - {PM_GROUP, "PM_GROUP"}, - {PM_KERNEL_H, "PM_KERNEL_H"}, - {PM_KERNEL_W, "PM_KERNEL_W"}, - {PM_STRIDE_H, "PM_STRIDE_H"}, - {PM_STRIDE_W, "PM_STRIDE_W"}, - {PM_PADDING_H, "PM_PADDING_H"}, - {PM_PADDING_W, "PM_PADDING_W"}, - {PM_ACTI, "PM_ACTI"}, - {PM_NUMDIM, "PM_NUMDIM"}, - {PM_AXIS, "PM_AXIS"}, - {PM_PERM, "PM_PERM"}, - {PM_OUTSHUFFLE, "PM_OUTSHUFFLE"}, - {PM_MERGE_GCONV_COUNT, "PM_MERGE_GCONV_COUNT"}, - {PM_AXES, "PM_AXES"}, - {PM_KEEP_DIMS, "PM_KEEP_DIMS"}, - {PM_EPSILON, "PM_EPSILON"}, - {PM_REPARTITION_DIM, "PM_REPARTITION_DIM"}, - {PM_REPARTITION_DEGREE, "PM_REPARTITION_DEGREE"}, - {PM_REPLICATE_DIM, "PM_REPLICATE_DIM"}, - {PM_REPLICATE_DEGREE, "PM_REPLICATE_DEGREE"}, - {PM_COMBINE_DIM, "PM_COMBINE_DIM"}, - {PM_COMBINE_DEGREE, "PM_COMBINE_DEGREE"}, - {PM_REDUCTION_DIM, "PM_REDUCTION_DIM"}, - {PM_REDUCTION_DEGREE, "PM_REDUCTION_DEGREE"}, - {PM_SOFTMAX_DIM, "PM_SOFTMAX_DIM"}, - {PM_NUM_HEADS, "PM_NUM_HEADS"}, - {PM_PARALLEL_DIM, "PM_PARALLEL_DIM"}, - {PM_PARALLEL_DEGREE, "PM_PARALLEL_DEGREE"}, - {PM_PAD, "PM_PAD"}}) - -NLOHMANN_JSON_SERIALIZE_ENUM( - Op, - {{OperatorType::NOOP, "OP_NOOP"}, - {OperatorType::CONV2D, "OP_CONV2D"}, - {OperatorType::DROPOUT, "OP_DROPOUT"}, - {OperatorType::LINEAR, "OP_LINEAR"}, - {OperatorType::BATCHMATMUL, "OP_BATCHMATMUL"}, - {OperatorType::POOL2D, "OP_POOL2D_MAX"}, - {OperatorType::SCALAR_MULTIPLY, "OP_SCALAR_MULTIPLY"}, - {OperatorType::SCALAR_ADD, "OP_SCALAR_ADD"}, - {OperatorType::SCALAR_FLOOR_DIV, "OP_SCALAR_FLOOR_DIV"}, - {OperatorType::SCALAR_TRUE_DIV, "OP_SCALAR_TRUE_DIV"}, - {OperatorType::SCALAR_SUB, "OP_SCALAR_SUB"}, - {OperatorType::RELU, "OP_RELU"}, - {OperatorType::IDENTITY, "OP_IDENTITY"}, - {OperatorType::SIGMOID, "OP_SIGMOID"}, - {OperatorType::TANH, "OP_TANH"}, - {OperatorType::ELU, "OP_ELU"}, - {OperatorType::FLAT, "OP_FLAT"}, - {OperatorType::SOFTMAX, "OP_SOFTMAX"}, - {OperatorType::BATCHNORM, "OP_BATCHNORM"}, - {OperatorType::CONCAT, "OP_CONCAT"}, - {OperatorType::SPLIT, "OP_SPLIT"}, - {OperatorType::EMBEDDING, "OP_EMBEDDING"}, - {OperatorType::CACHE, "OP_CACHE"}, - {OperatorType::RESHAPE, "OP_RESHAPE"}, - {OperatorType::REVERSE, "OP_REVERSE"}, - {OperatorType::TRANSPOSE, "OP_TRANSPOSE"}, - {OperatorType::EW_ADD, "OP_EW_ADD"}, - {OperatorType::EW_MUL, "OP_EW_MUL"}, - {OperatorType::MATMUL, "OP_MATMUL"}, - {OperatorType::MUL, "OP_MUL"}, - {OperatorType::ENLARGE, "OP_ENLARGE"}, - {OperatorType::SQUEEZE, "OP_SQUEEZE"}, - {OperatorType::UNSQUEEZE, "OP_UNSQUEEZE"}, - {OperatorType::EW_SUB, "OP_EW_SUB"}, - {OperatorType::EW_DIV, "OP_EW_DIV"}, - {OperatorType::EW_EQUAL, "OP_EW_EQUAL"}, - {OperatorType::EW_GREATER, "OP_EW_GREATER"}, - {OperatorType::EW_LESS, "OP_EW_LESS"}, - {OperatorType::EW_MAX, "OP_EW_MAX"}, - {OperatorType::EW_MIN, "OP_EW_MIN"}, - {OperatorType::REDUCE_ARGMAX, "OP_REDUCE_ARGMAX"}, - {OperatorType::REDUCE_ARGMIN, "OP_REDUCE_ARGMIN"}, - {OperatorType::REDUCE_MAX, "OP_REDUCE_MAX"}, - {OperatorType::REDUCE_MEAN, "OP_REDUCE_MEAN"}, - {OperatorType::REDUCE_MIN, "OP_REDUCE_MIN"}, - {OperatorType::REDUCE_PROD, "OP_REDUCE_PROD"}, - {OperatorType::REDUCE_SUM, "OP_REDUCE_SUM"}, - {OperatorType::PAD, "OP_PAD"}, - {OperatorType::SHAPE, "OP_SHAPE"}, - {OperatorType::SIZE, "OP_SIZE"}, - {OperatorType::TOPK, "OP_TOPK"}, - {OperatorType::WHERE, "OP_WHERE"}, - {OperatorType::CEIL, "OP_CEIL"}, - {OperatorType::CAST, "OP_CAST"}, - {OperatorType::EXP, "OP_EXP"}, - {OperatorType::ROUND, "OP_ROUND"}, - {OperatorType::LOG, "OP_LOG"}, - {OperatorType::LOGICAL_NOT, "OP_LOGICAL_NOT"}, - {OperatorType::SQRT, "OP_SQRT"}, - {OperatorType::SIN, "OP_SIN"}, - {OperatorType::COS, "OP_COS"}, - {OperatorType::LEAKYRELU, "OP_LEAKYRELU"}, - {OperatorType::SLICE, "OP_SLICE"}, - {OperatorType::RESIZE, "OP_RESIZE"}, - {OperatorType::PRELU, "OP_PRELU"}, - {OperatorType::GELU, "OP_GELU"}, - {OperatorType::MULTIHEAD_ATTENTION, "OP_MULTIHEAD_ATTENTION"}, - {OperatorType::FUSED, "OP_FUSED"}, - {OperatorType::RSQRT, "OP_RSQRT"}, - {OperatorType::POW, "OP_POW"}, - {OperatorType::MEAN, "OP_MEAN"}, - {OperatorType::LAYERNORM, "OP_LAYERNORM"}, - {OperatorType::REPARTITION, "OP_PARTITION"}, - {OperatorType::COMBINE, "OP_COMBINE"}, - {OperatorType::REPLICATE, "OP_REPLICATE"}, - {OperatorType::REDUCTION, "OP_REDUCE"}, - {OperatorType::PIPELINE, "OP_PIPELINE"}, - {OperatorType::FUSED_PARALLEL, "OP_FUSED_PARALLEL"}}) - struct Parameter { - PMParameter key; + LegacyPMParameter key; int value; }; void from_json(nlohmann::json const &j, Parameter &p); @@ -172,11 +22,11 @@ struct Tensor { void from_json(nlohmann::json const &j, Tensor &t); struct Operator { - OperatorType op_type; + LegacyOperatorType op_type; std::vector input; std::vector para; - std::optional at(PMParameter key) const; + std::optional at(LegacyPMParameter key) const; }; void from_json(nlohmann::json const &j, Operator &t); diff --git a/lib/substitution-generator/include/substitution-generator/legacy_operator_type.dtg.h b/lib/substitution-generator/include/substitution-generator/legacy_operator_type.dtg.h new file mode 100644 index 0000000000..a74f6ef4f3 --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_operator_type.dtg.h @@ -0,0 +1,124 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml +/* proj-data +{ + "generated_from": "d6ba52e2b0d58b7cb533dae3894b0486" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_OPERATOR_TYPE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_OPERATOR_TYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class LegacyOperatorType { + NOOP, + INPUT, + WEIGHT, + CONV2D, + DROPOUT, + LINEAR, + BATCHMATMUL, + POOL2D, + SCALAR_MULTIPLY, + SCALAR_ADD, + SCALAR_FLOOR_DIV, + SCALAR_TRUE_DIV, + SCALAR_SUB, + RELU, + IDENTITY, + SIGMOID, + TANH, + ELU, + FLAT, + SOFTMAX, + BATCHNORM, + CONCAT, + SPLIT, + EMBEDDING, + CACHE, + RESHAPE, + REVERSE, + TRANSPOSE, + EW_ADD, + EW_MUL, + MATMUL, + MUL, + ENLARGE, + SQUEEZE, + UNSQUEEZE, + EW_SUB, + EW_DIV, + EW_EQUAL, + EW_GREATER, + EW_LESS, + EW_MAX, + EW_MIN, + REDUCE_ARGMAX, + REDUCE_ARGMIN, + REDUCE_MAX, + REDUCE_MEAN, + REDUCE_MIN, + REDUCE_PROD, + REDUCE_SUM, + PAD, + SHAPE, + SIZE, + TOPK, + WHERE, + CEIL, + CAST, + EXP, + ROUND, + LOG, + LOGICAL_NOT, + SQRT, + SIN, + COS, + LEAKYRELU, + SLICE, + RESIZE, + PRELU, + GELU, + MULTIHEAD_ATTENTION, + FUSED, + RSQRT, + POW, + MEAN, + LAYERNORM, + GATHER, + BROADCAST, + REPARTITION, + COMBINE, + REPLICATE, + REDUCTION, + BATCH, + PIPELINE, + FUSED_PARALLEL +}; +std::string format_as(LegacyOperatorType); +std::ostream &operator<<(std::ostream &, LegacyOperatorType); +void to_json(::nlohmann::json &, LegacyOperatorType); +void from_json(::nlohmann::json const &, LegacyOperatorType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LegacyOperatorType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_OPERATOR_TYPE_DTG_H diff --git a/lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml b/lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml new file mode 100644 index 0000000000..3f0bcccf6f --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml @@ -0,0 +1,95 @@ +namespace = "FlexFlow" +name = "LegacyOperatorType" + +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +values = [ + { name = "NOOP", json_key = "OP_NOOP" }, + { name = "INPUT", json_key = "OP_INPUT" }, + { name = "WEIGHT", json_key = "OP_WEIGHT" }, + { name = "CONV2D", json_key = "OP_CONV2D" }, + { name = "DROPOUT", json_key = "OP_DROPOUT" }, + { name = "LINEAR", json_key = "OP_LINEAR" }, + { name = "BATCHMATMUL", json_key = "OP_BATCHMATMUL" }, + { name = "POOL2D", json_key = "OP_POOL2D" }, + { name = "SCALAR_MULTIPLY", json_key = "OP_SCALAR_MULTIPLY" }, + { name = "SCALAR_ADD", json_key = "OP_SCALAR_ADD" }, + { name = "SCALAR_FLOOR_DIV", json_key = "OP_SCALAR_FLOOR_DIV" }, + { name = "SCALAR_TRUE_DIV", json_key = "OP_SCALAR_TRUE_DIV" }, + { name = "SCALAR_SUB", json_key = "OP_SCALAR_SUB" }, + { name = "RELU", json_key = "OP_RELU" }, + { name = "IDENTITY", json_key = "OP_IDENTITY" }, + { name = "SIGMOID", json_key = "OP_SIGMOID" }, + { name = "TANH", json_key = "OP_TANH" }, + { name = "ELU", json_key = "OP_ELU" }, + { name = "FLAT", json_key = "OP_FLAT" }, + { name = "SOFTMAX", json_key = "OP_SOFTMAX" }, + { name = "BATCHNORM", json_key = "OP_BATCHNORM" }, + { name = "CONCAT", json_key = "OP_CONCAT" }, + { name = "SPLIT", json_key = "OP_SPLIT" }, + { name = "EMBEDDING", json_key = "OP_EMBEDDING" }, + { name = "CACHE", json_key = "OP_CACHE" }, + { name = "RESHAPE", json_key = "OP_RESHAPE" }, + { name = "REVERSE", json_key = "OP_REVERSE" }, + { name = "TRANSPOSE", json_key = "OP_TRANSPOSE" }, + { name = "EW_ADD", json_key = "OP_EW_ADD" }, + { name = "EW_MUL", json_key = "OP_EW_MUL" }, + { name = "MATMUL", json_key = "OP_MATMUL" }, + { name = "MUL", json_key = "OP_MUL" }, + { name = "ENLARGE", json_key = "OP_ENLARGE" }, + { name = "SQUEEZE", json_key = "OP_SQUEEZE" }, + { name = "UNSQUEEZE", json_key = "OP_UNSQUEEZE" }, + { name = "EW_SUB", json_key = "OP_EW_SUB" }, + { name = "EW_DIV", json_key = "OP_EW_DIV" }, + { name = "EW_EQUAL", json_key = "OP_EW_EQUAL" }, + { name = "EW_GREATER", json_key = "OP_EW_GREATER" }, + { name = "EW_LESS", json_key = "OP_EW_LESS" }, + { name = "EW_MAX", json_key = "OP_EW_MAX" }, + { name = "EW_MIN", json_key = "OP_EW_MIN" }, + { name = "REDUCE_ARGMAX", json_key = "OP_REDUCE_ARGMAX" }, + { name = "REDUCE_ARGMIN", json_key = "OP_REDUCE_ARGMIN" }, + { name = "REDUCE_MAX", json_key = "OP_REDUCE_MAX" }, + { name = "REDUCE_MEAN", json_key = "OP_REDUCE_MEAN" }, + { name = "REDUCE_MIN", json_key = "OP_REDUCE_MIN" }, + { name = "REDUCE_PROD", json_key = "OP_REDUCE_PROD" }, + { name = "REDUCE_SUM", json_key = "OP_REDUCE_SUM" }, + { name = "PAD", json_key = "OP_PAD" }, + { name = "SHAPE", json_key = "OP_SHAPE" }, + { name = "SIZE", json_key = "OP_SIZE" }, + { name = "TOPK", json_key = "OP_TOPK" }, + { name = "WHERE", json_key = "OP_WHERE" }, + { name = "CEIL", json_key = "OP_CEIL" }, + { name = "CAST", json_key = "OP_CAST" }, + { name = "EXP", json_key = "OP_EXP" }, + { name = "ROUND", json_key = "OP_ROUND" }, + { name = "LOG", json_key = "OP_LOG" }, + { name = "LOGICAL_NOT", json_key = "OP_LOGICAL_NOT" }, + { name = "SQRT", json_key = "OP_SQRT" }, + { name = "SIN", json_key = "OP_SIN" }, + { name = "COS", json_key = "OP_COS" }, + { name = "LEAKYRELU", json_key = "OP_LEAKYRELU" }, + { name = "SLICE", json_key = "OP_SLICE" }, + { name = "RESIZE", json_key = "OP_RESIZE" }, + { name = "PRELU", json_key = "OP_PRELU" }, + { name = "GELU", json_key = "OP_GELU" }, + { name = "MULTIHEAD_ATTENTION", json_key = "OP_MULTIHEAD_ATTENTION" }, + { name = "FUSED", json_key = "OP_FUSED" }, + { name = "RSQRT", json_key = "OP_RSQRT" }, + { name = "POW", json_key = "OP_POW" }, + { name = "MEAN", json_key = "OP_MEAN" }, + { name = "LAYERNORM", json_key = "OP_LAYERNORM" }, + { name = "GATHER", json_key = "OP_GATHER" }, + { name = "BROADCAST", json_key = "OP_BROADCAST" }, + { name = "REPARTITION", json_key = "OP_PARTITION" }, + { name = "COMBINE", json_key = "OP_COMBINE" }, + { name = "REPLICATE", json_key = "OP_REPLICATE" }, + { name = "REDUCTION", json_key = "OP_REDUCE" }, + { name = "BATCH", json_key = "OP_BATCH" }, + { name = "PIPELINE", json_key = "OP_PIPELINE" }, + { name = "FUSED_PARALLEL", json_key = "OP_FUSED_PARALLEL" }, +] diff --git a/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.dtg.h b/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.dtg.h new file mode 100644 index 0000000000..2435024b9a --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.dtg.h @@ -0,0 +1,73 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml +/* proj-data +{ + "generated_from": "e8dda0047c91e576878b86df2fec0b6b" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_PM_PARAMETER_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_PM_PARAMETER_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class LegacyPMParameter { + OP_TYPE, + NUM_INPUTS, + NUM_OUTPUTS, + GROUP, + KERNEL_H, + KERNEL_W, + STRIDE_H, + STRIDE_W, + PADDING_H, + PADDING_W, + ACTI, + NUMDIM, + AXIS, + PERM, + OUTSHUFFLE, + MERGE_GCONV_COUNT, + AXES, + KEEP_DIMS, + EPSILON, + REPARTITION_DIM, + REPARTITION_DEGREE, + REPLICATE_DIM, + REPLICATE_DEGREE, + COMBINE_DIM, + COMBINE_DEGREE, + REDUCTION_DIM, + REDUCTION_DEGREE, + SOFTMAX_DIM, + NUM_HEADS, + PARALLEL_DIM, + PARALLEL_DEGREE, + PAD +}; +std::string format_as(LegacyPMParameter); +std::ostream &operator<<(std::ostream &, LegacyPMParameter); +void to_json(::nlohmann::json &, LegacyPMParameter); +void from_json(::nlohmann::json const &, LegacyPMParameter &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LegacyPMParameter) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_PM_PARAMETER_DTG_H diff --git a/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml b/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml new file mode 100644 index 0000000000..e71a71a5a8 --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml @@ -0,0 +1,44 @@ +namespace = "FlexFlow" +name = "LegacyPMParameter" + +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +values = [ + { name = "OP_TYPE", json_key = "PM_OP_TYPE" }, + { name = "NUM_INPUTS", json_key = "PM_NUM_INPUTS" }, + { name = "NUM_OUTPUTS", json_key = "PM_NUM_OUTPUTS" }, + { name = "GROUP", json_key = "PM_GROUP" }, + { name = "KERNEL_H", json_key = "PM_KERNEL_H" }, + { name = "KERNEL_W", json_key = "PM_KERNEL_W" }, + { name = "STRIDE_H", json_key = "PM_STRIDE_H" }, + { name = "STRIDE_W", json_key = "PM_STRIDE_W" }, + { name = "PADDING_H", json_key = "PM_PADDING_H" }, + { name = "PADDING_W", json_key = "PM_PADDING_W" }, + { name = "ACTI", json_key = "PM_ACTI" }, + { name = "NUMDIM", json_key = "PM_NUMDIM" }, + { name = "AXIS", json_key = "PM_AXIS" }, + { name = "PERM", json_key = "PM_PERM" }, + { name = "OUTSHUFFLE", json_key = "PM_OUTSHUFFLE" }, + { name = "MERGE_GCONV_COUNT", json_key = "PM_MERGE_GCONV_COUNT" }, + { name = "AXES", json_key = "PM_AXES" }, + { name = "KEEP_DIMS", json_key = "PM_KEEP_DIMS" }, + { name = "EPSILON", json_key = "PM_EPSILON" }, + { name = "REPARTITION_DIM", json_key = "PM_REPARTITION_DIM" }, + { name = "REPARTITION_DEGREE", json_key = "PM_REPARTITION_DEGREE" }, + { name = "REPLICATE_DIM", json_key = "PM_REPLICATE_DIM" }, + { name = "REPLICATE_DEGREE", json_key = "PM_REPLICATE_DEGREE" }, + { name = "COMBINE_DIM", json_key = "PM_COMBINE_DIM" }, + { name = "COMBINE_DEGREE", json_key = "PM_COMBINE_DEGREE" }, + { name = "REDUCTION_DIM", json_key = "PM_REDUCTION_DIM" }, + { name = "REDUCTION_DEGREE", json_key = "PM_REDUCTION_DEGREE" }, + { name = "SOFTMAX_DIM", json_key = "PM_SOFTMAX_DIM" }, + { name = "NUM_HEADS", json_key = "PM_NUM_HEADS" }, + { name = "PARALLEL_DIM", json_key = "PM_PARALLEL_DIM" }, + { name = "PARALLEL_DEGREE", json_key = "PM_PARALLEL_DEGREE" }, + { name = "PAD", json_key = "PM_PAD" }, +] diff --git a/lib/substitution-generator/src/substitution-generator/json.cc b/lib/substitution-generator/src/substitution-generator/json.cc index 7e6a93b863..940ecb3e36 100644 --- a/lib/substitution-generator/src/substitution-generator/json.cc +++ b/lib/substitution-generator/src/substitution-generator/json.cc @@ -10,11 +10,6 @@ namespace FlexFlow { void from_json(json const &j, Parameter &p) { j.at("key").get_to(p.key); j.at("value").get_to(p.value); - if (p.key == PM_INVALID) { - std::ostringstream oss; - oss << "Attempted to load invalid PMParameter: " << j.at("key"); - throw std::runtime_error(oss.str()); - } } void from_json(json const &j, Tensor &t) { @@ -22,17 +17,17 @@ void from_json(json const &j, Tensor &t) { j.at("tsId").get_to(t.tsId); } -std::optional Operator::at(PMParameter key) const { - std::optional value = std::nullopt; - for (Parameter const &p : this->para) { - if (p.key == key) { - assert(!value.has_value()); - value = p.key; - } - } +/* std::optional Operator::at(LegacyPMParameter key) const { */ +/* std::optional value = std::nullopt; */ +/* for (Parameter const &p : this->para) { */ +/* if (p.key == key) { */ +/* assert(!value.has_value()); */ +/* value = p.key; */ +/* } */ +/* } */ - return value; -} +/* return value; */ +/* } */ void from_json(json const &j, Operator &o) { j.at("type").get_to(o.op_type); diff --git a/lib/substitution-generator/src/substitution-generator/legacy_operator_type.dtg.cc b/lib/substitution-generator/src/substitution-generator/legacy_operator_type.dtg.cc new file mode 100644 index 0000000000..94c65e33fd --- /dev/null +++ b/lib/substitution-generator/src/substitution-generator/legacy_operator_type.dtg.cc @@ -0,0 +1,721 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml +/* proj-data +{ + "generated_from": "d6ba52e2b0d58b7cb533dae3894b0486" +} +*/ + +#include "substitution-generator/legacy_operator_type.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::LegacyOperatorType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(LegacyOperatorType x) { + switch (x) { + case LegacyOperatorType::NOOP: + return "NOOP"; + case LegacyOperatorType::INPUT: + return "INPUT"; + case LegacyOperatorType::WEIGHT: + return "WEIGHT"; + case LegacyOperatorType::CONV2D: + return "CONV2D"; + case LegacyOperatorType::DROPOUT: + return "DROPOUT"; + case LegacyOperatorType::LINEAR: + return "LINEAR"; + case LegacyOperatorType::BATCHMATMUL: + return "BATCHMATMUL"; + case LegacyOperatorType::POOL2D: + return "POOL2D"; + case LegacyOperatorType::SCALAR_MULTIPLY: + return "SCALAR_MULTIPLY"; + case LegacyOperatorType::SCALAR_ADD: + return "SCALAR_ADD"; + case LegacyOperatorType::SCALAR_FLOOR_DIV: + return "SCALAR_FLOOR_DIV"; + case LegacyOperatorType::SCALAR_TRUE_DIV: + return "SCALAR_TRUE_DIV"; + case LegacyOperatorType::SCALAR_SUB: + return "SCALAR_SUB"; + case LegacyOperatorType::RELU: + return "RELU"; + case LegacyOperatorType::IDENTITY: + return "IDENTITY"; + case LegacyOperatorType::SIGMOID: + return "SIGMOID"; + case LegacyOperatorType::TANH: + return "TANH"; + case LegacyOperatorType::ELU: + return "ELU"; + case LegacyOperatorType::FLAT: + return "FLAT"; + case LegacyOperatorType::SOFTMAX: + return "SOFTMAX"; + case LegacyOperatorType::BATCHNORM: + return "BATCHNORM"; + case LegacyOperatorType::CONCAT: + return "CONCAT"; + case LegacyOperatorType::SPLIT: + return "SPLIT"; + case LegacyOperatorType::EMBEDDING: + return "EMBEDDING"; + case LegacyOperatorType::CACHE: + return "CACHE"; + case LegacyOperatorType::RESHAPE: + return "RESHAPE"; + case LegacyOperatorType::REVERSE: + return "REVERSE"; + case LegacyOperatorType::TRANSPOSE: + return "TRANSPOSE"; + case LegacyOperatorType::EW_ADD: + return "EW_ADD"; + case LegacyOperatorType::EW_MUL: + return "EW_MUL"; + case LegacyOperatorType::MATMUL: + return "MATMUL"; + case LegacyOperatorType::MUL: + return "MUL"; + case LegacyOperatorType::ENLARGE: + return "ENLARGE"; + case LegacyOperatorType::SQUEEZE: + return "SQUEEZE"; + case LegacyOperatorType::UNSQUEEZE: + return "UNSQUEEZE"; + case LegacyOperatorType::EW_SUB: + return "EW_SUB"; + case LegacyOperatorType::EW_DIV: + return "EW_DIV"; + case LegacyOperatorType::EW_EQUAL: + return "EW_EQUAL"; + case LegacyOperatorType::EW_GREATER: + return "EW_GREATER"; + case LegacyOperatorType::EW_LESS: + return "EW_LESS"; + case LegacyOperatorType::EW_MAX: + return "EW_MAX"; + case LegacyOperatorType::EW_MIN: + return "EW_MIN"; + case LegacyOperatorType::REDUCE_ARGMAX: + return "REDUCE_ARGMAX"; + case LegacyOperatorType::REDUCE_ARGMIN: + return "REDUCE_ARGMIN"; + case LegacyOperatorType::REDUCE_MAX: + return "REDUCE_MAX"; + case LegacyOperatorType::REDUCE_MEAN: + return "REDUCE_MEAN"; + case LegacyOperatorType::REDUCE_MIN: + return "REDUCE_MIN"; + case LegacyOperatorType::REDUCE_PROD: + return "REDUCE_PROD"; + case LegacyOperatorType::REDUCE_SUM: + return "REDUCE_SUM"; + case LegacyOperatorType::PAD: + return "PAD"; + case LegacyOperatorType::SHAPE: + return "SHAPE"; + case LegacyOperatorType::SIZE: + return "SIZE"; + case LegacyOperatorType::TOPK: + return "TOPK"; + case LegacyOperatorType::WHERE: + return "WHERE"; + case LegacyOperatorType::CEIL: + return "CEIL"; + case LegacyOperatorType::CAST: + return "CAST"; + case LegacyOperatorType::EXP: + return "EXP"; + case LegacyOperatorType::ROUND: + return "ROUND"; + case LegacyOperatorType::LOG: + return "LOG"; + case LegacyOperatorType::LOGICAL_NOT: + return "LOGICAL_NOT"; + case LegacyOperatorType::SQRT: + return "SQRT"; + case LegacyOperatorType::SIN: + return "SIN"; + case LegacyOperatorType::COS: + return "COS"; + case LegacyOperatorType::LEAKYRELU: + return "LEAKYRELU"; + case LegacyOperatorType::SLICE: + return "SLICE"; + case LegacyOperatorType::RESIZE: + return "RESIZE"; + case LegacyOperatorType::PRELU: + return "PRELU"; + case LegacyOperatorType::GELU: + return "GELU"; + case LegacyOperatorType::MULTIHEAD_ATTENTION: + return "MULTIHEAD_ATTENTION"; + case LegacyOperatorType::FUSED: + return "FUSED"; + case LegacyOperatorType::RSQRT: + return "RSQRT"; + case LegacyOperatorType::POW: + return "POW"; + case LegacyOperatorType::MEAN: + return "MEAN"; + case LegacyOperatorType::LAYERNORM: + return "LAYERNORM"; + case LegacyOperatorType::GATHER: + return "GATHER"; + case LegacyOperatorType::BROADCAST: + return "BROADCAST"; + case LegacyOperatorType::REPARTITION: + return "REPARTITION"; + case LegacyOperatorType::COMBINE: + return "COMBINE"; + case LegacyOperatorType::REPLICATE: + return "REPLICATE"; + case LegacyOperatorType::REDUCTION: + return "REDUCTION"; + case LegacyOperatorType::BATCH: + return "BATCH"; + case LegacyOperatorType::PIPELINE: + return "PIPELINE"; + case LegacyOperatorType::FUSED_PARALLEL: + return "FUSED_PARALLEL"; + default: + std::ostringstream oss; + oss << "Unknown LegacyOperatorType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, LegacyOperatorType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, LegacyOperatorType x) { + switch (x) { + case LegacyOperatorType::NOOP: + j = "OP_NOOP"; + break; + case LegacyOperatorType::INPUT: + j = "OP_INPUT"; + break; + case LegacyOperatorType::WEIGHT: + j = "OP_WEIGHT"; + break; + case LegacyOperatorType::CONV2D: + j = "OP_CONV2D"; + break; + case LegacyOperatorType::DROPOUT: + j = "OP_DROPOUT"; + break; + case LegacyOperatorType::LINEAR: + j = "OP_LINEAR"; + break; + case LegacyOperatorType::BATCHMATMUL: + j = "OP_BATCHMATMUL"; + break; + case LegacyOperatorType::POOL2D: + j = "OP_POOL2D"; + break; + case LegacyOperatorType::SCALAR_MULTIPLY: + j = "OP_SCALAR_MULTIPLY"; + break; + case LegacyOperatorType::SCALAR_ADD: + j = "OP_SCALAR_ADD"; + break; + case LegacyOperatorType::SCALAR_FLOOR_DIV: + j = "OP_SCALAR_FLOOR_DIV"; + break; + case LegacyOperatorType::SCALAR_TRUE_DIV: + j = "OP_SCALAR_TRUE_DIV"; + break; + case LegacyOperatorType::SCALAR_SUB: + j = "OP_SCALAR_SUB"; + break; + case LegacyOperatorType::RELU: + j = "OP_RELU"; + break; + case LegacyOperatorType::IDENTITY: + j = "OP_IDENTITY"; + break; + case LegacyOperatorType::SIGMOID: + j = "OP_SIGMOID"; + break; + case LegacyOperatorType::TANH: + j = "OP_TANH"; + break; + case LegacyOperatorType::ELU: + j = "OP_ELU"; + break; + case LegacyOperatorType::FLAT: + j = "OP_FLAT"; + break; + case LegacyOperatorType::SOFTMAX: + j = "OP_SOFTMAX"; + break; + case LegacyOperatorType::BATCHNORM: + j = "OP_BATCHNORM"; + break; + case LegacyOperatorType::CONCAT: + j = "OP_CONCAT"; + break; + case LegacyOperatorType::SPLIT: + j = "OP_SPLIT"; + break; + case LegacyOperatorType::EMBEDDING: + j = "OP_EMBEDDING"; + break; + case LegacyOperatorType::CACHE: + j = "OP_CACHE"; + break; + case LegacyOperatorType::RESHAPE: + j = "OP_RESHAPE"; + break; + case LegacyOperatorType::REVERSE: + j = "OP_REVERSE"; + break; + case LegacyOperatorType::TRANSPOSE: + j = "OP_TRANSPOSE"; + break; + case LegacyOperatorType::EW_ADD: + j = "OP_EW_ADD"; + break; + case LegacyOperatorType::EW_MUL: + j = "OP_EW_MUL"; + break; + case LegacyOperatorType::MATMUL: + j = "OP_MATMUL"; + break; + case LegacyOperatorType::MUL: + j = "OP_MUL"; + break; + case LegacyOperatorType::ENLARGE: + j = "OP_ENLARGE"; + break; + case LegacyOperatorType::SQUEEZE: + j = "OP_SQUEEZE"; + break; + case LegacyOperatorType::UNSQUEEZE: + j = "OP_UNSQUEEZE"; + break; + case LegacyOperatorType::EW_SUB: + j = "OP_EW_SUB"; + break; + case LegacyOperatorType::EW_DIV: + j = "OP_EW_DIV"; + break; + case LegacyOperatorType::EW_EQUAL: + j = "OP_EW_EQUAL"; + break; + case LegacyOperatorType::EW_GREATER: + j = "OP_EW_GREATER"; + break; + case LegacyOperatorType::EW_LESS: + j = "OP_EW_LESS"; + break; + case LegacyOperatorType::EW_MAX: + j = "OP_EW_MAX"; + break; + case LegacyOperatorType::EW_MIN: + j = "OP_EW_MIN"; + break; + case LegacyOperatorType::REDUCE_ARGMAX: + j = "OP_REDUCE_ARGMAX"; + break; + case LegacyOperatorType::REDUCE_ARGMIN: + j = "OP_REDUCE_ARGMIN"; + break; + case LegacyOperatorType::REDUCE_MAX: + j = "OP_REDUCE_MAX"; + break; + case LegacyOperatorType::REDUCE_MEAN: + j = "OP_REDUCE_MEAN"; + break; + case LegacyOperatorType::REDUCE_MIN: + j = "OP_REDUCE_MIN"; + break; + case LegacyOperatorType::REDUCE_PROD: + j = "OP_REDUCE_PROD"; + break; + case LegacyOperatorType::REDUCE_SUM: + j = "OP_REDUCE_SUM"; + break; + case LegacyOperatorType::PAD: + j = "OP_PAD"; + break; + case LegacyOperatorType::SHAPE: + j = "OP_SHAPE"; + break; + case LegacyOperatorType::SIZE: + j = "OP_SIZE"; + break; + case LegacyOperatorType::TOPK: + j = "OP_TOPK"; + break; + case LegacyOperatorType::WHERE: + j = "OP_WHERE"; + break; + case LegacyOperatorType::CEIL: + j = "OP_CEIL"; + break; + case LegacyOperatorType::CAST: + j = "OP_CAST"; + break; + case LegacyOperatorType::EXP: + j = "OP_EXP"; + break; + case LegacyOperatorType::ROUND: + j = "OP_ROUND"; + break; + case LegacyOperatorType::LOG: + j = "OP_LOG"; + break; + case LegacyOperatorType::LOGICAL_NOT: + j = "OP_LOGICAL_NOT"; + break; + case LegacyOperatorType::SQRT: + j = "OP_SQRT"; + break; + case LegacyOperatorType::SIN: + j = "OP_SIN"; + break; + case LegacyOperatorType::COS: + j = "OP_COS"; + break; + case LegacyOperatorType::LEAKYRELU: + j = "OP_LEAKYRELU"; + break; + case LegacyOperatorType::SLICE: + j = "OP_SLICE"; + break; + case LegacyOperatorType::RESIZE: + j = "OP_RESIZE"; + break; + case LegacyOperatorType::PRELU: + j = "OP_PRELU"; + break; + case LegacyOperatorType::GELU: + j = "OP_GELU"; + break; + case LegacyOperatorType::MULTIHEAD_ATTENTION: + j = "OP_MULTIHEAD_ATTENTION"; + break; + case LegacyOperatorType::FUSED: + j = "OP_FUSED"; + break; + case LegacyOperatorType::RSQRT: + j = "OP_RSQRT"; + break; + case LegacyOperatorType::POW: + j = "OP_POW"; + break; + case LegacyOperatorType::MEAN: + j = "OP_MEAN"; + break; + case LegacyOperatorType::LAYERNORM: + j = "OP_LAYERNORM"; + break; + case LegacyOperatorType::GATHER: + j = "OP_GATHER"; + break; + case LegacyOperatorType::BROADCAST: + j = "OP_BROADCAST"; + break; + case LegacyOperatorType::REPARTITION: + j = "OP_PARTITION"; + break; + case LegacyOperatorType::COMBINE: + j = "OP_COMBINE"; + break; + case LegacyOperatorType::REPLICATE: + j = "OP_REPLICATE"; + break; + case LegacyOperatorType::REDUCTION: + j = "OP_REDUCE"; + break; + case LegacyOperatorType::BATCH: + j = "OP_BATCH"; + break; + case LegacyOperatorType::PIPELINE: + j = "OP_PIPELINE"; + break; + case LegacyOperatorType::FUSED_PARALLEL: + j = "OP_FUSED_PARALLEL"; + break; + default: + std::ostringstream oss; + oss << "Unknown LegacyOperatorType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, LegacyOperatorType &x) { + std::string as_str = j.get(); + if (as_str == "OP_NOOP") { + x = LegacyOperatorType::NOOP; + } else if (as_str == "OP_INPUT") { + x = LegacyOperatorType::INPUT; + } else if (as_str == "OP_WEIGHT") { + x = LegacyOperatorType::WEIGHT; + } else if (as_str == "OP_CONV2D") { + x = LegacyOperatorType::CONV2D; + } else if (as_str == "OP_DROPOUT") { + x = LegacyOperatorType::DROPOUT; + } else if (as_str == "OP_LINEAR") { + x = LegacyOperatorType::LINEAR; + } else if (as_str == "OP_BATCHMATMUL") { + x = LegacyOperatorType::BATCHMATMUL; + } else if (as_str == "OP_POOL2D") { + x = LegacyOperatorType::POOL2D; + } else if (as_str == "OP_SCALAR_MULTIPLY") { + x = LegacyOperatorType::SCALAR_MULTIPLY; + } else if (as_str == "OP_SCALAR_ADD") { + x = LegacyOperatorType::SCALAR_ADD; + } else if (as_str == "OP_SCALAR_FLOOR_DIV") { + x = LegacyOperatorType::SCALAR_FLOOR_DIV; + } else if (as_str == "OP_SCALAR_TRUE_DIV") { + x = LegacyOperatorType::SCALAR_TRUE_DIV; + } else if (as_str == "OP_SCALAR_SUB") { + x = LegacyOperatorType::SCALAR_SUB; + } else if (as_str == "OP_RELU") { + x = LegacyOperatorType::RELU; + } else if (as_str == "OP_IDENTITY") { + x = LegacyOperatorType::IDENTITY; + } else if (as_str == "OP_SIGMOID") { + x = LegacyOperatorType::SIGMOID; + } else if (as_str == "OP_TANH") { + x = LegacyOperatorType::TANH; + } else if (as_str == "OP_ELU") { + x = LegacyOperatorType::ELU; + } else if (as_str == "OP_FLAT") { + x = LegacyOperatorType::FLAT; + } else if (as_str == "OP_SOFTMAX") { + x = LegacyOperatorType::SOFTMAX; + } else if (as_str == "OP_BATCHNORM") { + x = LegacyOperatorType::BATCHNORM; + } else if (as_str == "OP_CONCAT") { + x = LegacyOperatorType::CONCAT; + } else if (as_str == "OP_SPLIT") { + x = LegacyOperatorType::SPLIT; + } else if (as_str == "OP_EMBEDDING") { + x = LegacyOperatorType::EMBEDDING; + } else if (as_str == "OP_CACHE") { + x = LegacyOperatorType::CACHE; + } else if (as_str == "OP_RESHAPE") { + x = LegacyOperatorType::RESHAPE; + } else if (as_str == "OP_REVERSE") { + x = LegacyOperatorType::REVERSE; + } else if (as_str == "OP_TRANSPOSE") { + x = LegacyOperatorType::TRANSPOSE; + } else if (as_str == "OP_EW_ADD") { + x = LegacyOperatorType::EW_ADD; + } else if (as_str == "OP_EW_MUL") { + x = LegacyOperatorType::EW_MUL; + } else if (as_str == "OP_MATMUL") { + x = LegacyOperatorType::MATMUL; + } else if (as_str == "OP_MUL") { + x = LegacyOperatorType::MUL; + } else if (as_str == "OP_ENLARGE") { + x = LegacyOperatorType::ENLARGE; + } else if (as_str == "OP_SQUEEZE") { + x = LegacyOperatorType::SQUEEZE; + } else if (as_str == "OP_UNSQUEEZE") { + x = LegacyOperatorType::UNSQUEEZE; + } else if (as_str == "OP_EW_SUB") { + x = LegacyOperatorType::EW_SUB; + } else if (as_str == "OP_EW_DIV") { + x = LegacyOperatorType::EW_DIV; + } else if (as_str == "OP_EW_EQUAL") { + x = LegacyOperatorType::EW_EQUAL; + } else if (as_str == "OP_EW_GREATER") { + x = LegacyOperatorType::EW_GREATER; + } else if (as_str == "OP_EW_LESS") { + x = LegacyOperatorType::EW_LESS; + } else if (as_str == "OP_EW_MAX") { + x = LegacyOperatorType::EW_MAX; + } else if (as_str == "OP_EW_MIN") { + x = LegacyOperatorType::EW_MIN; + } else if (as_str == "OP_REDUCE_ARGMAX") { + x = LegacyOperatorType::REDUCE_ARGMAX; + } else if (as_str == "OP_REDUCE_ARGMIN") { + x = LegacyOperatorType::REDUCE_ARGMIN; + } else if (as_str == "OP_REDUCE_MAX") { + x = LegacyOperatorType::REDUCE_MAX; + } else if (as_str == "OP_REDUCE_MEAN") { + x = LegacyOperatorType::REDUCE_MEAN; + } else if (as_str == "OP_REDUCE_MIN") { + x = LegacyOperatorType::REDUCE_MIN; + } else if (as_str == "OP_REDUCE_PROD") { + x = LegacyOperatorType::REDUCE_PROD; + } else if (as_str == "OP_REDUCE_SUM") { + x = LegacyOperatorType::REDUCE_SUM; + } else if (as_str == "OP_PAD") { + x = LegacyOperatorType::PAD; + } else if (as_str == "OP_SHAPE") { + x = LegacyOperatorType::SHAPE; + } else if (as_str == "OP_SIZE") { + x = LegacyOperatorType::SIZE; + } else if (as_str == "OP_TOPK") { + x = LegacyOperatorType::TOPK; + } else if (as_str == "OP_WHERE") { + x = LegacyOperatorType::WHERE; + } else if (as_str == "OP_CEIL") { + x = LegacyOperatorType::CEIL; + } else if (as_str == "OP_CAST") { + x = LegacyOperatorType::CAST; + } else if (as_str == "OP_EXP") { + x = LegacyOperatorType::EXP; + } else if (as_str == "OP_ROUND") { + x = LegacyOperatorType::ROUND; + } else if (as_str == "OP_LOG") { + x = LegacyOperatorType::LOG; + } else if (as_str == "OP_LOGICAL_NOT") { + x = LegacyOperatorType::LOGICAL_NOT; + } else if (as_str == "OP_SQRT") { + x = LegacyOperatorType::SQRT; + } else if (as_str == "OP_SIN") { + x = LegacyOperatorType::SIN; + } else if (as_str == "OP_COS") { + x = LegacyOperatorType::COS; + } else if (as_str == "OP_LEAKYRELU") { + x = LegacyOperatorType::LEAKYRELU; + } else if (as_str == "OP_SLICE") { + x = LegacyOperatorType::SLICE; + } else if (as_str == "OP_RESIZE") { + x = LegacyOperatorType::RESIZE; + } else if (as_str == "OP_PRELU") { + x = LegacyOperatorType::PRELU; + } else if (as_str == "OP_GELU") { + x = LegacyOperatorType::GELU; + } else if (as_str == "OP_MULTIHEAD_ATTENTION") { + x = LegacyOperatorType::MULTIHEAD_ATTENTION; + } else if (as_str == "OP_FUSED") { + x = LegacyOperatorType::FUSED; + } else if (as_str == "OP_RSQRT") { + x = LegacyOperatorType::RSQRT; + } else if (as_str == "OP_POW") { + x = LegacyOperatorType::POW; + } else if (as_str == "OP_MEAN") { + x = LegacyOperatorType::MEAN; + } else if (as_str == "OP_LAYERNORM") { + x = LegacyOperatorType::LAYERNORM; + } else if (as_str == "OP_GATHER") { + x = LegacyOperatorType::GATHER; + } else if (as_str == "OP_BROADCAST") { + x = LegacyOperatorType::BROADCAST; + } else if (as_str == "OP_PARTITION") { + x = LegacyOperatorType::REPARTITION; + } else if (as_str == "OP_COMBINE") { + x = LegacyOperatorType::COMBINE; + } else if (as_str == "OP_REPLICATE") { + x = LegacyOperatorType::REPLICATE; + } else if (as_str == "OP_REDUCE") { + x = LegacyOperatorType::REDUCTION; + } else if (as_str == "OP_BATCH") { + x = LegacyOperatorType::BATCH; + } else if (as_str == "OP_PIPELINE") { + x = LegacyOperatorType::PIPELINE; + } else if (as_str == "OP_FUSED_PARALLEL") { + x = LegacyOperatorType::FUSED_PARALLEL; + } else { + std::ostringstream oss; + oss << "Unknown LegacyOperatorType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::element( + FlexFlow::LegacyOperatorType::NOOP, + FlexFlow::LegacyOperatorType::INPUT, + FlexFlow::LegacyOperatorType::WEIGHT, + FlexFlow::LegacyOperatorType::CONV2D, + FlexFlow::LegacyOperatorType::DROPOUT, + FlexFlow::LegacyOperatorType::LINEAR, + FlexFlow::LegacyOperatorType::BATCHMATMUL, + FlexFlow::LegacyOperatorType::POOL2D, + FlexFlow::LegacyOperatorType::SCALAR_MULTIPLY, + FlexFlow::LegacyOperatorType::SCALAR_ADD, + FlexFlow::LegacyOperatorType::SCALAR_FLOOR_DIV, + FlexFlow::LegacyOperatorType::SCALAR_TRUE_DIV, + FlexFlow::LegacyOperatorType::SCALAR_SUB, + FlexFlow::LegacyOperatorType::RELU, + FlexFlow::LegacyOperatorType::IDENTITY, + FlexFlow::LegacyOperatorType::SIGMOID, + FlexFlow::LegacyOperatorType::TANH, + FlexFlow::LegacyOperatorType::ELU, + FlexFlow::LegacyOperatorType::FLAT, + FlexFlow::LegacyOperatorType::SOFTMAX, + FlexFlow::LegacyOperatorType::BATCHNORM, + FlexFlow::LegacyOperatorType::CONCAT, + FlexFlow::LegacyOperatorType::SPLIT, + FlexFlow::LegacyOperatorType::EMBEDDING, + FlexFlow::LegacyOperatorType::CACHE, + FlexFlow::LegacyOperatorType::RESHAPE, + FlexFlow::LegacyOperatorType::REVERSE, + FlexFlow::LegacyOperatorType::TRANSPOSE, + FlexFlow::LegacyOperatorType::EW_ADD, + FlexFlow::LegacyOperatorType::EW_MUL, + FlexFlow::LegacyOperatorType::MATMUL, + FlexFlow::LegacyOperatorType::MUL, + FlexFlow::LegacyOperatorType::ENLARGE, + FlexFlow::LegacyOperatorType::SQUEEZE, + FlexFlow::LegacyOperatorType::UNSQUEEZE, + FlexFlow::LegacyOperatorType::EW_SUB, + FlexFlow::LegacyOperatorType::EW_DIV, + FlexFlow::LegacyOperatorType::EW_EQUAL, + FlexFlow::LegacyOperatorType::EW_GREATER, + FlexFlow::LegacyOperatorType::EW_LESS, + FlexFlow::LegacyOperatorType::EW_MAX, + FlexFlow::LegacyOperatorType::EW_MIN, + FlexFlow::LegacyOperatorType::REDUCE_ARGMAX, + FlexFlow::LegacyOperatorType::REDUCE_ARGMIN, + FlexFlow::LegacyOperatorType::REDUCE_MAX, + FlexFlow::LegacyOperatorType::REDUCE_MEAN, + FlexFlow::LegacyOperatorType::REDUCE_MIN, + FlexFlow::LegacyOperatorType::REDUCE_PROD, + FlexFlow::LegacyOperatorType::REDUCE_SUM, + FlexFlow::LegacyOperatorType::PAD, + FlexFlow::LegacyOperatorType::SHAPE, + FlexFlow::LegacyOperatorType::SIZE, + FlexFlow::LegacyOperatorType::TOPK, + FlexFlow::LegacyOperatorType::WHERE, + FlexFlow::LegacyOperatorType::CEIL, + FlexFlow::LegacyOperatorType::CAST, + FlexFlow::LegacyOperatorType::EXP, + FlexFlow::LegacyOperatorType::ROUND, + FlexFlow::LegacyOperatorType::LOG, + FlexFlow::LegacyOperatorType::LOGICAL_NOT, + FlexFlow::LegacyOperatorType::SQRT, + FlexFlow::LegacyOperatorType::SIN, + FlexFlow::LegacyOperatorType::COS, + FlexFlow::LegacyOperatorType::LEAKYRELU, + FlexFlow::LegacyOperatorType::SLICE, + FlexFlow::LegacyOperatorType::RESIZE, + FlexFlow::LegacyOperatorType::PRELU, + FlexFlow::LegacyOperatorType::GELU, + FlexFlow::LegacyOperatorType::MULTIHEAD_ATTENTION, + FlexFlow::LegacyOperatorType::FUSED, + FlexFlow::LegacyOperatorType::RSQRT, + FlexFlow::LegacyOperatorType::POW, + FlexFlow::LegacyOperatorType::MEAN, + FlexFlow::LegacyOperatorType::LAYERNORM, + FlexFlow::LegacyOperatorType::GATHER, + FlexFlow::LegacyOperatorType::BROADCAST, + FlexFlow::LegacyOperatorType::REPARTITION, + FlexFlow::LegacyOperatorType::COMBINE, + FlexFlow::LegacyOperatorType::REPLICATE, + FlexFlow::LegacyOperatorType::REDUCTION, + FlexFlow::LegacyOperatorType::BATCH, + FlexFlow::LegacyOperatorType::PIPELINE, + FlexFlow::LegacyOperatorType::FUSED_PARALLEL); +} +} // namespace rc diff --git a/lib/substitution-generator/src/substitution-generator/legacy_pm_parameter.dtg.cc b/lib/substitution-generator/src/substitution-generator/legacy_pm_parameter.dtg.cc new file mode 100644 index 0000000000..c8df4ccd7d --- /dev/null +++ b/lib/substitution-generator/src/substitution-generator/legacy_pm_parameter.dtg.cc @@ -0,0 +1,313 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml +/* proj-data +{ + "generated_from": "e8dda0047c91e576878b86df2fec0b6b" +} +*/ + +#include "substitution-generator/legacy_pm_parameter.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::LegacyPMParameter x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(LegacyPMParameter x) { + switch (x) { + case LegacyPMParameter::OP_TYPE: + return "OP_TYPE"; + case LegacyPMParameter::NUM_INPUTS: + return "NUM_INPUTS"; + case LegacyPMParameter::NUM_OUTPUTS: + return "NUM_OUTPUTS"; + case LegacyPMParameter::GROUP: + return "GROUP"; + case LegacyPMParameter::KERNEL_H: + return "KERNEL_H"; + case LegacyPMParameter::KERNEL_W: + return "KERNEL_W"; + case LegacyPMParameter::STRIDE_H: + return "STRIDE_H"; + case LegacyPMParameter::STRIDE_W: + return "STRIDE_W"; + case LegacyPMParameter::PADDING_H: + return "PADDING_H"; + case LegacyPMParameter::PADDING_W: + return "PADDING_W"; + case LegacyPMParameter::ACTI: + return "ACTI"; + case LegacyPMParameter::NUMDIM: + return "NUMDIM"; + case LegacyPMParameter::AXIS: + return "AXIS"; + case LegacyPMParameter::PERM: + return "PERM"; + case LegacyPMParameter::OUTSHUFFLE: + return "OUTSHUFFLE"; + case LegacyPMParameter::MERGE_GCONV_COUNT: + return "MERGE_GCONV_COUNT"; + case LegacyPMParameter::AXES: + return "AXES"; + case LegacyPMParameter::KEEP_DIMS: + return "KEEP_DIMS"; + case LegacyPMParameter::EPSILON: + return "EPSILON"; + case LegacyPMParameter::REPARTITION_DIM: + return "REPARTITION_DIM"; + case LegacyPMParameter::REPARTITION_DEGREE: + return "REPARTITION_DEGREE"; + case LegacyPMParameter::REPLICATE_DIM: + return "REPLICATE_DIM"; + case LegacyPMParameter::REPLICATE_DEGREE: + return "REPLICATE_DEGREE"; + case LegacyPMParameter::COMBINE_DIM: + return "COMBINE_DIM"; + case LegacyPMParameter::COMBINE_DEGREE: + return "COMBINE_DEGREE"; + case LegacyPMParameter::REDUCTION_DIM: + return "REDUCTION_DIM"; + case LegacyPMParameter::REDUCTION_DEGREE: + return "REDUCTION_DEGREE"; + case LegacyPMParameter::SOFTMAX_DIM: + return "SOFTMAX_DIM"; + case LegacyPMParameter::NUM_HEADS: + return "NUM_HEADS"; + case LegacyPMParameter::PARALLEL_DIM: + return "PARALLEL_DIM"; + case LegacyPMParameter::PARALLEL_DEGREE: + return "PARALLEL_DEGREE"; + case LegacyPMParameter::PAD: + return "PAD"; + default: + std::ostringstream oss; + oss << "Unknown LegacyPMParameter value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, LegacyPMParameter x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, LegacyPMParameter x) { + switch (x) { + case LegacyPMParameter::OP_TYPE: + j = "PM_OP_TYPE"; + break; + case LegacyPMParameter::NUM_INPUTS: + j = "PM_NUM_INPUTS"; + break; + case LegacyPMParameter::NUM_OUTPUTS: + j = "PM_NUM_OUTPUTS"; + break; + case LegacyPMParameter::GROUP: + j = "PM_GROUP"; + break; + case LegacyPMParameter::KERNEL_H: + j = "PM_KERNEL_H"; + break; + case LegacyPMParameter::KERNEL_W: + j = "PM_KERNEL_W"; + break; + case LegacyPMParameter::STRIDE_H: + j = "PM_STRIDE_H"; + break; + case LegacyPMParameter::STRIDE_W: + j = "PM_STRIDE_W"; + break; + case LegacyPMParameter::PADDING_H: + j = "PM_PADDING_H"; + break; + case LegacyPMParameter::PADDING_W: + j = "PM_PADDING_W"; + break; + case LegacyPMParameter::ACTI: + j = "PM_ACTI"; + break; + case LegacyPMParameter::NUMDIM: + j = "PM_NUMDIM"; + break; + case LegacyPMParameter::AXIS: + j = "PM_AXIS"; + break; + case LegacyPMParameter::PERM: + j = "PM_PERM"; + break; + case LegacyPMParameter::OUTSHUFFLE: + j = "PM_OUTSHUFFLE"; + break; + case LegacyPMParameter::MERGE_GCONV_COUNT: + j = "PM_MERGE_GCONV_COUNT"; + break; + case LegacyPMParameter::AXES: + j = "PM_AXES"; + break; + case LegacyPMParameter::KEEP_DIMS: + j = "PM_KEEP_DIMS"; + break; + case LegacyPMParameter::EPSILON: + j = "PM_EPSILON"; + break; + case LegacyPMParameter::REPARTITION_DIM: + j = "PM_REPARTITION_DIM"; + break; + case LegacyPMParameter::REPARTITION_DEGREE: + j = "PM_REPARTITION_DEGREE"; + break; + case LegacyPMParameter::REPLICATE_DIM: + j = "PM_REPLICATE_DIM"; + break; + case LegacyPMParameter::REPLICATE_DEGREE: + j = "PM_REPLICATE_DEGREE"; + break; + case LegacyPMParameter::COMBINE_DIM: + j = "PM_COMBINE_DIM"; + break; + case LegacyPMParameter::COMBINE_DEGREE: + j = "PM_COMBINE_DEGREE"; + break; + case LegacyPMParameter::REDUCTION_DIM: + j = "PM_REDUCTION_DIM"; + break; + case LegacyPMParameter::REDUCTION_DEGREE: + j = "PM_REDUCTION_DEGREE"; + break; + case LegacyPMParameter::SOFTMAX_DIM: + j = "PM_SOFTMAX_DIM"; + break; + case LegacyPMParameter::NUM_HEADS: + j = "PM_NUM_HEADS"; + break; + case LegacyPMParameter::PARALLEL_DIM: + j = "PM_PARALLEL_DIM"; + break; + case LegacyPMParameter::PARALLEL_DEGREE: + j = "PM_PARALLEL_DEGREE"; + break; + case LegacyPMParameter::PAD: + j = "PM_PAD"; + break; + default: + std::ostringstream oss; + oss << "Unknown LegacyPMParameter value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, LegacyPMParameter &x) { + std::string as_str = j.get(); + if (as_str == "PM_OP_TYPE") { + x = LegacyPMParameter::OP_TYPE; + } else if (as_str == "PM_NUM_INPUTS") { + x = LegacyPMParameter::NUM_INPUTS; + } else if (as_str == "PM_NUM_OUTPUTS") { + x = LegacyPMParameter::NUM_OUTPUTS; + } else if (as_str == "PM_GROUP") { + x = LegacyPMParameter::GROUP; + } else if (as_str == "PM_KERNEL_H") { + x = LegacyPMParameter::KERNEL_H; + } else if (as_str == "PM_KERNEL_W") { + x = LegacyPMParameter::KERNEL_W; + } else if (as_str == "PM_STRIDE_H") { + x = LegacyPMParameter::STRIDE_H; + } else if (as_str == "PM_STRIDE_W") { + x = LegacyPMParameter::STRIDE_W; + } else if (as_str == "PM_PADDING_H") { + x = LegacyPMParameter::PADDING_H; + } else if (as_str == "PM_PADDING_W") { + x = LegacyPMParameter::PADDING_W; + } else if (as_str == "PM_ACTI") { + x = LegacyPMParameter::ACTI; + } else if (as_str == "PM_NUMDIM") { + x = LegacyPMParameter::NUMDIM; + } else if (as_str == "PM_AXIS") { + x = LegacyPMParameter::AXIS; + } else if (as_str == "PM_PERM") { + x = LegacyPMParameter::PERM; + } else if (as_str == "PM_OUTSHUFFLE") { + x = LegacyPMParameter::OUTSHUFFLE; + } else if (as_str == "PM_MERGE_GCONV_COUNT") { + x = LegacyPMParameter::MERGE_GCONV_COUNT; + } else if (as_str == "PM_AXES") { + x = LegacyPMParameter::AXES; + } else if (as_str == "PM_KEEP_DIMS") { + x = LegacyPMParameter::KEEP_DIMS; + } else if (as_str == "PM_EPSILON") { + x = LegacyPMParameter::EPSILON; + } else if (as_str == "PM_REPARTITION_DIM") { + x = LegacyPMParameter::REPARTITION_DIM; + } else if (as_str == "PM_REPARTITION_DEGREE") { + x = LegacyPMParameter::REPARTITION_DEGREE; + } else if (as_str == "PM_REPLICATE_DIM") { + x = LegacyPMParameter::REPLICATE_DIM; + } else if (as_str == "PM_REPLICATE_DEGREE") { + x = LegacyPMParameter::REPLICATE_DEGREE; + } else if (as_str == "PM_COMBINE_DIM") { + x = LegacyPMParameter::COMBINE_DIM; + } else if (as_str == "PM_COMBINE_DEGREE") { + x = LegacyPMParameter::COMBINE_DEGREE; + } else if (as_str == "PM_REDUCTION_DIM") { + x = LegacyPMParameter::REDUCTION_DIM; + } else if (as_str == "PM_REDUCTION_DEGREE") { + x = LegacyPMParameter::REDUCTION_DEGREE; + } else if (as_str == "PM_SOFTMAX_DIM") { + x = LegacyPMParameter::SOFTMAX_DIM; + } else if (as_str == "PM_NUM_HEADS") { + x = LegacyPMParameter::NUM_HEADS; + } else if (as_str == "PM_PARALLEL_DIM") { + x = LegacyPMParameter::PARALLEL_DIM; + } else if (as_str == "PM_PARALLEL_DEGREE") { + x = LegacyPMParameter::PARALLEL_DEGREE; + } else if (as_str == "PM_PAD") { + x = LegacyPMParameter::PAD; + } else { + std::ostringstream oss; + oss << "Unknown LegacyPMParameter value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::element( + FlexFlow::LegacyPMParameter::OP_TYPE, + FlexFlow::LegacyPMParameter::NUM_INPUTS, + FlexFlow::LegacyPMParameter::NUM_OUTPUTS, + FlexFlow::LegacyPMParameter::GROUP, + FlexFlow::LegacyPMParameter::KERNEL_H, + FlexFlow::LegacyPMParameter::KERNEL_W, + FlexFlow::LegacyPMParameter::STRIDE_H, + FlexFlow::LegacyPMParameter::STRIDE_W, + FlexFlow::LegacyPMParameter::PADDING_H, + FlexFlow::LegacyPMParameter::PADDING_W, + FlexFlow::LegacyPMParameter::ACTI, + FlexFlow::LegacyPMParameter::NUMDIM, + FlexFlow::LegacyPMParameter::AXIS, + FlexFlow::LegacyPMParameter::PERM, + FlexFlow::LegacyPMParameter::OUTSHUFFLE, + FlexFlow::LegacyPMParameter::MERGE_GCONV_COUNT, + FlexFlow::LegacyPMParameter::AXES, + FlexFlow::LegacyPMParameter::KEEP_DIMS, + FlexFlow::LegacyPMParameter::EPSILON, + FlexFlow::LegacyPMParameter::REPARTITION_DIM, + FlexFlow::LegacyPMParameter::REPARTITION_DEGREE, + FlexFlow::LegacyPMParameter::REPLICATE_DIM, + FlexFlow::LegacyPMParameter::REPLICATE_DEGREE, + FlexFlow::LegacyPMParameter::COMBINE_DIM, + FlexFlow::LegacyPMParameter::COMBINE_DEGREE, + FlexFlow::LegacyPMParameter::REDUCTION_DIM, + FlexFlow::LegacyPMParameter::REDUCTION_DEGREE, + FlexFlow::LegacyPMParameter::SOFTMAX_DIM, + FlexFlow::LegacyPMParameter::NUM_HEADS, + FlexFlow::LegacyPMParameter::PARALLEL_DIM, + FlexFlow::LegacyPMParameter::PARALLEL_DEGREE, + FlexFlow::LegacyPMParameter::PAD); +} +} // namespace rc diff --git a/lib/substitution-generator/test/substitution-generator/json.cc b/lib/substitution-generator/test/substitution-generator/json.cc index 3b177f2bfe..befdaf1308 100644 --- a/lib/substitution-generator/test/substitution-generator/json.cc +++ b/lib/substitution-generator/test/substitution-generator/json.cc @@ -18,7 +18,7 @@ TEST_SUITE(FF_TEST_SUITE) { Operator o; from_json(j, o); - CHECK(o.op_type == OperatorType::EW_ADD); + CHECK(o.op_type == LegacyOperatorType::EW_ADD); CHECK(o.input.size() == 2); CHECK(o.input[0].opId == -2); CHECK(o.input[0].tsId == 0); From 0732e8b701b7dca6f3dd8dadc6148341260a1a63 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 3 Jun 2024 16:10:53 -0700 Subject: [PATCH 39/43] Format --- .../include/substitution-generator/json.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/substitution-generator/include/substitution-generator/json.h b/lib/substitution-generator/include/substitution-generator/json.h index 930d9c3f3f..5563d8a835 100644 --- a/lib/substitution-generator/include/substitution-generator/json.h +++ b/lib/substitution-generator/include/substitution-generator/json.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_SUBSTITUTION_LOADER_H #define _FLEXFLOW_SUBSTITUTION_LOADER_H -#include -#include #include "substitution-generator/legacy_operator_type.dtg.h" #include "substitution-generator/legacy_pm_parameter.dtg.h" +#include +#include #include namespace FlexFlow { From cdfdd118b5bc6740c084c21bfc49627ced95dcd2 Mon Sep 17 00:00:00 2001 From: Qinghan Chen Date: Mon, 3 Jun 2024 19:36:40 -0400 Subject: [PATCH 40/43] fix ci coverage attempt --- .github/workflows/per-lib-check.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 70308f990b..670d6dbfe8 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -111,8 +111,8 @@ jobs: - name: Generate code coverage run: | lcov --capture --directory . --output-file main_coverage.info - lcov --extract main_coverage.info 'lib/*' --output-file main_coverage.info - lcov --remove main_coverage.info 'lib/*.dtg.h' 'lib/*.dtg.cc' --output-file main_coverage.info + lcov --extract main_coverage.info "$GITHUB_WORKSPACE/lib/*" --output-file main_coverage.info + lcov --remove main_coverage.info "$GITHUB_WORKSPACE/lib/*.dtg.h" "$GITHUB_WORKSPACE/lib/*.dtg.cc" --output-file main_coverage.info lcov --list main_coverage.info - name: Upload code coverage From 137e59c7e63425b6619aa079471f0981f595eb9b Mon Sep 17 00:00:00 2001 From: Qinghan Chen Date: Mon, 3 Jun 2024 22:43:29 -0400 Subject: [PATCH 41/43] second attempt --- .github/workflows/per-lib-check.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 69148dc181..0d4bc9cbc7 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -111,9 +111,9 @@ jobs: - name: Generate code coverage run: | lcov --capture --directory . --output-file main_coverage.info - lcov --extract main_coverage.info "$GITHUB_WORKSPACE/lib/*" --output-file main_coverage.info - lcov --remove main_coverage.info "$GITHUB_WORKSPACE/lib/*.dtg.h" "$GITHUB_WORKSPACE/lib/*.dtg.cc" --output-file main_coverage.info - lcov --list main_coverage.info + lcov --extract main_coverage.info "$GITHUB_WORKSPACE/lib/*" --output-file main_coverage_e.info + lcov --remove main_coverage_e.info "$GITHUB_WORKSPACE/lib/*.dtg.h" "$GITHUB_WORKSPACE/lib/*.dtg.cc" --output-file main_coverage_e_f.info + lcov --list main_coverage_e_f.info - name: Upload code coverage uses: codecov/codecov-action@v4 From 70e1f09d8cfb47cccede61193c986055fcc1915e Mon Sep 17 00:00:00 2001 From: Qinghan Chen Date: Mon, 3 Jun 2024 22:49:31 -0400 Subject: [PATCH 42/43] small fix --- .github/workflows/per-lib-check.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 0d4bc9cbc7..ac014e0644 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -110,6 +110,7 @@ jobs: - name: Generate code coverage run: | + echo "gitwork: $GITHUB_WORKSPACE" lcov --capture --directory . --output-file main_coverage.info lcov --extract main_coverage.info "$GITHUB_WORKSPACE/lib/*" --output-file main_coverage_e.info lcov --remove main_coverage_e.info "$GITHUB_WORKSPACE/lib/*.dtg.h" "$GITHUB_WORKSPACE/lib/*.dtg.cc" --output-file main_coverage_e_f.info @@ -119,7 +120,7 @@ jobs: uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} - files: main_coverage.info + files: main_coverage_e_f.info flags: unittests name: codecov-umbrella fail_ci_if_error: false From 6230d992d5227fbd2b94232c4e590b9a2d6ac1b4 Mon Sep 17 00:00:00 2001 From: Qinghan Chen Date: Mon, 3 Jun 2024 23:18:47 -0400 Subject: [PATCH 43/43] small fix --- .github/workflows/per-lib-check.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index ac014e0644..c554c54dc4 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -120,8 +120,9 @@ jobs: uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} - files: main_coverage_e_f.info + file: main_coverage_e_f.info flags: unittests + plugin: pycoverage #hope this will disable gcov name: codecov-umbrella fail_ci_if_error: false verbose: true