aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc36
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h18
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc59
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h26
5 files changed, 97 insertions, 62 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index aa40fba9bb..a0db4563fb 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2369,20 +2369,20 @@ TEST_P(ConvFilterPaddingTest, DoIt) {
rhs_pad->shape().dimensions(3),
testcase.orig_conv_window))
.ValueOrDie();
- auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
- /*feature_group_count=*/1, window,
- dnums)
- .ValueOrDie(),
- input, rhs_pad, /*feature_group_count=*/1, window, dnums,
- DefaultPrecisionConfig(2)));
// Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
// after the transformation.
PrecisionConfig precision_config;
precision_config.add_operand_precision(PrecisionConfig::HIGH);
precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
- orig_conv->set_precision_config(precision_config);
+
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
+ /*feature_group_count=*/1, window,
+ dnums)
+ .ValueOrDie(),
+ input, rhs_pad, /*feature_group_count=*/1, window, dnums,
+ precision_config));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -2401,7 +2401,9 @@ TEST_P(ConvFilterPaddingTest, DoIt) {
conv->operand(1)->shape().dimensions(2),
conv->operand(1)->shape().dimensions(3),
testcase.expected_conv_window));
- EXPECT_THAT(conv->precision_config().operand_precision(),
+ EXPECT_THAT(Cast<HloConvolutionInstruction>(conv)
+ ->precision_config()
+ .operand_precision(),
ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST));
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index f66a0ae9e7..25ae344ea5 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -2020,11 +2020,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
const HloPrintOptions& options) const {
std::vector<string> extra = ExtraAttributesToStringImpl(options);
- string precision_config_string = PrecisionConfigToString();
- if (!precision_config_string.empty()) {
- extra.push_back(precision_config_string);
- }
-
if (options.print_subcomputation_mode() ==
HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
if (opcode() == HloOpcode::kWhile) {
@@ -2891,27 +2886,6 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
return found->second;
}
-string HloInstruction::PrecisionConfigToString() const {
- if (absl::c_all_of(
- precision_config_.operand_precision(), [](int32 precision) {
- return static_cast<PrecisionConfig::Precision>(precision) ==
- PrecisionConfig::DEFAULT;
- })) {
- return "";
- }
- return StrCat(
- "operand_precision={",
- StrJoin(
- precision_config_.operand_precision(), ",",
- [](string* out, int32 precision) {
- CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
- StrAppend(out,
- PrecisionToString(
- static_cast<PrecisionConfig::Precision>(precision)));
- }),
- "}");
-}
-
StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) {
static std::unordered_map<string, PrecisionConfig::Precision>* map = [] {
static auto* map =
@@ -2971,6 +2945,16 @@ Status HloInstruction::set_backend_config(
return ret;
}
+const PrecisionConfig& HloInstruction::precision_config() const {
+ if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
+ return convolution->precision_config();
+ }
+ if (auto* dot = DynCast<HloDotInstruction>(this)) {
+ return dot->precision_config();
+ }
+ LOG(FATAL) << "Unimplemented method.";
+}
+
HloModule* HloInstruction::GetModule() const {
if (parent_) {
return parent_->parent();
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 1619d1a985..5581c17c2d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -860,11 +860,6 @@ class HloInstruction {
return false;
}
- if (!absl::c_equal(precision_config_.operand_precision(),
- other.precision_config_.operand_precision())) {
- return false;
- }
-
return IdenticalSlowPath(other, eq_computations);
}
@@ -1086,9 +1081,6 @@ class HloInstruction {
// instruction.
void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
- // Returns the dump string of the precision configuration.
- string PrecisionConfigToString() const;
-
// Clones the HLO instruction. The clone will have the same opcode, shape, and
// operands. After creation the clone has no uses. "this" (the instruction
// cloned from) is not changed. Suffix is the string to append to the name of
@@ -1238,10 +1230,8 @@ class HloInstruction {
// information. Transformations to other HLOs will not preserve this
// information but it is presumed that the alternate lowering is strictly
// superior.
- const PrecisionConfig& precision_config() const { return precision_config_; }
- void set_precision_config(const PrecisionConfig& precision_config) {
- precision_config_ = precision_config;
- }
+ // Precondition: opcode must be kConvolution or kDot.
+ const PrecisionConfig& precision_config() const;
// Sets the debug metadata for this instruction.
void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
@@ -1651,10 +1641,6 @@ class HloInstruction {
// HLO. See the documentation on backend_config().
string backend_config_;
- // Information used to communicate to the implementation about the algorithm
- // used to produce results. See the documentation on precision_config().
- PrecisionConfig precision_config_;
-
// String identifier for instruction.
string name_;
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 76712d73db..fb7345a2ad 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -47,6 +47,27 @@ bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
return instruction->IsElementwiseOnOperand(operand_index);
});
}
+
+string PrecisionConfigToString(const PrecisionConfig& precision_config) {
+ if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) {
+ return static_cast<PrecisionConfig::Precision>(precision) ==
+ PrecisionConfig::DEFAULT;
+ })) {
+ return "";
+ }
+
+ return StrCat(
+ "operand_precision={",
+ StrJoin(
+ precision_config.operand_precision(), ",",
+ [](string* out, int32 precision) {
+ CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
+ StrAppend(out,
+ PrecisionToString(
+ static_cast<PrecisionConfig::Precision>(precision)));
+ }),
+ "}");
+}
} // namespace
HloBatchNormInstruction::HloBatchNormInstruction(
@@ -1634,7 +1655,8 @@ HloConvolutionInstruction::HloConvolutionInstruction(
: HloInstruction(HloOpcode::kConvolution, shape),
feature_group_count_(feature_group_count),
window_(window),
- convolution_dimension_numbers_(dimension_numbers) {
+ convolution_dimension_numbers_(dimension_numbers),
+ precision_config_(precision_config) {
if (window_util::HasBaseDilation(window)) {
SetAndSanitizeName(StrCat(name(), "-base-dilated"));
}
@@ -1643,7 +1665,6 @@ HloConvolutionInstruction::HloConvolutionInstruction(
}
AppendOperand(lhs);
AppendOperand(rhs);
- set_precision_config(precision_config);
}
string HloConvolutionInstruction::ToCategory() const {
@@ -1663,7 +1684,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const {
*proto.mutable_convolution_dimension_numbers() =
convolution_dimension_numbers_;
proto.set_feature_group_count(feature_group_count_);
- *proto.mutable_precision_config() = precision_config();
+ *proto.mutable_precision_config() = precision_config_;
return proto;
}
@@ -1678,6 +1699,12 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
if (feature_group_count_ != 1) {
extra.push_back(StrCat("feature_group_count=", feature_group_count_));
}
+
+ string precision_config_string = PrecisionConfigToString(precision_config_);
+ if (!precision_config_string.empty()) {
+ extra.push_back(precision_config_string);
+ }
+
return extra;
}
@@ -1693,7 +1720,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath(
return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
protobuf_util::ProtobufEquals(
convolution_dimension_numbers(),
- casted_other.convolution_dimension_numbers());
+ casted_other.convolution_dimension_numbers()) &&
+ protobuf_util::ProtobufEquals(precision_config(),
+ casted_other.precision_config());
}
std::unique_ptr<HloInstruction>
@@ -1703,7 +1732,7 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl(
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloConvolutionInstruction>(
shape, new_operands[0], new_operands[1], feature_group_count_, window(),
- convolution_dimension_numbers_, precision_config());
+ convolution_dimension_numbers_, precision_config_);
}
HloReduceWindowInstruction::HloReduceWindowInstruction(
@@ -2167,22 +2196,28 @@ HloDotInstruction::HloDotInstruction(
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config)
: HloInstruction(HloOpcode::kDot, shape),
- dot_dimension_numbers_(dimension_numbers) {
+ dot_dimension_numbers_(dimension_numbers),
+ precision_config_(precision_config) {
AppendOperand(lhs);
AppendOperand(rhs);
- set_precision_config(precision_config);
}
HloInstructionProto HloDotInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
*proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
- *proto.mutable_precision_config() = precision_config();
+ *proto.mutable_precision_config() = precision_config_;
return proto;
}
std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
- return {DotDimensionNumbersToString()};
+ std::vector<string> extra = {DotDimensionNumbersToString()};
+
+ string precision_config_string = PrecisionConfigToString(precision_config_);
+ if (!precision_config_string.empty()) {
+ extra.push_back(precision_config_string);
+ }
+ return extra;
}
bool HloDotInstruction::IdenticalSlowPath(
@@ -2191,7 +2226,9 @@ bool HloDotInstruction::IdenticalSlowPath(
eq_computations) const {
const auto& casted_other = static_cast<const HloDotInstruction&>(other);
return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
- casted_other.dot_dimension_numbers());
+ casted_other.dot_dimension_numbers()) &&
+ protobuf_util::ProtobufEquals(precision_config(),
+ casted_other.precision_config());
}
std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
@@ -2200,7 +2237,7 @@ std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloDotInstruction>(
shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
- precision_config());
+ precision_config_);
}
string HloDotInstruction::DotDimensionNumbersToString() const {
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index af46148c70..c3a7801164 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -957,6 +957,16 @@ class HloConvolutionInstruction : public HloInstruction {
// The number of feature groups. Must be a divisor of the input feature
// dimension and output feature dimension.
int64 feature_group_count() const { return feature_group_count_; }
+
+ // Returns the information used to tell the implementation information about
+ // what sort of precision is requested. The meaning of the field is backend
+ // specific. At the moment, it is only supported for kConvolution and kDot.
+ // Transformations on one kDot or kConvolution to another will preserve this
+ // information. Transformations to other HLOs will not preserve this
+ // information but it is presumed that the alternate lowering is strictly
+ // superior.
+ const PrecisionConfig& precision_config() const { return precision_config_; }
+
string ToCategory() const override;
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -979,6 +989,9 @@ class HloConvolutionInstruction : public HloInstruction {
Window window_;
// Describes the dimension numbers used for a convolution.
ConvolutionDimensionNumbers convolution_dimension_numbers_;
+ // Information used to communicate to the implementation about the algorithm
+ // used to produce results. See the documentation on precision_config().
+ PrecisionConfig precision_config_;
};
class HloReduceWindowInstruction : public HloInstruction {
@@ -1285,6 +1298,15 @@ class HloDotInstruction : public HloInstruction {
return dot_dimension_numbers_;
}
+ // Returns the information used to tell the implementation information about
+ // what sort of precision is requested. The meaning of the field is backend
+ // specific. At the moment, it is only supported for kConvolution and kDot.
+ // Transformations on one kDot or kConvolution to another will preserve this
+ // information. Transformations to other HLOs will not preserve this
+ // information but it is presumed that the alternate lowering is strictly
+ // superior.
+ const PrecisionConfig& precision_config() const { return precision_config_; }
+
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -1304,6 +1326,10 @@ class HloDotInstruction : public HloInstruction {
// Describes the dimension numbers used for a dot.
DotDimensionNumbers dot_dimension_numbers_;
+
+ // Information used to communicate to the implementation about the algorithm
+ // used to produce results. See the documentation on precision_config().
+ PrecisionConfig precision_config_;
};
class HloDomainInstruction : public HloInstruction {