diff options
Diffstat (limited to 'tensorflow/python/ops/math_ops.py')
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 827e3caa36..9a8ac93de9 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -2826,10 +2826,14 @@ def tensordot(a, b, axes, name=None): """Generates two sets of contraction axes for the two tensor arguments.""" a_shape = a.get_shape() if isinstance(axes, compat.integral_types): - if axes < 1: - raise ValueError("'axes' must be at least 1.") + if axes < 0: + raise ValueError("'axes' must be at least 0.") if a_shape.ndims is not None: - return range(a_shape.ndims - axes, a_shape.ndims), range(axes) + if axes > a_shape.ndims: + raise ValueError("'axes' must not be larger than the number of " + "dimensions of tensor %s." % a) + return (list(xrange(a_shape.ndims - axes, a_shape.ndims)), + list(xrange(axes))) else: rank = array_ops.rank(a) return (range(rank - axes, rank, dtype=dtypes.int32), |