aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc37
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.h6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/relu_op.cc43
3 files changed, 38 insertions, 48 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index 39c5567f80..1f9ac029c7 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -127,32 +127,21 @@ XLA_MAKE_BINARY(GreaterEqual, b->Ge(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Less, b->Lt(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(LessEqual, b->Le(lhs, rhs, extend_dimensions));
-#undef XLA_MAKE_BINARY
+// Non-linear ops
+XLA_MAKE_BINARY(SigmoidGrad,
+ b->Mul(b->Mul(rhs, lhs),
+ b->Sub(XlaHelpers::One(b, input_type(0)), lhs)));
-#define XLA_MAKE_BINARY_MAP(Name, HLO) \
- class Name##Op : public XlaBinaryMapOp { \
- public: \
- explicit Name##Op(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {} \
- void BuildMapLambda(xla::ComputationBuilder* b, \
- const xla::ComputationDataHandle& lhs, \
- const xla::ComputationDataHandle& rhs) override { \
- HLO; \
- } \
- }; \
- REGISTER_XLA_OP(#Name, Name##Op)
+XLA_MAKE_BINARY(SoftplusGrad,
+ b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)),
+ XlaHelpers::One(b, input_type(1)))));
+
+XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)),
+ b->Mul(lhs, lhs))));
-XLA_MAKE_BINARY_MAP(Pow, b->Pow(lhs, rhs));
-XLA_MAKE_BINARY_MAP(SigmoidGrad,
- b->Mul(b->Mul(rhs, lhs),
- b->Sub(XlaHelpers::One(b, input_type(0)), lhs)));
-XLA_MAKE_BINARY_MAP(SoftplusGrad,
- b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)),
- XlaHelpers::One(b, input_type(1)))));
-XLA_MAKE_BINARY_MAP(TanhGrad,
- b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)),
- b->Mul(lhs, lhs))));
-
-#undef XLA_MAKE_BINARY_MAP
+XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions));
+
+#undef XLA_MAKE_BINARY
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
index f0687c1d4b..ba38693325 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
@@ -32,9 +32,7 @@ namespace tensorflow {
// description of the operation; and Computation adds the
// implementation of the operation to a xla::ComputationBuilder. For most
// arithmetic Ops XLA handles the broadcasting automatically given the input
-// tensors. Ops like ReluGrad that need to map a scalar function over the inputs
-// can use the XlaBinaryMapOp subclass below which handles manual
-// broadcasting of the inputs.
+// tensors.
class XlaBinaryOp : public XlaOpKernel {
public:
explicit XlaBinaryOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
@@ -83,6 +81,8 @@ class XlaBinaryOp : public XlaOpKernel {
// virtual methods to override: description is a textual description
// of the mapped function; and BuildMapLambda adds the
// implementation of the lambda to a xla::ComputationBuilder.
+// Operations may have better performance if implemented as graphs of
+// element-wise tensor operations.
class XlaBinaryMapOp : public XlaBinaryOp {
public:
explicit XlaBinaryMapOp(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {}
diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
index 2dea4032c0..d1b857c22a 100644
--- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
@@ -50,36 +50,37 @@ class Relu6Op : public XlaOpKernel {
}
};
-// A subclass of a XlaBinaryMapOp must build the lambda computation
-// that describes the (scalar,scalar)->scalar function to apply to
-// each element of the input. We have to use XlaBinaryMapOp instead of
-// XlaBinaryOp here because XLA Select does not do automatic
-// broadcasting.
-class ReluGradOp : public XlaBinaryMapOp {
+class ReluGradOp : public XlaOpKernel {
public:
- explicit ReluGradOp(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {}
+ explicit ReluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
// otherwise return 0.
- void BuildMapLambda(xla::ComputationBuilder* b,
- const xla::ComputationDataHandle& gradient,
- const xla::ComputationDataHandle& feature) override {
- const auto zero = XlaHelpers::Zero(b, input_type(0));
- b->Select(b->Gt(feature, zero), gradient, zero);
+ void Compile(XlaOpKernelContext* ctx) {
+ xla::ComputationBuilder* b = ctx->builder();
+ const TensorShape shape = ctx->InputShape(0);
+ const auto zero =
+ b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
+ const auto pred = b->Gt(ctx->Input(1), zero);
+ ctx->SetOutput(0, b->Select(pred, ctx->Input(0), zero));
}
};
-class Relu6GradOp : public XlaBinaryMapOp {
+class Relu6GradOp : public XlaOpKernel {
public:
- explicit Relu6GradOp(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {}
+ explicit Relu6GradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
// Return the lhs (incoming gradient) if the rhs (input feature) > 0,
// otherwise return 0.
- void BuildMapLambda(xla::ComputationBuilder* b,
- const xla::ComputationDataHandle& gradient,
- const xla::ComputationDataHandle& feature) override {
- const auto zero = XlaHelpers::Zero(b, input_type(0));
- auto six = XlaHelpers::IntegerLiteral(b, input_type(0), 6);
- b->Select(b->LogicalAnd(b->Lt(feature, six), b->Gt(feature, zero)),
- gradient, zero);
+ void Compile(XlaOpKernelContext* ctx) {
+ xla::ComputationBuilder* b = ctx->builder();
+ const TensorShape shape = ctx->InputShape(0);
+ const auto zero =
+ b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
+ const auto six = b->Broadcast(
+ XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes());
+ auto out = b->Select(
+ b->LogicalAnd(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)),
+ ctx->Input(0), zero);
+ ctx->SetOutput(0, out);
}
};