aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-10-02 17:57:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 18:01:17 -0700
commit9f7a138640408cea58698a432fd1596cf436b484 (patch)
treed3f66d44d654333c94ebbfec002858e8238ac583 /tensorflow/core/ops
parentb7e9cbab27c893283acc4a6154d7a59dffb23758 (diff)
Set shape for output tensors of cond_v2.
PiperOrigin-RevId: 215492782
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r--tensorflow/core/ops/functional_ops.cc21
1 files changed, 20 insertions, 1 deletions
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc
index fed3fa22ed..22b4b07eff 100644
--- a/tensorflow/core/ops/functional_ops.cc
+++ b/tensorflow/core/ops/functional_ops.cc
@@ -110,8 +110,27 @@ REGISTER_OP("If")
.Attr("Tout: list(type) >= 0")
.Attr("then_branch: func")
.Attr("else_branch: func")
+ .Attr("output_shapes: list(shape) = []")
.SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ // If `output_shapes` attr is set use that as the shapes of the outputs
+ // else return unknown shapes.
+ if (output_shapes.empty()) return shape_inference::UnknownShape(c);
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as num outputs (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+ });
// TODO(drpng): remove this.
REGISTER_OP("_While")