aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/relu_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/relu_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/relu_op.cc43
1 files changed, 22 insertions, 21 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
index 2dea4032c0..d1b857c22a 100644
--- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
@@ -50,36 +50,37 @@ class Relu6Op : public XlaOpKernel {
}
};
-// A subclass of a XlaBinaryMapOp must build the lambda computation
-// that describes the (scalar,scalar)->scalar function to apply to
-// each element of the input. We have to use XlaBinaryMapOp instead of
-// XlaBinaryOp here because XLA Select does not do automatic
-// broadcasting.
-class ReluGradOp : public XlaBinaryMapOp {
+class ReluGradOp : public XlaOpKernel {
public:
- explicit ReluGradOp(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {}
+ explicit ReluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
// otherwise return 0.
- void BuildMapLambda(xla::ComputationBuilder* b,
- const xla::ComputationDataHandle& gradient,
- const xla::ComputationDataHandle& feature) override {
- const auto zero = XlaHelpers::Zero(b, input_type(0));
- b->Select(b->Gt(feature, zero), gradient, zero);
+ void Compile(XlaOpKernelContext* ctx) {
+ xla::ComputationBuilder* b = ctx->builder();
+ const TensorShape shape = ctx->InputShape(0);
+ const auto zero =
+ b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
+ const auto pred = b->Gt(ctx->Input(1), zero);
+ ctx->SetOutput(0, b->Select(pred, ctx->Input(0), zero));
}
};
-class Relu6GradOp : public XlaBinaryMapOp {
+class Relu6GradOp : public XlaOpKernel {
public:
- explicit Relu6GradOp(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {}
+ explicit Relu6GradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
// otherwise return 0.
- void BuildMapLambda(xla::ComputationBuilder* b,
- const xla::ComputationDataHandle& gradient,
- const xla::ComputationDataHandle& feature) override {
- const auto zero = XlaHelpers::Zero(b, input_type(0));
- auto six = XlaHelpers::IntegerLiteral(b, input_type(0), 6);
- b->Select(b->LogicalAnd(b->Lt(feature, six), b->Gt(feature, zero)),
- gradient, zero);
+ void Compile(XlaOpKernelContext* ctx) {
+ xla::ComputationBuilder* b = ctx->builder();
+ const TensorShape shape = ctx->InputShape(0);
+ const auto zero =
+ b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
+ const auto six = b->Broadcast(
+ XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes());
+ auto out = b->Select(
+ b->LogicalAnd(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)),
+ ctx->Input(0), zero);
+ ctx->SetOutput(0, out);
}
};