diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/ir_emitter.h')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_emitter.h | 31 |
1 files changed, 22 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 3c110a320f..4e928ffadc 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -97,7 +98,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { bool is_top_level_computation, std::vector<const HloInstruction*>* instruction_order); - llvm::IRBuilder<>* ir_builder() { return &ir_builder_; } + llvm::IRBuilder<>* b() { return &b_; } // Emits a call to `computation` with scalar arguments `arguments`. StatusOr<llvm::Value*> EmitScalarCall( @@ -117,6 +118,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCopy(HloInstruction* copy) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; @@ -146,6 +148,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConditional(HloInstruction* conditional) override; Status HandleAfterAll(HloInstruction* gen_token) override; + Status HandleIota(HloInstruction* iota) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -413,7 +416,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { // creates the encapsulated llvm::Function s.t. it is added to the llvm // module's function list). std::unique_ptr<IrFunction> compute_function_; - llvm::IRBuilder<> ir_builder_; + llvm::IRBuilder<> b_; // Maps HLO instructions to their index into the profile counter array. const std::unordered_map<const HloInstruction*, int64> @@ -449,23 +452,22 @@ class IrEmitter : public DfsHloVisitorWithDefault { : use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {} // Record the cycle counter before an HLO executes. - void RecordCycleStart(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo); + void RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo); // Record the number of cycles it took for an HLO to execute. - void RecordCycleDelta(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo, + void RecordCycleDelta(llvm::IRBuilder<>* b, HloInstruction* hlo, llvm::Value* prof_counter); // Record the number of cycles it took for the entire computation to // execute. - void RecordCompleteComputation(llvm::IRBuilder<>* ir_builder, + void RecordCompleteComputation(llvm::IRBuilder<>* b, llvm::Value* prof_counter); // Convenience function to generate a call to an intrinsic which reads the // CPU cycle counter. - llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* ir_builder); + llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* b); // Store the cycle counter delta to the per-HLO profile counter. - void UpdateProfileCounter(llvm::IRBuilder<>* ir_builder, - llvm::Value* prof_counter, llvm::Value* cycle_end, - llvm::Value* cycle_start); + void UpdateProfileCounter(llvm::IRBuilder<>* b, llvm::Value* prof_counter, + llvm::Value* cycle_end, llvm::Value* cycle_start); private: // Should we use the x86-specific rdtscp or the generic readcyclecounter @@ -513,6 +515,17 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Returns the number of bytes within the shape. int64 ByteSizeOf(const Shape& shape) const; + StatusOr<llvm::Value*> EmitTargetElementLoopBodyForMap( + HloMapInstruction* map, const llvm_ir::IrArray::Index& index); + StatusOr<llvm::Value*> EmitTargetElementLoopBodyForReduceWindow( + HloReduceWindowInstruction* reduce_window, + const llvm_ir::IrArray::Index& index); + StatusOr<llvm::Value*> EmitTargetElementLoopBodyForConvolution( + HloConvolutionInstruction* convolution, + const llvm_ir::IrArray::Index& index); + StatusOr<llvm::Value*> EmitTargetElementLoopBodyForReduce( + HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index); + enum class XfeedKind { kInfeed, kOutfeed, |