diff options
author | 2018-08-16 15:50:56 -0700 | |
---|---|---|
committer | 2018-08-16 15:54:47 -0700 | |
commit | 51f2c645d97c9ee1d9151bc3c959f73305daac09 (patch) | |
tree | ef20c4d17237f87e8e96020bd13a1c934e1cbeca /tensorflow/contrib/distribute/python/mirrored_strategy.py | |
parent | 985cbd7ff220795abc4a50839144c177924d469c (diff) |
Make tf.metrics work with TPU Strategy.
PiperOrigin-RevId: 209064406
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 72d1c6b7dd..edd5c6d17a 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -383,12 +383,21 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs + # We capture the control_flow_context at this point, before we run `fn` + # inside a while_loop. This is useful in cases where we might need to exit + # these contexts and get back to the outer context to do some things, for + # e.g. create an op which should be evaluated only once at the end of the + # loop on the host. One such usage is in creating metrics' value op. + self._outer_control_flow_context = ( + ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access + cond = lambda i, *args: i < iterations i = constant_op.constant(0) loop_result = control_flow_ops.while_loop( cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) + del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(loop_result) |