aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/cc/BUILD2
-rw-r--r--tensorflow/cc/gradients/math_grad.cc16
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc33
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt5
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt4
-rw-r--r--tensorflow/core/kernels/cwise_op_div.cc3
-rw-r--r--tensorflow/core/kernels/cwise_ops.h24
-rw-r--r--tensorflow/core/ops/math_grad.cc13
-rw-r--r--tensorflow/core/ops/math_grad_test.cc72
-rw-r--r--tensorflow/core/ops/math_ops.cc4
-rw-r--r--tensorflow/core/ops/math_ops_test.cc3
-rw-r--r--tensorflow/python/ops/math_grad.py18
-rw-r--r--tensorflow/python/ops/math_grad_test.py23
-rw-r--r--tensorflow/python/ops/math_ops.py25
-rw-r--r--tensorflow/python/ops/math_ops_test.py14
15 files changed, 258 insertions, 1 deletions
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index 588a45ea43..f4be60a183 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -379,9 +379,11 @@ tf_cc_test(
srcs = ["gradients/math_grad_test.cc"],
deps = [
":cc_ops",
+ ":client_session",
":grad_op_registry",
":grad_testutil",
":gradient_checker",
+ ":gradients",
":math_grad",
":testutil",
"//tensorflow/core:lib_internal",
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index d95dd879b4..5dcf00857d 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -441,6 +441,22 @@ Status RealDivGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
+Status UnsafeDivGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ auto x_1 = ConjugateHelper(scope, op.input(0));
+ auto x_2 = ConjugateHelper(scope, op.input(1));
+ // y = x_1 / x_2
+ // dy/dx_1 = 1/x_2
+ // dy/dx_2 = -x_1/x_2^2
+ auto gx_1 = UnsafeDiv(scope, grad_inputs[0], x_2);
+ auto gx_2 =
+ Mul(scope, grad_inputs[0],
+ UnsafeDiv(scope, UnsafeDiv(scope, Neg(scope, x_1), x_2), x_2));
+ return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
+}
+REGISTER_GRADIENT_OP("UnsafeDiv", UnsafeDivGrad);
+
Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index c6c9262786..88aef1fab4 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradient_checker.h"
+#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/framework/testutil.h"
#include "tensorflow/cc/gradients/grad_testutil.h"
#include "tensorflow/cc/ops/standard_ops.h"
@@ -46,6 +48,7 @@ using ops::SegmentSum;
using ops::SquaredDifference;
using ops::Sub;
using ops::Sum;
+using ops::UnsafeDiv;
// TODO(andydavis) Test gradient function against numeric gradients output.
// TODO(andydavis) As more gradients are added move common test functions
@@ -851,6 +854,36 @@ TEST_F(NaryGradTest, RealDiv) {
RunTest({x}, {x_shape}, {y}, {x_shape});
}
+TEST_F(NaryGradTest, UnsafeDiv) {
+ {
+ TensorShape x_shape({3, 2, 5});
+ const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
+ // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
+ // division errors in the numeric estimator used by the gradient checker.
+ const auto y = UnsafeDiv(
+ scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
+ RunTest({x}, {x_shape}, {y}, {x_shape});
+ }
+ {
+ // Return 0 gradient (rather than NaN) for division by zero.
+ const auto x = Placeholder(scope_, DT_FLOAT);
+ const auto zero = Const<float>(scope_, 0.0);
+ const auto y = UnsafeDiv(scope_, x, zero);
+
+ std::vector<Output> grad_outputs;
+ TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs));
+ ClientSession session(scope_);
+ std::vector<Tensor> grad_result;
+ TF_EXPECT_OK(
+ session.Run({{x, {-3.0f, 0.0f, 3.0f}}}, grad_outputs, &grad_result));
+ EXPECT_EQ(grad_result.size(), 1);
+ EXPECT_EQ(grad_result[0].NumElements(), 3);
+ EXPECT_EQ(grad_result[0].flat<float>()(0), 0.0f);
+ EXPECT_EQ(grad_result[0].flat<float>()(1), 0.0f);
+ EXPECT_EQ(grad_result[0].flat<float>()(2), 0.0f);
+ }
+}
+
TEST_F(NaryGradTest, SquaredDifference) {
TensorShape x1_shape({3, 2, 5});
TensorShape x2_shape({2, 5});
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt
new file mode 100644
index 0000000000..82c913d15e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "UnsafeDiv"
+ summary: "Returns 0 if the denominator is zero."
+ description: ""
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt
new file mode 100644
index 0000000000..56caabcf3c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "UnsafeDiv"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc
index b12652f7fb..d6a2403816 100644
--- a/tensorflow/core/kernels/cwise_op_div.cc
+++ b/tensorflow/core/kernels/cwise_op_div.cc
@@ -24,6 +24,9 @@ REGISTER5(BinaryOp, CPU, "TruncateDiv", functor::safe_div, uint8, uint16, int16,
int32, int64);
REGISTER6(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double,
bfloat16, complex64, complex128);
+REGISTER5(BinaryOp, CPU, "UnsafeDiv", functor::unsafe_div, float, double, int16,
+ int32, int64);
+
#if GOOGLE_CUDA
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
uint16, int16, int64, complex64, complex128);
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 1b1a704d42..1014519059 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -153,6 +153,27 @@ struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> {
};
};
+template <typename T>
+struct unsafe_div_op {
+ EIGEN_EMPTY_STRUCT_CTOR(unsafe_div_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a,
+ const T& b) const {
+ if (b != 0) {
+ return scalar_quotient_op<T>()(a, b);
+ } else {
+ return 0;
+ }
+ }
+};
+
+template <typename T>
+struct functor_traits<unsafe_div_op<T>> {
+ enum {
+ Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::AddCost,
+ PacketAccess = false,
+ };
+};
+
// scalar_left and scalar_right are template helpers to partially
// apply a binary function.
//
@@ -721,6 +742,9 @@ struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op<
};
template <typename T>
+struct unsafe_div : base<T, Eigen::internal::unsafe_div_op<T>> {};
+
+template <typename T>
struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {};
template <typename T>
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index 783d292389..57499a6f1d 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -495,6 +495,19 @@ Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("RealDiv", RealDivGrad);
+Status UnsafeDivGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"gx"}, "UnsafeDiv", {"dz", "y"}},
+ {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
+ {{"y2"}, "Square", {"y"}, {}, {"dz"}},
+ {{"nx_y2"}, "UnsafeDiv", {"nx", "y2"}},
+ {{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("UnsafeDiv", UnsafeDivGrad);
+
Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
std::vector<FDH::Node> nodes = {
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index 2a27ef3ddb..b0d1595c31 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -753,6 +753,78 @@ TEST_F(MathGradTest, Div) {
}
}
+TEST_F(MathGradTest, UnsafeDiv) {
+ auto x = test::AsTensor<float>(
+ {0.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 0.f}, TensorShape({3, 3}));
+ auto y = test::AsTensor<float>({-10.f, 0.f, 10.f}, TensorShape({3, 1}));
+ Tensor dx;
+ Tensor dy;
+ {
+ SymGrad("UnsafeDiv", x, y, &dx, &dy);
+ {
+ auto g = [](float x, float y) {
+ if (y == 0.f) {
+ return 0.f;
+ } else {
+ return 1.f / y;
+ }
+ };
+ test::ExpectClose(dx, test::AsTensor<float>(
+ {g(0.f, -10.f), g(-3.f, -10.f), g(-2.f, -10.f),
+ g(-1.f, 0.f), g(0.f, 0.f), g(1.f, 0.f),
+ g(2.f, 10.f), g(3.f, 10.f), g(0.f, 10.f)},
+ TensorShape({3, 3})));
+ }
+ {
+ auto g = [](float x, float y) {
+ if (y == 0.f) {
+ return 0.f;
+ } else {
+ return -x / (y * y);
+ }
+ };
+ test::ExpectClose(dy,
+ test::AsTensor<float>(
+ {g(0.f, -10.f) + g(-3.f, -10.f) + g(-2.f, -10.f),
+ g(-1.f, 0.f) + g(0.f, 0.f) + g(1.f, 0.f),
+ g(2.f, 10.f) + g(3.f, 10.f) + g(0.f, 10.f)},
+ TensorShape({3, 1})));
+ }
+ }
+ { // Swap x and y.
+ SymGrad("UnsafeDiv", y, x, &dy, &dx);
+ {
+ auto g = [](float x, float y) {
+ if (y == 0.f) {
+ return 0.f;
+ } else {
+ return 1.f / y;
+ }
+ };
+ test::ExpectClose(dy,
+ test::AsTensor<float>(
+ {g(-10.f, 0.f) + g(-10.f, -3.f) + g(-10.f, -2.f),
+ g(0.f, -1.f) + g(0.f, 0.f) + g(0.f, 1.f),
+ g(10.f, 2.f) + g(10.f, 3.f) + g(10.f, 0.f)},
+ TensorShape({3, 1})));
+ }
+ {
+ auto g = [](float x, float y) {
+ if (y == 0.f) {
+ return 0.f;
+ } else {
+ return -x / (y * y);
+ }
+ };
+ test::ExpectClose(dx, test::AsTensor<float>(
+ {g(-10.f, 0.f), g(-10.f, -3.f), g(-10.f, -2.f),
+ g(0.f, -1.f), g(0.f, 0.f), g(0.f, 1.f),
+ g(10.f, 2.f), g(10.f, 3.f), g(10.f, 0.f)},
+ TensorShape({3, 3})));
+ }
+ }
+}
+
TEST_F(MathGradTest, Pow) {
auto x = test::AsTensor<float>({0.f, 1.f, 2.f, 3.f, 4.f, 5.f},
TensorShape({2, 3}));
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 1667c398f4..49646f1f3a 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -392,6 +392,10 @@ Returns x * y element-wise.
REGISTER_OP("Div").BINARY_MORE().SetShapeFn(
shape_inference::BroadcastBinaryOpShapeFn);
+REGISTER_OP("UnsafeDiv")
+ .BINARY_MORE()
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+
REGISTER_OP("FloorDiv")
.BINARY_MORE()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index 23f1538912..ebeb048157 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -120,7 +120,8 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
"Maximum", "Minimum",
"Mod", "Mul",
"NotEqual", "Pow",
- "Sub", "SquaredDifference"}) {
+ "Sub", "SquaredDifference",
+ "UnsafeDiv"}) {
ShapeInferenceTestOp op(op_name);
INFER_OK(op, "?;?", "?");
INFER_OK(op, "[1,2];?", "?");
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index f0c6bd532f..2a7a2fd51f 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -972,6 +972,24 @@ def _RealDivGrad(op, grad):
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy))
+@ops.RegisterGradient("UnsafeDiv")
+def _UnsafeDivGrad(op, grad):
+ """UnsafeDiv op gradient."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+ x = math_ops.conj(x)
+ y = math_ops.conj(y)
+ return (array_ops.reshape(
+ math_ops.reduce_sum(math_ops.unsafe_div(grad, y), rx), sx),
+ array_ops.reshape(
+ math_ops.reduce_sum(
+ grad * math_ops.unsafe_div(math_ops.unsafe_div(-x, y), y),
+ ry), sy))
+
+
@ops.RegisterGradient("Pow")
def _PowGrad(op, grad):
"""Returns grad * (y*x^(y-1), z*log(x))."""
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py
index fa47b8f9b8..f9bb60e7fe 100644
--- a/tensorflow/python/ops/math_grad_test.py
+++ b/tensorflow/python/ops/math_grad_test.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -230,5 +231,27 @@ class FloorModGradientTest(test.TestCase):
self.assertLess(error, 1e-4)
+class UnsafeDivGradientTest(test.TestCase):
+
+ def testBasicGradient(self):
+ inputs = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32)
+ outputs = math_ops.unsafe_div(inputs, 1 + math_ops.abs(inputs))
+ with self.test_session():
+ error = gradient_checker.compute_gradient_error(
+ inputs,
+ inputs.get_shape().as_list(), outputs,
+ outputs.get_shape().as_list())
+ self.assertLess(error, 1e-4)
+
+ def testGradientWithDenominatorIsZero(self):
+ x = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32)
+ y = array_ops.zeros_like(x, dtype=dtypes.float32)
+ outputs = math_ops.unsafe_div(x, y)
+ with self.test_session():
+ dx, dy = gradients.gradients(outputs, [x, y])
+ self.assertAllClose(dx.eval(), np.zeros(x.shape.as_list()))
+ self.assertAllClose(dy.eval(), np.zeros(y.shape.as_list()))
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index fbe6b62302..81499bee56 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1038,6 +1038,31 @@ def div(x, y, name=None):
return _div_python2(x, y, name)
+def unsafe_div(x, y, name=None):
+ """Computes an unsafe divide which returns 0 if the y is zero.
+
+ Note that the function uses Python 3 division operator semantics.
+
+ Args:
+ x: A `Tensor`. Must be one of the following types:
+ `float32`, `float64`, `int16`, `int32`, `int64`.
+ y: A `Tensor` whose dtype is compatible with `x`.
+ name: A name for the operation (optional).
+ Returns:
+ The element-wise value of the x divided by y.
+ """
+
+ with ops.name_scope(name, "unsafe_div", [x, y]) as name:
+ x = ops.convert_to_tensor(x, name="x")
+ y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
+ x_dtype = x.dtype.base_dtype
+ y_dtype = y.dtype.base_dtype
+ if x_dtype != y_dtype:
+ raise TypeError(
+ "x and y must have the same dtype, got %r != %r" % (x_dtype, y_dtype))
+ return gen_math_ops.unsafe_div(x, y, name=name)
+
+
# TODO(aselle): This should be removed
mod = gen_math_ops.floor_mod
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 6b709e5e7f..5fe7bbca11 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -473,5 +473,19 @@ class DivAndModTest(test_util.TensorFlowTestCase):
self.assertAllEqual(tf_result, expanded_nums)
+class UnsafeDivTest(test_util.TensorFlowTestCase):
+
+ def testBasic(self):
+ nums = np.arange(-10, 10, .25).reshape(80, 1)
+ divs = np.arange(-3, 3, .25).reshape(1, 24)
+
+ np_result = np.true_divide(nums, divs)
+ np_result[:, divs[0] == 0] = 0
+
+ with self.test_session():
+ tf_result = math_ops.unsafe_div(nums, divs).eval()
+ self.assertAllEqual(tf_result, np_result)
+
+
if __name__ == "__main__":
googletest.main()