aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-08-16 15:50:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 15:54:47 -0700
commit51f2c645d97c9ee1d9151bc3c959f73305daac09 (patch)
treeef20c4d17237f87e8e96020bd13a1c934e1cbeca /tensorflow/contrib/distribute/python/mirrored_strategy.py
parent985cbd7ff220795abc4a50839144c177924d469c (diff)
Make tf.metrics work with TPU Strategy.
PiperOrigin-RevId: 209064406
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 72d1c6b7dd..edd5c6d17a 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -383,12 +383,21 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
with ops.control_dependencies([fn_result]):
return [i + 1] + flat_last_step_outputs
+ # We capture the control_flow_context at this point, before we run `fn`
+ # inside a while_loop. This is useful in cases where we might need to exit
+ # these contexts and get back to the outer context to do some things, for
+ # e.g. create an op which should be evaluated only once at the end of the
+ # loop on the host. One such usage is in creating metrics' value op.
+ self._outer_control_flow_context = (
+ ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
+
cond = lambda i, *args: i < iterations
i = constant_op.constant(0)
loop_result = control_flow_ops.while_loop(
cond, body, [i] + initial_loop_values, name="",
parallel_iterations=1, back_prop=False, swap_memory=False,
return_same_structure=True)
+ del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(loop_result)