diff options
author | 2018-08-16 15:50:56 -0700 | |
---|---|---|
committer | 2018-08-16 15:54:47 -0700 | |
commit | 51f2c645d97c9ee1d9151bc3c959f73305daac09 (patch) | |
tree | ef20c4d17237f87e8e96020bd13a1c934e1cbeca /tensorflow/contrib/distribute | |
parent | 985cbd7ff220795abc4a50839144c177924d469c (diff) |
Make tf.metrics work with TPU Strategy.
PiperOrigin-RevId: 209064406
Diffstat (limited to 'tensorflow/contrib/distribute')
4 files changed, 31 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 2f3d6bdd3f..8163494c8e 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -68,6 +68,8 @@ def _regression_dataset_fn(): "predictions": [1., .75, .25, 0.]}).repeat() +# TODO(priyag): Add TPU Strategy to this once metrics aggregate correctly using +# TowerLocalVariables on TPUs. Submit http://cl/208914352. def all_combinations(): return combinations.combine( distribution=[combinations.default_strategy, diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 72d1c6b7dd..edd5c6d17a 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -383,12 +383,21 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs + # We capture the control_flow_context at this point, before we run `fn` + # inside a while_loop. This is useful in cases where we might need to exit + # these contexts and get back to the outer context to do some things, for + # e.g. create an op which should be evaluated only once at the end of the + # loop on the host. One such usage is in creating metrics' value op. + self._outer_control_flow_context = ( + ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access + cond = lambda i, *args: i < iterations i = constant_op.constant(0) loop_result = control_flow_ops.while_loop( cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) + del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(loop_result) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 86833ad851..68561b5bbf 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -88,13 +88,22 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs + # We capture the control_flow_context at this point, before we run `fn` + # inside a while_loop. This is useful in cases where we might need to exit + # these contexts and get back to the outer context to do some things, for + # e.g. create an op which should be evaluated only once at the end of the + # loop on the host. One such usage is in creating metrics' value op. + self._outer_control_flow_context = ( + ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access + + # TODO(priyag): Use max_iterations instead of an explicit counter. cond = lambda i, *args: i < iterations i = constant_op.constant(0) - # TODO(priyag): Use max_iterations instead of an explicit counter. loop_result = control_flow_ops.while_loop( cond, body, [i] + initial_loop_values, name="", parallel_iterations=1, back_prop=False, swap_memory=False, return_same_structure=True) + del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(loop_result) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 50b5555ba5..77fc56de36 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -159,8 +159,18 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, initial_loop_values) + # We capture the control_flow_context at this point, before we run `fn` + # inside a while_loop and TPU replicate context. This is useful in cases + # where we might need to exit these contexts and get back to the outer + # context to do some things, for e.g. create an op which should be + # evaluated only once at the end of the loop on the host. One such usage + # is in creating metrics' value op. + self._outer_control_flow_context = ( + ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access + replicate_inputs = [[]] * self.num_towers replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) + del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) # Filter out any ops from the outputs, typically this would be the case |