aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/ir_emitter.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/ir_emitter.h')
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h31
1 files changed, 22 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 3c110a320f..4e928ffadc 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
@@ -97,7 +98,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
bool is_top_level_computation,
std::vector<const HloInstruction*>* instruction_order);
- llvm::IRBuilder<>* ir_builder() { return &ir_builder_; }
+ llvm::IRBuilder<>* b() { return &b_; }
// Emits a call to `computation` with scalar arguments `arguments`.
StatusOr<llvm::Value*> EmitScalarCall(
@@ -117,6 +118,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleCopy(HloInstruction* copy) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleSelect(HloInstruction* select) override;
+ Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleDot(HloInstruction* dot) override;
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleFft(HloInstruction* fft) override;
@@ -146,6 +148,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleConcatenate(HloInstruction* concatenate) override;
Status HandleConditional(HloInstruction* conditional) override;
Status HandleAfterAll(HloInstruction* gen_token) override;
+ Status HandleIota(HloInstruction* iota) override;
Status FinishVisit(HloInstruction* root) override;
Status Preprocess(HloInstruction* hlo) override;
@@ -413,7 +416,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// creates the encapsulated llvm::Function s.t. it is added to the llvm
// module's function list).
std::unique_ptr<IrFunction> compute_function_;
- llvm::IRBuilder<> ir_builder_;
+ llvm::IRBuilder<> b_;
// Maps HLO instructions to their index into the profile counter array.
const std::unordered_map<const HloInstruction*, int64>
@@ -449,23 +452,22 @@ class IrEmitter : public DfsHloVisitorWithDefault {
: use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {}
// Record the cycle counter before an HLO executes.
- void RecordCycleStart(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo);
+ void RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo);
// Record the number of cycles it took for an HLO to execute.
- void RecordCycleDelta(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo,
+ void RecordCycleDelta(llvm::IRBuilder<>* b, HloInstruction* hlo,
llvm::Value* prof_counter);
// Record the number of cycles it took for the entire computation to
// execute.
- void RecordCompleteComputation(llvm::IRBuilder<>* ir_builder,
+ void RecordCompleteComputation(llvm::IRBuilder<>* b,
llvm::Value* prof_counter);
// Convenience function to generate a call to an intrinsic which reads the
// CPU cycle counter.
- llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* ir_builder);
+ llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* b);
// Store the cycle counter delta to the per-HLO profile counter.
- void UpdateProfileCounter(llvm::IRBuilder<>* ir_builder,
- llvm::Value* prof_counter, llvm::Value* cycle_end,
- llvm::Value* cycle_start);
+ void UpdateProfileCounter(llvm::IRBuilder<>* b, llvm::Value* prof_counter,
+ llvm::Value* cycle_end, llvm::Value* cycle_start);
private:
// Should we use the x86-specific rdtscp or the generic readcyclecounter
@@ -513,6 +515,17 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// Returns the number of bytes within the shape.
int64 ByteSizeOf(const Shape& shape) const;
+ StatusOr<llvm::Value*> EmitTargetElementLoopBodyForMap(
+ HloMapInstruction* map, const llvm_ir::IrArray::Index& index);
+ StatusOr<llvm::Value*> EmitTargetElementLoopBodyForReduceWindow(
+ HloReduceWindowInstruction* reduce_window,
+ const llvm_ir::IrArray::Index& index);
+ StatusOr<llvm::Value*> EmitTargetElementLoopBodyForConvolution(
+ HloConvolutionInstruction* convolution,
+ const llvm_ir::IrArray::Index& index);
+ StatusOr<llvm::Value*> EmitTargetElementLoopBodyForReduce(
+ HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index);
+
enum class XfeedKind {
kInfeed,
kOutfeed,