aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc31
1 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index cc0ea38b9a..9329749473 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -1356,6 +1356,19 @@ REGISTER_OP("TensorArrayV3")
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
c->set_output(0, c->Vector(2));
c->set_output(1, c->Scalar());
+ bool identical_shapes;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr("identical_element_shapes", &identical_shapes));
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
+ PartialTensorShape p;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &p));
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
+ if (c->FullyDefined(s) || identical_shapes) {
+ c->set_output_handle_shapes_and_types(
+ 0, std::vector<shape_inference::ShapeAndType>{{s, t}});
+ }
return Status::OK();
})
.Doc(R"doc(
@@ -1464,6 +1477,15 @@ REGISTER_OP("TensorArrayWriteV3")
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+
+ auto* handle_data = c->input_handle_shapes_and_types(0);
+ if (handle_data != nullptr && !handle_data->empty()) {
+ shape_inference::ShapeAndType shape_and_type = (*handle_data)[0];
+ ShapeHandle value_shape = c->input(2);
+ TF_RETURN_IF_ERROR(
+ c->Merge(shape_and_type.shape, value_shape, &unused));
+ }
+
return shape_inference::ScalarShape(c);
})
.Doc(R"doc(
@@ -1490,7 +1512,14 @@ REGISTER_OP("TensorArrayReadV3")
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
- return shape_inference::UnknownShape(c);
+ auto shapes = c->input_handle_shapes_and_types(0);
+ if (shapes != nullptr && !shapes->empty()) {
+ ShapeHandle tensor_shape = shapes->at(0).shape;
+ c->set_output(0, tensor_shape);
+ return Status::OK();
+ } else {
+ return shape_inference::UnknownShape(c);
+ }
})
.Doc(R"doc(
Read an element from the TensorArray into output `value`.