diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/relu_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/relu_op.cc | 43 |
1 files changed, 22 insertions, 21 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index 2dea4032c0..d1b857c22a 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -50,36 +50,37 @@ class Relu6Op : public XlaOpKernel { } }; -// A subclass of a XlaBinaryMapOp must build the lambda computation -// that describes the (scalar,scalar)->scalar function to apply to -// each element of the input. We have to use XlaBinaryMapOp instead of -// XlaBinaryOp here because XLA Select does not do automatic -// broadcasting. -class ReluGradOp : public XlaBinaryMapOp { +class ReluGradOp : public XlaOpKernel { public: - explicit ReluGradOp(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {} + explicit ReluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return 0. - void BuildMapLambda(xla::ComputationBuilder* b, - const xla::ComputationDataHandle& gradient, - const xla::ComputationDataHandle& feature) override { - const auto zero = XlaHelpers::Zero(b, input_type(0)); - b->Select(b->Gt(feature, zero), gradient, zero); + void Compile(XlaOpKernelContext* ctx) { + xla::ComputationBuilder* b = ctx->builder(); + const TensorShape shape = ctx->InputShape(0); + const auto zero = + b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); + const auto pred = b->Gt(ctx->Input(1), zero); + ctx->SetOutput(0, b->Select(pred, ctx->Input(0), zero)); } }; -class Relu6GradOp : public XlaBinaryMapOp { +class Relu6GradOp : public XlaOpKernel { public: - explicit Relu6GradOp(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {} + explicit Relu6GradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return 0. - void BuildMapLambda(xla::ComputationBuilder* b, - const xla::ComputationDataHandle& gradient, - const xla::ComputationDataHandle& feature) override { - const auto zero = XlaHelpers::Zero(b, input_type(0)); - auto six = XlaHelpers::IntegerLiteral(b, input_type(0), 6); - b->Select(b->LogicalAnd(b->Lt(feature, six), b->Gt(feature, zero)), - gradient, zero); + void Compile(XlaOpKernelContext* ctx) { + xla::ComputationBuilder* b = ctx->builder(); + const TensorShape shape = ctx->InputShape(0); + const auto zero = + b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); + const auto six = b->Broadcast( + XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes()); + auto out = b->Select( + b->LogicalAnd(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)), + ctx->Input(0), zero); + ctx->SetOutput(0, out); } }; |