diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-04 10:26:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-04 10:30:43 -0700 |
commit | 53bb944808e4ead0946ebbbc95e932d5dae4f349 (patch) | |
tree | aea44dba3500cee59b1bcea9deef0fe8d0654381 /tensorflow/contrib/eager | |
parent | 102e0de242eccb2ac4664761183a7771b0a7c7af (diff) |
This CL adds extra tests for `contrib.eager.metrics` that check eager metrics combined with while loops.
PiperOrigin-RevId: 211479604
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r-- | tensorflow/contrib/eager/python/metrics_test.py | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index aa99616810..dcc7b71d79 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -25,11 +25,14 @@ from tensorflow.contrib.eager.python import metrics from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import context from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.training import training_util from tensorflow.python.training.checkpointable import util as checkpointable_utils @@ -244,6 +247,48 @@ class MetricsTest(test.TestCase): value = m.value() self.assertEqual(self.evaluate(value), 2.5) + @test_util.run_in_graph_and_eager_modes + def testGraphAndEagerTensorGlobalVariables(self): + m = metrics.Mean(use_global_variables=True) + inputs = ops.convert_to_tensor([1.0, 2.0]) + accumulate = m(inputs) + result = m.result() + self.evaluate(m.init_variables()) + self.evaluate(accumulate) + self.assertEqual(self.evaluate(result), 1.5) + # Second init resets all the variables. + self.evaluate(m.init_variables()) + inputs = ops.convert_to_tensor([2.0, 3.0]) + self.evaluate(m(inputs)) + value = m.value() + self.assertEqual(self.evaluate(value), 2.5) + + @test_util.run_in_graph_and_eager_modes + def testGraphAndEagerTensorWhileLoopDoubleCall(self): + m = metrics.Mean() + init_value = constant_op.constant(1) + cond = lambda i: math_ops.less(i, 3) + def body(x): + with ops.control_dependencies([m(x)]): + return math_ops.add(x, 1) + accumulate = control_flow_ops.while_loop(cond, body, [init_value]) + + result = m.result() + self.evaluate(m.init_variables()) + self.evaluate(accumulate) + self.assertEqual(self.evaluate(result), 1.5) + # Second init resets all the variables. + self.evaluate(m.init_variables()) + inputs = ops.convert_to_tensor([2.0, 3.0]) + self.evaluate(m(inputs)) + if ops.context.executing_eagerly(): + self.evaluate(control_flow_ops.while_loop(cond, body, [init_value])) + else: + # Reuse the loop operators in graph mode + self.evaluate(accumulate) + value = m.value() + self.assertEqual(self.evaluate(value), 2.0) + def testTwoMeansGraph(self): # Verify two metrics with the same name in the same graph raises a # ValueError. |