aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-08-08 16:04:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 16:08:36 -0700
commit96a77055bd1f0c86e37708f65d5ac72cc6026c66 (patch)
tree2efd17e56e0831af8912bf7934011c411d29dcdc /tensorflow/contrib/distribute/python/mirrored_strategy.py
parent643d809b50127682bf5ef70b8871f929183d5a10 (diff)
Add an API to distribution strategy that allows running N steps. Implement this for MirroredStrategy and OneDeviceStrategy. Implemented in TPUStrategy earlier.
PiperOrigin-RevId: 207961939
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index c5d6e978e7..3c1760c03c 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -27,13 +27,17 @@ from tensorflow.contrib.distribute.python import values
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import coordinator
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.util import nest
# TODO(josh11b): Replace asserts in this file with if ...: raise ...
@@ -357,6 +361,54 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self._call_dataset_fn(dataset_fn), self._devices,
self._prefetch_on_device)
+ # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
+ def _run_steps_on_dataset(self, fn, iterator, iterations,
+ initial_loop_values=None):
+ if initial_loop_values is None:
+ initial_loop_values = {}
+ initial_loop_values = nest.flatten(initial_loop_values)
+
+ ctx = values.MultiStepContext()
+ def body(i, *args):
+ """A wrapper around `fn` to create the while loop body."""
+ del args
+ fn_result = fn(ctx, iterator.get_next())
+ for (name, output) in ctx.last_step_outputs.items():
+ # Convert all outputs to tensors, potentially from `DistributedValues`.
+ ctx.last_step_outputs[name] = self.unwrap(output)
+ flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
+ with ops.control_dependencies([fn_result]):
+ return [i + 1] + flat_last_step_outputs
+
+ cond = lambda i, *args: i < iterations
+ i = constant_op.constant(0)
+ loop_result = control_flow_ops.while_loop(
+ cond, body, [i] + initial_loop_values, name="",
+ parallel_iterations=1, back_prop=False, swap_memory=False,
+ return_same_structure=True)
+
+ ctx.run_op = control_flow_ops.group(loop_result)
+
+ # Convert the last_step_outputs from a list to the original dict structure
+ # of last_step_outputs.
+ last_step_tensor_outputs = loop_result[1:]
+ last_step_tensor_outputs_dict = nest.pack_sequence_as(
+ ctx.last_step_outputs, last_step_tensor_outputs)
+
+ for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access
+ output = last_step_tensor_outputs_dict[name]
+ # For outputs that have already been aggregated, wrap them in a Mirrored
+ # container, else in a PerDevice container.
+ if aggregation is variables_lib.VariableAggregation.NONE:
+ last_step_tensor_outputs_dict[name] = values.regroup(
+ {d: t for d, t in zip(self._devices, output)}, values.PerDevice)
+ else:
+ assert len(output) == 1
+ last_step_tensor_outputs_dict[name] = output[0]
+
+ ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access
+ return ctx
+
def _broadcast(self, tensor, destinations):
# TODO(josh11b): In eager mode, use one thread per device, or async mode.
return self._get_cross_tower_ops().broadcast(tensor, destinations or