aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2017-01-11 09:35:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-11 09:48:10 -0800
commitbe270ecb79c8548a0caddf67908189e6169b1472 (patch)
treec6a1b72479dcac3c3ece17b7cf5ddfdfe04e5b8d /tensorflow
parentd4a9d91bc68ffa9f0148ae6fe344b7d3e3de7221 (diff)
Deprecate tf.neg, tf.mul, tf.sub (and remove math_ops.{neg,mul,sub} usages
tf.negative, tf.multiply, tf.subtract are the new names - Also enabled deprecation warning (to be completely removed by friday) Change: 144215355
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py12
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py2
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/core.py6
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/core_test.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/dataframe/transforms/binary_transforms.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/dataframe/transforms/unary_transforms.py2
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py4
-rw-r--r--tensorflow/python/kernel_tests/basic_gpu_test.py18
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py70
-rw-r--r--tensorflow/python/ops/math_ops.py12
-rw-r--r--tensorflow/python/ops/metrics_impl.py2
11 files changed, 68 insertions, 68 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 9d197b0646..852c80db1f 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -159,17 +159,17 @@ class BinaryOpsTest(XLATestCase):
expected=np.array([[8], [9]], dtype=dtype))
self._testBinary(
- math_ops.sub,
+ math_ops.subtract,
np.array([1, 2], dtype=dtype),
np.array([10, 20], dtype=dtype),
expected=np.array([-9, -18], dtype=dtype))
self._testBinary(
- math_ops.sub,
+ math_ops.subtract,
dtype(5),
np.array([1, 2], dtype=dtype),
expected=np.array([4, 3], dtype=dtype))
self._testBinary(
- math_ops.sub,
+ math_ops.subtract,
np.array([[1], [2]], dtype=dtype),
dtype(7),
expected=np.array([[-6], [-5]], dtype=dtype))
@@ -207,17 +207,17 @@ class BinaryOpsTest(XLATestCase):
expected=np.array([[7], [2]], dtype=dtype))
self._testBinary(
- math_ops.mul,
+ math_ops.multiply,
np.array([1, 20], dtype=dtype),
np.array([10, 2], dtype=dtype),
expected=np.array([10, 40], dtype=dtype))
self._testBinary(
- math_ops.mul,
+ math_ops.multiply,
dtype(5),
np.array([1, 20], dtype=dtype),
expected=np.array([5, 100], dtype=dtype))
self._testBinary(
- math_ops.mul,
+ math_ops.multiply,
np.array([[10], [2]], dtype=dtype),
dtype(7),
expected=np.array([[70], [14]], dtype=dtype))
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 35f4b1071a..ff565a9815 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -216,7 +216,7 @@ class UnaryOpsTest(XLATestCase):
expected=np.array([[2, 1]], dtype=dtype))
self._testUnary(
- math_ops.neg,
+ math_ops.negative,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[1, -1]], dtype=dtype))
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core.py b/tensorflow/contrib/labeled_tensor/python/ops/core.py
index cd590f9ffd..214e21aa19 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/core.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/core.py
@@ -1096,7 +1096,7 @@ def define_unary_op(op_name, elementwise_function):
abs_function = define_unary_op('abs', math_ops.abs)
-neg = define_unary_op('neg', math_ops.neg)
+neg = define_unary_op('neg', math_ops.negative)
sign = define_unary_op('sign', math_ops.sign)
reciprocal = define_unary_op('reciprocal', math_ops.reciprocal)
square = define_unary_op('square', math_ops.square)
@@ -1171,8 +1171,8 @@ def define_binary_op(op_name, elementwise_function):
add = define_binary_op('add', math_ops.add)
-sub = define_binary_op('sub', math_ops.sub)
-mul = define_binary_op('mul', math_ops.mul)
+sub = define_binary_op('sub', math_ops.subtract)
+mul = define_binary_op('mul', math_ops.multiply)
div = define_binary_op('div', math_ops.div)
mod = define_binary_op('mod', math_ops.mod)
pow_function = define_binary_op('pow', math_ops.pow)
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py
index 19cc85cb41..1f4a3ef568 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py
@@ -703,7 +703,7 @@ class CoreUnaryOpsTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin):
self.ops = [
('abs', operator.abs, math_ops.abs, core.abs_function),
- ('neg', operator.neg, math_ops.neg, core.neg),
+ ('neg', operator.neg, math_ops.negative, core.neg),
# TODO(shoyer): add unary + to core TensorFlow
('pos', None, None, None),
('sign', None, math_ops.sign, core.sign),
@@ -780,8 +780,8 @@ class CoreBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):
# elementwise for LabeledTensor, either.
self.ops = [
('add', operator.add, math_ops.add, core.add),
- ('sub', operator.sub, math_ops.sub, core.sub),
- ('mul', operator.mul, math_ops.mul, core.mul),
+ ('sub', operator.sub, math_ops.subtract, core.sub),
+ ('mul', operator.mul, math_ops.multiply, core.mul),
('div', operator.truediv, math_ops.div, core.div),
('mod', operator.mod, math_ops.mod, core.mod),
('pow', operator.pow, math_ops.pow, core.pow_function),
diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/binary_transforms.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/binary_transforms.py
index 14a82f2313..aa3949c8c0 100644
--- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/binary_transforms.py
+++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/binary_transforms.py
@@ -31,7 +31,7 @@ BINARY_TRANSFORMS = [("__eq__", math_ops.equal),
("__ge__", math_ops.greater_equal),
("__lt__", math_ops.less),
("__le__", math_ops.less_equal),
- ("__mul__", math_ops.mul),
+ ("__mul__", math_ops.multiply),
("__div__", math_ops.div),
("__truediv__", math_ops.truediv),
("__floordiv__", math_ops.floordiv),
diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/unary_transforms.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/unary_transforms.py
index a0e444e054..28a8561b61 100644
--- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/unary_transforms.py
+++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/unary_transforms.py
@@ -27,7 +27,7 @@ from tensorflow.python.ops import math_ops
# Each entry is a mapping from registered_name to operation. Each operation is
# wrapped in a transform and then registered as a member function
# `Series`.registered_name().
-UNARY_TRANSFORMS = [("__neg__", math_ops.neg),
+UNARY_TRANSFORMS = [("__neg__", math_ops.negative),
("sign", math_ops.sign),
("reciprocal", math_ops.reciprocal),
("square", math_ops.square),
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index c455340759..6ceeacbc72 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -2306,8 +2306,8 @@ def streaming_covariance(predictions,
else:
weights = _broadcast_weights(weights, labels)
batch_count = math_ops.reduce_sum(weights) # n_B in eqn
- weighted_predictions = math_ops.mul(predictions, weights)
- weighted_labels = math_ops.mul(labels, weights)
+ weighted_predictions = math_ops.multiply(predictions, weights)
+ weighted_labels = math_ops.multiply(labels, weights)
update_count = state_ops.assign_add(count, batch_count) # n_AB in eqn
prev_count = update_count - batch_count # n_A in update equation
diff --git a/tensorflow/python/kernel_tests/basic_gpu_test.py b/tensorflow/python/kernel_tests/basic_gpu_test.py
index a9bcb3fea9..0438d95bc4 100644
--- a/tensorflow/python/kernel_tests/basic_gpu_test.py
+++ b/tensorflow/python/kernel_tests/basic_gpu_test.py
@@ -52,8 +52,8 @@ class GPUBinaryOpsTest(test.TestCase):
x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float32)
y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float32)
self._compareGPU(x, y, np.add, math_ops.add)
- self._compareGPU(x, y, np.subtract, math_ops.sub)
- self._compareGPU(x, y, np.multiply, math_ops.mul)
+ self._compareGPU(x, y, np.subtract, math_ops.subtract)
+ self._compareGPU(x, y, np.multiply, math_ops.multiply)
self._compareGPU(x, y + 0.1, np.true_divide, math_ops.truediv)
self._compareGPU(x, y + 0.1, np.floor_divide, math_ops.floordiv)
self._compareGPU(x, y, np.power, math_ops.pow)
@@ -62,24 +62,24 @@ class GPUBinaryOpsTest(test.TestCase):
x = np.linspace(-5, 20, 15).reshape(3, 5).astype(np.float32)
y = np.linspace(20, -5, 30).reshape(2, 3, 5).astype(np.float32)
self._compareGPU(x, y, np.add, math_ops.add)
- self._compareGPU(x, y, np.subtract, math_ops.sub)
- self._compareGPU(x, y, np.multiply, math_ops.mul)
+ self._compareGPU(x, y, np.subtract, math_ops.subtract)
+ self._compareGPU(x, y, np.multiply, math_ops.multiply)
self._compareGPU(x, y + 0.1, np.true_divide, math_ops.truediv)
def testDoubleBasic(self):
x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float64)
y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float64)
self._compareGPU(x, y, np.add, math_ops.add)
- self._compareGPU(x, y, np.subtract, math_ops.sub)
- self._compareGPU(x, y, np.multiply, math_ops.mul)
+ self._compareGPU(x, y, np.subtract, math_ops.subtract)
+ self._compareGPU(x, y, np.multiply, math_ops.multiply)
self._compareGPU(x, y + 0.1, np.true_divide, math_ops.truediv)
def testDoubleWithBCast(self):
x = np.linspace(-5, 20, 15).reshape(3, 5).astype(np.float64)
y = np.linspace(20, -5, 30).reshape(2, 3, 5).astype(np.float64)
self._compareGPU(x, y, np.add, math_ops.add)
- self._compareGPU(x, y, np.subtract, math_ops.sub)
- self._compareGPU(x, y, np.multiply, math_ops.mul)
+ self._compareGPU(x, y, np.subtract, math_ops.subtract)
+ self._compareGPU(x, y, np.multiply, math_ops.multiply)
self._compareGPU(x, y + 0.1, np.true_divide, math_ops.truediv)
@@ -111,7 +111,7 @@ class MathBuiltinUnaryTest(test.TestCase):
self._compare(data, np.floor, math_ops.floor, use_gpu)
self._compare(data, np.log, math_ops.log, use_gpu)
self._compare(data, np.log1p, math_ops.log1p, use_gpu)
- self._compare(data, np.negative, math_ops.neg, use_gpu)
+ self._compare(data, np.negative, math_ops.negative, use_gpu)
self._compare(data, self._rsqrt, math_ops.rsqrt, use_gpu)
self._compare(data, np.sin, math_ops.sin, use_gpu)
self._compare(data, np.sqrt, math_ops.sqrt, use_gpu)
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index a828273314..cfe2754b32 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -186,7 +186,7 @@ class UnaryOpTest(test.TestCase):
self._compareBoth(x, np.abs, math_ops.abs)
self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.neg)
+ self._compareBoth(x, np.negative, math_ops.negative)
self._compareBoth(x, np.negative, _NEG)
self._compareBoth(y, self._inv, math_ops.reciprocal)
self._compareBoth(x, np.square, math_ops.square)
@@ -213,7 +213,7 @@ class UnaryOpTest(test.TestCase):
self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.neg)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
self._compareBothSparse(x, np.square, math_ops.square)
self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, math_ops.tanh)
@@ -230,7 +230,7 @@ class UnaryOpTest(test.TestCase):
x = np.empty((2, 0, 5), dtype=np.float32)
self._compareBoth(x, np.abs, math_ops.abs)
self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.neg)
+ self._compareBoth(x, np.negative, math_ops.negative)
self._compareBoth(x, np.negative, _NEG)
self._compareBoth(x, self._inv, math_ops.reciprocal)
self._compareBoth(x, np.square, math_ops.square)
@@ -255,7 +255,7 @@ class UnaryOpTest(test.TestCase):
self._compareBoth(x, np.arctan, math_ops.atan)
self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.neg)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
self._compareBothSparse(x, np.square, math_ops.square)
self._compareBothSparse(x, np.sqrt, math_ops.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, math_ops.tanh)
@@ -270,7 +270,7 @@ class UnaryOpTest(test.TestCase):
np.float64) # between -1 and 1
self._compareBoth(x, np.abs, math_ops.abs)
self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.neg)
+ self._compareBoth(x, np.negative, math_ops.negative)
self._compareBoth(x, np.negative, _NEG)
self._compareBoth(y, self._inv, math_ops.reciprocal)
self._compareBoth(x, np.square, math_ops.square)
@@ -297,7 +297,7 @@ class UnaryOpTest(test.TestCase):
self._compareBoth(k, np.tan, math_ops.tan)
self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.neg)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
self._compareBothSparse(x, np.square, math_ops.square)
self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, math_ops.tanh)
@@ -310,7 +310,7 @@ class UnaryOpTest(test.TestCase):
z = (x + 15.5).astype(np.float16) # all positive
self._compareBoth(x, np.abs, math_ops.abs)
self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.neg)
+ self._compareBoth(x, np.negative, math_ops.negative)
self._compareBoth(x, np.negative, _NEG)
self._compareBoth(y, self._inv, math_ops.reciprocal)
self._compareBoth(x, np.square, math_ops.square)
@@ -333,7 +333,7 @@ class UnaryOpTest(test.TestCase):
self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.neg)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
self._compareBothSparse(x, np.square, math_ops.square)
self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, math_ops.tanh)
@@ -344,13 +344,13 @@ class UnaryOpTest(test.TestCase):
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
self._compareCpu(x, np.abs, math_ops.abs)
self._compareCpu(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.neg)
+ self._compareBoth(x, np.negative, math_ops.negative)
self._compareBoth(x, np.negative, _NEG)
self._compareBoth(x, np.square, math_ops.square)
self._compareCpu(x, np.sign, math_ops.sign)
self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.neg)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
self._compareBothSparse(x, np.square, math_ops.square)
self._compareBothSparse(x, np.sign, math_ops.sign)
@@ -358,13 +358,13 @@ class UnaryOpTest(test.TestCase):
x = np.arange(-6 << 40, 6 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
self._compareCpu(x, np.abs, math_ops.abs)
self._compareCpu(x, np.abs, _ABS)
- self._compareCpu(x, np.negative, math_ops.neg)
+ self._compareCpu(x, np.negative, math_ops.negative)
self._compareCpu(x, np.negative, _NEG)
self._compareCpu(x, np.square, math_ops.square)
self._compareCpu(x, np.sign, math_ops.sign)
self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.neg)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
self._compareBothSparse(x, np.square, math_ops.square)
self._compareBothSparse(x, np.sign, math_ops.sign)
@@ -374,7 +374,7 @@ class UnaryOpTest(test.TestCase):
y = x + 0.5 # no zeros
self._compareCpu(x, np.abs, math_ops.abs)
self._compareCpu(x, np.abs, _ABS)
- self._compareCpu(x, np.negative, math_ops.neg)
+ self._compareCpu(x, np.negative, math_ops.negative)
self._compareCpu(x, np.negative, _NEG)
self._compareCpu(y, self._inv, math_ops.reciprocal)
self._compareCpu(x, np.square, math_ops.square)
@@ -390,7 +390,7 @@ class UnaryOpTest(test.TestCase):
self._compareCpu(x, np.cos, math_ops.cos)
self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.neg)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
self._compareBothSparse(x, np.square, math_ops.square)
self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
self._compareBothSparse(x, np.tanh, math_ops.tanh)
@@ -408,7 +408,7 @@ class UnaryOpTest(test.TestCase):
y = x + 0.5 # no zeros
self._compareCpu(x, np.abs, math_ops.abs)
self._compareCpu(x, np.abs, _ABS)
- self._compareCpu(x, np.negative, math_ops.neg)
+ self._compareCpu(x, np.negative, math_ops.negative)
self._compareCpu(x, np.negative, _NEG)
self._compareCpu(y, self._inv, math_ops.reciprocal)
self._compareCpu(x, np.square, math_ops.square)
@@ -424,7 +424,7 @@ class UnaryOpTest(test.TestCase):
self._compareCpu(x, np.cos, math_ops.cos)
self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.neg)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
self._compareBothSparse(x, np.square, math_ops.square)
self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
self._compareBothSparse(x, np.tanh, math_ops.tanh)
@@ -600,8 +600,8 @@ class BinaryOpTest(test.TestCase):
x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float32)
y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float32)
self._compareBoth(x, y, np.add, math_ops.add, also_compare_variables=True)
- self._compareBoth(x, y, np.subtract, math_ops.sub)
- self._compareBoth(x, y, np.multiply, math_ops.mul)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
self._compareBoth(x, y + 0.1, np.floor_divide, math_ops.floordiv)
self._compareBoth(x, y, np.add, _ADD)
@@ -657,8 +657,8 @@ class BinaryOpTest(test.TestCase):
x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float64)
y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float64)
self._compareBoth(x, y, np.add, math_ops.add)
- self._compareBoth(x, y, np.subtract, math_ops.sub)
- self._compareBoth(x, y, np.multiply, math_ops.mul)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
self._compareBoth(x, y + 0.1, np.floor_divide, math_ops.floordiv)
self._compareBoth(x, y, np.add, _ADD)
@@ -680,19 +680,19 @@ class BinaryOpTest(test.TestCase):
def testInt8Basic(self):
x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int8)
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int8)
- self._compareBoth(x, y, np.multiply, math_ops.mul)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
self._compareBoth(x, y, np.multiply, _MUL)
def testInt16Basic(self):
x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int16)
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int16)
- self._compareBoth(x, y, np.multiply, math_ops.mul)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
self._compareBoth(x, y, np.multiply, _MUL)
def testUint16Basic(self):
x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint16)
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint16)
- self._compareBoth(x, y, np.multiply, math_ops.mul)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
self._compareBoth(x, y, np.multiply, _MUL)
self._compareBoth(x, y, np.true_divide, math_ops.truediv)
self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
@@ -703,8 +703,8 @@ class BinaryOpTest(test.TestCase):
x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int32)
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int32)
self._compareBoth(x, y, np.add, math_ops.add)
- self._compareBoth(x, y, np.subtract, math_ops.sub)
- self._compareBoth(x, y, np.multiply, math_ops.mul)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
self._compareBoth(x, y, np.true_divide, math_ops.truediv)
self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
self._compareBoth(x, y, np.mod, math_ops.mod)
@@ -721,8 +721,8 @@ class BinaryOpTest(test.TestCase):
def testInt64Basic(self):
x = np.arange(1 << 40, 13 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int64)
- self._compareBoth(x, y, np.subtract, math_ops.sub)
- self._compareBoth(x, y, np.multiply, math_ops.mul)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
self._compareBoth(x, y, np.true_divide, math_ops.truediv)
self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
self._compareBoth(x, y, np.mod, math_ops.mod)
@@ -738,8 +738,8 @@ class BinaryOpTest(test.TestCase):
y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(
1, 3, 2).astype(np.complex64)
self._compareBoth(x, y, np.add, math_ops.add)
- self._compareBoth(x, y, np.subtract, math_ops.sub)
- self._compareBoth(x, y, np.multiply, math_ops.mul)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
self._compareBoth(x, y, np.add, _ADD)
self._compareBoth(x, y, np.subtract, _SUB)
@@ -752,8 +752,8 @@ class BinaryOpTest(test.TestCase):
y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(
1, 3, 2).astype(np.complex128)
self._compareBoth(x, y, np.add, math_ops.add)
- self._compareBoth(x, y, np.subtract, math_ops.sub)
- self._compareBoth(x, y, np.multiply, math_ops.mul)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
self._compareBoth(x, y, np.add, _ADD)
self._compareBoth(x, y, np.subtract, _SUB)
@@ -839,7 +839,7 @@ class BinaryOpTest(test.TestCase):
def _testBCastB(self, xs, ys):
funcs = [
- (np.subtract, math_ops.sub),
+ (np.subtract, math_ops.subtract),
(np.subtract, _SUB),
(np.power, math_ops.pow),
]
@@ -847,7 +847,7 @@ class BinaryOpTest(test.TestCase):
def _testBCastC(self, xs, ys):
funcs = [
- (np.multiply, math_ops.mul),
+ (np.multiply, math_ops.multiply),
(np.multiply, _MUL),
]
self._testBCastByFunc(funcs, xs, ys)
@@ -1055,8 +1055,8 @@ class BinaryOpTest(test.TestCase):
def testMismatchedDimensions(self):
for func in [
- math_ops.add, math_ops.sub, math_ops.mul, math_ops.div, _ADD, _SUB,
- _MUL, _TRUEDIV, _FLOORDIV
+ math_ops.add, math_ops.subtract, math_ops.multiply, math_ops.div,
+ _ADD, _SUB, _MUL, _TRUEDIV, _FLOORDIV
]:
with self.assertRaisesWithPredicateMatch(
ValueError, lambda e: "Dimensions must" in str(e)):
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index b008742a3f..071f970c58 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -359,9 +359,9 @@ multiply.__doc__ = gen_math_ops._mul.__doc__.replace("Mul", "`tf.multiply`")
# TODO(aselle): put deprecation in after another round of global code changes
-# @deprecated(
-# "2016-12-30",
-# "`tf.mul(x, y)` is deprecated, please use `tf.negative(x, y)` or `x * y`")
+@deprecated(
+ "2016-12-30",
+ "`tf.mul(x, y)` is deprecated, please use `tf.multiply(x, y)` or `x * y`")
def mul(x, y, name=None):
return gen_math_ops._mul(x, y, name)
mul.__doc__ = (gen_math_ops._mul.__doc__
@@ -374,9 +374,9 @@ subtract.__doc__ = gen_math_ops._sub.__doc__.replace("`Sub`", "`tf.subtract`")
# TODO(aselle): put deprecation in after another round of global code changes
-# @deprecated(
-# "2016-12-30",
-# "`tf.mul(x, y)` is deprecated, please use `tf.negative(x, y)` or `x * y`")
+@deprecated(
+ "2016-12-30",
+ "`tf.sub(x, y)` is deprecated, please use `tf.subtract(x, y)` or `x - y`")
def sub(x, y, name=None):
return gen_math_ops._sub(x, y, name)
sub.__doc__ = (gen_math_ops._sub.__doc__
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 3de4e1c12c..2b7f3af54a 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -2381,7 +2381,7 @@ def _sparse_false_positive_at_k(labels,
if weights is not None:
with ops.control_dependencies((_assert_weights_rank(weights, fp),)):
weights = math_ops.to_double(weights)
- fp = math_ops.mul(fp, weights)
+ fp = math_ops.multiply(fp, weights)
return fp