aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-06-29 18:02:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 18:05:25 -0700
commitc290930ec1beacbcac414b43b3367dd44ffbd303 (patch)
treeb1136e7c32718a6f1f9ebfde3073c88546078de6
parenta520735d205ca5678fc8a371ea1add00413907fe (diff)
Add `synchronization` and `aggregation` args to get_variable(). These args will be used for distributed variables.
Add Enum `VariableSynchronization` with values for `synchronization`: AUTO, UNREPLICATED, ON_WRITE, ON_READ Add Enum `VariableAggregation` with values for `aggregation`: NONE, SUM, MEAN. Replace all the aggregation methods strings in distribution strategy to the enum values. Update Mirrored strategy to use these parameters to decide on whether a variable should be Mirrored or TowerLocal. Update different distribution strategy value types to use the `VariableAggregation` Enum PiperOrigin-RevId: 202736077
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py71
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py82
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py47
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py180
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py12
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py9
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py5
-rw-r--r--tensorflow/contrib/distribute/python/values.py35
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py39
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py9
-rw-r--r--tensorflow/python/eager/graph_callable.py47
-rw-r--r--tensorflow/python/keras/layers/normalization.py4
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py42
-rw-r--r--tensorflow/python/ops/metrics_impl.py3
-rw-r--r--tensorflow/python/ops/variable_scope.py292
-rw-r--r--tensorflow/python/training/distribute.py73
-rw-r--r--tensorflow/python/training/distribute_test.py39
-rw-r--r--tensorflow/python/training/optimizer.py9
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-variable-aggregation.pbtxt16
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-variable-synchronization.pbtxt20
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt12
22 files changed, 770 insertions, 278 deletions
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index 0261ce43fa..06555c6760 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -28,6 +28,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import device_util
@@ -88,7 +89,7 @@ def _simple_broadcast(value, destinations):
def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
- method_string):
+ aggregation):
# pylint: disable=g-missing-docstring
all_values = []
count = 0
@@ -112,11 +113,12 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
all_values, accumulation_fn)
- if method_string == "mean":
+ if aggregation == vs.VariableAggregation.MEAN:
reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
reduced, count)
- elif method_string != "sum":
- raise ValueError("`method_string` must be 'sum' or 'mean'")
+ elif aggregation != vs.VariableAggregation.SUM:
+ raise ValueError("`aggregation` must be `sum`(VariableAggregation.SUM) "
+ "or `mean`(VariableAggregation.MEAN).")
return reduced
@@ -126,14 +128,15 @@ class CrossTowerOps(object):
def __init__(self):
pass
- def reduce(self, method_string, per_device_value, destinations=None):
+ def reduce(self, aggregation, per_device_value, destinations=None):
"""Reduce `per_device_value` to `destinations`.
- It runs the reduction operation defined by `method_string` and put the
+ It runs the reduction operation defined by `aggregation` and put the
result on `destinations`.
Args:
- method_string: either 'sum' or 'mean' specifying the reduction method.
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
per_device_value: a PerDevice object.
destinations: the reduction destinations.
@@ -147,16 +150,17 @@ class CrossTowerOps(object):
raise ValueError("`per_device_value` must be a `PerDevice` object.")
if destinations is not None:
validate_destinations(destinations)
- return self._reduce(method_string, per_device_value, destinations)
+ return self._reduce(aggregation, per_device_value, destinations)
- def batch_reduce(self, method_string, value_destination_pairs):
+ def batch_reduce(self, aggregation, value_destination_pairs):
"""Reduce PerDevice objects in a batch.
Reduce each first element in `value_destination_pairs` to each second
element which indicates the destinations.
Args:
- method_string: either 'sum' or 'mean' specifying the reduction method.
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
value_destination_pairs: a list or a tuple of tuples of PerDevice objects
and destinations. If a destination is None, then the destinations
are set to match the devices of the input PerDevice object.
@@ -175,7 +179,7 @@ class CrossTowerOps(object):
if d is not None:
validate_destinations(d)
- return self._batch_reduce(method_string, value_destination_pairs)
+ return self._batch_reduce(aggregation, value_destination_pairs)
def broadcast(self, tensor, destinations):
"""Broadcast the `tensor` to destinations.
@@ -190,11 +194,11 @@ class CrossTowerOps(object):
validate_destinations(destinations)
return self._broadcast(tensor, destinations)
- def _reduce(self, method_string, per_device_value, destinations):
+ def _reduce(self, aggregation, per_device_value, destinations):
raise NotImplementedError(
"_reduce method must be implemented in descendants.")
- def _batch_reduce(self, method_string, value_destination_pairs):
+ def _batch_reduce(self, aggregation, value_destination_pairs):
raise NotImplementedError(
"_batch_reduce method must be implemented in descendants.")
@@ -220,16 +224,18 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps):
self.accumulation_fn = accumulation_fn
super(ReductionToOneDeviceCrossTowerOps, self).__init__()
- def _reduce(self, method_string, per_device_value, destinations):
+ def _reduce(self, aggregation, per_device_value, destinations):
devices = get_devices_from(destinations or per_device_value)
reduce_to_device = self.reduce_to_device or devices[0]
reduced = _simple_reduce(per_device_value, reduce_to_device,
- self.accumulation_fn, method_string)
+ self.accumulation_fn, aggregation)
return self.broadcast(reduced, devices)
- def _batch_reduce(self, method_string, value_destination_pairs):
- return [self._reduce(method_string, t, destinations=v)
- for t, v in value_destination_pairs]
+ def _batch_reduce(self, aggregation, value_destination_pairs):
+ return [
+ self._reduce(aggregation, t, destinations=v)
+ for t, v in value_destination_pairs
+ ]
def _group_value_by_device(per_device_values):
@@ -260,18 +266,19 @@ def _group_value_by_device(per_device_values):
return grouped
-def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string):
+def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation):
"""Ungroup results from all-reduce and make Mirrored objects.
Each all-reduce result will be divided by the number of destinations before
- Mirrored objects are created if method_string is "mean".
+ Mirrored objects are created if aggregation is "mean".
Args:
grouped_reduced: a list of lists, each sublist has components for each
device, paired with a None. It is the result from
cross_tower_utils.aggregate_gradients_using*.
destinations: a list of device strings for returned Mirrored objects.
- method_string: "mean" or "sum".
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
Returns:
a list of Mirrored objects.
@@ -279,7 +286,7 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string):
index = [{} for _ in range(len(grouped_reduced[0]))]
for d, per_device_reduced in enumerate(grouped_reduced):
for i, (v, _) in enumerate(per_device_reduced):
- if method_string == "mean":
+ if aggregation == vs.VariableAggregation.MEAN:
index[i][destinations[d]] = v / len(destinations)
else:
index[i][destinations[d]] = v
@@ -488,13 +495,13 @@ class AllReduceCrossTowerOps(CrossTowerOps):
self._agg_small_grads_max_group = agg_small_grads_max_group
super(AllReduceCrossTowerOps, self).__init__()
- def _reduce(self, method_string, per_device_value, destinations):
+ def _reduce(self, aggregation, per_device_value, destinations):
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
per_device_value)
if ((destinations is None or _devices_match(per_device_value, destinations))
and not context.executing_eagerly()
and not contains_indexed_slices):
- return self._batch_all_reduce(method_string, [per_device_value])[0]
+ return self._batch_all_reduce(aggregation, [per_device_value])[0]
else:
if contains_indexed_slices:
logging.log_first_n(
@@ -504,16 +511,16 @@ class AllReduceCrossTowerOps(CrossTowerOps):
devices = get_devices_from(destinations or per_device_value)
reduce_to_device = devices[0]
reduced = _simple_reduce(per_device_value, reduce_to_device,
- math_ops.add_n, method_string)
+ math_ops.add_n, aggregation)
return self.broadcast(reduced, devices)
- def _batch_reduce(self, method_string, value_destination_pairs):
+ def _batch_reduce(self, aggregation, value_destination_pairs):
all_devices_match = _all_devices_match(value_destination_pairs)
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
value_destination_pairs)
if (all_devices_match and not context.executing_eagerly()
and not contains_indexed_slices):
- return self._batch_all_reduce(method_string,
+ return self._batch_all_reduce(aggregation,
[v[0] for v in value_destination_pairs])
else:
if not all_devices_match:
@@ -521,11 +528,11 @@ class AllReduceCrossTowerOps(CrossTowerOps):
"destinations are different.")
return [
- self._reduce(method_string, t, destinations=v)
+ self._reduce(aggregation, t, destinations=v)
for t, v in value_destination_pairs
]
- def _batch_all_reduce(self, method_string, per_device_values):
+ def _batch_all_reduce(self, aggregation, per_device_values):
"""All reduce algorithm in a batch."""
logging.info(
"batch_all_reduce invoked for batches size = %d with "
@@ -556,7 +563,7 @@ class AllReduceCrossTowerOps(CrossTowerOps):
reduced = _unpack_tensors(reduced, tensor_packer)
return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices,
- method_string)
+ aggregation)
AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple",
@@ -635,7 +642,7 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps):
validate_and_complete_spec(spec) for spec in all_reduce_spec
]
- def _batch_all_reduce(self, method_string, per_device_values):
+ def _batch_all_reduce(self, aggregation, per_device_values):
"""All reduce algorithm in a batch."""
logging.info(
"distributed batch_all_reduce invoked for batches size = %d with "
@@ -682,7 +689,7 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps):
assert not remaining_grads
return _ungroup_and_make_mirrored(aggregated_grads, destinations,
- method_string)
+ aggregation)
_dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index c540ea0d23..6a780ff60f 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_util
@@ -129,32 +130,45 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
# test reduce()
for destinations in all_destinations:
self._assert_values_equal(
- cross_tower_ops.reduce("mean", per_device, destinations=destinations),
+ cross_tower_ops.reduce(
+ vs.VariableAggregation.MEAN,
+ per_device,
+ destinations=destinations),
_fake_mirrored(mean, destinations or per_device))
self._assert_values_equal(
cross_tower_ops.reduce(
- "mean", per_device_2, destinations=destinations),
+ vs.VariableAggregation.MEAN,
+ per_device_2,
+ destinations=destinations),
_fake_mirrored(mean_2, destinations or per_device))
self._assert_values_equal(
- cross_tower_ops.reduce("sum", per_device, destinations=destinations),
+ cross_tower_ops.reduce(
+ vs.VariableAggregation.SUM, per_device,
+ destinations=destinations),
_fake_mirrored(mean * len(devices), destinations or per_device))
self._assert_values_equal(
cross_tower_ops.reduce(
- "sum", per_device_2, destinations=destinations),
+ vs.VariableAggregation.SUM,
+ per_device_2,
+ destinations=destinations),
_fake_mirrored(mean_2 * len(devices), destinations or per_device))
# test batch_reduce()
for d1, d2 in itertools.product(all_destinations, all_destinations):
self._assert_values_equal(
- cross_tower_ops.batch_reduce(
- "mean", [(per_device, d1), (per_device_2, d2)]),
- [_fake_mirrored(mean, d1 or per_device),
- _fake_mirrored(mean_2, d2 or per_device_2)])
+ cross_tower_ops.batch_reduce(vs.VariableAggregation.MEAN,
+ [(per_device, d1), (per_device_2, d2)]),
+ [
+ _fake_mirrored(mean, d1 or per_device),
+ _fake_mirrored(mean_2, d2 or per_device_2)
+ ])
self._assert_values_equal(
- cross_tower_ops.batch_reduce(
- "sum", [(per_device, d1), (per_device_2, d2)]),
- [_fake_mirrored(mean * len(devices), d1 or per_device),
- _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)])
+ cross_tower_ops.batch_reduce(vs.VariableAggregation.SUM,
+ [(per_device, d1), (per_device_2, d2)]),
+ [
+ _fake_mirrored(mean * len(devices), d1 or per_device),
+ _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)
+ ])
# test broadcast()
for destinations in all_destinations:
@@ -255,8 +269,8 @@ class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase):
t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0])
t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
- result = cross_tower_ops_lib._simple_reduce(per_device, devices[0],
- math_ops.add_n, "sum")
+ result = cross_tower_ops_lib._simple_reduce(
+ per_device, devices[0], math_ops.add_n, vs.VariableAggregation.SUM)
# Test that the result is semantically equal to both the concatenated
# IndexedSlices with and without duplicate indices.
@@ -267,21 +281,22 @@ class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase):
self._assert_indexed_slices_equal(total_with_dups, result)
self._assert_indexed_slices_equal(total_without_dups, result)
- @combinations.generate(combinations.combine(
- cross_tower_ops_instance=[
- combinations.NamedObject(
- "ReductionToOneDeviceCrossTowerOps",
- cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
- combinations.NamedObject(
- "AllReduceCrossTowerOps",
- cross_tower_ops_lib.AllReduceCrossTowerOps())
- ],
- method_string=["sum", "mean"],
- batch_reduce=[True, False],
- mode=["graph", "eager"],
- required_gpus=1))
- def testIndexedSlicesAllReduce(self, cross_tower_ops_instance,
- method_string, batch_reduce):
+ @combinations.generate(
+ combinations.combine(
+ cross_tower_ops_instance=[
+ combinations.NamedObject(
+ "ReductionToOneDeviceCrossTowerOps",
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
+ combinations.NamedObject(
+ "AllReduceCrossTowerOps",
+ cross_tower_ops_lib.AllReduceCrossTowerOps())
+ ],
+ aggregation=[vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN],
+ batch_reduce=[True, False],
+ mode=["graph", "eager"],
+ required_gpus=1))
+ def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, aggregation,
+ batch_reduce):
devices = ["/cpu:0", "/gpu:0"]
dense_shape = [5, 2]
t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0])
@@ -290,20 +305,19 @@ class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase):
per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
if batch_reduce:
- result = cross_tower_ops_instance.batch_reduce(method_string,
+ result = cross_tower_ops_instance.batch_reduce(aggregation,
[(per_device, devices)])
else:
- result = cross_tower_ops_instance.reduce(method_string, per_device,
- devices)
+ result = cross_tower_ops_instance.reduce(aggregation, per_device, devices)
total_indices_with_dups = [1, 1, 3]
total_indices_without_dups = [1, 3]
- if method_string == "sum":
+ if aggregation == vs.VariableAggregation.SUM:
total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]]
total_values_without_dups = [[4., 6.], [5., 6.]]
else:
- assert method_string == "mean"
+ assert aggregation == vs.VariableAggregation.MEAN
total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]]
total_values_without_dups = [[2., 3.], [2.5, 3.]]
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index d269bed1e5..14c02ab1ad 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -104,9 +104,32 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
colocate_with = kwargs.pop("colocate_with", None)
devices = self._get_devices_from(colocate_with)
- tower_local = kwargs.pop("tower_local_reduce_method", None)
- if tower_local is not None:
+ # Get synchronization value
+ synchronization = kwargs.get(
+ "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
+ if synchronization == variable_scope.VariableSynchronization.NONE:
+ raise ValueError("`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please"
+ " change the `synchronization` for variable: " +
+ kwargs["name"])
+ elif synchronization == variable_scope.VariableSynchronization.ON_READ:
+ # Variables that are to be synced on read are tower local.
+ is_tower_local = True
kwargs["trainable"] = False
+ elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
+ synchronization == variable_scope.VariableSynchronization.AUTO):
+ # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
+ is_tower_local = False
+ else:
+ raise ValueError("Invalid variable synchronization mode: " +
+ synchronization + " for variable: " + kwargs["name"])
+
+ # Get aggregation value
+ aggregation = kwargs.pop("aggregation",
+ variable_scope.VariableAggregation.NONE)
+ if aggregation not in [a for a in variable_scope.VariableAggregation]:
+ raise ValueError("Invalid variable aggregation mode: " + aggregation +
+ " for variable: " + kwargs["name"])
# Ignore user-specified caching device, not needed for mirrored variables.
kwargs.pop("caching_device", None)
@@ -139,11 +162,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
assert not isinstance(v, values.DistributedVariable)
index[d] = v
- if tower_local is None:
- result = values.MirroredVariable(index, index[devices[0]])
+ if is_tower_local:
+ result = values.TowerLocalVariable(index, index[devices[0]],
+ aggregation)
else:
- result = values.TowerLocalVariable(
- index, index[devices[0]], tower_local)
+ result = values.MirroredVariable(index, index[devices[0]], aggregation)
if not context.executing_eagerly():
g = ops.get_default_graph()
@@ -308,12 +331,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps())
return self._cross_tower_ops
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
assert not isinstance(value, values.Mirrored)
if not isinstance(value, values.PerDevice):
if value == 0:
return 0
- if method_string == "mean":
+ if aggregation == variable_scope.VariableAggregation.MEAN:
return self._broadcast(value, destinations)
cross_tower_ops_lib.validate_destinations(destinations)
@@ -331,13 +354,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
value_updates[d] = array_ops.identity(value)
return values.Mirrored(value_updates)
raise ValueError("A non PerDevice value cannot be reduced with the given "
- "method_string.")
+ "aggregation.")
return self._get_cross_tower_ops().reduce(
- method_string, value, destinations=destinations)
+ aggregation, value, destinations=destinations)
- def _batch_reduce(self, method_string, value_destination_pairs):
- return self._get_cross_tower_ops().batch_reduce(method_string,
+ def _batch_reduce(self, aggregation, value_destination_pairs):
+ return self._get_cross_tower_ops().batch_reduce(aggregation,
value_destination_pairs)
def _update(self, var, fn, *args, **kwargs):
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 8d474124b7..c02817f461 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -114,7 +114,10 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
dist = self._get_distribution_strategy()
with dist.scope():
result = dist.call_for_each_tower(run_fn, dist.worker_device_index)
- reduced = dist.reduce("sum", result, destinations="/device:CPU:0")
+ reduced = dist.reduce(
+ variable_scope.VariableAggregation.SUM,
+ result,
+ destinations="/device:CPU:0")
unwrapped = dist.unwrap(reduced)
self.assertEqual(1, len(unwrapped))
expected = sum(range(len(dist.worker_devices)))
@@ -132,8 +135,10 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
dist = mirrored_strategy.MirroredStrategy(devices)
with dist.scope():
- reduced = dist.reduce("sum", 1.0, destinations=["/device:CPU:0",
- "/device:GPU:0"])
+ reduced = dist.reduce(
+ variable_scope.VariableAggregation.SUM,
+ 1.0,
+ destinations=["/device:CPU:0", "/device:GPU:0"])
unwrapped = dist.unwrap(reduced)
self.assertEqual(2, len(unwrapped))
self.assertEqual(1.0, self.evaluate(unwrapped[0]))
@@ -284,18 +289,68 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
self.assertEquals("common/dense" + suffix + "/bias:0", bias.name)
@test_util.run_in_graph_and_eager_modes(config=config)
+ def testWithVariableAndVariableScope(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ def model_fn():
+ v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
+ with variable_scope.variable_scope("common"):
+ v1 = variable_scope.variable(1.0, name="var1")
+ # This will pause the current thread, and execute the other thread.
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ v2 = variable_scope.variable(
+ 1.0,
+ name="var2",
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ v3 = variable_scope.variable(
+ 1.0,
+ name="var3",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=variable_scope.VariableAggregation.MEAN)
+
+ return v0, v1, v2, v3
+
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ v = variable_scope.variable(1.0, name="var-main0")
+ self.assertEquals("var-main0:0", v.name)
+
+ result = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ self.assertEquals(4, len(result))
+ v0, v1, v2, v3 = result
+ self.assertIsInstance(v0, values.MirroredVariable)
+ self.assertEquals("var0:0", v0.name)
+ self.assertIsInstance(v1, values.MirroredVariable)
+ self.assertEquals("common/var1:0", v1.name)
+ self.assertIsInstance(v2, values.TowerLocalVariable)
+ self.assertEquals("common/var2:0", v2.name)
+ self.assertEquals(variable_scope.VariableAggregation.SUM, v2.aggregation)
+ self.assertIsInstance(v3, values.MirroredVariable)
+ self.assertEquals("common/var3:0", v3.name)
+ self.assertEquals(variable_scope.VariableAggregation.MEAN, v3.aggregation)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
def testWithGetVariableAndVariableScope(self):
self._skip_eager_if_gpus_less_than(1)
def model_fn():
- v0 = variable_scope.get_variable("var-thread0", [1])
+ v0 = variable_scope.get_variable("var0", [1])
with variable_scope.variable_scope("common"):
- v1 = variable_scope.get_variable("var-thread1", [1])
+ v1 = variable_scope.get_variable("var1", [1])
# This will pause the current thread, and execute the other thread.
distribute_lib.get_tower_context().merge_call(lambda _: _)
- v2 = variable_scope.get_variable("var-thread2", [1])
+ v2 = variable_scope.get_variable(
+ "var2", [1],
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ v3 = variable_scope.get_variable(
+ "var3", [1],
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=variable_scope.VariableAggregation.MEAN)
- return v0, v1, v2
+ return v0, v1, v2, v3
devices = ["/device:CPU:0", "/device:GPU:0"]
dist = mirrored_strategy.MirroredStrategy(devices)
@@ -305,14 +360,78 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
self.assertEquals("main/var-main0:0", v.name)
result = dist.call_for_each_tower(model_fn, run_concurrently=False)
- self.assertEquals(3, len(result))
- v0, v1, v2 = result
+ self.assertEquals(4, len(result))
+ v0, v1, v2, v3 = result
self.assertIsInstance(v0, values.MirroredVariable)
- self.assertEquals("main/var-thread0:0", v0.name)
+ self.assertEquals("main/var0:0", v0.name)
self.assertIsInstance(v1, values.MirroredVariable)
- self.assertEquals("main/common/var-thread1:0", v1.name)
- self.assertIsInstance(v2, values.MirroredVariable)
- self.assertEquals("main/common/var-thread2:0", v2.name)
+ self.assertEquals("main/common/var1:0", v1.name)
+ self.assertIsInstance(v2, values.TowerLocalVariable)
+ self.assertEquals("main/common/var2:0", v2.name)
+ self.assertEquals(variable_scope.VariableAggregation.SUM,
+ v2.aggregation)
+ self.assertIsInstance(v3, values.MirroredVariable)
+ self.assertEquals("main/common/var3:0", v3.name)
+ self.assertEquals(variable_scope.VariableAggregation.MEAN,
+ v3.aggregation)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testInvalidSynchronizationWithGetVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please change "
+ "the `synchronization` for variable: v"):
+ variable_scope.get_variable(
+ "v", [1],
+ synchronization=variable_scope.VariableSynchronization.NONE)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testInvalidSynchronizationWithVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please change "
+ "the `synchronization` for variable: v"):
+ variable_scope.variable(
+ 1.0,
+ name="v",
+ synchronization=variable_scope.VariableSynchronization.NONE)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testInvalidAggregationWithGetVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "Invalid variable aggregation mode: invalid for "
+ "variable: v"):
+ variable_scope.get_variable(
+ "v", [1],
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation="invalid")
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testInvalidAggregationWithVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "Invalid variable aggregation mode: invalid for "
+ "variable: v"):
+ variable_scope.variable(
+ 1.0,
+ name="v",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation="invalid")
@test_util.run_in_graph_and_eager_modes(config=config)
def testThreeDevices(self):
@@ -362,9 +481,11 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn(device_id):
tower_context = distribute_lib.get_tower_context()
- with tower_context.tower_local_var_scope("sum"):
+ with tower_context.tower_local_var_scope(
+ variable_scope.VariableAggregation.SUM):
v_sum = variable_scope.variable(1.0)
- with tower_context.tower_local_var_scope("mean"):
+ with tower_context.tower_local_var_scope(
+ variable_scope.VariableAggregation.MEAN):
v_mean = variable_scope.variable(4.0)
self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
self.assertTrue(isinstance(v_mean, values.TowerLocalVariable))
@@ -569,7 +690,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
tower_context = distribute_lib.get_tower_context()
- with tower_context.tower_local_var_scope("sum"):
+ with tower_context.tower_local_var_scope(
+ variable_scope.VariableAggregation.SUM):
v_sum = variable_scope.variable(1.0)
self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
return v_sum
@@ -642,7 +764,8 @@ class MirroredVariableUpdateTest(test.TestCase):
# aggregation type.
self._skip_eager_if_gpus_less_than(1)
def var_fn():
- v = variable_scope.variable(1.0, name="foo")
+ v = variable_scope.variable(
+ 1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -650,9 +773,6 @@ class MirroredVariableUpdateTest(test.TestCase):
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
- # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
- # aggregation method.
- mirrored_var._aggregation_method = "sum"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
@@ -661,7 +781,7 @@ class MirroredVariableUpdateTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError, "A non PerDevice value cannot be reduced with the given "
- "method_string."):
+ "aggregation."):
self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
@test_util.run_in_graph_and_eager_modes(config=config)
@@ -685,16 +805,14 @@ class MirroredVariableUpdateTest(test.TestCase):
def testAssignMirroredVarTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
- return variable_scope.variable(1.0, name="foo")
+ return variable_scope.variable(
+ 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
- # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
- # aggregation method.
- mirrored_var._aggregation_method = "mean"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(1.0, self.evaluate(mirrored_var))
@@ -729,16 +847,14 @@ class MirroredVariableUpdateTest(test.TestCase):
def testAssignAddMirroredVarTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
- return variable_scope.variable(1.0, name="foo")
+ return variable_scope.variable(
+ 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
- # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
- # aggregation method.
- mirrored_var._aggregation_method = "mean"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(1.0, self.evaluate(mirrored_var))
@@ -773,16 +889,14 @@ class MirroredVariableUpdateTest(test.TestCase):
def testAssignSubMirroredVarTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
- return variable_scope.variable(5.0, name="foo")
+ return variable_scope.variable(
+ 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
- # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
- # aggregation method.
- mirrored_var._aggregation_method = "mean"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(5.0, self.evaluate(mirrored_var))
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index a580dac96c..dbd3514aec 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.distribute.python import values
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import distribute as distribute_lib
@@ -43,11 +44,6 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
self._default_device = device
def _create_variable(self, next_creator, *args, **kwargs):
- # No need to distinguish tower-local variables when not mirroring,
- # we just enforce that they are not trainable.
- if kwargs.pop("tower_local_reduce_method", None) is not None:
- kwargs["trainable"] = False
-
colocate_with = kwargs.pop("colocate_with", None)
if colocate_with is None:
with ops.device(self._device):
@@ -80,15 +76,15 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
with ops.device(self._device):
return values.MapOutput([fn(m, *args, **kwargs) for m in map_over])
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
if not isinstance(value, values.MapOutput):
return value
l = value.get()
assert l
with ops.device(self._device):
- if method_string == "sum":
+ if aggregation == vs.VariableAggregation.SUM:
return math_ops.add_n(l)
- elif method_string == "mean":
+ elif aggregation == vs.VariableAggregation.MEAN:
return math_ops.add_n(l) / len(l)
else:
assert False
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index d2fe8b3b1e..baed0ebaae 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import optimizer
@@ -110,7 +111,8 @@ class DistributionTestBase(test.TestCase):
before_list.append(fetched)
# control_dependencies irrelevant but harmless in eager execution
with ops.control_dependencies([fetched]):
- g = d.reduce("sum", g, destinations=v)
+ g = d.reduce(
+ variable_scope.VariableAggregation.SUM, g, destinations=v)
with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
after_list.append(d.read_var(v))
return before_list, after_list
@@ -162,7 +164,8 @@ class DistributionTestBase(test.TestCase):
fetched = d.read_var(v)
before_list.append(fetched)
with ops.control_dependencies([fetched]):
- g = d.reduce("sum", g, destinations=v)
+ g = d.reduce(
+ variable_scope.VariableAggregation.SUM, g, destinations=v)
with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
after_list.append(d.read_var(v))
return before_list, after_list
@@ -184,7 +187,7 @@ class DistributionTestBase(test.TestCase):
with d.scope():
map_in = [constant_op.constant(i) for i in range(10)]
map_out = d.map(map_in, lambda x, y: x * y, 2)
- observed = d.reduce("sum", map_out)
+ observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out)
expected = 90 # 2 * (0 + 1 + ... + 9)
self.assertEqual(expected, observed.numpy())
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 1ae12ae98a..bc53898539 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import nest
@@ -137,9 +138,9 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def get_finalize_ops(self):
return [tpu.shutdown_system()]
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
del destinations # TPU is graph mode only. Rely on implicit Send/Recv.
- if method_string == 'mean':
+ if aggregation == vs.VariableAggregation.MEAN:
# TODO(jhseu): Revisit once we support model-parallelism.
value *= (1. / self._num_cores_per_host)
return tpu_ops.cross_replica_sum(value)
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 95390041f4..b36ac563d2 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -34,6 +34,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver
@@ -290,13 +291,13 @@ class MirroredVariable(DistributedVariable, Mirrored,
checkpointable.CheckpointableBase):
"""Holds a map from device to variables whose values are kept in sync."""
- def __init__(self, index, primary_var, aggregation_method=None):
+ def __init__(self, index, primary_var, aggregation):
# Use a weakref to make it easy to map from the contained values
# to the container without introducing a reference cycle.
for v in six.itervalues(index):
v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._primary_var = primary_var
- self._aggregation_method = aggregation_method
+ self._aggregation = aggregation
super(MirroredVariable, self).__init__(index)
# The arguments to update() are automatically unwrapped so the update()
@@ -325,17 +326,16 @@ class MirroredVariable(DistributedVariable, Mirrored,
# handle the different use cases can be found in the _reduce method.
# We call the function on each of the mirrored variables with the reduced
# value.
- if not self._aggregation_method:
+ if self._aggregation == vs.VariableAggregation.NONE:
raise ValueError("You must specify an aggregation method to update a "
"MirroredVariable in Tower Context.")
def merge_fn(strategy, value):
- return strategy.update(self,
- f,
- strategy.reduce(
- method_string=self._aggregation_method,
- value=value,
- destinations=self))
+ return strategy.update(
+ self, f,
+ strategy.reduce(
+ aggregation=self._aggregation, value=value, destinations=self))
+
return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
**kwargs)
@@ -348,6 +348,10 @@ class MirroredVariable(DistributedVariable, Mirrored,
def assign(self, *args, **kwargs):
return self._assign_func(f=state_ops.assign, *args, **kwargs)
+ @property
+ def aggregation(self):
+ return self._aggregation
+
def _get_cross_tower(self):
device = device_util.canonicalize(device_util.current())
if device in self._index:
@@ -411,7 +415,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
# To preserve the sum across save and restore, we have to divide the
# total across all devices when restoring a variable that was summed
# when saving.
- if self._tower_local_variable.reduce_method == "sum":
+ if self._tower_local_variable.aggregation == vs.VariableAggregation.SUM:
tensor *= 1. / len(self._tower_local_variable.devices)
return control_flow_ops.group([
_assign_on_device(d, v, tensor)
@@ -428,9 +432,9 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
checkpointable.CheckpointableBase):
"""Holds a map from device to variables whose values are reduced on save."""
- def __init__(self, index, primary_var, reduce_method):
+ def __init__(self, index, primary_var, aggregation):
self._primary_var = primary_var
- self._reduce_method = reduce_method
+ self._aggregation = aggregation
super(TowerLocalVariable, self).__init__(index)
def assign_sub(self, *args, **kwargs):
@@ -446,14 +450,14 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return self.get().assign(*args, **kwargs)
@property
- def reduce_method(self):
- return self._reduce_method
+ def aggregation(self):
+ return self._aggregation
def _get_cross_tower(self):
all_components = tuple(self._index.values())
# TODO(josh11b): Use a strategy-specific method.
total = math_ops.add_n(all_components)
- if self._reduce_method == "mean":
+ if self._aggregation == vs.VariableAggregation.MEAN:
return total * (1./ len(all_components))
return total
@@ -929,4 +933,3 @@ class MultiStepContext(object):
assert o.dtype == i.dtype, (
"Dtype {} of left {} doesn't match dtype {} of right {}.".
format(o.dtype, o, i.dtype, i))
-
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index c5b246e804..8e44f2fea1 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -158,7 +158,8 @@ def _make_mirrored():
v.append(variable_scope.get_variable(
name=n, initializer=init, use_resource=True))
index[d] = v[-1]
- mirrored = values.MirroredVariable(index, v[0])
+ mirrored = values.MirroredVariable(index, v[0],
+ variable_scope.VariableAggregation.SUM)
return v, devices, mirrored
@@ -277,7 +278,8 @@ class RegroupAndSelectDeviceTest(test.TestCase):
v = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
index = {d: v}
- mirrored = values.MirroredVariable(index, v)
+ mirrored = values.MirroredVariable(index, v,
+ variable_scope.VariableAggregation.SUM)
result = values.regroup(index)
self.assertIs(mirrored, result)
@@ -581,7 +583,8 @@ class MirroredVariableTest(test.TestCase):
v = variable_scope.get_variable(
name="v", initializer=[1.], use_resource=True)
index = {"/job:foo/device:CPU:0": v}
- mirrored = values.MirroredVariable(index, v)
+ mirrored = values.MirroredVariable(index, v,
+ variable_scope.VariableAggregation.MEAN)
self.assertEquals(v.name, mirrored.name)
self.assertEquals(v.dtype, mirrored.dtype)
@@ -716,7 +719,9 @@ class MirroredVariableTest(test.TestCase):
with ops.device("/device:GPU:0"):
v = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
- mirrored = values.MirroredVariable({"/device:GPU:0": v}, v)
+ mirrored = values.MirroredVariable({
+ "/device:GPU:0": v
+ }, v, variable_scope.VariableAggregation.MEAN)
sess.run(variables_lib.global_variables_initializer())
sess.run({"complicated": mirrored})
@@ -746,24 +751,27 @@ class TowerLocalVariableTest(test.TestCase):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
- v, tower_local = _make_tower_local("sum")
+ v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
self.assertEquals(v[0].name, tower_local.name)
self.assertEquals(v[0].dtype, tower_local.dtype)
self.assertEquals(v[0].shape, tower_local.shape)
- self.assertEquals("sum", tower_local.reduce_method)
+ self.assertEquals(variable_scope.VariableAggregation.SUM,
+ tower_local.aggregation)
@test_util.run_in_graph_and_eager_modes(config=config)
def testVariableOnAnotherDevice(self):
v = variable_scope.get_variable(
name="v", initializer=[1.], use_resource=True)
index = {"/job:foo/device:CPU:0": v}
- tower_local = values.TowerLocalVariable(index, v, "mean")
+ tower_local = values.TowerLocalVariable(
+ index, v, variable_scope.VariableAggregation.MEAN)
self.assertEquals(v.name, tower_local.name)
self.assertEquals(v.dtype, tower_local.dtype)
self.assertEquals(v.shape, tower_local.shape)
- self.assertEquals("mean", tower_local.reduce_method)
+ self.assertEquals(variable_scope.VariableAggregation.MEAN,
+ tower_local.aggregation)
def _assign_tower_local(self, devices, v, new):
for d, var, n in zip(devices, v, new):
@@ -789,7 +797,7 @@ class TowerLocalVariableTest(test.TestCase):
self.skipTest("A GPU is not available for this test in eager mode.")
with self.test_session() as sess:
- v, tower_local = _make_tower_local("sum")
+ v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [3., 4.])
@@ -812,7 +820,8 @@ class TowerLocalVariableTest(test.TestCase):
self.skipTest("A GPU is not available for this test in eager mode.")
with self.test_session() as sess:
- v, tower_local = _make_tower_local("mean")
+ v, tower_local = _make_tower_local(
+ variable_scope.VariableAggregation.MEAN)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [3., 4.])
@@ -831,7 +840,8 @@ class TowerLocalVariableTest(test.TestCase):
def _save_tower_local_mean(self):
"""Save variables with mirroring, returns save_path."""
with self.test_session(graph=ops.Graph()) as sess:
- v, tower_local = _make_tower_local("mean")
+ v, tower_local = _make_tower_local(
+ variable_scope.VariableAggregation.MEAN)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [3., 4.])
@@ -893,7 +903,8 @@ class TowerLocalVariableTest(test.TestCase):
def _restore_tower_local_mean(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
with self.test_session(graph=ops.Graph()) as sess:
- v, tower_local = _make_tower_local("mean")
+ v, tower_local = _make_tower_local(
+ variable_scope.VariableAggregation.MEAN)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [7., 8.])
@@ -907,7 +918,7 @@ class TowerLocalVariableTest(test.TestCase):
def _restore_tower_local_sum(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
with self.test_session(graph=ops.Graph()) as sess:
- v, tower_local = _make_tower_local("sum")
+ v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [7., 8.])
@@ -968,7 +979,7 @@ class TowerLocalVariableTest(test.TestCase):
def testTensorConversion(self):
with context.graph_mode():
- _, tower_local = _make_tower_local("sum")
+ _, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
converted = ops.internal_convert_to_tensor(tower_local, as_ref=False)
self.assertIsInstance(converted, ops.Tensor)
self.assertEqual(converted.dtype, tower_local.dtype)
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index c6f3bd6ee1..8c11d8bcfd 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -766,7 +766,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
# *after* loss() is evaluated, so we know what loss reduction it uses.
if scale_loss_by_num_towers is None:
scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() == "mean")
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
@@ -784,7 +785,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
# Scale loss for number of towers (non-callable-loss case).
if scale_loss_by_num_towers is None:
scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() == "mean")
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
@@ -896,7 +898,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
"""`apply_gradients` for use with a `DistributionStrategy`."""
- reduced_grads = distribution.batch_reduce("sum", grads_and_vars)
+ reduced_grads = distribution.batch_reduce(
+ variable_scope.VariableAggregation.SUM, grads_and_vars)
var_list = [v for _, v in grads_and_vars]
grads_and_vars = zip(reduced_grads, var_list)
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index 760a148552..848adf4fd3 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -110,13 +110,25 @@ class _VariableCapturingScope(object):
"""
# TODO(apassos) ignoring the regularizer and partitioner here; figure out
# how to deal with these.
- def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name
- partitioner=None, validate_shape=True,
- use_resource=None):
+ def _custom_getter( # pylint: disable=missing-docstring
+ getter=None,
+ name=None,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=True,
+ collections=None,
+ caching_device=None, # pylint: disable=redefined-outer-name
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ aggregation=variable_scope.VariableAggregation.NONE,
+ synchronization=variable_scope.VariableSynchronization.AUTO):
del getter, regularizer, partitioner, validate_shape, use_resource, dtype
- del collections, initializer, trainable, reuse, caching_device, shape,
+ del collections, initializer, trainable, reuse, caching_device, shape
+ del aggregation, synchronization
assert name in self.variables
v = self.variables[name]
return v.variable
@@ -136,13 +148,24 @@ class _VariableCapturingScope(object):
"""
# TODO(apassos) ignoring the regularizer and partitioner here; figure out
# how to deal with these.
- def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name
- partitioner=None, validate_shape=True,
- use_resource=None):
+ def _custom_getter( # pylint: disable=missing-docstring
+ getter=None,
+ name=None,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=True,
+ collections=None,
+ caching_device=None, # pylint: disable=redefined-outer-name
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ aggregation=variable_scope.VariableAggregation.NONE,
+ synchronization=variable_scope.VariableSynchronization.AUTO):
del getter, regularizer, collections, caching_device, partitioner
- del use_resource, validate_shape
+ del use_resource, validate_shape, aggregation, synchronization
if name in self.tf_variables:
if reuse:
return self.tf_variables[name].initialized_value()
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index d4c213eedd..8b894ca6b1 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -34,6 +34,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util.tf_export import tf_export
@@ -182,7 +183,8 @@ class BatchNormalization(Layer):
def _add_tower_local_variable(self, *args, **kwargs):
tower_context = distribute_lib.get_tower_context()
- with tower_context.tower_local_var_scope('mean'):
+ with tower_context.tower_local_var_scope(
+ variable_scope.VariableAggregation.MEAN):
return self.add_weight(*args, **kwargs)
def build(self, input_shape):
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 1e59a8c9bf..054c6f9dd7 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -1253,6 +1253,31 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
self.assertEqual(v3, v4)
self.assertEqual(3, called[0]) # skipped one in the first new_scope
+ def testSynchronizationAndAggregationWithCustomGetter(self):
+ called = [0]
+ synchronization = variable_scope.VariableSynchronization.AUTO
+ aggregation = variable_scope.VariableAggregation.NONE
+
+ def custom_getter(getter, *args, **kwargs):
+ called[0] += 1
+
+ # Verify synchronization and aggregation kwargs are as expected.
+ self.assertEqual(kwargs["synchronization"], synchronization)
+ self.assertEqual(kwargs["aggregation"], aggregation)
+ return getter(*args, **kwargs)
+
+ with variable_scope.variable_scope("scope", custom_getter=custom_getter):
+ variable_scope.get_variable("v", [1])
+ self.assertEqual(1, called[0])
+
+ with variable_scope.variable_scope("scope", custom_getter=custom_getter):
+ synchronization = variable_scope.VariableSynchronization.ON_READ
+ aggregation = variable_scope.VariableAggregation.MEAN
+ variable_scope.get_variable(
+ "v1", [1], synchronization=synchronization, aggregation=aggregation)
+
+ self.assertEqual(2, called[0])
+
def testCustomGetterWithReuse(self):
# Custom getter can choose to behave differently on reused variables.
def custom_getter(getter, *args, **kwargs):
@@ -1355,6 +1380,23 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
self.assertAllEqual(variable_names, ["forced_name"])
+ called = [False]
+
+ def creater_c(next_creator, **kwargs):
+ called[0] = True
+ self.assertEqual(kwargs["synchronization"],
+ variable_scope.VariableSynchronization.ON_WRITE)
+ self.assertEqual(kwargs["aggregation"],
+ variable_scope.VariableAggregation.MEAN)
+ return next_creator(**kwargs)
+
+ with variable_scope.variable_creator_scope(creater_c):
+ variable_scope.get_variable(
+ "v", [],
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=variable_scope.VariableAggregation.MEAN)
+ self.assertTrue(called[0])
+
class PartitionInfoTest(test.TestCase):
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 5eab12c41d..bfd225b0d8 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -73,7 +73,8 @@ def metric_variable(shape, dtype, validate_shape=True, name=None):
A (non-trainable) variable initialized to zero, or if inside a
`DistributionStrategy` scope a tower-local variable container.
"""
- with distribute_lib.get_tower_context().tower_local_var_scope('sum'):
+ with distribute_lib.get_tower_context().tower_local_var_scope(
+ variable_scope.VariableAggregation.SUM):
# Note that "tower local" implies trainable=False.
return variable_scope.variable(
lambda: array_ops.zeros(shape, dtype),
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 47414c28af..f862b62fad 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -1,4 +1,4 @@
- # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -44,9 +44,11 @@ from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-__all__ = ["AUTO_REUSE", "VariableScope", "get_variable_scope",
- "get_variable", "get_local_variable", "variable_scope",
- "variable_op_scope", "no_regularizer"]
+__all__ = [
+ "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable",
+ "get_local_variable", "variable_scope", "variable_op_scope",
+ "no_regularizer", "VariableSynchronization", "VariableAggregation"
+]
class _PartitionInfo(object):
@@ -188,6 +190,38 @@ class _ReuseMode(enum.Enum):
# REUSE_FALSE = 2
# REUSE_TRUE = 3
+
+@tf_export("VariableSynchronization")
+class VariableSynchronization(enum.Enum):
+ """Indicates when a distributed variable will be synced."""
+
+ # Indicates that the synchronization will be determined by the current
+ # `DistributionStrategy` (eg. With `MirroredStrategy` this would be
+ # `ON_WRITE`).
+ AUTO = 0
+
+ # Indicates that there will only be one copy of the variable, so there is no
+ # need to sync.
+ NONE = 1
+
+ # Indicates that the variable will be aggregated across devices
+ # every time it is updated.
+ ON_WRITE = 2
+
+ # Indicates that the variable will be aggregated across devices
+ # when it is read (eg. when checkpointing or when evaluating an op that uses
+ # the variable).
+ ON_READ = 3
+
+
+@tf_export("VariableAggregation")
+class VariableAggregation(enum.Enum):
+ """Indicates how a distributed variable will be aggregated."""
+ NONE = 0
+ SUM = 1
+ MEAN = 2
+
+
AUTO_REUSE = _ReuseMode.AUTO_REUSE
tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE")
AUTO_REUSE.__doc__ = """
@@ -214,11 +248,23 @@ class _VariableStore(object):
self._partitioned_vars = {} # A dict of the stored PartitionedVariables.
self._store_eager_variables = False
- def get_variable(self, name, shape=None, dtype=dtypes.float32,
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None,
- partitioner=None, validate_shape=True, use_resource=None,
- custom_getter=None, constraint=None):
+ def get_variable(self,
+ name,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=True,
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ custom_getter=None,
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Gets an existing variable with these parameters or create a new one.
If a variable with the given name is already stored, we return the stored
@@ -291,6 +337,14 @@ class _VariableStore(object):
variable and return the Tensor for the projected value
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
Returns:
The created or existing `Variable` (or `PartitionedVariable`, if a
@@ -343,11 +397,22 @@ class _VariableStore(object):
# it to custom_getter.
# Note: the parameters of _true_getter, and their documentation, match
# *exactly* item-for-item with the docstring of this method.
- def _true_getter(name, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None,
- partitioner=None, validate_shape=True, use_resource=None,
- constraint=None):
+ def _true_getter( # pylint: disable=missing-docstring
+ name,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=True,
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
is_scalar = (shape is not None
and isinstance(shape, collections_lib.Sequence)
and not shape)
@@ -397,11 +462,20 @@ class _VariableStore(object):
"name was already created with partitioning?" % name)
return self._get_single_variable(
- name=name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer, reuse=reuse,
- trainable=trainable, collections=collections,
- caching_device=caching_device, validate_shape=validate_shape,
- use_resource=use_resource, constraint=constraint)
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ reuse=reuse,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
if custom_getter is not None:
# Handle backwards compatibility with getter arguments that were added
@@ -420,6 +494,8 @@ class _VariableStore(object):
"partitioner": partitioner,
"validate_shape": validate_shape,
"use_resource": use_resource,
+ "synchronization": synchronization,
+ "aggregation": aggregation,
}
# `fn_args` can handle functions, `functools.partial`, `lambda`.
if "constraint" in function_utils.fn_args(custom_getter):
@@ -427,12 +503,21 @@ class _VariableStore(object):
return custom_getter(**custom_getter_kwargs)
else:
return _true_getter(
- name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer,
- reuse=reuse, trainable=trainable, collections=collections,
- caching_device=caching_device, partitioner=partitioner,
- validate_shape=validate_shape, use_resource=use_resource,
- constraint=constraint)
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ reuse=reuse,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
def _get_partitioned_variable(
self, name, partitioner, shape=None, dtype=dtypes.float32,
@@ -693,7 +778,9 @@ class _VariableStore(object):
caching_device=None,
validate_shape=True,
use_resource=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Get or create a single Variable (e.g. a shard or entire variable).
See the documentation of get_variable above (ignore partitioning components)
@@ -713,6 +800,8 @@ class _VariableStore(object):
validate_shape: see get_variable.
use_resource: see get_variable.
constraint: see get_variable.
+ synchronization: see get_variable.
+ aggregation: see get_variable.
Returns:
A Variable. See documentation of get_variable above.
@@ -793,7 +882,9 @@ class _VariableStore(object):
dtype=variable_dtype,
validate_shape=validate_shape,
constraint=constraint,
- use_resource=use_resource)
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
if context.executing_eagerly() and self._store_eager_variables:
if collections:
ops.add_to_collections(collections, v)
@@ -1052,7 +1143,9 @@ class VariableScope(object):
validate_shape=True,
use_resource=None,
custom_getter=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Gets an existing variable with this name or create a new one."""
if regularizer is None:
regularizer = self._regularizer
@@ -1090,12 +1183,22 @@ class VariableScope(object):
if dtype is None:
dtype = self._dtype
return var_store.get_variable(
- full_name, shape=shape, dtype=dtype, initializer=initializer,
- regularizer=regularizer, reuse=reuse, trainable=trainable,
- collections=collections, caching_device=caching_device,
- partitioner=partitioner, validate_shape=validate_shape,
- use_resource=use_resource, custom_getter=custom_getter,
- constraint=constraint)
+ full_name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ reuse=reuse,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ custom_getter=custom_getter,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
def _get_partitioned_variable(self,
var_store,
@@ -1326,14 +1429,28 @@ def get_variable(name,
validate_shape=True,
use_resource=None,
custom_getter=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
return get_variable_scope().get_variable(
- _get_default_variable_store(), name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer, trainable=trainable,
- collections=collections, caching_device=caching_device,
- partitioner=partitioner, validate_shape=validate_shape,
- use_resource=use_resource, custom_getter=custom_getter,
- constraint=constraint)
+ _get_default_variable_store(),
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ custom_getter=custom_getter,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
+
get_variable_or_local_docstring = (
"""%s
@@ -1430,29 +1547,44 @@ get_variable.__doc__ = get_variable_or_local_docstring % (
# The argument list for get_local_variable must match arguments to get_variable.
# So, if you are updating the arguments, also update arguments to get_variable.
@tf_export("get_local_variable")
-def get_local_variable(name,
- shape=None,
- dtype=None,
- initializer=None,
- regularizer=None,
- trainable=False, # pylint: disable=unused-argument
- collections=None,
- caching_device=None,
- partitioner=None,
- validate_shape=True,
- use_resource=None,
- custom_getter=None,
- constraint=None):
+def get_local_variable( # pylint: disable=missing-docstring
+ name,
+ shape=None,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=False, # pylint: disable=unused-argument
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE,
+ custom_getter=None,
+ constraint=None):
if collections:
collections += [ops.GraphKeys.LOCAL_VARIABLES]
else:
collections = [ops.GraphKeys.LOCAL_VARIABLES]
return get_variable(
- name, shape=shape, dtype=dtype, initializer=initializer,
- regularizer=regularizer, trainable=False, collections=collections,
- caching_device=caching_device, partitioner=partitioner,
- validate_shape=validate_shape, use_resource=use_resource,
- custom_getter=custom_getter, constraint=constraint)
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ trainable=False,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation,
+ custom_getter=custom_getter,
+ constraint=constraint)
+
+
get_local_variable.__doc__ = get_variable_or_local_docstring % (
"Gets an existing *local* variable or creates a new one.",
"Behavior is the same as in `get_variable`, except that variables are\n"
@@ -2214,6 +2346,12 @@ def default_variable_creator(next_creator=None, **kwargs):
dtype = kwargs.get("dtype", None)
constraint = kwargs.get("constraint", None)
use_resource = kwargs.get("use_resource", None)
+
+ # Enforce `ON_READ` variables to be not trainable.
+ synchronization = kwargs.pop("synchronization", VariableSynchronization.AUTO)
+ if synchronization == VariableSynchronization.ON_READ:
+ trainable = False
+
if use_resource is None:
use_resource = get_variable_scope().use_resource
if use_resource or (use_resource is None and context.executing_eagerly()):
@@ -2248,18 +2386,28 @@ def variable(initial_value=None,
name=None,
dtype=None,
constraint=None,
- use_resource=None):
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
previous_getter = _make_getter(getter, previous_getter)
- return previous_getter(initial_value=initial_value,
- trainable=trainable,
- collections=collections,
- validate_shape=validate_shape,
- caching_device=caching_device,
- name=name, dtype=dtype,
- constraint=constraint,
- use_resource=use_resource)
+
+ # Reset `aggregation` that is explicitly set as `None` to the enum None value.
+ if aggregation is None:
+ aggregation = VariableAggregation.NONE
+ return previous_getter(
+ initial_value=initial_value,
+ trainable=trainable,
+ collections=collections,
+ validate_shape=validate_shape,
+ caching_device=caching_device,
+ name=name,
+ dtype=dtype,
+ constraint=constraint,
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
@tf_contextlib.contextmanager
@@ -2311,6 +2459,14 @@ def variable_creator_scope(variable_creator):
constraint: A constraint function to be applied to the variable after
updates by some algorithms.
use_resource: if True, a ResourceVariable is always created.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
This set may grow over time, so it's important the signature of creators is as
mentioned above.
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 6a326b65bb..562ad3bb02 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -221,11 +221,12 @@ def has_distribution_strategy():
def get_loss_reduction():
- """Reduce `method_string` corresponding to the last loss reduction."""
+ """Reduce `aggregation` corresponding to the last loss reduction."""
loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
+ print(loss_reduction)
if loss_reduction == losses_impl.Reduction.SUM:
- return "sum"
- return "mean"
+ return variable_scope.VariableAggregation.SUM
+ return variable_scope.VariableAggregation.MEAN
# ------------------------------------------------------------------------------
@@ -539,8 +540,8 @@ class DistributionStrategy(object):
1. Wrap your input dataset in `d.distribute_dataset()` and create an iterator.
2. Define each tower `d.call_for_each_tower()` up to the point of
getting a list of gradient, variable pairs.
- 3. Call `d.reduce("sum", t, v)` or `d.batch_reduce()` to sum the
- gradients (with locality T) into values with locality V(`v`).
+ 3. Call `d.reduce(VariableAggregation.SUM, t, v)` or `d.batch_reduce()` to sum
+ the gradients (with locality T) into values with locality V(`v`).
4. Call `d.update(v)` for each variable to update its value.
Steps 3 and 4 are done automatically by class `Optimizer` if you call
@@ -614,7 +615,7 @@ class DistributionStrategy(object):
# Note: should support "colocate_with" argument.
raise NotImplementedError("must be implemented in descendants")
- def tower_local_var_scope(self, reduce_method):
+ def tower_local_var_scope(self, aggregation):
"""Inside this scope, new variables will not be mirrored.
There will still be one component variable per tower, but there is
@@ -636,16 +637,21 @@ class DistributionStrategy(object):
random numbers.
Args:
- reduce_method: String used as a `method_string` to `reduce()`
- to get the value to save when checkpointing.
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
Returns:
A context manager.
"""
+ # TODO(psv): Remove this after adding support for synchronization and
+ # aggregation parameters in get_variable() and mirrored strategy.
def create_tower_local_variable(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
kwargs["use_resource"] = True
- kwargs["tower_local_reduce_method"] = reduce_method
+
+ # Set synchronization to be ON_READ for tower local variables.
+ kwargs["synchronization"] = variable_scope.VariableSynchronization.ON_READ
+ kwargs["aggregation"] = aggregation
return next_creator(*args, **kwargs)
_require_distribution_strategy_scope(self)
@@ -816,12 +822,12 @@ class DistributionStrategy(object):
def _call_for_each_tower(self, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
- def reduce(self, method_string, value, destinations=None):
+ def reduce(self, aggregation, value, destinations=None):
"""Combine (via e.g. sum or mean) values across towers.
Args:
- method_string: A string indicating how to combine values, either
- "sum" or "mean".
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
value: A per-device value with one value per tower.
destinations: An optional mirrored variable, a device string,
list of device strings. The return value will be copied to all
@@ -836,18 +842,21 @@ class DistributionStrategy(object):
# TODO(josh11b): Return an unwrapped value if colocate_with is a
# single device.
_require_cross_tower_context(self)
- assert method_string in ("sum", "mean")
- return self._reduce(method_string, value, destinations)
+ assert aggregation in [
+ variable_scope.VariableAggregation.SUM,
+ variable_scope.VariableAggregation.MEAN
+ ]
+ return self._reduce(aggregation, value, destinations)
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
raise NotImplementedError("must be implemented in descendants")
- def batch_reduce(self, method_string, value_destination_pairs):
+ def batch_reduce(self, aggregation, value_destination_pairs):
"""Combine multiple `reduce` calls into one for faster execution.
Args:
- method_string: A string indicating how to combine values, either
- "sum" or "mean".
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
value_destination_pairs: A sequence of (value, destinations)
pairs. See `reduce()` for a description.
@@ -856,12 +865,17 @@ class DistributionStrategy(object):
"""
# TODO(josh11b): More docstring
_require_cross_tower_context(self)
- assert method_string in ("sum", "mean")
- return self._batch_reduce(method_string, value_destination_pairs)
-
- def _batch_reduce(self, method_string, value_destination_pairs):
- return [self.reduce(method_string, t, destinations=v)
- for t, v in value_destination_pairs]
+ assert aggregation in [
+ variable_scope.VariableAggregation.SUM,
+ variable_scope.VariableAggregation.MEAN
+ ]
+ return self._batch_reduce(aggregation, value_destination_pairs)
+
+ def _batch_reduce(self, aggregation, value_destination_pairs):
+ return [
+ self.reduce(aggregation, t, destinations=v)
+ for t, v in value_destination_pairs
+ ]
def update(self, var, fn, *args, **kwargs):
"""Run `fn` to update `var` using inputs mirrored to the same devices.
@@ -1090,9 +1104,9 @@ class TowerContext(object):
finally:
_pop_per_thread_mode()
- def tower_local_var_scope(self, reduce_method):
+ def tower_local_var_scope(self, aggregation):
"""Alias for distribution_strategy.tower_local_var_scope()."""
- return self._distribution_strategy.tower_local_var_scope(reduce_method)
+ return self._distribution_strategy.tower_local_var_scope(aggregation)
@property
def is_single_tower(self):
@@ -1140,13 +1154,12 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def creator(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
- kwargs.pop("tower_local_reduce_method", None)
return next_creator(*args, **kwargs)
return _CurrentDistributionContext(
self, variable_scope.variable_creator_scope(creator))
- def tower_local_var_scope(self, reduce_method):
+ def tower_local_var_scope(self, aggregation):
"""Does not set to resource variables."""
def create_tower_local_variable(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
@@ -1176,9 +1189,9 @@ class _DefaultDistributionStrategy(DistributionStrategy):
with TowerContext(self, tower_id=0):
return fn(*args, **kwargs)
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
# TODO(josh11b): Use destinations?
- del method_string, destinations
+ del aggregation, destinations
return value
def _update(self, var, fn, *args, **kwargs):
diff --git a/tensorflow/python/training/distribute_test.py b/tensorflow/python/training/distribute_test.py
index 0a4f19c31f..694145ede7 100644
--- a/tensorflow/python/training/distribute_test.py
+++ b/tensorflow/python/training/distribute_test.py
@@ -29,6 +29,14 @@ class _TestTowerContext(distribute.TowerContext):
return kwargs["test_arg"]
+def _get_test_variable(name, synchronization, aggregation):
+ return {
+ "name": name,
+ "synchronization": synchronization,
+ "aggregation": aggregation
+ }
+
+
class _TestStrategy(distribute.DistributionStrategy):
def _call_for_each_tower(self, fn, *args, **kwargs):
@@ -36,7 +44,8 @@ class _TestStrategy(distribute.DistributionStrategy):
return fn(*args, **kwargs)
def _create_variable(self, next_creator, *args, **kwargs):
- return kwargs["name"]
+ return _get_test_variable(kwargs["name"], kwargs["synchronization"],
+ kwargs["aggregation"])
def _assert_in_default_state(t):
@@ -61,7 +70,11 @@ class TestStrategyTest(test.TestCase):
self.assertTrue(distribute.has_distribution_strategy())
self.assertIs(dist, distribute.get_distribution_strategy())
self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
- self.assertEqual("bar", variable_scope.variable(1.0, name="bar"))
+ expected_value = _get_test_variable(
+ "bar", variable_scope.VariableSynchronization.AUTO,
+ variable_scope.VariableAggregation.NONE)
+ self.assertDictEqual(expected_value,
+ variable_scope.variable(1.0, name="bar"))
with self.assertRaises(RuntimeError):
dist.call_for_each_tower(run_fn)
@@ -77,7 +90,27 @@ class TestStrategyTest(test.TestCase):
self.assertIs(dist, distribute.get_cross_tower_context())
self.assertTrue(distribute.has_distribution_strategy())
self.assertIs(dist, distribute.get_distribution_strategy())
- self.assertEqual("baz", variable_scope.variable(1.0, name="baz"))
+ expected_value = _get_test_variable(
+ "baz", variable_scope.VariableSynchronization.AUTO,
+ variable_scope.VariableAggregation.NONE)
+ self.assertDictEqual(expected_value,
+ variable_scope.variable(1.0, name="baz"))
+ _assert_in_default_state(self)
+
+ def testSettingSynchronizationAndAggregation(self):
+ _assert_in_default_state(self)
+ dist = _TestStrategy()
+ with dist.scope():
+ expected_value = _get_test_variable(
+ "baz", variable_scope.VariableSynchronization.ON_WRITE,
+ variable_scope.VariableAggregation.MEAN)
+ self.assertDictEqual(
+ expected_value,
+ variable_scope.variable(
+ 1.0,
+ name="baz",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=variable_scope.VariableAggregation.MEAN))
_assert_in_default_state(self)
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index fe9ffde11c..784c9ddd1b 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -461,7 +461,8 @@ class Optimizer(
# Have to be careful to call distribute_lib.get_loss_reduction()
# *after* loss() is evaluated, so we know what loss reduction it uses.
# TODO(josh11b): Test that we handle weight decay in a reasonable way.
- if distribute_lib.get_loss_reduction() == "mean":
+ if distribute_lib.get_loss_reduction(
+ ) == variable_scope.VariableAggregation.MEAN:
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
loss_value *= (1. / num_towers)
@@ -478,7 +479,8 @@ class Optimizer(
"be a function when eager execution is enabled.")
# Scale loss if using a "mean" loss reduction and multiple towers.
- if distribute_lib.get_loss_reduction() == "mean":
+ if distribute_lib.get_loss_reduction(
+ ) == variable_scope.VariableAggregation.MEAN:
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
loss *= (1. / num_towers)
@@ -649,7 +651,8 @@ class Optimizer(
towers. If `global_step` was not None, that operation also
increments `global_step`.
"""
- reduced_grads = distribution.batch_reduce("sum", grads_and_vars)
+ reduced_grads = distribution.batch_reduce(
+ variable_scope.VariableAggregation.SUM, grads_and_vars)
var_list = [v for _, v in grads_and_vars]
grads_and_vars = zip(reduced_grads, var_list)
# Note that this is called in a cross-tower context.
diff --git a/tensorflow/tools/api/golden/tensorflow.-variable-aggregation.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable-aggregation.pbtxt
new file mode 100644
index 0000000000..36b534af36
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.-variable-aggregation.pbtxt
@@ -0,0 +1,16 @@
+path: "tensorflow.VariableAggregation"
+tf_class {
+ is_instance: "<enum \'VariableAggregation\'>"
+ member {
+ name: "MEAN"
+ mtype: "<enum \'VariableAggregation\'>"
+ }
+ member {
+ name: "NONE"
+ mtype: "<enum \'VariableAggregation\'>"
+ }
+ member {
+ name: "SUM"
+ mtype: "<enum \'VariableAggregation\'>"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt
index 8e539069da..ec1f72453f 100644
--- a/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt
@@ -56,7 +56,7 @@ tf_class {
}
member_method {
name: "get_variable"
- argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "global_variables"
diff --git a/tensorflow/tools/api/golden/tensorflow.-variable-synchronization.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable-synchronization.pbtxt
new file mode 100644
index 0000000000..7589bb2888
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.-variable-synchronization.pbtxt
@@ -0,0 +1,20 @@
+path: "tensorflow.VariableSynchronization"
+tf_class {
+ is_instance: "<enum \'VariableSynchronization\'>"
+ member {
+ name: "AUTO"
+ mtype: "<enum \'VariableSynchronization\'>"
+ }
+ member {
+ name: "NONE"
+ mtype: "<enum \'VariableSynchronization\'>"
+ }
+ member {
+ name: "ON_READ"
+ mtype: "<enum \'VariableSynchronization\'>"
+ }
+ member {
+ name: "ON_WRITE"
+ mtype: "<enum \'VariableSynchronization\'>"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 9b38d0e2fe..5470164a5b 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -261,10 +261,18 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "VariableAggregation"
+ mtype: "<class \'enum.EnumMeta\'>"
+ }
+ member {
name: "VariableScope"
mtype: "<type \'type\'>"
}
member {
+ name: "VariableSynchronization"
+ mtype: "<class \'enum.EnumMeta\'>"
+ }
+ member {
name: "WholeFileReader"
mtype: "<type \'type\'>"
}
@@ -1146,7 +1154,7 @@ tf_module {
}
member_method {
name: "get_local_variable"
- argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'synchronization\', \'aggregation\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\', \'None\'], "
}
member_method {
name: "get_seed"
@@ -1162,7 +1170,7 @@ tf_module {
}
member_method {
name: "get_variable"
- argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "get_variable_scope"