diff options
author | David Majnemer <majnemer@google.com> | 2018-09-07 09:11:28 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-07 09:16:11 -0700 |
commit | 81110ff2beb38a2cbfbefb69a9b640bf67a8558a (patch) | |
tree | 7d3ca9d6292ee6b2163715b20c46d187dc58a89d /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | 57c9485b24655ec1b640ef10ae3debde280d8d60 (diff) |
[XLA] Sink PrecisionConfig into Hlo{Dot,Convolution}Instruction
This field only makes sense on kDot & kConvolution. This should shave a few
more bytes off of HloInstruction and remove methods that aren't applicable on
many HLOs.
PiperOrigin-RevId: 211985502
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 59 |
1 files changed, 48 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 76712d73db..fb7345a2ad 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -47,6 +47,27 @@ bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, return instruction->IsElementwiseOnOperand(operand_index); }); } + +string PrecisionConfigToString(const PrecisionConfig& precision_config) { + if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) { + return static_cast<PrecisionConfig::Precision>(precision) == + PrecisionConfig::DEFAULT; + })) { + return ""; + } + + return StrCat( + "operand_precision={", + StrJoin( + precision_config.operand_precision(), ",", + [](string* out, int32 precision) { + CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision; + StrAppend(out, + PrecisionToString( + static_cast<PrecisionConfig::Precision>(precision))); + }), + "}"); +} } // namespace HloBatchNormInstruction::HloBatchNormInstruction( @@ -1634,7 +1655,8 @@ HloConvolutionInstruction::HloConvolutionInstruction( : HloInstruction(HloOpcode::kConvolution, shape), feature_group_count_(feature_group_count), window_(window), - convolution_dimension_numbers_(dimension_numbers) { + convolution_dimension_numbers_(dimension_numbers), + precision_config_(precision_config) { if (window_util::HasBaseDilation(window)) { SetAndSanitizeName(StrCat(name(), "-base-dilated")); } @@ -1643,7 +1665,6 @@ HloConvolutionInstruction::HloConvolutionInstruction( } AppendOperand(lhs); AppendOperand(rhs); - set_precision_config(precision_config); } string HloConvolutionInstruction::ToCategory() const { @@ -1663,7 +1684,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = convolution_dimension_numbers_; proto.set_feature_group_count(feature_group_count_); - *proto.mutable_precision_config() = precision_config(); + *proto.mutable_precision_config() = precision_config_; return proto; } @@ -1678,6 +1699,12 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl( if (feature_group_count_ != 1) { extra.push_back(StrCat("feature_group_count=", feature_group_count_)); } + + string precision_config_string = PrecisionConfigToString(precision_config_); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + return extra; } @@ -1693,7 +1720,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals(window(), casted_other.window()) && protobuf_util::ProtobufEquals( convolution_dimension_numbers(), - casted_other.convolution_dimension_numbers()); + casted_other.convolution_dimension_numbers()) && + protobuf_util::ProtobufEquals(precision_config(), + casted_other.precision_config()); } std::unique_ptr<HloInstruction> @@ -1703,7 +1732,7 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( CHECK_EQ(new_operands.size(), 2); return absl::make_unique<HloConvolutionInstruction>( shape, new_operands[0], new_operands[1], feature_group_count_, window(), - convolution_dimension_numbers_, precision_config()); + convolution_dimension_numbers_, precision_config_); } HloReduceWindowInstruction::HloReduceWindowInstruction( @@ -2167,22 +2196,28 @@ HloDotInstruction::HloDotInstruction( const DotDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config) : HloInstruction(HloOpcode::kDot, shape), - dot_dimension_numbers_(dimension_numbers) { + dot_dimension_numbers_(dimension_numbers), + precision_config_(precision_config) { AppendOperand(lhs); AppendOperand(rhs); - set_precision_config(precision_config); } HloInstructionProto HloDotInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_; - *proto.mutable_precision_config() = precision_config(); + *proto.mutable_precision_config() = precision_config_; return proto; } std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {DotDimensionNumbersToString()}; + std::vector<string> extra = {DotDimensionNumbersToString()}; + + string precision_config_string = PrecisionConfigToString(precision_config_); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + return extra; } bool HloDotInstruction::IdenticalSlowPath( @@ -2191,7 +2226,9 @@ bool HloDotInstruction::IdenticalSlowPath( eq_computations) const { const auto& casted_other = static_cast<const HloDotInstruction&>(other); return protobuf_util::ProtobufEquals(dot_dimension_numbers(), - casted_other.dot_dimension_numbers()); + casted_other.dot_dimension_numbers()) && + protobuf_util::ProtobufEquals(precision_config(), + casted_other.precision_config()); } std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl( @@ -2200,7 +2237,7 @@ std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl( CHECK_EQ(new_operands.size(), 2); return absl::make_unique<HloDotInstruction>( shape, new_operands[0], new_operands[1], dot_dimension_numbers_, - precision_config()); + precision_config_); } string HloDotInstruction::DotDimensionNumbersToString() const { |