aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-10-03 22:00:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 22:04:33 -0700
commit54cde61fbf473270ce19f8b40e9511373fbc12c7 (patch)
tree846efbc4e1ddf21b77b73a8d78c87de2cbbb9436 /tensorflow/python
parentd3ced638f0496c70c3a063be82b30b358179e369 (diff)
[tf.data] Fix bug in `tf.data.experimental.unbatch()`.
Previously, if the rank of the input to this transformation was statically unknown, we would erroneously report that the output is a scalar, and violate downstream shape integrity checks. Instead, in that case the output shape should be unknown. PiperOrigin-RevId: 215683027
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py
index 8703b2810e..956b4518f6 100644
--- a/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py
@@ -131,6 +131,20 @@ class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
"larger than the row shape"):
sess.run(get_next)
+ def testUnbatchWithUnknownRankInput(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ dataset = dataset_ops.Dataset.from_tensors(placeholder).apply(
+ batching.unbatch())
+ iterator = dataset.make_initializable_iterator()
+ next_elem = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
+ for i in range(4):
+ self.assertEqual(i, sess.run(next_elem))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_elem)
+
def testUnbatchScalarDataset(self):
data = tuple([math_ops.range(10) for _ in range(3)])
data = dataset_ops.Dataset.from_tensor_slices(data)