aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc223
1 files changed, 192 insertions, 31 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e2f43f5810..df26a2c744 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <deque>
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -180,8 +181,11 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
const Shape& shape,
- int64 channel_id)
- : HloInstruction(opcode, shape), channel_id_(channel_id) {}
+ int64 channel_id,
+ bool is_host_transfer)
+ : HloInstruction(opcode, shape),
+ channel_id_(channel_id),
+ is_host_transfer_(is_host_transfer) {}
HloInstructionProto HloSendRecvInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
@@ -191,7 +195,12 @@ HloInstructionProto HloSendRecvInstruction::ToProto() const {
std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {StrCat("channel_id=", channel_id_)};
+ std::vector<string> attrs;
+ attrs.push_back(StrCat("channel_id=", channel_id_));
+ if (is_host_transfer()) {
+ attrs.push_back("is_host_transfer=true");
+ }
+ return attrs;
}
bool HloSendRecvInstruction::IdenticalSlowPath(
@@ -204,26 +213,32 @@ bool HloSendRecvInstruction::IdenticalSlowPath(
// Send instruction produces a tuple of {aliased operand, U32 context}.
HloSendInstruction::HloSendInstruction(HloInstruction* operand,
- int64 channel_id)
+ HloInstruction* token, int64 channel_id,
+ bool is_host_transfer)
: HloSendRecvInstruction(
HloOpcode::kSend,
- ShapeUtil::MakeTupleShape(
- {CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {})}),
- channel_id) {
+ ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(),
+ ShapeUtil::MakeShape(U32, {}),
+ ShapeUtil::MakeTokenShape()}),
+ channel_id, is_host_transfer) {
AppendOperand(operand);
+ AppendOperand(token);
}
std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- CHECK_EQ(new_operands.size(), 1);
- return MakeUnique<HloSendInstruction>(new_operands[0], channel_id());
+ CHECK_EQ(new_operands.size(), 2);
+ return MakeUnique<HloSendInstruction>(new_operands[0], new_operands[1],
+ channel_id(), is_host_transfer());
}
-HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand)
- : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil(),
- CHECK_NOTNULL(operand)->channel_id()) {
+HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
+ bool is_host_transfer)
+ : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
+ CHECK_NOTNULL(operand)->channel_id(),
+ is_host_transfer) {
AppendOperand(operand);
}
@@ -234,30 +249,39 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return MakeUnique<HloSendDoneInstruction>(
- Cast<HloSendInstruction>(new_operands[0]));
+ Cast<HloSendInstruction>(new_operands[0]), is_host_transfer());
}
// Recv instruction produces a tuple of {receive buffer, U32 context}.
-HloRecvInstruction::HloRecvInstruction(const Shape& shape, int64 channel_id)
+HloRecvInstruction::HloRecvInstruction(const Shape& shape,
+ HloInstruction* token, int64 channel_id,
+ bool is_host_transfer)
: HloSendRecvInstruction(
HloOpcode::kRecv,
- ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}),
- channel_id) {}
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}),
+ ShapeUtil::MakeTokenShape()}),
+ channel_id, is_host_transfer) {
+ AppendOperand(token);
+}
std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
- CHECK_EQ(new_operands.size(), 0);
+ CHECK_EQ(new_operands.size(), 1);
return MakeUnique<HloRecvInstruction>(
- ShapeUtil::GetTupleElementShape(shape, 0), channel_id());
+ ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(),
+ is_host_transfer());
}
-HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand)
+HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
+ bool is_host_transfer)
: HloSendRecvInstruction(
HloOpcode::kRecvDone,
- ShapeUtil::GetTupleElementShape(operand->shape(), 0),
- CHECK_NOTNULL(operand)->channel_id()) {
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::GetTupleElementShape(operand->shape(), 0),
+ ShapeUtil::MakeTokenShape()}),
+ CHECK_NOTNULL(operand)->channel_id(), is_host_transfer) {
AppendOperand(operand);
}
@@ -268,7 +292,7 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return MakeUnique<HloRecvDoneInstruction>(
- Cast<HloRecvInstruction>(new_operands[0]));
+ Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer());
}
HloAllReduceInstruction::HloAllReduceInstruction(
@@ -281,8 +305,6 @@ HloAllReduceInstruction::HloAllReduceInstruction(
replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()),
cross_replica_sum_barrier_(barrier.begin(), barrier.end()),
all_reduce_id_(all_reduce_id) {
- // TODO(b/79737069): Remove the CHECK when supported.
- CHECK(!all_reduce_id_);
for (auto operand : operands) {
AppendOperand(operand);
}
@@ -459,6 +481,46 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
shape, new_operands[0], new_operands[1], dimensions(), to_apply());
}
+HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension,
+ HloInstruction* keys,
+ HloInstruction* values)
+ : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) {
+ AppendOperand(keys);
+ if (values) {
+ AppendOperand(values);
+ }
+}
+
+HloInstructionProto HloSortInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloSortInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloSortInstruction&>(other);
+ return dimensions() == casted_other.dimensions();
+}
+
+std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ HloInstruction* keys = new_operands[0];
+ HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr;
+ return MakeUnique<HloSortInstruction>(shape, dimensions(0), keys, values);
+}
+
HloTransposeInstruction::HloTransposeInstruction(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions)
@@ -757,7 +819,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
HloTraceInstruction::HloTraceInstruction(const string& tag,
HloInstruction* operand)
: HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()),
- literal_(Literal::CreateR1U8(tag)) {
+ literal_(LiteralUtil::CreateR1U8(tag)) {
AppendOperand(operand);
operand->set_tracing(this);
}
@@ -1043,8 +1105,6 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
clone = fused_expression_root();
} else {
- clone = fused_instructions_computation()->AddInstruction(
- instruction_to_fuse->Clone(/*suffix=*/""));
// When add_output is false, instruction_to_fuse is necessarily an operand
// of the fusion instruction. After fusion this will no longer be the
// case. Remove the operand from the operand list and remove its
@@ -1054,6 +1114,16 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
bool in_operand_list = std::find(operands().begin(), operands().end(),
instruction_to_fuse) != operands().end();
CHECK(add_output || in_operand_list);
+ if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
+ // We assume all uses of a kTuple operation are GTE ops, not another
+ // fusion node. In this case, we don't need to clone
+ // 'instruction_to_fuse'.
+ CHECK(!in_operand_list);
+ clone = instruction_to_fuse;
+ } else {
+ clone = fused_instructions_computation()->AddInstruction(
+ instruction_to_fuse->Clone(/*suffix=*/""));
+ }
const std::vector<HloInstruction*>& fused_parameters =
fused_instructions_computation()->parameter_instructions();
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
@@ -1150,9 +1220,10 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
}
int64 index = tuple_elements.size();
if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
- index -= instruction_to_fuse->operand_count();
+ CHECK_EQ(clone, instruction_to_fuse);
+ index -= clone->operand_count();
std::vector<HloInstruction*> to_be_removed;
- for (auto old_gte : instruction_to_fuse->users()) {
+ for (auto old_gte : clone->users()) {
CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
int64 old_tuple_index = old_gte->tuple_index();
HloInstruction* new_gte =
@@ -1164,7 +1235,6 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
for (auto old_gte : to_be_removed) {
TF_CHECK_OK(parent()->RemoveInstruction(old_gte));
}
- TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone));
} else {
HloInstruction* new_gte =
parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
@@ -1173,7 +1243,9 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
}
}
- VLOG(2) << "New clone:\n" << clone->ToString();
+ if (clone != instruction_to_fuse) {
+ VLOG(2) << "New clone:\n" << clone->ToString();
+ }
return clone;
}
@@ -1854,4 +1926,93 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
return MakeUnique<HloDynamicSliceInstruction>(
shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
}
+
+HloGatherInstruction::HloGatherInstruction(
+ const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
+ const GatherDimensionNumbers& gather_dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds)
+ : HloInstruction(HloOpcode::kGather, shape) {
+ AppendOperand(operand);
+ AppendOperand(gather_indices);
+ gather_dimension_numbers_ =
+ MakeUnique<GatherDimensionNumbers>(gather_dim_numbers);
+ c_copy(window_bounds, std::back_inserter(gather_window_bounds_));
+}
+
+string HloGatherInstruction::GatherDimensionNumbersToString() const {
+ CHECK(gather_dimension_numbers_ != nullptr);
+ string output_window_dims =
+ StrCat("output_window_dims={",
+ Join(gather_dimension_numbers_->output_window_dims(), ","), "}");
+ string elided_window_dims =
+ StrCat("elided_window_dims={",
+ Join(gather_dimension_numbers_->elided_window_dims(), ","), "}");
+ string gather_dims_to_operand_dims = StrCat(
+ "gather_dims_to_operand_dims={",
+ Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}");
+ string index_vector_dim = StrCat(
+ "index_vector_dim=", gather_dimension_numbers_->index_vector_dim());
+
+ return Join<std::initializer_list<string>>(
+ {output_window_dims, elided_window_dims, gather_dims_to_operand_dims,
+ index_vector_dim},
+ ", ");
+}
+
+/* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers(
+ tensorflow::gtl::ArraySlice<int64> output_window_dims,
+ tensorflow::gtl::ArraySlice<int64> elided_window_dims,
+ tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
+ int64 index_vector_dim) {
+ GatherDimensionNumbers gather_dim_numbers;
+ for (int64 output_window_dim : output_window_dims) {
+ gather_dim_numbers.add_output_window_dims(output_window_dim);
+ }
+ for (int64 elided_window_dim : elided_window_dims) {
+ gather_dim_numbers.add_elided_window_dims(elided_window_dim);
+ }
+ for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) {
+ gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim);
+ }
+
+ gather_dim_numbers.set_index_vector_dim(index_vector_dim);
+ return gather_dim_numbers;
+}
+
+HloInstructionProto HloGatherInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers();
+ for (int64 bound : gather_window_bounds()) {
+ proto.add_gather_window_bounds(bound);
+ }
+ return proto;
+}
+
+std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {GatherDimensionNumbersToString(),
+ StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")};
+}
+
+bool HloGatherInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloGatherInstruction&>(other);
+ return protobuf_util::ProtobufEquals(
+ gather_dimension_numbers(),
+ casted_other.gather_dimension_numbers()) &&
+ gather_window_bounds() == casted_other.gather_window_bounds();
+}
+
+std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 2);
+ return MakeUnique<HloGatherInstruction>(
+ shape, new_operands[0], new_operands[1], gather_dimension_numbers(),
+ gather_window_bounds());
+}
+
} // namespace xla