aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-09-07 09:11:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 09:16:11 -0700
commit81110ff2beb38a2cbfbefb69a9b640bf67a8558a (patch)
tree7d3ca9d6292ee6b2163715b20c46d187dc58a89d /tensorflow/compiler/xla/service/hlo_instructions.cc
parent57c9485b24655ec1b640ef10ae3debde280d8d60 (diff)
[XLA] Sink PrecisionConfig into Hlo{Dot,Convolution}Instruction
This field only makes sense on kDot & kConvolution. This should shave a few more bytes off of HloInstruction and remove methods that aren't applicable on many HLOs. PiperOrigin-RevId: 211985502
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc59
1 files changed, 48 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 76712d73db..fb7345a2ad 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -47,6 +47,27 @@ bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
return instruction->IsElementwiseOnOperand(operand_index);
});
}
+
+string PrecisionConfigToString(const PrecisionConfig& precision_config) {
+ if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) {
+ return static_cast<PrecisionConfig::Precision>(precision) ==
+ PrecisionConfig::DEFAULT;
+ })) {
+ return "";
+ }
+
+ return StrCat(
+ "operand_precision={",
+ StrJoin(
+ precision_config.operand_precision(), ",",
+ [](string* out, int32 precision) {
+ CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
+ StrAppend(out,
+ PrecisionToString(
+ static_cast<PrecisionConfig::Precision>(precision)));
+ }),
+ "}");
+}
} // namespace
HloBatchNormInstruction::HloBatchNormInstruction(
@@ -1634,7 +1655,8 @@ HloConvolutionInstruction::HloConvolutionInstruction(
: HloInstruction(HloOpcode::kConvolution, shape),
feature_group_count_(feature_group_count),
window_(window),
- convolution_dimension_numbers_(dimension_numbers) {
+ convolution_dimension_numbers_(dimension_numbers),
+ precision_config_(precision_config) {
if (window_util::HasBaseDilation(window)) {
SetAndSanitizeName(StrCat(name(), "-base-dilated"));
}
@@ -1643,7 +1665,6 @@ HloConvolutionInstruction::HloConvolutionInstruction(
}
AppendOperand(lhs);
AppendOperand(rhs);
- set_precision_config(precision_config);
}
string HloConvolutionInstruction::ToCategory() const {
@@ -1663,7 +1684,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const {
*proto.mutable_convolution_dimension_numbers() =
convolution_dimension_numbers_;
proto.set_feature_group_count(feature_group_count_);
- *proto.mutable_precision_config() = precision_config();
+ *proto.mutable_precision_config() = precision_config_;
return proto;
}
@@ -1678,6 +1699,12 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
if (feature_group_count_ != 1) {
extra.push_back(StrCat("feature_group_count=", feature_group_count_));
}
+
+ string precision_config_string = PrecisionConfigToString(precision_config_);
+ if (!precision_config_string.empty()) {
+ extra.push_back(precision_config_string);
+ }
+
return extra;
}
@@ -1693,7 +1720,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath(
return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
protobuf_util::ProtobufEquals(
convolution_dimension_numbers(),
- casted_other.convolution_dimension_numbers());
+ casted_other.convolution_dimension_numbers()) &&
+ protobuf_util::ProtobufEquals(precision_config(),
+ casted_other.precision_config());
}
std::unique_ptr<HloInstruction>
@@ -1703,7 +1732,7 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl(
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloConvolutionInstruction>(
shape, new_operands[0], new_operands[1], feature_group_count_, window(),
- convolution_dimension_numbers_, precision_config());
+ convolution_dimension_numbers_, precision_config_);
}
HloReduceWindowInstruction::HloReduceWindowInstruction(
@@ -2167,22 +2196,28 @@ HloDotInstruction::HloDotInstruction(
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config)
: HloInstruction(HloOpcode::kDot, shape),
- dot_dimension_numbers_(dimension_numbers) {
+ dot_dimension_numbers_(dimension_numbers),
+ precision_config_(precision_config) {
AppendOperand(lhs);
AppendOperand(rhs);
- set_precision_config(precision_config);
}
HloInstructionProto HloDotInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
*proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
- *proto.mutable_precision_config() = precision_config();
+ *proto.mutable_precision_config() = precision_config_;
return proto;
}
std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {DotDimensionNumbersToString()};
+ std::vector<string> extra = {DotDimensionNumbersToString()};
+
+ string precision_config_string = PrecisionConfigToString(precision_config_);
+ if (!precision_config_string.empty()) {
+ extra.push_back(precision_config_string);
+ }
+ return extra;
}
bool HloDotInstruction::IdenticalSlowPath(
@@ -2191,7 +2226,9 @@ bool HloDotInstruction::IdenticalSlowPath(
eq_computations) const {
const auto& casted_other = static_cast<const HloDotInstruction&>(other);
return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
- casted_other.dot_dimension_numbers());
+ casted_other.dot_dimension_numbers()) &&
+ protobuf_util::ProtobufEquals(precision_config(),
+ casted_other.precision_config());
}
std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
@@ -2200,7 +2237,7 @@ std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloDotInstruction>(
shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
- precision_config());
+ precision_config_);
}
string HloDotInstruction::DotDimensionNumbersToString() const {