Implement infer_bwd_binding, task reg#1097
Implement infer_bwd_binding, task reg#1097reyna-abhyankar wants to merge 26 commits intoflexflow:repo-refactorfrom
infer_bwd_binding, task reg#1097Conversation
infer_bwd_bindinginfer_bwd_binding, task reg
lockshaw
left a comment
There was a problem hiding this comment.
Test cases for infer_bwd_binding should be added
Reviewed 6 of 6 files at r1, all commit messages.
Reviewable status: all files reviewed, 3 unresolved discussions (waiting on @reyna-abhyankar)
lib/runtime/src/task_spec/op_task_invocation.h line 87 at r1 (raw file):
void bind_args_from_fwd(OpTaskBinding const &fwd) { this->arg_bindings = fwd.get_arg_bindings();
Can you make sure these are either implemented or marked as NOT_IMPLEMENTED?
lib/runtime/src/task_spec/op_task_invocation.cc line 51 at r1 (raw file):
bwd.bind_args_from_fwd(fwd); bwd.bind_tensors_from_fwd(fwd); for (auto const &[key, spec] : fwd.get_tensor_bindings()) {
We probably need to handle untrainable parameters here.
lib/runtime/src/task_spec/task_signature.h line 45 at r1 (raw file):
void add_variadic_arg_slot(slot_id name); static std::unordered_map<task_id_t, TaskSignature> task_sig_map;
Do you think it's better having these maps per type or would it be better to have a SignatureStore class (or something like it) that stores them--I'd lean toward the second just so that all of the signature storing and retrieval behavior is closeby
reyna-abhyankar
left a comment
There was a problem hiding this comment.
Reviewable status: all files reviewed, 3 unresolved discussions (waiting on @lockshaw)
lib/runtime/src/task_spec/op_task_invocation.h line 87 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Can you make sure these are either implemented or marked as
NOT_IMPLEMENTED?
Yes, get_arg_bindings() and get_tensor_bindings() are already implemented. See op_task_invocation.cc
lib/runtime/src/task_spec/op_task_invocation.cc line 51 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
We probably need to handle untrainable parameters here.
The signature has OpTensorSlotSpec and the binding uses OpTensorSpec. The slot spec is more descriptive and has OpSlotOptions, which is where untrainable is handled. The binding is only aware of the TensorRole (input, weight, or output). Is there a reason that untrainable also needs to be specified in the binding? I think the signature should suffice.
lib/runtime/src/task_spec/task_signature.h line 45 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Do you think it's better having these maps per type or would it be better to have a
SignatureStoreclass (or something like it) that stores them--I'd lean toward the second just so that all of the signature storing and retrieval behavior is closeby
At the moment, we only have two: one for task and one for operator task. If SignatureStore is just going to be a wrapper then I don't think it would be more useful.
lockshaw
left a comment
There was a problem hiding this comment.
Reviewed 8 of 8 files at r2, all commit messages.
Reviewable status: all files reviewed, 10 unresolved discussions (waiting on @reyna-abhyankar)
lib/runtime/src/task_spec/op_task_invocation.h line 136 at r1 (raw file):
OpTaskSignature infer_bwd_signature(OpTaskSignature const &fwd); OpTaskBinding infer_bwd_binding(OpTaskBinding const &fwd); OpTaskSignature get_op_signature(task_id_t const &);
Why remove this?
lib/runtime/src/task_spec/op_task_invocation.h line 43 at r2 (raw file):
TaskInvocationSpec>; struct OpArgSpecTypeAccessor {
Can we just move all of these over to having a type tag?
lib/runtime/src/task_spec/op_task_invocation.h line 118 at r2 (raw file):
} bool operator==(const OpTaskBinding& rhs) {
Probably should use visitable instead
lib/runtime/src/task_spec/op_task_invocation.cc line 51 at r1 (raw file):
Previously, reyna-abhyankar (Reyna Abhyankar) wrote…
The signature has
OpTensorSlotSpecand the binding usesOpTensorSpec. The slot spec is more descriptive and hasOpSlotOptions, which is where untrainable is handled. The binding is only aware of theTensorRole(input, weight, or output). Is there a reason that untrainable also needs to be specified in the binding? I think the signature should suffice.
My point is that if a tensor if marked untrainable in the signature, it probably shouldn't have its gradient bound (as it has no gradient)
lib/runtime/src/task_spec/op_task_signature.h line 74 at r2 (raw file):
} // TODO: what is this function's purpose? OpTaskSignature doesn't need a return value
Why does it not need a return value? For example, init_task for some operators returns a DeviceSpecific<xxxxxPerDeviceState>
lib/runtime/src/task_spec/op_task_signature.h line 92 at r2 (raw file):
private: OpTaskType type;
What is the purpose of this field?
lib/runtime/src/task_spec/op_task_signature.cc line 77 at r2 (raw file):
} bool OpTaskSignature::operator==(OpTaskSignature const & rhs) const {
Make OpTaskSignature a FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION so you get these automatically
lib/runtime/src/task_spec/task_signature.h line 45 at r1 (raw file):
Previously, reyna-abhyankar (Reyna Abhyankar) wrote…
At the moment, we only have two: one for task and one for operator task. If
SignatureStoreis just going to be a wrapper then I don't think it would be more useful.
It would just be nice to have all of the signature storage in one place in the code rather than in two, but I'll go with your preference
lib/runtime/test/src/test_op_task_invocation.cc line 13 at r2 (raw file):
// get binding from operator forward OpTaskSignature signature = fwd_signature<AGGREGATE_FWD_TASK_ID>(); AggregateAttrs attrs {n: 12, lambda_bal: 1.0};
I don't think this is valid C++ syntax?
lib/runtime/test/src/test_op_task_invocation.cc line 37 at r2 (raw file):
assert (arg_types.count(slot_id)); assert (arg_types.at(slot_id) == op_arg_spec_type); }
This should probably be pulled out into a separate function as every function's operator tests will be using it?
Also, tests should use CHECK, not assert
Code quote:
// check tensors
auto const& tensor_slots = signature.get_tensor_slots();
auto const& tensor_bindings = binding.get_tensor_bindings();
assert (tensor_slots.size() == tensor_bindings.size());
for (OpTensorSlotSpec const& tensor_slot_spec: tensor_slots) {
slot_id name = tensor_slot_spec.name;
assert (tensor_bindings.count({name, IsGrad::NO}));
assert (!tensor_bindings.count({name, IsGrad::YES}));
OpTensorSpec const tensor_spec = tensor_bindings.at({name, IsGrad::NO});
assert (tensor_spec.role == tensor_slot_spec.tensor_role);
}
// check arg types
auto const& arg_types = signature.get_arg_types();
auto const& arg_bindings = binding.get_arg_bindings();
assert (arg_types.size() == arg_bindings.size());
for (auto const &[slot_id, op_arg_spec] : arg_bindings) {
std::type_index op_arg_spec_type = std::visit(OpArgSpecTypeAccessor(), op_arg_spec);
assert (arg_types.count(slot_id));
assert (arg_types.at(slot_id) == op_arg_spec_type);
}
wmdi
left a comment
There was a problem hiding this comment.
Reviewed 16 of 24 files at r3, all commit messages.
Reviewable status: 19 of 27 files reviewed, 11 unresolved discussions (waiting on @lambda7xx, @lockshaw, and @reyna-abhyankar)
lib/op-attrs/include/op-attrs/tensor_shape.h line 19 at r4 (raw file):
template <typename Dims> TensorShape(Dims const &dims, DataType data_type) : dims(this->dims), data_type(this->data_type) {}
Why?
lockshaw
left a comment
There was a problem hiding this comment.
Reviewed 8 of 24 files at r3, 1 of 1 files at r4, all commit messages.
Reviewable status: all files reviewed, 28 unresolved discussions (waiting on @lambda7xx, @reyna-abhyankar, and @wmdi)
lib/op-attrs/include/op-attrs/tensor_shape.h line 19 at r4 (raw file):
Previously, wmdi (Mengdi Wu) wrote…
Why?
I think this is a bug and should be reverted: this->dims and this->data_type aren't initialized yet
lib/runtime/src/task_spec/op_task_invocation.cc line 6 at r4 (raw file):
namespace FlexFlow { OpTensorSpec input_tensor(int idx, OpSlotOptions option = OpSlotOptions::NECESSARY) {
Shouldn't the slot options be on the slot (i.e., as part of the signature) and not part of the binding?
lib/runtime/src/task_spec/op_task_signature.h line 66 at r4 (raw file):
bool operator==(OpTaskSignature const &) const; bool operator!=(OpTaskSignature const &) const;
Provided by visitable
lib/runtime/src/task_spec/op_task_signature.h line 98 at r4 (raw file):
OpTaskSignature get_op_signature(task_id_t const &); FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(OpTaskSignature, get_tensor_slots, set_arg_types, get_arg_types);
Suggestion:
FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(OpTaskSignature, op_tensor_slots, set_arg_types, task_arg_types);lib/runtime/src/task_spec/op_task_signature.cc line 5 at r2 (raw file):
namespace FlexFlow { OpTaskSignature::OpTaskSignature(OpTaskType t) {
Why was this deleted? I think it's necessary?
lib/runtime/src/task_spec/task_argument_accessor.h line 86 at r4 (raw file):
std::vector<privilege_mode_to_accessor<Permissions::WO>>>; struct ITaskArgumentAccessor {
Move into its own header/source files
lib/runtime/src/task_spec/task_argument_accessor.h line 89 at r4 (raw file):
virtual PrivilegeType get_tensor(slot_id slot, Permissions priv) const = 0; virtual PrivilegeVariadicType get_variadic_tensor(slot_id slot, Permissions priv) const = 0;
Why was the template parameter removed from these? By requiring it to be static we were able to infer the return type and didn't have to return a variant
lib/runtime/src/task_spec/task_argument_accessor.h line 98 at r4 (raw file):
virtual size_t get_device_idx() const = 0; };
Should be used on anything that declares a virtual method (checks https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#c67-a-polymorphic-class-should-suppress-public-copymove)
Suggestion:
virtual size_t get_device_idx() const = 0;
};
CHECK_RC_COPY_VIRTUAL_COMPLIANT(ITaskArgumentAccessor);lib/runtime/src/task_spec/task_argument_accessor.h line 132 at r4 (raw file):
return acc.get_variadic_argument<T>(slot); } };
lib/runtime/src/task_spec/task_argument_accessor.h line 134 at r4 (raw file):
}; struct LegionTaskArgumentAccessor : public ITaskArgumentAccessor {
Move into its own header/source files
lib/runtime/src/task_spec/task_argument_accessor.h line 169 at r4 (raw file):
}; struct LocalTaskArgumentAccessor : public ITaskArgumentAccessor {
Move into its own header/source files
lib/runtime/src/task_spec/task_argument_accessor.h line 201 at r4 (raw file):
void *allocate(size_t size); void deallocate(void *ptr);
The operators expect to be able to get an Allocator. Memory usage tracking should instead be moved to a wrapper around Allocator (e.g., TrackedAllocator(Allocator const &)) (which is also an Allocator of course)
Suggestion:
Allocator get_allocator() const;lib/runtime/src/task_spec/task_argument_accessor.h line 204 at r4 (raw file):
private: std::shared_ptr<SimTaskBinding const> sim_task_binding;
Is there a reason we have this as a shared ptr rather than a value?
lib/runtime/src/task_spec/task_argument_accessor.h line 210 at r4 (raw file):
using TaskArgumentAccessorBackend = variant<LegionTaskArgumentAccessor, LocalTaskArgumentAccessor>;
lib/runtime/src/task_spec/task_argument_accessor.h line 221 at r4 (raw file):
optional<T> const &get_optional_argument(slot_id slot) const { return std::visit(OptionalArgumentAccessor(), this->ptr, slot); }
Also for get_argument and get_variadic_argument
Suggestion:
template <typename T>
optional<T> const &get_optional_argument(slot_id slot) const {
return this->ptr->get_optional_argument(slot);
}lib/runtime/src/task_spec/task_argument_accessor.h line 251 at r4 (raw file):
TaskArgumentAccessor(std::shared_ptr<TaskArgumentAccessorBackend const> &ptr) : ptr(ptr) {} std::shared_ptr<TaskArgumentAccessorBackend const> ptr;
Suggestion:
TaskArgumentAccessor(std::shared_ptr<ITaskArgumentAccessorBackend const> ptr)
: ptr(ptr) {}
std::shared_ptr<ITaskArgumentAccessorBackend const> ptr;lib/runtime/src/task_spec/task_argument_accessor.cc line 6 at r4 (raw file):
template <typename T> T const &LocalTaskArgumentAccessor::get_argument(slot_id slot) const {
This needs to be rewritten entirely to use the signature/binding
lib/runtime/src/task_spec/task_argument_accessor.cc line 8 at r4 (raw file):
T const &LocalTaskArgumentAccessor::get_argument(slot_id slot) const { if (slot == PROFILING) { return get<ProfilingSettings>(this->arg_bindings.at(slot));
This should be retrieved from the signature/binding--the value of PROFILING can be different in every operator, and it can be assigned an arbitrary type (I could in theory have a slot PROFILING that takes in a Conv2DAttrs)
lib/runtime/src/task_spec/task_argument_accessor.cc line 29 at r4 (raw file):
free(cpu_ptr); return ptr; }
This whole file needs to be rewritten to use the signature and binding to generate values, not hardcoding in private enum values
reyna-abhyankar
left a comment
There was a problem hiding this comment.
Reviewable status: all files reviewed, 27 unresolved discussions (waiting on @lambda7xx, @lockshaw, and @wmdi)
lib/op-attrs/include/op-attrs/tensor_shape.h line 19 at r4 (raw file):
Previously, lockshaw (Colin Unger) wrote…
I think this is a bug and should be reverted:
this->dimsandthis->data_typearen't initialized yet
Done.
lib/runtime/src/task_spec/op_task_invocation.h line 136 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Why remove this?
This is moved to op_task_signature.h
lib/runtime/src/task_spec/op_task_invocation.h line 118 at r2 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Probably should use visitable instead
Done.
lib/runtime/src/task_spec/op_task_invocation.cc line 51 at r1 (raw file):
Previously, lockshaw (Colin Unger) wrote…
My point is that if a tensor if marked untrainable in the signature, it probably shouldn't have its gradient bound (as it has no gradient)
See response to your other comment on OpSlotOptions for binding.
lib/runtime/src/task_spec/op_task_invocation.cc line 6 at r4 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Shouldn't the slot options be on the slot (i.e., as part of the signature) and not part of the binding?
I added it to the binding so we can check untrainable in infer_bwd_binding() but the other option is have infer_bwd_binding() also take the signature as an argument.
lib/runtime/src/task_spec/op_task_signature.h line 74 at r2 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Why does it not need a return value? For example,
init_taskfor some operators returns aDeviceSpecific<xxxxxPerDeviceState>
Done. Should it be a vector of return values or is there only one return value?
lib/runtime/src/task_spec/op_task_signature.h line 92 at r2 (raw file):
Previously, lockshaw (Colin Unger) wrote…
What is the purpose of this field?
There is a get_task_type() function that is called in task_spec.cc. I've implemented it as follows:
Code snippet:
OpTaskType get_task_type() const {
return this->type;
}lib/runtime/src/task_spec/op_task_signature.h line 66 at r4 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Provided by visitable
Done.
lib/runtime/src/task_spec/op_task_signature.h line 98 at r4 (raw file):
OpTaskSignature get_op_signature(task_id_t const &); FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(OpTaskSignature, get_tensor_slots, set_arg_types, get_arg_types);
Done. But also getting member is inaccessible because they're private. Should we make them public?
lib/runtime/src/task_spec/op_task_signature.cc line 5 at r2 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Why was this deleted? I think it's necessary?
Added this as
Code snippet:
OpTaskSignature::OpTaskSignature(OpTaskType t) {
this->type = t;
}lib/runtime/src/task_spec/op_task_signature.cc line 77 at r2 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Make
OpTaskSignatureaFF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTIONso you get these automatically
Done. See above comment about member accessible
lib/runtime/src/task_spec/task_argument_accessor.h line 86 at r4 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Move into its own header/source files
Done.
lib/runtime/src/task_spec/task_argument_accessor.h line 89 at r4 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Why was the template parameter removed from these? By requiring it to be static we were able to infer the return type and didn't have to return a variant
AFAIK C++ does not allow virtual template functions. I was getting errors
lib/runtime/src/task_spec/task_argument_accessor.h line 98 at r4 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Should be used on anything that declares a virtual method (checks https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#c67-a-polymorphic-class-should-suppress-public-copymove)
Done. Added delete for operator= and virtual destructor.
lib/runtime/src/task_spec/task_argument_accessor.h line 132 at r4 (raw file):
return acc.get_variadic_argument<T>(slot); } };
Done.
lib/runtime/src/task_spec/task_argument_accessor.h line 134 at r4 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Move into its own header/source files
Done.
lib/runtime/src/task_spec/task_argument_accessor.h line 169 at r4 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Move into its own header/source files
Done.
lib/runtime/src/task_spec/task_argument_accessor.h line 210 at r4 (raw file):
using TaskArgumentAccessorBackend = variant<LegionTaskArgumentAccessor, LocalTaskArgumentAccessor>;
Done.
lib/runtime/src/task_spec/task_argument_accessor.h line 221 at r4 (raw file):
Previously, lockshaw (Colin Unger) wrote…
Also for
get_argumentandget_variadic_argument
Nice. I assumed because it's variant that we have to visit to access the function. I've made the change
lib/runtime/src/task_spec/task_argument_accessor.h line 251 at r4 (raw file):
TaskArgumentAccessor(std::shared_ptr<TaskArgumentAccessorBackend const> &ptr) : ptr(ptr) {} std::shared_ptr<TaskArgumentAccessorBackend const> ptr;
Done.
lib/runtime/test/src/test_op_task_invocation.cc line 13 at r2 (raw file):
Previously, lockshaw (Colin Unger) wrote…
I don't think this is valid C++ syntax?
This is just an initializer list, right? My compiler is not complaining
lib/runtime/test/src/test_op_task_invocation.cc line 37 at r2 (raw file):
Previously, lockshaw (Colin Unger) wrote…
This should probably be pulled out into a separate function as every function's operator tests will be using it?
Also, tests should use
CHECK, notassert
Done.
lockshaw
left a comment
There was a problem hiding this comment.
Reviewed 7 of 9 files at r5, 7 of 7 files at r6, all commit messages.
Reviewable status: all files reviewed, 33 unresolved discussions (waiting on @lambda7xx, @reyna-abhyankar, and @wmdi)
lib/kernels/include/kernels/allocation.h line 34 at r6 (raw file):
}; struct TrackedAllocator: public Allocator {
Is this better as a subclass of Allocator or of IAllocator?
lib/runtime/src/task_spec/itask_argument_accessor.h line 52 at r6 (raw file):
FF_VISITABLE_STRUCT(FutureArgumentFormat, type, future_idx); struct TaskArgumentsFormat {
TaskArgumentsFormat should be in a different file than ITaskArgumentAccessor, as TaskArgumentsFormat is part of compilation to legion tasks and ITaskArgumentAccessor is shared across all backends
lib/runtime/src/task_spec/itask_argument_accessor.h line 92 at r6 (raw file):
struct ITaskArgumentAccessor { ITaskArgumentAccessor& operator=(const ITaskArgumentAccessor&) = delete; virtual ~ITaskArgumentAccessor() {};
Suggestion:
virtual ~ITaskArgumentAccessor() = default;lib/runtime/src/task_spec/legion_task_argument_accessor.h line 20 at r6 (raw file):
namespace FlexFlow { struct LegionTaskArgumentAccessor : public ITaskArgumentAccessor {
You're missing some overrides
lib/runtime/src/task_spec/legion_task_argument_accessor.cc line 0 at r6 (raw file):
#include "runtime/task_spec/legion_task_argument_accessor.h"
lib/runtime/src/task_spec/local_task_argument_accessor.h line 29 at r6 (raw file):
PrivilegeType get_tensor(slot_id slot, Permissions priv) const override; PrivilegeVariadicType get_variadic_tensor(slot_id slot, Permissions priv) const override;
You're missing some overrides
lib/runtime/src/task_spec/local_task_argument_accessor.h line 53 at r6 (raw file):
void *allocate(size_t size); void deallocate(void *ptr);
Delete
Code quote:
size_t get_memory_usage() const {
return memory_usage;
}
void *allocate(size_t size);
void deallocate(void *ptr);lib/runtime/src/task_spec/local_task_argument_accessor.h line 59 at r6 (raw file):
SimTaskBinding sim_task_binding; Allocator local_allocator; size_t memory_usage;
Delete
lib/runtime/src/task_spec/local_task_argument_accessor.cc line 7 at r6 (raw file):
template <typename T> T const &LocalTaskArgumentAccessor::get_argument(slot_id slot) const { assert(contains_key(this->sim_task_binding.arg_bindings.at(slot)));
contains_key takes multiple arguments. You should be able to compile these files to catch errors like this, even if you're not able to build+link the full runtime
lib/runtime/src/task_spec/local_task_argument_accessor.cc line 12 at r6 (raw file):
template <typename T> optional<T> LocalTaskArgumentAccessor::get_optional_argument(slot_id slot) const {
Suggestion:
std::optional<T>lib/runtime/src/task_spec/local_task_argument_accessor.cc line 24 at r6 (raw file):
} void *LocalTaskArgumentAccessor::allocate(size_t size) {
Delete
lib/runtime/src/task_spec/local_task_argument_accessor.cc line 42 at r6 (raw file):
if (slot == GATE_PREDS || slot == GATE_ASSIGN) {
Why is there special handling needed for some slots?
lib/runtime/src/task_spec/local_task_argument_accessor.cc line 47 at r6 (raw file):
DataType data_type = gate_preds.shape.data_type; ArrayShape array_shape = { gate_preds.shape.dims.get_dims()}; // gate_preds.shape.dims.get_dims()
remove comments
lib/runtime/src/task_spec/local_task_argument_accessor.cc line 57 at r6 (raw file):
Datatype data_type = output_shape.data_type; ArrayShape array_shape = { output_shape.dims.get_dims()}; // output_shape.dims.get_dims() return
remove comments
lib/runtime/src/task_spec/local_task_argument_accessor.cc line 64 at r6 (raw file):
} else { throw mk_runtime_error( "Unknown Slot ID in LocalTaskArgumentAccessor::get_tensor");
Suggestion:
"Unknown Slot ID {} in LocalTaskArgumentAccessor::get_tensor", slot);lib/runtime/src/task_spec/local_task_argument_accessor.cc line 78 at r6 (raw file):
ArrayShape array_shape = { shape.dims .get_dims()}; // shape.dims.get_dims() return std::vector<size_t>
remove comment
lib/runtime/src/task_spec/op_task_invocation.h line 118 at r2 (raw file):
Previously, reyna-abhyankar (Reyna Abhyankar) wrote…
Done.
You should actually be using FF_VISITABLE_STRUCT (or in this case, FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION) instead of use_visitable_cmp (which is deprecated except in certain cases when you're working with template classes)
lib/runtime/src/task_spec/op_task_invocation.cc line 6 at r4 (raw file):
Previously, reyna-abhyankar (Reyna Abhyankar) wrote…
I added it to the binding so we can check untrainable in
infer_bwd_binding()but the other option is haveinfer_bwd_binding()also take the signature as an argument.
It should have access to the signature
lib/runtime/src/task_spec/op_task_signature.h line 74 at r2 (raw file):
Previously, reyna-abhyankar (Reyna Abhyankar) wrote…
Done. Should it be a vector of return values or is there only one return value?
Just one return value
lib/runtime/src/task_spec/op_task_signature.h line 98 at r4 (raw file):
Previously, reyna-abhyankar (Reyna Abhyankar) wrote…
Done. But also getting
member is inaccessiblebecause they're private. Should we make them public?
Sure
lib/runtime/src/task_spec/op_task_signature.h line 44 at r6 (raw file):
explicit OpTaskSignature(OpTaskType); OpTaskType get_task_type() const {
If the fields become public you can/should get rid of this
lib/runtime/src/task_spec/op_task_signature.h line 70 at r6 (raw file):
} // TODO: should this be a single type index?
Yes, tasks are only allowed to return one value/type
lib/runtime/src/task_spec/op_task_signature.cc line 5 at r6 (raw file):
namespace FlexFlow { OpTaskSignature::OpTaskSignature(OpTaskType t) {
Suggestion:
OpTaskSignature::OpTaskSignature(OpTaskType t): type(t) {lib/runtime/src/task_spec/task_argument_accessor.h line 89 at r4 (raw file):
Previously, reyna-abhyankar (Reyna Abhyankar) wrote…
AFAIK C++ does not allow virtual template functions. I was getting errors
Makes sense, then just make sure TaskArgumentAccessor has the template arguments that convert from the variant
lib/runtime/src/task_spec/task_argument_accessor.h line 22 at r6 (raw file):
namespace FlexFlow { using ITaskArgumentAccessorBackend = variant<LegionTaskArgumentAccessor,
Delete
lib/runtime/src/task_spec/task_argument_accessor.h line 64 at r6 (raw file):
TaskArgumentAccessor(std::shared_ptr<ITaskArgumentAccessorBackend const> ptr) : ptr(ptr) {} std::shared_ptr<ITaskArgumentAccessorBackend const> ptr;
Suggestion:
std::shared_ptr<ITaskArgumentAccessor const> ptr;lib/runtime/test/src/test_op_task_invocation.cc line 13 at r2 (raw file):
Previously, reyna-abhyankar (Reyna Abhyankar) wrote…
This is just an initializer list, right? My compiler is not complaining
It's valid only in C++20 I think
Description of changes:
Implement
infer_bwd_binding, task registrationRelated Issues:
#924
Issues closed by this PR:
infer_bwd_binding#1043, Implement task registration #1044,This change is