aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h191
1 files changed, 183 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index ec8a42bd3b..9586ad6673 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -141,12 +141,15 @@ class HloSendRecvInstruction : public HloInstruction {
// channel.
int64 channel_id() const { return channel_id_; }
+ // Returns whether this send/recv instruction sends data to/from the host.
+ bool is_host_transfer() const { return is_host_transfer_; }
+
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
protected:
explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape,
- int64 channel_id);
+ int64 channel_id, bool is_host_transfer);
private:
std::vector<string> ExtraAttributesToStringImpl(
@@ -157,11 +160,15 @@ class HloSendRecvInstruction : public HloInstruction {
eq_computations) const override;
// Represents a unique identifier for each Send/Recv instruction pair.
int64 channel_id_;
+
+ // Whether this send/recv instruction sends data to/from the host.
+ bool is_host_transfer_;
};
class HloSendInstruction : public HloSendRecvInstruction {
public:
- explicit HloSendInstruction(HloInstruction* operand, int64 channel_id);
+ explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token,
+ int64 channel_id, bool is_host_transfer);
private:
// Implementation for non-common logic of CloneWithNewOperands.
@@ -173,7 +180,8 @@ class HloSendInstruction : public HloSendRecvInstruction {
class HloSendDoneInstruction : public HloSendRecvInstruction {
public:
- explicit HloSendDoneInstruction(HloSendInstruction* operand);
+ explicit HloSendDoneInstruction(HloSendInstruction* operand,
+ bool is_host_transfer);
private:
// Implementation for non-common logic of CloneWithNewOperands.
@@ -185,7 +193,8 @@ class HloSendDoneInstruction : public HloSendRecvInstruction {
class HloRecvInstruction : public HloSendRecvInstruction {
public:
- explicit HloRecvInstruction(const Shape& shape, int64 channel_id);
+ explicit HloRecvInstruction(const Shape& shape, HloInstruction* token,
+ int64 channel_id, bool is_host_transfer);
private:
// Implementation for non-common logic of CloneWithNewOperands.
@@ -197,7 +206,8 @@ class HloRecvInstruction : public HloSendRecvInstruction {
class HloRecvDoneInstruction : public HloSendRecvInstruction {
public:
- explicit HloRecvDoneInstruction(HloRecvInstruction* operand);
+ explicit HloRecvDoneInstruction(HloRecvInstruction* operand,
+ bool is_host_transfer);
private:
// Implementation for non-common logic of CloneWithNewOperands.
@@ -214,8 +224,7 @@ class HloAllReduceInstruction : public HloInstruction {
HloComputation* reduce_computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
tensorflow::StringPiece barrier,
- const tensorflow::gtl::optional<int64>& all_reduce_id =
- tensorflow::gtl::nullopt);
+ const tensorflow::gtl::optional<int64>& all_reduce_id);
// Returns the group ids of each replica for CrossReplicaSum op.
const std::vector<int64>& replica_group_ids() const {
@@ -264,6 +273,47 @@ class HloAllReduceInstruction : public HloInstruction {
tensorflow::gtl::optional<int64> all_reduce_id_;
};
+class HloAllToAllInstruction : public HloInstruction {
+ public:
+ explicit HloAllToAllInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operand,
+ const std::vector<ReplicaGroup>& replica_groups,
+ tensorflow::StringPiece barrier);
+
+ const std::vector<ReplicaGroup>& replica_groups() const {
+ return replica_groups_;
+ }
+
+ // TODO(b/110096724): rename this.
+ void set_cross_replica_sum_barrier(string barrier) {
+ cross_replica_sum_barrier_ = barrier;
+ }
+ string cross_replica_sum_barrier() const {
+ return cross_replica_sum_barrier_;
+ }
+
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<ReplicaGroup> replica_groups_;
+
+ // The string representation of the barrier config.
+ string cross_replica_sum_barrier_;
+};
+
class HloReverseInstruction : public HloInstruction {
public:
explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
@@ -322,7 +372,7 @@ class HloConcatenateInstruction : public HloInstruction {
class HloReduceInstruction : public HloInstruction {
public:
explicit HloReduceInstruction(
- const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> args,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation);
// Returns the dimension sizes or numbers associated with this instruction.
@@ -331,6 +381,47 @@ class HloReduceInstruction : public HloInstruction {
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
+ // Returns the input tensors to be reduced.
+ tensorflow::gtl::ArraySlice<HloInstruction*> inputs() const {
+ return tensorflow::gtl::ArraySlice<HloInstruction*>(operands(), 0,
+ operand_count() / 2);
+ }
+
+ // Returns the init values of the reduction.
+ tensorflow::gtl::ArraySlice<HloInstruction*> init_values() const {
+ return tensorflow::gtl::ArraySlice<HloInstruction*>(
+ operands(), operand_count() / 2, operand_count());
+ }
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
+class HloSortInstruction : public HloInstruction {
+ public:
+ explicit HloSortInstruction(const Shape& shape, int64 dimension,
+ HloInstruction* keys,
+ HloInstruction* values = nullptr);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Returns the sort dimension for this instruction
+ int64 sort_dimension() { return dimensions(0); }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
@@ -496,6 +587,8 @@ class HloConstantInstruction : public HloInstruction {
explicit HloConstantInstruction(const Shape& shape);
// Returns the literal associated with this instruction.
const Literal& literal() const { return *literal_; }
+ // Returns whether there is literal associated with this instruction.
+ bool HasLiteral() const { return literal_ != nullptr; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -1117,6 +1210,88 @@ class HloDynamicSliceInstruction : public HloInstruction {
// ('start' is specified dynamically in the second operand of the operation).
std::vector<int64> dynamic_slice_sizes_;
};
+
+class HloGatherInstruction : public HloInstruction {
+ public:
+ explicit HloGatherInstruction(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* gather_indices,
+ const GatherDimensionNumbers& gather_dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds);
+ const GatherDimensionNumbers& gather_dimension_numbers() const {
+ CHECK(gather_dimension_numbers_ != nullptr);
+ return *gather_dimension_numbers_;
+ }
+ tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
+ return gather_window_bounds_;
+ }
+ // Returns the dump string of the gather dimension numbers.
+ string GatherDimensionNumbersToString() const;
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ // Creates an instance of GatherDimensionNumbers.
+ static GatherDimensionNumbers MakeGatherDimNumbers(
+ tensorflow::gtl::ArraySlice<int64> output_window_dims,
+ tensorflow::gtl::ArraySlice<int64> elided_window_dims,
+ tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
+ int64 index_vector_dim);
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
+ std::vector<int64> gather_window_bounds_;
+};
+
+class HloScatterInstruction : public HloInstruction {
+ public:
+ explicit HloScatterInstruction(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* scatter_indices, HloInstruction* updates,
+ HloComputation* update_computation,
+ const ScatterDimensionNumbers& scatter_dim_numbers);
+ const ScatterDimensionNumbers& scatter_dimension_numbers() const {
+ CHECK(scatter_dimension_numbers_ != nullptr);
+ return *scatter_dimension_numbers_;
+ }
+ // Returns the dump string of the scatter dimension numbers.
+ string ScatterDimensionNumbersToString() const;
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ // Creates an instance of ScatterDimensionNumbers.
+ static ScatterDimensionNumbers MakeScatterDimNumbers(
+ tensorflow::gtl::ArraySlice<int64> update_window_dims,
+ tensorflow::gtl::ArraySlice<int64> inserted_window_dims,
+ tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims,
+ int64 index_vector_dim);
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
+};
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_