aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-04-03 09:09:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-03 09:13:17 -0700
commit4657e5336160e019379c373a369e3a9b199bc680 (patch)
tree4a937a9759fcfaabd07f192018a84e5795b9e3f4 /tensorflow/contrib/training
parent5d1086ae98ccfe691161ff50c93036d432866741 (diff)
Make batch_sequences_with_states_test.py work with the C API enabled, take 2.
It turns out the error can depend on what sequence comes first in the input dict. This change internally sorts the input to make the error predictable (this is useful for this test, as well as any users who may run into this). PiperOrigin-RevId: 191449214
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r--tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py15
-rw-r--r--tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py5
2 files changed, 9 insertions, 11 deletions
diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
index 16c260edb0..f305197c19 100644
--- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
+++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
@@ -517,6 +518,7 @@ class BatchSequencesWithStatesTestWithCApi(BatchSequencesWithStatesTest):
ops._USE_C_API = self._prev_value
+@test_util.with_c_api
class PaddingTest(test.TestCase):
def testPaddingInvalidLengths(self):
@@ -526,15 +528,10 @@ class PaddingTest(test.TestCase):
"key_2": constant_op.constant([1.5, 2.5]) # length 2
}
- if ops._USE_C_API:
- with self.assertRaisesRegexp(
- ValueError, "Fill dimensions must be >= 0"):
- _, padded_seq = sqss._padding(sequences, 2)
- else:
- _, padded_seq = sqss._padding(sequences, 2)
- with self.assertRaisesOpError(
- ".*All sequence lengths must match, but received lengths.*"):
- padded_seq["key_1"].eval()
+ _, padded_seq = sqss._padding(sequences, 2)
+ with self.assertRaisesOpError(
+ ".*All sequence lengths must match, but received lengths.*"):
+ padded_seq["key_1"].eval()
def testPadding(self):
with ops.Graph().as_default() as g, self.test_session(graph=g):
diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py
index 7223194885..99d486b183 100644
--- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py
+++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py
@@ -1574,8 +1574,9 @@ def _padding(sequences, num_unroll):
if not sequences:
return 0, {}
- sequences_dict = {}
- for key, value in sequences.items():
+ # Sort 'sequences_dict' so 'length' will have a predictable value below.
+ sequences_dict = collections.OrderedDict()
+ for key, value in sorted(sequences.items()):
if not (isinstance(value, sparse_tensor.SparseTensor) or
isinstance(value, sparse_tensor.SparseTensorValue)):
sequences_dict[key] = ops.convert_to_tensor(value)