diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.h | 92 |
1 files changed, 87 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index ec8a42bd3b..e4031f04d5 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -141,12 +141,15 @@ class HloSendRecvInstruction : public HloInstruction { // channel. int64 channel_id() const { return channel_id_; } + // Returns whether this send/recv instruction sends data to/from the host. + bool is_host_transfer() const { return is_host_transfer_; } + // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; protected: explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, - int64 channel_id); + int64 channel_id, bool is_host_transfer); private: std::vector<string> ExtraAttributesToStringImpl( @@ -157,11 +160,15 @@ class HloSendRecvInstruction : public HloInstruction { eq_computations) const override; // Represents a unique identifier for each Send/Recv instruction pair. int64 channel_id_; + + // Whether this send/recv instruction sends data to/from the host. + bool is_host_transfer_; }; class HloSendInstruction : public HloSendRecvInstruction { public: - explicit HloSendInstruction(HloInstruction* operand, int64 channel_id); + explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token, + int64 channel_id, bool is_host_transfer); private: // Implementation for non-common logic of CloneWithNewOperands. @@ -173,7 +180,8 @@ class HloSendInstruction : public HloSendRecvInstruction { class HloSendDoneInstruction : public HloSendRecvInstruction { public: - explicit HloSendDoneInstruction(HloSendInstruction* operand); + explicit HloSendDoneInstruction(HloSendInstruction* operand, + bool is_host_transfer); private: // Implementation for non-common logic of CloneWithNewOperands. @@ -185,7 +193,8 @@ class HloSendDoneInstruction : public HloSendRecvInstruction { class HloRecvInstruction : public HloSendRecvInstruction { public: - explicit HloRecvInstruction(const Shape& shape, int64 channel_id); + explicit HloRecvInstruction(const Shape& shape, HloInstruction* token, + int64 channel_id, bool is_host_transfer); private: // Implementation for non-common logic of CloneWithNewOperands. @@ -197,7 +206,8 @@ class HloRecvInstruction : public HloSendRecvInstruction { class HloRecvDoneInstruction : public HloSendRecvInstruction { public: - explicit HloRecvDoneInstruction(HloRecvInstruction* operand); + explicit HloRecvDoneInstruction(HloRecvInstruction* operand, + bool is_host_transfer); private: // Implementation for non-common logic of CloneWithNewOperands. @@ -347,6 +357,35 @@ class HloReduceInstruction : public HloInstruction { std::vector<int64> dimensions_; }; +class HloSortInstruction : public HloInstruction { + public: + explicit HloSortInstruction(const Shape& shape, int64 dimension, + HloInstruction* keys, + HloInstruction* values = nullptr); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector<int64>& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns the sort dimension for this instruction + int64 sort_dimension() { return dimensions(0); } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const override; + + std::vector<int64> dimensions_; +}; + class HloTransposeInstruction : public HloInstruction { public: explicit HloTransposeInstruction( @@ -1117,6 +1156,49 @@ class HloDynamicSliceInstruction : public HloInstruction { // ('start' is specified dynamically in the second operand of the operation). std::vector<int64> dynamic_slice_sizes_; }; + +class HloGatherInstruction : public HloInstruction { + public: + explicit HloGatherInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds); + const GatherDimensionNumbers& gather_dimension_numbers() const { + CHECK(gather_dimension_numbers_ != nullptr); + return *gather_dimension_numbers_; + } + tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const { + return gather_window_bounds_; + } + // Returns the dump string of the gather dimension numbers. + string GatherDimensionNumbersToString() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Creates an instance of GatherDimensionNumbers. + static GatherDimensionNumbers MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice<int64> output_window_dims, + tensorflow::gtl::ArraySlice<int64> elided_window_dims, + tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, + int64 index_vector_dim); + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const override; + + std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_; + std::vector<int64> gather_window_bounds_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ |