diff options
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 27 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/one_device_strategy.py | 7 | ||||
-rw-r--r-- | tensorflow/python/training/distribute.py | 49 |
3 files changed, 10 insertions, 73 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index d8668b398f..98fea76b3d 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -351,33 +351,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): assert isinstance(tower_local_var, values.Mirrored) return array_ops.identity(tower_local_var.get()) - def _fetch(self, val, destination, fn): - """Return a copy of `val` or `fn(val)` on `destination`.""" - if isinstance(val, values.TowerLocalVariable): - val = self.reduce(val.reduce_method, val, destinations=destination) - with ops.device(destination): - return fn(self.unwrap(val)[0]) - - assert isinstance(val, values.Mirrored), ( - "val = %s (type %s)" % (val, val.__class__.__name__)) - if val.on_device(destination): - with ops.device(destination): - # Use an identity here to make sure we are returning a tensor - # instead of e.g. a variable object. - return array_ops.identity(fn(val.get(destination))) - device = None - for d in self._devices: - if val.on_device(d): - device = d - break - assert device is not None, ( - "Could not find destination %s in list of devices %s." % - (destination, val.devices)) - with ops.device(device): - v = fn(val.get(device)) - with ops.device(destination): - return array_ops.identity(v) - def _unwrap(self, val): if isinstance(val, values.DistributedValues): # Return in a deterministic order. diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 7f4bab9d93..a580dac96c 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -106,13 +106,6 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): """Read the aggregate value of a tower-local variable.""" return array_ops.identity(tower_local_var) - def _fetch(self, val, destination, fn): - """Return a copy of `val` or `fn(val)` on `destination`.""" - with ops.device(self._device): - v = fn(val) - with ops.device(destination): - return array_ops.identity(v) - def _unwrap(self, value): return [value] diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index caffd042a0..6a326b65bb 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import threading -import six from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import ops @@ -527,9 +526,13 @@ class DistributionStrategy(object): V(`v`), output will have locality V(`v`) as well. * `d.update_non_slot(d.non_slot_devices(), fn)`: in cross-tower context, like `d.update()` except with locality N. - * `d.fetch(t)`: Copy `t` with any locality to the client's CPU device. - TODO(josh11b): Deprecate `fetch`, switch to `read_var` for - reading tower-local variables. + * `d.read_var(v)`: Gets the (read-only) value of the variable `v` (on + the device determined by the current device scope), aggregating + across towers for tower-local variables. Frequently, this will be + done automatically when using `v` in an expression or fetching it in + a cross-tower context, but this function can be used to force that + conversion happens at a particular point in time (for example, to + add the result of the conversion to a graph collection). The standard pattern for updating variables is to: @@ -616,13 +619,13 @@ class DistributionStrategy(object): There will still be one component variable per tower, but there is no requirement that they stay in sync. Instead, when saving them - or calling `fetch()/read_var()`, we use the value that - results when calling `reduce()` on all the towers' variables. + or calling `read_var()`, we use the value that results when + calling `reduce()` on all the towers' variables. Note: tower-local implies not trainable. Instead, it is expected that each tower will directly update (using `assign_add()` or whatever) its local variable instance but only the aggregated - value (accessible using `fetch()`) will be exported from the + value (accessible using `read_var()`) will be exported from the model. When it is acceptable to only aggregate on export, we greatly reduce communication overhead by using tower-local variables. @@ -914,32 +917,6 @@ class DistributionStrategy(object): def _update_non_slot(self, colocate_with, fn, *args, **kwargs): raise NotImplementedError("must be implemented in descendants") - def fetch(self, val, destination="/device:CPU:0", fn=lambda x: x): - """Return a copy of `val` or `fn(val)` on `destination`. - - This is useful for getting a mirrored value onto a device. It - will attempt to avoid a copy by checking if the value is already - on the destination device. - - TODO(josh11b): Switch to `read_var`. - - Args: - val: Value (which may be mirrored) to copy. - destination: A device string to copy the value to. - fn: An optional function to apply to the value on the source - device, before copying. - - Returns: - A `Tensor` on `destination`. - """ - _require_cross_tower_context(self) - assert isinstance(destination, six.string_types) - destination = device_util.resolve(destination) - return self._fetch(val, destination, fn) - - def _fetch(self, val, destination, fn): - raise NotImplementedError("must be implemented in descendants") - def unwrap(self, value): """Returns the list of all per-device values contained in `value`. @@ -1219,12 +1196,6 @@ class _DefaultDistributionStrategy(DistributionStrategy): def read_var(self, tower_local_var): return array_ops.identity(tower_local_var) - def _fetch(self, var, destination, fn): - with ops.colocate_with(var): - var = fn(var) - with ops.device(destination): - return array_ops.identity(var) - def _unwrap(self, distributed_value): return [distributed_value] |