diff options
author | Patrick Nguyen <drpng@google.com> | 2018-05-01 14:28:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-01 14:33:20 -0700 |
commit | 325d0ef21a48bea1cc618a2bd24a9776de417ce5 (patch) | |
tree | d41cf6304071e95bebd5747ca87dfca571e98634 /tensorflow/contrib/rnn | |
parent | 46bf1e8934b3bc8edeff3f218a50b0ee5806e96b (diff) |
Merge changes from github.
PiperOrigin-RevId: 194997009
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index de5df91292..ba4933ddf7 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -307,6 +307,21 @@ class LSTMTest(test.TestCase): self._seed = 23489 np.random.seed(self._seed) + def testDType(self): + # Test case for GitHub issue 16228 + # Not passing dtype in constructor results in default float32 + lstm = rnn_cell.LSTMCell(10) + input_tensor = array_ops.ones([10, 50]) + lstm.build(input_tensor.get_shape()) + self.assertEqual(lstm._bias.dtype, dtypes.float32_ref) + + # Explicitly pass dtype in constructor + for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: + lstm = rnn_cell.LSTMCell(10, dtype=dtype) + input_tensor = array_ops.ones([10, 50]) + lstm.build(input_tensor.get_shape()) + self.assertEqual(lstm._bias.dtype, dtype._as_ref) + def testNoProjNoSharding(self): num_units = 3 input_size = 5 |