diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-24 13:44:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-24 13:49:25 -0700 |
commit | c31402273e6d60b4a53b28e372ef6c722a710495 (patch) | |
tree | 3e75a6d372c095248e0bac41d25259eaf161378f /tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py | |
parent | 94d267dfa6ee106dbf57c42a452925749bbe2f1a (diff) |
Make sure all assignments to a mirrored variable happen. Failure mode
being fixed is when you session.run(assignment) and assignment is the
MirroredVariable value returned by ResourceVariable.assign*, only one
of the components of assignment is executed. Now that it is safer,
allow session.run() on Mirrored values (not just MirroredVariables).
PiperOrigin-RevId: 210149461
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 1dfd80fb49..ac2697958d 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -888,8 +888,18 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) - mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0)) + + # read_value == True + mirrored_var_result = self.evaluate( + mirrored_var.assign_add(6.0, read_value=True)) self.assertEquals(7.0, mirrored_var_result) + self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + + # read_value == False + self.evaluate(mirrored_var.assign_add(2.0, read_value=False)) + self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) @test_util.run_in_graph_and_eager_modes(config=config) def testAssignAddMirroredVarTowerContext(self): @@ -956,6 +966,8 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertEquals(5.0, self.evaluate(mirrored_var)) mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0)) self.assertEquals(3.0, mirrored_var_result) + self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) @test_util.run_in_graph_and_eager_modes(config=config) def testAssignSubMirroredVarTowerContext(self): |