aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py27
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py7
-rw-r--r--tensorflow/python/training/distribute.py49
3 files changed, 10 insertions, 73 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index d8668b398f..98fea76b3d 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -351,33 +351,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
assert isinstance(tower_local_var, values.Mirrored)
return array_ops.identity(tower_local_var.get())
- def _fetch(self, val, destination, fn):
- """Return a copy of `val` or `fn(val)` on `destination`."""
- if isinstance(val, values.TowerLocalVariable):
- val = self.reduce(val.reduce_method, val, destinations=destination)
- with ops.device(destination):
- return fn(self.unwrap(val)[0])
-
- assert isinstance(val, values.Mirrored), (
- "val = %s (type %s)" % (val, val.__class__.__name__))
- if val.on_device(destination):
- with ops.device(destination):
- # Use an identity here to make sure we are returning a tensor
- # instead of e.g. a variable object.
- return array_ops.identity(fn(val.get(destination)))
- device = None
- for d in self._devices:
- if val.on_device(d):
- device = d
- break
- assert device is not None, (
- "Could not find destination %s in list of devices %s." %
- (destination, val.devices))
- with ops.device(device):
- v = fn(val.get(device))
- with ops.device(destination):
- return array_ops.identity(v)
-
def _unwrap(self, val):
if isinstance(val, values.DistributedValues):
# Return in a deterministic order.
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 7f4bab9d93..a580dac96c 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -106,13 +106,6 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
"""Read the aggregate value of a tower-local variable."""
return array_ops.identity(tower_local_var)
- def _fetch(self, val, destination, fn):
- """Return a copy of `val` or `fn(val)` on `destination`."""
- with ops.device(self._device):
- v = fn(val)
- with ops.device(destination):
- return array_ops.identity(v)
-
def _unwrap(self, value):
return [value]
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index caffd042a0..6a326b65bb 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import threading
-import six
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
@@ -527,9 +526,13 @@ class DistributionStrategy(object):
V(`v`), output will have locality V(`v`) as well.
* `d.update_non_slot(d.non_slot_devices(), fn)`: in cross-tower
context, like `d.update()` except with locality N.
- * `d.fetch(t)`: Copy `t` with any locality to the client's CPU device.
- TODO(josh11b): Deprecate `fetch`, switch to `read_var` for
- reading tower-local variables.
+ * `d.read_var(v)`: Gets the (read-only) value of the variable `v` (on
+ the device determined by the current device scope), aggregating
+ across towers for tower-local variables. Frequently, this will be
+ done automatically when using `v` in an expression or fetching it in
+ a cross-tower context, but this function can be used to force that
+ conversion happens at a particular point in time (for example, to
+ add the result of the conversion to a graph collection).
The standard pattern for updating variables is to:
@@ -616,13 +619,13 @@ class DistributionStrategy(object):
There will still be one component variable per tower, but there is
no requirement that they stay in sync. Instead, when saving them
- or calling `fetch()/read_var()`, we use the value that
- results when calling `reduce()` on all the towers' variables.
+ or calling `read_var()`, we use the value that results when
+ calling `reduce()` on all the towers' variables.
Note: tower-local implies not trainable. Instead, it is expected
that each tower will directly update (using `assign_add()` or
whatever) its local variable instance but only the aggregated
- value (accessible using `fetch()`) will be exported from the
+ value (accessible using `read_var()`) will be exported from the
model. When it is acceptable to only aggregate on export, we
greatly reduce communication overhead by using tower-local
variables.
@@ -914,32 +917,6 @@ class DistributionStrategy(object):
def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
- def fetch(self, val, destination="/device:CPU:0", fn=lambda x: x):
- """Return a copy of `val` or `fn(val)` on `destination`.
-
- This is useful for getting a mirrored value onto a device. It
- will attempt to avoid a copy by checking if the value is already
- on the destination device.
-
- TODO(josh11b): Switch to `read_var`.
-
- Args:
- val: Value (which may be mirrored) to copy.
- destination: A device string to copy the value to.
- fn: An optional function to apply to the value on the source
- device, before copying.
-
- Returns:
- A `Tensor` on `destination`.
- """
- _require_cross_tower_context(self)
- assert isinstance(destination, six.string_types)
- destination = device_util.resolve(destination)
- return self._fetch(val, destination, fn)
-
- def _fetch(self, val, destination, fn):
- raise NotImplementedError("must be implemented in descendants")
-
def unwrap(self, value):
"""Returns the list of all per-device values contained in `value`.
@@ -1219,12 +1196,6 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def read_var(self, tower_local_var):
return array_ops.identity(tower_local_var)
- def _fetch(self, var, destination, fn):
- with ops.colocate_with(var):
- var = fn(var)
- with ops.device(destination):
- return array_ops.identity(var)
-
def _unwrap(self, distributed_value):
return [distributed_value]