aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/data_flow_ops.cc
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-01-03 20:09:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-03 20:13:03 -0800
commit9cf8cccba6831eb7df173ddcbf8fd89ec7a06724 (patch)
tree3e444f739e80dfd972d24f03f26196a3c662b635 /tensorflow/core/ops/data_flow_ops.cc
parent241ed31c29b1f419b7d076ae04bd2e2bb9b4ddea (diff)
Made the shape of the elements stored in a TensorArrayV3 available to shape inference.
PiperOrigin-RevId: 180750305
Diffstat (limited to 'tensorflow/core/ops/data_flow_ops.cc')
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc31
1 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index cc0ea38b9a..9329749473 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -1356,6 +1356,19 @@ REGISTER_OP("TensorArrayV3")
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
c->set_output(0, c->Vector(2));
c->set_output(1, c->Scalar());
+ bool identical_shapes;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr("identical_element_shapes", &identical_shapes));
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
+ PartialTensorShape p;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &p));
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
+ if (c->FullyDefined(s) || identical_shapes) {
+ c->set_output_handle_shapes_and_types(
+ 0, std::vector<shape_inference::ShapeAndType>{{s, t}});
+ }
return Status::OK();
})
.Doc(R"doc(
@@ -1464,6 +1477,15 @@ REGISTER_OP("TensorArrayWriteV3")
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+
+ auto* handle_data = c->input_handle_shapes_and_types(0);
+ if (handle_data != nullptr && !handle_data->empty()) {
+ shape_inference::ShapeAndType shape_and_type = (*handle_data)[0];
+ ShapeHandle value_shape = c->input(2);
+ TF_RETURN_IF_ERROR(
+ c->Merge(shape_and_type.shape, value_shape, &unused));
+ }
+
return shape_inference::ScalarShape(c);
})
.Doc(R"doc(
@@ -1490,7 +1512,14 @@ REGISTER_OP("TensorArrayReadV3")
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
- return shape_inference::UnknownShape(c);
+ auto shapes = c->input_handle_shapes_and_types(0);
+ if (shapes != nullptr && !shapes->empty()) {
+ ShapeHandle tensor_shape = shapes->at(0).shape;
+ c->set_output(0, tensor_shape);
+ return Status::OK();
+ } else {
+ return shape_inference::UnknownShape(c);
+ }
})
.Doc(R"doc(
Read an element from the TensorArray into output `value`.