From 6f00d6a8fff6a32e7cf95449420e340ce8fa8f21 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 7 Mar 2017 19:39:07 -0800 Subject: Add missing unit test case for tensordot and fix typos in implementation. Change: 149498198 --- .../python/kernel_tests/tensordot_op_test.py | 26 ++++--- tensorflow/python/ops/math_ops.py | 87 +++++++++++++--------- 2 files changed, 67 insertions(+), 46 deletions(-) diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py index a147fc1dbf..bce7e30b68 100644 --- a/tensorflow/python/kernel_tests/tensordot_op_test.py +++ b/tensorflow/python/kernel_tests/tensordot_op_test.py @@ -54,10 +54,10 @@ class TensordotTest(test_lib.TestCase): b_ph = array_ops.placeholder(dtypes.float32) axes_ph = array_ops.placeholder(dtypes.int32) output = math_ops.tensordot(a_ph, b_ph, axes_ph) - _ = sess.run([output], - feed_dict={a_ph: a, - b_ph: b, - axes_ph: (a_axes, b_axes)}) + _ = sess.run( + [output], feed_dict={a_ph: a, + b_ph: b, + axes_ph: (a_axes, b_axes)}) def test_invalid_axes(self): a = [[1, 2], [3, 4]] @@ -79,10 +79,10 @@ class TensordotTest(test_lib.TestCase): for axes_value in 1, [1], [0, 1], [[1]], [[0, 1]], [[0], [7]]: with self.test_session() as sess: with self.assertRaises(errors_impl.InvalidArgumentError): - _ = sess.run([output], - feed_dict={a_ph: a, - b_ph: b, - axes_ph: axes_value}) + _ = sess.run( + [output], feed_dict={a_ph: a, + b_ph: b, + axes_ph: axes_value}) def test_no_partial_shape_inference(self): # If one of the shapes is only partially defined, the output shape is @@ -173,8 +173,14 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): all_axes.append(a_np.ndim - 1) for axes in all_axes: np_ans = np.tensordot(a_np, b_np, axes=axes) - with self.test_session(use_gpu=True): - tf_ans = math_ops.tensordot(a_np, b_np, axes=axes).eval() + with self.test_session(use_gpu=True) as sess: + if dynamic_shape_: + a = array_ops.placeholder(dtype_) + b = array_ops.placeholder(dtype_) + c = math_ops.tensordot(a, b, axes=axes) + tf_ans = sess.run(c, feed_dict={a: a_np, b: b_np}) + else: + tf_ans = math_ops.tensordot(a_np, b_np, axes=axes).eval() self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol) self.assertAllEqual(tf_ans.shape, np_ans.shape) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index e1b52a4086..2649cda7a6 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -160,7 +160,6 @@ from tensorflow.python.ops.gen_math_ops import * from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated - # Aliases for some automatically-generated names. linspace = gen_math_ops.lin_space @@ -177,8 +176,9 @@ def argmax(input, axis=None, name=None, dimension=None): return gen_math_ops.arg_max(input, axis, name) -argmax.__doc__ = (gen_math_ops.arg_max.__doc__.replace( - "dimensions", "axes").replace("dimension", "axis")) +argmax.__doc__ = (gen_math_ops.arg_max.__doc__.replace("dimensions", + "axes").replace( + "dimension", "axis")) # TODO(aselle:deprecate arg_min) @@ -192,8 +192,9 @@ def argmin(input, axis=None, name=None, dimension=None): return gen_math_ops.arg_min(input, axis, name) -argmin.__doc__ = (gen_math_ops.arg_min.__doc__.replace( - "dimensions", "axes").replace("dimension", "axis")) +argmin.__doc__ = (gen_math_ops.arg_min.__doc__.replace("dimensions", + "axes").replace( + "dimension", "axis")) # pylint: enable=redefined-builtin @@ -232,6 +233,8 @@ def abs(x, name=None): if x.dtype in (dtypes.complex64, dtypes.complex128): return gen_math_ops._complex_abs(x, Tout=x.dtype.real_dtype, name=name) return gen_math_ops._abs(x, name=name) + + # pylint: enable=g-docstring-has-escape @@ -271,6 +274,8 @@ def divide(x, y, name=None): def multiply(x, y, name=None): return gen_math_ops._mul(x, y, name) + + multiply.__doc__ = gen_math_ops._mul.__doc__.replace("Mul", "`tf.multiply`") @@ -280,12 +285,16 @@ multiply.__doc__ = gen_math_ops._mul.__doc__.replace("Mul", "`tf.multiply`") "`tf.mul(x, y)` is deprecated, please use `tf.multiply(x, y)` or `x * y`") def _mul(x, y, name=None): return gen_math_ops._mul(x, y, name) -_mul.__doc__ = (gen_math_ops._mul.__doc__ - + ("" if _mul.__doc__ is None else _mul.__doc__)) + + +_mul.__doc__ = (gen_math_ops._mul.__doc__ + + ("" if _mul.__doc__ is None else _mul.__doc__)) def subtract(x, y, name=None): return gen_math_ops._sub(x, y, name) + + subtract.__doc__ = gen_math_ops._sub.__doc__.replace("`Sub`", "`tf.subtract`") @@ -295,8 +304,10 @@ subtract.__doc__ = gen_math_ops._sub.__doc__.replace("`Sub`", "`tf.subtract`") "`tf.sub(x, y)` is deprecated, please use `tf.subtract(x, y)` or `x - y`") def _sub(x, y, name=None): return gen_math_ops._sub(x, y, name) -_sub.__doc__ = (gen_math_ops._sub.__doc__ - + ("" if _sub.__doc__ is None else _sub.__doc__)) + + +_sub.__doc__ = (gen_math_ops._sub.__doc__ + + ("" if _sub.__doc__ is None else _sub.__doc__)) # pylint: disable=g-docstring-has-escape @@ -320,13 +331,14 @@ def negative(x, name=None): indices=x.indices, values=x_neg, dense_shape=x.dense_shape) else: return gen_math_ops._neg(x, name=name) + + # pylint: enable=g-docstring-has-escape # pylint: disable=g-docstring-has-escape -@deprecated( - "2016-12-30", - "`tf.neg(x)` is deprecated, please use `tf.negative(x)` or `-x`") +@deprecated("2016-12-30", + "`tf.neg(x)` is deprecated, please use `tf.negative(x)` or `-x`") def _neg(x, name=None): """Computes numerical negative value element-wise. @@ -341,6 +353,8 @@ def _neg(x, name=None): A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. """ return negative(x, name) + + # pylint: enable=g-docstring-has-escape @@ -525,8 +539,7 @@ def complex(real, imag, name=None): Tout = dtypes.complex64 else: raise TypeError("real and imag have incorrect types: " - "{} {}".format(real.dtype.name, - imag.dtype.name)) + "{} {}".format(real.dtype.name, imag.dtype.name)) return gen_math_ops._complex(real, imag, Tout=Tout, name=name) @@ -681,15 +694,15 @@ def saturate_cast(value, dtype, name=None): value = ops.convert_to_tensor(value, name="value") dtype = dtypes.as_dtype(dtype).base_dtype if value.dtype.min < dtype.min: - value = gen_math_ops.maximum( - value, - ops.convert_to_tensor( - dtype.min, dtype=value.dtype, name="min")) + value = gen_math_ops.maximum(value, + ops.convert_to_tensor( + dtype.min, dtype=value.dtype, + name="min")) if value.dtype.max > dtype.max: - value = gen_math_ops.minimum( - value, - ops.convert_to_tensor( - dtype.max, dtype=value.dtype, name="max")) + value = gen_math_ops.minimum(value, + ops.convert_to_tensor( + dtype.max, dtype=value.dtype, + name="max")) return cast(value, dtype, name=name) @@ -802,11 +815,13 @@ def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor): def binary_op_wrapper_sparse(sp_x, y): with ops.name_scope(None, op_name, [sp_x, y]) as name: y = ops.convert_to_tensor(y, dtype=sp_x.dtype.base_dtype, name="y") - return sparse_tensor.SparseTensor( - sp_x.indices, - func( - sp_x.indices, sp_x.values, sp_x.dense_shape, y, name=name), - sp_x.dense_shape) + return sparse_tensor.SparseTensor(sp_x.indices, + func( + sp_x.indices, + sp_x.values, + sp_x.dense_shape, + y, + name=name), sp_x.dense_shape) def r_binary_op_wrapper(y, x): with ops.name_scope(None, op_name, [x, y]) as name: @@ -854,8 +869,8 @@ _TRUEDIV_TABLE = { # to explicitly use the "/" operator to invoke either truediv or div. def _sparse_dense_truediv(sp_indices, sp_values, sp_shape, y, name=None): """Internal helper function for 'sp_t / dense_t'.""" - with ops.name_scope(name, "truediv", - [sp_indices, sp_values, sp_shape, y]) as name: + with ops.name_scope(name, "truediv", [sp_indices, sp_values, sp_shape, + y]) as name: sp_values = ops.convert_to_tensor(sp_values, name="sp_values") y = ops.convert_to_tensor(y, name="y") x_dtype = sp_values.dtype.base_dtype @@ -1129,8 +1144,9 @@ def range(start, limit=None, delta=1, dtype=None, name="range"): dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 ] assert all(arg.dtype in dtype_hierarchy for arg in [start, limit, delta]) - inferred_dtype = max([arg.dtype for arg in [start, limit, delta]], - key=dtype_hierarchy.index) + inferred_dtype = max( + [arg.dtype for arg in [start, limit, delta]], + key=dtype_hierarchy.index) start = cast(start, inferred_dtype) limit = cast(limit, inferred_dtype) @@ -1941,8 +1957,8 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): zeros.set_shape(shape) ref = state_ops.assign(var, zeros, validate_shape=False) update_ops = [ - state_ops.assign_add( - ref, input_tensor, use_locking=True) for input_tensor in inputs + state_ops.assign_add(ref, input_tensor, use_locking=True) + for input_tensor in inputs ] with ops.control_dependencies(update_ops): return gen_state_ops._destroy_temporary_variable( @@ -2270,9 +2286,8 @@ def tensordot(a, b, axes, name=None): return range(a_shape.ndims - axes, a_shape.ndims), range(axes) else: rank = array_ops.rank(a) - return (array_ops.range( - rank - axes, rank, dtype=dtypes.int32), array_ops.range( - rank, dtype=dtypes.int32)) + return (range(rank - axes, rank, dtype=dtypes.int32), range( + axes, dtype=dtypes.int32)) elif isinstance(axes, (list, tuple)): if len(axes) != 2: raise ValueError("'axes' must be an integer or have length 2.") -- cgit v1.2.3