aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
blob: f0687c1d4b5071bb96ba94d2042b05a3447bb108 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// XLA-specific base classes for Unary and Binary Ops.

#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_

#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/bcast.h"

namespace tensorflow {

// Coefficient-wise binary operations. Each binary Op expects two
// inputs that can be broadcast to the same shape. The base class
// contains pure virtual methods to override: description is a textual
// description of the operation; and Computation adds the
// implementation of the operation to a xla::ComputationBuilder. For most
// arithmetic Ops XLA handles the broadcasting automatically given the input
// tensors. Ops like ReluGrad that need to map a scalar function over the inputs
// can use the XlaBinaryMapOp subclass below which handles manual
// broadcasting of the inputs.
class XlaBinaryOp : public XlaOpKernel {
 public:
  explicit XlaBinaryOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    const DataType lhs = BaseType(input_type(0));
    const DataType rhs = BaseType(input_type(1));
    OP_REQUIRES(ctx, lhs == rhs,
                errors::InvalidArgument("Input types of binary op must match"));
  }
  ~XlaBinaryOp() override {}

  // Implement the (tensor,tensor)->tensor lambda that should be
  // applied to the inputs. The desired computation should be added to
  // 'tc->builder()' and '(lhs,rhs)' are the function's inputs and
  // (lhs_shape,rhs_shape) are their respective
  // shapes. 'broadcast_helper' contains metadata about the shapes of
  // the inputs and the dimensions that need to be broadcast, which
  // may be useful for Ops that can't use standard XLA automatic
  // broadcasting. 'extend_dimension' is non-empty if lhs and rhs have
  // different ranks, and indicates which dimensions of the
  // higher-rank input should be matched when broadcasting the
  // lower-rank input. See comment below and the documentation on broadcasting
  // in the XLA documentation.
  virtual xla::ComputationDataHandle Computation(
      XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs,
      const gtl::ArraySlice<int64>& lhs_shape,
      const xla::ComputationDataHandle& rhs,
      const gtl::ArraySlice<int64>& rhs_shape, const BCast& broadcast_helper,
      const std::vector<int64>& extend_dimensions) = 0;

  void Compile(XlaOpKernelContext* ctx) override;

  // Helper function that performs the broadcasting described by
  // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same
  // shape.
  static std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle>
  Broadcast(xla::ComputationBuilder* builder,
            const xla::ComputationDataHandle& lhs,
            const xla::ComputationDataHandle& rhs,
            const BCast& broadcast_helper);
};

// Coefficient-wise binary operations that map a scalar function. Each
// BinaryMap Op expects two inputs that can be broadcast to the same
// shape and maps a (scalar,scalar)->scalar function across the zipped
// elements of its (broadcast) inputs. The base class contains pure
// virtual methods to override: description is a textual description
// of the mapped function; and BuildMapLambda adds the
// implementation of the lambda to a xla::ComputationBuilder.
class XlaBinaryMapOp : public XlaBinaryOp {
 public:
  explicit XlaBinaryMapOp(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {}
  ~XlaBinaryMapOp() override {}

  // Implement the (scalar,scalar)->scalar lambda that should be
  // applied to each pair of elements of the inputs. The desired
  // computation should be added to 'builder' and
  // '(scalar_lhs,scalar_rhs)' are the function's inputs.
  virtual void BuildMapLambda(xla::ComputationBuilder* builder,
                              const xla::ComputationDataHandle& scalar_lhs,
                              const xla::ComputationDataHandle& scalar_rhs) = 0;

  xla::ComputationDataHandle Computation(
      XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs,
      const gtl::ArraySlice<int64>& lhs_shape,
      const xla::ComputationDataHandle& rhs,
      const gtl::ArraySlice<int64>& rhs_shape, const BCast& broadcast_helper,
      const std::vector<int64>& extend_dimensions) override;
};

}  // namespace tensorflow

#endif  // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_