diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2018-01-25 16:15:50 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-25 16:19:47 -0800 |
commit | 2564d5e90e12b4ec3dbd01b442b42a2a8ac7f8f6 (patch) | |
tree | 77730b2ba03bd7e93e1b7d679a6d310cc88328bf | |
parent | 9581462f8743ee92f39d46263d68fc1283082b44 (diff) |
Make batch_sequences_with_states_test.py work with C API enabled, take 2.
This fixes the original rollback by using placeholders for the
SparseTensor shapes. The flakiness was caused by the nondeterministic
ordering of the sequences dict.
PiperOrigin-RevId: 183308774
-rw-r--r-- | tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py | 30 |
1 files changed, 27 insertions, 3 deletions
diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py index 2a0ef0e6b3..dbdbb08a82 100644 --- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py +++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py @@ -53,7 +53,7 @@ class BatchSequencesWithStatesTest(test.TestCase): sp_tensor1 = sparse_tensor.SparseTensor( array_ops.constant(ind1, dtypes.int64), array_ops.constant(val1, dtypes.int64), - array_ops.constant(shape1, dtypes.int64)) + array_ops.placeholder_with_default(shape1, shape=[2])) ind2 = np.array([ [0, 0, 1], [0, 1, 0], @@ -68,7 +68,7 @@ class BatchSequencesWithStatesTest(test.TestCase): sp_tensor2 = sparse_tensor.SparseTensor( array_ops.constant(ind2, dtypes.int64), array_ops.constant(val2, dtypes.int64), - array_ops.constant(shape2, dtypes.int64)) + array_ops.placeholder_with_default(shape2, shape=[3])) sp_tensor3 = sparse_tensor.SparseTensor( array_ops.constant([[1, 9], [2, 2], [2, 10]], dtypes.int64), array_ops.constant([7, 15, 2], dtypes.int64), @@ -320,6 +320,18 @@ class BatchSequencesWithStatesTest(test.TestCase): def testNotAMultiple(self): num_unroll = 3 # Not a divisor of value_length - # so padding would have been necessary. + + # Use placeholder_with_default in sequences to make sure we get runtime + # error instead of shape inference error + sequences = { + "seq1": array_ops.placeholder_with_default(self.sequences["seq1"], + shape=(None, 5)), + "seq2": array_ops.placeholder_with_default(self.sequences["seq2"], + shape=(None, 4, 2)), + "seq3": self.sequences["seq3"], + "seq4": self.sequences["seq4"], + } + with self.test_session() as sess: with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, ".*should be a multiple of: 3, but saw " @@ -330,7 +342,7 @@ class BatchSequencesWithStatesTest(test.TestCase): with coord.stop_on_exception(): next_batch = sqss.batch_sequences_with_states( input_key=self.key, - input_sequences=self.sequences, + input_sequences=sequences, input_context=self.context, input_length=3, initial_states=self.initial_states, @@ -493,6 +505,18 @@ class BatchSequencesWithStatesTest(test.TestCase): expected_seq4_batch2=expected_seq4_batch2) +class BatchSequencesWithStatesTestWithCApi(BatchSequencesWithStatesTest): + + def setUp(self): + self._prev_value = ops._USE_C_API + ops._USE_C_API = True + super(BatchSequencesWithStatesTestWithCApi, self).setUp() + + def tearDown(self): + super(BatchSequencesWithStatesTestWithCApi, self).tearDown() + ops._USE_C_API = self._prev_value + + class PaddingTest(test.TestCase): def testPaddingInvalidLengths(self): |