aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-03-21 08:40:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-21 08:43:32 -0700
commit73cea1b095c0211b532663ea5edf0dc50ff5a448 (patch)
treede57f39aec4e2d9e924c5acb12cf68b0b1e4c64f
parent56054e42a474a527f12f4d8d0b1f37eb1efd189d (diff)
More accurate shape inference for TensorArrayGatherV3 and TensorArrayScatterV3
PiperOrigin-RevId: 189912762
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc37
1 files changed, 34 insertions, 3 deletions
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index 4f946fb3ca..3112f35da4 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -668,13 +668,31 @@ REGISTER_OP("TensorArrayGatherV3")
.Attr("dtype: type")
.Attr("element_shape: shape = { unknown_rank: true }")
.SetShapeFn([](InferenceContext* c) {
+ ShapeHandle indices;
ShapeHandle unused;
DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
- return shape_inference::UnknownShape(c);
+ auto shapes = c->input_handle_shapes_and_types(0);
+ if (shapes != nullptr && !shapes->empty()) {
+ ShapeHandle tensor_shape = shapes->at(0).shape;
+ ShapeHandle output_shape;
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(indices, tensor_shape, &output_shape));
+ c->set_output(0, output_shape);
+ return Status::OK();
+ } else {
+ PartialTensorShape p;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &p));
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
+ ShapeHandle output_shape;
+ TF_RETURN_IF_ERROR(c->Concatenate(indices, s, &output_shape));
+ c->set_output(0, output_shape);
+ return Status::OK();
+ }
});
REGISTER_OP("TensorArrayScatterV3")
@@ -685,12 +703,25 @@ REGISTER_OP("TensorArrayScatterV3")
.Output("flow_out: float")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
+ ShapeHandle indices;
ShapeHandle unused;
DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ ShapeHandle value_shape;
+ // Assert that the length of the indices tensor is equal to the first
+ // dimension of the value tensor.
+ TF_RETURN_IF_ERROR(
+ c->MergePrefix(c->input(2), indices, &value_shape, &indices));
+ auto shapes = c->input_handle_shapes_and_types(0);
+ if (shapes != nullptr && !shapes->empty()) {
+ ShapeHandle tensor_shape = shapes->at(0).shape;
+ ShapeHandle fed_shape;
+ TF_RETURN_IF_ERROR(c->Subshape(value_shape, 1, &fed_shape));
+ TF_RETURN_IF_ERROR(c->Merge(tensor_shape, fed_shape, &fed_shape));
+ }
return shape_inference::ScalarShape(c);
});