diff options
author | Derek Murray <mrry@google.com> | 2018-10-03 22:00:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 22:04:33 -0700 |
commit | 54cde61fbf473270ce19f8b40e9511373fbc12c7 (patch) | |
tree | 846efbc4e1ddf21b77b73a8d78c87de2cbbb9436 /tensorflow/python | |
parent | d3ced638f0496c70c3a063be82b30b358179e369 (diff) |
[tf.data] Fix bug in `tf.data.experimental.unbatch()`.
Previously, if the rank of the input to this transformation was
statically unknown, we would erroneously report that the output is a
scalar, and violate downstream shape integrity checks. Instead, in
that case the output shape should be unknown.
PiperOrigin-RevId: 215683027
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py index 8703b2810e..956b4518f6 100644 --- a/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py @@ -131,6 +131,20 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): "larger than the row shape"): sess.run(get_next) + def testUnbatchWithUnknownRankInput(self): + placeholder = array_ops.placeholder(dtypes.int32) + dataset = dataset_ops.Dataset.from_tensors(placeholder).apply( + batching.unbatch()) + iterator = dataset.make_initializable_iterator() + next_elem = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]}) + for i in range(4): + self.assertEqual(i, sess.run(next_elem)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_elem) + def testUnbatchScalarDataset(self): data = tuple([math_ops.range(10) for _ in range(3)]) data = dataset_ops.Dataset.from_tensor_slices(data) |