diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java b/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java index a7e2dd0df82..85a905408c7 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/Shape.java @@ -275,6 +275,27 @@ public Shape takeLast(int n) { return Shape.of(newDimensions); } + /** + * Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code begin} to {@code end}. + * @param begin Where to start the sub-shape. + * @param end Where to end the sub-shape, exclusive. + * @return the sub-shape bounded by begin and end. + */ + public Shape subShape(int begin, int end){ + if (end > numDimensions()) { + throw new ArrayIndexOutOfBoundsException( + "End index " + end + " out of bounds: shape only has " + numDimensions() + " dimensions."); + } + if (begin < 0) { + throw new ArrayIndexOutOfBoundsException( + "Begin index " + begin + " out of bounds: cannot be less than 0."); + } + + long[] newDimensions = new long[end - begin]; + System.arraycopy(dimensionSizes, begin, newDimensions, 0, end - begin); + return Shape.of(newDimensions); + } + /** * Returns a new Shape, with a new first dimension added. In order for this call to succeed, * {@link Shape#isUnknown()} must be {@code false}. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 529b0d99c39..b0fb67b5ce1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -59,6 +59,8 @@ import org.tensorflow.op.core.BatchToSpace; import org.tensorflow.op.core.BatchToSpaceNd; import org.tensorflow.op.core.Bitcast; +import org.tensorflow.op.core.BooleanMask; +import org.tensorflow.op.core.BooleanMaskUpdate; import org.tensorflow.op.core.BroadcastDynamicShape; import org.tensorflow.op.core.BroadcastTo; import org.tensorflow.op.core.Bucketize; @@ -347,10 +349,10 @@ public final class Ops { public final SignalOps signal; - public final QuantizationOps quantization; - public final TrainOps train; + public final QuantizationOps quantization; + private final Scope scope; private Ops(Scope scope) { @@ -372,8 +374,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - quantization = new QuantizationOps(this); train = new TrainOps(this); + quantization = new QuantizationOps(this); } /** @@ -989,6 +991,61 @@ public Bitcast bitcast(Operand input, Clas return Bitcast.create(scope, input, type); } + /** + * Apply boolean mask to tensor. Returns the flat array of each element corresponding to a {@code true} in the mask. + *

+ * Numpy equivalent is {@code tensor[mask]}. + *

+ * In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match + * the first K dimensions of {@code tensor}'s shape. We then have: + * {@code booleanMask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]} + * where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major order). + *

+ * The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default). + * In that case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match + * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape. + * + * @param scope + * @param tensor The tensor to mask. + * @param mask The mask to apply. + * @param options carries optional attributes values + * @return The masked tensor. + */ + public Operand booleanMask(Operand tensor, Operand mask, + BooleanMask.Options... options) { + return BooleanMask.create(scope, tensor, mask, options); + } + + /** + * Updates a tensor at the masked values, and returns the updated tensor. Does not mutate the input tensors. {@code + * updates} will be broadcasted by default + *

+ * Numpy equivalent is `tensor[mask] = updates`. + *

+ * In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match the first K dimensions of + * {@code tensor}'s shape. We then have: {@code booleanMask(tensor, mask)[i, j1,...,jd] = + * tensor[i1,...,iK,j1,...,jd]} where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major + * order). + *

+ * The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default). In that + * case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match the first {@code axis + + * dim(mask)} dimensions of {@code tensor}'s shape. + *

+ * The shape of {@code updates} should be {@code [n, t_1, t_2, ...]} where {@code n} is the number of true values in + * {@code mask} and {@code t_i} is the {@code i}th dimension of {@code tensor} after {@code axis} and {@code mask}. + * {@code updates} will be broadcasted to this shape by default, which can be disabled using {@code options}. + * + * @param tensor The tensor to mask. + * @param mask The mask to apply. + * @param updates the new values + * @param options carries optional attributes values + * @return The masked tensor. + */ + public Operand booleanMaskUpdate(Operand tensor, Operand mask, + Operand updates, BooleanMaskUpdate.Options... options) { + return BooleanMaskUpdate.create(scope, tensor, mask, updates, options); + } + /** * Return the shape of s0 op s1 with broadcast. *

@@ -1834,13 +1891,14 @@ public Constant constant(Shape shape, IntDataBuffer data) { } /** - * Creates a scalar of {@code type}, with the value of {@code number}. - * {@code number} may be truncated if it does not fit in the target type. + * Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not + * fit in the target type. * * @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating}) * @param number the value of the tensor * @return a constant of the passed type - * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or unknown. + * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or + * unknown. */ public Constant constant(Class type, Number number) { return Constant.tensorOf(scope, type, number); @@ -1892,14 +1950,14 @@ public Constant constantOf(T tensor) { } /** - * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. - * {@code number} may be truncated if it does not fit in the target type. + * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be + * truncated if it does not fit in the target type. * * @param toMatch the operand providing the target type * @param number the value of the tensor * @return a constant with the same type as {@code toMatch} - * @see Ops#constant(Class, Number) * @throws IllegalArgumentException if the type is unknown (which should be impossible). + * @see Ops#constant(Class, Number) */ public Constant constantOfSameType(Operand toMatch, Number number) { return Constant.tensorOfSameType(scope, toMatch, number); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java index 73fa340a487..85e283d9260 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Scope.java @@ -125,6 +125,25 @@ public Scope withName(String opName) { return new Scope(env, nameScope.withName(opName), controlDependencies, deviceSpec); } + /** + * Returns a new scope where added operations will be prefixed by this scope's op name + * (set by {@link #withName(String)}), or the given default if it is unset. This is intended to be used for + * composite ops. + * + *

Ops created with this scope will have {@code name/opName/} as the prefix. The actual + * name will be unique in the returned scope. All other properties are inherited from the current + * scope. + * + *

The default child scope name must match the regular expression {@code [A-Za-z0-9.][A-Za-z0-9_.\-]*} + * + * @param defaultName name of the sub scope if this scope's name hasn't been set. + * @return a new subscope + * @throws IllegalArgumentException if the name is invalid + */ + public Scope withNameAsSubScope(String defaultName){ + return new Scope(env, nameScope.withSubScope(nameScope.makeOpName(defaultName)), controlDependencies, deviceSpec); + } + /** * Return a new scope that uses the provided device specification for an op. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java new file mode 100644 index 00000000000..85a41ef485f --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java @@ -0,0 +1,157 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Collections; +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Indices; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +@Operator +public abstract class BooleanMask { + + /** + * Apply boolean mask to tensor. Returns the flat array of each element corresponding to a {@code true} in the mask. + *

+ * Numpy equivalent is {@code tensor[mask]}. + *

+ * In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match + * the first K dimensions of {@code tensor}'s shape. We then have: + * {@code booleanMask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]} + * where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major order). + *

+ * The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default). + * In that case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match + * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape. + * + * @param scope + * @param tensor The tensor to mask. + * @param mask The mask to apply. + * @param options carries optional attributes values + * @return The masked tensor. + */ + @Endpoint(name = "booleanMask") + public static Operand create(Scope scope, Operand tensor, Operand mask, + Options... options) { + + scope = scope.withNameAsSubScope("BooleanMask"); + + int axis = 0; + if (options != null) { + for (Options opts : options) { + if (opts.axis != null) { + axis = opts.axis; + } + } + } + + if (axis < 0) { + axis += tensor.rank(); + } + + Shape maskShape = mask.shape(); + Shape tensorShape = tensor.shape(); + + if (maskShape.numDimensions() == 0) { + throw new IllegalArgumentException("Mask cannot be a scalar."); + } + if (maskShape.hasUnknownDimension()) { + throw new IllegalArgumentException("Mask cannot have unknown number of dimensions"); + } + + Operand axisTensor = Constant.scalarOf(scope, axis); + Shape requiredMaskShape = tensorShape.subShape(axis, axis + maskShape.numDimensions()); + if (!requiredMaskShape.isCompatibleWith(maskShape)) { + throw new IllegalArgumentException( + "Mask shape " + maskShape + " is not compatible with the required mask shape: " + requiredMaskShape + "."); + } + + org.tensorflow.op.core.Shape liveShape = org.tensorflow.op.core.Shape.create(scope, tensor); + + Operand leadingSize = ReduceProd.create(scope, + StridedSliceHelper.stridedSlice(scope, + liveShape, + Indices.range(axis, axis + maskShape.numDimensions()) + ), + Constant.arrayOf(scope, 0) + ); + + Operand flattened = Reshape.create(scope, tensor, Concat.create( + scope, + Arrays.asList( + StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceTo(axis)), + Reshape.create(scope, leadingSize, Constant.arrayOf(scope, 1)), + StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions())) + ), + Constant.scalarOf(scope, 0) + )); + + Operand flatMask = Reshape.create(scope, mask, Constant.arrayOf(scope, -1)); + + Operand indices = Squeeze.create(scope, Where.create(scope, flatMask), Squeeze.axis(Collections.singletonList(1L))); + return Gather.create(scope, flattened, indices, axisTensor); + } + + /** + * Used to indicate the axis to mask from. + * {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match + * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape. + * @param axis the axis to mask from. Uses 0 if null. + */ + public static Options axis(Integer axis){ + return new Options().axis(axis); + } + + + /** + * Used to indicate the axis to mask from. + * {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match + * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape. + * @param axis the axis to mask from. + */ + public static Options axis(int axis){ + return new Options().axis(axis); + } + + /** + * Optional attributes for {@link org.tensorflow.op.core.BooleanMask} + */ + public static class Options { + + /** + * @param axis (Optional) The axis to mask from, or 0 if not set. + */ + public Options axis(Integer axis) { + this.axis = axis; + return this; + } + + private Integer axis; + + private Options() { + } + } + +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java new file mode 100644 index 00000000000..a40ae7ab017 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java @@ -0,0 +1,189 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.op.core; + +import java.util.Arrays; +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Indices; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +@Operator +public abstract class BooleanMaskUpdate { + + /** + * Updates a tensor at the masked values, and returns the updated tensor. Does not mutate the input tensors. {@code + * updates} will be broadcasted by default + *

+ * Numpy equivalent is `tensor[mask] = updates`. + *

+ * In general, {@code 0 < dim(mask) = K <= dim(tensor)}, and {@code mask}'s shape must match the first K dimensions of + * {@code tensor}'s shape. We then have: {@code booleanMask(tensor, mask)[i, j1,...,jd] = + * tensor[i1,...,iK,j1,...,jd]} where {@code (i1,...,iK)} is the ith {@code true} entry of {@code mask} (row-major + * order). + *

+ * The {@code axis} could be used with {@code mask} to indicate the axis to mask from (it's 0 by default). In that + * case, {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match the first {@code axis + + * dim(mask)} dimensions of {@code tensor}'s shape. + *

+ * The shape of {@code updates} should be {@code [n, t_1, t_2, ...]} where {@code n} is the number of true values in + * {@code mask} and {@code t_i} is the {@code i}th dimension of {@code tensor} after {@code axis} and {@code mask}. + * {@code updates} will be broadcasted to this shape by default, which can be disabled using {@code options}. + * + * @param tensor The tensor to mask. + * @param mask The mask to apply. + * @param updates the new values + * @param options carries optional attributes values + * @return The masked tensor. + */ + @Endpoint(name = "booleanMaskUpdate") + public static Operand create(Scope scope, Operand tensor, Operand mask, + Operand updates, + Options... options) { + + scope = scope.withNameAsSubScope("BooleanMaskUpdate"); + + int axis = 0; + boolean broadcast = true; + if (options != null) { + for (Options opts : options) { + if (opts.axis != null) { + axis = opts.axis; + } + if (opts.broadcast != null) { + broadcast = opts.broadcast; + } + } + } + + if (axis < 0) { + axis += tensor.rank(); + } + + Shape maskShape = mask.shape(); + Shape tensorShape = tensor.shape(); + + if (maskShape.numDimensions() == 0) { + throw new IllegalArgumentException("Mask cannot be a scalar."); + } + if (maskShape.hasUnknownDimension()) { + throw new IllegalArgumentException("Mask cannot have unknown number of dimensions"); + } + + Shape requiredMaskShape = tensorShape.subShape(axis, axis + maskShape.numDimensions()); + if (!requiredMaskShape.isCompatibleWith(maskShape)) { + throw new IllegalArgumentException( + "Mask shape " + maskShape + " is not compatible with the required mask shape: " + requiredMaskShape + "."); + } + + Operand liveShape = org.tensorflow.op.core.Shape.create(scope, tensor); + + Operand leadingSize = ReduceProd.create(scope, + StridedSliceHelper.stridedSlice(scope, + liveShape, + Indices.sliceTo(axis + maskShape.numDimensions()) + ), + Constant.arrayOf(scope, 0) + ); + + Operand innerShape = StridedSliceHelper + .stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions())); + + Operand reshaped = Reshape.create(scope, tensor, Concat.create( + scope, + Arrays.asList( + Reshape.create(scope, leadingSize, Constant.arrayOf(scope, 1)), + innerShape + ), + Constant.scalarOf(scope, 0) + )); + + Operand indices = Where.create(scope, mask); + + if (broadcast) { + Operand indicesShape = org.tensorflow.op.core.Shape.create(scope, indices); + // this is the number of true values + Operand batchShape = StridedSliceHelper.stridedSlice(scope, indicesShape, Indices.sliceTo(-1)); + + Operand updateShape = Concat.create( + scope, + Arrays.asList( + batchShape, + innerShape + ), + Constant.scalarOf(scope, 0) + ); + + updates = BroadcastTo.create(scope, updates, updateShape); + } + + Operand newValue = TensorScatterNdUpdate.create(scope, reshaped, indices, updates); + return Reshape.create(scope, newValue, liveShape); + } + + /** + * Used to indicate the axis to mask from. {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match + * the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape. + * + * @param axis the axis to mask from. Uses 0 if null. + */ + public static Options axis(Integer axis) { + return new Options().axis(axis); + } + + /** + * Whether to try broadcasting update. True by default. + */ + public static Options broadcast(Boolean broadcast) { + return new Options().broadcast(broadcast); + } + + /** + * Optional attributes for {@link BooleanMaskUpdate} + */ + public static class Options { + + /** + * @param axis (Optional) The axis to mask from, or 0 if not set. + */ + public Options axis(Integer axis) { + this.axis = axis; + return this; + } + + /** + * @param broadcast (Optional) Whether to try broadcasting update. True by default. + */ + public Options broadcast(Boolean broadcast) { + this.broadcast = broadcast; + return this; + } + + private Integer axis; + private Boolean broadcast; + + private Options() { + } + } + +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java new file mode 100644 index 00000000000..a4d9293ccf8 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java @@ -0,0 +1,67 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.op.core; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.Test; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +public class BooleanMaskTest { + @Test + public void testBooleanMask(){ + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Operand input = Constant.arrayOf(scope, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + Operand input2 = ExpandDims.create(scope, input, Constant.scalarOf(scope, 0)); + + Operand mask = Constant.arrayOf(scope, true, true, false, false, true, true, true, false, false, false); + + Operand output1 = BooleanMask.create(scope, input, mask); + Operand output2 = BooleanMask.create(scope, input2, mask, BooleanMask.axis(1)); + + try (TFloat32 result = (TFloat32) sess.runner().fetch(output1).run().get(0)) { + // expected shape from Python tensorflow + assertEquals(Shape.of(5), result.shape()); + assertEquals(0, result.getFloat(0)); + assertEquals(1, result.getFloat(1)); + assertEquals(4, result.getFloat(2)); + assertEquals(5, result.getFloat(3)); + assertEquals(6, result.getFloat(4)); + } + + try (TFloat32 result = (TFloat32) sess.runner().fetch(output2).run().get(0)) { + // expected shape from Python tensorflow + assertEquals(Shape.of(5), result.shape()); + assertEquals(0, result.getFloat(0)); + assertEquals(1, result.getFloat(1)); + assertEquals(4, result.getFloat(2)); + assertEquals(5, result.getFloat(3)); + assertEquals(6, result.getFloat(4)); + } + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java new file mode 100644 index 00000000000..ab852bbffb2 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java @@ -0,0 +1,146 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.op.core; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; +import org.junit.Test; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt32; + +public class BooleanMaskUpdateTest { + + @Test + public void testBooleanMaskUpdateSlice() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Operand input = Constant.tensorOf(scope, new int[][]{{0, 0, 0}, {1, 1, 1}, {2, 2, 2}}); + + Operand mask = Constant.arrayOf(scope, true, false, false); + + Operand value = Constant.tensorOf(scope, new int[][]{{-1, -1, -1}}); + + Operand output = BooleanMaskUpdate.create(scope, input, mask, value); + + Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1)); + + List results = sess.runner().fetch(output).fetch(bcastOutput).run(); + try (TInt32 result = (TInt32) results.get(0); + TInt32 bcastResult = (TInt32) results.get(1)) { + + assertEquals(Shape.of(3, 3), result.shape()); + + assertEquals(-1, result.getInt(0, 0)); + assertEquals(-1, result.getInt(0, 1)); + assertEquals(-1, result.getInt(0, 2)); + assertEquals(1, result.getInt(1, 0)); + assertEquals(1, result.getInt(1, 1)); + assertEquals(1, result.getInt(1, 2)); + assertEquals(2, result.getInt(2, 0)); + assertEquals(2, result.getInt(2, 1)); + assertEquals(2, result.getInt(2, 2)); + + assertEquals(result, bcastResult); + } + } + } + + @Test + public void testBooleanMaskUpdateSliceWithBroadcast() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Operand input = Constant.tensorOf(scope, new int[][]{{0, 0, 0}, {1, 1, 1}, {2, 2, 2}}); + + Operand mask = Constant.arrayOf(scope, true, false, false); + + Operand value = Constant.vectorOf(scope, new int[]{-1, -1, -1}); + + Operand output = BooleanMaskUpdate.create(scope, input, mask, value); + + Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1)); + + List results = sess.runner().fetch(output).fetch(bcastOutput).run(); + try (TInt32 result = (TInt32) results.get(0); + TInt32 bcastResult = (TInt32) results.get(1)) { + + assertEquals(Shape.of(3, 3), result.shape()); + + assertEquals(-1, result.getInt(0, 0)); + assertEquals(-1, result.getInt(0, 1)); + assertEquals(-1, result.getInt(0, 2)); + assertEquals(1, result.getInt(1, 0)); + assertEquals(1, result.getInt(1, 1)); + assertEquals(1, result.getInt(1, 2)); + assertEquals(2, result.getInt(2, 0)); + assertEquals(2, result.getInt(2, 1)); + assertEquals(2, result.getInt(2, 2)); + + assertEquals(result, bcastResult); + } + } + } + + @Test + public void testBooleanMaskUpdateAxis() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Operand input = Constant.tensorOf(scope, new int[][][]{{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}}); + + Operand mask = Constant.arrayOf(scope, true, true, false, false, true, true, true, false, false, false); + + Operand value = Constant.arrayOf(scope, -1, -1, -1, -1, -1); + + Operand output = BooleanMaskUpdate.create(scope, input, mask, value, BooleanMaskUpdate.axis(2)); + + Operand bcastOutput = BooleanMaskUpdate + .create(scope, input, mask, Constant.scalarOf(scope, -1), BooleanMaskUpdate.axis(2)); + + List results = sess.runner().fetch(output).fetch(bcastOutput).run(); + try (TInt32 result = (TInt32) results.get(0); + TInt32 bcastResult = (TInt32) results.get(1)) { + + assertEquals(Shape.of(1, 1, 10), result.shape()); + + assertEquals(-1, result.getInt(0, 0, 0)); + assertEquals(-1, result.getInt(0, 0, 1)); + assertEquals(2, result.getInt(0, 0, 2)); + assertEquals(3, result.getInt(0, 0, 3)); + assertEquals(-1, result.getInt(0, 0, 4)); + assertEquals(-1, result.getInt(0, 0, 5)); + assertEquals(-1, result.getInt(0, 0, 6)); + assertEquals(7, result.getInt(0, 0, 7)); + assertEquals(8, result.getInt(0, 0, 8)); + assertEquals(9, result.getInt(0, 0, 9)); + + assertEquals(result, bcastResult); + } + } + } +}