diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2018-04-03 09:09:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-03 09:13:17 -0700 |
commit | 4657e5336160e019379c373a369e3a9b199bc680 (patch) | |
tree | 4a937a9759fcfaabd07f192018a84e5795b9e3f4 /tensorflow/contrib/training | |
parent | 5d1086ae98ccfe691161ff50c93036d432866741 (diff) |
Make batch_sequences_with_states_test.py work with the C API enabled, take 2.
It turns out the error can depend on what sequence comes first in the
input dict. This change internally sorts the input to make the error
predictable (this is useful for this test, as well as any users who
may run into this).
PiperOrigin-RevId: 191449214
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r-- | tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py | 15 | ||||
-rw-r--r-- | tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py | 5 |
2 files changed, 9 insertions, 11 deletions
diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py index 16c260edb0..f305197c19 100644 --- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py +++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -517,6 +518,7 @@ class BatchSequencesWithStatesTestWithCApi(BatchSequencesWithStatesTest): ops._USE_C_API = self._prev_value +@test_util.with_c_api class PaddingTest(test.TestCase): def testPaddingInvalidLengths(self): @@ -526,15 +528,10 @@ class PaddingTest(test.TestCase): "key_2": constant_op.constant([1.5, 2.5]) # length 2 } - if ops._USE_C_API: - with self.assertRaisesRegexp( - ValueError, "Fill dimensions must be >= 0"): - _, padded_seq = sqss._padding(sequences, 2) - else: - _, padded_seq = sqss._padding(sequences, 2) - with self.assertRaisesOpError( - ".*All sequence lengths must match, but received lengths.*"): - padded_seq["key_1"].eval() + _, padded_seq = sqss._padding(sequences, 2) + with self.assertRaisesOpError( + ".*All sequence lengths must match, but received lengths.*"): + padded_seq["key_1"].eval() def testPadding(self): with ops.Graph().as_default() as g, self.test_session(graph=g): diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py index 7223194885..99d486b183 100644 --- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py +++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py @@ -1574,8 +1574,9 @@ def _padding(sequences, num_unroll): if not sequences: return 0, {} - sequences_dict = {} - for key, value in sequences.items(): + # Sort 'sequences_dict' so 'length' will have a predictable value below. + sequences_dict = collections.OrderedDict() + for key, value in sorted(sequences.items()): if not (isinstance(value, sparse_tensor.SparseTensor) or isinstance(value, sparse_tensor.SparseTensorValue)): sequences_dict[key] = ops.convert_to_tensor(value) |