diff options
author | 2018-09-24 21:29:42 -0700 | |
---|---|---|
committer | 2018-09-24 21:34:13 -0700 | |
commit | c1644948d23cae271b140d67101c1a386e5495fd (patch) | |
tree | 002efca36c4f95f75b08358343c3701de014880b /tensorflow/python/ops | |
parent | 9875df75c308d7498e601ae9a4b57db6aad47056 (diff) |
Unpack output of cond_v2 if it is a singleton to match behavior of cond.
PiperOrigin-RevId: 214381126
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r-- | tensorflow/python/ops/cond_v2_impl.py | 6 | ||||
-rw-r--r-- | tensorflow/python/ops/image_ops_impl.py | 2 |
2 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py index c6a6b2a7fa..f8b1ddb140 100644 --- a/tensorflow/python/ops/cond_v2_impl.py +++ b/tensorflow/python/ops/cond_v2_impl.py @@ -119,7 +119,11 @@ def cond_v2(pred, true_fn, false_fn, name="cond"): attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access - return tuple(tensors[:num_cond_outputs]) + result = tuple(tensors[:num_cond_outputs]) + if len(result) == 1: + return result[0] + else: + return result @ops.RegisterGradient("If") diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 208b56e909..1c75aab578 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -329,8 +329,6 @@ def _random_flip(image, flip_index, seed, scope_name): lambda: image, name=scope ) - if isinstance(result, tuple): - result = result[0] # TODO(b/111124878) remove this logic (CondV2). return fix_image_flip_shape(image, result) elif shape.ndims == 4: batch_size = array_ops.shape(image)[0] |