aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-07-06 13:50:29 -0700
committerGravatar Yifei Feng <yifeif@google.com>2018-07-06 15:17:59 -0700
commit90fc5e3819ed62e93228a9c2c29dede0f0f8cfd6 (patch)
tree0e50e14646a382fbdf5edec988f9818bb93b12c0 /tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
parentd64754c5c768f26b6a95b350cfd8c7ded2590dc9 (diff)
Allow is_initialized and initializer to be called on MirroredVariables and TowerLocalVariables.
PiperOrigin-RevId: 203520287
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index b597bce035..15161b604a 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -922,5 +922,49 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(4.5, self.evaluate(mirrored_var))
+class MirroredAndTowerLocalVariableInitializerTest(test.TestCase):
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = True
+
+ def testAssignMirroredVarInitializer(self):
+ # This test is not eager compatible since in eager variables are initialized
+ # upon construction instead of once the initialization op is run.
+ with context.graph_mode():
+ def var_fn():
+ v = variable_scope.variable(1.0, name="foo")
+ return v
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.assertFalse(self.evaluate(mirrored_var.is_initialized()))
+ self.evaluate(mirrored_var.initializer)
+ self.assertTrue(self.evaluate(mirrored_var.is_initialized()))
+
+ def testAssignTowerLocalVarInitializer(self):
+ # This test is not eager compatible since in eager variables are initialized
+ # upon construction instead of once the initialization op is run.
+ with context.graph_mode():
+ def model_fn():
+ tower_context = distribute_lib.get_tower_context()
+ with tower_context.tower_local_var_scope(
+ variable_scope.VariableAggregation.SUM):
+ v_sum = variable_scope.variable(1.0)
+ self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
+ return v_sum
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ tower_local_var = dist.call_for_each_tower(model_fn)
+ self.assertTrue(isinstance(tower_local_var, values.TowerLocalVariable))
+ self.assertFalse(self.evaluate(tower_local_var.is_initialized()))
+ self.evaluate(tower_local_var.initializer)
+ self.assertTrue(self.evaluate(tower_local_var.is_initialized()))
+
if __name__ == "__main__":
test.main()