aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-21 18:46:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 18:50:27 -0700
commita2fd40adcc714f18167acd9650e5442d4afd6a01 (patch)
tree5c02772a56e0fdba2ef40573c5919a33e2edee7a /tensorflow/compiler/tf2xla
parent1cb8940078f6be9313899734e1307a69fffc4b6f (diff)
[tf:xla]Implement DivNoNan.
PiperOrigin-RevId: 214076068
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc19
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index 0d9a768a6f..66676452d0 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -55,6 +56,24 @@ XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
+// Implementation of DivNoNan. Pseudo-code:
+// if (y == 0) {
+// return 0
+// } else {
+// return x / y;
+// }
+static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto y_equals_0 = xla::Eq(y, zero);
+ auto zeros = xla::ZerosLike(x);
+ auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y));
+ return result;
+}
+XLA_MAKE_BINARY(DivNoNan,
+ DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
// Implementation of FloorDiv. Pseudo-code:
// if ((x < 0) != (y < 0)) {
// T abs_x = std::abs(x);