From cc9944abe196827bae38975d813ee3e428349dcb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 28 Mar 2018 11:34:32 -0700 Subject: In contrib/all_reduce raise a ValueError if the input tensors do not have fully-defined shapes. PiperOrigin-RevId: 190804146 --- tensorflow/contrib/all_reduce/python/all_reduce.py | 7 +++---- tensorflow/contrib/all_reduce/python/all_reduce_test.py | 6 ++++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 6658f0d9c1..8add2aacff 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -38,16 +38,15 @@ def _flatten_tensors(tensors): shape: the original shape of each element of input tensors Raises: - ValueError: tensors are empty or non-isomorphic. + ValueError: tensors are empty or non-isomorphic or have unknown shape. """ if not tensors: raise ValueError("tensors cannot be empty") shape = tensors[0].shape for tensor in tensors: shape = shape.merge_with(tensor.shape) - if shape.ndims is None: - raise ValueError("At least one of the tensors in 'tensors' must have " - "statically known rank.") + if not shape.is_fully_defined(): + raise ValueError("Tensors must have statically known shape.") if len(shape) != 1: reshaped = [] for t in tensors: diff --git a/tensorflow/contrib/all_reduce/python/all_reduce_test.py b/tensorflow/contrib/all_reduce/python/all_reduce_test.py index 47bab0a367..b3f5d92259 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce_test.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce_test.py @@ -36,6 +36,12 @@ from tensorflow.python.platform import tf_logging class AllReduceTest(test_util.TensorFlowTestCase): + def testFlattenTensorsShapesDefined(self): + x = array_ops.placeholder(types_pb2.DT_FLOAT, [None]) + with self.assertRaisesRegexp(ValueError, + "must have statically known shape"): + ar._flatten_tensors([x, x]) + def testRingPermutations(self): # 0 devices pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 0, []) -- cgit v1.2.3