From 73cea1b095c0211b532663ea5edf0dc50ff5a448 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 21 Mar 2018 08:40:35 -0700 Subject: More accurate shape inference for TensorArrayGatherV3 and TensorArrayScatterV3 PiperOrigin-RevId: 189912762 --- tensorflow/core/ops/data_flow_ops.cc | 37 +++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index 4f946fb3ca..3112f35da4 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -668,13 +668,31 @@ REGISTER_OP("TensorArrayGatherV3") .Attr("dtype: type") .Attr("element_shape: shape = { unknown_rank: true }") .SetShapeFn([](InferenceContext* c) { + ShapeHandle indices; ShapeHandle unused; DimensionHandle unused_dim; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices)); TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim)); TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); - return shape_inference::UnknownShape(c); + auto shapes = c->input_handle_shapes_and_types(0); + if (shapes != nullptr && !shapes->empty()) { + ShapeHandle tensor_shape = shapes->at(0).shape; + ShapeHandle output_shape; + TF_RETURN_IF_ERROR( + c->Concatenate(indices, tensor_shape, &output_shape)); + c->set_output(0, output_shape); + return Status::OK(); + } else { + PartialTensorShape p; + TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &p)); + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s)); + ShapeHandle output_shape; + TF_RETURN_IF_ERROR(c->Concatenate(indices, s, &output_shape)); + c->set_output(0, output_shape); + return Status::OK(); + } }); REGISTER_OP("TensorArrayScatterV3") @@ -685,12 +703,25 @@ REGISTER_OP("TensorArrayScatterV3") .Output("flow_out: float") .Attr("T: type") .SetShapeFn([](InferenceContext* c) { + ShapeHandle indices; ShapeHandle unused; DimensionHandle unused_dim; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices)); TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + ShapeHandle value_shape; + // Assert that the length of the indices tensor is equal to the first + // dimension of the value tensor. + TF_RETURN_IF_ERROR( + c->MergePrefix(c->input(2), indices, &value_shape, &indices)); + auto shapes = c->input_handle_shapes_and_types(0); + if (shapes != nullptr && !shapes->empty()) { + ShapeHandle tensor_shape = shapes->at(0).shape; + ShapeHandle fed_shape; + TF_RETURN_IF_ERROR(c->Subshape(value_shape, 1, &fed_shape)); + TF_RETURN_IF_ERROR(c->Merge(tensor_shape, fed_shape, &fed_shape)); + } return shape_inference::ScalarShape(c); }); -- cgit v1.2.3