aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 17:19:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 17:19:43 -0700
commitdfa0447740b42edff5ae2d76d5957aa688ae8053 (patch)
treed2a94a6385280f239603d7cf70b32f10e8780673 /tensorflow/contrib/cudnn_rnn
parent1d78936a3989f6ee5a9945746cd329c37e82287c (diff)
parentb127c201cda558db21ce5f48f5899593d73da46b (diff)
Merge pull request #21114 from yongtang:07242018-CudnnRNNParamsSize
PiperOrigin-RevId: 213726710
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
index 5a667485be..c59d3682d4 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
@@ -413,6 +413,31 @@ class CudnnRNNTestParamsSize(TensorFlowTestCase):
self._testOneLSTMParamsSize(num_layers, num_units, input_size,
direction)
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testLSTMParamsSizeShape(self):
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ constant_op.constant([4]), 200, 200,
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ 4, constant_op.constant([200]), 200,
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ 4, 200, constant_op.constant([200]),
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+
class CudnnRNNTestInference(TensorFlowTestCase):