aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-09-04 11:17:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 11:41:30 -0700
commit5d183ab7fc7b82f1dea0b9fa9c6412c39ade15a1 (patch)
tree79a4f6fcf270617fc56082702b0209240425ae8c
parent9ae8214229960c634c9f82c00f2c0df287c27a9d (diff)
[XLA] Make kConvolution, kDot HLO attributes mandatory
HLO transformations would forget to propagate the feature depth attribute. Making these attributes mandatory, while slightly less convenient for tests, makes HLO transformations more robust. PiperOrigin-RevId: 211490160
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc4
-rw-r--r--tensorflow/compiler/xla/reference_util.cc14
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc21
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc50
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.cc4
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc11
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc13
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/dot_decomposer.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc112
-rw-r--r--tensorflow/compiler/xla/service/graphviz_example.cc7
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc31
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc25
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc37
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc57
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc35
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc41
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc4
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc27
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h10
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc4
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h6
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc7
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc31
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc12
39 files changed, 436 insertions, 254 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index e639028ccd..7f2125f74c 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -990,8 +990,8 @@ XlaOp XlaBuilder::ConvGeneralDilated(
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferConvolveShape(
- lhs_shape, rhs_shape, instr.window(),
- dimension_numbers, feature_group_count));
+ lhs_shape, rhs_shape, feature_group_count,
+ instr.window(), dimension_numbers));
*instr.mutable_convolution_dimension_numbers() = dimension_numbers;
instr.set_feature_group_count(feature_group_count);
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index a4854f593f..8a05d1b0d7 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -564,18 +564,22 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
dim2.set_base_dilation(lhs_dilation.second);
*window.add_dimensions() = dim2;
- const Shape& shape =
- ShapeInference::InferConvolveShape(lhs_literal->shape(),
- rhs_literal->shape(), window, dnums)
- .ConsumeValueOrDie();
+ const Shape& shape = ShapeInference::InferConvolveShape(
+ lhs_literal->shape(), rhs_literal->shape(),
+ /*feature_group_count=*/1, window, dnums)
+ .ConsumeValueOrDie();
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ /*new_size=*/2, PrecisionConfigProto::DEFAULT);
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, precision_config));
HloModuleConfig config;
HloModule module("ReferenceUtil", config);
auto computation = module.AddEntryComputation(b.Build());
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 7c078f07d7..3d18fe3be2 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -950,9 +950,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
new_dot_rhs = rhs_slice;
}
- auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
- dot.shape(), new_dot_lhs, new_dot_rhs, new_dot_dnums));
- new_dot->set_precision_config(dot.precision_config());
+ auto* new_dot = computation_->AddInstruction(
+ HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs,
+ new_dot_dnums, dot.precision_config()));
if (add_result) {
add_result = computation_->AddInstruction(HloInstruction::CreateBinary(
@@ -1053,9 +1053,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
const int n =
right_operand->shape().dimensions(1 - rhs_contracting_dimension);
auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n});
- auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot(
- memoized_shape, left_operand, right_operand, dnums));
- memoized_inst->set_precision_config(dot->precision_config());
+ auto* memoized_inst = computation_->AddInstruction(
+ HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
+ dnums, dot->precision_config()));
// Get pair {start, 0} or {0, start}.
HloInstruction* original_start_indices =
lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1);
@@ -1151,9 +1151,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
dot_dimension_numbers.add_rhs_contracting_dimensions(0);
auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
ShapeUtil::PermuteDimensions({1, 0}, dot->shape()),
- rhs->mutable_operand(0), lhs->mutable_operand(0),
- dot_dimension_numbers));
- new_dot->set_precision_config(dot->precision_config());
+ rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers,
+ dot->precision_config()));
return ReplaceWithNewInstruction(
dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
}
@@ -2477,8 +2476,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
dot_dimension_numbers.add_lhs_contracting_dimensions(1);
dot_dimension_numbers.add_rhs_contracting_dimensions(0);
auto dot = computation_->AddInstruction(HloInstruction::CreateDot(
- dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers));
- dot->set_precision_config(convolution->precision_config());
+ dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
+ convolution->precision_config()));
return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot));
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 43a891e4fa..019840b476 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -1013,6 +1013,13 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
1);
}
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
auto builder = HloComputation::Builder(TestName());
HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -1044,7 +1051,8 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
dim->set_window_reversal(false);
// Create add computation.
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums));
+ ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(builder.Build());
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
@@ -2260,9 +2268,11 @@ TEST_P(ConvInputPaddingTest, DoTest) {
.ValueOrDie();
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(),
- window, dnums)
+ /*feature_group_count=*/1, window,
+ dnums)
.ValueOrDie(),
- lhs_pad, filter, window, dnums));
+ lhs_pad, filter, /*feature_group_count=*/1, window, dnums,
+ DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -2368,9 +2378,11 @@ TEST_P(ConvFilterPaddingTest, DoIt) {
.ValueOrDie();
auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
- window, dnums)
+ /*feature_group_count=*/1, window,
+ dnums)
.ValueOrDie(),
- input, rhs_pad, window, dnums));
+ input, rhs_pad, /*feature_group_count=*/1, window, dnums,
+ DefaultPrecisionConfig(2)));
// Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
// after the transformation.
@@ -2522,8 +2534,9 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
HloInstruction* filter =
b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
- b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
- window, dnums));
+ b.AddInstruction(HloInstruction::CreateConvolve(
+ out_shape, input, filter,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
// TODO(b/80488902): verify this module.
auto module = HloTestBase::CreateNewModule();
@@ -2901,7 +2914,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums,
+ DefaultPrecisionConfig(2)));
std::unique_ptr<HloComputation> dot_computation(builder.Build());
HloComputation::Builder call_builder(TestName() + ".Call");
@@ -3253,8 +3267,8 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
@@ -3329,8 +3343,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
dot_dnums.add_rhs_contracting_dimensions(0);
Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -3393,8 +3407,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
dot_dnums.add_rhs_contracting_dimensions(0);
Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -3511,8 +3525,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
int64 dot_row_size = 1;
int64 dot_col_size = spec.n;
Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -3581,8 +3595,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
int64 dot_row_size = spec.m;
int64 dot_col_size = 1;
Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
index a16b85a0a5..eda026ac56 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
@@ -63,8 +63,8 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size());
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
- MakeDotHlo(new_lhs, new_rhs, new_dim_numbers));
- new_dot->set_precision_config(batch_dot->precision_config());
+ MakeDotHlo(new_lhs, new_rhs, new_dim_numbers,
+ batch_dot->precision_config()));
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped,
MakeReshapeHlo(batch_dot->shape(), new_dot));
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index b08705d4c2..d480d72297 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -308,8 +308,11 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums));
+ HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 8bd1533972..7398f105a0 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -1490,10 +1490,13 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot_ab = builder.AddInstruction(
- HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums));
- auto dot_bc = builder.AddInstruction(
- HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums));
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
+ auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot(
+ shape_2x4, param_a, param_b, dot_dnums, precision_config));
+ auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot(
+ shape_3x4, param_b, param_c, dot_dnums, precision_config));
builder.AddInstruction(
HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0));
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
index 9c81a86bbb..0826380f65 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
@@ -223,8 +223,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
filter_mask, expanded_filter, zero_filter));
auto new_convolution = HloInstruction::CreateConvolve(
convolution->shape(), convolution->mutable_operand(0), new_filter,
- convolution->window(), dim_numbers, /*feature_group_count=*/1);
- new_convolution->set_precision_config(convolution->precision_config());
+ /*feature_group_count=*/1, convolution->window(), dim_numbers,
+ convolution->precision_config());
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
convolution, std::move(new_convolution)));
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
index 098ce17a56..2d9978404c 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
@@ -130,9 +130,9 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
// change the dimension mapping but not the dimension sizes. For
// example, input height and width are the same as before the reshapes.
HloInstruction* new_conv = module->entry_computation()->AddInstruction(
- HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel,
- hlo->window(), new_dnums));
- new_conv->set_precision_config(hlo->precision_config());
+ HloInstruction::CreateConvolve(
+ new_conv_shape, new_input, new_kernel, hlo->feature_group_count(),
+ hlo->window(), new_dnums, hlo->precision_config()));
// Reshape the output back to the shape of the original convolution.
TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction(
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
index 547d4c696d..616c453750 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
@@ -56,6 +56,13 @@ class ConvCanonicalizationTest : public HloTestBase {
static constexpr int kOutputFeatureCount = 64;
};
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
auto builder = HloComputation::Builder(TestName());
// The input dimensions are in CNHW order.
@@ -84,7 +91,8 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(
F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}),
- input, kernel, conv_window_, dnums));
+ input, kernel, /*feature_group_count=*/1, conv_window_, dnums,
+ DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -146,7 +154,8 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(
F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}),
- input, kernel, conv_window_, dnums));
+ input, kernel, /*feature_group_count=*/1, conv_window_, dnums,
+ DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 284929ca07..6bd0a2dd90 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -38,7 +38,11 @@ std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs,
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
+ return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
+ precision_config);
}
TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc
index 09cb10d6ee..b2ba261790 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.cc
+++ b/tensorflow/compiler/xla/service/dot_decomposer.cc
@@ -134,9 +134,9 @@ Status DecomposeBatchDot(HloInstruction* dot) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot(
- dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums));
- dot_r2->set_precision_config(dot->precision_config());
+ auto dot_r2 = computation->AddInstruction(
+ HloInstruction::CreateDot(dot_shape_r2, lhs_slice_r2, rhs_slice_r2,
+ dot_dnums, dot->precision_config()));
// Reshape Dot to R3 so we can concat along batch dimension.
auto dot_r3 = computation->AddInstruction(
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
index 46c23db465..9b46bfc098 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
@@ -95,6 +95,13 @@ class CudnnConvolutionRewriterTest : public HloVerifiedTestBase {
ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_;
};
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) {
HloComputation::Builder builder(TestName());
HloInstruction* activations =
@@ -107,12 +114,12 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) {
conv_window.mutable_dimensions(1)->set_size(2);
conv_window.mutable_dimensions(1)->set_window_dilation(2);
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(activations->shape(),
- gradients->shape(), conv_window,
- tf_default_dnums_for_backward_filter_)
+ ShapeInference::InferConvolveShape(
+ activations->shape(), gradients->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_filter_)
.ConsumeValueOrDie(),
- activations, gradients, conv_window,
- tf_default_dnums_for_backward_filter_));
+ activations, gradients, /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -135,12 +142,12 @@ TEST_F(CudnnConvolutionRewriterTest,
Window conv_window = default_conv_window_;
conv_window.mutable_dimensions(1)->set_size(3);
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(activations->shape(),
- gradients->shape(), conv_window,
- tf_default_dnums_for_backward_filter_)
+ ShapeInference::InferConvolveShape(
+ activations->shape(), gradients->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_filter_)
.ConsumeValueOrDie(),
- activations, gradients, conv_window,
- tf_default_dnums_for_backward_filter_));
+ activations, gradients, /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -170,7 +177,8 @@ TEST_F(CudnnConvolutionRewriterTest,
}
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
- conv_window, tf_default_dnums_for_backward_filter_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -200,7 +208,8 @@ TEST_F(CudnnConvolutionRewriterTest,
}
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
- conv_window, tf_default_dnums_for_backward_filter_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -228,7 +237,8 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
}
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
- conv_window, tf_default_dnums_for_backward_filter_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -272,13 +282,14 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) {
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output,
- /*rhs=*/reverse_kernel, conv_window, conv_dnums));
+ /*rhs=*/reverse_kernel, /*feature_group_count=*/1, conv_window,
+ conv_dnums, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(),
- ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window, conv_dnums)
- .ValueOrDie()));
+ conv->shape(), ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(),
+ /*feature_group_count=*/1, conv_window, conv_dnums)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -319,11 +330,11 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) {
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
- conv_window,
+ /*feature_group_count=*/1, conv_window,
tf_default_dnums_for_backward_input_)
.ConsumeValueOrDie(),
- /*lhs=*/output, /*rhs=*/kernel, conv_window,
- tf_default_dnums_for_backward_input_));
+ /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -350,12 +361,13 @@ TEST_F(CudnnConvolutionRewriterTest,
1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
- default_conv_window_,
- tf_default_dnums_for_backward_input_)
+ ShapeInference::InferConvolveShape(
+ output->shape(), kernel->shape(), /*feature_group_count=*/1,
+ default_conv_window_, tf_default_dnums_for_backward_input_)
.ConsumeValueOrDie(),
- /*lhs=*/output, /*rhs=*/kernel, default_conv_window_,
- tf_default_dnums_for_backward_input_));
+ /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
+ default_conv_window_, tf_default_dnums_for_backward_input_,
+ DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -402,13 +414,15 @@ TEST_F(CudnnConvolutionRewriterTest,
}
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -449,13 +463,15 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
}
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -502,13 +518,15 @@ TEST_F(CudnnConvolutionRewriterTest,
forward_conv_col_dim->set_base_dilation(2);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
const HloComputation* entry_computation =
@@ -554,13 +572,15 @@ TEST_F(CudnnConvolutionRewriterTest,
forward_conv_col_dim->set_padding_high(2);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc
index a2be89511b..0a49d85c6d 100644
--- a/tensorflow/compiler/xla/service/graphviz_example.cc
+++ b/tensorflow/compiler/xla/service/graphviz_example.cc
@@ -112,8 +112,11 @@ std::unique_ptr<HloModule> MakeBigGraph() {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums));
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ /*new_size=*/2, PrecisionConfigProto::DEFAULT);
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ vshape, clamp, param_v0, dot_dnums, precision_config));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({dot, param_s, clamp}));
auto scalar = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 5f85f14565..576c5ff7a4 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -353,6 +353,13 @@ TEST_F(HeapSimulatorTest, BufferReusedOnce) {
(neg_buffer == output_buffer_1));
}
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_F(HeapSimulatorTest, MultiplyDot) {
auto builder = HloComputation::Builder(TestName());
auto paramA = builder.AddInstruction(
@@ -366,8 +373,8 @@ TEST_F(HeapSimulatorTest, MultiplyDot) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
// The buffer for dot is the output, and it cannot be shared with the buffer
// for mul, since dot isn't elementwise.
@@ -402,8 +409,8 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA));
@@ -440,10 +447,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot0 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
- auto dot1 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums));
+ auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
// The buffer for dot1 is the output. No buffers can be shared. The buffer
// for mul is freed before the end, since it's no longer used after dot0
@@ -481,10 +488,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot0 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
- auto dot1 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums));
+ auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1}));
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index f7ed1b0316..a2c1ce34c6 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -601,8 +601,11 @@ TEST_F(HloComputationTest, Stringification) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -633,8 +636,11 @@ TEST_F(HloComputationTest, StringificationIndent) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -666,8 +672,11 @@ TEST_F(HloComputationTest, StringificationCanonical) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index 19ffb465c0..a6ae0337a5 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -61,15 +61,18 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
}
StatusOr<HloInstruction*> MakeConvolveHlo(
- HloInstruction* lhs, HloInstruction* rhs, const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
+ const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config) {
HloComputation* computation = lhs->parent();
CHECK_EQ(computation, rhs->parent());
- TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape(
- lhs->shape(), rhs->shape(),
- window, dimension_numbers));
+ TF_ASSIGN_OR_RETURN(Shape convolve_shape,
+ ShapeInference::InferConvolveShape(
+ lhs->shape(), rhs->shape(), feature_group_count,
+ window, dimension_numbers));
return computation->AddInstruction(HloInstruction::CreateConvolve(
- convolve_shape, lhs, rhs, window, dimension_numbers));
+ convolve_shape, lhs, rhs, feature_group_count, window, dimension_numbers,
+ precision_config));
}
StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
@@ -164,15 +167,17 @@ StatusOr<HloInstruction*> MakeConcatHlo(
HloInstruction::CreateConcatenate(concat_shape, operands, dimension));
}
-StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dim_numbers) {
+StatusOr<HloInstruction*> MakeDotHlo(
+ HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config) {
HloComputation* computation = lhs->parent();
CHECK_EQ(computation, rhs->parent());
TF_ASSIGN_OR_RETURN(
Shape dot_shape,
ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers));
- return computation->AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers));
+ return computation->AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dim_numbers, precision_config));
}
StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index a1c4b374d1..1c82956907 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -48,8 +48,9 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
// Creates a convolution HLO instruction and adds it to the computation
// containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
StatusOr<HloInstruction*> MakeConvolveHlo(
- HloInstruction* lhs, HloInstruction* rhs, const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
+ const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config);
// Creates a transpose HLO instruction and adds it to the computation containing
// `operand`.
@@ -97,8 +98,10 @@ StatusOr<HloInstruction*> MakeConcatHlo(
// Creates a Dot HLO instruction and adds it to the computation containing `lhs`
// and `rhs` (both must be in the same computation).
-StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dim_numbers);
+StatusOr<HloInstruction*> MakeDotHlo(
+ HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config);
// Creates a Map HLO instruction and adds it to the computation containing the
// operands. All operands must be in the same computation.
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index d1a96c10f8..62eea2b06c 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -2334,8 +2334,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
+ HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 441dcad000..ffb3451164 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -53,7 +53,6 @@ namespace xla {
namespace {
-
template <typename OperandT>
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
LiteralSlice lhs_literal,
@@ -345,7 +344,8 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
- const DotDimensionNumbers& dim_numbers, const Literal& lhs,
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config, const Literal& lhs,
const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
HloInstruction::CreateConstant(lhs.CloneToUnique());
@@ -358,7 +358,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
- dim_numbers);
+ dim_numbers, precision_config);
return Evaluate(cloned_instruction.get());
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index c2d49e56ac..e13af8e999 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -115,7 +115,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
HloOpcode opcode, const Literal& operand);
StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
- const DotDimensionNumbers& dim_numbers, const Literal& lhs,
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config, const Literal& lhs,
const Literal& rhs);
protected:
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 7e490d7f32..3ab8ef18dd 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -622,6 +622,13 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
HloComputation::Builder b(TestName());
@@ -649,7 +656,8 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
- rhs_instruction, dot_dnums));
+ rhs_instruction, dot_dnums,
+ DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -694,7 +702,8 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0);
b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
- rhs_instruction, dot_dnums));
+ rhs_instruction, dot_dnums,
+ DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -737,7 +746,8 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
- rhs_instruction, dot_dnums));
+ rhs_instruction, dot_dnums,
+ DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -790,7 +800,8 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -844,7 +855,8 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -927,7 +939,8 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -1004,7 +1017,8 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -1063,7 +1077,8 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -1126,7 +1141,8 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -1197,7 +1213,8 @@ TEST_P(HloEvaluatorTest,
const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index cb27e13e99..dc16a84246 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1021,9 +1021,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
CHECK_EQ(num_spatial_dims + 2, lhs_rank);
CHECK_EQ(num_spatial_dims + 2, rhs_rank);
- TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape,
- window, dnums));
+ TF_ASSIGN_OR_RETURN(
+ auto inferred_return_shape,
+ ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, conv->feature_group_count(), window, dnums));
CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape set to: " << ShapeUtil::HumanString(result_shape)
<< " but is inferred to be: "
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 6d13f85cbb..f25761ac70 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -341,17 +341,21 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
source_target_pairs);
break;
}
- case HloOpcode::kConvolution:
+ case HloOpcode::kConvolution: {
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "Convolution instruction should have 2 operands but sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.has_window());
TF_RET_CHECK(proto.has_convolution_dimension_numbers());
+ PrecisionConfigProto precision_config = proto.precision_config();
+ precision_config.mutable_operand_precision()->Resize(
+ proto.operand_ids_size(), PrecisionConfigProto::DEFAULT);
instruction = CreateConvolve(
- proto.shape(), operands(0), operands(1), proto.window(),
- proto.convolution_dimension_numbers(),
- std::max(static_cast<int64>(proto.feature_group_count()), 1LL));
+ proto.shape(), operands(0), operands(1),
+ std::max<int64>(proto.feature_group_count(), 1), proto.window(),
+ proto.convolution_dimension_numbers(), precision_config);
break;
+ }
case HloOpcode::kReduceWindow:
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "ReduceWindow instruction should have 2 operands but sees "
@@ -468,6 +472,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
computation_map.at(computation_id));
}
}
+ if (instruction->opcode() == HloOpcode::kDot) {
+ instruction->precision_config_ = proto.precision_config();
+ instruction->precision_config_.mutable_operand_precision()->Resize(
+ instruction->operand_count(), PrecisionConfigProto::DEFAULT);
+ TF_RET_CHECK(proto.has_dot_dimension_numbers());
+ instruction->dot_dimension_numbers_ =
+ absl::make_unique<DotDimensionNumbers>(
+ proto.dot_dimension_numbers());
+ } else {
+ TF_RET_CHECK(!proto.has_precision_config())
+ << instruction->opcode() << proto.DebugString();
+ TF_RET_CHECK(!proto.has_dot_dimension_numbers())
+ << instruction->opcode();
+ }
break;
}
}
@@ -476,12 +494,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
- instruction->precision_config_ = proto.precision_config();
-
- if (proto.has_dot_dimension_numbers()) {
- instruction->dot_dimension_numbers_ =
- absl::make_unique<DotDimensionNumbers>(proto.dot_dimension_numbers());
- }
if (proto.has_sharding()) {
TF_ASSIGN_OR_RETURN(const auto& sharding,
@@ -643,10 +655,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
+ int64 feature_group_count, const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config) {
return absl::make_unique<HloConvolutionInstruction>(
- shape, lhs, rhs, window, dimension_numbers, feature_group_count);
+ shape, lhs, rhs, feature_group_count, window, dimension_numbers,
+ precision_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
@@ -658,13 +672,15 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dimension_numbers) {
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config) {
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
instruction->AppendOperand(lhs);
instruction->AppendOperand(rhs);
instruction->dot_dimension_numbers_ =
absl::make_unique<DotDimensionNumbers>(dimension_numbers);
+ instruction->set_precision_config(precision_config);
return instruction;
}
@@ -1057,7 +1073,6 @@ void HloInstruction::SetupDerivedInstruction(
derived_instruction->clear_sharding();
}
derived_instruction->set_metadata(metadata_);
- derived_instruction->set_precision_config(precision_config_);
}
bool HloInstruction::HasSideEffectNoRecurse() const {
@@ -1278,7 +1293,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kDot:
CHECK_EQ(new_operands.size(), 2);
clone = CreateDot(shape, new_operands[0], new_operands[1],
- *dot_dimension_numbers_);
+ *dot_dimension_numbers_, precision_config());
break;
case HloOpcode::kReshape:
CHECK_EQ(new_operands.size(), 1);
@@ -2167,7 +2182,9 @@ HloInstructionProto HloInstruction::ToProto() const {
*proto.mutable_metadata() = metadata_;
proto.set_backend_config(backend_config_);
- *proto.mutable_precision_config() = precision_config_;
+ if (opcode() == HloOpcode::kConvolution || opcode() == HloOpcode::kDot) {
+ *proto.mutable_precision_config() = precision_config_;
+ }
if (opcode() != HloOpcode::kFusion) {
for (const HloComputation* computation : called_computations_) {
proto.add_called_computation_ids(computation->unique_id());
@@ -2948,7 +2965,11 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
}
string HloInstruction::PrecisionConfigToString() const {
- if (precision_config_.operand_precision().empty()) {
+ if (absl::c_all_of(
+ precision_config_.operand_precision(), [](int32 precision) {
+ return static_cast<PrecisionConfigProto::Precision>(precision) ==
+ PrecisionConfigProto::DEFAULT;
+ })) {
return "";
}
return StrCat(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index cca134e8b4..55d592ff94 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -405,9 +405,9 @@ class HloInstruction {
// and window describes how the filter is applied to lhs.
static std::unique_ptr<HloInstruction> CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window,
+ int64 feature_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ const PrecisionConfigProto& precision_config);
// Creates an FFT op, of the type indicated by fft_type.
static std::unique_ptr<HloInstruction> CreateFft(
@@ -418,7 +418,8 @@ class HloInstruction {
// dimensions specified in 'dimension_numbers'.
static std::unique_ptr<HloInstruction> CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dimension_numbers);
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config);
// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
// of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 76b0e940a6..b4e302e832 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1122,6 +1122,13 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
}
}
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
// Fused expression:
//
@@ -1147,8 +1154,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -1188,8 +1195,8 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(s, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ s, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -1239,8 +1246,8 @@ TEST_F(HloInstructionTest, NestedFusionEquality) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums));
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ data_shape, a, b_t, dot_dnums, DefaultPrecisionConfig(2)));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_operand = builder.AddInstruction(
@@ -1320,8 +1327,8 @@ TEST_F(HloInstructionTest, Stringification) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto options = HloPrintOptions().set_print_metadata(false);
@@ -1485,8 +1492,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto options = HloPrintOptions().Canonical();
@@ -1527,8 +1534,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -1583,8 +1590,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e46afa764f..bed273149b 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1628,12 +1628,13 @@ std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
HloConvolutionInstruction::HloConvolutionInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count)
+ int64 feature_group_count, const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config)
: HloInstruction(HloOpcode::kConvolution, shape),
+ feature_group_count_(feature_group_count),
window_(window),
- convolution_dimension_numbers_(dimension_numbers),
- feature_group_count_(feature_group_count) {
+ convolution_dimension_numbers_(dimension_numbers) {
if (window_util::HasBaseDilation(window)) {
SetAndSanitizeName(StrCat(name(), "-base-dilated"));
}
@@ -1642,6 +1643,7 @@ HloConvolutionInstruction::HloConvolutionInstruction(
}
AppendOperand(lhs);
AppendOperand(rhs);
+ set_precision_config(precision_config);
}
string HloConvolutionInstruction::ToCategory() const {
@@ -1697,8 +1699,8 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloConvolutionInstruction>(
- shape, new_operands[0], new_operands[1], window(),
- convolution_dimension_numbers_, feature_group_count_);
+ shape, new_operands[0], new_operands[1], feature_group_count_, window(),
+ convolution_dimension_numbers_, precision_config());
}
HloReduceWindowInstruction::HloReduceWindowInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 3230383579..1c85aa4681 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -942,9 +942,9 @@ class HloConvolutionInstruction : public HloInstruction {
public:
explicit HloConvolutionInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window,
+ int64 feature_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count);
+ const PrecisionConfigProto& precision_config);
const Window& window() const override { return window_; }
void set_window(const Window& window) override { window_ = window; }
const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
@@ -972,12 +972,13 @@ class HloConvolutionInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- Window window_;
- // Describes the dimension numbers used for a convolution.
- ConvolutionDimensionNumbers convolution_dimension_numbers_;
// The number of feature groups. Must be a divisor of the input feature
// dimension and output feature dimension.
int64 feature_group_count_;
+ // Describes the window used for a convolution.
+ Window window_;
+ // Describes the dimension numbers used for a convolution.
+ ConvolutionDimensionNumbers convolution_dimension_numbers_;
};
class HloReduceWindowInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index ea8e6a239a..62f01c4adb 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -530,10 +530,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
&backend_config};
- optional<std::vector<PrecisionConfigProto::Precision>> operand_precision;
- attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
- &operand_precision};
-
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
@@ -913,6 +909,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
AttrTy::kConvolutionDimensionNumbers, &dnums};
attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
&feature_group_count};
+ optional<std::vector<PrecisionConfigProto::Precision>> operand_precision;
+ attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
+ &operand_precision};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
@@ -923,9 +922,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!feature_group_count) {
feature_group_count = 1;
}
+ PrecisionConfigProto precision_config;
+ if (operand_precision) {
+ *precision_config.mutable_operand_precision() = {
+ operand_precision->begin(), operand_precision->end()};
+ } else {
+ precision_config.mutable_operand_precision()->Resize(
+ operands.size(), PrecisionConfigProto::DEFAULT);
+ }
instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
- shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums,
- feature_group_count.value()));
+ shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
+ feature_group_count.value(), *window, *dnums, precision_config));
break;
}
case HloOpcode::kFft: {
@@ -1272,6 +1279,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<std::vector<tensorflow::int64>> rhs_batch_dims;
attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
&rhs_batch_dims};
+ optional<std::vector<PrecisionConfigProto::Precision>> operand_precision;
+ attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
+ &operand_precision};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
@@ -1296,8 +1306,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
rhs_batch_dims->end()};
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateDot(shape, operands[0], operands[1], dnum));
+ PrecisionConfigProto precision_config;
+ if (operand_precision) {
+ *precision_config.mutable_operand_precision() = {
+ operand_precision->begin(), operand_precision->end()};
+ } else {
+ precision_config.mutable_operand_precision()->Resize(
+ operands.size(), PrecisionConfigProto::DEFAULT);
+ }
+
+ instruction = builder->AddInstruction(HloInstruction::CreateDot(
+ shape, operands[0], operands[1], dnum, precision_config));
break;
}
case HloOpcode::kGather: {
@@ -1414,12 +1433,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (backend_config) {
instruction->set_raw_backend_config_string(std::move(*backend_config));
}
- if (operand_precision) {
- PrecisionConfigProto precision_config;
- *precision_config.mutable_operand_precision() = {operand_precision->begin(),
- operand_precision->end()};
- instruction->set_precision_config(precision_config);
- }
return AddInstruction(name, instruction, name_loc);
} // NOLINT(readability/fn_size)
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 95516dec74..069586a738 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -86,8 +86,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
const Shape expected,
ShapeInference::InferConvolveShape(
convolution->operand(0)->shape(), convolution->operand(1)->shape(),
- convolution->window(), convolution->convolution_dimension_numbers(),
- convolution->feature_group_count()));
+ convolution->feature_group_count(), convolution->window(),
+ convolution->convolution_dimension_numbers()));
return CheckShape(convolution, expected);
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index a4de02a890..4a71ee909b 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -165,6 +165,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
TF_ASSIGN_OR_RETURN(
computed_array,
ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(),
+ instr->precision_config(),
FindOrDie(cache_, instr->operand(0)),
FindOrDie(cache_, instr->operand(1))));
} else {
@@ -1030,6 +1031,7 @@ bool CanFoldDotIntoIndexedArray(
StatusOr<Analysis::Array*>
IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config,
ScalarIndexedConstantArray* lhs, ConstantArray* rhs) {
VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " "
<< ToString(rhs);
@@ -1045,9 +1047,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
new_dim_numbers.set_lhs_contracting_dimensions(
0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1));
- TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
- TakeOwnership(HloEvaluator{}.EvaluateDotOp(
- new_dim_numbers, lhs->literal(), *rhs->literal())));
+ TF_ASSIGN_OR_RETURN(
+ Literal * literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateDotOp(
+ new_dim_numbers, precision_config, lhs->literal(), *rhs->literal())));
// The new source dimension is wherever the non-batch non-contracting LHS
// dimension "went".
@@ -1063,7 +1066,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
StatusOr<Analysis::Array*>
IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
- ConstantArray* lhs, ScalarIndexedConstantArray* rhs) {
+ const PrecisionConfigProto& precision_config, ConstantArray* lhs,
+ ScalarIndexedConstantArray* rhs) {
VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " "
<< ToString(rhs);
if (!CanFoldDotIntoIndexedArray(
@@ -1079,9 +1083,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
new_dim_numbers.set_rhs_contracting_dimensions(
0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1));
- TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
- TakeOwnership(HloEvaluator{}.EvaluateDotOp(
- new_dim_numbers, *lhs->literal(), rhs->literal())));
+ TF_ASSIGN_OR_RETURN(
+ Literal * literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateDotOp(
+ new_dim_numbers, precision_config, *lhs->literal(), rhs->literal())));
// The new source dimension is wherever the non-batch non-contracting RHS
// dimension "went".
@@ -1095,8 +1100,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
}
StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
- const Shape& shape, const DotDimensionNumbers& dim_numbers, Array* lhs,
- Array* rhs) {
+ const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs) {
// Intuitively, if
//
// - The LHS of a dot product is a gathered sequence of rows from a constant
@@ -1119,6 +1124,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
dynamic_cast<ScalarIndexedConstantArray*>(lhs)) {
if (auto* rhs_constant = dynamic_cast<ConstantArray*>(rhs)) {
return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers,
+ precision_config,
lhs_indexed_array, rhs_constant);
}
}
@@ -1126,7 +1132,8 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
if (auto* rhs_indexed_array =
dynamic_cast<ScalarIndexedConstantArray*>(rhs)) {
if (auto* lhs_constant = dynamic_cast<ConstantArray*>(lhs)) {
- return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, lhs_constant,
+ return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers,
+ precision_config, lhs_constant,
rhs_indexed_array);
}
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index dcfb725535..f21e784a4d 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -267,15 +267,17 @@ class IndexedArrayAnalysis {
StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config,
ScalarIndexedConstantArray* lhs, ConstantArray* rhs);
StatusOr<Array*> ComputeArrayForDotWithIndexedRhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
- ConstantArray* lhs, ScalarIndexedConstantArray* rhs);
+ const PrecisionConfigProto& precision_config, ConstantArray* lhs,
+ ScalarIndexedConstantArray* rhs);
- StatusOr<Array*> ComputeArrayForDot(const Shape& shape,
- const DotDimensionNumbers& dim_numbers,
- Array* lhs, Array* rhs);
+ StatusOr<Array*> ComputeArrayForDot(
+ const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs);
// This tries to fold a ScalarIndexedArray which has another
// ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 2611749862..7758a5dd4d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1552,8 +1552,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
- const Shape& lhs, const Shape& rhs, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) {
+ const Shape& lhs, const Shape& rhs, int64 feature_group_count,
+ const Window& window, const ConvolutionDimensionNumbers& dnums) {
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index a28345acef..96a0ee165d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -108,9 +108,9 @@ class ShapeInference {
// Infers the shape produced by applying the given convolutional
// filter (rhs) to lhs in the way specified by the fields on window.
static StatusOr<Shape> InferConvolveShape(
- const Shape& lhs, const Shape& rhs, const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ const Shape& lhs, const Shape& rhs, int64 feature_group_count,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers);
// Infers the shape produced by the given FFT type on the given operand.
static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index cc92e58ef8..864ed43118 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -419,8 +419,8 @@ TEST_F(ShapeInferenceTest, Convolve) {
dim1->set_padding_high(0);
dim1->set_window_dilation(1);
dim1->set_base_dilation(1);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
@@ -464,8 +464,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
dim1->set_padding_high(1);
dim1->set_window_dilation(2);
dim1->set_base_dilation(1);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
@@ -509,8 +509,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
dim1->set_padding_high(1);
dim1->set_window_dilation(1);
dim1->set_base_dilation(2);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
@@ -547,8 +547,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
dim1->set_stride(2);
dim1->set_padding_low(1);
dim1->set_padding_high(1);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
HasSubstr("each dimension exactly once"));
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index 530f40e4b2..7c1f4b5cc6 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -108,8 +108,7 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) {
}
std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
- dot->shape(), new_lhs, new_rhs, new_dim_numbers);
- new_dot->set_precision_config(dot->precision_config());
+ dot->shape(), new_lhs, new_rhs, new_dim_numbers, dot->precision_config());
return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
}
@@ -178,8 +177,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
}
auto new_conv = HloInstruction::CreateConvolve(
- convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums);
- new_conv->set_precision_config(convolution.precision_config());
+ convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(),
+ convolution.window(), new_dnums, convolution.precision_config());
TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
&convolution, std::move(new_conv)));
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index 58f767e913..e486a00e53 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -215,6 +215,13 @@ ENTRY entry_computation {
/*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
}
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
// Test that a two dimension swap of the kernel gets folded into convolution.
TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
auto builder = HloComputation::Builder("entry_computation");
@@ -240,10 +247,12 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- x->shape(), transpose_y->shape(), window, dnums);
+ x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
+ conv_shape.ValueOrDie(), x, transpose_y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
@@ -293,10 +302,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- x->shape(), transpose_y->shape(), window, dnums);
+ x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
+ conv_shape.ValueOrDie(), x, transpose_y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
@@ -351,10 +362,12 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- transpose_x->shape(), y->shape(), window, dnums);
+ transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
+ conv_shape.ValueOrDie(), transpose_x, y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
@@ -415,10 +428,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- transpose_x->shape(), y->shape(), window, dnums);
+ transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
+ conv_shape.ValueOrDie(), transpose_x, y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index a32d1f9026..e3328203a6 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -1064,8 +1064,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ /*new_size=*/2, PrecisionConfigProto::DEFAULT);
auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
+ HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index 05f90ba9fb..53b5e933b6 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -47,6 +47,12 @@ limitations under the License.
namespace xla {
namespace {
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
class MultiOutputFusionTest : public HloTestBase {
protected:
@@ -90,8 +96,8 @@ class MultiOutputFusionTest : public HloTestBase {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ elem_shape2, sub, add2, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
if (manual_fusion) {
@@ -154,7 +160,7 @@ class MultiOutputFusionTest : public HloTestBase {
dot_dnums.add_rhs_contracting_dimensions(0);
HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape,
- dot_dnums));
+ dot_dnums, DefaultPrecisionConfig(2)));
auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
if (manual_fusion) {