diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-08-06 19:19:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-06 19:23:02 -0700 |
commit | 576e0000d8aa227163ff8df9d2b9f2e656191c76 (patch) | |
tree | e7de2c4f87d1553a7b03417c9b48cdf7bc8baa8c /tensorflow/contrib/distribute | |
parent | c26c9e6caa3f1e3e5027031db61ebde9bc3ba706 (diff) |
Add comments to MirroredStrategy's reduce function. Also add more unit tests.
PiperOrigin-RevId: 207649000
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 11 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py | 69 |
2 files changed, 80 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 01b456d0d4..c5d6e978e7 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -186,12 +186,20 @@ def _reduce_non_distributed_value(distribution, aggregation, value, raise ValueError("You are passing a `DistributedValue` to " "`_reduce_non_distributed_value`, which is not allowed.") + # If the same value is present on all towers then the PerDevice value will + # be a single value. We also handle the case when `value` is a single value + # and equal to 0. if value == 0: return 0 + # If the aggregation type is MEAN, then this essentially means that the same + # value should be on all destinations. if aggregation == variable_scope.VariableAggregation.MEAN: return distribution.broadcast(value, destinations) cross_tower_ops_lib.validate_destinations(destinations) + # We do not support an aggregation type of SUM if the value is the same across + # all towers. We call this as part of assign functions for MirroredVariables + # and summing up identical values across towers is not clearly defined. if (len(distribution.worker_devices) != 1 or not cross_tower_ops_lib.check_destinations(destinations)): raise ValueError("A non-DistributedValues value cannot be reduced with the " @@ -386,6 +394,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def _reduce(self, aggregation, value, destinations): assert not isinstance(value, values.Mirrored) if not isinstance(value, values.DistributedValues): + # This function handles reducing values that are not PerDevice or Mirrored + # values. For example, the same value could be present on all towers in + # which case `value` would be a single value or value could be 0. return _reduce_non_distributed_value(self, aggregation, value, destinations) return self._get_cross_tower_ops().reduce( diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index e5e291a71f..e064cfe37d 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -842,6 +842,29 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertEquals(0.5, self.evaluate(mirrored_var)) @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContextWithSingleValue(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + return mirrored_var.assign(5.0) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + + @test_util.run_in_graph_and_eager_modes(config=config) def testAssignAddMirroredVarCrossTowerContext(self): self._skip_eager_if_gpus_less_than(1) def var_fn(): @@ -884,6 +907,29 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertEquals(1.5, self.evaluate(mirrored_var)) @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignAddMirroredVarTowerContextWithSingleValue(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + return mirrored_var.assign_add(5.0) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(6.0, self.evaluate(mirrored_var)) + + @test_util.run_in_graph_and_eager_modes(config=config) def testAssignSubMirroredVarCrossTowerContext(self): self._skip_eager_if_gpus_less_than(1) def var_fn(): @@ -925,6 +971,29 @@ class MirroredVariableUpdateTest(test.TestCase): model_fn, run_concurrently=False))) self.assertEquals(4.5, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignSubMirroredVarTowerContextWithSingleValue(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + + def model_fn(): + return mirrored_var.assign_sub(1.0) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(4.0, self.evaluate(mirrored_var)) + class MirroredAndTowerLocalVariableInitializerTest(test.TestCase): config = config_pb2.ConfigProto() |