aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-06 20:48:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 20:52:28 -0700
commit44efcf0db7b9204a77710a7f076c904d0e13e6fa (patch)
tree2114672b05239fcdb9554bcf26354a81b4f8565b
parenta7e3047fea74a43174c063320fd0cb6bb6dcceb1 (diff)
Split out HloDomainInstruction as subclass form HloInstruction.
PiperOrigin-RevId: 211916428
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc42
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h19
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc39
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h31
4 files changed, 96 insertions, 35 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 563aa695c9..f66a0ae9e7 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -465,6 +465,14 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.dot_dimension_numbers(), precision_config);
break;
}
+ case HloOpcode::kDomain:
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Domain instruction should have 1 operands but sees "
+ << proto.operand_ids_size();
+ instruction = absl::make_unique<HloDomainInstruction>(
+ proto.shape(), operands(0), /*operand_side_metadata=*/nullptr,
+ /*user_side_metadata=*/nullptr);
+ break;
default: {
instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
@@ -567,7 +575,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kClz:
- case HloOpcode::kDomain:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
@@ -1137,12 +1144,9 @@ bool HloInstruction::HasSideEffect() const {
const Shape& shape, HloInstruction* operand,
std::unique_ptr<DomainMetadata> operand_side_metadata,
std::unique_ptr<DomainMetadata> user_side_metadata) {
- auto instruction =
- absl::WrapUnique(new HloInstruction(HloOpcode::kDomain, shape));
- instruction->operand_side_metadata_ = std::move(operand_side_metadata);
- instruction->user_side_metadata_ = std::move(user_side_metadata);
- instruction->AppendOperand(operand);
- return instruction;
+ return absl::make_unique<HloDomainInstruction>(
+ shape, operand, std::move(operand_side_metadata),
+ std::move(user_side_metadata));
}
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
@@ -1199,6 +1203,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kScatter:
case HloOpcode::kIota:
case HloOpcode::kDot:
+ case HloOpcode::kDomain:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
// Unary ops.
@@ -1295,12 +1300,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
true_computation(), new_operands[2],
false_computation());
break;
- case HloOpcode::kDomain:
- CHECK_EQ(new_operands.size(), 1);
- clone =
- CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(),
- user_side_metadata_->Clone());
- break;
case HloOpcode::kAfterAll:
if (new_operands.empty()) {
clone = CreateToken();
@@ -1611,10 +1610,6 @@ bool HloInstruction::IdenticalSlowPath(
return false;
}
- case HloOpcode::kDomain:
- return operand_side_metadata().Matches(other.operand_side_metadata()) &&
- user_side_metadata().Matches(other.user_side_metadata());
-
// Ops migrated to subclasses should never come to this line.
// TODO(b/80131774): Remove this switch when migration is complete.
case HloOpcode::kBatchNormTraining:
@@ -1655,6 +1650,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kGather:
case HloOpcode::kScatter:
case HloOpcode::kDot:
+ case HloOpcode::kDomain:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
@@ -2114,11 +2110,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
}),
"}"));
}
- if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
- extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
- "\", entry=", user_side_metadata_->ToString(),
- ", exit=", operand_side_metadata_->ToString(), "}"));
- }
return extra;
}
@@ -3288,4 +3279,11 @@ const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
}
+const DomainMetadata& HloInstruction::operand_side_metadata() const {
+ return Cast<HloDomainInstruction>(this)->operand_side_metadata();
+}
+
+const DomainMetadata& HloInstruction::user_side_metadata() const {
+ return Cast<HloDomainInstruction>(this)->user_side_metadata();
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index de60ddf42d..1619d1a985 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -1079,15 +1079,6 @@ class HloInstruction {
return other->has_sharding() ? sharding() == other->sharding() : false;
}
- // Retrieves the operand side metadata of a kDomain instruction.
- const DomainMetadata& operand_side_metadata() const {
- return *operand_side_metadata_;
- }
- // Retrieves the user side metadata of a kDomain instruction.
- const DomainMetadata& user_side_metadata() const {
- return *user_side_metadata_;
- }
-
// When creating a new instruction which either replaces, or shifts up (kCopy
// insertion case), another instruction, we need to make sure the certain
// properties of the new instruction are copied into the derived one. As of
@@ -1496,6 +1487,12 @@ class HloInstruction {
// Delegates to HloDotInstruction::dot_dimension_numbers().
const DotDimensionNumbers& dot_dimension_numbers() const;
+ // Delegates to HloDomainInstruction::operand_side_metadata().
+ const DomainMetadata& operand_side_metadata() const;
+
+ // Delegates to HloDomainInstruction::user_side_metadata().
+ const DomainMetadata& user_side_metadata() const;
+
// Old methods kept for smooth subclassing transition END.
protected:
@@ -1641,10 +1638,6 @@ class HloInstruction {
// many element tuples.
std::shared_ptr<const HloSharding> sharding_;
- // Fields used by the kDomain instruction.
- std::unique_ptr<DomainMetadata> operand_side_metadata_;
- std::unique_ptr<DomainMetadata> user_side_metadata_;
-
// Computations called by this instruction.
std::vector<HloComputation*> called_computations_;
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 4e3e0c055e..76712d73db 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -2224,4 +2224,43 @@ string HloDotInstruction::DotDimensionNumbersToString() const {
return StrJoin(result, ", ");
}
+
+HloDomainInstruction::HloDomainInstruction(
+ const Shape& shape, HloInstruction* operand,
+ std::unique_ptr<DomainMetadata> operand_side_metadata,
+ std::unique_ptr<DomainMetadata> user_side_metadata)
+ : HloInstruction(HloOpcode::kDomain, shape),
+ operand_side_metadata_(std::move(operand_side_metadata)),
+ user_side_metadata_(std::move(user_side_metadata)) {
+ AppendOperand(operand);
+}
+
+std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
+ return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
+ "\", entry=", user_side_metadata_->ToString(),
+ ", exit=", operand_side_metadata_->ToString(), "}")};
+ }
+ return {};
+}
+
+bool HloDomainInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
+ return operand_side_metadata().Matches(
+ casted_other.operand_side_metadata()) &&
+ user_side_metadata().Matches(casted_other.user_side_metadata());
+}
+
+std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return absl::make_unique<HloDomainInstruction>(
+ shape, new_operands[0], operand_side_metadata_->Clone(),
+ user_side_metadata_->Clone());
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index e72ddabff9..af46148c70 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1306,6 +1306,37 @@ class HloDotInstruction : public HloInstruction {
DotDimensionNumbers dot_dimension_numbers_;
};
+class HloDomainInstruction : public HloInstruction {
+ public:
+ explicit HloDomainInstruction(
+ const Shape& shape, HloInstruction* operand,
+ std::unique_ptr<DomainMetadata> operand_side_metadata,
+ std::unique_ptr<DomainMetadata> user_side_metadata);
+
+ // Retrieves the operand side metadata of a kDomain instruction.
+ const DomainMetadata& operand_side_metadata() const {
+ return *operand_side_metadata_;
+ }
+ // Retrieves the user side metadata of a kDomain instruction.
+ const DomainMetadata& user_side_metadata() const {
+ return *user_side_metadata_;
+ }
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const override;
+
+ std::unique_ptr<DomainMetadata> operand_side_metadata_;
+ std::unique_ptr<DomainMetadata> user_side_metadata_;
+};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_