aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-01-25 16:15:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-25 16:19:47 -0800
commit2564d5e90e12b4ec3dbd01b442b42a2a8ac7f8f6 (patch)
tree77730b2ba03bd7e93e1b7d679a6d310cc88328bf
parent9581462f8743ee92f39d46263d68fc1283082b44 (diff)
Make batch_sequences_with_states_test.py work with C API enabled, take 2.
This fixes the original rollback by using placeholders for the SparseTensor shapes. The flakiness was caused by the nondeterministic ordering of the sequences dict. PiperOrigin-RevId: 183308774
-rw-r--r--tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py30
1 files changed, 27 insertions, 3 deletions
diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
index 2a0ef0e6b3..dbdbb08a82 100644
--- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
+++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
@@ -53,7 +53,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
sp_tensor1 = sparse_tensor.SparseTensor(
array_ops.constant(ind1, dtypes.int64),
array_ops.constant(val1, dtypes.int64),
- array_ops.constant(shape1, dtypes.int64))
+ array_ops.placeholder_with_default(shape1, shape=[2]))
ind2 = np.array([
[0, 0, 1],
[0, 1, 0],
@@ -68,7 +68,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
sp_tensor2 = sparse_tensor.SparseTensor(
array_ops.constant(ind2, dtypes.int64),
array_ops.constant(val2, dtypes.int64),
- array_ops.constant(shape2, dtypes.int64))
+ array_ops.placeholder_with_default(shape2, shape=[3]))
sp_tensor3 = sparse_tensor.SparseTensor(
array_ops.constant([[1, 9], [2, 2], [2, 10]], dtypes.int64),
array_ops.constant([7, 15, 2], dtypes.int64),
@@ -320,6 +320,18 @@ class BatchSequencesWithStatesTest(test.TestCase):
def testNotAMultiple(self):
num_unroll = 3 # Not a divisor of value_length -
# so padding would have been necessary.
+
+ # Use placeholder_with_default in sequences to make sure we get runtime
+ # error instead of shape inference error
+ sequences = {
+ "seq1": array_ops.placeholder_with_default(self.sequences["seq1"],
+ shape=(None, 5)),
+ "seq2": array_ops.placeholder_with_default(self.sequences["seq2"],
+ shape=(None, 4, 2)),
+ "seq3": self.sequences["seq3"],
+ "seq4": self.sequences["seq4"],
+ }
+
with self.test_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
".*should be a multiple of: 3, but saw "
@@ -330,7 +342,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
with coord.stop_on_exception():
next_batch = sqss.batch_sequences_with_states(
input_key=self.key,
- input_sequences=self.sequences,
+ input_sequences=sequences,
input_context=self.context,
input_length=3,
initial_states=self.initial_states,
@@ -493,6 +505,18 @@ class BatchSequencesWithStatesTest(test.TestCase):
expected_seq4_batch2=expected_seq4_batch2)
+class BatchSequencesWithStatesTestWithCApi(BatchSequencesWithStatesTest):
+
+ def setUp(self):
+ self._prev_value = ops._USE_C_API
+ ops._USE_C_API = True
+ super(BatchSequencesWithStatesTestWithCApi, self).setUp()
+
+ def tearDown(self):
+ super(BatchSequencesWithStatesTestWithCApi, self).tearDown()
+ ops._USE_C_API = self._prev_value
+
+
class PaddingTest(test.TestCase):
def testPaddingInvalidLengths(self):