aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-08-16 15:50:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 15:54:47 -0700
commit51f2c645d97c9ee1d9151bc3c959f73305daac09 (patch)
treeef20c4d17237f87e8e96020bd13a1c934e1cbeca /tensorflow/contrib/distribute
parent985cbd7ff220795abc4a50839144c177924d469c (diff)
Make tf.metrics work with TPU Strategy.
PiperOrigin-RevId: 209064406
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py9
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py11
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py10
4 files changed, 31 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
index 2f3d6bdd3f..8163494c8e 100644
--- a/tensorflow/contrib/distribute/python/metrics_v1_test.py
+++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py
@@ -68,6 +68,8 @@ def _regression_dataset_fn():
"predictions": [1., .75, .25, 0.]}).repeat()
+# TODO(priyag): Add TPU Strategy to this once metrics aggregate correctly using
+# TowerLocalVariables on TPUs. Submit http://cl/208914352.
def all_combinations():
return combinations.combine(
distribution=[combinations.default_strategy,
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 72d1c6b7dd..edd5c6d17a 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -383,12 +383,21 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
with ops.control_dependencies([fn_result]):
return [i + 1] + flat_last_step_outputs
+ # We capture the control_flow_context at this point, before we run `fn`
+ # inside a while_loop. This is useful in cases where we might need to exit
+ # these contexts and get back to the outer context to do some things, for
+ # e.g. create an op which should be evaluated only once at the end of the
+ # loop on the host. One such usage is in creating metrics' value op.
+ self._outer_control_flow_context = (
+ ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
+
cond = lambda i, *args: i < iterations
i = constant_op.constant(0)
loop_result = control_flow_ops.while_loop(
cond, body, [i] + initial_loop_values, name="",
parallel_iterations=1, back_prop=False, swap_memory=False,
return_same_structure=True)
+ del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(loop_result)
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 86833ad851..68561b5bbf 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -88,13 +88,22 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
with ops.control_dependencies([fn_result]):
return [i + 1] + flat_last_step_outputs
+ # We capture the control_flow_context at this point, before we run `fn`
+ # inside a while_loop. This is useful in cases where we might need to exit
+ # these contexts and get back to the outer context to do some things, for
+ # e.g. create an op which should be evaluated only once at the end of the
+ # loop on the host. One such usage is in creating metrics' value op.
+ self._outer_control_flow_context = (
+ ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
+
+ # TODO(priyag): Use max_iterations instead of an explicit counter.
cond = lambda i, *args: i < iterations
i = constant_op.constant(0)
- # TODO(priyag): Use max_iterations instead of an explicit counter.
loop_result = control_flow_ops.while_loop(
cond, body, [i] + initial_loop_values, name="",
parallel_iterations=1, back_prop=False, swap_memory=False,
return_same_structure=True)
+ del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(loop_result)
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 50b5555ba5..77fc56de36 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -159,8 +159,18 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def iterate_on_tpu():
return training_loop.repeat(iterations, run_fn, initial_loop_values)
+ # We capture the control_flow_context at this point, before we run `fn`
+ # inside a while_loop and TPU replicate context. This is useful in cases
+ # where we might need to exit these contexts and get back to the outer
+ # context to do some things, for e.g. create an op which should be
+ # evaluated only once at the end of the loop on the host. One such usage
+ # is in creating metrics' value op.
+ self._outer_control_flow_context = (
+ ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
+
replicate_inputs = [[]] * self.num_towers
replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
+ del self._outer_control_flow_context
ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)
# Filter out any ops from the outputs, typically this would be the case