aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-07 19:39:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-07 19:50:00 -0800
commit6f00d6a8fff6a32e7cf95449420e340ce8fa8f21 (patch)
tree02913dce224e268461c1e65c5b979153eafb10ae
parenta3e975073b309cb2138e7279ba201f2d0f8a8469 (diff)
Add missing unit test case for tensordot and fix typos in implementation.
Change: 149498198
-rw-r--r--tensorflow/python/kernel_tests/tensordot_op_test.py26
-rw-r--r--tensorflow/python/ops/math_ops.py87
2 files changed, 67 insertions, 46 deletions
diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py
index a147fc1dbf..bce7e30b68 100644
--- a/tensorflow/python/kernel_tests/tensordot_op_test.py
+++ b/tensorflow/python/kernel_tests/tensordot_op_test.py
@@ -54,10 +54,10 @@ class TensordotTest(test_lib.TestCase):
b_ph = array_ops.placeholder(dtypes.float32)
axes_ph = array_ops.placeholder(dtypes.int32)
output = math_ops.tensordot(a_ph, b_ph, axes_ph)
- _ = sess.run([output],
- feed_dict={a_ph: a,
- b_ph: b,
- axes_ph: (a_axes, b_axes)})
+ _ = sess.run(
+ [output], feed_dict={a_ph: a,
+ b_ph: b,
+ axes_ph: (a_axes, b_axes)})
def test_invalid_axes(self):
a = [[1, 2], [3, 4]]
@@ -79,10 +79,10 @@ class TensordotTest(test_lib.TestCase):
for axes_value in 1, [1], [0, 1], [[1]], [[0, 1]], [[0], [7]]:
with self.test_session() as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
- _ = sess.run([output],
- feed_dict={a_ph: a,
- b_ph: b,
- axes_ph: axes_value})
+ _ = sess.run(
+ [output], feed_dict={a_ph: a,
+ b_ph: b,
+ axes_ph: axes_value})
def test_no_partial_shape_inference(self):
# If one of the shapes is only partially defined, the output shape is
@@ -173,8 +173,14 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
all_axes.append(a_np.ndim - 1)
for axes in all_axes:
np_ans = np.tensordot(a_np, b_np, axes=axes)
- with self.test_session(use_gpu=True):
- tf_ans = math_ops.tensordot(a_np, b_np, axes=axes).eval()
+ with self.test_session(use_gpu=True) as sess:
+ if dynamic_shape_:
+ a = array_ops.placeholder(dtype_)
+ b = array_ops.placeholder(dtype_)
+ c = math_ops.tensordot(a, b, axes=axes)
+ tf_ans = sess.run(c, feed_dict={a: a_np, b: b_np})
+ else:
+ tf_ans = math_ops.tensordot(a_np, b_np, axes=axes).eval()
self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol)
self.assertAllEqual(tf_ans.shape, np_ans.shape)
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index e1b52a4086..2649cda7a6 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -160,7 +160,6 @@ from tensorflow.python.ops.gen_math_ops import *
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated
-
# Aliases for some automatically-generated names.
linspace = gen_math_ops.lin_space
@@ -177,8 +176,9 @@ def argmax(input, axis=None, name=None, dimension=None):
return gen_math_ops.arg_max(input, axis, name)
-argmax.__doc__ = (gen_math_ops.arg_max.__doc__.replace(
- "dimensions", "axes").replace("dimension", "axis"))
+argmax.__doc__ = (gen_math_ops.arg_max.__doc__.replace("dimensions",
+ "axes").replace(
+ "dimension", "axis"))
# TODO(aselle:deprecate arg_min)
@@ -192,8 +192,9 @@ def argmin(input, axis=None, name=None, dimension=None):
return gen_math_ops.arg_min(input, axis, name)
-argmin.__doc__ = (gen_math_ops.arg_min.__doc__.replace(
- "dimensions", "axes").replace("dimension", "axis"))
+argmin.__doc__ = (gen_math_ops.arg_min.__doc__.replace("dimensions",
+ "axes").replace(
+ "dimension", "axis"))
# pylint: enable=redefined-builtin
@@ -232,6 +233,8 @@ def abs(x, name=None):
if x.dtype in (dtypes.complex64, dtypes.complex128):
return gen_math_ops._complex_abs(x, Tout=x.dtype.real_dtype, name=name)
return gen_math_ops._abs(x, name=name)
+
+
# pylint: enable=g-docstring-has-escape
@@ -271,6 +274,8 @@ def divide(x, y, name=None):
def multiply(x, y, name=None):
return gen_math_ops._mul(x, y, name)
+
+
multiply.__doc__ = gen_math_ops._mul.__doc__.replace("Mul", "`tf.multiply`")
@@ -280,12 +285,16 @@ multiply.__doc__ = gen_math_ops._mul.__doc__.replace("Mul", "`tf.multiply`")
"`tf.mul(x, y)` is deprecated, please use `tf.multiply(x, y)` or `x * y`")
def _mul(x, y, name=None):
return gen_math_ops._mul(x, y, name)
-_mul.__doc__ = (gen_math_ops._mul.__doc__
- + ("" if _mul.__doc__ is None else _mul.__doc__))
+
+
+_mul.__doc__ = (gen_math_ops._mul.__doc__ +
+ ("" if _mul.__doc__ is None else _mul.__doc__))
def subtract(x, y, name=None):
return gen_math_ops._sub(x, y, name)
+
+
subtract.__doc__ = gen_math_ops._sub.__doc__.replace("`Sub`", "`tf.subtract`")
@@ -295,8 +304,10 @@ subtract.__doc__ = gen_math_ops._sub.__doc__.replace("`Sub`", "`tf.subtract`")
"`tf.sub(x, y)` is deprecated, please use `tf.subtract(x, y)` or `x - y`")
def _sub(x, y, name=None):
return gen_math_ops._sub(x, y, name)
-_sub.__doc__ = (gen_math_ops._sub.__doc__
- + ("" if _sub.__doc__ is None else _sub.__doc__))
+
+
+_sub.__doc__ = (gen_math_ops._sub.__doc__ +
+ ("" if _sub.__doc__ is None else _sub.__doc__))
# pylint: disable=g-docstring-has-escape
@@ -320,13 +331,14 @@ def negative(x, name=None):
indices=x.indices, values=x_neg, dense_shape=x.dense_shape)
else:
return gen_math_ops._neg(x, name=name)
+
+
# pylint: enable=g-docstring-has-escape
# pylint: disable=g-docstring-has-escape
-@deprecated(
- "2016-12-30",
- "`tf.neg(x)` is deprecated, please use `tf.negative(x)` or `-x`")
+@deprecated("2016-12-30",
+ "`tf.neg(x)` is deprecated, please use `tf.negative(x)` or `-x`")
def _neg(x, name=None):
"""Computes numerical negative value element-wise.
@@ -341,6 +353,8 @@ def _neg(x, name=None):
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
"""
return negative(x, name)
+
+
# pylint: enable=g-docstring-has-escape
@@ -525,8 +539,7 @@ def complex(real, imag, name=None):
Tout = dtypes.complex64
else:
raise TypeError("real and imag have incorrect types: "
- "{} {}".format(real.dtype.name,
- imag.dtype.name))
+ "{} {}".format(real.dtype.name, imag.dtype.name))
return gen_math_ops._complex(real, imag, Tout=Tout, name=name)
@@ -681,15 +694,15 @@ def saturate_cast(value, dtype, name=None):
value = ops.convert_to_tensor(value, name="value")
dtype = dtypes.as_dtype(dtype).base_dtype
if value.dtype.min < dtype.min:
- value = gen_math_ops.maximum(
- value,
- ops.convert_to_tensor(
- dtype.min, dtype=value.dtype, name="min"))
+ value = gen_math_ops.maximum(value,
+ ops.convert_to_tensor(
+ dtype.min, dtype=value.dtype,
+ name="min"))
if value.dtype.max > dtype.max:
- value = gen_math_ops.minimum(
- value,
- ops.convert_to_tensor(
- dtype.max, dtype=value.dtype, name="max"))
+ value = gen_math_ops.minimum(value,
+ ops.convert_to_tensor(
+ dtype.max, dtype=value.dtype,
+ name="max"))
return cast(value, dtype, name=name)
@@ -802,11 +815,13 @@ def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor):
def binary_op_wrapper_sparse(sp_x, y):
with ops.name_scope(None, op_name, [sp_x, y]) as name:
y = ops.convert_to_tensor(y, dtype=sp_x.dtype.base_dtype, name="y")
- return sparse_tensor.SparseTensor(
- sp_x.indices,
- func(
- sp_x.indices, sp_x.values, sp_x.dense_shape, y, name=name),
- sp_x.dense_shape)
+ return sparse_tensor.SparseTensor(sp_x.indices,
+ func(
+ sp_x.indices,
+ sp_x.values,
+ sp_x.dense_shape,
+ y,
+ name=name), sp_x.dense_shape)
def r_binary_op_wrapper(y, x):
with ops.name_scope(None, op_name, [x, y]) as name:
@@ -854,8 +869,8 @@ _TRUEDIV_TABLE = {
# to explicitly use the "/" operator to invoke either truediv or div.
def _sparse_dense_truediv(sp_indices, sp_values, sp_shape, y, name=None):
"""Internal helper function for 'sp_t / dense_t'."""
- with ops.name_scope(name, "truediv",
- [sp_indices, sp_values, sp_shape, y]) as name:
+ with ops.name_scope(name, "truediv", [sp_indices, sp_values, sp_shape,
+ y]) as name:
sp_values = ops.convert_to_tensor(sp_values, name="sp_values")
y = ops.convert_to_tensor(y, name="y")
x_dtype = sp_values.dtype.base_dtype
@@ -1129,8 +1144,9 @@ def range(start, limit=None, delta=1, dtype=None, name="range"):
dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
]
assert all(arg.dtype in dtype_hierarchy for arg in [start, limit, delta])
- inferred_dtype = max([arg.dtype for arg in [start, limit, delta]],
- key=dtype_hierarchy.index)
+ inferred_dtype = max(
+ [arg.dtype for arg in [start, limit, delta]],
+ key=dtype_hierarchy.index)
start = cast(start, inferred_dtype)
limit = cast(limit, inferred_dtype)
@@ -1941,8 +1957,8 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
zeros.set_shape(shape)
ref = state_ops.assign(var, zeros, validate_shape=False)
update_ops = [
- state_ops.assign_add(
- ref, input_tensor, use_locking=True) for input_tensor in inputs
+ state_ops.assign_add(ref, input_tensor, use_locking=True)
+ for input_tensor in inputs
]
with ops.control_dependencies(update_ops):
return gen_state_ops._destroy_temporary_variable(
@@ -2270,9 +2286,8 @@ def tensordot(a, b, axes, name=None):
return range(a_shape.ndims - axes, a_shape.ndims), range(axes)
else:
rank = array_ops.rank(a)
- return (array_ops.range(
- rank - axes, rank, dtype=dtypes.int32), array_ops.range(
- rank, dtype=dtypes.int32))
+ return (range(rank - axes, rank, dtype=dtypes.int32), range(
+ axes, dtype=dtypes.int32))
elif isinstance(axes, (list, tuple)):
if len(axes) != 2:
raise ValueError("'axes' must be an integer or have length 2.")