diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2018-04-01 22:46:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-01 22:48:42 -0700 |
commit | 1fbad5034a8ea531e496b0ecbf9e2c3839b62311 (patch) | |
tree | 90b79408ba9424f6841047b2e06768bca3840c80 /tensorflow/contrib/training | |
parent | 926bd44844d36bbefcbd620eb65ba0019e0a6dde (diff) |
Make batch_sequences_with_states_test.py work with the C API enabled.
The C API improves static shape inference, making more errors caught
at graph construction time instead of runtime.
PiperOrigin-RevId: 191260634
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r-- | tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py index dbdbb08a82..16c260edb0 100644 --- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py +++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py @@ -526,10 +526,15 @@ class PaddingTest(test.TestCase): "key_2": constant_op.constant([1.5, 2.5]) # length 2 } - _, padded_seq = sqss._padding(sequences, 2) - with self.assertRaisesOpError( - ".*All sequence lengths must match, but received lengths.*"): - padded_seq["key_1"].eval() + if ops._USE_C_API: + with self.assertRaisesRegexp( + ValueError, "Fill dimensions must be >= 0"): + _, padded_seq = sqss._padding(sequences, 2) + else: + _, padded_seq = sqss._padding(sequences, 2) + with self.assertRaisesOpError( + ".*All sequence lengths must match, but received lengths.*"): + padded_seq["key_1"].eval() def testPadding(self): with ops.Graph().as_default() as g, self.test_session(graph=g): |