diff options
author | Saurabh Saxena <srbs@google.com> | 2018-10-02 13:18:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 13:22:56 -0700 |
commit | 8d12c635cc48e896da0bcac1cd568bd6381ca64e (patch) | |
tree | d651bbcfdd325e649c230c19424acc62c28de725 /tensorflow/core | |
parent | 78e4ce52aeda5a10ddaf5e64ea8958f439a2f9f2 (diff) |
Support shape_invariants in while_v2. Note that this arg is temporary and may be replaced by automatic shape inference in TF 2.0 (or before).
Add a output_shapes attr to While op to allow output shapes to be different from the incoming loop_vars.
PiperOrigin-RevId: 215446737
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt | 7 | ||||
-rw-r--r-- | tensorflow/core/ops/functional_ops.cc | 23 |
2 files changed, 28 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt index c94ee2f227..0ec95dd684 100644 --- a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt +++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt @@ -88,6 +88,13 @@ library { } } } + attr { + key: "output_shapes" + value { + list { + } + } + } } ret { key: "while" diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index bda4a75c5d..fed3fa22ed 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -150,10 +150,29 @@ REGISTER_OP("While") .Attr("T: list(type) >= 0") .Attr("cond: func") .Attr("body: func") + .Attr("output_shapes: list(shape) = []") .SetIsStateful() .SetShapeFn([](shape_inference::InferenceContext* c) { - for (int i = 0; i < c->num_outputs(); ++i) { - c->set_output(i, c->input(i)); + std::vector<PartialTensorShape> output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + // If `output_shapes` attr is set use that as the shapes of the outputs + // else use the input shapes. + if (!output_shapes.empty()) { + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as num outputs (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast<int>(i), output_shape_handle); + } + } else { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->input(i)); + } } return Status::OK(); }); |