diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/binary_ops.cc | 37 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/cwise_ops.h | 6 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/relu_op.cc | 43 |
3 files changed, 38 insertions, 48 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 39c5567f80..1f9ac029c7 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -127,32 +127,21 @@ XLA_MAKE_BINARY(GreaterEqual, b->Ge(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Less, b->Lt(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(LessEqual, b->Le(lhs, rhs, extend_dimensions)); -#undef XLA_MAKE_BINARY +// Non-linear ops +XLA_MAKE_BINARY(SigmoidGrad, + b->Mul(b->Mul(rhs, lhs), + b->Sub(XlaHelpers::One(b, input_type(0)), lhs))); -#define XLA_MAKE_BINARY_MAP(Name, HLO) \ - class Name##Op : public XlaBinaryMapOp { \ - public: \ - explicit Name##Op(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {} \ - void BuildMapLambda(xla::ComputationBuilder* b, \ - const xla::ComputationDataHandle& lhs, \ - const xla::ComputationDataHandle& rhs) override { \ - HLO; \ - } \ - }; \ - REGISTER_XLA_OP(#Name, Name##Op) +XLA_MAKE_BINARY(SoftplusGrad, + b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)), + XlaHelpers::One(b, input_type(1))))); + +XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)), + b->Mul(lhs, lhs)))); -XLA_MAKE_BINARY_MAP(Pow, b->Pow(lhs, rhs)); -XLA_MAKE_BINARY_MAP(SigmoidGrad, - b->Mul(b->Mul(rhs, lhs), - b->Sub(XlaHelpers::One(b, input_type(0)), lhs))); -XLA_MAKE_BINARY_MAP(SoftplusGrad, - b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)), - XlaHelpers::One(b, input_type(1))))); -XLA_MAKE_BINARY_MAP(TanhGrad, - b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)), - b->Mul(lhs, lhs)))); - -#undef XLA_MAKE_BINARY_MAP +XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions)); + +#undef XLA_MAKE_BINARY } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index f0687c1d4b..ba38693325 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -32,9 +32,7 @@ namespace tensorflow { // description of the operation; and Computation adds the // implementation of the operation to a xla::ComputationBuilder. For most // arithmetic Ops XLA handles the broadcasting automatically given the input -// tensors. Ops like ReluGrad that need to map a scalar function over the inputs -// can use the XlaBinaryMapOp subclass below which handles manual -// broadcasting of the inputs. +// tensors. class XlaBinaryOp : public XlaOpKernel { public: explicit XlaBinaryOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -83,6 +81,8 @@ class XlaBinaryOp : public XlaOpKernel { // virtual methods to override: description is a textual description // of the mapped function; and BuildMapLambda adds the // implementation of the lambda to a xla::ComputationBuilder. +// Operations may have better performance if implemented as graphs of +// element-wise tensor operations. class XlaBinaryMapOp : public XlaBinaryOp { public: explicit XlaBinaryMapOp(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index 2dea4032c0..d1b857c22a 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -50,36 +50,37 @@ class Relu6Op : public XlaOpKernel { } }; -// A subclass of a XlaBinaryMapOp must build the lambda computation -// that describes the (scalar,scalar)->scalar function to apply to -// each element of the input. We have to use XlaBinaryMapOp instead of -// XlaBinaryOp here because XLA Select does not do automatic -// broadcasting. -class ReluGradOp : public XlaBinaryMapOp { +class ReluGradOp : public XlaOpKernel { public: - explicit ReluGradOp(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {} + explicit ReluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return 0. - void BuildMapLambda(xla::ComputationBuilder* b, - const xla::ComputationDataHandle& gradient, - const xla::ComputationDataHandle& feature) override { - const auto zero = XlaHelpers::Zero(b, input_type(0)); - b->Select(b->Gt(feature, zero), gradient, zero); + void Compile(XlaOpKernelContext* ctx) { + xla::ComputationBuilder* b = ctx->builder(); + const TensorShape shape = ctx->InputShape(0); + const auto zero = + b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); + const auto pred = b->Gt(ctx->Input(1), zero); + ctx->SetOutput(0, b->Select(pred, ctx->Input(0), zero)); } }; -class Relu6GradOp : public XlaBinaryMapOp { +class Relu6GradOp : public XlaOpKernel { public: - explicit Relu6GradOp(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {} + explicit Relu6GradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return 0. - void BuildMapLambda(xla::ComputationBuilder* b, - const xla::ComputationDataHandle& gradient, - const xla::ComputationDataHandle& feature) override { - const auto zero = XlaHelpers::Zero(b, input_type(0)); - auto six = XlaHelpers::IntegerLiteral(b, input_type(0), 6); - b->Select(b->LogicalAnd(b->Lt(feature, six), b->Gt(feature, zero)), - gradient, zero); + void Compile(XlaOpKernelContext* ctx) { + xla::ComputationBuilder* b = ctx->builder(); + const TensorShape shape = ctx->InputShape(0); + const auto zero = + b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); + const auto six = b->Broadcast( + XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes()); + auto out = b->Select( + b->LogicalAnd(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)), + ctx->Input(0), zero); + ctx->SetOutput(0, out); } }; |