diff options
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py | 102 |
1 files changed, 100 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index a12ff662db..830681a4ce 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -38,6 +38,7 @@ from tensorflow.python.layers import core from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import device_util @@ -128,6 +129,25 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): expected = sum(range(len(dist.worker_devices))) self.assertEqual(expected, self.evaluate(unwrapped[0])) + @test_util.run_in_graph_and_eager_modes + def testReduceOnlyFirstTowerUpdates(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + + def run_fn(device_id): + return constant_op.constant(3 + 5 * device_id) + + dist = self._get_distribution_strategy() + with dist.scope(): + result = dist.call_for_each_tower(run_fn, dist.worker_device_index) + reduced = dist.reduce( + variable_scope.VariableAggregation.ONLY_FIRST_TOWER, + result, + destinations="/device:CPU:0") + unwrapped = dist.unwrap(reduced) + self.assertEqual(1, len(unwrapped)) + self.assertEqual(3, self.evaluate(unwrapped[0])) + @test_util.run_in_graph_and_eager_modes() def testReduceToMultipleDestinations(self): if not GPU_TEST: @@ -384,6 +404,84 @@ class MirroredStrategyVariableCreationTest(test.TestCase): v3.aggregation) @test_util.run_in_graph_and_eager_modes(config=config) + def testOnlyFirstTowerUpdatesVariables(self): + self._skip_eager_if_gpus_less_than(1) + + def create_fn(): + aggregation = variable_scope.VariableAggregation.ONLY_FIRST_TOWER + v0 = variable_scope.variable( + 2.0, + name="on_read", + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=aggregation) + v1 = variable_scope.variable( + 3.0, + name="on_write", + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=aggregation) + return v0, v1 + + devices = ["/device:GPU:0", "/device:CPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + v0, v1 = dist.call_for_each_tower(create_fn, run_concurrently=False) + self.evaluate(v0.initializer) + self.assertEqual(2.0, self.evaluate(v0.get(devices[0]))) + self.assertEqual(2.0, self.evaluate(v0.get(devices[1]))) + self.assertEqual(2.0, self.evaluate(dist.read_var(v0))) + self.evaluate(v1.initializer) + self.assertEqual(3.0, self.evaluate(v1.get(devices[0]))) + self.assertEqual(3.0, self.evaluate(v1.get(devices[1]))) + self.assertEqual(3.0, self.evaluate(dist.read_var(v1))) + + # Update using the assign_add member function. + def update_member_fn(device_id): + update0 = v0.assign_add(5.0 * (device_id + 1)) + update1 = v1.assign_add(7.0 * (device_id + 1)) + return update0, update1 + + update0a, update1a = dist.call_for_each_tower( + update_member_fn, dist.worker_device_index, run_concurrently=False) + + # Update "sync on read" variable. + self.evaluate(dist.group(update0a)) + self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0]))) + # Writes are not synchronized for "sync on read" variables, + # so device[1] can end up with a different value. + self.assertEqual(2.0 + 2*5.0, self.evaluate(v0.get(devices[1]))) + # Always reads from device 0. + self.assertEqual(2.0 + 5.0, self.evaluate(dist.read_var(v0))) + + # Update "sync on write" variable. + self.evaluate(dist.group(update1a)) + self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0]))) + # Writes are synchronized for v1, only the argument to assign_add on + # device[0] is used. + self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1]))) + self.assertEqual(3.0 + 7.0, self.evaluate(dist.read_var(v1))) + + # Update using state_ops.assign_add global function. + def update_state_ops_fn(device_id): + update0 = state_ops.assign_add(v0, 11.0 * (device_id + 1)) + update1 = state_ops.assign_add(v1, 13.0 * (device_id + 1)) + return update0, update1 + + update0b, update1b = dist.call_for_each_tower( + update_state_ops_fn, dist.worker_device_index, run_concurrently=False) + self.evaluate(dist.group(update0b)) + + # Update "sync on read" variable. + self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0]))) + self.assertEqual(2.0 + 2*5.0 + 2*11.0, self.evaluate(v0.get(devices[1]))) + self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(dist.read_var(v0))) + + # Update "sync on write" variable. + self.evaluate(dist.group(update1b)) + self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0]))) + self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1]))) + self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(dist.read_var(v1))) + + @test_util.run_in_graph_and_eager_modes(config=config) def testNoneSynchronizationWithGetVariable(self): self._skip_eager_if_gpus_less_than(1) devices = ["/device:CPU:0", "/device:GPU:0"] @@ -804,8 +902,8 @@ class MirroredVariableUpdateTest(test.TestCase): return mirrored_var.assign(5.0) with self.assertRaisesRegexp( - ValueError, "A non-DistributedValues value cannot be reduced with " - "the given aggregation."): + ValueError, "A non-DistributedValues value 5.0 cannot be reduced " + "with the given aggregation VariableAggregation.SUM."): self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) @test_util.run_in_graph_and_eager_modes(config=config) |