aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-10-04 10:46:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 11:00:00 -0700
commit5e9bd578802fcfff5de9729332eea4ae85c05c9e (patch)
tree175f255c25f1b04711b6baad7d1dbbb0be065070 /tensorflow/core/kernels
parent419fff9de94ea9573f2e368fd6a68fdf54c59bab (diff)
[tf.data] Fix C++ shape inference for `Dataset.concatenate()`.
Previously, we were returning an unknown shape in `Dataset::output_shapes()` for the "most specific compatible shape" between the two inputs. While this does not cause correctness problems (since the unknown shape *is* compatible), we gain the ability to raise errors earlier when more shape information is available. PiperOrigin-RevId: 215764530
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/data/concatenate_dataset_op.cc8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
index a04f150e71..9607e9444c 100644
--- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
@@ -171,16 +171,16 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
static PartialTensorShape MostSpecificCompatibleShape(
const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
- PartialTensorShape output_tensorshape;
if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
- return output_tensorshape;
+ return PartialTensorShape();
+ PartialTensorShape output_tensorshape({});
auto dims1 = ts1.dim_sizes();
auto dims2 = ts2.dim_sizes();
for (int d = 0; d < ts1.dims(); d++) {
if (dims1[d] == dims2[d])
- output_tensorshape.Concatenate(dims1[d]);
+ output_tensorshape.AddDim(dims1[d]);
else
- output_tensorshape.Concatenate(-1);
+ output_tensorshape.AddDim(-1);
}
return output_tensorshape;
}