aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc238
1 files changed, 126 insertions, 112 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 5aaeec802f..8b9bdd2f46 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
@@ -112,29 +112,30 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
}
case HloOpcode::kSend:
- TF_RET_CHECK(proto.operand_ids_size() == 1)
- << "Send instruction should have 1 operand but sees "
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "Send instruction should have 2 operand but sees "
<< proto.operand_ids_size();
- instruction = CreateSend(operands(0), proto.channel_id());
+ instruction = CreateSend(operands(0), operands(1), proto.channel_id(),
+ proto.is_host_transfer());
break;
case HloOpcode::kSendDone:
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "SendDone instruction should have 1 operand but sees "
<< proto.operand_ids_size();
- instruction = CreateSendDone(operands(0));
+ instruction = CreateSendDone(operands(0), proto.is_host_transfer());
break;
case HloOpcode::kRecv:
- TF_RET_CHECK(proto.operand_ids_size() == 0)
- << "Recv instruction should have 0 operand but sees "
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Recv instruction should have 1 operand but sees "
<< proto.operand_ids_size();
- instruction =
- CreateRecv(proto.shape().tuple_shapes(0), proto.channel_id());
+ instruction = CreateRecv(proto.shape().tuple_shapes(0), operands(0),
+ proto.channel_id(), proto.is_host_transfer());
break;
case HloOpcode::kRecvDone:
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "RecvDone instruction should have 1 operand but sees "
<< proto.operand_ids_size();
- instruction = CreateRecvDone(operands(0));
+ instruction = CreateRecvDone(operands(0), proto.is_host_transfer());
break;
case HloOpcode::kReverse:
TF_RET_CHECK(proto.operand_ids_size() == 1)
@@ -163,6 +164,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.dimensions().end()),
computations(0));
break;
+ case HloOpcode::kSort: {
+ TF_RET_CHECK(proto.operand_ids_size() == 1 ||
+ proto.operand_ids_size() == 2)
+ << "Sort instruction should have 1 or 2 operands but has "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.dimensions().size() == 1)
+ << "Sort instruction should have 1 dimension";
+ HloInstruction* keys = operands(0);
+ HloInstruction* values =
+ proto.operand_ids_size() == 2 ? operands(1) : nullptr;
+ instruction =
+ CreateSort(proto.shape(), proto.dimensions(0), keys, values);
+ break;
+ }
case HloOpcode::kTranspose:
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Transpose instruction should have 1 operand but sees "
@@ -271,7 +286,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
// converted to take tokens.
instruction = CreateInfeed(data_shape, proto.infeed_config());
} else {
- CHECK_EQ(proto.operand_ids_size(), 2);
+ CHECK_EQ(proto.operand_ids_size(), 1);
instruction =
CreateInfeed(data_shape, operands(0), proto.infeed_config());
}
@@ -372,6 +387,23 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
slice_sizes);
break;
}
+ case HloOpcode::kGather: {
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "Gather instruction should have 2 operands but sees "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.has_gather_dimension_numbers())
+ << "Gather instruction should have GatherDimensionNumbers set.";
+ std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers =
+ MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers());
+ std::vector<int64> gather_window_bounds;
+ for (int64 bound : proto.gather_window_bounds()) {
+ gather_window_bounds.push_back(bound);
+ }
+ instruction =
+ CreateGather(proto.shape(), operands(0), operands(1),
+ *gather_dimension_numbers, gather_window_bounds);
+ break;
+ }
default: {
instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
@@ -413,13 +445,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->set_sharding(sharding);
}
- if (proto.has_gather_dimension_numbers()) {
- instruction->gather_dimension_numbers_ =
- MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers());
- }
- for (int64 bound : proto.gather_window_bounds()) {
- instruction->gather_window_bounds_.push_back(bound);
- }
return std::move(instruction);
}
@@ -438,6 +463,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
return MakeUnique<HloConstantInstruction>(std::move(literal));
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
+ const Shape& shape) {
+ return WrapUnique(new HloInstruction(HloOpcode::kIota, shape));
+}
+
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateGetTupleElement(const Shape& shape,
HloInstruction* operand, int64 index) {
@@ -489,7 +519,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kReal:
case HloOpcode::kSign:
case HloOpcode::kSin:
- case HloOpcode::kSort:
case HloOpcode::kTanh:
break;
default:
@@ -542,8 +571,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
// Only certain opcodes are supported with CreateTernary: opcodes of ternary
// instructions with no auxiliary fields.
switch (opcode) {
- case (HloOpcode::kClamp):
- case (HloOpcode::kSelect):
+ case HloOpcode::kClamp:
+ case HloOpcode::kSelect:
+ case HloOpcode::kTupleSelect:
break;
default:
LOG(FATAL) << "Invalid ternary instruction opcode "
@@ -651,29 +681,33 @@ HloInstruction::CreateCrossReplicaSum(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
- HloInstruction* operand, int64 channel_id) {
- return MakeUnique<HloSendInstruction>(operand, channel_id);
+ HloInstruction* operand, HloInstruction* token, int64 channel_id,
+ bool is_host_transfer) {
+ return MakeUnique<HloSendInstruction>(operand, token, channel_id,
+ is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
- HloInstruction* operand) {
+ HloInstruction* operand, bool is_host_transfer) {
auto send_operand = DynCast<HloSendInstruction>(operand);
CHECK(send_operand != nullptr)
<< "SendDone must take the context operand from Send";
- return MakeUnique<HloSendDoneInstruction>(send_operand);
+ return MakeUnique<HloSendDoneInstruction>(send_operand, is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
- const Shape& shape, int64 channel_id) {
- return MakeUnique<HloRecvInstruction>(shape, channel_id);
+ const Shape& shape, HloInstruction* token, int64 channel_id,
+ bool is_host_transfer) {
+ return MakeUnique<HloRecvInstruction>(shape, token, channel_id,
+ is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
- HloInstruction* operand) {
+ HloInstruction* operand, bool is_host_transfer) {
auto recv_operand = DynCast<HloRecvInstruction>(operand);
CHECK(recv_operand != nullptr)
<< "RecvDone must take the context operand from Recv";
- return MakeUnique<HloRecvDoneInstruction>(recv_operand);
+ return MakeUnique<HloRecvDoneInstruction>(recv_operand, is_host_transfer);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
@@ -684,6 +718,7 @@ HloInstruction::CreateCrossReplicaSum(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ CHECK(!operands.empty());
auto instruction = WrapUnique(
new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
for (auto operand : operands) {
@@ -692,6 +727,11 @@ HloInstruction::CreateCrossReplicaSum(
return instruction;
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
+ return WrapUnique(
+ new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
const Shape& shape, HloComputation* condition, HloComputation* body,
HloInstruction* init) {
@@ -908,6 +948,12 @@ HloInstruction::CreateBroadcastSequence(
return MakeUnique<HloTransposeInstruction>(shape, operand, dimensions);
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
+ const Shape& shape, int64 dimension, HloInstruction* keys,
+ HloInstruction* values) {
+ return MakeUnique<HloSortInstruction>(shape, dimension, keys, values);
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
return MakeUnique<HloFusionInstruction>(shape, fusion_kind, fused_root);
@@ -952,6 +998,8 @@ bool HloInstruction::HasSideEffectNoRecurse() const {
case HloOpcode::kTrace:
case HloOpcode::kHostCompute:
return true;
+ case HloOpcode::kCrossReplicaSum:
+ return all_reduce_id().has_value();
default:
return false;
}
@@ -1010,34 +1058,8 @@ bool HloInstruction::HasSideEffect() const {
const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
const GatherDimensionNumbers& gather_dim_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds) {
- std::unique_ptr<HloInstruction> instruction =
- WrapUnique(new HloInstruction(HloOpcode::kGather, shape));
- instruction->AppendOperand(operand);
- instruction->AppendOperand(gather_indices);
- instruction->gather_dimension_numbers_ =
- MakeUnique<GatherDimensionNumbers>(gather_dim_numbers);
- c_copy(window_bounds, std::back_inserter(instruction->gather_window_bounds_));
- return instruction;
-}
-
-/* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers(
- tensorflow::gtl::ArraySlice<int64> output_window_dims,
- tensorflow::gtl::ArraySlice<int64> elided_window_dims,
- tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
- int64 index_vector_dim) {
- GatherDimensionNumbers gather_dim_numbers;
- for (int64 output_window_dim : output_window_dims) {
- gather_dim_numbers.add_output_window_dims(output_window_dim);
- }
- for (int64 elided_window_dim : elided_window_dims) {
- gather_dim_numbers.add_elided_window_dims(elided_window_dim);
- }
- for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) {
- gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim);
- }
-
- gather_dim_numbers.set_index_vector_dim(index_vector_dim);
- return gather_dim_numbers;
+ return MakeUnique<HloGatherInstruction>(shape, operand, gather_indices,
+ gather_dim_numbers, window_bounds);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
@@ -1100,6 +1122,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kHostCompute:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
+ case HloOpcode::kSort:
+ case HloOpcode::kGather:
+ case HloOpcode::kIota:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
// Unary ops.
@@ -1122,7 +1147,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kReal:
case HloOpcode::kSign:
case HloOpcode::kSin:
- case HloOpcode::kSort:
case HloOpcode::kTanh:
CHECK_EQ(new_operands.size(), 1);
clone = CreateUnary(shape, opcode_, new_operands[0]);
@@ -1156,6 +1180,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
// Ternary ops.
case HloOpcode::kClamp:
case HloOpcode::kSelect:
+ case HloOpcode::kTupleSelect:
CHECK_EQ(new_operands.size(), 3);
clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
new_operands[2]);
@@ -1201,11 +1226,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
true_computation(), new_operands[2],
false_computation());
break;
- case HloOpcode::kGather:
- CHECK_EQ(new_operands.size(), 2);
- clone = CreateGather(shape, new_operands[0], new_operands[1],
- *gather_dimension_numbers_, gather_window_bounds_);
- break;
case HloOpcode::kDomain:
CHECK_EQ(new_operands.size(), 1);
clone =
@@ -1213,7 +1233,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
user_side_metadata_->Clone());
break;
case HloOpcode::kAfterAll:
- clone = CreateAfterAll(new_operands);
+ if (new_operands.empty()) {
+ clone = CreateToken();
+ } else {
+ clone = CreateAfterAll(new_operands);
+ }
break;
}
SetupDerivedInstruction(clone.get());
@@ -1495,11 +1519,10 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
case HloOpcode::kTuple:
+ case HloOpcode::kTupleSelect:
return true;
- // These opcodes have complex or special behavior so just return false.
- case HloOpcode::kDomain:
- case HloOpcode::kWhile:
+ // This opcode has complex or special behavior so just return false.
case HloOpcode::kAfterAll:
return false;
@@ -1508,11 +1531,6 @@ bool HloInstruction::IdenticalSlowPath(
return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
other.dot_dimension_numbers());
- case HloOpcode::kGather:
- return protobuf_util::ProtobufEquals(gather_dimension_numbers(),
- other.gather_dimension_numbers()) &&
- gather_window_bounds() == other.gather_window_bounds();
-
// Remaining instructions with special values.
case HloOpcode::kCall:
return eq_computations(to_apply(), other.to_apply());
@@ -1520,9 +1538,17 @@ bool HloInstruction::IdenticalSlowPath(
return eq_computations(true_computation(), other.true_computation()) &&
eq_computations(false_computation(), other.false_computation());
- // These opcodes are not yet supported.
- case HloOpcode::kSort:
+ case HloOpcode::kWhile: {
+ if (eq_computations(while_body(), other.while_body()) &&
+ eq_computations(while_condition(), other.while_condition())) {
+ return true;
+ }
return false;
+ }
+
+ case HloOpcode::kDomain:
+ return operand_side_metadata().Matches(other.operand_side_metadata()) &&
+ user_side_metadata().Matches(other.user_side_metadata());
// Ops migrated to subclasses should never come to this line.
// TODO(b/80131774): Remove this switch when migration is complete.
@@ -1537,11 +1563,13 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kReverse:
case HloOpcode::kConcatenate:
case HloOpcode::kReduce:
+ case HloOpcode::kSort:
case HloOpcode::kTranspose:
case HloOpcode::kBroadcast:
case HloOpcode::kMap:
case HloOpcode::kSlice:
case HloOpcode::kConstant:
+ case HloOpcode::kIota:
case HloOpcode::kTrace:
case HloOpcode::kFusion:
case HloOpcode::kRng:
@@ -1558,9 +1586,11 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kHostCompute:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
+ case HloOpcode::kGather:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
+ return false;
}
void HloInstruction::RemoveUser(HloInstruction* user) {
@@ -1610,8 +1640,8 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num,
TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
new_operand->shape()))
- << old_operand->shape().ShortDebugString() << " is not compatible with "
- << new_operand->shape().ShortDebugString();
+ << old_operand->shape() << " is not compatible with "
+ << new_operand->shape();
operands_[operand_num] = new_operand;
VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
@@ -1820,7 +1850,6 @@ bool HloInstruction::IsElementwiseImpl(
// Ternary elementwise operations.
case HloOpcode::kSelect:
- return !ShapeUtil::IsTuple(shape_);
case HloOpcode::kClamp:
return true;
@@ -1832,6 +1861,10 @@ bool HloInstruction::IsElementwiseImpl(
}
}
+bool HloInstruction::IsCrossModuleAllReduce() const {
+ return opcode() == HloOpcode::kCrossReplicaSum && all_reduce_id();
+}
+
string HloInstruction::ToStringWithCanonicalNameMap(
const HloPrintOptions& options,
CanonicalNameMap* canonical_name_map) const {
@@ -1924,11 +1957,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
if (dot_dimension_numbers_ != nullptr) {
extra.push_back(DotDimensionNumbersToString());
}
- if (gather_dimension_numbers_ != nullptr) {
- extra.push_back(GatherDimensionNumbersToString());
- extra.push_back(
- StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}"));
- }
if (options.print_subcomputation_mode() ==
HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
@@ -2015,8 +2043,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
}
if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
- "\", entry=", operand_side_metadata_->ToString(),
- ", exit=", user_side_metadata_->ToString(), "}"));
+ "\", entry=", user_side_metadata_->ToString(),
+ ", exit=", operand_side_metadata_->ToString(), "}"));
}
return extra;
@@ -2058,14 +2086,6 @@ HloInstructionProto HloInstruction::ToProto() const {
if (dot_dimension_numbers_ != nullptr) {
*proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_;
}
- if (gather_dimension_numbers_ != nullptr) {
- *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_;
- }
- if (opcode() == HloOpcode::kGather) {
- for (int64 bound : gather_window_bounds()) {
- proto.add_gather_window_bounds(bound);
- }
- }
if (has_sharding()) {
*proto.mutable_sharding() = sharding().ToProto();
@@ -2191,6 +2211,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleRemainder(this);
case HloOpcode::kSelect:
return visitor->HandleSelect(this);
+ case HloOpcode::kTupleSelect:
+ return visitor->HandleTupleSelect(this);
case HloOpcode::kConvolution:
return visitor->HandleConvolution(this);
case HloOpcode::kFft:
@@ -2293,6 +2315,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleDomain(this);
case HloOpcode::kAfterAll:
return visitor->HandleAfterAll(this);
+ case HloOpcode::kIota:
+ return visitor->HandleIota(this);
// These opcodes are not handled here.
case HloOpcode::kTrace:
@@ -2824,26 +2848,6 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
return os << ToString(kind);
}
-string HloInstruction::GatherDimensionNumbersToString() const {
- CHECK_NE(gather_dimension_numbers_.get(), nullptr);
- string output_window_dims =
- StrCat("output_window_dims={",
- Join(gather_dimension_numbers_->output_window_dims(), ","), "}");
- string elided_window_dims =
- StrCat("elided_window_dims={",
- Join(gather_dimension_numbers_->elided_window_dims(), ","), "}");
- string gather_dims_to_operand_dims = StrCat(
- "gather_dims_to_operand_dims={",
- Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}");
- string index_vector_dim = StrCat(
- "index_vector_dim=", gather_dimension_numbers_->index_vector_dim());
-
- return Join<std::initializer_list<string>>(
- {output_window_dims, elided_window_dims, gather_dims_to_operand_dims,
- index_vector_dim},
- ", ");
-}
-
bool HloInstruction::CouldBeBitcast() const {
switch (opcode_) {
case HloOpcode::kTranspose:
@@ -3157,4 +3161,14 @@ int64 HloInstruction::slice_sizes(int64 dimension) const {
const std::vector<int64>& HloInstruction::dynamic_slice_sizes() const {
return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes();
}
+
+const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
+ return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
+}
+
+tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_window_bounds()
+ const {
+ return Cast<HloGatherInstruction>(this)->gather_window_bounds();
+}
+
} // namespace xla