diff options
-rw-r--r-- | tensorflow/compiler/tests/binary_ops_test.py | 7 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/binary_ops.cc | 19 |
2 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 900e84ab58..e219cf3d88 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -560,6 +560,13 @@ class BinaryOpsTest(xla_test.XLATestCase): dtype(2), expected=np.array([[5], [2]], dtype=dtype)) + if dtype in [np.float32, np.float64]: + nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1) + divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24) + np_result = np.true_divide(nums, divs) + np_result[:, divs[0] == 0] = 0 + self._testBinary(gen_math_ops.div_no_nan, nums, divs, expected=np_result) + if dtype not in self.complex_types: # floordiv unsupported for complex. self._testBinary( gen_math_ops.floor_div, diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 0d9a768a6f..66676452d0 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -55,6 +56,24 @@ XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); +// Implementation of DivNoNan. Pseudo-code: +// if (y == 0) { +// return 0 +// } else { +// return x / y; +// } +static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto y_equals_0 = xla::Eq(y, zero); + auto zeros = xla::ZerosLike(x); + auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y)); + return result; +} +XLA_MAKE_BINARY(DivNoNan, + DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + // Implementation of FloorDiv. Pseudo-code: // if ((x < 0) != (y < 0)) { // T abs_x = std::abs(x); |