diff options
author | 2017-01-04 10:03:37 -0800 | |
---|---|---|
committer | 2017-01-04 10:25:41 -0800 | |
commit | 4982c62ae45a5e1549f2c202fb4834065ab9a1ac (patch) | |
tree | 8e08a41270f34acca7feeab71d37a18babf60f59 /tensorflow | |
parent | bd970230fb9a7547ca53d115f88a4ae652f888c6 (diff) |
Add deprecation warnings to tf.neg and prepare for deprecation warnings of
tf.mul, tf.sub per go/tf-numpy-parity-plan
- Created wrappers for mul, multiply, sub and subtract. Made wrapper for
negative
- Propagate docstrings from ops in wrappers
- Add Mul and Neg to hidden_ops.txt
- Change gen_math_ops.sub to gen_math_ops._sub
- Change gen_math_ops.mul to gen_math_ops._mul
Change: 143565494
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/g3doc/get_started/basic_usage.md | 6 | ||||
-rw-r--r-- | tensorflow/python/ops/hidden_ops.txt | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 64 |
3 files changed, 62 insertions, 10 deletions
diff --git a/tensorflow/g3doc/get_started/basic_usage.md b/tensorflow/g3doc/get_started/basic_usage.md index d1c008940a..c39f365ed8 100644 --- a/tensorflow/g3doc/get_started/basic_usage.md +++ b/tensorflow/g3doc/get_started/basic_usage.md @@ -201,7 +201,7 @@ a = tf.constant([3.0, 3.0]) x.initializer.run() # Add an op to subtract 'a' from 'x'. Run it and print the result -sub = tf.sub(x, a) +sub = tf.subtract(x, a) print(sub.eval()) # ==> [-2. -1.] @@ -278,7 +278,7 @@ input1 = tf.constant([3.0]) input2 = tf.constant([2.0]) input3 = tf.constant([5.0]) intermed = tf.add(input2, input3) -mul = tf.mul(input1, intermed) +mul = tf.multiply(input1, intermed) with tf.Session() as sess: result = sess.run([mul, intermed]) @@ -307,7 +307,7 @@ tf.placeholder() to create them: input1 = tf.placeholder(tf.float32) input2 = tf.placeholder(tf.float32) -output = tf.mul(input1, input2) +output = input1 * input2 with tf.Session() as sess: print(sess.run([output], feed_dict={input1:[7.], input2:[2.]})) diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index 54a4a62567..72ba4031ae 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -208,12 +208,14 @@ FloorMod Max Mean Min +Mul Pow Prod Range RealDiv Select SparseMatMul +Sub Sum MatMul Sigmoid diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 84cf27e1f5..5ea0258ff5 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -246,6 +246,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops.gen_math_ops import * # pylint: enable=wildcard-import from tensorflow.python.util import compat +from tensorflow.python.util.deprecation import deprecated # Aliases for some automatically-generated names. @@ -282,6 +283,7 @@ argmin.__doc__ = (gen_math_ops.arg_min.__doc__.replace( # pylint: disable=anomalous-backslash-in-string,protected-access +# pylint: disable=g-docstring-has-escape def abs(x, name=None): """Computes the absolute value of a tensor. @@ -314,6 +316,7 @@ def abs(x, name=None): if x.dtype in (dtypes.complex64, dtypes.complex128): return gen_math_ops._complex_abs(x, Tout=x.dtype.real_dtype, name=name) return gen_math_ops._abs(x, name=name) +# pylint: enable=g-docstring-has-escape def divide(x, y, name=None): @@ -322,13 +325,38 @@ def divide(x, y, name=None): return x / y -# Make Python Aliases -multiply = gen_math_ops.mul -subtract = gen_math_ops.sub -negative = gen_math_ops.neg +def multiply(x, y, name=None): + return gen_math_ops._mul(x, y, name) +multiply.__doc__ = gen_math_ops._mul.__doc__.replace("Mul", "`tf.multiply`") -def neg(x, name=None): +# TODO(aselle): put deprecation in after another round of global code changes +# @deprecated( +# "2016-12-30", +# "`tf.mul(x, y)` is deprecated, please use `tf.negative(x, y)` or `x * y`") +def mul(x, y, name=None): + return gen_math_ops._mul(x, y, name) +mul.__doc__ = (gen_math_ops._mul.__doc__ + + ("" if mul.__doc__ is None else mul.__doc__)) + + +def subtract(x, y, name=None): + return gen_math_ops._sub(x, y, name) +subtract.__doc__ = gen_math_ops._sub.__doc__.replace("`Sub`", "`tf.subtract`") + + +# TODO(aselle): put deprecation in after another round of global code changes +# @deprecated( +# "2016-12-30", +# "`tf.mul(x, y)` is deprecated, please use `tf.negative(x, y)` or `x * y`") +def sub(x, y, name=None): + return gen_math_ops._sub(x, y, name) +sub.__doc__ = (gen_math_ops._sub.__doc__ + + ("" if sub.__doc__ is None else sub.__doc__)) + + +# pylint: disable=g-docstring-has-escape +def negative(x, name=None): """Computes numerical negative value element-wise. I.e., \\(y = -x\\). @@ -348,6 +376,28 @@ def neg(x, name=None): indices=x.indices, values=x_neg, dense_shape=x.dense_shape) else: return gen_math_ops.neg(x, name=name) +# pylint: enable=g-docstring-has-escape + + +# pylint: disable=g-docstring-has-escape +@deprecated( + "2016-12-30", + "`tf.neg(x)` is deprecated, please use `tf.negative(x)` or `-x`") +def neg(x, name=None): + """Computes numerical negative value element-wise. + + I.e., \\(y = -x\\). + + Args: + x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, + `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`. + name: A name for the operation (optional). + + Returns: + A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. + """ + return negative(x, name) +# pylint: enable=g-docstring-has-escape def sign(x, name=None): @@ -1024,7 +1074,7 @@ def _mul_dispatch(x, y, name=None): """Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse".""" is_tensor_y = isinstance(y, ops.Tensor) if is_tensor_y: - return gen_math_ops.mul(x, y, name=name) + return gen_math_ops._mul(x, y, name=name) else: assert isinstance(y, sparse_tensor.SparseTensor) # Case: Dense * Sparse. new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values, @@ -1043,7 +1093,7 @@ _OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_mul, "mul", sparse_tensor.SparseTensor) _OverrideBinaryOperatorHelper(gen_math_ops.add, "add") -_OverrideBinaryOperatorHelper(gen_math_ops.sub, "sub") +_OverrideBinaryOperatorHelper(gen_math_ops._sub, "sub") _OverrideBinaryOperatorHelper(_mul_dispatch, "mul") _OverrideBinaryOperatorHelper(_div_python2, "div") _OverrideBinaryOperatorHelper(_truediv_python3, "truediv") |