diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 3e635b0d957..6c40149f3de 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -51,7 +51,7 @@ public class LossesHelper { * @param tf the TensorFlow Ops * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction - * . + * . * @param the data type for the labels, predictions and result * @return LossTuple of prediction, label,sampleWeight will * be null. Each of them possibly has the last dimension squeezed, sampleWeight @@ -77,7 +77,7 @@ public static LossTuple squeezeOrExpandDimensions( * @param tf the TensorFlow Ops * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction - * . + * . * @param sampleWeights Optional sample weight(s) Operand whose dimensions match * * prediction. @@ -179,7 +179,7 @@ private static Operand maybeExpandWeights( * * @param tf the TensorFlowOps * @param labels Label values, a Tensor whose dimensions match predictions - * . + * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. @@ -194,7 +194,7 @@ public static LossTuple removeSqueezableDimensions( * * @param tf the TensorFlowOps * @param labels Label values, a Operand whose dimensions match predictions - * . + * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). * @param the data type for the labels, predictions and result @@ -222,11 +222,13 @@ public static LossTuple removeSqueezableDimensions( // Use dynamic rank. // TODO: hold for lazy select feature, - // Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); + // Operand rankDiff = tf.math.sub(tf.rank(predictions), + // tf.rank(labels)); if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { /* - * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze - * predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), + * TODO, if we ever get a select that does lazy evaluation, but for now do the + * tf.squeeze predictions = tf.select( + * tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), * tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), predictions ); * */ predictions = tf.squeeze(predictions, Squeeze.axis(Collections.singletonList(-1L))); @@ -282,11 +284,12 @@ private static Operand reduceWeightedLoss( if (reduction == Reduction.NONE) { loss = weightedLoss; } else { - loss = - tf.reduceSum(weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE)); if (reduction == Reduction.AUTO || reduction == Reduction.SUM_OVER_BATCH_SIZE) { - loss = safeMean(tf, loss, weightedLoss.shape().size()); - } + loss = safeMean(tf, weightedLoss); + } else + loss = + tf.reduceSum( + weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE)); } return loss; } @@ -301,10 +304,10 @@ private static Operand reduceWeightedLoss( * @return A scalar representing the mean of losses. If numElements is * zero, then zero is returned. */ - public static Operand safeMean( - Ops tf, Operand losses, long numElements) { - Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses)); - return tf.math.divNoNan(totalLoss, cast(tf, tf.constant(numElements), losses.type())); + public static Operand safeMean(Ops tf, Operand losses) { + Operand totalLoss = + tf.reduceSum(losses, allAxes(tf, losses), ReduceSum.keepDims(Boolean.FALSE)); + return tf.math.divNoNan(totalLoss, cast(tf, tf.shape.size(tf.shape(losses)), losses.type())); } /** @@ -348,7 +351,8 @@ public static Operand rangeCheck( tf.math.logicalAnd( tf.reduceAll(tf.math.greaterEqual(values, minValue), allDims), tf.reduceAll(tf.math.lessEqual(values, maxValue), allDims)); - // Graph and Eager mode need to be handled differently, control dependencies are not allowed in + // Graph and Eager mode need to be handled differently, control dependencies are + // not allowed in // Eager mode if (tf.scope().env().isGraph()) { AssertThat assertThat = @@ -398,7 +402,8 @@ public static Operand valueCheck( } else return values; } else { // use dynamic shape Operand cond = tf.math.equal(tf.shape.size(tf.shape(diff.out())), tf.constant(0)); - // Graph and Eager mode need to be handled differently, control dependencies are not allowed + // Graph and Eager mode need to be handled differently, control dependencies are + // not allowed // in Eager mode if (tf.scope().env().isGraph()) { AssertThat assertThat = diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index c8c1df607c2..c2982e9b0b0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -22,7 +22,7 @@ import org.tensorflow.types.family.TNumber; /** Interface for metrics */ -interface Metric { +public interface Metric { /** * Creates a List of Operations to update the metric state based on input values.