diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java b/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java
index a7e2dd0df82..85a905408c7 100644
--- a/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java
+++ b/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java
@@ -275,6 +275,27 @@ public Shape takeLast(int n) {
return Shape.of(newDimensions);
}
+ /**
+ * Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code begin} to {@code end}.
+ * @param begin Where to start the sub-shape.
+ * @param end Where to end the sub-shape, exclusive.
+ * @return the sub-shape bounded by begin and end.
+ */
+ public Shape subShape(int begin, int end){
+ if (end > numDimensions()) {
+ throw new ArrayIndexOutOfBoundsException(
+ "End index " + end + " out of bounds: shape only has " + numDimensions() + " dimensions.");
+ }
+ if (begin < 0) {
+ throw new ArrayIndexOutOfBoundsException(
+ "Begin index " + begin + " out of bounds: cannot be less than 0.");
+ }
+
+ long[] newDimensions = new long[end - begin];
+ System.arraycopy(dimensionSizes, begin, newDimensions, 0, end - begin);
+ return Shape.of(newDimensions);
+ }
+
/**
* Returns a new Shape, with a new first dimension added. In order for this call to succeed,
* {@link Shape#isUnknown()} must be {@code false}.
diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
index 529b0d99c39..b0fb67b5ce1 100644
--- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
+++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
@@ -59,6 +59,8 @@
import org.tensorflow.op.core.BatchToSpace;
import org.tensorflow.op.core.BatchToSpaceNd;
import org.tensorflow.op.core.Bitcast;
+import org.tensorflow.op.core.BooleanMask;
+import org.tensorflow.op.core.BooleanMaskUpdate;
import org.tensorflow.op.core.BroadcastDynamicShape;
import org.tensorflow.op.core.BroadcastTo;
import org.tensorflow.op.core.Bucketize;
@@ -347,10 +349,10 @@ public final class Ops {
public final SignalOps signal;
- public final QuantizationOps quantization;
-
public final TrainOps train;
+ public final QuantizationOps quantization;
+
private final Scope scope;
private Ops(Scope scope) {
@@ -372,8 +374,8 @@ private Ops(Scope scope) {
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
- quantization = new QuantizationOps(this);
train = new TrainOps(this);
+ quantization = new QuantizationOps(this);
}
/**
@@ -989,6 +991,61 @@ public Bitcast bitcast(Operand extends TType> input, Clas
return Bitcast.create(scope, input, type);
}
+ /**
+ * Apply boolean mask to tensor. Returns the flat array of each element corresponding to a {@code true} in the mask.
+ *
+ * Numpy equivalent is {@code tensor[mask]}.
+ *
+ * In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match
+ * the first K dimensions of {@code tensor}'s shape. We then have:
+ * {@code booleanMask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]}
+ * where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major order).
+ *
+ * The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default).
+ * In that case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
+ * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape.
+ *
+ * @param scope
+ * @param tensor The tensor to mask.
+ * @param mask The mask to apply.
+ * @param options carries optional attributes values
+ * @return The masked tensor.
+ */
+ public Operand booleanMask(Operand tensor, Operand mask,
+ BooleanMask.Options... options) {
+ return BooleanMask.create(scope, tensor, mask, options);
+ }
+
+ /**
+ * Updates a tensor at the masked values, and returns the updated tensor. Does not mutate the input tensors. {@code
+ * updates} will be broadcasted by default
+ *
+ * Numpy equivalent is `tensor[mask] = updates`.
+ *
+ * In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match the first K dimensions of
+ * {@code tensor}'s shape. We then have: {@code booleanMask(tensor, mask)[i, j1,...,jd] =
+ * tensor[i1,...,iK,j1,...,jd]} where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major
+ * order).
+ *
+ * The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default). In that
+ * case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match the first {@code axis +
+ * dim(mask)} dimensions of {@code tensor}'s shape.
+ *
+ * The shape of {@code updates} should be {@code [n, t_1, t_2, ...]} where {@code n} is the number of true values in
+ * {@code mask} and {@code t_i} is the {@code i}th dimension of {@code tensor} after {@code axis} and {@code mask}.
+ * {@code updates} will be broadcasted to this shape by default, which can be disabled using {@code options}.
+ *
+ * @param tensor The tensor to mask.
+ * @param mask The mask to apply.
+ * @param updates the new values
+ * @param options carries optional attributes values
+ * @return The masked tensor.
+ */
+ public Operand booleanMaskUpdate(Operand tensor, Operand mask,
+ Operand updates, BooleanMaskUpdate.Options... options) {
+ return BooleanMaskUpdate.create(scope, tensor, mask, updates, options);
+ }
+
/**
* Return the shape of s0 op s1 with broadcast.
*
@@ -1834,13 +1891,14 @@ public Constant constant(Shape shape, IntDataBuffer data) {
}
/**
- * Creates a scalar of {@code type}, with the value of {@code number}.
- * {@code number} may be truncated if it does not fit in the target type.
+ * Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not
+ * fit in the target type.
*
* @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating})
* @param number the value of the tensor
* @return a constant of the passed type
- * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or unknown.
+ * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or
+ * unknown.
*/
public Constant constant(Class type, Number number) {
return Constant.tensorOf(scope, type, number);
@@ -1892,14 +1950,14 @@ public Constant constantOf(T tensor) {
}
/**
- * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}.
- * {@code number} may be truncated if it does not fit in the target type.
+ * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be
+ * truncated if it does not fit in the target type.
*
* @param toMatch the operand providing the target type
* @param number the value of the tensor
* @return a constant with the same type as {@code toMatch}
- * @see Ops#constant(Class, Number)
* @throws IllegalArgumentException if the type is unknown (which should be impossible).
+ * @see Ops#constant(Class, Number)
*/
public Constant constantOfSameType(Operand toMatch, Number number) {
return Constant.tensorOfSameType(scope, toMatch, number);
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java
index 73fa340a487..85e283d9260 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java
@@ -125,6 +125,25 @@ public Scope withName(String opName) {
return new Scope(env, nameScope.withName(opName), controlDependencies, deviceSpec);
}
+ /**
+ * Returns a new scope where added operations will be prefixed by this scope's op name
+ * (set by {@link #withName(String)}), or the given default if it is unset. This is intended to be used for
+ * composite ops.
+ *
+ * Ops created with this scope will have {@code name/opName/} as the prefix. The actual
+ * name will be unique in the returned scope. All other properties are inherited from the current
+ * scope.
+ *
+ *
The default child scope name must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*}
+ *
+ * @param defaultName name of the sub scope if this scope's name hasn't been set.
+ * @return a new subscope
+ * @throws IllegalArgumentException if the name is invalid
+ */
+ public Scope withNameAsSubScope(String defaultName){
+ return new Scope(env, nameScope.withSubScope(nameScope.makeOpName(defaultName)), controlDependencies, deviceSpec);
+ }
+
/**
* Return a new scope that uses the provided device specification for an op.
*
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java
new file mode 100644
index 00000000000..85a41ef485f
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java
@@ -0,0 +1,157 @@
+/*
+ Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================
+ */
+package org.tensorflow.op.core;
+
+import java.util.Arrays;
+import java.util.Collections;
+import org.tensorflow.Operand;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.ndarray.index.Indices;
+import org.tensorflow.op.Scope;
+import org.tensorflow.op.annotation.Endpoint;
+import org.tensorflow.op.annotation.Operator;
+import org.tensorflow.types.TBool;
+import org.tensorflow.types.TInt32;
+import org.tensorflow.types.TInt64;
+import org.tensorflow.types.family.TType;
+
+@Operator
+public abstract class BooleanMask {
+
+ /**
+ * Apply boolean mask to tensor. Returns the flat array of each element corresponding to a {@code true} in the mask.
+ *
+ * Numpy equivalent is {@code tensor[mask]}.
+ *
+ * In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match
+ * the first K dimensions of {@code tensor}'s shape. We then have:
+ * {@code booleanMask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]}
+ * where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major order).
+ *
+ * The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default).
+ * In that case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
+ * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape.
+ *
+ * @param scope
+ * @param tensor The tensor to mask.
+ * @param mask The mask to apply.
+ * @param options carries optional attributes values
+ * @return The masked tensor.
+ */
+ @Endpoint(name = "booleanMask")
+ public static Operand create(Scope scope, Operand tensor, Operand mask,
+ Options... options) {
+
+ scope = scope.withNameAsSubScope("BooleanMask");
+
+ int axis = 0;
+ if (options != null) {
+ for (Options opts : options) {
+ if (opts.axis != null) {
+ axis = opts.axis;
+ }
+ }
+ }
+
+ if (axis < 0) {
+ axis += tensor.rank();
+ }
+
+ Shape maskShape = mask.shape();
+ Shape tensorShape = tensor.shape();
+
+ if (maskShape.numDimensions() == 0) {
+ throw new IllegalArgumentException("Mask cannot be a scalar.");
+ }
+ if (maskShape.hasUnknownDimension()) {
+ throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
+ }
+
+ Operand axisTensor = Constant.scalarOf(scope, axis);
+ Shape requiredMaskShape = tensorShape.subShape(axis, axis + maskShape.numDimensions());
+ if (!requiredMaskShape.isCompatibleWith(maskShape)) {
+ throw new IllegalArgumentException(
+ "Mask shape " + maskShape + " is not compatible with the required mask shape: " + requiredMaskShape + ".");
+ }
+
+ org.tensorflow.op.core.Shape liveShape = org.tensorflow.op.core.Shape.create(scope, tensor);
+
+ Operand leadingSize = ReduceProd.create(scope,
+ StridedSliceHelper.stridedSlice(scope,
+ liveShape,
+ Indices.range(axis, axis + maskShape.numDimensions())
+ ),
+ Constant.arrayOf(scope, 0)
+ );
+
+ Operand flattened = Reshape.create(scope, tensor, Concat.create(
+ scope,
+ Arrays.asList(
+ StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceTo(axis)),
+ Reshape.create(scope, leadingSize, Constant.arrayOf(scope, 1)),
+ StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions()))
+ ),
+ Constant.scalarOf(scope, 0)
+ ));
+
+ Operand flatMask = Reshape.create(scope, mask, Constant.arrayOf(scope, -1));
+
+ Operand indices = Squeeze.create(scope, Where.create(scope, flatMask), Squeeze.axis(Collections.singletonList(1L)));
+ return Gather.create(scope, flattened, indices, axisTensor);
+ }
+
+ /**
+ * Used to indicate the axis to mask from.
+ * {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
+ * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape.
+ * @param axis the axis to mask from. Uses 0 if null.
+ */
+ public static Options axis(Integer axis){
+ return new Options().axis(axis);
+ }
+
+
+ /**
+ * Used to indicate the axis to mask from.
+ * {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
+ * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape.
+ * @param axis the axis to mask from.
+ */
+ public static Options axis(int axis){
+ return new Options().axis(axis);
+ }
+
+ /**
+ * Optional attributes for {@link org.tensorflow.op.core.BooleanMask}
+ */
+ public static class Options {
+
+ /**
+ * @param axis (Optional) The axis to mask from, or 0 if not set.
+ */
+ public Options axis(Integer axis) {
+ this.axis = axis;
+ return this;
+ }
+
+ private Integer axis;
+
+ private Options() {
+ }
+ }
+
+}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java
new file mode 100644
index 00000000000..a40ae7ab017
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java
@@ -0,0 +1,189 @@
+/*
+ Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================
+ */
+package org.tensorflow.op.core;
+
+import java.util.Arrays;
+import org.tensorflow.Operand;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.ndarray.index.Indices;
+import org.tensorflow.op.Scope;
+import org.tensorflow.op.annotation.Endpoint;
+import org.tensorflow.op.annotation.Operator;
+import org.tensorflow.types.TBool;
+import org.tensorflow.types.TInt32;
+import org.tensorflow.types.TInt64;
+import org.tensorflow.types.family.TType;
+
+@Operator
+public abstract class BooleanMaskUpdate {
+
+ /**
+ * Updates a tensor at the masked values, and returns the updated tensor. Does not mutate the input tensors. {@code
+ * updates} will be broadcasted by default
+ *
+ * Numpy equivalent is `tensor[mask] = updates`.
+ *
+ * In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match the first K dimensions of
+ * {@code tensor}'s shape. We then have: {@code booleanMask(tensor, mask)[i, j1,...,jd] =
+ * tensor[i1,...,iK,j1,...,jd]} where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major
+ * order).
+ *
+ * The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default). In that
+ * case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match the first {@code axis +
+ * dim(mask)} dimensions of {@code tensor}'s shape.
+ *
+ * The shape of {@code updates} should be {@code [n, t_1, t_2, ...]} where {@code n} is the number of true values in
+ * {@code mask} and {@code t_i} is the {@code i}th dimension of {@code tensor} after {@code axis} and {@code mask}.
+ * {@code updates} will be broadcasted to this shape by default, which can be disabled using {@code options}.
+ *
+ * @param tensor The tensor to mask.
+ * @param mask The mask to apply.
+ * @param updates the new values
+ * @param options carries optional attributes values
+ * @return The masked tensor.
+ */
+ @Endpoint(name = "booleanMaskUpdate")
+ public static Operand create(Scope scope, Operand tensor, Operand mask,
+ Operand updates,
+ Options... options) {
+
+ scope = scope.withNameAsSubScope("BooleanMaskUpdate");
+
+ int axis = 0;
+ boolean broadcast = true;
+ if (options != null) {
+ for (Options opts : options) {
+ if (opts.axis != null) {
+ axis = opts.axis;
+ }
+ if (opts.broadcast != null) {
+ broadcast = opts.broadcast;
+ }
+ }
+ }
+
+ if (axis < 0) {
+ axis += tensor.rank();
+ }
+
+ Shape maskShape = mask.shape();
+ Shape tensorShape = tensor.shape();
+
+ if (maskShape.numDimensions() == 0) {
+ throw new IllegalArgumentException("Mask cannot be a scalar.");
+ }
+ if (maskShape.hasUnknownDimension()) {
+ throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
+ }
+
+ Shape requiredMaskShape = tensorShape.subShape(axis, axis + maskShape.numDimensions());
+ if (!requiredMaskShape.isCompatibleWith(maskShape)) {
+ throw new IllegalArgumentException(
+ "Mask shape " + maskShape + " is not compatible with the required mask shape: " + requiredMaskShape + ".");
+ }
+
+ Operand liveShape = org.tensorflow.op.core.Shape.create(scope, tensor);
+
+ Operand leadingSize = ReduceProd.create(scope,
+ StridedSliceHelper.stridedSlice(scope,
+ liveShape,
+ Indices.sliceTo(axis + maskShape.numDimensions())
+ ),
+ Constant.arrayOf(scope, 0)
+ );
+
+ Operand innerShape = StridedSliceHelper
+ .stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions()));
+
+ Operand reshaped = Reshape.create(scope, tensor, Concat.create(
+ scope,
+ Arrays.asList(
+ Reshape.create(scope, leadingSize, Constant.arrayOf(scope, 1)),
+ innerShape
+ ),
+ Constant.scalarOf(scope, 0)
+ ));
+
+ Operand indices = Where.create(scope, mask);
+
+ if (broadcast) {
+ Operand indicesShape = org.tensorflow.op.core.Shape.create(scope, indices);
+ // this is the number of true values
+ Operand batchShape = StridedSliceHelper.stridedSlice(scope, indicesShape, Indices.sliceTo(-1));
+
+ Operand updateShape = Concat.create(
+ scope,
+ Arrays.asList(
+ batchShape,
+ innerShape
+ ),
+ Constant.scalarOf(scope, 0)
+ );
+
+ updates = BroadcastTo.create(scope, updates, updateShape);
+ }
+
+ Operand newValue = TensorScatterNdUpdate.create(scope, reshaped, indices, updates);
+ return Reshape.create(scope, newValue, liveShape);
+ }
+
+ /**
+ * Used to indicate the axis to mask from. {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
+ * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape.
+ *
+ * @param axis the axis to mask from. Uses 0 if null.
+ */
+ public static Options axis(Integer axis) {
+ return new Options().axis(axis);
+ }
+
+ /**
+ * Whether to try broadcasting update. True by default.
+ */
+ public static Options broadcast(Boolean broadcast) {
+ return new Options().broadcast(broadcast);
+ }
+
+ /**
+ * Optional attributes for {@link BooleanMaskUpdate}
+ */
+ public static class Options {
+
+ /**
+ * @param axis (Optional) The axis to mask from, or 0 if not set.
+ */
+ public Options axis(Integer axis) {
+ this.axis = axis;
+ return this;
+ }
+
+ /**
+ * @param broadcast (Optional) Whether to try broadcasting update. True by default.
+ */
+ public Options broadcast(Boolean broadcast) {
+ this.broadcast = broadcast;
+ return this;
+ }
+
+ private Integer axis;
+ private Boolean broadcast;
+
+ private Options() {
+ }
+ }
+
+}
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java
new file mode 100644
index 00000000000..a4d9293ccf8
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java
@@ -0,0 +1,67 @@
+/*
+ Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================
+ */
+package org.tensorflow.op.core;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import org.junit.Test;
+import org.tensorflow.Graph;
+import org.tensorflow.Operand;
+import org.tensorflow.Session;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.op.Scope;
+import org.tensorflow.types.TBool;
+import org.tensorflow.types.TFloat32;
+import org.tensorflow.types.TInt32;
+
+public class BooleanMaskTest {
+ @Test
+ public void testBooleanMask(){
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Operand input = Constant.arrayOf(scope, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
+ Operand input2 = ExpandDims.create(scope, input, Constant.scalarOf(scope, 0));
+
+ Operand mask = Constant.arrayOf(scope, true, true, false, false, true, true, true, false, false, false);
+
+ Operand output1 = BooleanMask.create(scope, input, mask);
+ Operand output2 = BooleanMask.create(scope, input2, mask, BooleanMask.axis(1));
+
+ try (TFloat32 result = (TFloat32) sess.runner().fetch(output1).run().get(0)) {
+ // expected shape from Python tensorflow
+ assertEquals(Shape.of(5), result.shape());
+ assertEquals(0, result.getFloat(0));
+ assertEquals(1, result.getFloat(1));
+ assertEquals(4, result.getFloat(2));
+ assertEquals(5, result.getFloat(3));
+ assertEquals(6, result.getFloat(4));
+ }
+
+ try (TFloat32 result = (TFloat32) sess.runner().fetch(output2).run().get(0)) {
+ // expected shape from Python tensorflow
+ assertEquals(Shape.of(5), result.shape());
+ assertEquals(0, result.getFloat(0));
+ assertEquals(1, result.getFloat(1));
+ assertEquals(4, result.getFloat(2));
+ assertEquals(5, result.getFloat(3));
+ assertEquals(6, result.getFloat(4));
+ }
+ }
+ }
+}
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java
new file mode 100644
index 00000000000..ab852bbffb2
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java
@@ -0,0 +1,146 @@
+/*
+ Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================
+ */
+package org.tensorflow.op.core;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import java.util.List;
+import org.junit.Test;
+import org.tensorflow.Graph;
+import org.tensorflow.Operand;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.op.Scope;
+import org.tensorflow.types.TBool;
+import org.tensorflow.types.TInt32;
+
+public class BooleanMaskUpdateTest {
+
+ @Test
+ public void testBooleanMaskUpdateSlice() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Operand input = Constant.tensorOf(scope, new int[][]{{0, 0, 0}, {1, 1, 1}, {2, 2, 2}});
+
+ Operand mask = Constant.arrayOf(scope, true, false, false);
+
+ Operand value = Constant.tensorOf(scope, new int[][]{{-1, -1, -1}});
+
+ Operand output = BooleanMaskUpdate.create(scope, input, mask, value);
+
+ Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1));
+
+ List results = sess.runner().fetch(output).fetch(bcastOutput).run();
+ try (TInt32 result = (TInt32) results.get(0);
+ TInt32 bcastResult = (TInt32) results.get(1)) {
+
+ assertEquals(Shape.of(3, 3), result.shape());
+
+ assertEquals(-1, result.getInt(0, 0));
+ assertEquals(-1, result.getInt(0, 1));
+ assertEquals(-1, result.getInt(0, 2));
+ assertEquals(1, result.getInt(1, 0));
+ assertEquals(1, result.getInt(1, 1));
+ assertEquals(1, result.getInt(1, 2));
+ assertEquals(2, result.getInt(2, 0));
+ assertEquals(2, result.getInt(2, 1));
+ assertEquals(2, result.getInt(2, 2));
+
+ assertEquals(result, bcastResult);
+ }
+ }
+ }
+
+ @Test
+ public void testBooleanMaskUpdateSliceWithBroadcast() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Operand input = Constant.tensorOf(scope, new int[][]{{0, 0, 0}, {1, 1, 1}, {2, 2, 2}});
+
+ Operand mask = Constant.arrayOf(scope, true, false, false);
+
+ Operand value = Constant.vectorOf(scope, new int[]{-1, -1, -1});
+
+ Operand output = BooleanMaskUpdate.create(scope, input, mask, value);
+
+ Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1));
+
+ List results = sess.runner().fetch(output).fetch(bcastOutput).run();
+ try (TInt32 result = (TInt32) results.get(0);
+ TInt32 bcastResult = (TInt32) results.get(1)) {
+
+ assertEquals(Shape.of(3, 3), result.shape());
+
+ assertEquals(-1, result.getInt(0, 0));
+ assertEquals(-1, result.getInt(0, 1));
+ assertEquals(-1, result.getInt(0, 2));
+ assertEquals(1, result.getInt(1, 0));
+ assertEquals(1, result.getInt(1, 1));
+ assertEquals(1, result.getInt(1, 2));
+ assertEquals(2, result.getInt(2, 0));
+ assertEquals(2, result.getInt(2, 1));
+ assertEquals(2, result.getInt(2, 2));
+
+ assertEquals(result, bcastResult);
+ }
+ }
+ }
+
+ @Test
+ public void testBooleanMaskUpdateAxis() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Operand input = Constant.tensorOf(scope, new int[][][]{{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}});
+
+ Operand mask = Constant.arrayOf(scope, true, true, false, false, true, true, true, false, false, false);
+
+ Operand value = Constant.arrayOf(scope, -1, -1, -1, -1, -1);
+
+ Operand output = BooleanMaskUpdate.create(scope, input, mask, value, BooleanMaskUpdate.axis(2));
+
+ Operand bcastOutput = BooleanMaskUpdate
+ .create(scope, input, mask, Constant.scalarOf(scope, -1), BooleanMaskUpdate.axis(2));
+
+ List results = sess.runner().fetch(output).fetch(bcastOutput).run();
+ try (TInt32 result = (TInt32) results.get(0);
+ TInt32 bcastResult = (TInt32) results.get(1)) {
+
+ assertEquals(Shape.of(1, 1, 10), result.shape());
+
+ assertEquals(-1, result.getInt(0, 0, 0));
+ assertEquals(-1, result.getInt(0, 0, 1));
+ assertEquals(2, result.getInt(0, 0, 2));
+ assertEquals(3, result.getInt(0, 0, 3));
+ assertEquals(-1, result.getInt(0, 0, 4));
+ assertEquals(-1, result.getInt(0, 0, 5));
+ assertEquals(-1, result.getInt(0, 0, 6));
+ assertEquals(7, result.getInt(0, 0, 7));
+ assertEquals(8, result.getInt(0, 0, 8));
+ assertEquals(9, result.getInt(0, 0, 9));
+
+ assertEquals(result, bcastResult);
+ }
+ }
+ }
+}