aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
blob: 516ead4bfe89b4ddeee11dcc6410a838d04f28a9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// XLA-specific base classes for Unary and Binary Ops.

#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_

#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/bcast.h"

namespace tensorflow {

// Coefficient-wise binary operations. Each binary Op expects two
// inputs that can be broadcast to the same shape. The base class
// contains pure virtual methods to override: description is a textual
// description of the operation; and Computation adds the
// implementation of the operation to a xla::XlaBuilder. For most
// arithmetic Ops XLA handles the broadcasting automatically given the input
// tensors.
class XlaBinaryOp : public XlaOpKernel {
 public:
  explicit XlaBinaryOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    const DataType lhs = BaseType(input_type(0));
    const DataType rhs = BaseType(input_type(1));
    OP_REQUIRES(ctx, lhs == rhs,
                errors::InvalidArgument("Input types of binary op must match"));
  }
  ~XlaBinaryOp() override {}

  // Implement the (tensor,tensor)->tensor lambda that should be
  // applied to the inputs. The desired computation should be added to
  // 'tc->builder()' and '(lhs,rhs)' are the function's inputs and
  // (lhs_shape,rhs_shape) are their respective
  // shapes. 'broadcast_helper' contains metadata about the shapes of
  // the inputs and the dimensions that need to be broadcast, which
  // may be useful for Ops that can't use standard XLA automatic
  // broadcasting. 'extend_dimension' is non-empty if lhs and rhs have
  // different ranks, and indicates which dimensions of the
  // higher-rank input should be matched when broadcasting the
  // lower-rank input. See comment below and the documentation on broadcasting
  // in the XLA documentation.
  virtual xla::XlaOp Computation(
      XlaOpKernelContext* ctx, const xla::XlaOp& lhs,
      const absl::Span<const int64>& lhs_shape, const xla::XlaOp& rhs,
      const absl::Span<const int64>& rhs_shape, const BCast& broadcast_helper,
      const std::vector<int64>& extend_dimensions) = 0;

  void Compile(XlaOpKernelContext* ctx) override;

  // Helper function that performs the broadcasting described by
  // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same
  // shape.
  static std::pair<xla::XlaOp, xla::XlaOp> Broadcast(
      xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper);
};

}  // namespace tensorflow

#endif  // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_