diff options
author | 2018-10-02 17:57:49 -0700 | |
---|---|---|
committer | 2018-10-02 18:01:17 -0700 | |
commit | 9f7a138640408cea58698a432fd1596cf436b484 (patch) | |
tree | d3f66d44d654333c94ebbfec002858e8238ac583 | |
parent | b7e9cbab27c893283acc4a6154d7a59dffb23758 (diff) |
Set shape for output tensors of cond_v2.
PiperOrigin-RevId: 215492782
-rw-r--r-- | tensorflow/core/ops/functional_ops.cc | 21 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 7 | ||||
-rw-r--r-- | tensorflow/python/ops/cond_v2_impl.py | 20 |
3 files changed, 44 insertions, 4 deletions
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index fed3fa22ed..22b4b07eff 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -110,8 +110,27 @@ REGISTER_OP("If") .Attr("Tout: list(type) >= 0") .Attr("then_branch: func") .Attr("else_branch: func") + .Attr("output_shapes: list(shape) = []") .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + std::vector<PartialTensorShape> output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + // If `output_shapes` attr is set use that as the shapes of the outputs + // else return unknown shapes. + if (output_shapes.empty()) return shape_inference::UnknownShape(c); + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as num outputs (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast<int>(i), output_shape_handle); + } + return Status::OK(); + }); // TODO(drpng): remove this. REGISTER_OP("_While") diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 07ec859766..a1be77601c 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -351,6 +351,13 @@ class ControlFlowTest(test.TestCase): grad = gradients_impl.gradients(y, [v]) self.assertAllEqual([None], grad) + def testCondOutputShape(self): + x = constant_op.constant(1.0) + b = control_flow_ops.cond( + constant_op.constant(True), lambda: math_ops.square(x), + lambda: math_ops.subtract(x, 1.)) + self.assertEqual(b.shape, tensor_shape.scalar()) + def testFetchable(self): with self.cached_session() as sess: x = array_ops.placeholder(dtypes.float32) diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py index f8b1ddb140..195ad11c71 100644 --- a/tensorflow/python/ops/cond_v2_impl.py +++ b/tensorflow/python/ops/cond_v2_impl.py @@ -96,9 +96,12 @@ def cond_v2(pred, true_fn, false_fn, name="cond"): # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access - pred, cond_inputs, [t.dtype for t in true_graph.outputs], + pred, + cond_inputs, [t.dtype for t in true_graph.outputs], _create_new_tf_function(true_graph), _create_new_tf_function(false_graph), + output_shapes=_get_output_shapes(true_graph.outputs, + false_graph.outputs), name=scope) # Set the flag to enable lowering on the `if` op if necessary @@ -175,9 +178,12 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name # Create the gradient If op. tensors = gen_functional_ops._if( - op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs], + op.inputs[0], + grad_inputs, [t.dtype for t in true_grad_graph.outputs], _create_new_tf_function(true_grad_graph), - _create_new_tf_function(false_grad_graph)) + _create_new_tf_function(false_grad_graph), + output_shapes=_get_output_shapes(true_grad_graph.outputs, + false_grad_graph.outputs)) # The predicate has no gradient. return [None] + tensors[:num_grad_outputs] @@ -480,6 +486,14 @@ def _check_same_outputs(true_graph, false_graph): " false_fn: %s" % (true_output_types, false_output_types)) +def _get_output_shapes(true_graph_outputs, false_graph_outputs): + output_shapes = [ + t_out.shape.most_specific_compatible_shape(f_out.shape) + for t_out, f_out in zip(true_graph_outputs, false_graph_outputs) + ] + return output_shapes + + def _is_ancestor(graph, maybe_ancestor): if maybe_ancestor == graph: return True |