aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-01 16:31:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 16:35:32 -0700
commitbb1f9e1a57c8bc18325b3c86298be96e6647a0a3 (patch)
tree66809f6708a49a5b8d408ef0b349fe837d905cda /tensorflow/contrib/distribute
parent49bbfec04b729960999ef054e3acab719631b101 (diff)
Change semantics of DistributionStrategy.update() to make sure the
output depends on the updates across all mirrors. Before this change, update() would return a Mirrored value that where each component was an update to a single mirror. This caused a problem since for reading purposes other DistributionStrategy methods would consider it okay to read any single component, and so if you for example did something like session.run(strategy.update(...)) it would only perform the update on one replica. The fix is to have the output be a Mirrored value that is actually the identity operation returning the output on that device, but that has a control dependency making sure that the update actually happens on all the replicas. This fix was already present in MirroredVariable._assign_func, this CL moves the fix into update() and generalizes it to multiple return values. To disable this new grouping behavior, you may now pass "grouped=False" to update(). For example, some callers (like Optimizer) are performing a lot of updates and they prefer to group all of them together at once for performance reasons. In this case, we still want to make sure the caller executes the update on all replicas, so we return an unwrapped value instead of a Mirrored value. This has the happy side effect of removing a bunch of unwrap calls in client code, since unwrapping was the only safe way to use the Mirrored value we used to return. PiperOrigin-RevId: 215301909
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py12
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py17
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy.py22
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py6
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py36
-rw-r--r--tensorflow/contrib/distribute/python/values.py36
9 files changed, 92 insertions, 45 deletions
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index 33ffbf6abe..6796a23d46 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -128,7 +128,8 @@ class CollectiveAllReduceStrategyTestBase(
# TODO(yuefengz): support non-Mirrored variable as destinations.
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(
+ d.update(v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 4d7516063c..6bd380a22d 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -627,9 +627,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
return self._get_cross_tower_ops().batch_reduce(aggregation,
value_destination_pairs)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
# TODO(josh11b): In eager mode, use one thread per device.
assert isinstance(var, values.DistributedVariable)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
updates = {}
for d, v in var._index.items(): # pylint: disable=protected-access
name = "update_%d" % self._device_index.get(d)
@@ -638,10 +640,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
updates[d] = fn(v,
*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))
- return values.regroup(updates, values.Mirrored)
+ return values.update_regroup(self, updates, should_group)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
assert isinstance(colocate_with, list)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
# TODO(josh11b): In eager mode, use one thread per device.
updates = {}
for d in colocate_with:
@@ -649,7 +653,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
updates[d] = fn(*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))
- return values.regroup(updates, values.Mirrored)
+ return values.update_regroup(self, updates, should_group)
def read_var(self, tower_local_var):
"""Read the aggregate value of a tower-local variable."""
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index f51e543624..eeac528329 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -826,7 +826,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
with dist.scope():
ret_v_sum = dist.call_for_each_tower(model_fn, run_concurrently=False)
- update_ops = dist.unwrap(dist.update(ret_v_sum, update, 5.0))
+ update_ops = dist.update(ret_v_sum, update, 5.0, grouped=False)
# Initialize variables.
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 23b220f64b..f525919048 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -141,14 +141,21 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
else:
assert False
- def _update(self, var, fn, *args, **kwargs):
- with ops.device(self._device), distribute_lib.UpdateContext(self._device):
- return fn(var, *args, **kwargs)
+ def _update(self, var, options, fn, *args, **kwargs):
+ # The implementations of _update() and _update_non_slot() are identical
+ # except _update() passes `var` as the first argument to `fn()`.
+ return self._update_non_slot(var, options, fn, var, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
del colocate_with
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
with ops.device(self._device), distribute_lib.UpdateContext(self._device):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def read_var(self, tower_local_var):
"""Read the aggregate value of a tower-local variable."""
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 1125d027f6..6ddd91507b 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -343,21 +343,33 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
return nest.map_structure(_select_fn, structured)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
if isinstance(var, values.AggregatingVariable):
var = var.get()
if not isinstance(var, resource_variable_ops.ResourceVariable):
raise ValueError(
"You can not update `var` %r. It must be a Variable." % var)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
- return fn(var, *self._select_single_value(args),
- **self._select_single_value(kwargs))
+ result = fn(var, *self._select_single_value(args),
+ **self._select_single_value(kwargs))
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
# TODO(yuefengz): does it need to call _select_single_value?
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
with ops.device(
colocate_with.device), distribute_lib.UpdateContext(colocate_with):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def _unwrap(self, val):
if isinstance(val, values.DistributedValues):
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index 12789e0bc9..353d11a583 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -395,7 +395,8 @@ class ParameterServerStrategyTestBase(
# TODO(yuefengz): support non-Mirrored variable as destinations.
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(
+ d.update(v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index 5d498fb629..fd280f5754 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -115,7 +115,8 @@ class DistributionTestBase(test.TestCase):
with ops.control_dependencies([fetched]):
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(d.update(
+ v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
@@ -169,7 +170,8 @@ class DistributionTestBase(test.TestCase):
with ops.control_dependencies([fetched]):
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(d.update(
+ v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 1b555482d3..c3c7df3cd8 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -297,6 +297,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
# For outputs that have already been aggregated, take the first value
# from the list as each value should be the same. Else return the full
# list of values.
+ # TODO(josh11b): If aggregation is NONE, we should return a PerDevice value.
if aggregation is not variables_lib.VariableAggregation.NONE:
# TODO(priyag): Should this return the element or a list with 1 element
last_step_tensor_outputs_dict[name] = output[0]
@@ -398,11 +399,16 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
return output * (1. / len(value))
return output
- def _update(self, var, fn, *args, **kwargs):
- # TODO(jhseu): Consider supporting grouped==False.
+ def _update(self, var, options, fn, *args, **kwargs):
assert isinstance(var, values.TPUMirroredVariable)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
+
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
- return fn(var, *args, **kwargs)
+ if should_group:
+ return fn(var, *args, **kwargs)
+ else:
+ return [fn(var, *args, **kwargs)]
# Otherwise, we revert to MirroredStrategy behavior and update each variable
# directly.
@@ -414,23 +420,25 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
updates[d] = fn(v,
*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))
+ return values.update_regroup(self, updates, should_group)
- # Make a single control dependency to keep the variables mirrored. If one
- # assignment is fetched, then run all assignments.
- sorted_keys = sorted(updates.keys())
- update_tuple = control_flow_ops.tuple([updates[d] for d in sorted_keys])
- for i, d in enumerate(sorted_keys):
- updates[d] = update_tuple[i]
- return values.regroup(updates, values.Mirrored)
+ # TODO(josh11b): Need to implement _update_non_slot()!
def read_var(self, var):
assert isinstance(var, values.TPUMirroredVariable)
return var.read_value()
- def _unwrap(self, value):
- if isinstance(value, list):
- return value
- return [value]
+ def _unwrap(self, val):
+ if isinstance(val, values.DistributedValues):
+ # Return in a deterministic order.
+ return [val.get(device=d) for d in sorted(val.devices)]
+ elif isinstance(val, list):
+ # TODO(josh11b): We need to remove this case; per device values should
+ # be represented using a PerDevice wrapper instead of a list with
+ # one entry per device.
+ return val
+ return [val]
+
@property
def num_towers(self):
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index c18faeb67d..18ceba42c2 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -366,18 +366,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
# We are calling assign on the mirrored variable in cross tower context,
# use update to update the variable.
strategy = distribution_strategy_context.get_distribution_strategy()
- updates = strategy.update(self, f, *args, **kwargs)
- grouped = strategy.group(updates)
- if isinstance(updates, DistributedValues) and updates.is_tensor_like:
- # Make sure we run all updates. Without this, something like
- # session.run(mirrored_var.assign*(...)) may only update one tower.
- index = {}
- for d in updates.devices:
- with ops.device(d), ops.control_dependencies([grouped]):
- index[d] = array_ops.identity(updates.get(d))
- return Mirrored(index)
- else:
- return grouped
+ return strategy.update(self, f, *args, **kwargs)
else:
_assert_tower_context()
# We are calling an assign function on the mirrored variable in tower
@@ -1049,6 +1038,29 @@ def select_device_mirrored(device, structured):
return nest.map_structure(_get_mirrored, structured)
+def update_regroup(strategy, updates, should_group):
+ """Regroup for an update, with dependencies to ensure all updates execute."""
+ regrouped = regroup(updates, Mirrored)
+ if not should_group:
+ return nest.map_structure(strategy.unwrap, regrouped)
+ grouped_flat = []
+ for u in nest.flatten(regrouped):
+ if isinstance(u, DistributedValues):
+ g = strategy.group(u)
+ if u.is_tensor_like:
+ # Make sure we run all updates. Without this, something like
+ # session.run(strategy.update(...)) may only update one tower.
+ index = {}
+ for d in u.devices:
+ with ops.device(d), ops.control_dependencies([g]):
+ index[d] = array_ops.identity(u.get(d))
+ g = Mirrored(index)
+ else:
+ g = u
+ grouped_flat.append(g)
+ return nest.pack_sequence_as(regrouped, grouped_flat)
+
+
class PerDeviceDataIterator(object):
"""An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`."""