aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py39
1 files changed, 21 insertions, 18 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
index 675b7ce185..c59d3682d4 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
@@ -416,24 +416,27 @@ class CudnnRNNTestParamsSize(TensorFlowTestCase):
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
def testLSTMParamsSizeShape(self):
- with self.assertRaisesRegexp(ValueError, "Shape must be rank 0 but is rank 1"):
- model = _CreateModel(
- cudnn_rnn_ops.CUDNN_LSTM,
- constant_op.constant([4]), 200, 200,
- direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
- params_size = model.params_size()
- with self.assertRaisesRegexp(ValueError, "Shape must be rank 0 but is rank 1"):
- model = _CreateModel(
- cudnn_rnn_ops.CUDNN_LSTM,
- 4, constant_op.constant([200]), 200,
- direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
- params_size = model.params_size()
- with self.assertRaisesRegexp(ValueError, "Shape must be rank 0 but is rank 1"):
- model = _CreateModel(
- cudnn_rnn_ops.CUDNN_LSTM,
- 4, 200, constant_op.constant([200]),
- direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
- params_size = model.params_size()
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ constant_op.constant([4]), 200, 200,
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ 4, constant_op.constant([200]), 200,
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ 4, 200, constant_op.constant([200]),
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
class CudnnRNNTestInference(TensorFlowTestCase):