aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/ir_emitter.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/ir_emitter.h')
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 4533253680..2fea6846d8 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -191,6 +191,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
tensorflow::StringPiece custom_call_target) override;
Status HandleWhile(HloInstruction* xla_while) override;
+ Status HandleConcatenate(
+ HloInstruction* concatenate,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
Status FinishVisit(HloInstruction* root) override;
Status Preprocess(HloInstruction* hlo) override;
@@ -407,6 +410,21 @@ class IrEmitter : public DfsHloVisitorWithDefault {
HloInstruction* arg, tensorflow::gtl::ArraySlice<int64> dimensions,
unsigned element_alignment);
+ // Tries to emit a fast concatenate operation using memcpy. Returns true if
+ // successful, and false on failure. On failure, sets "failure_reason" to a
+ // string describing why it could not emit a fast concatenate.
+ StatusOr<bool> EmitFastConcatenate(
+ HloInstruction* concatenate,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ string* failure_reason);
+
+ // Emits LLVM IR to transfer "element_count" elements of type "primitive_type"
+ // from the address "source" to the address "target".
+ void EmitTransferElements(llvm::Value* target, llvm::Value* source,
+ int64 element_count, PrimitiveType primitive_type,
+ const llvm_ir::IrArray& target_array,
+ const llvm_ir::IrArray& source_array);
+
// Name of the computation entry function. This function serves as the
// top-level "main" of the computation and will be invoked by the JIT.
string entry_function_name_;