aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn
diff options
context:
space:
mode:
authorGravatar Patrick Nguyen <drpng@google.com>2018-05-01 14:28:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-01 14:33:20 -0700
commit325d0ef21a48bea1cc618a2bd24a9776de417ce5 (patch)
treed41cf6304071e95bebd5747ca87dfca571e98634 /tensorflow/contrib/rnn
parent46bf1e8934b3bc8edeff3f218a50b0ee5806e96b (diff)
Merge changes from github.
PiperOrigin-RevId: 194997009
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index de5df91292..ba4933ddf7 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -307,6 +307,21 @@ class LSTMTest(test.TestCase):
self._seed = 23489
np.random.seed(self._seed)
+ def testDType(self):
+ # Test case for GitHub issue 16228
+ # Not passing dtype in constructor results in default float32
+ lstm = rnn_cell.LSTMCell(10)
+ input_tensor = array_ops.ones([10, 50])
+ lstm.build(input_tensor.get_shape())
+ self.assertEqual(lstm._bias.dtype, dtypes.float32_ref)
+
+ # Explicitly pass dtype in constructor
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ lstm = rnn_cell.LSTMCell(10, dtype=dtype)
+ input_tensor = array_ops.ones([10, 50])
+ lstm.build(input_tensor.get_shape())
+ self.assertEqual(lstm._bias.dtype, dtype._as_ref)
+
def testNoProjNoSharding(self):
num_units = 3
input_size = 5