aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-14 14:08:15 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-14 14:08:15 +0800
commit3c83ef9fbc8dc23ab0878cffa13ecbfd07ac70e5 (patch)
treeb53c9e3df64f7911ec847c7c8e0617b0fd8e1b91 /tensorflow
parenta90fce71faaa356b531157c2e00804046961b39d (diff)
CLN: rename UnsafeDiv => DivNoNan
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/cc/gradients/math_grad.cc15
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt2
-rw-r--r--tensorflow/core/kernels/cwise_op_div.cc2
-rw-r--r--tensorflow/core/ops/math_grad.cc8
-rw-r--r--tensorflow/core/ops/math_grad_test.cc6
-rw-r--r--tensorflow/core/ops/math_ops.cc2
-rw-r--r--tensorflow/core/ops/math_ops_test.cc2
-rw-r--r--tensorflow/python/ops/math_grad.py6
-rw-r--r--tensorflow/python/ops/math_grad_test.py2
-rw-r--r--tensorflow/python/ops/math_ops_test.py2
12 files changed, 29 insertions, 30 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index c6e60689fa..cd215f740d 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -441,21 +441,20 @@ Status RealDivGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
-Status UnsafeDivGrad(const Scope& scope, const Operation& op,
- const std::vector<Output>& grad_inputs,
- std::vector<Output>* grad_outputs) {
+Status DivNoNanGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
auto x_1 = ConjugateHelper(scope, op.input(0));
auto x_2 = ConjugateHelper(scope, op.input(1));
// y = x_1 / x_2
// dy/dx_1 = 1/x_2
// dy/dx_2 = -x_1/x_2^2
- auto gx_1 = UnsafeDiv(scope, grad_inputs[0], x_2);
- auto gx_2 =
- Mul(scope, grad_inputs[0],
- UnsafeDiv(scope, UnsafeDiv(scope, Neg(scope, x_1), x_2), x_2));
+ auto gx_1 = DivNoNan(scope, grad_inputs[0], x_2);
+ auto gx_2 = Mul(scope, grad_inputs[0],
+ DivNoNan(scope, DivNoNan(scope, Neg(scope, x_1), x_2), x_2));
return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
}
-REGISTER_GRADIENT_OP("UnsafeDiv", UnsafeDivGrad);
+REGISTER_GRADIENT_OP("DivNoNan", DivNoNanGrad);
Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 12a19bcf28..147428cc39 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -33,6 +33,7 @@ using ops::AddN;
using ops::BatchMatMul;
using ops::Const;
using ops::Div;
+using ops::DivNoNan;
using ops::MatMul;
using ops::Max;
using ops::Maximum;
@@ -47,7 +48,6 @@ using ops::RealDiv;
using ops::SquaredDifference;
using ops::Sub;
using ops::Sum;
-using ops::UnsafeDiv;
using ops::Where3;
// TODO(andydavis) Test gradient function against numeric gradients output.
@@ -854,13 +854,13 @@ TEST_F(NaryGradTest, RealDiv) {
RunTest({x}, {x_shape}, {y}, {x_shape});
}
-TEST_F(NaryGradTest, UnsafeDiv) {
+TEST_F(NaryGradTest, DivNoNan) {
{
TensorShape x_shape({3, 2, 5});
const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
// Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
// division errors in the numeric estimator used by the gradient checker.
- const auto y = UnsafeDiv(
+ const auto y = DivNoNan(
scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
RunTest({x}, {x_shape}, {y}, {x_shape});
}
@@ -868,7 +868,7 @@ TEST_F(NaryGradTest, UnsafeDiv) {
// Return 0 gradient (rather than NaN) for division by zero.
const auto x = Placeholder(scope_, DT_FLOAT);
const auto zero = Const<float>(scope_, 0.0);
- const auto y = UnsafeDiv(scope_, x, zero);
+ const auto y = DivNoNan(scope_, x, zero);
std::vector<Output> grad_outputs;
TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs));
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt
index d8f76c4cf8..5604a1a89e 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt
@@ -1,9 +1,9 @@
op {
- graph_op_name: "UnsafeDiv"
+ graph_op_name: "DivNoNan"
summary: "Returns 0 if the denominator is zero."
description: <<END
-*NOTE*: `UnsafeDiv` supports broadcasting. More about broadcasting
+*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
END
}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt
index 56caabcf3c..1bf3fba3c6 100644
--- a/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt
@@ -1,4 +1,4 @@
op {
- graph_op_name: "UnsafeDiv"
+ graph_op_name: "DivNoNan"
visibility: HIDDEN
}
diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc
index d6a2403816..7178a4787e 100644
--- a/tensorflow/core/kernels/cwise_op_div.cc
+++ b/tensorflow/core/kernels/cwise_op_div.cc
@@ -24,7 +24,7 @@ REGISTER5(BinaryOp, CPU, "TruncateDiv", functor::safe_div, uint8, uint16, int16,
int32, int64);
REGISTER6(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double,
bfloat16, complex64, complex128);
-REGISTER5(BinaryOp, CPU, "UnsafeDiv", functor::unsafe_div, float, double, int16,
+REGISTER5(BinaryOp, CPU, "DivNoNan", functor::unsafe_div, float, double, int16,
int32, int64);
#if GOOGLE_CUDA
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index 57499a6f1d..07f876cb90 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -495,18 +495,18 @@ Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("RealDiv", RealDivGrad);
-Status UnsafeDivGrad(const AttrSlice& attrs, FunctionDef* g) {
+Status DivNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForBinaryCwise(g, {
- {{"gx"}, "UnsafeDiv", {"dz", "y"}},
+ {{"gx"}, "DivNoNan", {"dz", "y"}},
{{"nx"}, "Neg", {"x"}, {}, {"dz"}},
{{"y2"}, "Square", {"y"}, {}, {"dz"}},
- {{"nx_y2"}, "UnsafeDiv", {"nx", "y2"}},
+ {{"nx_y2"}, "DivNoNan", {"nx", "y2"}},
{{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
});
// clang-format on
}
-REGISTER_OP_GRADIENT("UnsafeDiv", UnsafeDivGrad);
+REGISTER_OP_GRADIENT("DivNoNan", DivNoNanGrad);
Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index b0d1595c31..5ee79809ac 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -753,14 +753,14 @@ TEST_F(MathGradTest, Div) {
}
}
-TEST_F(MathGradTest, UnsafeDiv) {
+TEST_F(MathGradTest, DivNoNan) {
auto x = test::AsTensor<float>(
{0.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 0.f}, TensorShape({3, 3}));
auto y = test::AsTensor<float>({-10.f, 0.f, 10.f}, TensorShape({3, 1}));
Tensor dx;
Tensor dy;
{
- SymGrad("UnsafeDiv", x, y, &dx, &dy);
+ SymGrad("DivNoNan", x, y, &dx, &dy);
{
auto g = [](float x, float y) {
if (y == 0.f) {
@@ -792,7 +792,7 @@ TEST_F(MathGradTest, UnsafeDiv) {
}
}
{ // Swap x and y.
- SymGrad("UnsafeDiv", y, x, &dy, &dx);
+ SymGrad("DivNoNan", y, x, &dy, &dx);
{
auto g = [](float x, float y) {
if (y == 0.f) {
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 49646f1f3a..4cd472adad 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -392,7 +392,7 @@ Returns x * y element-wise.
REGISTER_OP("Div").BINARY_MORE().SetShapeFn(
shape_inference::BroadcastBinaryOpShapeFn);
-REGISTER_OP("UnsafeDiv")
+REGISTER_OP("DivNoNan")
.BINARY_MORE()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index ebeb048157..be4c3ed2b6 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -121,7 +121,7 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
"Mod", "Mul",
"NotEqual", "Pow",
"Sub", "SquaredDifference",
- "UnsafeDiv"}) {
+ "DivNoNan"}) {
ShapeInferenceTestOp op(op_name);
INFER_OK(op, "?;?", "?");
INFER_OK(op, "[1,2];?", "?");
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 2a7a2fd51f..38c715df10 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -972,9 +972,9 @@ def _RealDivGrad(op, grad):
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy))
-@ops.RegisterGradient("UnsafeDiv")
-def _UnsafeDivGrad(op, grad):
- """UnsafeDiv op gradient."""
+@ops.RegisterGradient("DivNoNan")
+def _DivNoNanGrad(op, grad):
+ """DivNoNan op gradient."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py
index 9afb034047..8524d9191c 100644
--- a/tensorflow/python/ops/math_grad_test.py
+++ b/tensorflow/python/ops/math_grad_test.py
@@ -231,7 +231,7 @@ class FloorModGradientTest(test.TestCase):
self.assertLess(error, 1e-4)
-class UnsafeDivGradientTest(test.TestCase):
+class DivNoNanGradientTest(test.TestCase):
def testBasicGradient(self):
inputs = constant_op.constant(np.arange(-3, 3),
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 5fe7bbca11..c9181269e3 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -473,7 +473,7 @@ class DivAndModTest(test_util.TensorFlowTestCase):
self.assertAllEqual(tf_result, expanded_nums)
-class UnsafeDivTest(test_util.TensorFlowTestCase):
+class DivNoNanTest(test_util.TensorFlowTestCase):
def testBasic(self):
nums = np.arange(-10, 10, .25).reshape(80, 1)