aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-10-02 13:18:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 13:22:56 -0700
commit8d12c635cc48e896da0bcac1cd568bd6381ca64e (patch)
treed651bbcfdd325e649c230c19424acc62c28de725 /tensorflow/core
parent78e4ce52aeda5a10ddaf5e64ea8958f439a2f9f2 (diff)
Support shape_invariants in while_v2. Note that this arg is temporary and may be replaced by automatic shape inference in TF 2.0 (or before).
Add a output_shapes attr to While op to allow output shapes to be different from the incoming loop_vars. PiperOrigin-RevId: 215446737
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt7
-rw-r--r--tensorflow/core/ops/functional_ops.cc23
2 files changed, 28 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
index c94ee2f227..0ec95dd684 100644
--- a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
@@ -88,6 +88,13 @@ library {
}
}
}
+ attr {
+ key: "output_shapes"
+ value {
+ list {
+ }
+ }
+ }
}
ret {
key: "while"
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc
index bda4a75c5d..fed3fa22ed 100644
--- a/tensorflow/core/ops/functional_ops.cc
+++ b/tensorflow/core/ops/functional_ops.cc
@@ -150,10 +150,29 @@ REGISTER_OP("While")
.Attr("T: list(type) >= 0")
.Attr("cond: func")
.Attr("body: func")
+ .Attr("output_shapes: list(shape) = []")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
- for (int i = 0; i < c->num_outputs(); ++i) {
- c->set_output(i, c->input(i));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ // If `output_shapes` attr is set use that as the shapes of the outputs
+ // else use the input shapes.
+ if (!output_shapes.empty()) {
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as num outputs (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ } else {
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->input(i));
+ }
}
return Status::OK();
});