aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-06-29 18:02:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 18:05:25 -0700
commitc290930ec1beacbcac414b43b3367dd44ffbd303 (patch)
treeb1136e7c32718a6f1f9ebfde3073c88546078de6 /tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
parenta520735d205ca5678fc8a371ea1add00413907fe (diff)
Add `synchronization` and `aggregation` args to get_variable(). These args will be used for distributed variables.
Add Enum `VariableSynchronization` with values for `synchronization`: AUTO, UNREPLICATED, ON_WRITE, ON_READ Add Enum `VariableAggregation` with values for `aggregation`: NONE, SUM, MEAN. Replace all the aggregation methods strings in distribution strategy to the enum values. Update Mirrored strategy to use these parameters to decide on whether a variable should be Mirrored or TowerLocal. Update different distribution strategy value types to use the `VariableAggregation` Enum PiperOrigin-RevId: 202736077
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py180
1 files changed, 147 insertions, 33 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 8d474124b7..c02817f461 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -114,7 +114,10 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
dist = self._get_distribution_strategy()
with dist.scope():
result = dist.call_for_each_tower(run_fn, dist.worker_device_index)
- reduced = dist.reduce("sum", result, destinations="/device:CPU:0")
+ reduced = dist.reduce(
+ variable_scope.VariableAggregation.SUM,
+ result,
+ destinations="/device:CPU:0")
unwrapped = dist.unwrap(reduced)
self.assertEqual(1, len(unwrapped))
expected = sum(range(len(dist.worker_devices)))
@@ -132,8 +135,10 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
dist = mirrored_strategy.MirroredStrategy(devices)
with dist.scope():
- reduced = dist.reduce("sum", 1.0, destinations=["/device:CPU:0",
- "/device:GPU:0"])
+ reduced = dist.reduce(
+ variable_scope.VariableAggregation.SUM,
+ 1.0,
+ destinations=["/device:CPU:0", "/device:GPU:0"])
unwrapped = dist.unwrap(reduced)
self.assertEqual(2, len(unwrapped))
self.assertEqual(1.0, self.evaluate(unwrapped[0]))
@@ -284,18 +289,68 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
self.assertEquals("common/dense" + suffix + "/bias:0", bias.name)
@test_util.run_in_graph_and_eager_modes(config=config)
+ def testWithVariableAndVariableScope(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ def model_fn():
+ v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
+ with variable_scope.variable_scope("common"):
+ v1 = variable_scope.variable(1.0, name="var1")
+ # This will pause the current thread, and execute the other thread.
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ v2 = variable_scope.variable(
+ 1.0,
+ name="var2",
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ v3 = variable_scope.variable(
+ 1.0,
+ name="var3",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=variable_scope.VariableAggregation.MEAN)
+
+ return v0, v1, v2, v3
+
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ v = variable_scope.variable(1.0, name="var-main0")
+ self.assertEquals("var-main0:0", v.name)
+
+ result = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ self.assertEquals(4, len(result))
+ v0, v1, v2, v3 = result
+ self.assertIsInstance(v0, values.MirroredVariable)
+ self.assertEquals("var0:0", v0.name)
+ self.assertIsInstance(v1, values.MirroredVariable)
+ self.assertEquals("common/var1:0", v1.name)
+ self.assertIsInstance(v2, values.TowerLocalVariable)
+ self.assertEquals("common/var2:0", v2.name)
+ self.assertEquals(variable_scope.VariableAggregation.SUM, v2.aggregation)
+ self.assertIsInstance(v3, values.MirroredVariable)
+ self.assertEquals("common/var3:0", v3.name)
+ self.assertEquals(variable_scope.VariableAggregation.MEAN, v3.aggregation)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
def testWithGetVariableAndVariableScope(self):
self._skip_eager_if_gpus_less_than(1)
def model_fn():
- v0 = variable_scope.get_variable("var-thread0", [1])
+ v0 = variable_scope.get_variable("var0", [1])
with variable_scope.variable_scope("common"):
- v1 = variable_scope.get_variable("var-thread1", [1])
+ v1 = variable_scope.get_variable("var1", [1])
# This will pause the current thread, and execute the other thread.
distribute_lib.get_tower_context().merge_call(lambda _: _)
- v2 = variable_scope.get_variable("var-thread2", [1])
+ v2 = variable_scope.get_variable(
+ "var2", [1],
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ v3 = variable_scope.get_variable(
+ "var3", [1],
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=variable_scope.VariableAggregation.MEAN)
- return v0, v1, v2
+ return v0, v1, v2, v3
devices = ["/device:CPU:0", "/device:GPU:0"]
dist = mirrored_strategy.MirroredStrategy(devices)
@@ -305,14 +360,78 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
self.assertEquals("main/var-main0:0", v.name)
result = dist.call_for_each_tower(model_fn, run_concurrently=False)
- self.assertEquals(3, len(result))
- v0, v1, v2 = result
+ self.assertEquals(4, len(result))
+ v0, v1, v2, v3 = result
self.assertIsInstance(v0, values.MirroredVariable)
- self.assertEquals("main/var-thread0:0", v0.name)
+ self.assertEquals("main/var0:0", v0.name)
self.assertIsInstance(v1, values.MirroredVariable)
- self.assertEquals("main/common/var-thread1:0", v1.name)
- self.assertIsInstance(v2, values.MirroredVariable)
- self.assertEquals("main/common/var-thread2:0", v2.name)
+ self.assertEquals("main/common/var1:0", v1.name)
+ self.assertIsInstance(v2, values.TowerLocalVariable)
+ self.assertEquals("main/common/var2:0", v2.name)
+ self.assertEquals(variable_scope.VariableAggregation.SUM,
+ v2.aggregation)
+ self.assertIsInstance(v3, values.MirroredVariable)
+ self.assertEquals("main/common/var3:0", v3.name)
+ self.assertEquals(variable_scope.VariableAggregation.MEAN,
+ v3.aggregation)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testInvalidSynchronizationWithGetVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please change "
+ "the `synchronization` for variable: v"):
+ variable_scope.get_variable(
+ "v", [1],
+ synchronization=variable_scope.VariableSynchronization.NONE)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testInvalidSynchronizationWithVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please change "
+ "the `synchronization` for variable: v"):
+ variable_scope.variable(
+ 1.0,
+ name="v",
+ synchronization=variable_scope.VariableSynchronization.NONE)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testInvalidAggregationWithGetVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "Invalid variable aggregation mode: invalid for "
+ "variable: v"):
+ variable_scope.get_variable(
+ "v", [1],
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation="invalid")
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testInvalidAggregationWithVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "Invalid variable aggregation mode: invalid for "
+ "variable: v"):
+ variable_scope.variable(
+ 1.0,
+ name="v",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation="invalid")
@test_util.run_in_graph_and_eager_modes(config=config)
def testThreeDevices(self):
@@ -362,9 +481,11 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn(device_id):
tower_context = distribute_lib.get_tower_context()
- with tower_context.tower_local_var_scope("sum"):
+ with tower_context.tower_local_var_scope(
+ variable_scope.VariableAggregation.SUM):
v_sum = variable_scope.variable(1.0)
- with tower_context.tower_local_var_scope("mean"):
+ with tower_context.tower_local_var_scope(
+ variable_scope.VariableAggregation.MEAN):
v_mean = variable_scope.variable(4.0)
self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
self.assertTrue(isinstance(v_mean, values.TowerLocalVariable))
@@ -569,7 +690,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
tower_context = distribute_lib.get_tower_context()
- with tower_context.tower_local_var_scope("sum"):
+ with tower_context.tower_local_var_scope(
+ variable_scope.VariableAggregation.SUM):
v_sum = variable_scope.variable(1.0)
self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
return v_sum
@@ -642,7 +764,8 @@ class MirroredVariableUpdateTest(test.TestCase):
# aggregation type.
self._skip_eager_if_gpus_less_than(1)
def var_fn():
- v = variable_scope.variable(1.0, name="foo")
+ v = variable_scope.variable(
+ 1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -650,9 +773,6 @@ class MirroredVariableUpdateTest(test.TestCase):
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
- # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
- # aggregation method.
- mirrored_var._aggregation_method = "sum"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
@@ -661,7 +781,7 @@ class MirroredVariableUpdateTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError, "A non PerDevice value cannot be reduced with the given "
- "method_string."):
+ "aggregation."):
self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
@test_util.run_in_graph_and_eager_modes(config=config)
@@ -685,16 +805,14 @@ class MirroredVariableUpdateTest(test.TestCase):
def testAssignMirroredVarTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
- return variable_scope.variable(1.0, name="foo")
+ return variable_scope.variable(
+ 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
- # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
- # aggregation method.
- mirrored_var._aggregation_method = "mean"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(1.0, self.evaluate(mirrored_var))
@@ -729,16 +847,14 @@ class MirroredVariableUpdateTest(test.TestCase):
def testAssignAddMirroredVarTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
- return variable_scope.variable(1.0, name="foo")
+ return variable_scope.variable(
+ 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
- # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
- # aggregation method.
- mirrored_var._aggregation_method = "mean"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(1.0, self.evaluate(mirrored_var))
@@ -773,16 +889,14 @@ class MirroredVariableUpdateTest(test.TestCase):
def testAssignSubMirroredVarTowerContext(self):
self._skip_eager_if_gpus_less_than(1)
def var_fn():
- return variable_scope.variable(5.0, name="foo")
+ return variable_scope.variable(
+ 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
- # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
- # aggregation method.
- mirrored_var._aggregation_method = "mean"
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(5.0, self.evaluate(mirrored_var))