aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py36
1 files changed, 24 insertions, 12 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index c18faeb67d..18ceba42c2 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -366,18 +366,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
# We are calling assign on the mirrored variable in cross tower context,
# use update to update the variable.
strategy = distribution_strategy_context.get_distribution_strategy()
- updates = strategy.update(self, f, *args, **kwargs)
- grouped = strategy.group(updates)
- if isinstance(updates, DistributedValues) and updates.is_tensor_like:
- # Make sure we run all updates. Without this, something like
- # session.run(mirrored_var.assign*(...)) may only update one tower.
- index = {}
- for d in updates.devices:
- with ops.device(d), ops.control_dependencies([grouped]):
- index[d] = array_ops.identity(updates.get(d))
- return Mirrored(index)
- else:
- return grouped
+ return strategy.update(self, f, *args, **kwargs)
else:
_assert_tower_context()
# We are calling an assign function on the mirrored variable in tower
@@ -1049,6 +1038,29 @@ def select_device_mirrored(device, structured):
return nest.map_structure(_get_mirrored, structured)
+def update_regroup(strategy, updates, should_group):
+ """Regroup for an update, with dependencies to ensure all updates execute."""
+ regrouped = regroup(updates, Mirrored)
+ if not should_group:
+ return nest.map_structure(strategy.unwrap, regrouped)
+ grouped_flat = []
+ for u in nest.flatten(regrouped):
+ if isinstance(u, DistributedValues):
+ g = strategy.group(u)
+ if u.is_tensor_like:
+ # Make sure we run all updates. Without this, something like
+ # session.run(strategy.update(...)) may only update one tower.
+ index = {}
+ for d in u.devices:
+ with ops.device(d), ops.control_dependencies([g]):
+ index[d] = array_ops.identity(u.get(d))
+ g = Mirrored(index)
+ else:
+ g = u
+ grouped_flat.append(g)
+ return nest.pack_sequence_as(regrouped, grouped_flat)
+
+
class PerDeviceDataIterator(object):
"""An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`."""