aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-08-03 11:12:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-03 11:17:05 -0700
commit5220933e536bcfa925232b0cb5fdb338297a218f (patch)
treed7749a7dd740c637b5c32e1507463942a6eb9aea /tensorflow/contrib/distribute/python/mirrored_strategy.py
parent8653ae60ec3f9ac9b6f9913830c38f656b1b6a1f (diff)
Automated rollback of commit 493d7588172bcf476309b3954db342839ca37872
PiperOrigin-RevId: 207294037
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py138
1 files changed, 65 insertions, 73 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 0c26ae8dbc..eb2d102012 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -209,75 +209,6 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
return values.Mirrored(value_updates)
-def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring
- # Figure out what collections this variable should be added to.
- # We'll add the MirroredVariable to those collections instead.
- collections = kwargs.pop("collections", None)
- if collections is None:
- collections = [ops.GraphKeys.GLOBAL_VARIABLES]
- kwargs["collections"] = []
-
- # Get synchronization value
- synchronization = kwargs.get("synchronization",
- variable_scope.VariableSynchronization.ON_WRITE)
- if synchronization == variable_scope.VariableSynchronization.NONE:
- raise ValueError("`NONE` variable synchronization mode is not "
- "supported with `Mirrored` distribution strategy. Please"
- " change the `synchronization` for variable: " +
- kwargs["name"])
- elif synchronization == variable_scope.VariableSynchronization.ON_READ:
- # Variables that are to be synced on read are tower local.
- is_tower_local = True
- kwargs["trainable"] = False
- elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
- synchronization == variable_scope.VariableSynchronization.AUTO):
- # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
- is_tower_local = False
- else:
- raise ValueError("Invalid variable synchronization mode: " +
- synchronization + " for variable: " + kwargs["name"])
-
- # Get aggregation value
- aggregation = kwargs.pop("aggregation",
- variable_scope.VariableAggregation.NONE)
- if aggregation not in [
- variable_scope.VariableAggregation.NONE,
- variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
- ]:
- raise ValueError("Invalid variable aggregation mode: " + aggregation +
- " for variable: " + kwargs["name"])
-
- # Ignore user-specified caching device, not needed for mirrored variables.
- kwargs.pop("caching_device", None)
-
- # TODO(josh11b,apassos): It would be better if variable initialization
- # was never recorded on the tape instead of having to do this manually
- # here.
- with tape.stop_recording():
- index = real_mirrored_creator(devices, *args, **kwargs)
-
- if is_tower_local:
- result = values.TowerLocalVariable(index, index[devices[0]], aggregation)
- else:
- result = values.MirroredVariable(index, index[devices[0]], aggregation)
-
- if not context.executing_eagerly():
- g = ops.get_default_graph()
- # If "trainable" is True, next_creator() will add the member variables
- # to the TRAINABLE_VARIABLES collection, so we manually remove
- # them and replace with the MirroredVariable. We can't set
- # "trainable" to False for next_creator() since that causes functions
- # like implicit_gradients to skip those variables.
- if kwargs.get("trainable", True):
- collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
- l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
- for v in index.values():
- l.remove(v)
- g.add_to_collections(collections, result)
- return result
-
-
class MirroredStrategy(distribute_lib.DistributionStrategy):
"""Mirrors vars to distribute across multiple devices on a single machine.
@@ -312,10 +243,54 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def _create_variable(self, next_creator, *args, **kwargs):
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
+ # Figure out what collections this variable should be added to.
+ # We'll add the MirroredVariable to those collections instead.
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
colocate_with = kwargs.pop("colocate_with", None)
devices = self._get_devices_from(colocate_with)
- def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring
+ # Get synchronization value
+ synchronization = kwargs.get(
+ "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
+ if synchronization == variable_scope.VariableSynchronization.NONE:
+ raise ValueError("`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please"
+ " change the `synchronization` for variable: " +
+ kwargs["name"])
+ elif synchronization == variable_scope.VariableSynchronization.ON_READ:
+ # Variables that are to be synced on read are tower local.
+ is_tower_local = True
+ kwargs["trainable"] = False
+ elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
+ synchronization == variable_scope.VariableSynchronization.AUTO):
+ # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
+ is_tower_local = False
+ else:
+ raise ValueError("Invalid variable synchronization mode: " +
+ synchronization + " for variable: " + kwargs["name"])
+
+ # Get aggregation value
+ aggregation = kwargs.pop("aggregation",
+ variable_scope.VariableAggregation.NONE)
+ if aggregation not in [
+ variable_scope.VariableAggregation.NONE,
+ variable_scope.VariableAggregation.SUM,
+ variable_scope.VariableAggregation.MEAN
+ ]:
+ raise ValueError("Invalid variable aggregation mode: " + aggregation +
+ " for variable: " + kwargs["name"])
+
+ # Ignore user-specified caching device, not needed for mirrored variables.
+ kwargs.pop("caching_device", None)
+
+ # TODO(josh11b,apassos): It would be better if variable initialization
+ # was never recorded on the tape instead of having to do this manually
+ # here.
+ with tape.stop_recording():
index = {}
for i, d in enumerate(devices):
with ops.device(d):
@@ -339,10 +314,27 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
v = next_creator(*args, **kwargs)
assert not isinstance(v, values.DistributedVariable)
index[d] = v
- return index
- return _create_mirrored_variable(devices, _real_mirrored_creator, *args,
- **kwargs)
+ if is_tower_local:
+ result = values.TowerLocalVariable(index, index[devices[0]],
+ aggregation)
+ else:
+ result = values.MirroredVariable(index, index[devices[0]], aggregation)
+
+ if not context.executing_eagerly():
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the member variables
+ # to the TRAINABLE_VARIABLES collection, so we manually remove
+ # them and replace with the MirroredVariable. We can't set
+ # "trainable" to False for next_creator() since that causes functions
+ # like implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ for v in index.values():
+ l.remove(v)
+ g.add_to_collections(collections, result)
+ return result
def distribute_dataset(self, dataset_fn):
return values.PerDeviceDataset(