aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/ir_emitter.h
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-07-31 17:21:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-31 17:26:19 -0700
commit182a00ee781017932443bacb475af7acc4a56d5a (patch)
treeec24c5cc7e424642bf6772b4d46f7a08ac31dc88 /tensorflow/compiler/xla/service/cpu/ir_emitter.h
parent64f191cdc0121bbcb322c3b11b160d638c2f4af9 (diff)
Automated rollback of commit fba2d773f45f10882aa475ac75cbf9884995d626
PiperOrigin-RevId: 206855848
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/ir_emitter.h')
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h97
1 files changed, 48 insertions, 49 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 372017441f..03bbb2afb5 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -100,15 +100,14 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::IRBuilder<>* b() { return &b_; }
+ // Emits a call to `computation` with scalar arguments `arguments`.
+ StatusOr<llvm::Value*> EmitScalarCall(
+ PrimitiveType return_type, HloComputation* computation,
+ const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name);
+
// Emit an LLVM global variable for every constant buffer allocation.
Status EmitConstantGlobals();
- // Emit code to map one element according to `map_instr`.
- llvm::Value* EmitElementalMap(
- const HloMapInstruction& map_instr,
- tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
- tensorflow::StringPiece name);
-
protected:
//
// The following methods implement the DfsHloVisitor interface.
@@ -144,6 +143,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandlePad(HloInstruction* pad) override;
Status HandleTuple(HloInstruction* tuple) override;
+ Status HandleMap(HloInstruction* map) override;
Status HandleFusion(HloInstruction* fusion) override;
Status HandleCall(HloInstruction* call) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
@@ -218,18 +218,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// computation function being emitted by this emitter.
llvm::Value* GetTempBuffersArgument();
- // Helper for EmitTempBufferPointer.
- llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice,
- const Shape& target_shape);
-
- // Helper for EmitTempBufferPointer.
- llvm::Value* EmitThreadLocalTempBufferPointer(
- const BufferAllocation::Slice& slice, const Shape& target_shape);
-
- // Emits code that computes the address of the given buffer allocation slice.
- //
- // TODO(sanjoy): This should be renamed to reflect that it no longer provides
- // access to just temporaries.
+ // Emits code that computes the address of the given temporary buffer to the
+ // function. target_shape is the shape of this temporary buffer.
+ // The returned Value's type is a pointer to element_type.
llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice,
const Shape& target_shape);
@@ -241,27 +232,44 @@ class IrEmitter : public DfsHloVisitorWithDefault {
tensorflow::StringPiece
function_name_suffix); // Used for LLVM IR register names.
- // Emits a call to a thread local function (e.g. to the computation nested
- // within a reduce or a map). Thread local callees (by definition) only write
- // to and read from thread local allocations.
- //
- // `parameters` holds the *scalar values* that need to be passed to the
- // callee. The return value is the scalar returned by the callee.
- llvm::Value* EmitThreadLocalCall(
- const HloComputation& callee,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
+ // Methods that emit a function call.
+ // Parameters:
+ // function - The LLVM function to call.
+ // return_shape - The return shape of the HLO computation that was used to
+ // make the function. Not the same as the return type of the function
+ // in LLVM, since we use output parameters for the return type.
+ // element_count - number of elements to return (array form only).
+ // parameter_addresses - pointers to be passed to the function as
+ // parameters.
+ // name - used for LLVM IR register names.
+
+ // Emits a function call, returning a scalar, often an element of a larger
+ // array. Returns a Value for the scalar element returned by the function.
+ llvm::Value* EmitElementFunctionCall(
+ llvm::Function* function, const Shape& return_shape,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
tensorflow::StringPiece name);
- // Emits a call to a "global" function (e.g. to the computation nested within
- // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to
- // the parameters and return values for these computations so there is no need
- // to explicitly pass parameters or return results.
- void EmitGlobalCall(const HloComputation& callee,
- tensorflow::StringPiece name);
-
- // Returns the buffer to which a global call to `callee` would have written
- // its result.
- llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee);
+ // Array function call emitter. Stores the function's result into a supplied
+ // buffer.
+ // Parameters:
+ // function - The LLVM function to call.
+ // parameter_addresses - pointers to be passed to the function as
+ // parameters.
+ // return_value - pointer to a buffer where the call result is stored.
+
+ void EmitArrayFunctionCallInto(
+ llvm::Function* function,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ llvm::Value* return_value_buffer, tensorflow::StringPiece name);
+
+ // Array function call emitter. Returns a Value for the function's return
+ // value buffer address. The return value buffer is alloca'ed by this
+ // function.
+ llvm::Value* EmitArrayFunctionCall(
+ llvm::Function* function, const Shape& return_shape, int64 element_count,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ tensorflow::StringPiece name);
// Verifies that the element types of all of the given operand instructions
// match and are of one of the given supported types.
@@ -400,10 +408,11 @@ class IrEmitter : public DfsHloVisitorWithDefault {
NameUniquer name_uniquer_;
// Map containing all previously emitted computations.
- std::map<const HloComputation*, llvm::Function*> emitted_functions_;
+ std::map<HloComputation*, llvm::Function*> emitted_functions_;
// Map containing all previously emitted thread-local temporary buffers.
- std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, llvm::Value*>
+ std::map<std::pair<llvm::Function*, BufferAllocation::Slice>,
+ llvm::AllocaInst*>
thread_local_buffers_;
// The following fields track the IR emission state. According to LLVM memory
@@ -413,16 +422,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
std::unique_ptr<IrFunction> compute_function_;
llvm::IRBuilder<> b_;
- // The buffer allocation slice for the root of the computation being compiled.
- // Only relevant for thread local computations.
- BufferAllocation::Slice computation_root_allocation_;
-
- // Maps the buffer allocation slices for the parameters to the computation
- // being compiled to their parameter numbers. Only relevant for thread local
- // computations.
- tensorflow::gtl::FlatMap<BufferAllocation::Index, int64>
- computation_parameter_allocations_;
-
// Maps HLO instructions to their index into the profile counter array.
const std::unordered_map<const HloInstruction*, int64>
instruction_to_profile_idx_;