diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-06 20:48:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-06 20:52:28 -0700 |
commit | 44efcf0db7b9204a77710a7f076c904d0e13e6fa (patch) | |
tree | 2114672b05239fcdb9554bcf26354a81b4f8565b | |
parent | a7e3047fea74a43174c063320fd0cb6bb6dcceb1 (diff) |
Split out HloDomainInstruction as subclass form HloInstruction.
PiperOrigin-RevId: 211916428
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 42 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 19 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 39 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.h | 31 |
4 files changed, 96 insertions, 35 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 563aa695c9..f66a0ae9e7 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -465,6 +465,14 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( proto.dot_dimension_numbers(), precision_config); break; } + case HloOpcode::kDomain: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Domain instruction should have 1 operands but sees " + << proto.operand_ids_size(); + instruction = absl::make_unique<HloDomainInstruction>( + proto.shape(), operands(0), /*operand_side_metadata=*/nullptr, + /*user_side_metadata=*/nullptr); + break; default: { instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -567,7 +575,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kClz: - case HloOpcode::kDomain: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -1137,12 +1144,9 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, HloInstruction* operand, std::unique_ptr<DomainMetadata> operand_side_metadata, std::unique_ptr<DomainMetadata> user_side_metadata) { - auto instruction = - absl::WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); - instruction->operand_side_metadata_ = std::move(operand_side_metadata); - instruction->user_side_metadata_ = std::move(user_side_metadata); - instruction->AppendOperand(operand); - return instruction; + return absl::make_unique<HloDomainInstruction>( + shape, operand, std::move(operand_side_metadata), + std::move(user_side_metadata)); } std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( @@ -1199,6 +1203,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kScatter: case HloOpcode::kIota: case HloOpcode::kDot: + case HloOpcode::kDomain: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1295,12 +1300,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kDomain: - CHECK_EQ(new_operands.size(), 1); - clone = - CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), - user_side_metadata_->Clone()); - break; case HloOpcode::kAfterAll: if (new_operands.empty()) { clone = CreateToken(); @@ -1611,10 +1610,6 @@ bool HloInstruction::IdenticalSlowPath( return false; } - case HloOpcode::kDomain: - return operand_side_metadata().Matches(other.operand_side_metadata()) && - user_side_metadata().Matches(other.user_side_metadata()); - // Ops migrated to subclasses should never come to this line. // TODO(b/80131774): Remove this switch when migration is complete. case HloOpcode::kBatchNormTraining: @@ -1655,6 +1650,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kGather: case HloOpcode::kScatter: case HloOpcode::kDot: + case HloOpcode::kDomain: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -2114,11 +2110,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString( }), "}")); } - if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { - extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), - "\", entry=", user_side_metadata_->ToString(), - ", exit=", operand_side_metadata_->ToString(), "}")); - } return extra; } @@ -3288,4 +3279,11 @@ const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const { return Cast<HloDotInstruction>(this)->dot_dimension_numbers(); } +const DomainMetadata& HloInstruction::operand_side_metadata() const { + return Cast<HloDomainInstruction>(this)->operand_side_metadata(); +} + +const DomainMetadata& HloInstruction::user_side_metadata() const { + return Cast<HloDomainInstruction>(this)->user_side_metadata(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index de60ddf42d..1619d1a985 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1079,15 +1079,6 @@ class HloInstruction { return other->has_sharding() ? sharding() == other->sharding() : false; } - // Retrieves the operand side metadata of a kDomain instruction. - const DomainMetadata& operand_side_metadata() const { - return *operand_side_metadata_; - } - // Retrieves the user side metadata of a kDomain instruction. - const DomainMetadata& user_side_metadata() const { - return *user_side_metadata_; - } - // When creating a new instruction which either replaces, or shifts up (kCopy // insertion case), another instruction, we need to make sure the certain // properties of the new instruction are copied into the derived one. As of @@ -1496,6 +1487,12 @@ class HloInstruction { // Delegates to HloDotInstruction::dot_dimension_numbers(). const DotDimensionNumbers& dot_dimension_numbers() const; + // Delegates to HloDomainInstruction::operand_side_metadata(). + const DomainMetadata& operand_side_metadata() const; + + // Delegates to HloDomainInstruction::user_side_metadata(). + const DomainMetadata& user_side_metadata() const; + // Old methods kept for smooth subclassing transition END. protected: @@ -1641,10 +1638,6 @@ class HloInstruction { // many element tuples. std::shared_ptr<const HloSharding> sharding_; - // Fields used by the kDomain instruction. - std::unique_ptr<DomainMetadata> operand_side_metadata_; - std::unique_ptr<DomainMetadata> user_side_metadata_; - // Computations called by this instruction. std::vector<HloComputation*> called_computations_; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 4e3e0c055e..76712d73db 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2224,4 +2224,43 @@ string HloDotInstruction::DotDimensionNumbersToString() const { return StrJoin(result, ", "); } + +HloDomainInstruction::HloDomainInstruction( + const Shape& shape, HloInstruction* operand, + std::unique_ptr<DomainMetadata> operand_side_metadata, + std::unique_ptr<DomainMetadata> user_side_metadata) + : HloInstruction(HloOpcode::kDomain, shape), + operand_side_metadata_(std::move(operand_side_metadata)), + user_side_metadata_(std::move(user_side_metadata)) { + AppendOperand(operand); +} + +std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { + return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(), + "\", entry=", user_side_metadata_->ToString(), + ", exit=", operand_side_metadata_->ToString(), "}")}; + } + return {}; +} + +bool HloDomainInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloDomainInstruction&>(other); + return operand_side_metadata().Matches( + casted_other.operand_side_metadata()) && + user_side_metadata().Matches(casted_other.user_side_metadata()); +} + +std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span<HloInstruction* const> new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return absl::make_unique<HloDomainInstruction>( + shape, new_operands[0], operand_side_metadata_->Clone(), + user_side_metadata_->Clone()); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index e72ddabff9..af46148c70 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1306,6 +1306,37 @@ class HloDotInstruction : public HloInstruction { DotDimensionNumbers dot_dimension_numbers_; }; +class HloDomainInstruction : public HloInstruction { + public: + explicit HloDomainInstruction( + const Shape& shape, HloInstruction* operand, + std::unique_ptr<DomainMetadata> operand_side_metadata, + std::unique_ptr<DomainMetadata> user_side_metadata); + + // Retrieves the operand side metadata of a kDomain instruction. + const DomainMetadata& operand_side_metadata() const { + return *operand_side_metadata_; + } + // Retrieves the user side metadata of a kDomain instruction. + const DomainMetadata& user_side_metadata() const { + return *user_side_metadata_; + } + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, absl::Span<HloInstruction* const> new_operands, + HloCloneContext* context) const override; + + std::unique_ptr<DomainMetadata> operand_side_metadata_; + std::unique_ptr<DomainMetadata> user_side_metadata_; +}; } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ |