From 5f3281dd4a0d72cb51064599118088167878e0ef Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 12 Jun 2018 22:39:11 -0700 Subject: Split out HloGetTupleIndexInstruction and HloReducePrecisionInstruction as subclasses from HloInstruction. PiperOrigin-RevId: 200337508 --- tensorflow/compiler/xla/service/hlo_instruction.cc | 76 ++++++++------------ tensorflow/compiler/xla/service/hlo_instruction.h | 37 +++------- .../compiler/xla/service/hlo_instructions.cc | 81 ++++++++++++++++++++++ tensorflow/compiler/xla/service/hlo_instructions.h | 55 +++++++++++++++ 4 files changed, 174 insertions(+), 75 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index aafb3b9dfd..39662d1735 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -233,6 +233,16 @@ StatusOr> HloInstruction::CreateFromProto( instruction = CreateParameter(proto.parameter_number(), proto.shape(), proto.name()); break; + case HloOpcode::kGetTupleElement: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateGetTupleElement(proto.shape(), operands(0), + proto.tuple_index()); + break; + case HloOpcode::kReducePrecision: + instruction = + CreateReducePrecision(proto.shape(), operands(0), + proto.exponent_bits(), proto.mantissa_bits()); + break; default: { instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -260,11 +270,9 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(!proto.name().empty()); instruction->SetAndSanitizeName(proto.name()); - instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); - instruction->tuple_index_ = proto.tuple_index(); if (proto.has_window()) { instruction->window_ = MakeUnique(proto.window()); } @@ -278,8 +286,6 @@ StatusOr> HloInstruction::CreateFromProto( MakeUnique(proto.dot_dimension_numbers()); } - instruction->exponent_bits_ = proto.exponent_bits(); - instruction->mantissa_bits_ = proto.mantissa_bits(); for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) { instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size); } @@ -334,12 +340,7 @@ StatusOr> HloInstruction::CreateFromProto( /* static */ std::unique_ptr HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { - CHECK(ShapeUtil::IsTuple(operand->shape())); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kGetTupleElement, shape)); - instruction->tuple_index_ = index; - instruction->AppendOperand(operand); - return instruction; + return MakeUnique(shape, operand, index); } /* static */ std::unique_ptr HloInstruction::CreateRng( @@ -520,12 +521,8 @@ HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kReducePrecision, shape)); - instruction->AppendOperand(operand); - instruction->exponent_bits_ = exponent_bits; - instruction->mantissa_bits_ = mantissa_bits; - return instruction; + return MakeUnique( + shape, operand, exponent_bits, mantissa_bits); } /* static */ std::unique_ptr @@ -1041,6 +1038,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kFusion: case HloOpcode::kRng: case HloOpcode::kParameter: + case HloOpcode::kGetTupleElement: + case HloOpcode::kReducePrecision: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1127,11 +1126,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBitcastConvert(shape, new_operands[0]); break; - case HloOpcode::kReducePrecision: - CHECK_EQ(new_operands.size(), 1); - clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_, - mantissa_bits_); - break; case HloOpcode::kConvolution: CHECK_EQ(new_operands.size(), 2); clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, @@ -1147,10 +1141,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CreateCrossReplicaSum(shape, new_operands, to_apply(), replica_group_ids_, cross_replica_sum_barrier_); break; - case HloOpcode::kGetTupleElement: - CHECK_EQ(new_operands.size(), 1); - clone = CreateGetTupleElement(shape, new_operands[0], tuple_index()); - break; case HloOpcode::kPad: CHECK_EQ(new_operands.size(), 2); clone = @@ -1297,11 +1287,6 @@ const HloInstruction* HloInstruction::LatestNonGteAncestor() const { return hlo; } -int64 HloInstruction::tuple_index() const { - CHECK_EQ(HloOpcode::kGetTupleElement, opcode_); - return tuple_index_; -} - const HloInstruction* HloInstruction::operand(int64 i) const { return operands_[i]; } @@ -1464,11 +1449,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kGenerateToken: return false; - // A reduce-precision operation is determined by the bit sizes. - case HloOpcode::kReducePrecision: - return exponent_bits() == other.exponent_bits() && - mantissa_bits() == other.mantissa_bits(); - // Convolution has a window and dimensions. case HloOpcode::kConvolution: return protobuf_util::ProtobufEquals(window(), other.window()) && @@ -1497,8 +1477,6 @@ bool HloInstruction::IdenticalSlowPath( protobuf_util::ProtobufEquals(window(), other.window()); // Remaining instructions with special values. - case HloOpcode::kGetTupleElement: - return tuple_index() == other.tuple_index(); case HloOpcode::kPad: return protobuf_util::ProtobufEquals(padding_config(), other.padding_config()); @@ -1555,6 +1533,8 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kFusion: case HloOpcode::kRng: case HloOpcode::kParameter: + case HloOpcode::kGetTupleElement: + case HloOpcode::kReducePrecision: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -2044,9 +2024,6 @@ std::vector HloInstruction::ExtraAttributesToString( } } - if (opcode() == HloOpcode::kGetTupleElement) { - extra.push_back(StrCat("index=", tuple_index())); - } if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } @@ -2066,10 +2043,6 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")); } - if (opcode() == HloOpcode::kReducePrecision) { - extra.push_back(StrCat("exponent_bits=", exponent_bits_)); - extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); - } if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), "\", entry=", operand_side_metadata_->ToString(), @@ -2127,7 +2100,6 @@ HloInstructionProto HloInstruction::ToProto() const { } } - proto.set_tuple_index(tuple_index_); if (window_ != nullptr) { *proto.mutable_window() = *window_; } @@ -2147,8 +2119,6 @@ HloInstructionProto HloInstruction::ToProto() const { } } - proto.set_exponent_bits(exponent_bits_); - proto.set_mantissa_bits(mantissa_bits_); for (int64 slice_size : dynamic_slice_sizes_) { proto.add_dynamic_slice_sizes(slice_size); } @@ -3186,4 +3156,16 @@ RandomDistribution HloInstruction::random_distribution() const { int64 HloInstruction::parameter_number() const { return Cast(this)->parameter_number(); } + +int64 HloInstruction::tuple_index() const { + return Cast(this)->tuple_index(); +} + +int32 HloInstruction::exponent_bits() const { + return Cast(this)->exponent_bits(); +} + +int32 HloInstruction::mantissa_bits() const { + return Cast(this)->mantissa_bits(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 245c9e56f1..a206cdab27 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -876,11 +876,6 @@ class HloInstruction { template Status Visit(DfsHloVisitorBase* visitor); - // Returns the tuple index associated with this instruction. - // - // Precondition: opcode() == HloOpcode::kGetTupleElement - int64 tuple_index() const; - // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. @@ -1078,22 +1073,6 @@ class HloInstruction { return dynamic_slice_sizes_; } - // Returns the number of exponent bits for a reduce-precision node. - // - // Precondition: opcode() == HloOpcode::kReducePrecision - int32 exponent_bits() const { - CHECK_EQ(HloOpcode::kReducePrecision, opcode_); - return exponent_bits_; - } - - // Returns the number of mantissa bits for a reduce-precision node. - // - // Precondition: opcode() == HloOpcode::kReducePrecision - int32 mantissa_bits() const { - CHECK_EQ(HloOpcode::kReducePrecision, opcode_); - return mantissa_bits_; - } - // Returns data on the window in a windowed operation such as // convolution. const Window& window() const { @@ -1439,6 +1418,15 @@ class HloInstruction { // Delegates to HloParameterInstruction::parameter_number. int64 parameter_number() const; + + // Delegates to HloGetTupleElementInstruction::tuple_index. + int64 tuple_index() const; + + // Returns the number of exponent bits for a reduce-precision node. + int32 exponent_bits() const; + + // Returns the number of mantissa bits for a reduce-precision node. + int32 mantissa_bits() const; // Old methods kept for smooth subclassing transition END. // Returns the group ids of each replica for CrossReplicaSum op. @@ -1573,9 +1561,6 @@ class HloInstruction { // Result shape of this instruction. Shape shape_; - // Constant index, only present for kGetTupleElement. - int64 tuple_index_ = -1; - // Describes the window in a windowed operation such as convolution. std::unique_ptr window_; @@ -1588,10 +1573,6 @@ class HloInstruction { std::unique_ptr gather_dimension_numbers_; std::vector gather_window_bounds_; - // The bit sizes for a reduce-precision operation. - int32 exponent_bits_ = 0; - int32 mantissa_bits_ = 0; - // Describes the [start, start + size) range size for a dynamic slice // ('start' is specified dynamically in the second operand of the operation). std::vector dynamic_slice_sizes_; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 22c8707e37..d326d5d009 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1203,4 +1203,85 @@ HloParameterInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { return MakeUnique(parameter_number_, shape, name()); } + +HloGetTupleElementInstruction::HloGetTupleElementInstruction( + const Shape& shape, HloInstruction* operand, int64 index) + : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) { + CHECK(ShapeUtil::IsTuple(operand->shape())); + AppendOperand(operand); +} + +HloInstructionProto HloGetTupleElementInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_tuple_index(tuple_index_); + return proto; +} + +std::vector HloGetTupleElementInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("index=", tuple_index())}; +} + +bool HloGetTupleElementInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return tuple_index() == casted_other.tuple_index(); +} + +std::unique_ptr +HloGetTupleElementInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + tuple_index()); +} + +HloReducePrecisionInstruction::HloReducePrecisionInstruction( + const Shape& shape, HloInstruction* operand, const int exponent_bits, + const int mantissa_bits) + : HloInstruction(HloOpcode::kReducePrecision, shape), + exponent_bits_(exponent_bits), + mantissa_bits_(mantissa_bits) { + AppendOperand(operand); +} + +HloInstructionProto HloReducePrecisionInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_exponent_bits(exponent_bits_); + proto.set_mantissa_bits(mantissa_bits_); + return proto; +} + +std::vector HloReducePrecisionInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("exponent_bits=", exponent_bits_), + StrCat("mantissa_bits=", mantissa_bits_)}; +} + +bool HloReducePrecisionInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + // A reduce-precision operation is determined by the bit sizes. + return exponent_bits() == casted_other.exponent_bits() && + mantissa_bits() == casted_other.mantissa_bits(); +} + +std::unique_ptr +HloReducePrecisionInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + shape, new_operands[0], exponent_bits(), mantissa_bits()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index bab2a48166..6749d87555 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -667,6 +667,61 @@ class HloParameterInstruction : public HloInstruction { int64 parameter_number_ = 0; }; +class HloGetTupleElementInstruction : public HloInstruction { + public: + explicit HloGetTupleElementInstruction(const Shape& shape, + HloInstruction* operand, int64 index); + // Returns the tuple index associated with this instruction. + int64 tuple_index() const { return tuple_index_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + int64 tuple_index_ = -1; +}; + +class HloReducePrecisionInstruction : public HloInstruction { + public: + explicit HloReducePrecisionInstruction(const Shape& shape, + HloInstruction* operand, + const int exponent_bits, + const int mantissa_bits); + // Returns the number of exponent bits for a reduce-precision node. + int32 exponent_bits() const { return exponent_bits_; } + // Returns the number of mantissa bits for a reduce-precision node. + int32 mantissa_bits() const { return mantissa_bits_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The bit sizes for a reduce-precision operation. + int32 exponent_bits_ = 0; + int32 mantissa_bits_ = 0; +}; } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ -- cgit v1.2.3