diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 75 |
1 files changed, 41 insertions, 34 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 59a383218c..30bff286c2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -33,7 +33,7 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/xla/iterator_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -346,6 +346,9 @@ class HloInstruction { static std::unique_ptr<HloInstruction> CreateConstant( std::unique_ptr<Literal> literal); + // Creates an Iota instruction. + static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape); + // Creates a get tuple element instruction. static std::unique_ptr<HloInstruction> CreateGetTupleElement( const Shape& shape, HloInstruction* operand, int64 index); @@ -477,7 +480,7 @@ class HloInstruction { const Shape& outfeed_shape, HloInstruction* operand, HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); // Overload which does not require a token. - // TODO(b/80000000): Remove this overload when all uses of infeed are + // TODO(b/80000000): Remove this overload when all uses of outfeed are // converted to take tokens. static std::unique_ptr<HloInstruction> CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, @@ -485,25 +488,30 @@ class HloInstruction { // Creates an asynchronous send instruction with the given channel id, which // initiates sending the operand data to a unique receive instruction in - // another computation that has the same channel id. - static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand, - int64 channel_id); + // another computation that has the same channel id. If is_host_transfer is + // true, then this Send operation transfers data to the host. + static std::unique_ptr<HloInstruction> CreateSend( + HloInstruction* operand, HloInstruction* token, int64 channel_id, + bool is_host_transfer = false); // Blocks until data transfer for the Send instruction (operand) is complete. // The operand must be kSend. static std::unique_ptr<HloInstruction> CreateSendDone( - HloInstruction* operand); + HloInstruction* operand, bool is_host_transfer = false); // Creates an asynchronous receive instruction with the given channel id, // which allocates resources to receive data of the given shape from a unique - // send instruction in another computation that has the same channel id. - static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape, - int64 channel_id); + // send instruction in another computation that has the same channel id. If + // is_host_transfer is true, then this Send operation transfers data from the + // host. + static std::unique_ptr<HloInstruction> CreateRecv( + const Shape& shape, HloInstruction* token, int64 channel_id, + bool is_host_transfer = false); // Blocks until data transfer for the Recv instruction (operand) is complete // and returns the receive buffer. The operand must be kRecv. static std::unique_ptr<HloInstruction> CreateRecvDone( - HloInstruction* operand); + HloInstruction* operand, bool is_host_transfer = false); // Creates a slice instruction, where the operand is sliced by the given // start/limit indices. @@ -611,6 +619,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions); + // Creates a sort op, with a keys operand, and an optional values operand. + static std::unique_ptr<HloInstruction> CreateSort( + const Shape& shape, int64 dimension, HloInstruction* keys, + HloInstruction* values = nullptr); + // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 @@ -680,17 +693,18 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions); - // Creates a token instruction used for joining or creating new values of - // token type which thread through side-effecting operations. + // Creates a Afterall instruction used for joining or creating new values of + // token type which thread through side-effecting operations. Operands must + // all be tokens, and there must be at least one operand. static std::unique_ptr<HloInstruction> CreateAfterAll( tensorflow::gtl::ArraySlice<HloInstruction*> operands); - // Creates an instance of GatherDimensionNumbers. - static GatherDimensionNumbers MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice<int64> output_window_dims, - tensorflow::gtl::ArraySlice<int64> elided_window_dims, - tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, - int64 index_vector_dim); + // Creates an AfterAll instruction which creates a token type out of thin air + // (no operands). This is a separate method from CreateAfterAll to facility + // the removal of operand-less AfterAll instructions. + // TODO(b/110532604): Remove this capability of creating a token from nothing + // when we plumb a primordial token from the entry computation. + static std::unique_ptr<HloInstruction> CreateToken(); // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } @@ -1066,19 +1080,6 @@ class HloInstruction { // Returns the dump string of the dot dimension numbers. string DotDimensionNumbersToString() const; - const GatherDimensionNumbers& gather_dimension_numbers() const { - CHECK(gather_dimension_numbers_ != nullptr); - return *gather_dimension_numbers_; - } - - tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const { - CHECK_EQ(opcode(), HloOpcode::kGather); - return gather_window_bounds_; - } - - // Returns the dump string of the gather dimension numbers. - string GatherDimensionNumbersToString() const; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1133,6 +1134,9 @@ class HloInstruction { // Returns true if this instruction is elementwise on all its operands. bool IsElementwise() const; + // Returns true if this is an cross module all-reduce instrucion. + bool IsCrossModuleAllReduce() const; + // Returns true if this elementwise instruction implicitly broadcasts operand // `operand_idx`. // @@ -1445,6 +1449,12 @@ class HloInstruction { // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes. const std::vector<int64>& dynamic_slice_sizes() const; + + // Delegates to HloGatherInstruction::gather_dimension_numbers. + const GatherDimensionNumbers& gather_dimension_numbers() const; + // Delegates to HloGatherInstruction::gather_window_bounds. + tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const; + // Old methods kept for smooth subclassing transition END. protected: @@ -1588,9 +1598,6 @@ class HloInstruction { // Describes the dimension numbers used for a dot. std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_; - std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_; - std::vector<int64> gather_window_bounds_; - // Used to tag kCopy instructions that are eligible for copy elision. bool copy_elision_allowed_ = true; |