Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/operator/tensor/elemwise_binary_scalar_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,28 @@ class BinaryScalarOp : public UnaryOp {
}
}

template<typename xpu, typename OP>
static void LogicComputeEx(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
DCHECK_EQ(inputs.size(), 1);
DCHECK_EQ(outputs.size(), 1);
const auto in_stype = inputs[0].storage_type();
const auto out_stype = outputs[0].storage_type();
if (req[0] == kNullOp) {
return;
}
if ((in_stype == kRowSparseStorage && out_stype == kRowSparseStorage) ||
(in_stype == kCSRStorage && out_stype == kCSRStorage)) {
// csr -> csr, or rsp -> rsp
UnaryOp::MapToFCompute<xpu>(attrs, ctx, inputs, req, outputs, Compute<xpu, OP>);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}

template<typename xpu, typename OP>
static void Backward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
Expand Down
59 changes: 47 additions & 12 deletions src/operator/tensor/elemwise_binary_scalar_op_logic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,68 @@
namespace mxnet {
namespace op {

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_equal_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::eq>)
#define MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(__name$, __kernel$) \
MXNET_OPERATOR_REGISTER_BINARY_SCALAR(__name$) \
.set_attr<FInferStorageType>("FInferStorageType", BinaryScalarLogicStorageType<__kernel$>) \
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, __kernel$>) \
.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryScalarOp::LogicComputeEx<cpu, __kernel$>)

template<typename OP>
static bool BinaryScalarLogicStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 1);
const auto in_stype = in_attrs->at(0);
auto &out_stype = out_attrs->at(0);
bool dispatched = false;
const double alpha = nnvm::get<double>(attrs.parsed);
bool is_sparse = OP::Map(static_cast<double>(0), alpha) == 0;
if (!dispatched && in_stype == kDefaultStorage) {
// dns -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched && in_stype == kRowSparseStorage && is_sparse) {
// rsp -> rsp
dispatched = storage_type_assign(&out_stype, kRowSparseStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched && in_stype == kCSRStorage && is_sparse) {
// csr -> csr
dispatched = storage_type_assign(&out_stype, kCSRStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}


MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_equal_scalar, mshadow_op::eq)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_alias("_EqualScalar");

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_not_equal_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::ne>)
MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_not_equal_scalar, mshadow_op::ne)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_alias("_NotEqualScalar");

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_greater_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::gt>)
MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_greater_scalar, mshadow_op::gt)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_alias("_GreaterScalar");

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_greater_equal_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::ge>)
MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_greater_equal_scalar, mshadow_op::ge)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_alias("_GreaterEqualScalar");

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_lesser_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::lt>)
MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_lesser_scalar, mshadow_op::lt)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_alias("_LesserScalar");

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_lesser_equal_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::le>)
MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_lesser_equal_scalar, mshadow_op::le)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_alias("_LesserEqualScalar");

Expand Down
18 changes: 12 additions & 6 deletions src/operator/tensor/elemwise_binary_scalar_op_logic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,28 @@ namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_equal_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::eq>);
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::eq>)
.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::LogicComputeEx<gpu, mshadow_op::eq>);

NNVM_REGISTER_OP(_not_equal_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::ne>);
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::ne>)
.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::LogicComputeEx<gpu, mshadow_op::ne>);

NNVM_REGISTER_OP(_greater_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::gt>);
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::gt>)
.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::LogicComputeEx<gpu, mshadow_op::gt>);

NNVM_REGISTER_OP(_greater_equal_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::ge>);
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::ge>)
.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::LogicComputeEx<gpu, mshadow_op::ge>);

NNVM_REGISTER_OP(_lesser_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::lt>);
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::lt>)
.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::LogicComputeEx<gpu, mshadow_op::lt>);

NNVM_REGISTER_OP(_lesser_equal_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::le>);
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::le>)
.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryScalarOp::LogicComputeEx<gpu, mshadow_op::le>);

NNVM_REGISTER_OP(_logical_and_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::logical_and>);
Expand Down
28 changes: 26 additions & 2 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,12 @@ def test_sparse_nd_equal():
y = sparse_nd_ones(shape, stype)
z = x == y
assert (z.asnumpy() == np.zeros(shape)).all()
z = 0 == x
z = 0 == y
assert (z.asnumpy() == np.zeros(shape)).all()
assert z.stype == 'default'
z = 1 == y
assert (z.asnumpy() == np.ones(shape)).all()
assert z.stype == stype


@with_seed()
Expand All @@ -192,8 +196,12 @@ def test_sparse_nd_not_equal():
y = sparse_nd_ones(shape, stype)
z = x != y
assert (z.asnumpy() == np.ones(shape)).all()
z = 0 != x
z = 0 != y
assert (z.asnumpy() == np.ones(shape)).all()
assert z.stype == stype
z = 1 != y
assert (z.asnumpy() == np.zeros(shape)).all()
assert z.stype == 'default'
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this default storage type?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming most of y's entries are 0's, 0 != 1 gives 1, so it should be in dense format



@with_seed()
Expand All @@ -206,8 +214,13 @@ def test_sparse_nd_greater():
assert (z.asnumpy() == np.zeros(shape)).all()
z = y > 0
assert (z.asnumpy() == np.ones(shape)).all()
assert z.stype == stype
z = 0 > y
assert (z.asnumpy() == np.zeros(shape)).all()
assert z.stype == stype
z = y > 1
assert (z.asnumpy() == np.zeros(shape)).all()
assert z.stype == stype


@with_seed()
Expand All @@ -220,10 +233,13 @@ def test_sparse_nd_greater_equal():
assert (z.asnumpy() == np.zeros(shape)).all()
z = y >= 0
assert (z.asnumpy() == np.ones(shape)).all()
assert z.stype == 'default'
z = 0 >= y
assert (z.asnumpy() == np.zeros(shape)).all()
assert z.stype == 'default'
z = y >= 1
assert (z.asnumpy() == np.ones(shape)).all()
assert z.stype == stype


@with_seed()
Expand All @@ -236,8 +252,13 @@ def test_sparse_nd_lesser():
assert (z.asnumpy() == np.zeros(shape)).all()
z = 0 < y
assert (z.asnumpy() == np.ones(shape)).all()
assert z.stype == stype
z = y < 0
assert (z.asnumpy() == np.zeros(shape)).all()
assert z.stype == stype
z = y < 1
assert (z.asnumpy() == np.zeros(shape)).all()
assert z.stype == 'default'


@with_seed()
Expand All @@ -250,10 +271,13 @@ def test_sparse_nd_lesser_equal():
assert (z.asnumpy() == np.zeros(shape)).all()
z = 0 <= y
assert (z.asnumpy() == np.ones(shape)).all()
assert z.stype == 'default'
z = y <= 0
assert (z.asnumpy() == np.zeros(shape)).all()
assert z.stype == 'default'
z = 1 <= y
assert (z.asnumpy() == np.ones(shape)).all()
assert z.stype == stype


@with_seed()
Expand Down