diff options
author | 2018-01-26 14:32:45 -0800 | |
---|---|---|
committer | 2018-01-26 14:36:45 -0800 | |
commit | 887509ffe25387efd4c869ffbe46b47ba6049860 (patch) | |
tree | 4b9113505b440431688884ec9aded0c48ec6c3e9 | |
parent | 84d7f94efd1489b939a4672b0f47d6aa66d9eb91 (diff) |
tfe.metrics.{Mean,Accuracy} return their inputs.
This makes chaining them easier. Control dependencies to ensure updates
happen are implicitly added by the function code.
PiperOrigin-RevId: 183446211
-rw-r--r-- | tensorflow/contrib/eager/python/metrics_impl.py | 12 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/metrics_test.py | 13 |
2 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index bf029ca5f9..ea8dbf2b46 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -291,6 +291,9 @@ class Mean(Metric): Args: values: Tensor with the per-example value. weights: Optional weighting of each example. Defaults to 1. + + Returns: + The arguments, for easy chaining. """ if weights is None: self.denom.assign_add( @@ -302,6 +305,9 @@ class Mean(Metric): self.denom.assign_add(math_ops.reduce_sum(weights)) values = math_ops.cast(values, self.dtype) * weights self.numer.assign_add(math_ops.reduce_sum(values)) + if weights is None: + return values + return values, weights def result(self): t = self.numer / self.denom @@ -329,7 +335,13 @@ class Accuracy(Mean): per element of the Tensor. predictions: Tensor with the predicted label for each example. weights: Optional weighting of each example. Defaults to 1. + + Returns: + The arguments, for easy chaining. """ matches = math_ops.equal(labels, predictions) matches = math_ops.cast(matches, dtypes.float64) super(Accuracy, self).call(matches, weights=weights) + if weights is None: + return labels, predictions + return labels, predictions, weights diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 9cf34fd9b2..a9ecaa3f8b 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -180,6 +180,19 @@ class MetricsTest(test.TestCase): m2 = metrics.Mean() m2(2) + def testMetricsChain(self): + with context.graph_mode(), self.test_session(): + m1 = metrics.Mean() + m2 = metrics.Mean(name="m2") + update_m2 = m2(3.0) + update_m2_2 = m2(m1(1.0)) + m1.init_variables().run() + m2.init_variables().run() + update_m2.eval() + update_m2_2.eval() + self.assertAllEqual(m2.result().eval(), 2.0) + self.assertAllEqual(m1.result().eval(), 1.0) + if __name__ == "__main__": test.main() |