diff options
5 files changed, 97 insertions, 62 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index aa40fba9bb..a0db4563fb 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2369,20 +2369,20 @@ TEST_P(ConvFilterPaddingTest, DoIt) { rhs_pad->shape().dimensions(3), testcase.orig_conv_window)) .ValueOrDie(); - auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), - /*feature_group_count=*/1, window, - dnums) - .ValueOrDie(), - input, rhs_pad, /*feature_group_count=*/1, window, dnums, - DefaultPrecisionConfig(2))); // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place // after the transformation. PrecisionConfig precision_config; precision_config.add_operand_precision(PrecisionConfig::HIGH); precision_config.add_operand_precision(PrecisionConfig::HIGHEST); - orig_conv->set_precision_config(precision_config); + + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), + /*feature_group_count=*/1, window, + dnums) + .ValueOrDie(), + input, rhs_pad, /*feature_group_count=*/1, window, dnums, + precision_config)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -2401,7 +2401,9 @@ TEST_P(ConvFilterPaddingTest, DoIt) { conv->operand(1)->shape().dimensions(2), conv->operand(1)->shape().dimensions(3), testcase.expected_conv_window)); - EXPECT_THAT(conv->precision_config().operand_precision(), + EXPECT_THAT(Cast<HloConvolutionInstruction>(conv) + ->precision_config() + .operand_precision(), ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST)); } } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index f66a0ae9e7..25ae344ea5 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2020,11 +2020,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString( const HloPrintOptions& options) const { std::vector<string> extra = ExtraAttributesToStringImpl(options); - string precision_config_string = PrecisionConfigToString(); - if (!precision_config_string.empty()) { - extra.push_back(precision_config_string); - } - if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { if (opcode() == HloOpcode::kWhile) { @@ -2891,27 +2886,6 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) { return found->second; } -string HloInstruction::PrecisionConfigToString() const { - if (absl::c_all_of( - precision_config_.operand_precision(), [](int32 precision) { - return static_cast<PrecisionConfig::Precision>(precision) == - PrecisionConfig::DEFAULT; - })) { - return ""; - } - return StrCat( - "operand_precision={", - StrJoin( - precision_config_.operand_precision(), ",", - [](string* out, int32 precision) { - CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision; - StrAppend(out, - PrecisionToString( - static_cast<PrecisionConfig::Precision>(precision))); - }), - "}"); -} - StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) { static std::unordered_map<string, PrecisionConfig::Precision>* map = [] { static auto* map = @@ -2971,6 +2945,16 @@ Status HloInstruction::set_backend_config( return ret; } +const PrecisionConfig& HloInstruction::precision_config() const { + if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) { + return convolution->precision_config(); + } + if (auto* dot = DynCast<HloDotInstruction>(this)) { + return dot->precision_config(); + } + LOG(FATAL) << "Unimplemented method."; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 1619d1a985..5581c17c2d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -860,11 +860,6 @@ class HloInstruction { return false; } - if (!absl::c_equal(precision_config_.operand_precision(), - other.precision_config_.operand_precision())) { - return false; - } - return IdenticalSlowPath(other, eq_computations); } @@ -1086,9 +1081,6 @@ class HloInstruction { // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; - // Returns the dump string of the precision configuration. - string PrecisionConfigToString() const; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1238,10 +1230,8 @@ class HloInstruction { // information. Transformations to other HLOs will not preserve this // information but it is presumed that the alternate lowering is strictly // superior. - const PrecisionConfig& precision_config() const { return precision_config_; } - void set_precision_config(const PrecisionConfig& precision_config) { - precision_config_ = precision_config; - } + // Precondition: opcode must be kConvolution or kDot. + const PrecisionConfig& precision_config() const; // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } @@ -1651,10 +1641,6 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; - // Information used to communicate to the implementation about the algorithm - // used to produce results. See the documentation on precision_config(). - PrecisionConfig precision_config_; - // String identifier for instruction. string name_; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 76712d73db..fb7345a2ad 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -47,6 +47,27 @@ bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, return instruction->IsElementwiseOnOperand(operand_index); }); } + +string PrecisionConfigToString(const PrecisionConfig& precision_config) { + if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) { + return static_cast<PrecisionConfig::Precision>(precision) == + PrecisionConfig::DEFAULT; + })) { + return ""; + } + + return StrCat( + "operand_precision={", + StrJoin( + precision_config.operand_precision(), ",", + [](string* out, int32 precision) { + CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision; + StrAppend(out, + PrecisionToString( + static_cast<PrecisionConfig::Precision>(precision))); + }), + "}"); +} } // namespace HloBatchNormInstruction::HloBatchNormInstruction( @@ -1634,7 +1655,8 @@ HloConvolutionInstruction::HloConvolutionInstruction( : HloInstruction(HloOpcode::kConvolution, shape), feature_group_count_(feature_group_count), window_(window), - convolution_dimension_numbers_(dimension_numbers) { + convolution_dimension_numbers_(dimension_numbers), + precision_config_(precision_config) { if (window_util::HasBaseDilation(window)) { SetAndSanitizeName(StrCat(name(), "-base-dilated")); } @@ -1643,7 +1665,6 @@ HloConvolutionInstruction::HloConvolutionInstruction( } AppendOperand(lhs); AppendOperand(rhs); - set_precision_config(precision_config); } string HloConvolutionInstruction::ToCategory() const { @@ -1663,7 +1684,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = convolution_dimension_numbers_; proto.set_feature_group_count(feature_group_count_); - *proto.mutable_precision_config() = precision_config(); + *proto.mutable_precision_config() = precision_config_; return proto; } @@ -1678,6 +1699,12 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl( if (feature_group_count_ != 1) { extra.push_back(StrCat("feature_group_count=", feature_group_count_)); } + + string precision_config_string = PrecisionConfigToString(precision_config_); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + return extra; } @@ -1693,7 +1720,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals(window(), casted_other.window()) && protobuf_util::ProtobufEquals( convolution_dimension_numbers(), - casted_other.convolution_dimension_numbers()); + casted_other.convolution_dimension_numbers()) && + protobuf_util::ProtobufEquals(precision_config(), + casted_other.precision_config()); } std::unique_ptr<HloInstruction> @@ -1703,7 +1732,7 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( CHECK_EQ(new_operands.size(), 2); return absl::make_unique<HloConvolutionInstruction>( shape, new_operands[0], new_operands[1], feature_group_count_, window(), - convolution_dimension_numbers_, precision_config()); + convolution_dimension_numbers_, precision_config_); } HloReduceWindowInstruction::HloReduceWindowInstruction( @@ -2167,22 +2196,28 @@ HloDotInstruction::HloDotInstruction( const DotDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config) : HloInstruction(HloOpcode::kDot, shape), - dot_dimension_numbers_(dimension_numbers) { + dot_dimension_numbers_(dimension_numbers), + precision_config_(precision_config) { AppendOperand(lhs); AppendOperand(rhs); - set_precision_config(precision_config); } HloInstructionProto HloDotInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_; - *proto.mutable_precision_config() = precision_config(); + *proto.mutable_precision_config() = precision_config_; return proto; } std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {DotDimensionNumbersToString()}; + std::vector<string> extra = {DotDimensionNumbersToString()}; + + string precision_config_string = PrecisionConfigToString(precision_config_); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + return extra; } bool HloDotInstruction::IdenticalSlowPath( @@ -2191,7 +2226,9 @@ bool HloDotInstruction::IdenticalSlowPath( eq_computations) const { const auto& casted_other = static_cast<const HloDotInstruction&>(other); return protobuf_util::ProtobufEquals(dot_dimension_numbers(), - casted_other.dot_dimension_numbers()); + casted_other.dot_dimension_numbers()) && + protobuf_util::ProtobufEquals(precision_config(), + casted_other.precision_config()); } std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl( @@ -2200,7 +2237,7 @@ std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl( CHECK_EQ(new_operands.size(), 2); return absl::make_unique<HloDotInstruction>( shape, new_operands[0], new_operands[1], dot_dimension_numbers_, - precision_config()); + precision_config_); } string HloDotInstruction::DotDimensionNumbersToString() const { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index af46148c70..c3a7801164 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -957,6 +957,16 @@ class HloConvolutionInstruction : public HloInstruction { // The number of feature groups. Must be a divisor of the input feature // dimension and output feature dimension. int64 feature_group_count() const { return feature_group_count_; } + + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + const PrecisionConfig& precision_config() const { return precision_config_; } + string ToCategory() const override; // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -979,6 +989,9 @@ class HloConvolutionInstruction : public HloInstruction { Window window_; // Describes the dimension numbers used for a convolution. ConvolutionDimensionNumbers convolution_dimension_numbers_; + // Information used to communicate to the implementation about the algorithm + // used to produce results. See the documentation on precision_config(). + PrecisionConfig precision_config_; }; class HloReduceWindowInstruction : public HloInstruction { @@ -1285,6 +1298,15 @@ class HloDotInstruction : public HloInstruction { return dot_dimension_numbers_; } + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + const PrecisionConfig& precision_config() const { return precision_config_; } + // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1304,6 +1326,10 @@ class HloDotInstruction : public HloInstruction { // Describes the dimension numbers used for a dot. DotDimensionNumbers dot_dimension_numbers_; + + // Information used to communicate to the implementation about the algorithm + // used to produce results. See the documentation on precision_config(). + PrecisionConfig precision_config_; }; class HloDomainInstruction : public HloInstruction { |