aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-04 10:26:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 10:30:43 -0700
commit53bb944808e4ead0946ebbbc95e932d5dae4f349 (patch)
treeaea44dba3500cee59b1bcea9deef0fe8d0654381 /tensorflow/contrib/eager
parent102e0de242eccb2ac4664761183a7771b0a7c7af (diff)
This CL adds extra tests for `contrib.eager.metrics` that check eager metrics combined with while loops.
PiperOrigin-RevId: 211479604
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/metrics_test.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py
index aa99616810..dcc7b71d79 100644
--- a/tensorflow/contrib/eager/python/metrics_test.py
+++ b/tensorflow/contrib/eager/python/metrics_test.py
@@ -25,11 +25,14 @@ from tensorflow.contrib.eager.python import metrics
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import context
from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import util as checkpointable_utils
@@ -244,6 +247,48 @@ class MetricsTest(test.TestCase):
value = m.value()
self.assertEqual(self.evaluate(value), 2.5)
+ @test_util.run_in_graph_and_eager_modes
+ def testGraphAndEagerTensorGlobalVariables(self):
+ m = metrics.Mean(use_global_variables=True)
+ inputs = ops.convert_to_tensor([1.0, 2.0])
+ accumulate = m(inputs)
+ result = m.result()
+ self.evaluate(m.init_variables())
+ self.evaluate(accumulate)
+ self.assertEqual(self.evaluate(result), 1.5)
+ # Second init resets all the variables.
+ self.evaluate(m.init_variables())
+ inputs = ops.convert_to_tensor([2.0, 3.0])
+ self.evaluate(m(inputs))
+ value = m.value()
+ self.assertEqual(self.evaluate(value), 2.5)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testGraphAndEagerTensorWhileLoopDoubleCall(self):
+ m = metrics.Mean()
+ init_value = constant_op.constant(1)
+ cond = lambda i: math_ops.less(i, 3)
+ def body(x):
+ with ops.control_dependencies([m(x)]):
+ return math_ops.add(x, 1)
+ accumulate = control_flow_ops.while_loop(cond, body, [init_value])
+
+ result = m.result()
+ self.evaluate(m.init_variables())
+ self.evaluate(accumulate)
+ self.assertEqual(self.evaluate(result), 1.5)
+ # Second init resets all the variables.
+ self.evaluate(m.init_variables())
+ inputs = ops.convert_to_tensor([2.0, 3.0])
+ self.evaluate(m(inputs))
+ if ops.context.executing_eagerly():
+ self.evaluate(control_flow_ops.while_loop(cond, body, [init_value]))
+ else:
+ # Reuse the loop operators in graph mode
+ self.evaluate(accumulate)
+ value = m.value()
+ self.assertEqual(self.evaluate(value), 2.0)
+
def testTwoMeansGraph(self):
# Verify two metrics with the same name in the same graph raises a
# ValueError.