diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-21 18:46:35 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 18:50:27 -0700 |
commit | a2fd40adcc714f18167acd9650e5442d4afd6a01 (patch) | |
tree | 5c02772a56e0fdba2ef40573c5919a33e2edee7a /tensorflow/compiler/tf2xla | |
parent | 1cb8940078f6be9313899734e1307a69fffc4b6f (diff) |
[tf:xla]Implement DivNoNan.
PiperOrigin-RevId: 214076068
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/binary_ops.cc | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 0d9a768a6f..66676452d0 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -55,6 +56,24 @@ XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); +// Implementation of DivNoNan. Pseudo-code: +// if (y == 0) { +// return 0 +// } else { +// return x / y; +// } +static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto y_equals_0 = xla::Eq(y, zero); + auto zeros = xla::ZerosLike(x); + auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y)); + return result; +} +XLA_MAKE_BINARY(DivNoNan, + DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + // Implementation of FloorDiv. Pseudo-code: // if ((x < 0) != (y < 0)) { // T abs_x = std::abs(x); |