diff options
-rw-r--r-- | tensorflow/core/ops/data_flow_ops.cc | 31 |
1 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index cc0ea38b9a..9329749473 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -1356,6 +1356,19 @@ REGISTER_OP("TensorArrayV3") TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); c->set_output(0, c->Vector(2)); c->set_output(1, c->Scalar()); + bool identical_shapes; + TF_RETURN_IF_ERROR( + c->GetAttr("identical_element_shapes", &identical_shapes)); + DataType t; + TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t)); + PartialTensorShape p; + TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &p)); + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s)); + if (c->FullyDefined(s) || identical_shapes) { + c->set_output_handle_shapes_and_types( + 0, std::vector<shape_inference::ShapeAndType>{{s, t}}); + } return Status::OK(); }) .Doc(R"doc( @@ -1464,6 +1477,15 @@ REGISTER_OP("TensorArrayWriteV3") ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data != nullptr && !handle_data->empty()) { + shape_inference::ShapeAndType shape_and_type = (*handle_data)[0]; + ShapeHandle value_shape = c->input(2); + TF_RETURN_IF_ERROR( + c->Merge(shape_and_type.shape, value_shape, &unused)); + } + return shape_inference::ScalarShape(c); }) .Doc(R"doc( @@ -1490,7 +1512,14 @@ REGISTER_OP("TensorArrayReadV3") ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); - return shape_inference::UnknownShape(c); + auto shapes = c->input_handle_shapes_and_types(0); + if (shapes != nullptr && !shapes->empty()) { + ShapeHandle tensor_shape = shapes->at(0).shape; + c->set_output(0, tensor_shape); + return Status::OK(); + } else { + return shape_inference::UnknownShape(c); + } }) .Doc(R"doc( Read an element from the TensorArray into output `value`. |