aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/all_reduce
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-28 11:34:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-28 11:36:34 -0700
commitcc9944abe196827bae38975d813ee3e428349dcb (patch)
tree8655d58161ad7885df1c806dc3fe94ab5891fab8 /tensorflow/contrib/all_reduce
parentb8384bbe0325c5b1c20838f9e6fd494e78e299dc (diff)
In contrib/all_reduce raise a ValueError if the input tensors
do not have fully-defined shapes. PiperOrigin-RevId: 190804146
Diffstat (limited to 'tensorflow/contrib/all_reduce')
-rw-r--r--tensorflow/contrib/all_reduce/python/all_reduce.py7
-rw-r--r--tensorflow/contrib/all_reduce/python/all_reduce_test.py6
2 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py
index 6658f0d9c1..8add2aacff 100644
--- a/tensorflow/contrib/all_reduce/python/all_reduce.py
+++ b/tensorflow/contrib/all_reduce/python/all_reduce.py
@@ -38,16 +38,15 @@ def _flatten_tensors(tensors):
shape: the original shape of each element of input tensors
Raises:
- ValueError: tensors are empty or non-isomorphic.
+ ValueError: tensors are empty or non-isomorphic or have unknown shape.
"""
if not tensors:
raise ValueError("tensors cannot be empty")
shape = tensors[0].shape
for tensor in tensors:
shape = shape.merge_with(tensor.shape)
- if shape.ndims is None:
- raise ValueError("At least one of the tensors in 'tensors' must have "
- "statically known rank.")
+ if not shape.is_fully_defined():
+ raise ValueError("Tensors must have statically known shape.")
if len(shape) != 1:
reshaped = []
for t in tensors:
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce_test.py b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
index 47bab0a367..b3f5d92259 100644
--- a/tensorflow/contrib/all_reduce/python/all_reduce_test.py
+++ b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
@@ -36,6 +36,12 @@ from tensorflow.python.platform import tf_logging
class AllReduceTest(test_util.TensorFlowTestCase):
+ def testFlattenTensorsShapesDefined(self):
+ x = array_ops.placeholder(types_pb2.DT_FLOAT, [None])
+ with self.assertRaisesRegexp(ValueError,
+ "must have statically known shape"):
+ ar._flatten_tensors([x, x])
+
def testRingPermutations(self):
# 0 devices
pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 0, [])