From 51f2c645d97c9ee1d9151bc3c959f73305daac09 Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Thu, 16 Aug 2018 15:50:56 -0700 Subject: Make tf.metrics work with TPU Strategy. PiperOrigin-RevId: 209064406 --- tensorflow/contrib/distribute/python/mirrored_strategy.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py') diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 72d1c6b7dd..edd5c6d17a 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -383,12 +383,21 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs + # We capture the control_flow_context at this point, before we run `fn` + # inside a while_loop. This is useful in cases where we might need to exit + # these contexts and get back to the outer context to do some things, for + # e.g. create an op which should be evaluated only once at the end of the + # loop on the host. One such usage is in creating metrics' value op. + self._outer_control_flow_context = ( + ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access + cond = lambda i, *args: i < iterations i = constant_op.constant(0) loop_result = control_flow_ops.while_loop( cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) + del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(loop_result) -- cgit v1.2.3