aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/binary_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc37
1 files changed, 13 insertions, 24 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index 39c5567f80..1f9ac029c7 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -127,32 +127,21 @@ XLA_MAKE_BINARY(GreaterEqual, b->Ge(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Less, b->Lt(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(LessEqual, b->Le(lhs, rhs, extend_dimensions));
-#undef XLA_MAKE_BINARY
+// Non-linear ops
+XLA_MAKE_BINARY(SigmoidGrad,
+ b->Mul(b->Mul(rhs, lhs),
+ b->Sub(XlaHelpers::One(b, input_type(0)), lhs)));
-#define XLA_MAKE_BINARY_MAP(Name, HLO) \
- class Name##Op : public XlaBinaryMapOp { \
- public: \
- explicit Name##Op(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {} \
- void BuildMapLambda(xla::ComputationBuilder* b, \
- const xla::ComputationDataHandle& lhs, \
- const xla::ComputationDataHandle& rhs) override { \
- HLO; \
- } \
- }; \
- REGISTER_XLA_OP(#Name, Name##Op)
+XLA_MAKE_BINARY(SoftplusGrad,
+ b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)),
+ XlaHelpers::One(b, input_type(1)))));
+
+XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)),
+ b->Mul(lhs, lhs))));
-XLA_MAKE_BINARY_MAP(Pow, b->Pow(lhs, rhs));
-XLA_MAKE_BINARY_MAP(SigmoidGrad,
- b->Mul(b->Mul(rhs, lhs),
- b->Sub(XlaHelpers::One(b, input_type(0)), lhs)));
-XLA_MAKE_BINARY_MAP(SoftplusGrad,
- b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)),
- XlaHelpers::One(b, input_type(1)))));
-XLA_MAKE_BINARY_MAP(TanhGrad,
- b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)),
- b->Mul(lhs, lhs))));
-
-#undef XLA_MAKE_BINARY_MAP
+XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions));
+
+#undef XLA_MAKE_BINARY
} // namespace
} // namespace tensorflow