diff options
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py | 39 |
1 files changed, 21 insertions, 18 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 675b7ce185..c59d3682d4 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -416,24 +416,27 @@ class CudnnRNNTestParamsSize(TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def testLSTMParamsSizeShape(self): - with self.assertRaisesRegexp(ValueError, "Shape must be rank 0 but is rank 1"): - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - constant_op.constant([4]), 200, 200, - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - params_size = model.params_size() - with self.assertRaisesRegexp(ValueError, "Shape must be rank 0 but is rank 1"): - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - 4, constant_op.constant([200]), 200, - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - params_size = model.params_size() - with self.assertRaisesRegexp(ValueError, "Shape must be rank 0 but is rank 1"): - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - 4, 200, constant_op.constant([200]), - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - params_size = model.params_size() + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + model = _CreateModel( + cudnn_rnn_ops.CUDNN_LSTM, + constant_op.constant([4]), 200, 200, + direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + params_size = model.params_size() + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + model = _CreateModel( + cudnn_rnn_ops.CUDNN_LSTM, + 4, constant_op.constant([200]), 200, + direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + params_size = model.params_size() + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + model = _CreateModel( + cudnn_rnn_ops.CUDNN_LSTM, + 4, 200, constant_op.constant([200]), + direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + params_size = model.params_size() class CudnnRNNTestInference(TensorFlowTestCase): |