diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/ir_emitter.h')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_emitter.h | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 4533253680..2fea6846d8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -191,6 +191,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::gtl::ArraySlice<HloInstruction*> operands, tensorflow::StringPiece custom_call_target) override; Status HandleWhile(HloInstruction* xla_while) override; + Status HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice<HloInstruction*> operands) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -407,6 +410,21 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* arg, tensorflow::gtl::ArraySlice<int64> dimensions, unsigned element_alignment); + // Tries to emit a fast concatenate operation using memcpy. Returns true if + // successful, and false on failure. On failure, sets "failure_reason" to a + // string describing why it could not emit a fast concatenate. + StatusOr<bool> EmitFastConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice<HloInstruction*> operands, + string* failure_reason); + + // Emits LLVM IR to transfer "element_count" elements of type "primitive_type" + // from the address "source" to the address "target". + void EmitTransferElements(llvm::Value* target, llvm::Value* source, + int64 element_count, PrimitiveType primitive_type, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& source_array); + // Name of the computation entry function. This function serves as the // top-level "main" of the computation and will be invoked by the JIT. string entry_function_name_; |