diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/binary_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/binary_ops.cc | 37 |
1 files changed, 13 insertions, 24 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 39c5567f80..1f9ac029c7 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -127,32 +127,21 @@ XLA_MAKE_BINARY(GreaterEqual, b->Ge(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Less, b->Lt(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(LessEqual, b->Le(lhs, rhs, extend_dimensions)); -#undef XLA_MAKE_BINARY +// Non-linear ops +XLA_MAKE_BINARY(SigmoidGrad, + b->Mul(b->Mul(rhs, lhs), + b->Sub(XlaHelpers::One(b, input_type(0)), lhs))); -#define XLA_MAKE_BINARY_MAP(Name, HLO) \ - class Name##Op : public XlaBinaryMapOp { \ - public: \ - explicit Name##Op(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {} \ - void BuildMapLambda(xla::ComputationBuilder* b, \ - const xla::ComputationDataHandle& lhs, \ - const xla::ComputationDataHandle& rhs) override { \ - HLO; \ - } \ - }; \ - REGISTER_XLA_OP(#Name, Name##Op) +XLA_MAKE_BINARY(SoftplusGrad, + b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)), + XlaHelpers::One(b, input_type(1))))); + +XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)), + b->Mul(lhs, lhs)))); -XLA_MAKE_BINARY_MAP(Pow, b->Pow(lhs, rhs)); -XLA_MAKE_BINARY_MAP(SigmoidGrad, - b->Mul(b->Mul(rhs, lhs), - b->Sub(XlaHelpers::One(b, input_type(0)), lhs))); -XLA_MAKE_BINARY_MAP(SoftplusGrad, - b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)), - XlaHelpers::One(b, input_type(1))))); -XLA_MAKE_BINARY_MAP(TanhGrad, - b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)), - b->Mul(lhs, lhs)))); - -#undef XLA_MAKE_BINARY_MAP +XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions)); + +#undef XLA_MAKE_BINARY } // namespace } // namespace tensorflow |