diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-07-31 17:21:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-31 17:26:19 -0700 |
commit | 182a00ee781017932443bacb475af7acc4a56d5a (patch) | |
tree | ec24c5cc7e424642bf6772b4d46f7a08ac31dc88 /tensorflow/compiler/xla/service/cpu/ir_emitter.h | |
parent | 64f191cdc0121bbcb322c3b11b160d638c2f4af9 (diff) |
Automated rollback of commit fba2d773f45f10882aa475ac75cbf9884995d626
PiperOrigin-RevId: 206855848
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/ir_emitter.h')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_emitter.h | 97 |
1 files changed, 48 insertions, 49 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 372017441f..03bbb2afb5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -100,15 +100,14 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::IRBuilder<>* b() { return &b_; } + // Emits a call to `computation` with scalar arguments `arguments`. + StatusOr<llvm::Value*> EmitScalarCall( + PrimitiveType return_type, HloComputation* computation, + const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name); + // Emit an LLVM global variable for every constant buffer allocation. Status EmitConstantGlobals(); - // Emit code to map one element according to `map_instr`. - llvm::Value* EmitElementalMap( - const HloMapInstruction& map_instr, - tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands, - tensorflow::StringPiece name); - protected: // // The following methods implement the DfsHloVisitor interface. @@ -144,6 +143,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleRecvDone(HloInstruction* recv_done) override; Status HandlePad(HloInstruction* pad) override; Status HandleTuple(HloInstruction* tuple) override; + Status HandleMap(HloInstruction* map) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; @@ -218,18 +218,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); - // Helper for EmitTempBufferPointer. - llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice, - const Shape& target_shape); - - // Helper for EmitTempBufferPointer. - llvm::Value* EmitThreadLocalTempBufferPointer( - const BufferAllocation::Slice& slice, const Shape& target_shape); - - // Emits code that computes the address of the given buffer allocation slice. - // - // TODO(sanjoy): This should be renamed to reflect that it no longer provides - // access to just temporaries. + // Emits code that computes the address of the given temporary buffer to the + // function. target_shape is the shape of this temporary buffer. + // The returned Value's type is a pointer to element_type. llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice, const Shape& target_shape); @@ -241,27 +232,44 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::StringPiece function_name_suffix); // Used for LLVM IR register names. - // Emits a call to a thread local function (e.g. to the computation nested - // within a reduce or a map). Thread local callees (by definition) only write - // to and read from thread local allocations. - // - // `parameters` holds the *scalar values* that need to be passed to the - // callee. The return value is the scalar returned by the callee. - llvm::Value* EmitThreadLocalCall( - const HloComputation& callee, - tensorflow::gtl::ArraySlice<llvm::Value*> parameters, + // Methods that emit a function call. + // Parameters: + // function - The LLVM function to call. + // return_shape - The return shape of the HLO computation that was used to + // make the function. Not the same as the return type of the function + // in LLVM, since we use output parameters for the return type. + // element_count - number of elements to return (array form only). + // parameter_addresses - pointers to be passed to the function as + // parameters. + // name - used for LLVM IR register names. + + // Emits a function call, returning a scalar, often an element of a larger + // array. Returns a Value for the scalar element returned by the function. + llvm::Value* EmitElementFunctionCall( + llvm::Function* function, const Shape& return_shape, + tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, tensorflow::StringPiece name); - // Emits a call to a "global" function (e.g. to the computation nested within - // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to - // the parameters and return values for these computations so there is no need - // to explicitly pass parameters or return results. - void EmitGlobalCall(const HloComputation& callee, - tensorflow::StringPiece name); - - // Returns the buffer to which a global call to `callee` would have written - // its result. - llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee); + // Array function call emitter. Stores the function's result into a supplied + // buffer. + // Parameters: + // function - The LLVM function to call. + // parameter_addresses - pointers to be passed to the function as + // parameters. + // return_value - pointer to a buffer where the call result is stored. + + void EmitArrayFunctionCallInto( + llvm::Function* function, + tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, + llvm::Value* return_value_buffer, tensorflow::StringPiece name); + + // Array function call emitter. Returns a Value for the function's return + // value buffer address. The return value buffer is alloca'ed by this + // function. + llvm::Value* EmitArrayFunctionCall( + llvm::Function* function, const Shape& return_shape, int64 element_count, + tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, + tensorflow::StringPiece name); // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. @@ -400,10 +408,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { NameUniquer name_uniquer_; // Map containing all previously emitted computations. - std::map<const HloComputation*, llvm::Function*> emitted_functions_; + std::map<HloComputation*, llvm::Function*> emitted_functions_; // Map containing all previously emitted thread-local temporary buffers. - std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, llvm::Value*> + std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, + llvm::AllocaInst*> thread_local_buffers_; // The following fields track the IR emission state. According to LLVM memory @@ -413,16 +422,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { std::unique_ptr<IrFunction> compute_function_; llvm::IRBuilder<> b_; - // The buffer allocation slice for the root of the computation being compiled. - // Only relevant for thread local computations. - BufferAllocation::Slice computation_root_allocation_; - - // Maps the buffer allocation slices for the parameters to the computation - // being compiled to their parameter numbers. Only relevant for thread local - // computations. - tensorflow::gtl::FlatMap<BufferAllocation::Index, int64> - computation_parameter_allocations_; - // Maps HLO instructions to their index into the profile counter array. const std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx_; |