aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h75
1 files changed, 41 insertions, 34 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 59a383218c..30bff286c2 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -33,7 +33,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/iterator_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -346,6 +346,9 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateConstant(
std::unique_ptr<Literal> literal);
+ // Creates an Iota instruction.
+ static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape);
+
// Creates a get tuple element instruction.
static std::unique_ptr<HloInstruction> CreateGetTupleElement(
const Shape& shape, HloInstruction* operand, int64 index);
@@ -477,7 +480,7 @@ class HloInstruction {
const Shape& outfeed_shape, HloInstruction* operand,
HloInstruction* token_operand, tensorflow::StringPiece outfeed_config);
// Overload which does not require a token.
- // TODO(b/80000000): Remove this overload when all uses of infeed are
+ // TODO(b/80000000): Remove this overload when all uses of outfeed are
// converted to take tokens.
static std::unique_ptr<HloInstruction> CreateOutfeed(
const Shape& outfeed_shape, HloInstruction* operand,
@@ -485,25 +488,30 @@ class HloInstruction {
// Creates an asynchronous send instruction with the given channel id, which
// initiates sending the operand data to a unique receive instruction in
- // another computation that has the same channel id.
- static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand,
- int64 channel_id);
+ // another computation that has the same channel id. If is_host_transfer is
+ // true, then this Send operation transfers data to the host.
+ static std::unique_ptr<HloInstruction> CreateSend(
+ HloInstruction* operand, HloInstruction* token, int64 channel_id,
+ bool is_host_transfer = false);
// Blocks until data transfer for the Send instruction (operand) is complete.
// The operand must be kSend.
static std::unique_ptr<HloInstruction> CreateSendDone(
- HloInstruction* operand);
+ HloInstruction* operand, bool is_host_transfer = false);
// Creates an asynchronous receive instruction with the given channel id,
// which allocates resources to receive data of the given shape from a unique
- // send instruction in another computation that has the same channel id.
- static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape,
- int64 channel_id);
+ // send instruction in another computation that has the same channel id. If
+ // is_host_transfer is true, then this Send operation transfers data from the
+ // host.
+ static std::unique_ptr<HloInstruction> CreateRecv(
+ const Shape& shape, HloInstruction* token, int64 channel_id,
+ bool is_host_transfer = false);
// Blocks until data transfer for the Recv instruction (operand) is complete
// and returns the receive buffer. The operand must be kRecv.
static std::unique_ptr<HloInstruction> CreateRecvDone(
- HloInstruction* operand);
+ HloInstruction* operand, bool is_host_transfer = false);
// Creates a slice instruction, where the operand is sliced by the given
// start/limit indices.
@@ -611,6 +619,11 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions);
+ // Creates a sort op, with a keys operand, and an optional values operand.
+ static std::unique_ptr<HloInstruction> CreateSort(
+ const Shape& shape, int64 dimension, HloInstruction* keys,
+ HloInstruction* values = nullptr);
+
// Creates a while instruction, given a condition computation, a body
// computation, and the initial value for the input of the computations. For
// example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
@@ -680,17 +693,18 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions);
- // Creates a token instruction used for joining or creating new values of
- // token type which thread through side-effecting operations.
+ // Creates a Afterall instruction used for joining or creating new values of
+ // token type which thread through side-effecting operations. Operands must
+ // all be tokens, and there must be at least one operand.
static std::unique_ptr<HloInstruction> CreateAfterAll(
tensorflow::gtl::ArraySlice<HloInstruction*> operands);
- // Creates an instance of GatherDimensionNumbers.
- static GatherDimensionNumbers MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> output_window_dims,
- tensorflow::gtl::ArraySlice<int64> elided_window_dims,
- tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
- int64 index_vector_dim);
+ // Creates an AfterAll instruction which creates a token type out of thin air
+ // (no operands). This is a separate method from CreateAfterAll to facility
+ // the removal of operand-less AfterAll instructions.
+ // TODO(b/110532604): Remove this capability of creating a token from nothing
+ // when we plumb a primordial token from the entry computation.
+ static std::unique_ptr<HloInstruction> CreateToken();
// Returns the opcode for this instruction.
HloOpcode opcode() const { return opcode_; }
@@ -1066,19 +1080,6 @@ class HloInstruction {
// Returns the dump string of the dot dimension numbers.
string DotDimensionNumbersToString() const;
- const GatherDimensionNumbers& gather_dimension_numbers() const {
- CHECK(gather_dimension_numbers_ != nullptr);
- return *gather_dimension_numbers_;
- }
-
- tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
- CHECK_EQ(opcode(), HloOpcode::kGather);
- return gather_window_bounds_;
- }
-
- // Returns the dump string of the gather dimension numbers.
- string GatherDimensionNumbersToString() const;
-
// Clones the HLO instruction. The clone will have the same opcode, shape, and
// operands. After creation the clone has no uses. "this" (the instruction
// cloned from) is not changed. Suffix is the string to append to the name of
@@ -1133,6 +1134,9 @@ class HloInstruction {
// Returns true if this instruction is elementwise on all its operands.
bool IsElementwise() const;
+ // Returns true if this is an cross module all-reduce instrucion.
+ bool IsCrossModuleAllReduce() const;
+
// Returns true if this elementwise instruction implicitly broadcasts operand
// `operand_idx`.
//
@@ -1445,6 +1449,12 @@ class HloInstruction {
// Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes.
const std::vector<int64>& dynamic_slice_sizes() const;
+
+ // Delegates to HloGatherInstruction::gather_dimension_numbers.
+ const GatherDimensionNumbers& gather_dimension_numbers() const;
+ // Delegates to HloGatherInstruction::gather_window_bounds.
+ tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const;
+
// Old methods kept for smooth subclassing transition END.
protected:
@@ -1588,9 +1598,6 @@ class HloInstruction {
// Describes the dimension numbers used for a dot.
std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
- std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
- std::vector<int64> gather_window_bounds_;
-
// Used to tag kCopy instructions that are eligible for copy elision.
bool copy_elision_allowed_ = true;