aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
blob: c2eeb0a1f92e3df91cd234599fe1cd5fc97c4625 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_

#include "llvm/IR/IRBuilder.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"

namespace xla {
namespace cpu {

bool PotentiallyImplementedAsEigenDot(
    const HloInstruction& hlo,
    const TargetMachineFeatures& target_machine_features);

// Returns the index for an operand to `hlo` that should ideally be column
// major.  Returns nullopt if there is no such operand or if `hlo` is not a dot
// or a fusion containing a dot.
tensorflow::gtl::optional<int64> ProfitableToMakeDotOperandColumnMajor(
    const HloInstruction& hlo);

// Returns true to indicate that we can generate a tiled LLVM IR implementation
// for |dot|.
bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot);

// Helper class for emitting LLVM IR to perform the dot operation.
class DotOpEmitter {
 public:
  // Emit LLVM IR to perform the dot operation on lhs_array and rhs_array and
  // place the result in target_array. IR is emitted at current insert point of
  // the builder. Upon completion of the method, the insert point is set to the
  // end of all instructions emitted for this operation.
  //
  // If `addend_array` is not nullptr then it must be an array of the same
  // dimensions as the result, and the result is computed as `addend_array` +
  // dot(`lhs_array`, `rhs_array`).  A non-null `addend_array` is only supported
  // for Matrix-vector products.
  static Status EmitDotOperation(
      const HloInstruction& dot, const llvm_ir::IrArray& target_array,
      const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
      const llvm_ir::IrArray* addend_array,
      llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
      const HloModuleConfig& hlo_module_config,
      const TargetMachineFeatures& target_machine_features);

 private:
  DotOpEmitter(const HloInstruction& dot, const llvm_ir::IrArray& target_array,
               const llvm_ir::IrArray& lhs_array,
               const llvm_ir::IrArray& rhs_array,
               const llvm_ir::IrArray* addend_array,
               llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
               const HloModuleConfig& hlo_module_config,
               const TargetMachineFeatures& target_machine_features);

  // Emits the IR to perform the dot operation.
  Status Emit();

  // Emits instructions to perform a scalar dot product (a multiply of the
  // LHS and RHS) and store the results in the target.
  Status EmitScalarDot();

  // Emit an LLVM IR implementation of the dot operation if we can.  Returns
  // true if an LLVM IR implementation was emitted.
  bool EmitLlvmIrDotIfProfitable();

  // Emits a call to the CPU runtime to perform the matrix multiply.
  Status EmitCallToRuntime();

  // Emits a series of nested loops for iterating over an operand array in the
  // dot operation. Loops are constructed in major to minor dimension layout
  // order. No loop is emitted for the given reduction_dimension. The function
  // returns an IrArray index for the given operand_array containing the indvars
  // of the loops. All dimensions of the index are filled except for the
  // reduction dimension. name_suffix is the string to append to the names of
  // LLVM constructs (eg, basic blocks) constructed by this method.
  llvm_ir::IrArray::Index EmitOperandArrayLoopNest(
      llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array,
      int64 reduction_dimension, tensorflow::StringPiece name_suffix);

  // Represents the dimensions of a matrix-matrix multiply operation.
  struct MatMultDims {
    // The number of rows in the LHS.
    int64 m;

    // The number of columns in the LHS, which is also must be equal to the
    // number of rows in the RHS.
    int64 k;

    // The number of columns on the RHS.
    int64 n;

    // True if the LHS matrix is column major.
    bool lhs_column_major;

    // True if the LHS contraction dimension is not 1.
    bool lhs_non_canonical;

    // True if the RHS matrix is column major.
    bool rhs_column_major;

    // True if the RHS contraction dimension is not 0.
    bool rhs_non_canonical;

    // True if the result matrix is column major.
    bool target_column_major;
  };

  // Get the MatMultDims instance for the dot product this DotOpEmitter
  // represents.  Precondition: the dot is of rank 2 (and thus its operands are
  // of rank 2 as well).
  MatMultDims GetMatMultDims() const;

  bool EmitExperimentalGebpDotIfEnabled(const MatMultDims& mat_mult_dims);

  // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
  // registers.
  int64 GetGemvTilingFactor() const {
    const int64 kDefaultTilingFactor = 8;
    return options::LlvmIrGemvTilingFactor(hlo_module_config_)
        .value_or(kDefaultTilingFactor);
  }

  std::tuple<int64, int64, int64> GetGemmTileSize() const {
    // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
    //
    // TODO(b/80093688): Tune for other architectures and centralize this
    // information in one place.
    const std::tuple<int64, int64, int64> kDefaultTileSize =
        std::tuple<int64, int64, int64>(11, 9, 1);
    return options::LlvmIrGemmTileSize(hlo_module_config_)
        .value_or(kDefaultTileSize);
  }

  // Returns true if we should use an experimental implementation of GEMM
  // (general matrix matrix multiplication) if possible.
  bool EnableExperimentalLlvmIrGemm() const {
    return options::EnableExperimentalLlvmIrGemm(hlo_module_config_);
  }

  // Returns true if we should call into multi-threaded Eigen routines.
  bool ShouldUseMultiThreadedEigen() {
    return hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
  }

  const HloInstruction& dot_;
  const llvm_ir::IrArray& target_array_;
  const llvm_ir::IrArray& lhs_array_;
  const llvm_ir::IrArray& rhs_array_;
  const llvm_ir::IrArray* addend_array_;
  llvm::Value* executable_run_options_value_;
  llvm::IRBuilder<>* b_;
  const HloModuleConfig& hlo_module_config_;
  const TargetMachineFeatures& target_machine_features_;
};

}  // namespace cpu
}  // namespace xla

#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_