diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h | 16 |
1 files changed, 7 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 084462330e..bd5db72051 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -193,14 +193,12 @@ class IrEmitterUnnested : public IrEmitter { LaunchDimensions EmitHlo021Tile(HloInstruction* hlo, absl::Span<const int64> reduced_output_dims, absl::Span<const int64> tiled_param_ids); - // Generates the IrArray for each output of hlo and returns the number of - // outputs. - int ConstructIrArrayForOutputs(const HloInstruction& hlo, - std::vector<llvm_ir::IrArray>* output_arrays); - // Generates the IrArray for each input of hlo and returns the number of - // inputs. - int ConstructIrArrayForInputs(const HloInstruction& hlo, - std::vector<llvm_ir::IrArray>* param_arrays); + + // Generates the IrArray for each input of an hlo and returns a vector that + // constains such IrArrays. + std::vector<llvm_ir::IrArray> ConstructIrArrayForInputs( + const HloInstruction& hlo); + // For each output of the `hlo` instruction, constructs the reduced shape for // the output with the given `reduced_output_dims` and cast the original // output IrArray element in `output_arrays` to the reduced shape. Returns @@ -244,7 +242,7 @@ class IrEmitterUnnested : public IrEmitter { // Returns a thunk that, given a reduce or select-and-scatter op, initializes // its memory to the appropriate initial value. StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk( - const HloInstruction* hlo, const ShapeIndex& index = {}); + HloInstruction* hlo, const ShapeIndex& index = {}); // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst); |