aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2016-11-15 13:35:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-15 13:45:09 -0800
commitfb01ebb8c38b2d274f6fe9a7115b2362828a452e (patch)
tree81b598b8c41108b36c9f08331d1adf2415af0051 /tensorflow
parent7632193992bb77e08cfe93f752ccfd7a27cb2618 (diff)
Deprecate tf.inv in favor of tf.reciprocal.
Change: 139240711
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/cc/gradients/math_grad.cc16
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc4
-rw-r--r--tensorflow/contrib/factorization/python/ops/clustering_ops.py3
-rw-r--r--tensorflow/contrib/labeled_tensor/__init__.py2
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/core.py2
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/core_test.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/dataframe/transforms/unary_transforms.py2
-rw-r--r--tensorflow/core/ops/math_grad.cc17
-rw-r--r--tensorflow/core/ops/math_grad_test.cc4
-rw-r--r--tensorflow/core/ops/math_ops.cc14
-rw-r--r--tensorflow/core/public/version.h3
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py14
-rw-r--r--tensorflow/python/ops/hidden_ops.txt1
-rw-r--r--tensorflow/python/ops/linalg_grad.py2
-rw-r--r--tensorflow/python/ops/math_grad.py37
-rw-r--r--tensorflow/python/ops/math_ops.py2
-rw-r--r--tensorflow/python/ops/nn.py4
17 files changed, 80 insertions, 49 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index d841899bbd..11c6207599 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -50,6 +50,7 @@ Status InvGrad(const Scope& scope, const Operation& op,
return scope.status();
}
REGISTER_GRADIENT_OP("Inv", InvGrad);
+REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
Status SquareGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
@@ -68,7 +69,7 @@ Status SqrtGrad(const Scope& scope, const Operation& op,
// y = sqrt(x)
// dy/dx = 0.5 * (1 / sqrt(x)) = 0.5 * (1 / y)
// dx = dy * (0.5 * (1 / y))
- auto y_inv = Inv(scope, op.output(0));
+ auto y_inv = Reciprocal(scope, op.output(0));
auto half = Cast(scope, Const(scope, 0.5), op.input(0).type());
auto dx = Mul(scope, grad_inputs[0], Mul(scope, half, y_inv));
grad_outputs->push_back(dx);
@@ -82,7 +83,7 @@ Status RsqrtGrad(const Scope& scope, const Operation& op,
// y = 1/x^1/2 = x^-1/2
// dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1
// dx = dy * (-1/2 * y * x^-1)
- auto x_inv = Inv(scope, op.input(0));
+ auto x_inv = Reciprocal(scope, op.input(0));
auto y = op.output(0);
auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type());
auto a = Mul(scope, neghalf, x_inv);
@@ -110,7 +111,8 @@ Status LogGrad(const Scope& scope, const Operation& op,
// f(x) = log(x) = y
// df/dx = 1 / x
// dx = dy * (1 / x)
- grad_outputs->push_back(Mul(scope, grad_inputs[0], Inv(scope, op.input(0))));
+ grad_outputs->push_back(
+ Mul(scope, grad_inputs[0], Reciprocal(scope, op.input(0))));
return scope.status();
}
REGISTER_GRADIENT_OP("Log", LogGrad);
@@ -186,7 +188,7 @@ Status AsinGrad(const Scope& scope, const Operation& op,
// dx = dy * (1 / (1 - x * x)^1/2)
auto x2 = Square(scope, op.input(0));
auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
- auto dydx = Inv(scope, Sqrt(scope, Sub(scope, one, x2)));
+ auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
auto dx = Mul(scope, grad_inputs[0], dydx);
grad_outputs->push_back(dx);
return scope.status();
@@ -201,7 +203,7 @@ Status AcosGrad(const Scope& scope, const Operation& op,
// dx = dy * (- 1 / (1 - x * x)^1/2)
auto x2 = Square(scope, op.input(0));
auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
- auto dydx = Neg(scope, Inv(scope, Sqrt(scope, Sub(scope, one, x2))));
+ auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))));
auto dx = Mul(scope, grad_inputs[0], dydx);
grad_outputs->push_back(dx);
return scope.status();
@@ -214,7 +216,7 @@ Status TanGrad(const Scope& scope, const Operation& op,
// y = tan(x)
// dy/dx = sec(x)^2 = 1 / cos(x)^2
// dx = dy * (1 / cos(x)^2)
- auto dydx = Square(scope, Inv(scope, Cos(scope, op.input(0))));
+ auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
auto dx = Mul(scope, grad_inputs[0], dydx);
grad_outputs->push_back(dx);
return scope.status();
@@ -228,7 +230,7 @@ Status AtanGrad(const Scope& scope, const Operation& op,
// dy/dx = 1 / (1 + x^2)
// dx = dy * (1 / (1 + x^2)
auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
- auto dydx = Inv(scope, Add(scope, one, Square(scope, op.input(0))));
+ auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0))));
auto dx = Mul(scope, grad_inputs[0], dydx);
grad_outputs->push_back(dx);
return scope.status();
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 456a036111..8b7fb8d765 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -84,7 +84,7 @@ class CWiseUnaryGradTest : public ::testing::Test {
y = Neg(scope_, x);
break;
case INV:
- y = Inv(scope_, x);
+ y = Reciprocal(scope_, x);
break;
case SQUARE:
y = Square(scope_, x);
@@ -157,7 +157,7 @@ TEST_F(CWiseUnaryGradTest, Neg) {
TestCWiseGrad(NEG, x_fn, dy_fn, dx_fn);
}
-TEST_F(CWiseUnaryGradTest, Inv) {
+TEST_F(CWiseUnaryGradTest, Reciprocal) {
auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); };
auto dx_fn = [this](const float x, const float dy) {
diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py
index b0204c7ccf..79f2f92ebb 100644
--- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py
@@ -358,7 +358,8 @@ class KMeans(object):
cluster_center_updates -= tf.cast(
tf.reshape(count_updates, broadcast_shape),
inp.dtype) * old_cluster_centers
- learning_rate = tf.inv(tf.cast(old_counts + count_updates, inp.dtype))
+ learning_rate = tf.reciprocal(tf.cast(old_counts + count_updates,
+ inp.dtype))
learning_rate = tf.reshape(learning_rate, broadcast_shape)
# scale by 1 / (n + k), see comment above.
cluster_center_updates *= learning_rate
diff --git a/tensorflow/contrib/labeled_tensor/__init__.py b/tensorflow/contrib/labeled_tensor/__init__.py
index 75299a3a0e..71bb7b95f5 100644
--- a/tensorflow/contrib/labeled_tensor/__init__.py
+++ b/tensorflow/contrib/labeled_tensor/__init__.py
@@ -53,7 +53,7 @@ define_reduce_op = _ops.define_reduce_op
abs = _core.abs_function # pylint: disable=redefined-builtin
neg = _core.neg
sign = _core.sign
-inv = _core.inv
+reciprocal = _core.reciprocal
square = _core.square
round = _core.round_function # pylint: disable=redefined-builtin
sqrt = _core.sqrt
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core.py b/tensorflow/contrib/labeled_tensor/python/ops/core.py
index 69fd06133b..870dbdd383 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/core.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/core.py
@@ -1098,7 +1098,7 @@ def define_unary_op(op_name, elementwise_function):
abs_function = define_unary_op('abs', math_ops.abs)
neg = define_unary_op('neg', math_ops.neg)
sign = define_unary_op('sign', math_ops.sign)
-inv = define_unary_op('inv', math_ops.inv)
+reciprocal = define_unary_op('reciprocal', math_ops.reciprocal)
square = define_unary_op('square', math_ops.square)
round_function = define_unary_op('round', math_ops.round)
sqrt = define_unary_op('sqrt', math_ops.sqrt)
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py
index 5710dc34e8..f01955d507 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py
@@ -687,7 +687,7 @@ class CoreUnaryOpsTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin):
# TODO(shoyer): add unary + to core TensorFlow
('pos', None, None, None),
('sign', None, tf.sign, core.sign),
- ('inv', None, tf.inv, core.inv),
+ ('reciprocal', None, tf.reciprocal, core.reciprocal),
('square', None, tf.square, core.square),
('round', None, tf.round, core.round_function),
('sqrt', None, tf.sqrt, core.sqrt),
diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transforms/unary_transforms.py b/tensorflow/contrib/learn/python/learn/dataframe/transforms/unary_transforms.py
index 8734b4669d..fac25bf2b3 100644
--- a/tensorflow/contrib/learn/python/learn/dataframe/transforms/unary_transforms.py
+++ b/tensorflow/contrib/learn/python/learn/dataframe/transforms/unary_transforms.py
@@ -29,7 +29,7 @@ from tensorflow.python.ops import math_ops
# `Series`.registered_name().
UNARY_TRANSFORMS = [("__neg__", math_ops.neg),
("sign", math_ops.sign),
- ("inv", math_ops.inv),
+ ("reciprocal", math_ops.reciprocal),
("square", math_ops.square),
("round", math_ops.round),
("sqrt", math_ops.sqrt),
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index 949d8d302d..f22d6493e2 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -62,7 +62,7 @@ REGISTER_OP_GRADIENT("Neg", NegGrad);
Status InvGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForUnaryCwise(g, {
- {{"y"}, "Inv", {"x"}},
+ {{"y"}, "Reciprocal", {"x"}},
{{"y2"}, "Square", {"y"}, {}, {"dy"}},
{{"y2_neg"}, "Neg", {"y2"}},
{{"dx"}, "Mul", {"dy", "y2_neg"}}
@@ -70,6 +70,7 @@ Status InvGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format on
}
REGISTER_OP_GRADIENT("Inv", InvGrad);
+REGISTER_OP_GRADIENT("Reciprocal", InvGrad);
Status SquareGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
@@ -87,7 +88,7 @@ Status SqrtGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForUnaryCwise(g, {
{{"y"}, "Sqrt", {"x"}},
- {{"y_inv"}, "Inv", {"y"}, {}, {"dy"}},
+ {{"y_inv"}, "Reciprocal", {"y"}, {}, {"dy"}},
FDH::Const("const", 0.5f),
{{"half"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
{{"a"}, "Mul", {"half", "y_inv"}}, // .5 * 1/y
@@ -100,7 +101,7 @@ REGISTER_OP_GRADIENT("Sqrt", SqrtGrad);
Status RsqrtGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForUnaryCwise(g, {
- {{"x_inv"}, "Inv", {"x"}, {}, {"dy"}},
+ {{"x_inv"}, "Reciprocal", {"x"}, {}, {"dy"}},
{{"y"}, "Rsqrt", {"x"}},
FDH::Const("const", -.5f),
{{"neghalf"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
@@ -125,7 +126,7 @@ REGISTER_OP_GRADIENT("Exp", ExpGrad);
Status LogGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForUnaryCwise(g, {
- {{"x_inv"}, "Inv", {"x"}, {}, {"dy"}},
+ {{"x_inv"}, "Reciprocal", {"x"}, {}, {"dy"}},
{{"dx"}, "Mul", {"dy", "x_inv"}}, // dy * 1/x
});
// clang-format on
@@ -201,7 +202,7 @@ Status AcosGrad(const AttrSlice& attrs, FunctionDef* g) {
{{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
{{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
{{"b"}, "Sqrt", {"a"}},
- {{"inv"}, "Inv", {"b"}},
+ {{"inv"}, "Reciprocal", {"b"}},
{{"neg"}, "Neg", {"inv"}},
{{"dx"}, "Mul", {"dy", "neg"}}
});
@@ -217,7 +218,7 @@ Status AsinGrad(const AttrSlice& attrs, FunctionDef* g) {
{{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
{{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
{{"b"}, "Sqrt", {"a"}},
- {{"inv"}, "Inv", {"b"}},
+ {{"inv"}, "Reciprocal", {"b"}},
{{"dx"}, "Mul", {"dy", "inv"}}
});
// clang-format on
@@ -231,7 +232,7 @@ Status AtanGrad(const AttrSlice& attrs, FunctionDef* g) {
FDH::Const("const", 1.0f),
{{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
{{"a"}, "Add", {"one", "x2"}}, // 1 + x^2
- {{"inv"}, "Inv", {"a"}},
+ {{"inv"}, "Reciprocal", {"a"}},
{{"dx"}, "Mul", {"dy", "inv"}}
});
// clang-format on
@@ -242,7 +243,7 @@ Status TanGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForUnaryCwise(g, {
{{"cosx"}, "Cos", {"x"}},
- {{"secx"}, "Inv", {"cosx"}},
+ {{"secx"}, "Reciprocal", {"cosx"}},
{{"secx2"}, "Square", {"secx"}},
{{"dx"}, "Mul", {"dy", "secx2"}}
});
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index e937fc5ab1..e76e0cf6b3 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -417,13 +417,13 @@ TEST_F(MathGradTest, Neg) {
test::ExpectClose(ans, dx);
}
-TEST_F(MathGradTest, Inv) {
+TEST_F(MathGradTest, Reciprocal) {
auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f},
TensorShape({2, 3}));
auto g = [](float x) { return -1.f / (x * x); };
auto dx = test::AsTensor<float>(
{g(-3.f), g(-2.f), g(-1.f), g(1.f), g(2.f), g(3.f)}, TensorShape({2, 3}));
- auto ans = SymGrad("Inv", x);
+ auto ans = SymGrad("Reciprocal", x);
test::ExpectClose(ans, dx);
}
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index ec5f5e095f..f0749667aa 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -217,17 +217,23 @@ Computes numerical negative value element-wise.
I.e., \\(y = -x\\).
)doc");
-REGISTER_OP("Inv").UNARY().Doc(R"doc(
+REGISTER_OP("Inv")
+ .UNARY()
+ .Doc(R"doc(
Computes the reciprocal of x element-wise.
I.e., \\(y = 1 / x\\).
-)doc");
+)doc")
+ .Deprecated(17, "Use Reciprocal");
-REGISTER_OP("InvGrad").UNARY_GRADIENT_COMPLEX().Doc(R"doc(
+REGISTER_OP("InvGrad")
+ .UNARY_GRADIENT_COMPLEX()
+ .Doc(R"doc(
Computes the gradient for the inverse of `x` wrt its input.
Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy`
is the corresponding input gradient.
-)doc");
+)doc")
+ .Deprecated(17, "Use ReciprocalGrad");
REGISTER_OP("Reciprocal").UNARY().Doc(R"doc(
Computes the reciprocal of x element-wise.
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 34b5c0ba16..66af1897be 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -69,7 +69,8 @@ limitations under the License.
// 13. Deprecate multiple batch linear algebra ops (9sep2016).
// 14. Deprecate batch_matrix_* ops. (10sep2016).
// 15. Deprecate batch_fft_* ops. (14sep2016).
-// 16. Deprecate tensor_array (v1) ops in favor of v2 (10may2017).
+// 16. Deprecate tensor_array (v1) ops in favor of v2 (10nov2016).
+// 17. Deprecate inv (11nov2016).
// 17. Expose reverse_v2 (10nov2016)
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index cb2489a662..15fd8ef805 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -188,7 +188,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBoth(x, np.abs, _ABS)
self._compareBoth(x, np.negative, tf.neg)
self._compareBoth(x, np.negative, _NEG)
- self._compareBoth(y, self._inv, tf.inv)
+ self._compareBoth(y, self._inv, tf.reciprocal)
self._compareBoth(x, np.square, tf.square)
self._compareBoth(z, np.sqrt, tf.sqrt)
self._compareBoth(z, self._rsqrt, tf.rsqrt)
@@ -231,7 +231,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBoth(x, np.abs, _ABS)
self._compareBoth(x, np.negative, tf.neg)
self._compareBoth(x, np.negative, _NEG)
- self._compareBoth(x, self._inv, tf.inv)
+ self._compareBoth(x, self._inv, tf.reciprocal)
self._compareBoth(x, np.square, tf.square)
self._compareBoth(x, np.sqrt, tf.sqrt)
self._compareBoth(x, self._rsqrt, tf.rsqrt)
@@ -269,7 +269,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBoth(x, np.abs, _ABS)
self._compareBoth(x, np.negative, tf.neg)
self._compareBoth(x, np.negative, _NEG)
- self._compareBoth(y, self._inv, tf.inv)
+ self._compareBoth(y, self._inv, tf.reciprocal)
self._compareBoth(x, np.square, tf.square)
self._compareBoth(z, np.sqrt, tf.sqrt)
self._compareBoth(z, self._rsqrt, tf.rsqrt)
@@ -308,7 +308,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBoth(x, np.abs, _ABS)
self._compareBoth(x, np.negative, tf.neg)
self._compareBoth(x, np.negative, _NEG)
- self._compareBoth(y, self._inv, tf.inv)
+ self._compareBoth(y, self._inv, tf.reciprocal)
self._compareBoth(x, np.square, tf.square)
self._compareBoth(z, np.sqrt, tf.sqrt)
self._compareBoth(z, self._rsqrt, tf.rsqrt)
@@ -372,7 +372,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareCpu(x, np.abs, _ABS)
self._compareCpu(x, np.negative, tf.neg)
self._compareCpu(x, np.negative, _NEG)
- self._compareCpu(y, self._inv, tf.inv)
+ self._compareCpu(y, self._inv, tf.reciprocal)
self._compareCpu(x, np.square, tf.square)
self._compareCpu(y, np.sqrt, tf.sqrt)
self._compareCpu(y, self._rsqrt, tf.rsqrt)
@@ -404,7 +404,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareCpu(x, np.abs, _ABS)
self._compareCpu(x, np.negative, tf.neg)
self._compareCpu(x, np.negative, _NEG)
- self._compareCpu(y, self._inv, tf.inv)
+ self._compareCpu(y, self._inv, tf.reciprocal)
self._compareCpu(x, np.square, tf.square)
self._compareCpu(y, np.sqrt, tf.sqrt)
self._compareCpu(y, self._rsqrt, tf.rsqrt)
@@ -433,7 +433,7 @@ class UnaryOpTest(tf.test.TestCase):
shape = (5,)
dtype_tols = [(np.float32, 5e-4), (np.float64, 1e-6), (np.complex64, 5e-4),
(np.complex128, 1e-6)]
- op_range = [(gen_math_ops._inv_grad, [-2, 2]),
+ op_range = [(gen_math_ops._reciprocal_grad, [-2, 2]),
(gen_math_ops._rsqrt_grad, [0.1, 3]),
(gen_math_ops._sigmoid_grad, [-2, 2]),
(gen_math_ops._sqrt_grad, [0.1, 3]),
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index 17a07157d9..b1b47dadc8 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -198,6 +198,7 @@ Tanh
SigmoidGrad
TanhGrad
InvGrad
+ReciprocalGrad
SqrtGrad
RsqrtGrad
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
index c757f5999e..7510698665 100644
--- a/tensorflow/python/ops/linalg_grad.py
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -198,7 +198,7 @@ def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
# degenerate eigenvalues, the corresponding eigenvectors are only defined
# up to arbitrary rotation in a (k-dimensional) subspace.
f = array_ops.matrix_set_diag(
- math_ops.inv(
+ math_ops.reciprocal(
array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
array_ops.zeros_like(e))
grad_a = math_ops.batch_matmul(
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 91e70c0a8b..8d999f0074 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -249,7 +249,15 @@ def _InvGrad(op, grad):
"""Returns -grad * (1 / x^2)."""
y = op.outputs[0] # y = 1 / x
# pylint: disable=protected-access
- return gen_math_ops._inv_grad(y, grad)
+ return gen_math_ops._reciprocal_grad(y, grad)
+
+
+@ops.RegisterGradient("Reciprocal")
+def _ReciprocalGrad(op, grad):
+ """Returns -grad * (1 / x^2)."""
+ y = op.outputs[0] # y = 1 / x
+ # pylint: disable=protected-access
+ return gen_math_ops._reciprocal_grad(y, grad)
@ops.RegisterGradient("InvGrad")
@@ -260,7 +268,18 @@ def _InvGradGrad(op, grad):
ca = math_ops.conj(op.inputs[0])
cg = math_ops.conj(grad)
# pylint: disable=protected-access
- return cg * -2.0 * b * ca, gen_math_ops._inv_grad(ca, grad)
+ return cg * -2.0 * b * ca, gen_math_ops._reciprocal_grad(ca, grad)
+
+
+@ops.RegisterGradient("ReciprocalGrad")
+def _ReciprocalGradGrad(op, grad):
+ b = op.inputs[1]
+ # op.output[0]: y = -b * conj(a)^2
+ with ops.control_dependencies([grad.op]):
+ ca = math_ops.conj(op.inputs[0])
+ cg = math_ops.conj(grad)
+ # pylint: disable=protected-access
+ return cg * -2.0 * b * ca, gen_math_ops._reciprocal_grad(ca, grad)
@ops.RegisterGradient("Square")
@@ -323,7 +342,7 @@ def _LogGrad(op, grad):
x = op.inputs[0]
with ops.control_dependencies([grad.op]):
x = math_ops.conj(x)
- return grad * math_ops.inv(x)
+ return grad * math_ops.reciprocal(x)
@ops.RegisterGradient("Log1p")
@@ -332,7 +351,7 @@ def _Log1pGrad(op, grad):
x = op.inputs[0]
with ops.control_dependencies([grad.op]):
x = math_ops.conj(x)
- return grad * math_ops.inv(1 + x)
+ return grad * math_ops.reciprocal(1 + x)
@ops.RegisterGradient("Tanh")
@@ -505,7 +524,7 @@ def _TanGrad(op, grad):
x = op.inputs[0]
with ops.control_dependencies([grad.op]):
x = math_ops.conj(x)
- secx = math_ops.inv(math_ops.cos(x))
+ secx = math_ops.reciprocal(math_ops.cos(x))
secx2 = math_ops.square(secx)
return grad * secx2
@@ -519,7 +538,7 @@ def _AsinGrad(op, grad):
x2 = math_ops.square(x)
one = constant_op.constant(1, dtype=grad.dtype)
den = math_ops.sqrt(math_ops.sub(one, x2))
- inv = math_ops.inv(den)
+ inv = math_ops.reciprocal(den)
return grad * inv
@@ -532,19 +551,19 @@ def _AcosGrad(op, grad):
x2 = math_ops.square(x)
one = constant_op.constant(1, dtype=grad.dtype)
den = math_ops.sqrt(math_ops.sub(one, x2))
- inv = math_ops.inv(den)
+ inv = math_ops.reciprocal(den)
return -grad * inv
@ops.RegisterGradient("Atan")
def _AtanGrad(op, grad):
- """Returns grad * 1/ (1 + x^2)"""
+ """Returns grad * 1/ (1 + x^2)."""
x = op.inputs[0]
with ops.control_dependencies([grad.op]):
x = math_ops.conj(x)
x2 = math_ops.square(x)
one = constant_op.constant(1, dtype=grad.dtype)
- inv = math_ops.inv(math_ops.add(one, x2))
+ inv = math_ops.reciprocal(math_ops.add(one, x2))
return grad * inv
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 48a9d4eb36..c86451a8db 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -46,7 +46,7 @@ mathematical functions to your graph.
@@abs
@@negative
@@sign
-@@inv
+@@reciprocal
@@square
@@round
@@sqrt
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 71296ea798..d8a431f1c2 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -795,7 +795,7 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
Two `Tensor` objects: `mean` and `variance`.
"""
with ops.name_scope(name, "normalize", [counts, mean_ss, variance_ss, shift]):
- divisor = math_ops.inv(counts, name="divisor")
+ divisor = math_ops.reciprocal(counts, name="divisor")
if shift is not None:
shifted_mean = math_ops.mul(mean_ss, divisor, name="shifted_mean")
mean = math_ops.add(shifted_mean, shift, name="mean")
@@ -904,7 +904,7 @@ def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
name="sum_of_weights",
keep_dims=True)
- divisor = math_ops.inv(sum_of_weights, name="inv_weight_sum")
+ divisor = math_ops.reciprocal(sum_of_weights, name="inv_weight_sum")
weighted_mean = math_ops.mul(weighted_input_sum, divisor)