aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-10-02 17:57:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 18:01:17 -0700
commit9f7a138640408cea58698a432fd1596cf436b484 (patch)
treed3f66d44d654333c94ebbfec002858e8238ac583
parentb7e9cbab27c893283acc4a6154d7a59dffb23758 (diff)
Set shape for output tensors of cond_v2.
PiperOrigin-RevId: 215492782
-rw-r--r--tensorflow/core/ops/functional_ops.cc21
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py7
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py20
3 files changed, 44 insertions, 4 deletions
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc
index fed3fa22ed..22b4b07eff 100644
--- a/tensorflow/core/ops/functional_ops.cc
+++ b/tensorflow/core/ops/functional_ops.cc
@@ -110,8 +110,27 @@ REGISTER_OP("If")
.Attr("Tout: list(type) >= 0")
.Attr("then_branch: func")
.Attr("else_branch: func")
+ .Attr("output_shapes: list(shape) = []")
.SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ // If `output_shapes` attr is set use that as the shapes of the outputs
+ // else return unknown shapes.
+ if (output_shapes.empty()) return shape_inference::UnknownShape(c);
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as num outputs (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+ });
// TODO(drpng): remove this.
REGISTER_OP("_While")
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 07ec859766..a1be77601c 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -351,6 +351,13 @@ class ControlFlowTest(test.TestCase):
grad = gradients_impl.gradients(y, [v])
self.assertAllEqual([None], grad)
+ def testCondOutputShape(self):
+ x = constant_op.constant(1.0)
+ b = control_flow_ops.cond(
+ constant_op.constant(True), lambda: math_ops.square(x),
+ lambda: math_ops.subtract(x, 1.))
+ self.assertEqual(b.shape, tensor_shape.scalar())
+
def testFetchable(self):
with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index f8b1ddb140..195ad11c71 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -96,9 +96,12 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
# Create the If op.
tensors = gen_functional_ops._if( # pylint: disable=protected-access
- pred, cond_inputs, [t.dtype for t in true_graph.outputs],
+ pred,
+ cond_inputs, [t.dtype for t in true_graph.outputs],
_create_new_tf_function(true_graph),
_create_new_tf_function(false_graph),
+ output_shapes=_get_output_shapes(true_graph.outputs,
+ false_graph.outputs),
name=scope)
# Set the flag to enable lowering on the `if` op if necessary
@@ -175,9 +178,12 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
# Create the gradient If op.
tensors = gen_functional_ops._if(
- op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs],
+ op.inputs[0],
+ grad_inputs, [t.dtype for t in true_grad_graph.outputs],
_create_new_tf_function(true_grad_graph),
- _create_new_tf_function(false_grad_graph))
+ _create_new_tf_function(false_grad_graph),
+ output_shapes=_get_output_shapes(true_grad_graph.outputs,
+ false_grad_graph.outputs))
# The predicate has no gradient.
return [None] + tensors[:num_grad_outputs]
@@ -480,6 +486,14 @@ def _check_same_outputs(true_graph, false_graph):
" false_fn: %s" % (true_output_types, false_output_types))
+def _get_output_shapes(true_graph_outputs, false_graph_outputs):
+ output_shapes = [
+ t_out.shape.most_specific_compatible_shape(f_out.shape)
+ for t_out, f_out in zip(true_graph_outputs, false_graph_outputs)
+ ]
+ return output_shapes
+
+
def _is_ancestor(graph, maybe_ancestor):
if maybe_ancestor == graph:
return True