Skip to content

Implement infer_bwd_binding, task reg#1097

Closed
reyna-abhyankar wants to merge 26 commits intoflexflow:repo-refactorfrom
reyna-abhyankar:op-task
Closed

Implement infer_bwd_binding, task reg#1097
reyna-abhyankar wants to merge 26 commits intoflexflow:repo-refactorfrom
reyna-abhyankar:op-task

Conversation

@reyna-abhyankar
Copy link
Collaborator

@reyna-abhyankar reyna-abhyankar commented Sep 5, 2023

Description of changes:

Implement infer_bwd_binding, task registration

Related Issues:
#924

Issues closed by this PR:


This change is Reviewable

@reyna-abhyankar reyna-abhyankar changed the title Implement infer_bwd_binding Implement infer_bwd_binding, task reg Sep 6, 2023
@lockshaw lockshaw linked an issue Sep 7, 2023 that may be closed by this pull request
@reyna-abhyankar reyna-abhyankar linked an issue Sep 7, 2023 that may be closed by this pull request
Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test cases for infer_bwd_binding should be added

Reviewed 6 of 6 files at r1, all commit messages.
Reviewable status: all files reviewed, 3 unresolved discussions (waiting on @reyna-abhyankar)


lib/runtime/src/task_spec/op_task_invocation.h line 87 at r1 (raw file):

  void bind_args_from_fwd(OpTaskBinding const &fwd) {
    this->arg_bindings = fwd.get_arg_bindings();

Can you make sure these are either implemented or marked as NOT_IMPLEMENTED?


lib/runtime/src/task_spec/op_task_invocation.cc line 51 at r1 (raw file):

  bwd.bind_args_from_fwd(fwd);
  bwd.bind_tensors_from_fwd(fwd);
  for (auto const &[key, spec] : fwd.get_tensor_bindings()) {

We probably need to handle untrainable parameters here.


lib/runtime/src/task_spec/task_signature.h line 45 at r1 (raw file):

  void add_variadic_arg_slot(slot_id name);

  static std::unordered_map<task_id_t, TaskSignature> task_sig_map;

Do you think it's better having these maps per type or would it be better to have a SignatureStore class (or something like it) that stores them--I'd lean toward the second just so that all of the signature storing and retrieval behavior is closeby

Copy link
Collaborator Author

@reyna-abhyankar reyna-abhyankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: all files reviewed, 3 unresolved discussions (waiting on @lockshaw)


lib/runtime/src/task_spec/op_task_invocation.h line 87 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Can you make sure these are either implemented or marked as NOT_IMPLEMENTED?

Yes, get_arg_bindings() and get_tensor_bindings() are already implemented. See op_task_invocation.cc


lib/runtime/src/task_spec/op_task_invocation.cc line 51 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

We probably need to handle untrainable parameters here.

The signature has OpTensorSlotSpec and the binding uses OpTensorSpec. The slot spec is more descriptive and has OpSlotOptions, which is where untrainable is handled. The binding is only aware of the TensorRole (input, weight, or output). Is there a reason that untrainable also needs to be specified in the binding? I think the signature should suffice.


lib/runtime/src/task_spec/task_signature.h line 45 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Do you think it's better having these maps per type or would it be better to have a SignatureStore class (or something like it) that stores them--I'd lean toward the second just so that all of the signature storing and retrieval behavior is closeby

At the moment, we only have two: one for task and one for operator task. If SignatureStore is just going to be a wrapper then I don't think it would be more useful.

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 8 of 8 files at r2, all commit messages.
Reviewable status: all files reviewed, 10 unresolved discussions (waiting on @reyna-abhyankar)


lib/runtime/src/task_spec/op_task_invocation.h line 136 at r1 (raw file):

OpTaskSignature infer_bwd_signature(OpTaskSignature const &fwd);
OpTaskBinding infer_bwd_binding(OpTaskBinding const &fwd);
OpTaskSignature get_op_signature(task_id_t const &);

Why remove this?


lib/runtime/src/task_spec/op_task_invocation.h line 43 at r2 (raw file):

                          TaskInvocationSpec>;

struct OpArgSpecTypeAccessor {

Can we just move all of these over to having a type tag?


lib/runtime/src/task_spec/op_task_invocation.h line 118 at r2 (raw file):

  }

  bool operator==(const OpTaskBinding& rhs) {

Probably should use visitable instead


lib/runtime/src/task_spec/op_task_invocation.cc line 51 at r1 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

The signature has OpTensorSlotSpec and the binding uses OpTensorSpec. The slot spec is more descriptive and has OpSlotOptions, which is where untrainable is handled. The binding is only aware of the TensorRole (input, weight, or output). Is there a reason that untrainable also needs to be specified in the binding? I think the signature should suffice.

My point is that if a tensor if marked untrainable in the signature, it probably shouldn't have its gradient bound (as it has no gradient)


lib/runtime/src/task_spec/op_task_signature.h line 74 at r2 (raw file):

  }

  // TODO: what is this function's purpose? OpTaskSignature doesn't need a return value

Why does it not need a return value? For example, init_task for some operators returns a DeviceSpecific<xxxxxPerDeviceState>


lib/runtime/src/task_spec/op_task_signature.h line 92 at r2 (raw file):

private:
  OpTaskType type;

What is the purpose of this field?


lib/runtime/src/task_spec/op_task_signature.cc line 77 at r2 (raw file):

}

bool OpTaskSignature::operator==(OpTaskSignature const & rhs) const {

Make OpTaskSignature a FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION so you get these automatically


lib/runtime/src/task_spec/task_signature.h line 45 at r1 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

At the moment, we only have two: one for task and one for operator task. If SignatureStore is just going to be a wrapper then I don't think it would be more useful.

It would just be nice to have all of the signature storage in one place in the code rather than in two, but I'll go with your preference


lib/runtime/test/src/test_op_task_invocation.cc line 13 at r2 (raw file):

  // get binding from operator forward
  OpTaskSignature signature = fwd_signature<AGGREGATE_FWD_TASK_ID>();
  AggregateAttrs attrs {n: 12, lambda_bal: 1.0};

I don't think this is valid C++ syntax?


lib/runtime/test/src/test_op_task_invocation.cc line 37 at r2 (raw file):

    assert (arg_types.count(slot_id));
    assert (arg_types.at(slot_id) == op_arg_spec_type);
  }

This should probably be pulled out into a separate function as every function's operator tests will be using it?

Also, tests should use CHECK, not assert

Code quote:

  // check tensors
  auto const& tensor_slots = signature.get_tensor_slots();
  auto const& tensor_bindings = binding.get_tensor_bindings();
  assert (tensor_slots.size() == tensor_bindings.size());
  for (OpTensorSlotSpec const& tensor_slot_spec: tensor_slots) {
    slot_id name = tensor_slot_spec.name;
    assert (tensor_bindings.count({name, IsGrad::NO}));
    assert (!tensor_bindings.count({name, IsGrad::YES}));
    OpTensorSpec const tensor_spec = tensor_bindings.at({name, IsGrad::NO});
    assert (tensor_spec.role == tensor_slot_spec.tensor_role);
  }

  // check arg types
  auto const& arg_types = signature.get_arg_types();
  auto const& arg_bindings = binding.get_arg_bindings();
  assert (arg_types.size() == arg_bindings.size());
  for (auto const &[slot_id, op_arg_spec] : arg_bindings) {
    std::type_index op_arg_spec_type = std::visit(OpArgSpecTypeAccessor(), op_arg_spec);
    assert (arg_types.count(slot_id));
    assert (arg_types.at(slot_id) == op_arg_spec_type);
  }

Copy link
Collaborator

@wmdi wmdi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 16 of 24 files at r3, all commit messages.
Reviewable status: 19 of 27 files reviewed, 11 unresolved discussions (waiting on @lambda7xx, @lockshaw, and @reyna-abhyankar)


lib/op-attrs/include/op-attrs/tensor_shape.h line 19 at r4 (raw file):

  template <typename Dims>
  TensorShape(Dims const &dims, DataType data_type)
      : dims(this->dims), data_type(this->data_type) {}

Why?

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 8 of 24 files at r3, 1 of 1 files at r4, all commit messages.
Reviewable status: all files reviewed, 28 unresolved discussions (waiting on @lambda7xx, @reyna-abhyankar, and @wmdi)


lib/op-attrs/include/op-attrs/tensor_shape.h line 19 at r4 (raw file):

Previously, wmdi (Mengdi Wu) wrote…

Why?

I think this is a bug and should be reverted: this->dims and this->data_type aren't initialized yet


lib/runtime/src/task_spec/op_task_invocation.cc line 6 at r4 (raw file):

namespace FlexFlow {

OpTensorSpec input_tensor(int idx, OpSlotOptions option = OpSlotOptions::NECESSARY) {

Shouldn't the slot options be on the slot (i.e., as part of the signature) and not part of the binding?


lib/runtime/src/task_spec/op_task_signature.h line 66 at r4 (raw file):

  bool operator==(OpTaskSignature const &) const;
  bool operator!=(OpTaskSignature const &) const;

Provided by visitable


lib/runtime/src/task_spec/op_task_signature.h line 98 at r4 (raw file):

OpTaskSignature get_op_signature(task_id_t const &);
FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(OpTaskSignature, get_tensor_slots, set_arg_types, get_arg_types);

Suggestion:

FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(OpTaskSignature, op_tensor_slots, set_arg_types, task_arg_types);

lib/runtime/src/task_spec/op_task_signature.cc line 5 at r2 (raw file):

namespace FlexFlow {

OpTaskSignature::OpTaskSignature(OpTaskType t) {

Why was this deleted? I think it's necessary?


lib/runtime/src/task_spec/task_argument_accessor.h line 86 at r4 (raw file):

                                    std::vector<privilege_mode_to_accessor<Permissions::WO>>>;

struct ITaskArgumentAccessor {

Move into its own header/source files


lib/runtime/src/task_spec/task_argument_accessor.h line 89 at r4 (raw file):

  virtual PrivilegeType get_tensor(slot_id slot, Permissions priv) const = 0;

  virtual PrivilegeVariadicType get_variadic_tensor(slot_id slot, Permissions priv) const = 0;

Why was the template parameter removed from these? By requiring it to be static we were able to infer the return type and didn't have to return a variant


lib/runtime/src/task_spec/task_argument_accessor.h line 98 at r4 (raw file):

  virtual size_t get_device_idx() const = 0;
};

Should be used on anything that declares a virtual method (checks https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#c67-a-polymorphic-class-should-suppress-public-copymove)

Suggestion:

  virtual size_t get_device_idx() const = 0;
};
CHECK_RC_COPY_VIRTUAL_COMPLIANT(ITaskArgumentAccessor);

lib/runtime/src/task_spec/task_argument_accessor.h line 132 at r4 (raw file):

    return acc.get_variadic_argument<T>(slot);
  }
};

lib/runtime/src/task_spec/task_argument_accessor.h line 134 at r4 (raw file):

};

struct LegionTaskArgumentAccessor : public ITaskArgumentAccessor {

Move into its own header/source files


lib/runtime/src/task_spec/task_argument_accessor.h line 169 at r4 (raw file):

};

struct LocalTaskArgumentAccessor : public ITaskArgumentAccessor {

Move into its own header/source files


lib/runtime/src/task_spec/task_argument_accessor.h line 201 at r4 (raw file):

  void *allocate(size_t size);
  void deallocate(void *ptr);

The operators expect to be able to get an Allocator. Memory usage tracking should instead be moved to a wrapper around Allocator (e.g., TrackedAllocator(Allocator const &)) (which is also an Allocator of course)

Suggestion:

  Allocator get_allocator() const;

lib/runtime/src/task_spec/task_argument_accessor.h line 204 at r4 (raw file):

private:
  std::shared_ptr<SimTaskBinding const> sim_task_binding;

Is there a reason we have this as a shared ptr rather than a value?


lib/runtime/src/task_spec/task_argument_accessor.h line 210 at r4 (raw file):

using TaskArgumentAccessorBackend = variant<LegionTaskArgumentAccessor,
                                        LocalTaskArgumentAccessor>;

lib/runtime/src/task_spec/task_argument_accessor.h line 221 at r4 (raw file):

  optional<T> const &get_optional_argument(slot_id slot) const {
    return std::visit(OptionalArgumentAccessor(), this->ptr, slot);
  }

Also for get_argument and get_variadic_argument

Suggestion:

  template <typename T>
  optional<T> const &get_optional_argument(slot_id slot) const {
    return this->ptr->get_optional_argument(slot);
  }

lib/runtime/src/task_spec/task_argument_accessor.h line 251 at r4 (raw file):

  TaskArgumentAccessor(std::shared_ptr<TaskArgumentAccessorBackend const> &ptr)
      : ptr(ptr) {}
  std::shared_ptr<TaskArgumentAccessorBackend const> ptr;

Suggestion:

  TaskArgumentAccessor(std::shared_ptr<ITaskArgumentAccessorBackend const> ptr)
      : ptr(ptr) {}
  std::shared_ptr<ITaskArgumentAccessorBackend const> ptr;

lib/runtime/src/task_spec/task_argument_accessor.cc line 6 at r4 (raw file):

template <typename T>
T const &LocalTaskArgumentAccessor::get_argument(slot_id slot) const {

This needs to be rewritten entirely to use the signature/binding


lib/runtime/src/task_spec/task_argument_accessor.cc line 8 at r4 (raw file):

T const &LocalTaskArgumentAccessor::get_argument(slot_id slot) const {
  if (slot == PROFILING) {
    return get<ProfilingSettings>(this->arg_bindings.at(slot));

This should be retrieved from the signature/binding--the value of PROFILING can be different in every operator, and it can be assigned an arbitrary type (I could in theory have a slot PROFILING that takes in a Conv2DAttrs)


lib/runtime/src/task_spec/task_argument_accessor.cc line 29 at r4 (raw file):

  free(cpu_ptr);
  return ptr;
}

This whole file needs to be rewritten to use the signature and binding to generate values, not hardcoding in private enum values

Copy link
Collaborator Author

@reyna-abhyankar reyna-abhyankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: all files reviewed, 27 unresolved discussions (waiting on @lambda7xx, @lockshaw, and @wmdi)


lib/op-attrs/include/op-attrs/tensor_shape.h line 19 at r4 (raw file):

Previously, lockshaw (Colin Unger) wrote…

I think this is a bug and should be reverted: this->dims and this->data_type aren't initialized yet

Done.


lib/runtime/src/task_spec/op_task_invocation.h line 136 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Why remove this?

This is moved to op_task_signature.h


lib/runtime/src/task_spec/op_task_invocation.h line 118 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Probably should use visitable instead

Done.


lib/runtime/src/task_spec/op_task_invocation.cc line 51 at r1 (raw file):

Previously, lockshaw (Colin Unger) wrote…

My point is that if a tensor if marked untrainable in the signature, it probably shouldn't have its gradient bound (as it has no gradient)

See response to your other comment on OpSlotOptions for binding.


lib/runtime/src/task_spec/op_task_invocation.cc line 6 at r4 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Shouldn't the slot options be on the slot (i.e., as part of the signature) and not part of the binding?

I added it to the binding so we can check untrainable in infer_bwd_binding() but the other option is have infer_bwd_binding() also take the signature as an argument.


lib/runtime/src/task_spec/op_task_signature.h line 74 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Why does it not need a return value? For example, init_task for some operators returns a DeviceSpecific<xxxxxPerDeviceState>

Done. Should it be a vector of return values or is there only one return value?


lib/runtime/src/task_spec/op_task_signature.h line 92 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

What is the purpose of this field?

There is a get_task_type() function that is called in task_spec.cc. I've implemented it as follows:

Code snippet:

OpTaskType get_task_type() const {
  return this->type;
}

lib/runtime/src/task_spec/op_task_signature.h line 66 at r4 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Provided by visitable

Done.


lib/runtime/src/task_spec/op_task_signature.h line 98 at r4 (raw file):

OpTaskSignature get_op_signature(task_id_t const &);
FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(OpTaskSignature, get_tensor_slots, set_arg_types, get_arg_types);

Done. But also getting member is inaccessible because they're private. Should we make them public?


lib/runtime/src/task_spec/op_task_signature.cc line 5 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Why was this deleted? I think it's necessary?

Added this as

Code snippet:

OpTaskSignature::OpTaskSignature(OpTaskType t) {
  this->type = t;
}

lib/runtime/src/task_spec/op_task_signature.cc line 77 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Make OpTaskSignature a FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION so you get these automatically

Done. See above comment about member accessible


lib/runtime/src/task_spec/task_argument_accessor.h line 86 at r4 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Move into its own header/source files

Done.


lib/runtime/src/task_spec/task_argument_accessor.h line 89 at r4 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Why was the template parameter removed from these? By requiring it to be static we were able to infer the return type and didn't have to return a variant

AFAIK C++ does not allow virtual template functions. I was getting errors


lib/runtime/src/task_spec/task_argument_accessor.h line 98 at r4 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Should be used on anything that declares a virtual method (checks https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#c67-a-polymorphic-class-should-suppress-public-copymove)

Done. Added delete for operator= and virtual destructor.


lib/runtime/src/task_spec/task_argument_accessor.h line 132 at r4 (raw file):

    return acc.get_variadic_argument<T>(slot);
  }
};

Done.


lib/runtime/src/task_spec/task_argument_accessor.h line 134 at r4 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Move into its own header/source files

Done.


lib/runtime/src/task_spec/task_argument_accessor.h line 169 at r4 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Move into its own header/source files

Done.


lib/runtime/src/task_spec/task_argument_accessor.h line 210 at r4 (raw file):

using TaskArgumentAccessorBackend = variant<LegionTaskArgumentAccessor,
                                        LocalTaskArgumentAccessor>;

Done.


lib/runtime/src/task_spec/task_argument_accessor.h line 221 at r4 (raw file):

Previously, lockshaw (Colin Unger) wrote…

Also for get_argument and get_variadic_argument

Nice. I assumed because it's variant that we have to visit to access the function. I've made the change


lib/runtime/src/task_spec/task_argument_accessor.h line 251 at r4 (raw file):

  TaskArgumentAccessor(std::shared_ptr<TaskArgumentAccessorBackend const> &ptr)
      : ptr(ptr) {}
  std::shared_ptr<TaskArgumentAccessorBackend const> ptr;

Done.


lib/runtime/test/src/test_op_task_invocation.cc line 13 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

I don't think this is valid C++ syntax?

This is just an initializer list, right? My compiler is not complaining


lib/runtime/test/src/test_op_task_invocation.cc line 37 at r2 (raw file):

Previously, lockshaw (Colin Unger) wrote…

This should probably be pulled out into a separate function as every function's operator tests will be using it?

Also, tests should use CHECK, not assert

Done.

Copy link
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 7 of 9 files at r5, 7 of 7 files at r6, all commit messages.
Reviewable status: all files reviewed, 33 unresolved discussions (waiting on @lambda7xx, @reyna-abhyankar, and @wmdi)


lib/kernels/include/kernels/allocation.h line 34 at r6 (raw file):

};

struct TrackedAllocator: public Allocator {

Is this better as a subclass of Allocator or of IAllocator?


lib/runtime/src/task_spec/itask_argument_accessor.h line 52 at r6 (raw file):

FF_VISITABLE_STRUCT(FutureArgumentFormat, type, future_idx);

struct TaskArgumentsFormat {

TaskArgumentsFormat should be in a different file than ITaskArgumentAccessor, as TaskArgumentsFormat is part of compilation to legion tasks and ITaskArgumentAccessor is shared across all backends


lib/runtime/src/task_spec/itask_argument_accessor.h line 92 at r6 (raw file):

struct ITaskArgumentAccessor {
  ITaskArgumentAccessor& operator=(const ITaskArgumentAccessor&) = delete;
  virtual ~ITaskArgumentAccessor() {};

Suggestion:

  virtual ~ITaskArgumentAccessor() = default;

lib/runtime/src/task_spec/legion_task_argument_accessor.h line 20 at r6 (raw file):

namespace FlexFlow {

struct LegionTaskArgumentAccessor : public ITaskArgumentAccessor {

You're missing some overrides


lib/runtime/src/task_spec/legion_task_argument_accessor.cc line 0 at r6 (raw file):
#include "runtime/task_spec/legion_task_argument_accessor.h"


lib/runtime/src/task_spec/local_task_argument_accessor.h line 29 at r6 (raw file):

  PrivilegeType get_tensor(slot_id slot, Permissions priv) const override;

  PrivilegeVariadicType get_variadic_tensor(slot_id slot, Permissions priv) const override;

You're missing some overrides


lib/runtime/src/task_spec/local_task_argument_accessor.h line 53 at r6 (raw file):

  void *allocate(size_t size);
  void deallocate(void *ptr);

Delete

Code quote:

  size_t get_memory_usage() const {
    return memory_usage;
  }

  void *allocate(size_t size);
  void deallocate(void *ptr);

lib/runtime/src/task_spec/local_task_argument_accessor.h line 59 at r6 (raw file):

  SimTaskBinding sim_task_binding;
  Allocator local_allocator;
  size_t memory_usage;

Delete


lib/runtime/src/task_spec/local_task_argument_accessor.cc line 7 at r6 (raw file):

template <typename T>
T const &LocalTaskArgumentAccessor::get_argument(slot_id slot) const {
  assert(contains_key(this->sim_task_binding.arg_bindings.at(slot)));

contains_key takes multiple arguments. You should be able to compile these files to catch errors like this, even if you're not able to build+link the full runtime


lib/runtime/src/task_spec/local_task_argument_accessor.cc line 12 at r6 (raw file):

template <typename T>
optional<T> LocalTaskArgumentAccessor::get_optional_argument(slot_id slot) const {

Suggestion:

std::optional<T>

lib/runtime/src/task_spec/local_task_argument_accessor.cc line 24 at r6 (raw file):

}

void *LocalTaskArgumentAccessor::allocate(size_t size) {

Delete


lib/runtime/src/task_spec/local_task_argument_accessor.cc line 42 at r6 (raw file):

  if (slot == GATE_PREDS || slot == GATE_ASSIGN) {

Why is there special handling needed for some slots?


lib/runtime/src/task_spec/local_task_argument_accessor.cc line 47 at r6 (raw file):

    DataType data_type = gate_preds.shape.data_type;
    ArrayShape array_shape = {
        gate_preds.shape.dims.get_dims()}; // gate_preds.shape.dims.get_dims()

remove comments


lib/runtime/src/task_spec/local_task_argument_accessor.cc line 57 at r6 (raw file):

    Datatype data_type = output_shape.data_type;
    ArrayShape array_shape = {
        output_shape.dims.get_dims()}; // output_shape.dims.get_dims() return

remove comments


lib/runtime/src/task_spec/local_task_argument_accessor.cc line 64 at r6 (raw file):

  } else {
    throw mk_runtime_error(
        "Unknown Slot ID in LocalTaskArgumentAccessor::get_tensor");

Suggestion:

       "Unknown Slot ID {} in LocalTaskArgumentAccessor::get_tensor", slot);

lib/runtime/src/task_spec/local_task_argument_accessor.cc line 78 at r6 (raw file):

    ArrayShape array_shape = {
        shape.dims
            .get_dims()}; // shape.dims.get_dims() return std::vector<size_t>

remove comment


lib/runtime/src/task_spec/op_task_invocation.h line 118 at r2 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

Done.

You should actually be using FF_VISITABLE_STRUCT (or in this case, FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION) instead of use_visitable_cmp (which is deprecated except in certain cases when you're working with template classes)


lib/runtime/src/task_spec/op_task_invocation.cc line 6 at r4 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

I added it to the binding so we can check untrainable in infer_bwd_binding() but the other option is have infer_bwd_binding() also take the signature as an argument.

It should have access to the signature


lib/runtime/src/task_spec/op_task_signature.h line 74 at r2 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

Done. Should it be a vector of return values or is there only one return value?

Just one return value


lib/runtime/src/task_spec/op_task_signature.h line 98 at r4 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

Done. But also getting member is inaccessible because they're private. Should we make them public?

Sure


lib/runtime/src/task_spec/op_task_signature.h line 44 at r6 (raw file):

  explicit OpTaskSignature(OpTaskType);

  OpTaskType get_task_type() const {

If the fields become public you can/should get rid of this


lib/runtime/src/task_spec/op_task_signature.h line 70 at r6 (raw file):

  }

  // TODO: should this be a single type index? 

Yes, tasks are only allowed to return one value/type


lib/runtime/src/task_spec/op_task_signature.cc line 5 at r6 (raw file):

namespace FlexFlow {

OpTaskSignature::OpTaskSignature(OpTaskType t) {

Suggestion:

OpTaskSignature::OpTaskSignature(OpTaskType t): type(t) {

lib/runtime/src/task_spec/task_argument_accessor.h line 89 at r4 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

AFAIK C++ does not allow virtual template functions. I was getting errors

Makes sense, then just make sure TaskArgumentAccessor has the template arguments that convert from the variant


lib/runtime/src/task_spec/task_argument_accessor.h line 22 at r6 (raw file):

namespace FlexFlow {

using ITaskArgumentAccessorBackend = variant<LegionTaskArgumentAccessor,

Delete


lib/runtime/src/task_spec/task_argument_accessor.h line 64 at r6 (raw file):

  TaskArgumentAccessor(std::shared_ptr<ITaskArgumentAccessorBackend const> ptr)
      : ptr(ptr) {}
  std::shared_ptr<ITaskArgumentAccessorBackend const> ptr;

Suggestion:

  std::shared_ptr<ITaskArgumentAccessor const> ptr;

lib/runtime/test/src/test_op_task_invocation.cc line 13 at r2 (raw file):

Previously, reyna-abhyankar (Reyna Abhyankar) wrote…

This is just an initializer list, right? My compiler is not complaining

It's valid only in C++20 I think

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement task registration Implement infer_bwd_binding

5 participants