diff options
author | Saurabh Saxena <srbs@google.com> | 2018-10-02 17:57:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 18:01:17 -0700 |
commit | 9f7a138640408cea58698a432fd1596cf436b484 (patch) | |
tree | d3f66d44d654333c94ebbfec002858e8238ac583 /tensorflow/core/ops | |
parent | b7e9cbab27c893283acc4a6154d7a59dffb23758 (diff) |
Set shape for output tensors of cond_v2.
PiperOrigin-RevId: 215492782
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r-- | tensorflow/core/ops/functional_ops.cc | 21 |
1 files changed, 20 insertions, 1 deletions
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index fed3fa22ed..22b4b07eff 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -110,8 +110,27 @@ REGISTER_OP("If") .Attr("Tout: list(type) >= 0") .Attr("then_branch: func") .Attr("else_branch: func") + .Attr("output_shapes: list(shape) = []") .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + std::vector<PartialTensorShape> output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + // If `output_shapes` attr is set use that as the shapes of the outputs + // else return unknown shapes. + if (output_shapes.empty()) return shape_inference::UnknownShape(c); + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as num outputs (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast<int>(i), output_shape_handle); + } + return Status::OK(); + }); // TODO(drpng): remove this. REGISTER_OP("_While") |