From c290930ec1beacbcac414b43b3367dd44ffbd303 Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Fri, 29 Jun 2018 18:02:18 -0700 Subject: Add `synchronization` and `aggregation` args to get_variable(). These args will be used for distributed variables. Add Enum `VariableSynchronization` with values for `synchronization`: AUTO, UNREPLICATED, ON_WRITE, ON_READ Add Enum `VariableAggregation` with values for `aggregation`: NONE, SUM, MEAN. Replace all the aggregation methods strings in distribution strategy to the enum values. Update Mirrored strategy to use these parameters to decide on whether a variable should be Mirrored or TowerLocal. Update different distribution strategy value types to use the `VariableAggregation` Enum PiperOrigin-RevId: 202736077 --- .../python/mirrored_strategy_multigpu_test.py | 180 +++++++++++++++++---- 1 file changed, 147 insertions(+), 33 deletions(-) (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py') diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 8d474124b7..c02817f461 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -114,7 +114,10 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): dist = self._get_distribution_strategy() with dist.scope(): result = dist.call_for_each_tower(run_fn, dist.worker_device_index) - reduced = dist.reduce("sum", result, destinations="/device:CPU:0") + reduced = dist.reduce( + variable_scope.VariableAggregation.SUM, + result, + destinations="/device:CPU:0") unwrapped = dist.unwrap(reduced) self.assertEqual(1, len(unwrapped)) expected = sum(range(len(dist.worker_devices))) @@ -132,8 +135,10 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): dist = mirrored_strategy.MirroredStrategy(devices) with dist.scope(): - reduced = dist.reduce("sum", 1.0, destinations=["/device:CPU:0", - "/device:GPU:0"]) + reduced = dist.reduce( + variable_scope.VariableAggregation.SUM, + 1.0, + destinations=["/device:CPU:0", "/device:GPU:0"]) unwrapped = dist.unwrap(reduced) self.assertEqual(2, len(unwrapped)) self.assertEqual(1.0, self.evaluate(unwrapped[0])) @@ -283,19 +288,69 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertIsInstance(bias, values.MirroredVariable) self.assertEquals("common/dense" + suffix + "/bias:0", bias.name) + @test_util.run_in_graph_and_eager_modes(config=config) + def testWithVariableAndVariableScope(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + v0 = variable_scope.variable(1.0, name="var0", aggregation=None) + with variable_scope.variable_scope("common"): + v1 = variable_scope.variable(1.0, name="var1") + # This will pause the current thread, and execute the other thread. + distribute_lib.get_tower_context().merge_call(lambda _: _) + v2 = variable_scope.variable( + 1.0, + name="var2", + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + v3 = variable_scope.variable( + 1.0, + name="var3", + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=variable_scope.VariableAggregation.MEAN) + + return v0, v1, v2, v3 + + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + v = variable_scope.variable(1.0, name="var-main0") + self.assertEquals("var-main0:0", v.name) + + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(4, len(result)) + v0, v1, v2, v3 = result + self.assertIsInstance(v0, values.MirroredVariable) + self.assertEquals("var0:0", v0.name) + self.assertIsInstance(v1, values.MirroredVariable) + self.assertEquals("common/var1:0", v1.name) + self.assertIsInstance(v2, values.TowerLocalVariable) + self.assertEquals("common/var2:0", v2.name) + self.assertEquals(variable_scope.VariableAggregation.SUM, v2.aggregation) + self.assertIsInstance(v3, values.MirroredVariable) + self.assertEquals("common/var3:0", v3.name) + self.assertEquals(variable_scope.VariableAggregation.MEAN, v3.aggregation) + @test_util.run_in_graph_and_eager_modes(config=config) def testWithGetVariableAndVariableScope(self): self._skip_eager_if_gpus_less_than(1) def model_fn(): - v0 = variable_scope.get_variable("var-thread0", [1]) + v0 = variable_scope.get_variable("var0", [1]) with variable_scope.variable_scope("common"): - v1 = variable_scope.get_variable("var-thread1", [1]) + v1 = variable_scope.get_variable("var1", [1]) # This will pause the current thread, and execute the other thread. distribute_lib.get_tower_context().merge_call(lambda _: _) - v2 = variable_scope.get_variable("var-thread2", [1]) + v2 = variable_scope.get_variable( + "var2", [1], + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + v3 = variable_scope.get_variable( + "var3", [1], + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=variable_scope.VariableAggregation.MEAN) - return v0, v1, v2 + return v0, v1, v2, v3 devices = ["/device:CPU:0", "/device:GPU:0"] dist = mirrored_strategy.MirroredStrategy(devices) @@ -305,14 +360,78 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEquals("main/var-main0:0", v.name) result = dist.call_for_each_tower(model_fn, run_concurrently=False) - self.assertEquals(3, len(result)) - v0, v1, v2 = result + self.assertEquals(4, len(result)) + v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) - self.assertEquals("main/var-thread0:0", v0.name) + self.assertEquals("main/var0:0", v0.name) self.assertIsInstance(v1, values.MirroredVariable) - self.assertEquals("main/common/var-thread1:0", v1.name) - self.assertIsInstance(v2, values.MirroredVariable) - self.assertEquals("main/common/var-thread2:0", v2.name) + self.assertEquals("main/common/var1:0", v1.name) + self.assertIsInstance(v2, values.TowerLocalVariable) + self.assertEquals("main/common/var2:0", v2.name) + self.assertEquals(variable_scope.VariableAggregation.SUM, + v2.aggregation) + self.assertIsInstance(v3, values.MirroredVariable) + self.assertEquals("main/common/var3:0", v3.name) + self.assertEquals(variable_scope.VariableAggregation.MEAN, + v3.aggregation) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testInvalidSynchronizationWithGetVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please change " + "the `synchronization` for variable: v"): + variable_scope.get_variable( + "v", [1], + synchronization=variable_scope.VariableSynchronization.NONE) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testInvalidSynchronizationWithVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please change " + "the `synchronization` for variable: v"): + variable_scope.variable( + 1.0, + name="v", + synchronization=variable_scope.VariableSynchronization.NONE) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testInvalidAggregationWithGetVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "Invalid variable aggregation mode: invalid for " + "variable: v"): + variable_scope.get_variable( + "v", [1], + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation="invalid") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testInvalidAggregationWithVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "Invalid variable aggregation mode: invalid for " + "variable: v"): + variable_scope.variable( + 1.0, + name="v", + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation="invalid") @test_util.run_in_graph_and_eager_modes(config=config) def testThreeDevices(self): @@ -362,9 +481,11 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(device_id): tower_context = distribute_lib.get_tower_context() - with tower_context.tower_local_var_scope("sum"): + with tower_context.tower_local_var_scope( + variable_scope.VariableAggregation.SUM): v_sum = variable_scope.variable(1.0) - with tower_context.tower_local_var_scope("mean"): + with tower_context.tower_local_var_scope( + variable_scope.VariableAggregation.MEAN): v_mean = variable_scope.variable(4.0) self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) self.assertTrue(isinstance(v_mean, values.TowerLocalVariable)) @@ -569,7 +690,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): tower_context = distribute_lib.get_tower_context() - with tower_context.tower_local_var_scope("sum"): + with tower_context.tower_local_var_scope( + variable_scope.VariableAggregation.SUM): v_sum = variable_scope.variable(1.0) self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) return v_sum @@ -642,7 +764,8 @@ class MirroredVariableUpdateTest(test.TestCase): # aggregation type. self._skip_eager_if_gpus_less_than(1) def var_fn(): - v = variable_scope.variable(1.0, name="foo") + v = variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM) return v dist = mirrored_strategy.MirroredStrategy( @@ -650,9 +773,6 @@ class MirroredVariableUpdateTest(test.TestCase): with dist.scope(): mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) - # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the - # aggregation method. - mirrored_var._aggregation_method = "sum" self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) @@ -661,7 +781,7 @@ class MirroredVariableUpdateTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "A non PerDevice value cannot be reduced with the given " - "method_string."): + "aggregation."): self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) @test_util.run_in_graph_and_eager_modes(config=config) @@ -685,16 +805,14 @@ class MirroredVariableUpdateTest(test.TestCase): def testAssignMirroredVarTowerContext(self): self._skip_eager_if_gpus_less_than(1) def var_fn(): - return variable_scope.variable(1.0, name="foo") + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) - # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the - # aggregation method. - mirrored_var._aggregation_method = "mean" self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -729,16 +847,14 @@ class MirroredVariableUpdateTest(test.TestCase): def testAssignAddMirroredVarTowerContext(self): self._skip_eager_if_gpus_less_than(1) def var_fn(): - return variable_scope.variable(1.0, name="foo") + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) - # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the - # aggregation method. - mirrored_var._aggregation_method = "mean" self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -773,16 +889,14 @@ class MirroredVariableUpdateTest(test.TestCase): def testAssignSubMirroredVarTowerContext(self): self._skip_eager_if_gpus_less_than(1) def var_fn(): - return variable_scope.variable(5.0, name="foo") + return variable_scope.variable( + 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) - # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the - # aggregation method. - mirrored_var._aggregation_method = "mean" self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(5.0, self.evaluate(mirrored_var)) -- cgit v1.2.3