diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java
index 3e635b0d957..6c40149f3de 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java
@@ -51,7 +51,7 @@ public class LossesHelper {
* @param tf the TensorFlow Ops
* @param predictions Predicted values, a Operand of arbitrary dimensions.
* @param labels Optional label Operand whose dimensions match prediction
- * .
+ * .
* @param the data type for the labels, predictions and result
* @return LossTuple of prediction, label,sampleWeight will
* be null. Each of them possibly has the last dimension squeezed, sampleWeight
@@ -77,7 +77,7 @@ public static LossTuple squeezeOrExpandDimensions(
* @param tf the TensorFlow Ops
* @param predictions Predicted values, a Operand of arbitrary dimensions.
* @param labels Optional label Operand whose dimensions match prediction
- * .
+ * .
* @param sampleWeights Optional sample weight(s) Operand whose dimensions match
*
* prediction.
@@ -179,7 +179,7 @@ private static Operand maybeExpandWeights(
*
* @param tf the TensorFlowOps
* @param labels Label values, a Tensor whose dimensions match predictions
- * .
+ * .
* @param predictions Predicted values, a Tensor of arbitrary dimensions.
* @param the data type for the labels, predictions and result
* @return labels and predictions, possibly with last dim squeezed.
@@ -194,7 +194,7 @@ public static LossTuple removeSqueezableDimensions(
*
* @param tf the TensorFlowOps
* @param labels Label values, a Operand whose dimensions match predictions
- * .
+ * .
* @param predictions Predicted values, a Tensor of arbitrary dimensions.
* @param expectedRankDiff Expected result of rank(predictions) - rank(labels).
* @param the data type for the labels, predictions and result
@@ -222,11 +222,13 @@ public static LossTuple removeSqueezableDimensions(
// Use dynamic rank.
// TODO: hold for lazy select feature,
- // Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels));
+ // Operand rankDiff = tf.math.sub(tf.rank(predictions),
+ // tf.rank(labels));
if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) {
/*
- * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze
- * predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ),
+ * TODO, if we ever get a select that does lazy evaluation, but for now do the
+ * tf.squeeze predictions = tf.select(
+ * tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ),
* tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), predictions ); *
*/
predictions = tf.squeeze(predictions, Squeeze.axis(Collections.singletonList(-1L)));
@@ -282,11 +284,12 @@ private static Operand reduceWeightedLoss(
if (reduction == Reduction.NONE) {
loss = weightedLoss;
} else {
- loss =
- tf.reduceSum(weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE));
if (reduction == Reduction.AUTO || reduction == Reduction.SUM_OVER_BATCH_SIZE) {
- loss = safeMean(tf, loss, weightedLoss.shape().size());
- }
+ loss = safeMean(tf, weightedLoss);
+ } else
+ loss =
+ tf.reduceSum(
+ weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE));
}
return loss;
}
@@ -301,10 +304,10 @@ private static Operand reduceWeightedLoss(
* @return A scalar representing the mean of losses. If numElements is
* zero, then zero is returned.
*/
- public static Operand safeMean(
- Ops tf, Operand losses, long numElements) {
- Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses));
- return tf.math.divNoNan(totalLoss, cast(tf, tf.constant(numElements), losses.type()));
+ public static Operand safeMean(Ops tf, Operand losses) {
+ Operand totalLoss =
+ tf.reduceSum(losses, allAxes(tf, losses), ReduceSum.keepDims(Boolean.FALSE));
+ return tf.math.divNoNan(totalLoss, cast(tf, tf.shape.size(tf.shape(losses)), losses.type()));
}
/**
@@ -348,7 +351,8 @@ public static Operand rangeCheck(
tf.math.logicalAnd(
tf.reduceAll(tf.math.greaterEqual(values, minValue), allDims),
tf.reduceAll(tf.math.lessEqual(values, maxValue), allDims));
- // Graph and Eager mode need to be handled differently, control dependencies are not allowed in
+ // Graph and Eager mode need to be handled differently, control dependencies are
+ // not allowed in
// Eager mode
if (tf.scope().env().isGraph()) {
AssertThat assertThat =
@@ -398,7 +402,8 @@ public static Operand valueCheck(
} else return values;
} else { // use dynamic shape
Operand cond = tf.math.equal(tf.shape.size(tf.shape(diff.out())), tf.constant(0));
- // Graph and Eager mode need to be handled differently, control dependencies are not allowed
+ // Graph and Eager mode need to be handled differently, control dependencies are
+ // not allowed
// in Eager mode
if (tf.scope().env().isGraph()) {
AssertThat assertThat =
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java
index c8c1df607c2..c2982e9b0b0 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java
@@ -22,7 +22,7 @@
import org.tensorflow.types.family.TNumber;
/** Interface for metrics */
-interface Metric {
+public interface Metric {
/**
* Creates a List of Operations to update the metric state based on input values.