aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-10-03 22:00:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 22:04:33 -0700
commit54cde61fbf473270ce19f8b40e9511373fbc12c7 (patch)
tree846efbc4e1ddf21b77b73a8d78c87de2cbbb9436 /tensorflow/core
parentd3ced638f0496c70c3a063be82b30b358179e369 (diff)
[tf.data] Fix bug in `tf.data.experimental.unbatch()`.
Previously, if the rank of the input to this transformation was statically unknown, we would erroneously report that the output is a scalar, and violate downstream shape integrity checks. Instead, in that case the output shape should be unknown. PiperOrigin-RevId: 215683027
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/kernels/data/unbatch_dataset_op.cc13
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
index 81c432b938..74908994b4 100644
--- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
@@ -41,11 +41,16 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetBase(DatasetContext(ctx)), input_(input) {
input_->Ref();
for (const PartialTensorShape& shape : input->output_shapes()) {
- gtl::InlinedVector<int64, 4> partial_dim_sizes;
- for (int i = 1; i < shape.dims(); ++i) {
- partial_dim_sizes.push_back(shape.dim_size(i));
+ if (!shape.unknown_rank()) {
+ gtl::InlinedVector<int64, 4> partial_dim_sizes;
+ for (int i = 1; i < shape.dims(); ++i) {
+ partial_dim_sizes.push_back(shape.dim_size(i));
+ }
+ shapes_.emplace_back(std::move(partial_dim_sizes));
+ } else {
+ // If the input shape is unknown, the output shape will be unknown.
+ shapes_.emplace_back();
}
- shapes_.emplace_back(std::move(partial_dim_sizes));
}
}