From f46cdb6c412a73f45f6324f35d438d9d2a2e7e85 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 8 Mar 2024 17:14:16 -0500 Subject: [PATCH 1/2] Adding casts to the if test so it passes on GPU. --- .../src/test/java/org/tensorflow/op/core/IfTest.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java index 57bc0bc9ffb..1fc1b7ed46b 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java @@ -27,6 +27,7 @@ import org.tensorflow.Session; import org.tensorflow.Signature; import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; public class IfTest { @@ -37,7 +38,8 @@ private static Operand basicIf(Ops tf, Operand a, Operand { Operand a1 = ops.placeholder(TInt32.class); Operand b1 = ops.placeholder(TInt32.class); - return Signature.builder().input("a", a1).input("b", b1).output("y", a1).build(); + Operand y = ops.identity(a1); + return Signature.builder().input("a", a1).input("b", b1).output("y", y).build(); }); ConcreteFunction elseBranch = @@ -45,7 +47,10 @@ private static Operand basicIf(Ops tf, Operand a, Operand { Operand a1 = ops.placeholder(TInt32.class); Operand b1 = ops.placeholder(TInt32.class); - Operand y = ops.math.neg(b1); + // Casts around the math.neg operator as it's not implemented correctly for int32 in + // GPUs at some point between TF 2.10 and TF 2.15. + Operand y = + ops.dtypes.cast(ops.math.neg(ops.dtypes.cast(a1, TFloat32.class)), TInt32.class); return Signature.builder().input("a", a1).input("b", b1).output("y", y).build(); }); From 8d48f57b48794626f0da0ed9843bb54be10768d0 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 8 Mar 2024 17:31:29 -0500 Subject: [PATCH 2/2] Fix a typo in the variable name --- .../src/test/java/org/tensorflow/op/core/IfTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java index 1fc1b7ed46b..16cd17cab8e 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IfTest.java @@ -50,7 +50,7 @@ private static Operand basicIf(Ops tf, Operand a, Operand y = - ops.dtypes.cast(ops.math.neg(ops.dtypes.cast(a1, TFloat32.class)), TInt32.class); + ops.dtypes.cast(ops.math.neg(ops.dtypes.cast(b1, TFloat32.class)), TInt32.class); return Signature.builder().input("a", a1).input("b", b1).output("y", y).build(); });