diff options
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 36 |
1 files changed, 24 insertions, 12 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index c18faeb67d..18ceba42c2 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -366,18 +366,7 @@ class MirroredVariable(DistributedVariable, Mirrored, # We are calling assign on the mirrored variable in cross tower context, # use update to update the variable. strategy = distribution_strategy_context.get_distribution_strategy() - updates = strategy.update(self, f, *args, **kwargs) - grouped = strategy.group(updates) - if isinstance(updates, DistributedValues) and updates.is_tensor_like: - # Make sure we run all updates. Without this, something like - # session.run(mirrored_var.assign*(...)) may only update one tower. - index = {} - for d in updates.devices: - with ops.device(d), ops.control_dependencies([grouped]): - index[d] = array_ops.identity(updates.get(d)) - return Mirrored(index) - else: - return grouped + return strategy.update(self, f, *args, **kwargs) else: _assert_tower_context() # We are calling an assign function on the mirrored variable in tower @@ -1049,6 +1038,29 @@ def select_device_mirrored(device, structured): return nest.map_structure(_get_mirrored, structured) +def update_regroup(strategy, updates, should_group): + """Regroup for an update, with dependencies to ensure all updates execute.""" + regrouped = regroup(updates, Mirrored) + if not should_group: + return nest.map_structure(strategy.unwrap, regrouped) + grouped_flat = [] + for u in nest.flatten(regrouped): + if isinstance(u, DistributedValues): + g = strategy.group(u) + if u.is_tensor_like: + # Make sure we run all updates. Without this, something like + # session.run(strategy.update(...)) may only update one tower. + index = {} + for d in u.devices: + with ops.device(d), ops.control_dependencies([g]): + index[d] = array_ops.identity(u.get(d)) + g = Mirrored(index) + else: + g = u + grouped_flat.append(g) + return nest.pack_sequence_as(regrouped, grouped_flat) + + class PerDeviceDataIterator(object): """An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`.""" |