aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-26 17:42:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 17:46:19 -0700
commit7b88cabfec45c9e04ab3d9cf1c2411c6dce4c694 (patch)
tree9bdc598fa33808d8689299438a50ad7445ebdec5 /tensorflow/compiler
parentbfda65cc70526c919c57ef8321dd282e463ed8a3 (diff)
Add xlogy and xdivy op.
PiperOrigin-RevId: 214700693
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index 66676452d0..a988d3c33e 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -103,6 +103,24 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
XLA_MAKE_BINARY(FloorDiv,
FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto is_zero = xla::Eq(x, zero);
+ return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y)));
+}
+XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
+static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto is_zero = xla::Eq(x, zero);
+ return xla::Select(is_zero, zero, xla::Div(x, y));
+}
+XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
// Implementation of FloorMod. Pseudo-code:
// T trunc_mod = std::fmod(x, y);
// return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);