diff options
author | 2018-01-03 20:09:22 -0800 | |
---|---|---|
committer | 2018-01-03 20:13:03 -0800 | |
commit | 9cf8cccba6831eb7df173ddcbf8fd89ec7a06724 (patch) | |
tree | 3e444f739e80dfd972d24f03f26196a3c662b635 /tensorflow/core/ops/data_flow_ops.cc | |
parent | 241ed31c29b1f419b7d076ae04bd2e2bb9b4ddea (diff) |
Made the shape of the elements stored in a TensorArrayV3 available to shape inference.
PiperOrigin-RevId: 180750305
Diffstat (limited to 'tensorflow/core/ops/data_flow_ops.cc')
-rw-r--r-- | tensorflow/core/ops/data_flow_ops.cc | 31 |
1 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index cc0ea38b9a..9329749473 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -1356,6 +1356,19 @@ REGISTER_OP("TensorArrayV3") TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); c->set_output(0, c->Vector(2)); c->set_output(1, c->Scalar()); + bool identical_shapes; + TF_RETURN_IF_ERROR( + c->GetAttr("identical_element_shapes", &identical_shapes)); + DataType t; + TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t)); + PartialTensorShape p; + TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &p)); + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s)); + if (c->FullyDefined(s) || identical_shapes) { + c->set_output_handle_shapes_and_types( + 0, std::vector<shape_inference::ShapeAndType>{{s, t}}); + } return Status::OK(); }) .Doc(R"doc( @@ -1464,6 +1477,15 @@ REGISTER_OP("TensorArrayWriteV3") ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data != nullptr && !handle_data->empty()) { + shape_inference::ShapeAndType shape_and_type = (*handle_data)[0]; + ShapeHandle value_shape = c->input(2); + TF_RETURN_IF_ERROR( + c->Merge(shape_and_type.shape, value_shape, &unused)); + } + return shape_inference::ScalarShape(c); }) .Doc(R"doc( @@ -1490,7 +1512,14 @@ REGISTER_OP("TensorArrayReadV3") ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); - return shape_inference::UnknownShape(c); + auto shapes = c->input_handle_shapes_and_types(0); + if (shapes != nullptr && !shapes->empty()) { + ShapeHandle tensor_shape = shapes->at(0).shape; + c->set_output(0, tensor_shape); + return Status::OK(); + } else { + return shape_inference::UnknownShape(c); + } }) .Doc(R"doc( Read an element from the TensorArray into output `value`. |