aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-01-26 14:32:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 14:36:45 -0800
commit887509ffe25387efd4c869ffbe46b47ba6049860 (patch)
tree4b9113505b440431688884ec9aded0c48ec6c3e9
parent84d7f94efd1489b939a4672b0f47d6aa66d9eb91 (diff)
tfe.metrics.{Mean,Accuracy} return their inputs.
This makes chaining them easier. Control dependencies to ensure updates happen are implicitly added by the function code. PiperOrigin-RevId: 183446211
-rw-r--r--tensorflow/contrib/eager/python/metrics_impl.py12
-rw-r--r--tensorflow/contrib/eager/python/metrics_test.py13
2 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py
index bf029ca5f9..ea8dbf2b46 100644
--- a/tensorflow/contrib/eager/python/metrics_impl.py
+++ b/tensorflow/contrib/eager/python/metrics_impl.py
@@ -291,6 +291,9 @@ class Mean(Metric):
Args:
values: Tensor with the per-example value.
weights: Optional weighting of each example. Defaults to 1.
+
+ Returns:
+ The arguments, for easy chaining.
"""
if weights is None:
self.denom.assign_add(
@@ -302,6 +305,9 @@ class Mean(Metric):
self.denom.assign_add(math_ops.reduce_sum(weights))
values = math_ops.cast(values, self.dtype) * weights
self.numer.assign_add(math_ops.reduce_sum(values))
+ if weights is None:
+ return values
+ return values, weights
def result(self):
t = self.numer / self.denom
@@ -329,7 +335,13 @@ class Accuracy(Mean):
per element of the Tensor.
predictions: Tensor with the predicted label for each example.
weights: Optional weighting of each example. Defaults to 1.
+
+ Returns:
+ The arguments, for easy chaining.
"""
matches = math_ops.equal(labels, predictions)
matches = math_ops.cast(matches, dtypes.float64)
super(Accuracy, self).call(matches, weights=weights)
+ if weights is None:
+ return labels, predictions
+ return labels, predictions, weights
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py
index 9cf34fd9b2..a9ecaa3f8b 100644
--- a/tensorflow/contrib/eager/python/metrics_test.py
+++ b/tensorflow/contrib/eager/python/metrics_test.py
@@ -180,6 +180,19 @@ class MetricsTest(test.TestCase):
m2 = metrics.Mean()
m2(2)
+ def testMetricsChain(self):
+ with context.graph_mode(), self.test_session():
+ m1 = metrics.Mean()
+ m2 = metrics.Mean(name="m2")
+ update_m2 = m2(3.0)
+ update_m2_2 = m2(m1(1.0))
+ m1.init_variables().run()
+ m2.init_variables().run()
+ update_m2.eval()
+ update_m2_2.eval()
+ self.assertAllEqual(m2.result().eval(), 2.0)
+ self.assertAllEqual(m1.result().eval(), 1.0)
+
if __name__ == "__main__":
test.main()