aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/cc/gradients/math_grad.cc15
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt9
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt5
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt4
-rw-r--r--tensorflow/core/kernels/cwise_op_div.cc3
-rw-r--r--tensorflow/core/kernels/cwise_ops.h8
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt35
-rw-r--r--tensorflow/core/ops/math_grad.cc8
-rw-r--r--tensorflow/core/ops/math_grad_test.cc6
-rw-r--r--tensorflow/core/ops/math_ops.cc7
-rw-r--r--tensorflow/core/ops/math_ops_test.cc2
-rw-r--r--tensorflow/core/ops/ops.pbtxt35
-rw-r--r--tensorflow/python/ops/math_grad.py10
-rw-r--r--tensorflow/python/ops/math_grad_test.py15
-rw-r--r--tensorflow/python/ops/math_ops.py16
-rw-r--r--tensorflow/python/ops/math_ops_test.py17
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.pbtxt4
20 files changed, 80 insertions, 135 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index 5dcf00857d..1329b568ab 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -441,21 +441,20 @@ Status RealDivGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
-Status UnsafeDivGrad(const Scope& scope, const Operation& op,
- const std::vector<Output>& grad_inputs,
- std::vector<Output>* grad_outputs) {
+Status DivNoNanGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
auto x_1 = ConjugateHelper(scope, op.input(0));
auto x_2 = ConjugateHelper(scope, op.input(1));
// y = x_1 / x_2
// dy/dx_1 = 1/x_2
// dy/dx_2 = -x_1/x_2^2
- auto gx_1 = UnsafeDiv(scope, grad_inputs[0], x_2);
- auto gx_2 =
- Mul(scope, grad_inputs[0],
- UnsafeDiv(scope, UnsafeDiv(scope, Neg(scope, x_1), x_2), x_2));
+ auto gx_1 = DivNoNan(scope, grad_inputs[0], x_2);
+ auto gx_2 = Mul(scope, grad_inputs[0],
+ DivNoNan(scope, DivNoNan(scope, Neg(scope, x_1), x_2), x_2));
return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
}
-REGISTER_GRADIENT_OP("UnsafeDiv", UnsafeDivGrad);
+REGISTER_GRADIENT_OP("DivNoNan", DivNoNanGrad);
Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 88aef1fab4..c16938322c 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -33,6 +33,7 @@ using ops::AddN;
using ops::BatchMatMul;
using ops::Const;
using ops::Div;
+using ops::DivNoNan;
using ops::MatMul;
using ops::Max;
using ops::Maximum;
@@ -48,7 +49,6 @@ using ops::SegmentSum;
using ops::SquaredDifference;
using ops::Sub;
using ops::Sum;
-using ops::UnsafeDiv;
// TODO(andydavis) Test gradient function against numeric gradients output.
// TODO(andydavis) As more gradients are added move common test functions
@@ -854,13 +854,13 @@ TEST_F(NaryGradTest, RealDiv) {
RunTest({x}, {x_shape}, {y}, {x_shape});
}
-TEST_F(NaryGradTest, UnsafeDiv) {
+TEST_F(NaryGradTest, DivNoNan) {
{
TensorShape x_shape({3, 2, 5});
const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
// Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large
// division errors in the numeric estimator used by the gradient checker.
- const auto y = UnsafeDiv(
+ const auto y = DivNoNan(
scope_, x, Add(scope_, Const<float>(scope_, 1), Abs(scope_, x)));
RunTest({x}, {x_shape}, {y}, {x_shape});
}
@@ -868,7 +868,7 @@ TEST_F(NaryGradTest, UnsafeDiv) {
// Return 0 gradient (rather than NaN) for division by zero.
const auto x = Placeholder(scope_, DT_FLOAT);
const auto zero = Const<float>(scope_, 0.0);
- const auto y = UnsafeDiv(scope_, x, zero);
+ const auto y = DivNoNan(scope_, x, zero);
std::vector<Output> grad_outputs;
TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs));
diff --git a/tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt b/tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt
new file mode 100644
index 0000000000..5604a1a89e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt
@@ -0,0 +1,9 @@
+op {
+ graph_op_name: "DivNoNan"
+ summary: "Returns 0 if the denominator is zero."
+ description: <<END
+
+*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt
deleted file mode 100644
index 82c913d15e..0000000000
--- a/tensorflow/core/api_def/base_api/api_def_UnsafeDiv.pbtxt
+++ /dev/null
@@ -1,5 +0,0 @@
-op {
- graph_op_name: "UnsafeDiv"
- summary: "Returns 0 if the denominator is zero."
- description: ""
-}
diff --git a/tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt b/tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt
new file mode 100644
index 0000000000..1bf3fba3c6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "DivNoNan"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt
deleted file mode 100644
index 56caabcf3c..0000000000
--- a/tensorflow/core/api_def/python_api/api_def_UnsafeDiv.pbtxt
+++ /dev/null
@@ -1,4 +0,0 @@
-op {
- graph_op_name: "UnsafeDiv"
- visibility: HIDDEN
-}
diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc
index d6a2403816..35662e278f 100644
--- a/tensorflow/core/kernels/cwise_op_div.cc
+++ b/tensorflow/core/kernels/cwise_op_div.cc
@@ -24,8 +24,7 @@ REGISTER5(BinaryOp, CPU, "TruncateDiv", functor::safe_div, uint8, uint16, int16,
int32, int64);
REGISTER6(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double,
bfloat16, complex64, complex128);
-REGISTER5(BinaryOp, CPU, "UnsafeDiv", functor::unsafe_div, float, double, int16,
- int32, int64);
+REGISTER2(BinaryOp, CPU, "DivNoNan", functor::div_no_nan, float, double);
#if GOOGLE_CUDA
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 1014519059..de164c1c09 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -154,8 +154,8 @@ struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> {
};
template <typename T>
-struct unsafe_div_op {
- EIGEN_EMPTY_STRUCT_CTOR(unsafe_div_op)
+struct div_no_nan_op {
+ EIGEN_EMPTY_STRUCT_CTOR(div_no_nan_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a,
const T& b) const {
if (b != 0) {
@@ -167,7 +167,7 @@ struct unsafe_div_op {
};
template <typename T>
-struct functor_traits<unsafe_div_op<T>> {
+struct functor_traits<div_no_nan_op<T>> {
enum {
Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::AddCost,
PacketAccess = false,
@@ -742,7 +742,7 @@ struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op<
};
template <typename T>
-struct unsafe_div : base<T, Eigen::internal::unsafe_div_op<T>> {};
+struct div_no_nan : base<T, Eigen::internal::div_no_nan_op<T>> {};
template <typename T>
struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {};
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index d0dd01ad99..88feecee66 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -73443,41 +73443,6 @@ op {
}
}
op {
- name: "UnsafeDiv"
- input_arg {
- name: "x"
- type_attr: "T"
- }
- input_arg {
- name: "y"
- type_attr: "T"
- }
- output_arg {
- name: "z"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_BFLOAT16
- type: DT_HALF
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_UINT8
- type: DT_INT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT32
- type: DT_INT64
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- }
- }
- }
-}
-op {
name: "UnsortedSegmentMax"
input_arg {
name: "data"
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index 57499a6f1d..07f876cb90 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -495,18 +495,18 @@ Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("RealDiv", RealDivGrad);
-Status UnsafeDivGrad(const AttrSlice& attrs, FunctionDef* g) {
+Status DivNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForBinaryCwise(g, {
- {{"gx"}, "UnsafeDiv", {"dz", "y"}},
+ {{"gx"}, "DivNoNan", {"dz", "y"}},
{{"nx"}, "Neg", {"x"}, {}, {"dz"}},
{{"y2"}, "Square", {"y"}, {}, {"dz"}},
- {{"nx_y2"}, "UnsafeDiv", {"nx", "y2"}},
+ {{"nx_y2"}, "DivNoNan", {"nx", "y2"}},
{{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
});
// clang-format on
}
-REGISTER_OP_GRADIENT("UnsafeDiv", UnsafeDivGrad);
+REGISTER_OP_GRADIENT("DivNoNan", DivNoNanGrad);
Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index b0d1595c31..5ee79809ac 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -753,14 +753,14 @@ TEST_F(MathGradTest, Div) {
}
}
-TEST_F(MathGradTest, UnsafeDiv) {
+TEST_F(MathGradTest, DivNoNan) {
auto x = test::AsTensor<float>(
{0.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 0.f}, TensorShape({3, 3}));
auto y = test::AsTensor<float>({-10.f, 0.f, 10.f}, TensorShape({3, 1}));
Tensor dx;
Tensor dy;
{
- SymGrad("UnsafeDiv", x, y, &dx, &dy);
+ SymGrad("DivNoNan", x, y, &dx, &dy);
{
auto g = [](float x, float y) {
if (y == 0.f) {
@@ -792,7 +792,7 @@ TEST_F(MathGradTest, UnsafeDiv) {
}
}
{ // Swap x and y.
- SymGrad("UnsafeDiv", y, x, &dy, &dx);
+ SymGrad("DivNoNan", y, x, &dy, &dx);
{
auto g = [](float x, float y) {
if (y == 0.f) {
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 49646f1f3a..717263a9b0 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -392,8 +392,11 @@ Returns x * y element-wise.
REGISTER_OP("Div").BINARY_MORE().SetShapeFn(
shape_inference::BroadcastBinaryOpShapeFn);
-REGISTER_OP("UnsafeDiv")
- .BINARY_MORE()
+REGISTER_OP("DivNoNan")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
REGISTER_OP("FloorDiv")
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index ebeb048157..be4c3ed2b6 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -121,7 +121,7 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
"Mod", "Mul",
"NotEqual", "Pow",
"Sub", "SquaredDifference",
- "UnsafeDiv"}) {
+ "DivNoNan"}) {
ShapeInferenceTestOp op(op_name);
INFER_OK(op, "?;?", "?");
INFER_OK(op, "[1,2];?", "?");
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 40b3dbe564..359bfa904f 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -34986,41 +34986,6 @@ op {
}
}
op {
- name: "UnsafeDiv"
- input_arg {
- name: "x"
- type_attr: "T"
- }
- input_arg {
- name: "y"
- type_attr: "T"
- }
- output_arg {
- name: "z"
- type_attr: "T"
- }
- attr {
- name: "T"
- type: "type"
- allowed_values {
- list {
- type: DT_BFLOAT16
- type: DT_HALF
- type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_UINT8
- type: DT_INT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT32
- type: DT_INT64
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- }
- }
- }
-}
-op {
name: "UnsortedSegmentMax"
input_arg {
name: "data"
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 2a7a2fd51f..8e11c4bce1 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -972,9 +972,9 @@ def _RealDivGrad(op, grad):
grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy))
-@ops.RegisterGradient("UnsafeDiv")
-def _UnsafeDivGrad(op, grad):
- """UnsafeDiv op gradient."""
+@ops.RegisterGradient("DivNoNan")
+def _DivNoNanGrad(op, grad):
+ """DivNoNan op gradient."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
@@ -983,10 +983,10 @@ def _UnsafeDivGrad(op, grad):
x = math_ops.conj(x)
y = math_ops.conj(y)
return (array_ops.reshape(
- math_ops.reduce_sum(math_ops.unsafe_div(grad, y), rx), sx),
+ math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
array_ops.reshape(
math_ops.reduce_sum(
- grad * math_ops.unsafe_div(math_ops.unsafe_div(-x, y), y),
+ grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y),
ry), sy))
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py
index f9bb60e7fe..059c8ebd7e 100644
--- a/tensorflow/python/ops/math_grad_test.py
+++ b/tensorflow/python/ops/math_grad_test.py
@@ -231,11 +231,12 @@ class FloorModGradientTest(test.TestCase):
self.assertLess(error, 1e-4)
-class UnsafeDivGradientTest(test.TestCase):
+class DivNoNanGradientTest(test.TestCase):
def testBasicGradient(self):
- inputs = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32)
- outputs = math_ops.unsafe_div(inputs, 1 + math_ops.abs(inputs))
+ inputs = constant_op.constant(np.arange(-3, 3),
+ dtype=dtypes.float32)
+ outputs = math_ops.div_no_nan(inputs, 1 + math_ops.abs(inputs))
with self.test_session():
error = gradient_checker.compute_gradient_error(
inputs,
@@ -244,9 +245,11 @@ class UnsafeDivGradientTest(test.TestCase):
self.assertLess(error, 1e-4)
def testGradientWithDenominatorIsZero(self):
- x = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32)
- y = array_ops.zeros_like(x, dtype=dtypes.float32)
- outputs = math_ops.unsafe_div(x, y)
+ x = constant_op.constant(np.arange(-3, 3),
+ dtype=dtypes.float32)
+ y = array_ops.zeros_like(x,
+ dtype=dtypes.float32)
+ outputs = math_ops.div_no_nan(x, y)
with self.test_session():
dx, dy = gradients.gradients(outputs, [x, y])
self.assertAllClose(dx.eval(), np.zeros(x.shape.as_list()))
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 4033d5f079..67ea534639 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1038,29 +1038,27 @@ def div(x, y, name=None):
return _div_python2(x, y, name)
-def unsafe_div(x, y, name=None):
+@tf_export("div_no_nan")
+def div_no_nan(x, y, name=None):
"""Computes an unsafe divide which returns 0 if the y is zero.
- Note that the function uses Python 3 division operator semantics.
-
Args:
- x: A `Tensor`. Must be one of the following types:
- `float32`, `float64`, `int16`, `int32`, `int64`.
+ x: A `Tensor`. Must be one of the following types: `float32`, `float64`.
y: A `Tensor` whose dtype is compatible with `x`.
name: A name for the operation (optional).
Returns:
The element-wise value of the x divided by y.
"""
- with ops.name_scope(name, "unsafe_div", [x, y]) as name:
+ with ops.name_scope(name, "div_no_nan", [x, y]) as name:
x = ops.convert_to_tensor(x, name="x")
y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
x_dtype = x.dtype.base_dtype
y_dtype = y.dtype.base_dtype
if x_dtype != y_dtype:
- raise TypeError(
- "x and y must have the same dtype, got %r != %r" % (x_dtype, y_dtype))
- return gen_math_ops.unsafe_div(x, y, name=name)
+ raise TypeError("x and y must have the same dtype, got %r != %r" %
+ (x_dtype, y_dtype))
+ return gen_math_ops.div_no_nan(x, y, name=name)
# TODO(aselle): This should be removed
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 5fe7bbca11..5ac7e133d9 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -473,18 +473,19 @@ class DivAndModTest(test_util.TensorFlowTestCase):
self.assertAllEqual(tf_result, expanded_nums)
-class UnsafeDivTest(test_util.TensorFlowTestCase):
+class DivNoNanTest(test_util.TensorFlowTestCase):
def testBasic(self):
- nums = np.arange(-10, 10, .25).reshape(80, 1)
- divs = np.arange(-3, 3, .25).reshape(1, 24)
+ for dtype in [np.float32, np.float64]:
+ nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1)
+ divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24)
- np_result = np.true_divide(nums, divs)
- np_result[:, divs[0] == 0] = 0
+ np_result = np.true_divide(nums, divs)
+ np_result[:, divs[0] == 0] = 0
- with self.test_session():
- tf_result = math_ops.unsafe_div(nums, divs).eval()
- self.assertAllEqual(tf_result, np_result)
+ with self.test_session():
+ tf_result = math_ops.div_no_nan(nums, divs).eval()
+ self.assertAllEqual(tf_result, np_result)
if __name__ == "__main__":
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 4de662fe33..8040eae01a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -1005,6 +1005,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "div_no_nan"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "divide"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 4de662fe33..8040eae01a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -1005,6 +1005,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "div_no_nan"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "divide"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}