diff options
author | Priya Gupta <priyag@google.com> | 2018-05-15 13:20:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-15 13:22:27 -0700 |
commit | f17620153c47370f30a84b99eaba82bef8cd7d8e (patch) | |
tree | f5c30f0e31d4241cb7d0c572a3195819fa1f5a44 /tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py | |
parent | 103638433f16f31dbde3480504c4c0a33273cc64 (diff) |
Handle delayed variable initialization in MirroredStrategy. Test with RNN layer.
Bug reported and solution suggested in #19069
PiperOrigin-RevId: 196718454
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 3635bd2e34..04d30860c4 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -28,9 +28,12 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib @@ -436,6 +439,30 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEquals("foo/" + name + ":0", v0.name) self.assertEquals("tower_1/foo/" + name + ":0", v1.name) + def testDynamicRnnVariables(self): + def model_fn(): + inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) + cell_fw = rnn_cell_impl.LSTMCell(300) + cell_bw = rnn_cell_impl.LSTMCell(300) + (outputs, _) = rnn.bidirectional_dynamic_rnn( + cell_fw, + cell_bw, + inputs, + dtype=dtypes.float32) + return outputs + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + # Two variables are created by the RNN layer. + self.assertEquals(2, len(result)) + for v in result: + self.assertIsInstance(v, values.DistributedValues) + _, v1 = dist.unwrap(v) + self.assertStartsWith(v1.name, "tower_1/") + if __name__ == "__main__": test.main() |