Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[Operator] Support ndim > 3 for batch_dot #10386

@sxjscience

Description

@sxjscience

In batch_dot, i.e, C = batch_dot(A, B), both inputs A & B must have ndim=3. However, in some cases our inputs could have have higher dimensionality. We can consider to support the following two cases:

  • Both A, B have ndim >3

    We should compute the output as:

    C[..., i, j] = sum_k (A[..., i, k] * B[..., k, j]), for all indices i, j.
    

    For example, if the shape of A is (batch_size, a1, a2, …, b, c) and the shape of B is (batch_size, a1, a2, …, c, d), the shape of C will be (batch_size, a1, a2, ..., b, d).
    This is similar to the matmul in Tensorflow

  • One of A, B have ndim > 3

    Assume A.ndim >3, we have

    C[b, ..., i, j] = sum_k (A[b, ..., i, k] * B[b, k, j]), for all indices i, j.
    

Supporting these two cases will solve the problems in https://discuss.gluon.ai/t/topic/2534/8, https://discuss.gluon.ai/t/topic/5618/9

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions