aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc19
2 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 900e84ab58..e219cf3d88 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -560,6 +560,13 @@ class BinaryOpsTest(xla_test.XLATestCase):
dtype(2),
expected=np.array([[5], [2]], dtype=dtype))
+ if dtype in [np.float32, np.float64]:
+ nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1)
+ divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24)
+ np_result = np.true_divide(nums, divs)
+ np_result[:, divs[0] == 0] = 0
+ self._testBinary(gen_math_ops.div_no_nan, nums, divs, expected=np_result)
+
if dtype not in self.complex_types: # floordiv unsupported for complex.
self._testBinary(
gen_math_ops.floor_div,
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index 0d9a768a6f..66676452d0 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -55,6 +56,24 @@ XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
+// Implementation of DivNoNan. Pseudo-code:
+// if (y == 0) {
+// return 0
+// } else {
+// return x / y;
+// }
+static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto y_equals_0 = xla::Eq(y, zero);
+ auto zeros = xla::ZerosLike(x);
+ auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y));
+ return result;
+}
+XLA_MAKE_BINARY(DivNoNan,
+ DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
// Implementation of FloorDiv. Pseudo-code:
// if ((x < 0) != (y < 0)) {
// T abs_x = std::abs(x);