/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // HLO instructions are in DAG form and represent the computations that the user // has built up via the XLA service interface. They are ultimately lowered // in a platform-aware way by traversing the HLO DAG and emitting a lowered // form; e.g. see DfsHloVisitor. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ #include #include #include #include #include #include #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" namespace xla { class HloComputation; class HloModule; // A bunch of switches that control how the hlo text should be printed. class HloPrintOptions { public: enum class PrintSubcomputationMode { kOff, // Do not print anything about subcomputations. kNameOnly, // Only print the name of subcomputations. kFullBodies, // Print the full bodies of subcomputations. }; // Constructs the default print options: don't print large constants, don't // compact operands, no indentation. HloPrintOptions() : print_large_constants_(false), print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly), print_metadata_(true), print_backend_config_(true), compact_operands_(false), print_operand_shape_(true), print_operand_names_(true), print_program_shape_(true), print_percent_(true), print_control_dependencies_(true), canonicalize_instruction_names_(false), indent_amount_(0), is_in_nested_computation_(false) {} static HloPrintOptions ShortParsable() { return HloPrintOptions() .set_print_large_constants(true) .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly) .set_print_metadata(false) .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_program_shape(false) .set_print_percent(false) .set_print_control_dependencies(false); } // Options to produce the canonical string representing an isomorphic // computation graph. static HloPrintOptions Canonical() { return HloPrintOptions() .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) .set_print_metadata(false) .set_print_backend_config(false) .set_compact_operands(true) .set_print_operand_names(false) .set_print_operand_shape(true) .set_print_program_shape(false) .set_print_percent(false) .set_print_control_dependencies(false) .set_canonicalize_instruction_names(true); } // If true, large constants will be printed out. HloPrintOptions& set_print_large_constants(bool value) { print_large_constants_ = value; return *this; } HloPrintOptions& set_print_subcomputation_mode( PrintSubcomputationMode value) { print_subcomputation_mode_ = value; return *this; } // If true, metadata will be printed. HloPrintOptions& set_print_metadata(bool value) { print_metadata_ = value; return *this; } // If true, backend_config will be printed. HloPrintOptions& set_print_backend_config(bool value) { print_backend_config_ = value; return *this; } // If true, operands' shapes will be printed. HloPrintOptions& set_print_operand_shape(bool value) { print_operand_shape_ = value; return *this; } // If true, the operand names will be printed. HloPrintOptions& set_print_operand_names(bool value) { print_operand_names_ = value; return *this; } // If true, program shape of hlo computations will be printed. HloPrintOptions& set_print_program_shape(bool value) { print_program_shape_ = value; return *this; } // If true, names will be printed with prefix '%'. HloPrintOptions& set_print_percent(bool value) { print_percent_ = value; return *this; } // If true, control dependencies will be printed. HloPrintOptions& set_print_control_dependencies(bool value) { print_control_dependencies_ = value; return *this; } // If true, only a part of operands will be printed out (note that in this // case the text will not be parsable). HloPrintOptions& set_compact_operands(bool value) { compact_operands_ = value; return *this; } // If true, canonicalizes instructions' name. Instead of using "%foo.1" as // the name of an instruction, we use "%tmp_1", "%tmp_2" etc. HloPrintOptions& set_canonicalize_instruction_names(bool value) { canonicalize_instruction_names_ = value; return *this; } // The indent of the hlo text block. HloPrintOptions& set_indent_amount(int value) { indent_amount_ = value; return *this; } // If true, indicates the instruction being printed is inside a nested // computation. HloPrintOptions& set_is_in_nested_computation(bool value) { is_in_nested_computation_ = value; return *this; } bool print_large_constants() const { return print_large_constants_; } PrintSubcomputationMode print_subcomputation_mode() const { return print_subcomputation_mode_; } bool print_metadata() const { return print_metadata_; } bool print_backend_config() const { return print_backend_config_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } bool print_operand_names() const { return print_operand_names_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } bool print_control_dependencies() const { return print_control_dependencies_; } bool canonicalize_instruction_names() const { return canonicalize_instruction_names_; } int indent_amount() const { return indent_amount_; } int is_in_nested_computation() const { return is_in_nested_computation_; } private: bool print_large_constants_; PrintSubcomputationMode print_subcomputation_mode_; bool print_metadata_; bool print_backend_config_; bool compact_operands_; bool print_operand_shape_; bool print_operand_names_; bool print_program_shape_; bool print_percent_; bool print_control_dependencies_; bool canonicalize_instruction_names_; int indent_amount_; bool is_in_nested_computation_; }; // For canonical string output, we need to have a canonical way to rename // each instruction and its operands. Each operand is renamed as "tmp_", // where is an index starting from 0. class CanonicalNameMap { public: CanonicalNameMap() : index(0) {} string LookupOrInsert(const string& old_name) { auto iter = canonical_name_map.find(old_name); if (iter != canonical_name_map.end()) { return iter->second; } string new_name = absl::StrCat("tmp_", index++); canonical_name_map[old_name] = new_name; return new_name; } void Clear() { canonical_name_map.clear(); index = 0; } private: int64 index; absl::flat_hash_map canonical_name_map; }; // HLO instructions are the atomic unit of the high-level compiler's IR. // // HloInstructions live inside of an HloComputation, which is analogous to a // function in other programming languages. Nodes have no total order within // their computation. Instead, they have a partial ordering determined by their // data and control dependencies. // // HLO does not have basic blocks or explicit "branch" instructions. Instead, // certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode // control flow. For example, the kConditional HLO executes one of two possible // computations, depending on the runtime value of a predicate. // // HLO is pure (mostly). It has no concept of mutable state. Instead, data // values are produced by one HLO and flow into consumers across dependency // edges. class HloInstruction { public: // A fusion node computes the same value a call to its fusion computation // would compute. However, the choice of fusion kind dictates codegen // strategy for the backend. // // To generate code for a kFusion HloInstruction, most backends do something // like the following: // // 1) Identify the "primary" HloInstruction of the fused computation. // 2) Emit code that does the work of the primary node, creating its inputs // and transforming its outputs as specified by the fused computation. // // In step (2), the code emitted is usually similar to the code that would be // emitted for an *unfused* version of the primary node, except that // // - when the primary node reads an element of one of its operands, instead // of loading the value from memory, it *computes* the value based on the // contents of the fused computation. // - when the primary node outputs a value, instead of storing it to memory, // it forwards the value to its users, which then perform additional // computations before the value is finally stored to memory at the root of // the fusion node. // // An HloInstruction's FusionKind helps us find the kFusion instruction's // primary node, and can also affect how we generate code in step (2). // // - kInput: The primary node is the root of the fused instruction. // // - kOutput: The primary node is not the root of the fused instruction. // This fusion kind requires that one operand buffer of the fusion // instruction be able to alias the output buffer. This constraint is // usually enough to let backends find the primary node unambiguously. // // - kLoop: The primary node is the root of the fused computation, but, // unlike in input fusion, we prescribe a specific implementation for // codegen. Rather than generating code that looks like the code we'd emit // for an unfused version of the primary/root node, we emit code that // generates one element of the root at a time. // // - kCustom: Custom category for backend-specific fusions that don't fit // into the above patterns. // // Not all backends support all fusion kinds, and given a particular fused // computation, it's not in general safe to change its fusion kind. Creation // of fusion nodes is always backend-specific. // // For elementwise ops (e.g. kAdd), most backends would emit a // one-element-at-a-time implementation for the unfused version, so loop // fusion and input fusion are probably equivalent if the root node is // elementwise. They're not necessarily equivalent e.g. for kReduce, where an // implementation might emit something more sophisticated for an unfused or // input-fusion reduce, but will emit the naive code that reduces one element // at a time for loop fusion with a reduce as the root. // // Another way to think of loop fusion is that it's equivalent to input // fusion, but where the root node is an implicit identity node, whose // unfused implementation is "read one element, write one element". // // TODO(b/79869434): This categorization scheme is not great. For one thing, // input and loop fusion are basically the same thing: There is no reason for // the HLO to encode backend-specific decisions about how e.g. a reduce that's // the root of a fusion should be lowered. In addition, this scheme as // written doesn't work for multi-output fusion, where the primary node is // never actually the root (which is a kTuple instruction that gathers the // multiple outputs of the fusion). enum class FusionKind { kLoop, kInput, kOutput, kCustom, }; virtual ~HloInstruction(); // Creates an instruction from the given proto. Arguments: // // proto: the proto to convert from. // instruction_map: a map from instruction id to HloInstruction*. This map // must contain all operands of the newly constructed instruction. // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed instruction // calls. static StatusOr> CreateFromProto( const HloInstructionProto& proto, const absl::flat_hash_map& instruction_map, const absl::flat_hash_map& computation_map); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, const Shape& shape, const string& name); // Creates a literal constant instruction. static std::unique_ptr CreateConstant(Literal literal); // Creates an Iota instruction. static std::unique_ptr CreateIota(const Shape& shape, int64 iota_dimension); // Creates a get tuple element instruction. static std::unique_ptr CreateGetTupleElement( const Shape& shape, HloInstruction* operand, int64 index); // Creates a trace instruction that logs the input operand in the computation. static std::unique_ptr CreateTrace(const string& tag, HloInstruction* operand); // Creates a random number generation instruction that fills a shape with // random numbers from a given distribution. static std::unique_ptr CreateRng( const Shape& shape, RandomDistribution distribution, absl::Span parameters); // Creates a unary instruction (one operand). // Precondition: opcode must be a legitimate unary operation. static std::unique_ptr CreateUnary(const Shape& shape, HloOpcode opcode, HloInstruction* operand); // Creates a binary instruction (two operands). // Precondition: opcode must be a legitimate binary operation. static std::unique_ptr CreateBinary(const Shape& shape, HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs); // Creates a ternary instruction (three operands). // Precondition: opcode must be a legitimate ternary operation. static std::unique_ptr CreateTernary(const Shape& shape, HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs, HloInstruction* ehs); // Creates a variadic instruction (variable number of operands). // Precondition: opcode must be a legitimate variadic operation. static std::unique_ptr CreateVariadic( const Shape& shape, HloOpcode opcode, absl::Span operands); // Creates a map instruction, where the computation (given by the handle) is // applied element-wise to every element in operands (across the operands, // at a given index) static std::unique_ptr CreateMap( const Shape& shape, absl::Span operands, HloComputation* map_computation); // Creates a convolution op, where rhs is the convolutional filter // and window describes how the filter is applied to lhs. static std::unique_ptr CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config); // Creates an FFT op, of the type indicated by fft_type. static std::unique_ptr CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, absl::Span fft_length); // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch // dimensions specified in 'dimension_numbers'. static std::unique_ptr CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config); // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to // reduce it to. static std::unique_ptr CreateReducePrecision( const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits); // Creates a cross replica reduction op. // // `reduction_computation`: the reduction function. // // `replica_groups`: each ReplicaGroup contains a list of replica id. If // empty, all replicas belong to one group in the order of 0 - (n-1). // Allreduce will be applied within subgroups. // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // // `all_reduce_id`: for Allreduce nodes from different modules, if they have // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will // not be applied cross modules. // // TODO(b/79737069): Rename this to AllReduce. static std::unique_ptr CreateCrossReplicaSum( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id); // This op handles the communication of an Alltoall operation. On each core, // the operands are N ops in the same shape, where N is the number of cores // participating the Alltoall. Then the N operands are scattered to N cores, // e.g., the ith operand is sent to the ith core. Then each core gathers the // received data into a tuple. // // - `replica_groups`: each ReplicaGroup contains a list of replica id. If // empty, all replicas belong to one group in the order of 0 - (n-1). Alltoall // will be applied within subgroups in the specified order. For example, // replica groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied // within replica 1, 2, 3, and in the gather phase, the received blocks will // be concatenated in the order of 1, 2, 3; another Alltoall will be applied // within replica 4, 5, 0, and the concatenation order is 4, 5, 0. static std::unique_ptr CreateAllToAll( const Shape& shape, absl::Span operands, const std::vector& replica_groups); // Creates a communitation instructions that permutes data cross replicas. // Data is sent/received according to the (source_replica_id, // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a // target_replica_id in any pair, the output on that replica is a tensor // conssits of 0(s) in `shape`. static std::unique_ptr CreateCollectivePermute( const Shape& shape, HloInstruction* operand, const std::vector>& source_target_pairs); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. static std::unique_ptr CreateConvert(const Shape& shape, HloInstruction* operand); // Creates a bitcast conversion instruction, where operand is the data to // convert and shape is the target shape for the conversion. static std::unique_ptr CreateBitcastConvert( const Shape& shape, HloInstruction* operand); // Creates an infeed instruction, which reads data of the given shape from the // Infeed interface of the device. infeed_shape is the shape of the data // received from the infeed *not* the shape of the infeed instruction which // is a tuple containing the infeed_shape and the TOKEN. static std::unique_ptr CreateInfeed( const Shape& infeed_shape, HloInstruction* token_operand, const string& config); // Creates an outfeed instruction, which outputs data. outfeed_shape is the // shape of the data being outfed *not* the shape of the outfeed instruction // which is a TOKEN. static std::unique_ptr CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, HloInstruction* token_operand, absl::string_view outfeed_config); // Creates an asynchronous send instruction with the given channel id, which // initiates sending the operand data to a unique receive instruction in // another computation that has the same channel id. If is_host_transfer is // true, then this Send operation transfers data to the host. static std::unique_ptr CreateSend( HloInstruction* operand, HloInstruction* token, int64 channel_id, bool is_host_transfer = false); // Blocks until data transfer for the Send instruction (operand) is complete. // The operand must be kSend. static std::unique_ptr CreateSendDone( HloInstruction* operand, bool is_host_transfer = false); // Creates an asynchronous receive instruction with the given channel id, // which allocates resources to receive data of the given shape from a unique // send instruction in another computation that has the same channel id. If // is_host_transfer is true, then this Send operation transfers data from the // host. static std::unique_ptr CreateRecv( const Shape& shape, HloInstruction* token, int64 channel_id, bool is_host_transfer = false); // Blocks until data transfer for the Recv instruction (operand) is complete // and returns the receive buffer. The operand must be kRecv. static std::unique_ptr CreateRecvDone( HloInstruction* operand, bool is_host_transfer = false); // Creates a slice instruction, where the operand is sliced by the given // start/limit indices. static std::unique_ptr CreateSlice( const Shape& shape, HloInstruction* operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides); // Creates a slice instruction, where the first operand is sliced by // start indices specified in the second operand, and by size specified in // 'slice_sizes'. static std::unique_ptr CreateDynamicSlice( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, absl::Span slice_sizes); // Creates a dynamic update slice instruction, which updates a slice // of 'operand' with 'update' and 'start_indices'. static std::unique_ptr CreateDynamicUpdateSlice( const Shape& shape, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices); // Creates a concatenate instruction, where the operands are concatenated on // the provided dimension. static std::unique_ptr CreateConcatenate( const Shape& shape, absl::Span operands, int64 dimension); // Creates a reduce instruction, where the computation (given by the handle) // is applied successively to every element in operand. For example, let f be // the function to apply, which takes 2 arguments, an accumulator and the // current value. Let init be an initial value (which is normally chosen to be // the identity element for f, e.g. 0 if f is addition). // Then the reduce HLO will compute: // f(f(init, value0), value1), ...) static std::unique_ptr CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, absl::Span dimensions_to_reduce, HloComputation* reduce_computation); // A more general, multiple-argument version of the above. // The function to apply, f, now takes N arguments: // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., // init_valueN], and returns an N-tuple. The performed computation is (for // commutative and associative f operators) equivalent to: // // f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0) // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1, // ..., inputN.value1) // ... // TODO(b/112040122): Add support to this in HLO passes and in backends. static std::unique_ptr CreateReduce( const Shape& shape, absl::Span operands, absl::Span init_values, absl::Span dimensions_to_reduce, HloComputation* reduce_computation); // Creates a reduce-window instruction, where the computation (given // by the handle) is applied window-wise at each valid window // position in the operand. static std::unique_ptr CreateReduceWindow( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation); // Creates a batch-norm-training instruction. static std::unique_ptr CreateBatchNormTraining( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index); // Creates a batch-norm-inference instruction. static std::unique_ptr CreateBatchNormInference( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, float epsilon, int64 feature_index); // Creates a batch-norm-grad instruction. static std::unique_ptr CreateBatchNormGrad( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output, float epsilon, int64 feature_index); // Creates a scatter computation that scatters the `source` array to the // selected indices of each window. static std::unique_ptr CreateSelectAndScatter( const Shape& shape, HloInstruction* operand, HloComputation* select, const Window& window, HloInstruction* source, HloInstruction* init_value, HloComputation* scatter); // Creates a broadcast instruction. static std::unique_ptr CreateBroadcast( const Shape& shape, HloInstruction* operand, absl::Span broadcast_dimensions); // Creates a sequence of instructions that performs an explicit broadcast of // the operand to the target shape. // // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is // returned as a unique_ptr for API consistency with other factory methods in // this interface. // // TODO(b/72173833) Ideally HloComputations would always be present, and so // the adder being passed by the caller would not be necessary. static std::unique_ptr CreateBroadcastSequence( const Shape& output_shape, HloInstruction* operand, const std::function)>& adder); // Creates a pad instruction, where the operand is padded on the edges and // between the elements with the given padding value. static std::unique_ptr CreatePad( const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config); // Creates a reshape instruction, where the operand is flattened row-major // order and then reshaped to the given result shape. static std::unique_ptr CreateReshape(const Shape& shape, HloInstruction* operand); // Creates a transpose instruction which permutes the operand dimensions. static std::unique_ptr CreateTranspose( const Shape& shape, HloInstruction* operand, absl::Span dimensions); // Creates a sort op, with a keys operand, and optional values operands. static std::unique_ptr CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, absl::Span values = {}); // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 // corresponds to the C code below. // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 } static std::unique_ptr CreateWhile(const Shape& shape, HloComputation* condition, HloComputation* body, HloInstruction* init); static std::unique_ptr CreateConditional( const Shape& shape, HloInstruction* pred, HloInstruction* true_computation_arg, HloComputation* true_computation, HloInstruction* false_computation_arg, HloComputation* false_computation); static std::unique_ptr CreateGather( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, absl::Span slice_sizes); static std::unique_ptr CreateScatter( const Shape& shape, HloInstruction* operand, HloInstruction* scatter_indices, HloInstruction* updates, HloComputation* update_computation, const ScatterDimensionNumbers& scatter_dim_numbers); // Creates a kDomain instruction which delimits an HLO domain which have // the provided user and operand side metadata. static std::unique_ptr CreateDomain( const Shape& shape, HloInstruction* operand, std::unique_ptr operand_side_metadata, std::unique_ptr user_side_metadata); // Creates a fusion instruction. A fusion instruction contains one or more // fused instructions forming an expression with a single root // "fused_root". Additional instructions can be added to the fusion // instruction with the method FuseInstruction. static std::unique_ptr CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); static std::unique_ptr CreateFusion( const Shape& shape, FusionKind fusion_kind, absl::Span operands, HloComputation* fusion_computation); // Creates a call instruction that applies the given computation on the given // operands. "shape" is the resultant shape. static std::unique_ptr CreateCall( const Shape& shape, absl::Span operands, HloComputation* computation); // Creates a custom call instruction that applies the given custom call target // to the given operands. "opaque" can be an arbitrary string with a // backend-specific interpretation. "shape" is the resultant shape. static std::unique_ptr CreateCustomCall( const Shape& shape, absl::Span operands, absl::string_view custom_call_target, absl::string_view opaque = ""); // Overload which constrains the layouts of the operand and result. 'shape' // and 'operand_shapes_with_layout' must have layouts. // 'operand_shapes_with_layout' must have a compatible element for each // operand. static std::unique_ptr CreateCustomCall( const Shape& shape, absl::Span operands, absl::string_view custom_call_target, absl::Span operand_shapes_with_layout, absl::string_view opaque = ""); // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. static std::unique_ptr CreateTuple( absl::Span elements); // Creates a reverse instruction, which reverses the order of the elements // in the specified dimensions. static std::unique_ptr CreateReverse( const Shape& shape, HloInstruction* operand, absl::Span dimensions); // Creates a Afterall instruction used for joining or creating new values of // token type which thread through side-effecting operations. Operands must // all be tokens, and there must be at least one operand. static std::unique_ptr CreateAfterAll( absl::Span operands); // Creates an AfterAll instruction which creates a token type out of thin air // (no operands). This is a separate method from CreateAfterAll to facility // the removal of operand-less AfterAll instructions. // TODO(b/110532604): Remove this capability of creating a token from nothing // when we plumb a primordial token from the entry computation. static std::unique_ptr CreateToken(); // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } // Returns true if this instruction has a side effect, irrespective of whether // any called computations may contain an instruction with side effects. bool HasSideEffectNoRecurse() const; // Returns true if this instruction has a side effect. An instruction has a // side effect if it uses certain opcodes or calls a computation with a side // effect. bool HasSideEffect() const; // Returns the result shape of this instruction. const Shape& shape() const; // Returns the (mutable) result shape of this instruction. Shape* mutable_shape() { return &shape_; } // Returns the ith operand to this instruction. const HloInstruction* operand(int64 i) const; // Returns the ith operand to this instruction. HloInstruction* mutable_operand(int64 i); // Returns the number of operands to this instruction. int64 operand_count() const { return operands_.size(); } // Returns the vector of operands of this instruction. using InstructionVector = absl::InlinedVector; const InstructionVector& operands() const { return operands_; } // Returns the vector of unique operands, in the same order they are found // within the operand vector. InstructionVector unique_operands() const; // Returns the index of 'target' in the operands sequence. // Precondition: target must be an operand (or a fatal error will occur). int64 operand_index(const HloInstruction* target) const; // Returns the number of users of this instruction. int64 user_count() const { return users_.size(); } // Returns the users of this instruction. const std::vector& users() const { return users_; } // Returns true if this instruction is a user of 'instruction'. bool IsUserOf(const HloInstruction* instruction) const { return ContainsKey(instruction->user_set_, this); } // Adds a control dependency from this instruction to the given // instruction. This instruction becomes a control predecessor of // 'instruction', and 'instruction' becomes a control successor of this // instruction. Returns an error status if either of the given instructions // does not belong to the same computation. // // This is used to enforce an additional ordering requirement that is not // captured by normal data dependencies, such as ordering among Send or Recv // operations to avoid deadlock. Status AddControlDependencyTo(HloInstruction* instruction); // Removes a previously added control dependency from this instruction to // 'instruction'. Status RemoveControlDependencyTo(HloInstruction* instruction); // Drops all control predecessors and successors from this HLO instruction. Status DropAllControlDeps(); // Copies the control predecessors and successors on this HLO instruction to // `inst`. Does not do a deep copy so this makes sense only if `inst` and // this HLO are in the same module. // // Depending on the use cases we see in practice, in the future we may // consider folding the logic here into Clone, CloneWithNewOperands and // ReplaceAllUsesWith by treating control dependencies like data dependencies. Status CopyAllControlDepsFrom(const HloInstruction* inst); // Returns the set of control predecessors (successors) of this // instruction. Control predecessors (successors) must execute before (after) // the current instruction. const std::vector& control_predecessors() const { return control_predecessors_; } const std::vector& control_successors() const { return control_successors_; } // Returns true if "other" performs the same computation as this instruction. bool Identical( const HloInstruction& other, const std::function& eq_operands = std::equal_to(), const std::function& eq_computations = std::equal_to(), bool layout_sensitive = true) const { // An instruction is always identical to itself. if (this == &other) { return true; } // Identical instruction must have the same opcode, shape, and identical // operands. if (opcode() != other.opcode()) { return false; } if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape()) : ShapeUtil::Compatible(shape(), other.shape()))) { return false; } if (operands().size() != other.operands().size()) { return false; } // Use an explicit loop rather than ContainerEquals, because copying around // std::functions may be too expensive in some cases. for (size_t i = 0; i < operands().size(); ++i) { if (!eq_operands(operand(i), other.operand(i))) { return false; } } if (backend_config_ != other.backend_config_) { return false; } return IdenticalSlowPath(other, eq_computations); } // Returns whether the instruction has a constant operand. bool HasConstantOperand() const; // Replaces the use of this instruction in "user" with "new_producer". Note // that there might be multiple uses of this instruction in "user"; all will // be replaced. // // If user is a fusion instruction, this function will remove any duplicated // operands of it which could be created due to this replacement. Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); // Replaces the specified operand with new_operand. // // This function does NOT remove duplicated operands even if this instruction // is a fusion, so that the existing operand numbers do not change. Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand); // Replaces all uses of this instruction with the new producer. If // new_producer is a user of this instruction then new_producer remains a use // of this instruction to avoid introducing cycles into the graph. // // If this instruction is the root of its computation, sets the computation's // root to new_producer. // // If a user is a fusion instruction, this function will remove any duplicated // operands of it which could be created due to this replacement. Status ReplaceAllUsesWith(HloInstruction* new_producer); // Performs a postorder DFS visit using this node as the root. If // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when // complete. If ignore_control_predecessors is true, instructions only // reachable via control dependencies will not be visited, and the postorder // will not take control dependencies into account. It is as if the control // dependencies didn't exist in the graph at all. template Status Accept(DfsHloVisitorBase* visitor, bool call_finish_visit = true, bool ignore_control_predecessors = false); Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true, bool ignore_control_predecessors = false) const { return const_cast(this)->Accept( visitor, call_finish_visit, ignore_control_predecessors); } // Same as Accept() above, but the order of operand and control predecessor // visitation is determined by the given operand order; if compare(A, B) == // true, A is visited before B. using CompareFunction = std::function; Status AcceptWithOperandOrder(DfsHloVisitor* visitor, const CompareFunction& operand_order, bool call_finish_visit = true); // Performs a postorder DFS visit using this node as the root. Calls the given // visitor function at each instruction. Status Accept(const std::function& visitor_func); Status Accept( const std::function& visitor_func) const; // Visits all instructions rooted at this instruction using the given visitor // in the given order. 'order' must contain at least the set of instructions // rooted at this node (ie, those accessible from a DFS traversal from this // instruction). Instructions contained in 'order' which are not in the set of // instructions rooted at this node are ignored. 'order' must also be a valid // topological sort of these instructions (defs appear before uses) though // need not be a DFS post-order. Status AcceptOrdered(DfsHloVisitor* visitor, const std::vector& order); // Visit this instruction and only this instruction with the given visitor. template Status Visit(DfsHloVisitorBase* visitor); // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. std::pair LatestNonGteAncestorAndIndex() const; std::pair LatestNonGteAncestorAndIndex() { auto rv = const_cast(this)->LatestNonGteAncestorAndIndex(); return {const_cast(rv.first), rv.second}; } // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction. const HloInstruction* LatestNonGteAncestor() const; HloInstruction* LatestNonGteAncestor() { return const_cast( const_cast(this)->LatestNonGteAncestor()); } // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. // The setter should only be called by HloModule or HloComputation methods. // // Precondition: The instruction has a valid to_apply_ field. HloComputation* to_apply() const; void set_to_apply(HloComputation* to_apply); // Gets/sets the while_condition or while_body HloComputation for While. The // setters should only be called by HloModule or HloComputation methods. // // Precondition: The instruction is a While instruction. HloComputation* while_condition() const; HloComputation* while_body() const; void set_while_condition(HloComputation* while_condition); void set_while_body(HloComputation* while_body); // Gets/sets the true and false HloComputation for Conditional. The setters // should only be called by HloModule or HloComputation methods. // // Precondition: The instruction is a Conditional instruction. HloComputation* true_computation() const; HloComputation* false_computation() const; void set_true_computation(HloComputation* true_computation); void set_false_computation(HloComputation* false_computation); // Returns a string for the signature of this instruction if considered as a // function, e.g. the signature of an F32 add is (F32, F32) -> F32. string SignatureString() const; // Returns a debugging string that represents this instruction. // // (We express the default options using an overload rather than a default // param because gdb ignores default params, but does resolve overloads.) // // TODO(b/73348663): Make ToString() adaptive to the size of the string by // default, backing off on providing full information for very large strings, // or provide a different name for a ToString-like function that does that. string ToString() const { return ToString(HloPrintOptions()); } string ToString(const HloPrintOptions& options) const; // Components of the ToString() representation: // Returns a string representation of the operand list. string OperandsToString(const HloPrintOptions& options) const; // Returns string representation of op-specific attributes. std::vector ExtraAttributesToString( const HloPrintOptions& options) const; // As ToString, but returns a shorter string. string ToShortString() const; // Returns a serialized representation of this instruction. virtual HloInstructionProto ToProto() const; // Returns a category for the HLO. This could be something like "convolution" // or "elementwise". virtual string ToCategory() const; // Returns a logging instruction, if the output of this instruction is logged. // // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace HloInstruction* tracing() const; void set_tracing(HloInstruction* trace_instruction); // Returns true if this instruction is fused, ie contained within a fusion // instruction. bool IsFused() const; // Returns true if this instruction can be legally fused into a fusion // instruction. bool IsFusible() const; // Returns the sharding applied to this operator. // REQUIRES: has_sharding() is true. const HloSharding& sharding() const { CHECK(has_sharding()); return *sharding_; } std::shared_ptr sharding_ptr() const { return sharding_; } // Returns the sharding applied to this operator, or default_ if none exists. const HloSharding& sharding_or_default(const HloSharding& default_) const { return sharding_ ? *sharding_ : default_; } // Returns the sharding unique device, if any. absl::optional sharding_unique_device() const { if (sharding_ == nullptr) { return absl::optional(); } return sharding_->UniqueDevice(); } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { sharding_ = std::make_shared(sharding); } void set_sharding(std::shared_ptr sharding) { sharding_ = std::move(sharding); } void set_single_sharding(const HloSharding& sharding); // Sets a sharding that assigns the current instruction to device. void set_device_sharding(int64 device) { set_single_sharding(HloSharding::AssignDevice(device)); } // Remove any sharding from this operator. void clear_sharding() { sharding_ = nullptr; } // Return true if this operator has a sharding assigned. bool has_sharding() const { return sharding_ != nullptr; } // Checks whether the instruction has compatible sharding with the other // instruction. bool has_compatible_sharding(const HloInstruction* other) const { if (!has_sharding()) { return !other->has_sharding(); } return other->has_sharding() ? sharding() == other->sharding() : false; } // When creating a new instruction which either replaces, or shifts up (kCopy // insertion case), another instruction, we need to make sure the certain // properties of the new instruction are copied into the derived one. As of // today, the metadata and sharding will be propagated to the derived // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of // the instruction to form the name of the cloned instruction. // Ignores the control predecessors and successors of this HLO instruction. std::unique_ptr Clone( const string& suffix = "clone", HloCloneContext* context = nullptr) const; // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( const Shape& shape, absl::Span new_operands, HloCloneContext* context = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector& called_computations() const { return called_computations_; } // Replaces all called computations based on a map function. This is needed // when we clone hlo_computations and want to let the instructions to point // to the newly cloned nodes. void ReplaceCalledComputations( std::function map_function) { for (int64 i = 0; i < called_computations_.size(); ++i) { called_computations_[i] = map_function(called_computations_[i]); } } // Clears out the called computations. // // This is, in particular, necessary when inlining function bodies into their // caller. If there were side-effecting operations in the called computations, // the call itself is considered side-effecting and thus cannot be removed. By // clearing out the computations, we reflect the fact that all side-effecting // properties have been reflected in the caller, and make the call HLO // removable. void ClearCalledComputations() { called_computations_.clear(); } // Returns true if this instruction performs an elementwise operation on // `operand_idx`-th operand. An instruction is elementwise on an operand iff, // after performing necessary implicit broadcast // (cs/IrArray::EmitArrayElementAddress), to compute the output at index // {i_0,i_1,...,i_n}, the only element required from the operand (if any) is // the element at {i_0,i_1,...,i_n}. // // Note on performance: when this instruction is kFusion, this method, in the // worst case, scans all fused instructions. We could speed this up by // caching. bool IsElementwiseOnOperand(int64 operand_idx) const; // Returns true if this instruction is elementwise on all its operands. bool IsElementwise() const; // Returns true if this is an cross module all-reduce instrucion. bool IsCrossModuleAllReduce() const; // Returns true if this elementwise instruction implicitly broadcasts operand // `operand_idx`. // // Precondition: this instruction should be an elementwise operation. bool ImplicitlyBroadcastsOperand(int64 operand_idx) const; // Returns true if this instruction is binary and elementwise. bool IsElementwiseBinary() const; // Returns whether this instruction may reuse elements of its `i`th operand. bool ReusesOperandElements(int64 i) const { return OperandElementUse(i) == UseKind::kReuse; } // Returns the indices that the given operand appear in the operand list of // this instruction. Note that an instruction can use the same operand // multiple times. std::vector OperandIndices(const HloInstruction* operand) const; // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If // this reshape merely inserts or deletes 1-sized dimensions, return the input // indices of the deleted dimensions and the output indices of the inserted // dimensions. // // Precondition: this op must be a reshape. std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; // Gets the string identifier for this instruction. const string& name() const { return name_; } // Sets the string identifier for this instruction. Name will be sanitized to // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". void SetAndSanitizeName(const string& name) { name_ = NameUniquer::GetSanitizedName(name); } // Use the given NameUniquer to select a unique name for the instruction based // on the instruction's existing name. void UniquifyName(NameUniquer* name_uniquer); // Set the unique id for this instruction to "id" void SetUniqueId(int id) { CHECK_EQ(unique_id_, -1); // Should not be assigned already CHECK_GE(id, 0); unique_id_ = id; } // Return the unique ID assigned to this node via SetUniqueId (or -1 // if no id has been assigned yet). int unique_id() const { return unique_id_; } // Returns the backend-specific configuration for how a backend should compile // this HLO. The meaning of the field is backend specific. Not for use before // or during general HLO optimization, since HLO optimizations do not preserve // this field and they cannot interpret it due to its meaning being backend // specific. // // ConfigProto should be a protobuf Message type. template StatusOr backend_config() const { ConfigProto proto; TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto)); return std::move(proto); } Status set_backend_config(const tensorflow::protobuf::Message& proto); // Getter/setter for raw JSON-encoded backend config. Prefer the // functions above that deal in proto Messages where possible. const string& raw_backend_config_string() const { return backend_config_; } void set_raw_backend_config_string(string config_str) { backend_config_ = std::move(config_str); } // Returns a string representation of a proto in the format used by // raw_backend_config_string. // // This is morally equivalent to: // // HloInstruction instr; // TF_RETURN_IF_ERROR(instr.set_backend_config(proto)); // return instr.raw_backend_config_string(); // static StatusOr BackendConfigToRawString( const tensorflow::protobuf::Message& proto); // Returns the information used to tell the implementation information about // what sort of precision is requested. The meaning of the field is backend // specific. At the moment, it is only supported for kConvolution and kDot. // Transformations on one kDot or kConvolution to another will preserve this // information. Transformations to other HLOs will not preserve this // information but it is presumed that the alternate lowering is strictly // superior. // Precondition: opcode must be kConvolution or kDot. const PrecisionConfig& precision_config() const; // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } // Set/get the computation containing this instruction. set_parent should only // be called by HloComputation methods which add/remove instructions to // computations. void set_parent(HloComputation* computation) { parent_ = computation; } const HloComputation* parent() const { return parent_; } HloComputation* parent() { return parent_; } // Returns the module for this instruction. HloModule* GetModule() const; // Returns whether we could assign input and output layouts to this // instruction to make it a bitcast. bool CouldBeBitcast() const; // Get/Set the number of partitions per outer dimension (in order, starting // with outer-most dimension first). Currently used by the parallel cpu // backend to partition HLOs into parallel tasks. // // TODO(b/62783254) Replace these methods with a more general way to // annotate HLOs with backend-specific information. const std::vector& outer_dimension_partitions() const { return outer_dimension_partitions_; } void set_outer_dimension_partitions( const std::vector& outer_dimension_partitions); // Old methods kept for smooth subclassing transition BEGIN. // TODO(b/80131774): Remove this code. // Delegates to HloBatchNormInstruction::feature_index. int64 feature_index() const; // Delegates to HloBatchNormInstruction::epsilon. float epsilon() const; // Delegates to HloFftInstruction::fft_type. FftType fft_type() const; // Delegates to HloFftInstruction::fft_length. const std::vector& fft_length() const; // Delegates to HloSendRecvInstruction::channel_id. int64 channel_id() const; // Returns the dimension sizes or numbers associated with this instruction. virtual const std::vector& dimensions() const { LOG(FATAL) << "Unimplemented method."; } virtual int64 dimensions(int64 index) const { LOG(FATAL) << "Unimplemented method."; } // Delegates to HloConcatenateInstruction::concatenate_dimension. int64 concatenate_dimension() const; // Returns whether this instruction does a rank-2 transposition. bool IsRank2Transpose() const; // Delegates to HloSliceInstruction::slice_start. int64 slice_starts(int64 dimension) const; const std::vector& slice_starts() const; // Delegates to HloSliceInstruction::slice_limits. int64 slice_limits(int64 dimension) const; const std::vector& slice_limits() const; // Delegates to HloSliceInstruction::slice_strides. int64 slice_strides(int64 dimension) const; const std::vector& slice_strides() const; // Returns the literal associated with this instruction. const Literal& literal() const; // Returns whether the instruction is a constant. bool IsConstant() const; // Delegate to HloConstantInstruction::RelayoutConstant. void RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index = {}); // Delegates to HloTraceInstruction::TracingTag. string TracingTag() const; // Delegates to HloFusionInstruction::AddFusionOperand. HloInstruction* AddFusionOperand(HloInstruction* new_operand); // Delegates to HloFusionInstruction::MergeFusionInstruction. void MergeFusionInstruction(HloInstruction* instruction_to_merge); // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. void MergeFusionInstructionIntoMultiOutput( HloInstruction* instruction_to_merge); // Delegates to HloFusionInstruction::FuseInstruction. HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse); // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput. HloInstruction* FuseInstructionIntoMultiOutput( HloInstruction* instruction_to_fuse); // Delegates to HloFusionInstruction::fused_instruction. HloComputation* fused_instructions_computation() const; // Delegates to HloFusionInstruction::fused_expression_root. HloInstruction* fused_expression_root() const; // Delegates to HloFusionInstruction::fused_instructions. const tensorflow::gtl::iterator_range>::const_iterator>> fused_instructions() const; const tensorflow::gtl::iterator_range< UnwrappingIterator>::iterator>> fused_instructions(); // Delegates to HloFusionInstruction::fused_instruction_count. int64 fused_instruction_count() const; // Delegates to HloFusionInstruction::fused_parameter. HloInstruction* fused_parameter(int64 parameter_number) const; // Delegates to HloFusionInstruction::fused_parameters. const std::vector& fused_parameters() const; // Returns true if this instruction is a fusion instruction that generates // multiple outputs. const bool IsMultiOutputFusion() const; // Delegates to HloFusionInstruction::fusion_kind. FusionKind fusion_kind() const; // Delegates to HloFusionInstruction::set_fusion_kind. void set_fusion_kind(FusionKind kind); // Delegates to HloRngInstruction::random_distribution. RandomDistribution random_distribution() const; // Delegates to HloParameterInstruction::parameter_number. int64 parameter_number() const; // Delegates to HloGetTupleElementInstruction::tuple_index. int64 tuple_index() const; // Delegates to HloReducePrecisionInstruction::exponent_bits. int32 exponent_bits() const; // Delegates to HloReducePrecisionInstruction::mantissa_bits. int32 mantissa_bits() const; // Delegates to HloInfeedInstruction::infeed_config. string infeed_config() const; // Delegates to HloInfeedInstruction::set_infeed_config. void set_infeed_config(const string& config); // Returns the config for the Outfeed instruction. const string& outfeed_config() const; // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const; // Delegates to HloCollectiveInstruction::replica_groups. const std::vector& replica_groups() const; // Delegates to HloCollectivePermuteInstruction::source_target_pairs. const std::vector>& source_target_pairs() const; // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. string cross_replica_sum_barrier() const; void set_cross_replica_sum_barrier(const string& barrier); // Delegates to HloAllReduceInstruction::all_reduce_id. absl::optional all_reduce_id() const; // Returns data on the window in a windowed operation such as // convolution. virtual const Window& window() const { LOG(FATAL) << "Unimplemented method."; } // Sets the window data in a windowed operation such as convolution. virtual void set_window(const Window& window) { LOG(FATAL) << "Unimplemented method."; } // Returns data on the dimension numbers used for a convolution operation, // which may be a kConvolution instruction or a kCustomCall that implements a // convolution. const ConvolutionDimensionNumbers& convolution_dimension_numbers() const; // Sets the convolution dimension numbers on this instruction. In general you // shouldn't need to call this; instead, specify the convolution dimension // numbers when you create the instruction. void set_convolution_dimension_numbers( const ConvolutionDimensionNumbers& dnums); // The number of feature groups. Must be a divisor of the input feature // dimension and output feature dimension. int64 feature_group_count() const; void set_feature_group_count(int64 feature_group_count); // Delegates to HloSelectAndScatterInstruction::select. HloComputation* select() const; // Delegates to HloSelectAndScatterInstruction::scatter. HloComputation* scatter() const; // Delegates to HloSelectAndScatterInstruction::set_select. void set_select(HloComputation* computation); // Delegates to HloSelectAndScatterInstruction::set_scatter. void set_scatter(HloComputation* computation); // Delegates to HloCustomCallInstruction::custom_call_target. const string& custom_call_target() const; // Delegates to HloPadInstruction::padding_config. const PaddingConfig& padding_config() const; // Delegates to HloDynamicSliceInstruction::slice_sizes. int64 slice_sizes(int64 dimension) const; // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes. const std::vector& dynamic_slice_sizes() const; // Delegates to HloGatherInstruction::gather_dimension_numbers. const GatherDimensionNumbers& gather_dimension_numbers() const; // Delegates to HloGatherInstruction::gather_slice_sizes. absl::Span gather_slice_sizes() const; // Delegates to HloScatterInstruction::scatter_dimension_numbers(). const ScatterDimensionNumbers& scatter_dimension_numbers() const; // Delegates to HloDotInstruction::dot_dimension_numbers(). const DotDimensionNumbers& dot_dimension_numbers() const; // Delegates to HloDomainInstruction::operand_side_metadata(). const DomainMetadata& operand_side_metadata() const; // Delegates to HloDomainInstruction::user_side_metadata(). const DomainMetadata& user_side_metadata() const; // Old methods kept for smooth subclassing transition END. protected: enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; // Helper class for computing OperandElementUse for kFusion. class FusionReusesParamElements; // Internal constructor for a given opcode/shape, other fields must be filled // by factory methods. HloInstruction(HloOpcode opcode, const Shape& shape); // Appends operand to the list of operands and adds this instruction as a user // of the operand. void AppendOperand(HloInstruction* operand); void RemoveOperandAt(int index) { operands_.erase(operands_.begin() + index); } // Removes a list of operands with the given indices in ascending order. void RemoveOperandsAtAscendingIndices( absl::Span ascending_indices); void AppendComputation(HloComputation* computation) { called_computations_.push_back(computation); } void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); } void set_called_computation(int index, HloComputation* computation) { called_computations_[index] = computation; } // Indices of computations in called_computations_ for instructions which call // multiple computations. enum { // kWhile computations. kBodyComputationIndex = 0, kConditionComputationIndex = 1, // kSelectAndScatter computations. kSelectComputationIndex = 0, kScatterComputationIndex = 1, // kConditional computations. kTrueComputationIndex = 0, kFalseComputationIndex = 1, }; private: // Implementation for non-common logic of CloneWithNewOperands. virtual std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { // TODO(b/80131774): This should be pure virtual. LOG(FATAL) << "Unimplemented method."; } // Implementation for non-common logic of ExtraAttributesToString. virtual std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const { return {}; } // Implementation for IsElementwise if operand_idx is nullopt and for // IsElementwiseOnOperand if otherwise. // // NOTE: For all instructions other than kFusion, being elementwise on one of // the operands is equivalent to being elementwise on all the operands. virtual bool IsElementwiseImpl( const absl::optional& operand_idx) const; // Prints an instruction to a string. // // The canonical string representation needs to name operands and instruction // names in a consistent way. This is implemented through the // canonical_name_map. string ToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const; // Prints an operand to a string. virtual string OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const; // Allow HloInstruction to access the ToStringWithCanonicalNameMap() and // OperandsToStringWithCanonicalNameMap() functions. friend class HloComputation; // See comments on Identical(). virtual bool IdenticalSlowPath( const HloInstruction& other, const std::function& eq_computations) const; // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( const Shape& shape, HloOpcode opcode, absl::Span operands); // Adds a user for this instruction. void AddUser(HloInstruction* user); // Removes a user for this instruction. void RemoveUser(HloInstruction* user); // Returns how this instruction uses elements of its `i`th operand. UseKind OperandElementUse(int64 i) const; // Helper for implementing backend_config(). Parses backend_config_ into the // given proto. Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const; int unique_id_; // Unique to this HloInstruction within a HloModule // Opcode for this instruction. HloOpcode opcode_; // Instruction operands. InstructionVector operands_; // The set of control predecessors of this instruction. // Note that the order of the instructions in the vector influences the order // computed in HloComputation::ComputeInstructionPostOrder, which may // influence the result of the compilation by changing the scheduling. We are // not sure if it matters. std::vector control_predecessors_; // The users of this instruction. Users are HLOs where this instruction is an // operand. The vector users_ and the set user_set_ contain identical // members. The set enables fast membership testing and the vector enables // fast, stable iteration. std::vector users_; std::unordered_set user_set_; // The set of control successors of this instruction. std::vector control_successors_; // The computation in which this instruction is contained. HloComputation* parent_ = nullptr; // Result shape of this instruction. Shape shape_; // The sharding, if one exists. // Uses std::shared_ptr to allow reuse of the same sharding object between // HloInstructions and other components as HloSharding can be very large for // many element tuples. std::shared_ptr sharding_; // Computations called by this instruction. std::vector called_computations_; // A trace instruction that consumes this instruction. // // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as // an operand. HloInstruction* trace_instruction_ = nullptr; // The backend-specific configuration for how a backend should compile this // HLO. See the documentation on backend_config(). string backend_config_; // String identifier for instruction. string name_; // Metadata for debugging. OpMetadata metadata_; // The number of partitions per outer dimension (listed in order from // outer-most dimension first). std::vector outer_dimension_partitions_; TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); }; string ToString(HloInstruction::FusionKind kind); StatusOr StringToFusionKind( const string& kind_name); // Custom (de)stringification functions for protos that live inside // HloInstruction. string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); string PrecisionToString(const PrecisionConfig::Precision& precision); string ConvolutionDimensionNumbersToString( const ConvolutionDimensionNumbers& dnums); StatusOr StringToRandomDistribution(const string& name); StatusOr StringToPrecision(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); // Map classes that guarantee a deterministic iteration order when the key is // an HloInstruction* or a const HloInstruction*. // To make the iteration order over the map deterministic, the comparator // should not be using the pointer values, but rather an intrinsic property of // the hlo. Exception: null pointer values compare less than non-null. struct HloPtrComparator { bool operator()(const HloInstruction* const& lhs, const HloInstruction* const& rhs) const; }; template using HloInstructionMap = std::map; template using ConstHloInstructionMap = std::map; using HloInstructionSet = std::set; using ConstHloInstructionSet = std::set; } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_