aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-08-06 19:19:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-06 19:23:02 -0700
commit576e0000d8aa227163ff8df9d2b9f2e656191c76 (patch)
treee7de2c4f87d1553a7b03417c9b48cdf7bc8baa8c /tensorflow/contrib/distribute
parentc26c9e6caa3f1e3e5027031db61ebde9bc3ba706 (diff)
Add comments to MirroredStrategy's reduce function. Also add more unit tests.
PiperOrigin-RevId: 207649000
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py11
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py69
2 files changed, 80 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 01b456d0d4..c5d6e978e7 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -186,12 +186,20 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
raise ValueError("You are passing a `DistributedValue` to "
"`_reduce_non_distributed_value`, which is not allowed.")
+ # If the same value is present on all towers then the PerDevice value will
+ # be a single value. We also handle the case when `value` is a single value
+ # and equal to 0.
if value == 0:
return 0
+ # If the aggregation type is MEAN, then this essentially means that the same
+ # value should be on all destinations.
if aggregation == variable_scope.VariableAggregation.MEAN:
return distribution.broadcast(value, destinations)
cross_tower_ops_lib.validate_destinations(destinations)
+ # We do not support an aggregation type of SUM if the value is the same across
+ # all towers. We call this as part of assign functions for MirroredVariables
+ # and summing up identical values across towers is not clearly defined.
if (len(distribution.worker_devices) != 1 or
not cross_tower_ops_lib.check_destinations(destinations)):
raise ValueError("A non-DistributedValues value cannot be reduced with the "
@@ -386,6 +394,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def _reduce(self, aggregation, value, destinations):
assert not isinstance(value, values.Mirrored)
if not isinstance(value, values.DistributedValues):
+ # This function handles reducing values that are not PerDevice or Mirrored
+ # values. For example, the same value could be present on all towers in
+ # which case `value` would be a single value or value could be 0.
return _reduce_non_distributed_value(self, aggregation, value,
destinations)
return self._get_cross_tower_ops().reduce(
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index e5e291a71f..e064cfe37d 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -842,6 +842,29 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(0.5, self.evaluate(mirrored_var))
@test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignMirroredVarTowerContextWithSingleValue(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ return variable_scope.variable(
+ 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEquals(1.0, self.evaluate(mirrored_var))
+
+ def model_fn():
+ return mirrored_var.assign(5.0)
+
+ self.evaluate(dist.unwrap(dist.call_for_each_tower(
+ model_fn, run_concurrently=False)))
+ self.assertEquals(5.0, self.evaluate(mirrored_var))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
def testAssignAddMirroredVarCrossTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
@@ -884,6 +907,29 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(1.5, self.evaluate(mirrored_var))
@test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignAddMirroredVarTowerContextWithSingleValue(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ return variable_scope.variable(
+ 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEquals(1.0, self.evaluate(mirrored_var))
+
+ def model_fn():
+ return mirrored_var.assign_add(5.0)
+
+ self.evaluate(dist.unwrap(dist.call_for_each_tower(
+ model_fn, run_concurrently=False)))
+ self.assertEquals(6.0, self.evaluate(mirrored_var))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
def testAssignSubMirroredVarCrossTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
@@ -925,6 +971,29 @@ class MirroredVariableUpdateTest(test.TestCase):
model_fn, run_concurrently=False)))
self.assertEquals(4.5, self.evaluate(mirrored_var))
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignSubMirroredVarTowerContextWithSingleValue(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ return variable_scope.variable(
+ 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEquals(5.0, self.evaluate(mirrored_var))
+
+ def model_fn():
+ return mirrored_var.assign_sub(1.0)
+
+ self.evaluate(dist.unwrap(dist.call_for_each_tower(
+ model_fn, run_concurrently=False)))
+ self.assertEquals(4.0, self.evaluate(mirrored_var))
+
class MirroredAndTowerLocalVariableInitializerTest(test.TestCase):
config = config_pb2.ConfigProto()