aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h92
1 files changed, 87 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index ec8a42bd3b..e4031f04d5 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -141,12 +141,15 @@ class HloSendRecvInstruction : public HloInstruction {
// channel.
int64 channel_id() const { return channel_id_; }
+ // Returns whether this send/recv instruction sends data to/from the host.
+ bool is_host_transfer() const { return is_host_transfer_; }
+
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
protected:
explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape,
- int64 channel_id);
+ int64 channel_id, bool is_host_transfer);
private:
std::vector<string> ExtraAttributesToStringImpl(
@@ -157,11 +160,15 @@ class HloSendRecvInstruction : public HloInstruction {
eq_computations) const override;
// Represents a unique identifier for each Send/Recv instruction pair.
int64 channel_id_;
+
+ // Whether this send/recv instruction sends data to/from the host.
+ bool is_host_transfer_;
};
class HloSendInstruction : public HloSendRecvInstruction {
public:
- explicit HloSendInstruction(HloInstruction* operand, int64 channel_id);
+ explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token,
+ int64 channel_id, bool is_host_transfer);
private:
// Implementation for non-common logic of CloneWithNewOperands.
@@ -173,7 +180,8 @@ class HloSendInstruction : public HloSendRecvInstruction {
class HloSendDoneInstruction : public HloSendRecvInstruction {
public:
- explicit HloSendDoneInstruction(HloSendInstruction* operand);
+ explicit HloSendDoneInstruction(HloSendInstruction* operand,
+ bool is_host_transfer);
private:
// Implementation for non-common logic of CloneWithNewOperands.
@@ -185,7 +193,8 @@ class HloSendDoneInstruction : public HloSendRecvInstruction {
class HloRecvInstruction : public HloSendRecvInstruction {
public:
- explicit HloRecvInstruction(const Shape& shape, int64 channel_id);
+ explicit HloRecvInstruction(const Shape& shape, HloInstruction* token,
+ int64 channel_id, bool is_host_transfer);
private:
// Implementation for non-common logic of CloneWithNewOperands.
@@ -197,7 +206,8 @@ class HloRecvInstruction : public HloSendRecvInstruction {
class HloRecvDoneInstruction : public HloSendRecvInstruction {
public:
- explicit HloRecvDoneInstruction(HloRecvInstruction* operand);
+ explicit HloRecvDoneInstruction(HloRecvInstruction* operand,
+ bool is_host_transfer);
private:
// Implementation for non-common logic of CloneWithNewOperands.
@@ -347,6 +357,35 @@ class HloReduceInstruction : public HloInstruction {
std::vector<int64> dimensions_;
};
+class HloSortInstruction : public HloInstruction {
+ public:
+ explicit HloSortInstruction(const Shape& shape, int64 dimension,
+ HloInstruction* keys,
+ HloInstruction* values = nullptr);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Returns the sort dimension for this instruction
+ int64 sort_dimension() { return dimensions(0); }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
class HloTransposeInstruction : public HloInstruction {
public:
explicit HloTransposeInstruction(
@@ -1117,6 +1156,49 @@ class HloDynamicSliceInstruction : public HloInstruction {
// ('start' is specified dynamically in the second operand of the operation).
std::vector<int64> dynamic_slice_sizes_;
};
+
+class HloGatherInstruction : public HloInstruction {
+ public:
+ explicit HloGatherInstruction(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* gather_indices,
+ const GatherDimensionNumbers& gather_dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds);
+ const GatherDimensionNumbers& gather_dimension_numbers() const {
+ CHECK(gather_dimension_numbers_ != nullptr);
+ return *gather_dimension_numbers_;
+ }
+ tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
+ return gather_window_bounds_;
+ }
+ // Returns the dump string of the gather dimension numbers.
+ string GatherDimensionNumbersToString() const;
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ // Creates an instance of GatherDimensionNumbers.
+ static GatherDimensionNumbers MakeGatherDimNumbers(
+ tensorflow::gtl::ArraySlice<int64> output_window_dims,
+ tensorflow::gtl::ArraySlice<int64> elided_window_dims,
+ tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
+ int64 index_vector_dim);
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
+ std::vector<int64> gather_window_bounds_;
+};
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_