diff options
author | 2018-10-03 22:00:51 -0700 | |
---|---|---|
committer | 2018-10-03 22:04:33 -0700 | |
commit | 54cde61fbf473270ce19f8b40e9511373fbc12c7 (patch) | |
tree | 846efbc4e1ddf21b77b73a8d78c87de2cbbb9436 /tensorflow/core | |
parent | d3ced638f0496c70c3a063be82b30b358179e369 (diff) |
[tf.data] Fix bug in `tf.data.experimental.unbatch()`.
Previously, if the rank of the input to this transformation was
statically unknown, we would erroneously report that the output is a
scalar, and violate downstream shape integrity checks. Instead, in
that case the output shape should be unknown.
PiperOrigin-RevId: 215683027
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/kernels/data/unbatch_dataset_op.cc | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc index 81c432b938..74908994b4 100644 --- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc @@ -41,11 +41,16 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { : DatasetBase(DatasetContext(ctx)), input_(input) { input_->Ref(); for (const PartialTensorShape& shape : input->output_shapes()) { - gtl::InlinedVector<int64, 4> partial_dim_sizes; - for (int i = 1; i < shape.dims(); ++i) { - partial_dim_sizes.push_back(shape.dim_size(i)); + if (!shape.unknown_rank()) { + gtl::InlinedVector<int64, 4> partial_dim_sizes; + for (int i = 1; i < shape.dims(); ++i) { + partial_dim_sizes.push_back(shape.dim_size(i)); + } + shapes_.emplace_back(std::move(partial_dim_sizes)); + } else { + // If the input shape is unknown, the output shape will be unknown. + shapes_.emplace_back(); } - shapes_.emplace_back(std::move(partial_dim_sizes)); } } |