/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // HLO instructions are in DAG form and represent the computations that the user // has built up via the XLA service interface. They are ultimately lowered // in a platform-aware way by traversing the HLO DAG and emitting a lowered // form; e.g. see DfsHloVisitor. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ #include #include #include #include #include #include #include #include #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace xla { class HloComputation; class HloModule; // HLO instructions are the IR used by the high-level compiler. class HloInstruction { public: enum class FusionKind { kLoop, // Fused into a loop. kInput, // Op's input is fused into the op itself. kOutput, // Op's output is fused into the op itself. // REQUIRES: At least one operand buffer must be able // to alias the output buffer. kTransposeDot, // Fused into a dot with transposed operands. kConvBackwardFilter, // Fused into a backward filter convolution. kConvBackwardInput, // Fused into a backward input convolution. }; ~HloInstruction(); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, const Shape& shape, const string& name); // Creates a literal constant instruction. static std::unique_ptr CreateConstant( std::unique_ptr literal); // Creates a get tuple element instruction. static std::unique_ptr CreateGetTupleElement( const Shape& shape, HloInstruction* operand, int64 index); // Creates a trace instruction that logs the input operand in the computation. static std::unique_ptr CreateTrace(const string& tag, HloInstruction* operand); // Creates a random number generation instruction that fills a shape with // random numbers from a given distribution. static std::unique_ptr CreateRng( const Shape& shape, RandomDistribution distribution, tensorflow::gtl::ArraySlice parameters); // Creates a unary instruction (one operand). // Precondition: opcode must be a legitimate unary operation. static std::unique_ptr CreateUnary(const Shape& shape, HloOpcode opcode, HloInstruction* operand); // Creates a binary instruction (two operands). // Precondition: opcode must be a legitimate binary operation. static std::unique_ptr CreateBinary(const Shape& shape, HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs); // Creates a ternary instruction (three operands). // Precondition: opcode must be a legitimate ternary operation. static std::unique_ptr CreateTernary(const Shape& shape, HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs, HloInstruction* ehs); // Creates a variadic instruction (variable number of operands). // Precondition: opcode must be a legitimate variadic operation. static std::unique_ptr CreateVariadic( const Shape& shape, HloOpcode opcode, tensorflow::gtl::ArraySlice operands); // Creates a map instruction, where the computation (given by the handle) is // applied element-wise to every element in operands (across the operands, // at a given index) with the same `static_operands`. static std::unique_ptr CreateMap( const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* map_computation, tensorflow::gtl::ArraySlice static_operands = {}); // Creates a convolution op, where rhs is the convolutional filter // and window describes how the filter is applied to lhs. static std::unique_ptr CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to // reduce it to. static std::unique_ptr CreateReducePrecision( const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits); // Creates a cross replica sum op. static std::unique_ptr CreateCrossReplicaSum( const Shape& shape, HloInstruction* operand); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. static std::unique_ptr CreateConvert(const Shape& shape, HloInstruction* operand); // Creates an infeed instruction, which reads data of the given shape from the // Infeed interface of the device. static std::unique_ptr CreateInfeed(const Shape& shape, const string& config); // Creates an outfeed instruction, which outputs data. static std::unique_ptr CreateOutfeed( const Shape& shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config); // Creates a send instruction with the given channel id, which sends the // operand data to a unique receive instruction in another computation that // has the same channel id. static std::unique_ptr CreateSend(HloInstruction* operand, int64 channel_id); // Creates a receive instruction with the given channel id, which receives // data of the given shape from a unique send instruction in another // computation that has the same channel id. static std::unique_ptr CreateRecv(const Shape& shape, int64 channel_id); // Creates a slice instruction, where the operand is sliced by the given // start/limit indices. static std::unique_ptr CreateSlice( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices); // Creates a slice instruction, where the first operand is sliced by // start indices specified in the second operand, and by size specfied in // 'slice_sizes'. static std::unique_ptr CreateDynamicSlice( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, tensorflow::gtl::ArraySlice slice_sizes); // Creates a dynamic update slice instruction, which updates a slice // of 'operand' with 'update' and 'start_indices'. static std::unique_ptr CreateDynamicUpdateSlice( const Shape& shape, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices); // Creates a concatenate instruction, where the operands are concatenated on // the provided dimension. static std::unique_ptr CreateConcatenate( const Shape& shape, tensorflow::gtl::ArraySlice operands, int64 dimension); // Creates a reduce instruction, where the computation (given by the handle) // is applied successively to every element in operand. That is, if f is the // function to apply (which either takes 2 [accumulator, value] or 3 // [accumulator, index, value] arguments) and init is a reduction operator // specified initial value (for example, 0 for addition), then this operation // will compute: // f(f(init, [index0], value0), [index1], value1), ...) static std::unique_ptr CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation); // Creates a reduce-window instruction, where the computation (given // by the handle) is applied window-wise at each valid window // position in the operand. static std::unique_ptr CreateReduceWindow( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation); // Creates a batch-norm-training instruction. static std::unique_ptr CreateBatchNormTraining( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index); // Creates a scatter computation that scatters the `source` array to the // selected indices of each window. static std::unique_ptr CreateSelectAndScatter( const Shape& shape, HloInstruction* operand, HloComputation* select, const Window& window, HloInstruction* source, HloInstruction* init_value, HloComputation* scatter); // Creates a broadcast instruction. static std::unique_ptr CreateBroadcast( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice broadcast_dimensions); // Creates a pad instruction, where the operand is padded on the edges and // between the elements with the given padding value. static std::unique_ptr CreatePad( const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config); // Creates a reshape instruction, where the operand is flattened row-major // order and then reshaped to the given result shape. static std::unique_ptr CreateReshape(const Shape& shape, HloInstruction* operand); // Creates a transpose instruction which permutes the operand dimensions. static std::unique_ptr CreateTranspose( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 // corresponds to the C code below. // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 } static std::unique_ptr CreateWhile(const Shape& shape, HloComputation* condition, HloComputation* body, HloInstruction* init); // Creates a fusion instruction. A fusion instruction contains one or more // fused instructions forming an expression with a single root // "fused_root". Additional instructions can be added to the fusion // instruction with the method FuseInstruction. static std::unique_ptr CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); // Creates a fusion instruction that represents backward convolution. This is // similar to CreateFusion, but with extra arguments indicating the window and // dimemsion mapping of the backward convolution. static std::unique_ptr CreateFusionForBackwardConvolution( const Shape& shape, FusionKind fusion_kind, const Window& window, const ConvolutionDimensionNumbers& conv_dnums, HloInstruction* fused_root); // Creates a call instruction that applies the given computation on the given // operands. "shape" is the resultant shape. static std::unique_ptr CreateCall( const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* computation); // Creates a custom call instruction that applies the given custom call target // to the given operands. "shape" is the resultant shape. static std::unique_ptr CreateCustomCall( const Shape& shape, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target); // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. static std::unique_ptr CreateTuple( tensorflow::gtl::ArraySlice elements); // Creates a reverse instruction, which reverses the order of the elements // in the specified dimensions. static std::unique_ptr CreateReverse( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } // Returns the result shape of this instruction. const Shape& shape() const; // Returns the (mutable) result shape of this instruction. Shape* mutable_shape() { return &shape_; } // Returns the ith operand to this instruction. const HloInstruction* operand(int64 i) const; // Returns the ith operand to this instruction. HloInstruction* mutable_operand(int64 i); // Returns the number of operands to this instruction. int64 operand_count() const { return operands_.size(); } // Returns the vector of operands of this instruction. const std::vector& operands() const { return operands_; } // Returns the index of 'target' in the operands sequence. // Precondition: target must be an operand (or a fatal error will occur). int64 operand_index(const HloInstruction* target) const; // Returns the number of users of this instruction. int64 user_count() const { return users_.size(); } // Returns the users of this instruction. const std::vector& users() const { return users_; } // Returns true if this instruction is a user of 'instruction'. bool IsUserOf(const HloInstruction* instruction) const { return ContainsKey(instruction->user_set_, this); } // Adds a control dependency from this instruction to the given // instruction. This instruction becomes a control predecessor of // 'instruction', and 'instruction' becomes a control successor of this // instruction. Returns an error status if either of the given instructions // does not belong to the same computation. // // This is used to enforce an additional ordering requirement that is not // captured by normal data dependencies, such as ordering among Send or Recv // operations to avoid deadlock. Status AddControlDependencyTo(HloInstruction* instruction); // Removes a previously added control dependency from this instruction to // 'instruction'. Status RemoveControlDependencyTo(HloInstruction* instruction); // Returns the set of control predecessors (successors) of this // instruction. Control predecessors (sucessors) must execute before (after) // the current instruction. const std::vector& control_predecessors() const { return control_predecessors_; } const std::vector& control_successors() const { return control_successors_; } // Returns true if "other" performs the same computation as this instruction. // Layout of the instructions' output array is not considered. bool Identical( const HloInstruction& other, std::function eq_operands = std::equal_to(), std::function eq_computations = std::equal_to()) const; // Returns whether the instruction has a constant operand. bool HasConstantOperand() const; // Returns whether this instruction does a rank-2 transposition. bool IsRank2Transpose() const; // Replaces the use of this instruction in "user" with "new_producer". Note // that there might be multiple uses of this instruction in "user"; all will // be replaced. Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); // Replaces the specified operand with new_operand. Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand); // Replaces all uses of this instruction with the new producer. If // new_producer is a user of this instruction then new_producer remains a use // of this instruction to avoid introducing cycles into the graph. Status ReplaceAllUsesWith(HloInstruction* new_producer); // Detaches an instruction from its operands. That is, remove the instruction // from each operand's user set. This should only be called prior to // deallocating the instruction. void DetachFromOperands(); // Performs a postorder DFS visit using this node as the root. If // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when // complete. If ignore_control_predecessors is true, instructions only // reachable via control dependencies will not be visited, and the postorder // will not take control dependencies into account. It is as if the control // dependencies didn't exist in the graph at all. Status Accept(DfsHloVisitor* visitor, bool call_finish_visit = true, bool ignore_control_predecessors = false); // Same as Accept() above, but the order of operand and control predecessor // visitation is determined by the given operand order; if compare(A, B) == // true, A is visited before B. using CompareFunction = std::function; Status AcceptWithOperandOrder(DfsHloVisitor* visitor, const CompareFunction& operand_order, bool call_finish_visit = true); // Performs a postorder DFS visit using this node as the root. Calls the given // visitor function at each instruction. Status Accept(const FunctionVisitor::VisitorFunction& visitor_func); // Visits all instructions rooted at this instruction using the given visitor // in the given order. 'order' must contain at least the set of instructions // rooted at this node (ie, those accessible from a DFS traversal from this // instruction). Instructions contained in 'order' which are not in the set of // instructions rooted at this node are ignored. 'order' must also be a valid // topological sort of these instructions (defs appear before uses) though // need not be a DFS post-order. Status AcceptOrdered(DfsHloVisitor* visitor, const std::vector& order); // Visit this instruction and only this instruction with the given visitor. Status Visit(DfsHloVisitor* visitor); // Returns the literal associated with this instruction. // // Note: only constant and parameter opcodes have an associated literal. const Literal& literal() const; // Returns the parameter number associated with this instruction. // // Note: only parameter opcodes have an associated parameter number. int64 parameter_number() const { CHECK_EQ(HloOpcode::kParameter, opcode_); return parameter_number_; } const string& parameter_name() const { CHECK_EQ(HloOpcode::kParameter, opcode_); return parameter_name_; } void set_parameter_name(const string& str) { CHECK_EQ(HloOpcode::kParameter, opcode_); parameter_name_ = str; } // Returns the dimension sizes or numbers associated with this instruction. // // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, // and reverse. const std::vector& dimensions() const; int64 dimensions(int64 index) const; // Accessor for the dimension in which a concatenate HLO should occur. // Precondition: opcode() == HloOpcode::kConcatenate int64 concatenate_dimension() const; // Returns the tuple index associated with this instruction. // // Precondition: opcode() == HloOpcode::kGetTupleElement int64 tuple_index() const; // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. // The setter should only be called by HloModule or HloComputation methods. // // Precondition: The instruction has a valid to_apply_ field. HloComputation* to_apply() const; void set_to_apply(HloComputation* to_apply); // Returns the custom_call_target for CustomCall. // Precondition: opcode() == HloOpcode::kCustomCall const string& custom_call_target() const; // Returns the config for the Outfeed instruction. // Precondition: opcode() == HloOpcode::kOutfeed const string& outfeed_config() const; // Returns the shape for the Outfeed instruction. // Precondition: opcode() == HloOpcode::kOutfeed const Shape& outfeed_shape() const; // Gets/sets the while_condition or while_body HloComputation for While. The // setters should only be called by HloModule or HloComputation methods. // // Precondition: The instruction is a While instruction. HloComputation* while_condition() const; HloComputation* while_body() const; void set_while_condition(HloComputation* while_condition); void set_while_body(HloComputation* while_body); // Gets/sets the select or scatter HloComputation for SelectAndScatter. The // setters should only be called by HloModule or HloComputation methods. // // Precondition: opcode() == HloOpcode::kSelectAndScatter. HloComputation* select() const; HloComputation* scatter() const; void set_select(HloComputation* select); void set_scatter(HloComputation* scatter); // Returns a string for the signature of this instruction if considered as a // function, e.g. the signature of an F32 add is (F32, F32) -> F32. string SignatureString() const; // Returns a debugging string that represents this instruction. string ToString(bool compact_operands = false, bool include_metadata = true) const; string ToStringNoMetadata() const { return ToString(false, false); } // As ToString, but returns a shorter string. string ToShortString() const; // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const; // Returns a category for the HLO. This could be something like "convolution" // or "elementwise". string ToCategory() const; // Returns the string concatenation of parent name and this instructions // name. This name is guaranteed to be unique among all instructions in the // HloModule. string FullyQualifiedName() const; // Returns a logging instruction, if the output of this instruction is logged. // // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace HloInstruction* tracing() const; void set_tracing(HloInstruction* trace_instruction); // Returns the channel id associated with the instruction. The id is // shared between each Send/Recv pair and is globally unique to identify each // channel. // // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv int64 channel_id() const { return channel_id_; } // Returns feature_index field associated with the instruction. The index // represents the index of the feature dimension. // // Precondition: opcode() == HloOpcode::kBatchNormTraining int64 feature_index() const { return feature_index_; } // Returns a epsilon value associated with the instruction. The is a small // number added to the variance to avoid divide-by-zero error. // // Precondition: opcode() == HloOpcode::kBatchNormTraining float epsilon() const { return epsilon_; } // Returns the infeed configuration string. The infeed configuration includes // any metadata needed for the backend compiler (e.g., infeed buffer address) // and is target-dependent. string infeed_config() const { return infeed_config_; } void set_infeed_config(const string& config) { infeed_config_ = config; } // Returns a tag to be used in tracing. // // Precondition: opcode() == HloOpcode::kTrace string TracingTag() const; // Returns whether the instruction is a constant. bool IsConstant() const; // Returns true if this instruction is fused, ie contained within a fusion // instruction. bool IsFused() const; // Returns true if this instruction can be legally fused into a fusion // instruction. bool IsFusable() const; // Returns the fusion instruction that contains this instruction. // // Note: only valid if this instruction is fused into a fusion instruction. HloInstruction* fusion_instruction() const; // Returns the root instruction of the fused expression contained within this // fusion instruction. // // Precondition: opcode() == HloOpcode::kFusion HloInstruction* fused_expression_root() const; // Returns the computation for this fused instruction. // // Precondition: opcode() == HloOpcode::kFusion HloComputation* fused_instructions_computation() const; // Returns the vector of fused instructions inside this fusion // instruction. The order is a reverse postorder of the fused expression (root // is first in the order). // // Note: although the list itself is const, the instructions contained in the // list returned here are mutable. // // Precondition: opcode() == HloOpcode::kFusion const std::list>& fused_instructions() const; // Returns the fused parameter instruction in this fusion instruction // corresponding to the given parameter number. // // Precondition: opcode() == HloOpcode::kFusion HloInstruction* fused_parameter(int64 parameter_number) const; // Returns the vector of fused parameters inside this fusion instruction. // // Precondition: opcode() == HloOpcode::kFusion const std::vector& fused_parameters() const; // Returns true if this instruction is a fusion instruction that generates // multiple outputs. const bool IsMultiOutputFusion() const { return (opcode() == HloOpcode::kFusion && fused_expression_root()->opcode() == HloOpcode::kTuple); } FusionKind fusion_kind() const { CHECK_EQ(HloOpcode::kFusion, opcode_); return fusion_kind_; } // Merges the fused instructions from 'instruction_to_merge' into the // fused instruction set of 'this', updating operands as necessary. // // Precondition: opcode() == HloOpcode::kFusion // Predondition: 'instruction_to_merge' must be an operand of 'this'. void MergeFusionInstruction(HloInstruction* instruction_to_merge); // Fuses the given instruction in this fusion instruction. instruction_to_fuse // is cloned and the clone is placed in the fusion // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather // than moved to cleanly handle the case where the instruction has a use // outside the fusion instruction. Moving such an instruction into a fusion // instruction would violate the single-result invariant of HLO instructions // and significantly complicate code generation. // // Precondition: this->opcode() == HloOpcode::kFusion HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse); // Returns the start index in the given dimension for a slice node. // // Precondition: opcode() == HloOpcode::kSlice int64 slice_starts(int64 dimension) const { CHECK_EQ(HloOpcode::kSlice, opcode_); return slice_starts_[dimension]; } const std::vector& slice_starts() const { return slice_starts_; } // Returns the (exclusive) limit index in the given dimension for a slice // node. // // Precondition: opcode() == HloOpcode::kSlice int64 slice_limits(int64 dimension) const { CHECK_EQ(HloOpcode::kSlice, opcode_); return slice_limits_[dimension]; } const std::vector& slice_limits() const { CHECK_EQ(HloOpcode::kSlice, opcode_); return slice_limits_; } // Returns the size of the slice in the given dimension for a dynamic // slice node. // // Precondition: opcode() == HloOpcode::kDynamicSlice int64 slice_sizes(int64 dimension) const { CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); return dynamic_slice_sizes_[dimension]; } const std::vector& dynamic_slice_sizes() const { CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); return dynamic_slice_sizes_; } // Returns the number of exponent bits for a reduce-precision node. // // Precondition: opcode() == HloOpcode::kReducePrecision int32 exponent_bits() const { CHECK_EQ(HloOpcode::kReducePrecision, opcode_); return exponent_bits_; } // Returns the number of mantissa bits for a reduce-precision node. // // Precondition: opcode() == HloOpcode::kReducePrecision int32 mantissa_bits() const { CHECK_EQ(HloOpcode::kReducePrecision, opcode_); return mantissa_bits_; } // Returns data on the window in a windowed operation such as // convolution. const Window& window() const { CHECK(window_ != nullptr); return *window_; } // Returns the padding configuration for a pad node. // // Precondition: opcode() == HloOpcode::kPad const PaddingConfig& padding_config() const { CHECK(padding_config_ != nullptr); return *padding_config_; } // Returns data on the dimension numbers used for a convolution // operation. const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { CHECK(convolution_dimension_numbers_ != nullptr); return *convolution_dimension_numbers_; } // Returns the dump string of the convolution dimension numbers. string ConvolutionDimensionNumbersToString() const; // Returns the random distribution for this rng node. // // Precondition: opcode() == HloOpcode::kRng RandomDistribution random_distribution() const; // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of // the instruction to form the name of the cloned instruction. std::unique_ptr Clone(const string& suffix = "clone"); // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands); // Returns the computations this instruction calls (if any). This includes // computations called by fused instructions inside of a fusion instruction. const std::vector& called_computations() const { return called_computations_; } // Replaces all called computations based on a map function. This is needed // when we clone hlo_computations and want to let the instructions to point // to the newly cloned nodes. void ReplaceCalledComputations( std::function map_function) { for (int64 i = 0; i < called_computations_.size(); ++i) { called_computations_[i] = map_function(called_computations_[i]); } } // Returns true if this instruction performs an elementwise operation on // `operand_idx`-th operand. An instruction is elementwise on an operand iff, // after performing necessary implicit broadcast // (cs/IrArray::EmitArrayElementAddress), to compute the output at index // {i_0,i_1,...,i_n}, the only element required from the operand (if any) is // the element at {i_0,i_1,...,i_n}. // // Note on performance: when this instruction is kFusion, this method, in the // worst case, scans all fused instructions. We could speed this up by // caching. bool IsElementwiseOnOperand(int64 operand_idx) const; // Returns true if this instruction is elementwise on all its operands. bool IsElementwise() const; // Returns whether this instruction may reuse elements of its `i`th operand. bool ReusesOperandElements(int64 i) const { return OperandElementUse(i) == UseKind::kReuse; } // Returns the indices that the given operand appear in the operand list of // this instruction. Note that an instruction can use the same operand // multiple times. std::vector OperandIndices(const HloInstruction* operand) const; // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If // this reshape merely inserts or deletes 1-sized dimensions, return the input // indices of the deleted dimensions and the output indices of the inserted // dimensions. // // Precondition: this op must be a reshape. std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; // Returns the opcode string for this instruction. Compared with // HloOpcodeString method, this wrapper dumps additional information // such as fusion kind. string ExtendedOpcodeStr() const; // Returns a string identifier for this instruction. If no string identifier // has been explicitly set, then the identifier is the serialized pointer to // this instruction. const string& name() const { return name_; } // Use the given NameUniquer to select a unique name for the instruction based // on the instruction's existing name. void UniquifyName(NameUniquer* name_uniquer); // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } // Set/get the computation containing this instruction. set_parent should only // be called by HloComputation methods which add/remove instructions to // computations. void set_parent(HloComputation* computation) { parent_ = computation; } const HloComputation* parent() const { return parent_; } HloComputation* parent() { return parent_; } // Returns the module for this instruction. HloModule* GetModule() const; // Returns whether we could assign input and output layouts to this // instruction to make it a bitcast. bool CouldBeBitcast() const; // Sets the parent fusion instruction for this instruction. // // Precondition: opcode() == HloOpcode::kFusion void SetParentFusion(HloInstruction* fusion_instruction) { CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); parent_fusion_instruction_ = fusion_instruction; } // Get/Set the number of partitions per outer dimension (in order, starting // with outer-most dimension first). Currently used by the parallel cpu // backend to partition HLOs into parallel tasks. // TODO(b/62783254) Replace these methods with a more general way to // annotate HLOs with backend-specific information. const std::vector& outer_dimension_partitions() const { return outer_dimension_partitions_; } void set_outer_dimension_partitions( const std::vector& outer_dimension_partitions); private: enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; // Helper class for computing OperandElementUse for kFusion. class FusionReusesParamElements; // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( const Shape& shape, HloOpcode opcode, tensorflow::gtl::ArraySlice operands); // Appends operand to the list of operands and adds this instruction as a user // of the operand. void AppendOperand(HloInstruction* operand); // Adds a user for this instruction. void AddUser(HloInstruction* user); // Removes a user for this instruction. void RemoveUser(HloInstruction* user); // Internal constructor for a given opcode/shape, other fields must be filled // by factory methods. HloInstruction(HloOpcode opcode, const Shape& shape); // Clones the given instruction_to_fuse and insert the clone into this fusion // instruction. // // Precondition: opcode() == HloOpcode::kFusion HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse); // Clones a fusion instruction with a new shape and operands. std::unique_ptr CloneFusionWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands); // CHECKs various invariants of a fusion instruction. void CheckFusionInstruction() const; // Returns true if this instruction can legally have the dimensions field // set. Used for checking precondition of dimensions field accessors. bool CanHaveDimensionsField() const; // Returns how this instruction uses elements of its `i`th operand. UseKind OperandElementUse(int64 i) const; // Shape of outfeed request. Shape outfeed_shape_; // Result shape of this instruction. Shape shape_; // Opcode for this instruction. HloOpcode opcode_; // Literal, only present for kConstant. std::unique_ptr literal_; // Constant index, only present for kGetTupleElement. int64 tuple_index_ = 0; // Dimensions present for some operations that require reshaping or // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. std::vector dimensions_; // Describes the window in a windowed operation such as convolution. std::unique_ptr window_; // Describes the dimension numbers used for a convolution. std::unique_ptr convolution_dimension_numbers_; // Describes the [begin, end) index range for a slice. std::vector slice_starts_; std::vector slice_limits_; // The bit sizes for a reduce-precision operation. int32 exponent_bits_; int32 mantissa_bits_; // Describes the [start, start + size) range size for a dynamic slice // ('start' is specified dynamically in the second operand of the operation). std::vector dynamic_slice_sizes_; // The padding configuration that describes the edge padding and interior // padding of this pad instruction. Only set for pad instructions. std::unique_ptr padding_config_; // The computation that stores of instructions fused into this fusion // instruction. Only set for fusion instructions. std::unique_ptr fused_instructions_computation_; // If this instruction is fused into a fusion instruction, this field points // to the fusion instruction. HloInstruction* parent_fusion_instruction_ = nullptr; // The type of the fusion. Used by kFusion only. FusionKind fusion_kind_; // For parameter instructions this field holds the parameter number. int64 parameter_number_ = 0; string parameter_name_; // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; // Computations called by this instruction. std::vector called_computations_; // Indices of computations in called_computations_ for instructions which call // multiple computations. enum { // kWhile computations. kBodyComputationIndex = 0, kConditionComputationIndex = 1, // kSelectAndScatter computations. kSelectComputationIndex = 0, kScatterComputationIndex = 1, }; // Outfeed configuration information, only present for kOutfeed. string outfeed_config_; // Instruction operands. std::vector operands_; // The users of this instruction. Users are HLOs where this instruction is an // operand. The vector users_ and the set user_set_ contain identical // members. The set enables fast membership testing and the vector enables // fast, stable iteration. std::vector users_; std::unordered_set user_set_; // The set of control predecessors of this instruction. std::vector control_predecessors_; // The set of control successors of this instruction. std::vector control_successors_; // A trace instruction that consumes this instruction. // // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as // an operand. HloInstruction* trace_instruction_ = nullptr; // The distribution requested for random number generation. // Only present for kRng. RandomDistribution distribution_; // A small float number added to the variance to avoid divide-by-zero error. // Only present for kBatchNormTraining. float epsilon_; // An integer value representing the index of the feature dimension. // Only present for kBatchNormTraining. int64 feature_index_; // Represents a unique identifier for each Send/Recv instruction pair. // Only present for kSend or kRecv. int64 channel_id_ = -1; // The string representation of the infeed configuration. string infeed_config_; // String identifier for instruction. string name_; // The computation in which this instruction is contained. HloComputation* parent_ = nullptr; // Metadata for debugging. OpMetadata metadata_; // The number of partitions per outer dimension (listed in order from // outer-most dimension first). std::vector outer_dimension_partitions_; TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); }; string ToString(HloInstruction::FusionKind kind); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_