diff options
author | Saurabh Saxena <srbs@google.com> | 2018-10-02 17:57:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 18:01:17 -0700 |
commit | 9f7a138640408cea58698a432fd1596cf436b484 (patch) | |
tree | d3f66d44d654333c94ebbfec002858e8238ac583 /tensorflow/python/ops | |
parent | b7e9cbab27c893283acc4a6154d7a59dffb23758 (diff) |
Set shape for output tensors of cond_v2.
PiperOrigin-RevId: 215492782
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r-- | tensorflow/python/ops/cond_v2_impl.py | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py index f8b1ddb140..195ad11c71 100644 --- a/tensorflow/python/ops/cond_v2_impl.py +++ b/tensorflow/python/ops/cond_v2_impl.py @@ -96,9 +96,12 @@ def cond_v2(pred, true_fn, false_fn, name="cond"): # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access - pred, cond_inputs, [t.dtype for t in true_graph.outputs], + pred, + cond_inputs, [t.dtype for t in true_graph.outputs], _create_new_tf_function(true_graph), _create_new_tf_function(false_graph), + output_shapes=_get_output_shapes(true_graph.outputs, + false_graph.outputs), name=scope) # Set the flag to enable lowering on the `if` op if necessary @@ -175,9 +178,12 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name # Create the gradient If op. tensors = gen_functional_ops._if( - op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs], + op.inputs[0], + grad_inputs, [t.dtype for t in true_grad_graph.outputs], _create_new_tf_function(true_grad_graph), - _create_new_tf_function(false_grad_graph)) + _create_new_tf_function(false_grad_graph), + output_shapes=_get_output_shapes(true_grad_graph.outputs, + false_grad_graph.outputs)) # The predicate has no gradient. return [None] + tensors[:num_grad_outputs] @@ -480,6 +486,14 @@ def _check_same_outputs(true_graph, false_graph): " false_fn: %s" % (true_output_types, false_output_types)) +def _get_output_shapes(true_graph_outputs, false_graph_outputs): + output_shapes = [ + t_out.shape.most_specific_compatible_shape(f_out.shape) + for t_out, f_out in zip(true_graph_outputs, false_graph_outputs) + ] + return output_shapes + + def _is_ancestor(graph, maybe_ancestor): if maybe_ancestor == graph: return True |