diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-06 20:48:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-06 20:52:28 -0700 |
commit | 44efcf0db7b9204a77710a7f076c904d0e13e6fa (patch) | |
tree | 2114672b05239fcdb9554bcf26354a81b4f8565b /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | a7e3047fea74a43174c063320fd0cb6bb6dcceb1 (diff) |
Split out HloDomainInstruction as subclass form HloInstruction.
PiperOrigin-RevId: 211916428
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 4e3e0c055e..76712d73db 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2224,4 +2224,43 @@ string HloDotInstruction::DotDimensionNumbersToString() const { return StrJoin(result, ", "); } + +HloDomainInstruction::HloDomainInstruction( + const Shape& shape, HloInstruction* operand, + std::unique_ptr<DomainMetadata> operand_side_metadata, + std::unique_ptr<DomainMetadata> user_side_metadata) + : HloInstruction(HloOpcode::kDomain, shape), + operand_side_metadata_(std::move(operand_side_metadata)), + user_side_metadata_(std::move(user_side_metadata)) { + AppendOperand(operand); +} + +std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { + return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(), + "\", entry=", user_side_metadata_->ToString(), + ", exit=", operand_side_metadata_->ToString(), "}")}; + } + return {}; +} + +bool HloDomainInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloDomainInstruction&>(other); + return operand_side_metadata().Matches( + casted_other.operand_side_metadata()) && + user_side_metadata().Matches(casted_other.user_side_metadata()); +} + +std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span<HloInstruction* const> new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return absl::make_unique<HloDomainInstruction>( + shape, new_operands[0], operand_side_metadata_->Clone(), + user_side_metadata_->Clone()); +} } // namespace xla |