aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-04-01 22:46:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-01 22:48:42 -0700
commit1fbad5034a8ea531e496b0ecbf9e2c3839b62311 (patch)
tree90b79408ba9424f6841047b2e06768bca3840c80 /tensorflow/contrib/training
parent926bd44844d36bbefcbd620eb65ba0019e0a6dde (diff)
Make batch_sequences_with_states_test.py work with the C API enabled.
The C API improves static shape inference, making more errors caught at graph construction time instead of runtime. PiperOrigin-RevId: 191260634
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r--tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
index dbdbb08a82..16c260edb0 100644
--- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
+++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
@@ -526,10 +526,15 @@ class PaddingTest(test.TestCase):
"key_2": constant_op.constant([1.5, 2.5]) # length 2
}
- _, padded_seq = sqss._padding(sequences, 2)
- with self.assertRaisesOpError(
- ".*All sequence lengths must match, but received lengths.*"):
- padded_seq["key_1"].eval()
+ if ops._USE_C_API:
+ with self.assertRaisesRegexp(
+ ValueError, "Fill dimensions must be >= 0"):
+ _, padded_seq = sqss._padding(sequences, 2)
+ else:
+ _, padded_seq = sqss._padding(sequences, 2)
+ with self.assertRaisesOpError(
+ ".*All sequence lengths must match, but received lengths.*"):
+ padded_seq["key_1"].eval()
def testPadding(self):
with ops.Graph().as_default() as g, self.test_session(graph=g):