diff options
author | Yong Tang <yong.tang.github@outlook.com> | 2018-07-25 04:28:08 +0000 |
---|---|---|
committer | Yong Tang <yong.tang.github@outlook.com> | 2018-07-27 07:20:25 +0000 |
commit | 01387ccddcf5c23d48c5745f4a6a49a670f528aa (patch) | |
tree | 0d78a2e013eca5f396a36c7c2825a5dff224acf6 /tensorflow/contrib/cudnn_rnn | |
parent | 27de8e717c1bec91398f5a6be6c7287b657fc960 (diff) |
Add test cases for shape function of CudnnRNNParamsSize
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 5a667485be..675b7ce185 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -413,6 +413,28 @@ class CudnnRNNTestParamsSize(TensorFlowTestCase): self._testOneLSTMParamsSize(num_layers, num_units, input_size, direction) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testLSTMParamsSizeShape(self): + with self.assertRaisesRegexp(ValueError, "Shape must be rank 0 but is rank 1"): + model = _CreateModel( + cudnn_rnn_ops.CUDNN_LSTM, + constant_op.constant([4]), 200, 200, + direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + params_size = model.params_size() + with self.assertRaisesRegexp(ValueError, "Shape must be rank 0 but is rank 1"): + model = _CreateModel( + cudnn_rnn_ops.CUDNN_LSTM, + 4, constant_op.constant([200]), 200, + direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + params_size = model.params_size() + with self.assertRaisesRegexp(ValueError, "Shape must be rank 0 but is rank 1"): + model = _CreateModel( + cudnn_rnn_ops.CUDNN_LSTM, + 4, 200, constant_op.constant([200]), + direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + params_size = model.params_size() + class CudnnRNNTestInference(TensorFlowTestCase): |