aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-06-26 11:25:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-26 11:28:45 -0700
commitbfda539bef38845809e3b0c5930458dc500d505d (patch)
treeca88e6879357c02b08263cf81e197a8dc47efae1 /tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
parentd10213099df42d7138dd7479264e4c987a3d870f (diff)
Enable assign, assign_add and assign_sub to be called on Mirrored Variables in cross tower and tower context.
PiperOrigin-RevId: 202162272
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py216
1 files changed, 216 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 647cf953d7..8d474124b7 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -32,12 +32,14 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.layers import core
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
+
GPU_TEST = "test_gpu" in sys.argv[0]
@@ -118,6 +120,24 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
expected = sum(range(len(dist.worker_devices)))
self.assertEqual(expected, self.evaluate(unwrapped[0]))
+ @test_util.run_in_graph_and_eager_modes()
+ def testReduceToMultipleDestinations(self):
+ if not GPU_TEST:
+ self.skipTest("Not GPU test")
+
+ devices = ["/device:GPU:0"]
+ if GPU_TEST:
+ self.assertGreater(context.num_gpus(), 0)
+ print(self.id().split(".")[-1], "devices:", ", ".join(devices))
+
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ reduced = dist.reduce("sum", 1.0, destinations=["/device:CPU:0",
+ "/device:GPU:0"])
+ unwrapped = dist.unwrap(reduced)
+ self.assertEqual(2, len(unwrapped))
+ self.assertEqual(1.0, self.evaluate(unwrapped[0]))
+
class MirroredStrategyVariableCreationTest(test.TestCase):
@@ -581,5 +601,201 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
self.assertEquals(10.0, self.evaluate(ret_v_sum))
+class MirroredVariableUpdateTest(test.TestCase):
+ # The following tests check assign, assign_add and assign_sub on Mirrored
+ # variables in tower and cross tower context.
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = True
+
+ def _skip_eager_if_gpus_less_than(self, num_gpus):
+ if context.num_gpus() < num_gpus and context.executing_eagerly():
+ self.skipTest("Enough GPUs not available for this test in eager mode.")
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignMirroredVarTowerContextWithoutAggregationType(self):
+ # Test that we always have an aggregation type set on the mirrored variable
+ # if we assign to it in tower mode.
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ v = variable_scope.variable(1.0, name="foo")
+ return v
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+
+ def model_fn():
+ return mirrored_var.assign(5.0)
+
+ with self.assertRaisesRegexp(
+ ValueError, "You must specify an aggregation method to update a "
+ "MirroredVariable in Tower Context."):
+ self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignMirroredVarTowerContextWithSum(self):
+ # Test that we don't reduce a non-per-device value with the "sum"
+ # aggregation type.
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ v = variable_scope.variable(1.0, name="foo")
+ return v
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
+ # aggregation method.
+ mirrored_var._aggregation_method = "sum"
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+
+ def model_fn():
+ return mirrored_var.assign(5.0)
+
+ with self.assertRaisesRegexp(
+ ValueError, "A non PerDevice value cannot be reduced with the given "
+ "method_string."):
+ self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignMirroredVarCrossTowerContext(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ return variable_scope.variable(1.0, name="foo")
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEquals(1.0, self.evaluate(mirrored_var))
+ mirrored_var_result = self.evaluate(mirrored_var.assign(6.0))
+ self.assertEquals(6.0, mirrored_var_result)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignMirroredVarTowerContext(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ return variable_scope.variable(1.0, name="foo")
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
+ # aggregation method.
+ mirrored_var._aggregation_method = "mean"
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEquals(1.0, self.evaluate(mirrored_var))
+
+ def model_fn():
+ value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
+ mirrored_var.dtype)
+ return mirrored_var.assign(value)
+
+ self.evaluate(dist.unwrap(dist.call_for_each_tower(
+ model_fn, run_concurrently=False)))
+ self.assertEquals(0.5, self.evaluate(mirrored_var))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignAddMirroredVarCrossTowerContext(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ return variable_scope.variable(1.0, name="foo")
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEquals(1.0, self.evaluate(mirrored_var))
+ mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0))
+ self.assertEquals(7.0, mirrored_var_result)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignAddMirroredVarTowerContext(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ return variable_scope.variable(1.0, name="foo")
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
+ # aggregation method.
+ mirrored_var._aggregation_method = "mean"
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEquals(1.0, self.evaluate(mirrored_var))
+
+ def model_fn():
+ value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
+ mirrored_var.dtype)
+ return mirrored_var.assign_add(value)
+
+ self.evaluate(dist.unwrap(dist.call_for_each_tower(
+ model_fn, run_concurrently=False)))
+ self.assertEquals(1.5, self.evaluate(mirrored_var))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignSubMirroredVarCrossTowerContext(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ return variable_scope.variable(5.0, name="foo")
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEquals(5.0, self.evaluate(mirrored_var))
+ mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
+ self.assertEquals(3.0, mirrored_var_result)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignSubMirroredVarTowerContext(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ return variable_scope.variable(5.0, name="foo")
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the
+ # aggregation method.
+ mirrored_var._aggregation_method = "mean"
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEquals(5.0, self.evaluate(mirrored_var))
+
+ def model_fn():
+ value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
+ mirrored_var.dtype)
+ return mirrored_var.assign_sub(value)
+
+ self.evaluate(dist.unwrap(dist.call_for_each_tower(
+ model_fn, run_concurrently=False)))
+ self.assertEquals(4.5, self.evaluate(mirrored_var))
+
+
if __name__ == "__main__":
test.main()