aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-05-15 13:20:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-15 13:22:27 -0700
commitf17620153c47370f30a84b99eaba82bef8cd7d8e (patch)
treef5c30f0e31d4241cb7d0c572a3195819fa1f5a44 /tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
parent103638433f16f31dbde3480504c4c0a33273cc64 (diff)
Handle delayed variable initialization in MirroredStrategy. Test with RNN layer.
Bug reported and solution suggested in #19069 PiperOrigin-RevId: 196718454
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 3635bd2e34..04d30860c4 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -28,9 +28,12 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.layers import core
+from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
@@ -436,6 +439,30 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
self.assertEquals("foo/" + name + ":0", v0.name)
self.assertEquals("tower_1/foo/" + name + ":0", v1.name)
+ def testDynamicRnnVariables(self):
+ def model_fn():
+ inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]])
+ cell_fw = rnn_cell_impl.LSTMCell(300)
+ cell_bw = rnn_cell_impl.LSTMCell(300)
+ (outputs, _) = rnn.bidirectional_dynamic_rnn(
+ cell_fw,
+ cell_bw,
+ inputs,
+ dtype=dtypes.float32)
+ return outputs
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with context.graph_mode(), dist.scope():
+ result = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ # Two variables are created by the RNN layer.
+ self.assertEquals(2, len(result))
+ for v in result:
+ self.assertIsInstance(v, values.DistributedValues)
+ _, v1 = dist.unwrap(v)
+ self.assertStartsWith(v1.name, "tower_1/")
+
if __name__ == "__main__":
test.main()