aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_ops.py
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-02-07 14:36:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-07 14:39:49 -0800
commitd90054e7c0f41f4bab81df0548577a73b939a87a (patch)
treea15aea686a9d3f305e316d2a6ada0859ad8170d1 /tensorflow/python/ops/math_ops.py
parent8461760f9f6cde8ed97507484d2a879140141032 (diff)
Merge changes from github.
PiperOrigin-RevId: 184897758
Diffstat (limited to 'tensorflow/python/ops/math_ops.py')
-rw-r--r--tensorflow/python/ops/math_ops.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 827e3caa36..9a8ac93de9 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2826,10 +2826,14 @@ def tensordot(a, b, axes, name=None):
"""Generates two sets of contraction axes for the two tensor arguments."""
a_shape = a.get_shape()
if isinstance(axes, compat.integral_types):
- if axes < 1:
- raise ValueError("'axes' must be at least 1.")
+ if axes < 0:
+ raise ValueError("'axes' must be at least 0.")
if a_shape.ndims is not None:
- return range(a_shape.ndims - axes, a_shape.ndims), range(axes)
+ if axes > a_shape.ndims:
+ raise ValueError("'axes' must not be larger than the number of "
+ "dimensions of tensor %s." % a)
+ return (list(xrange(a_shape.ndims - axes, a_shape.ndims)),
+ list(xrange(axes)))
else:
rank = array_ops.rank(a)
return (range(rank - axes, rank, dtype=dtypes.int32),