aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-10-02 17:57:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 18:01:17 -0700
commit9f7a138640408cea58698a432fd1596cf436b484 (patch)
treed3f66d44d654333c94ebbfec002858e8238ac583 /tensorflow/python/ops
parentb7e9cbab27c893283acc4a6154d7a59dffb23758 (diff)
Set shape for output tensors of cond_v2.
PiperOrigin-RevId: 215492782
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py20
1 files changed, 17 insertions, 3 deletions
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index f8b1ddb140..195ad11c71 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -96,9 +96,12 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
# Create the If op.
tensors = gen_functional_ops._if( # pylint: disable=protected-access
- pred, cond_inputs, [t.dtype for t in true_graph.outputs],
+ pred,
+ cond_inputs, [t.dtype for t in true_graph.outputs],
_create_new_tf_function(true_graph),
_create_new_tf_function(false_graph),
+ output_shapes=_get_output_shapes(true_graph.outputs,
+ false_graph.outputs),
name=scope)
# Set the flag to enable lowering on the `if` op if necessary
@@ -175,9 +178,12 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
# Create the gradient If op.
tensors = gen_functional_ops._if(
- op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs],
+ op.inputs[0],
+ grad_inputs, [t.dtype for t in true_grad_graph.outputs],
_create_new_tf_function(true_grad_graph),
- _create_new_tf_function(false_grad_graph))
+ _create_new_tf_function(false_grad_graph),
+ output_shapes=_get_output_shapes(true_grad_graph.outputs,
+ false_grad_graph.outputs))
# The predicate has no gradient.
return [None] + tensors[:num_grad_outputs]
@@ -480,6 +486,14 @@ def _check_same_outputs(true_graph, false_graph):
" false_fn: %s" % (true_output_types, false_output_types))
+def _get_output_shapes(true_graph_outputs, false_graph_outputs):
+ output_shapes = [
+ t_out.shape.most_specific_compatible_shape(f_out.shape)
+ for t_out, f_out in zip(true_graph_outputs, false_graph_outputs)
+ ]
+ return output_shapes
+
+
def _is_ancestor(graph, maybe_ancestor):
if maybe_ancestor == graph:
return True