aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc21
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc50
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.cc4
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc11
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc13
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/dot_decomposer.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc112
-rw-r--r--tensorflow/compiler/xla/service/graphviz_example.cc7
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc31
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc25
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.cc41
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_metadata.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc112
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h43
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc41
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc57
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc35
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc41
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc4
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc27
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h10
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc76
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc14
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h6
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc7
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc31
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc5
46 files changed, 687 insertions, 303 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 26b48cf419..f6cfac6537 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -3289,6 +3289,8 @@ tf_cc_test(
size = "small",
srcs = ["hlo_parser_test.cc"],
deps = [
+ ":hlo",
+ ":hlo_casting_utils",
":hlo_matchers",
":hlo_parser",
"//tensorflow/compiler/xla:window_util",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 7c078f07d7..3d18fe3be2 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -950,9 +950,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
new_dot_rhs = rhs_slice;
}
- auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
- dot.shape(), new_dot_lhs, new_dot_rhs, new_dot_dnums));
- new_dot->set_precision_config(dot.precision_config());
+ auto* new_dot = computation_->AddInstruction(
+ HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs,
+ new_dot_dnums, dot.precision_config()));
if (add_result) {
add_result = computation_->AddInstruction(HloInstruction::CreateBinary(
@@ -1053,9 +1053,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
const int n =
right_operand->shape().dimensions(1 - rhs_contracting_dimension);
auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n});
- auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot(
- memoized_shape, left_operand, right_operand, dnums));
- memoized_inst->set_precision_config(dot->precision_config());
+ auto* memoized_inst = computation_->AddInstruction(
+ HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
+ dnums, dot->precision_config()));
// Get pair {start, 0} or {0, start}.
HloInstruction* original_start_indices =
lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1);
@@ -1151,9 +1151,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
dot_dimension_numbers.add_rhs_contracting_dimensions(0);
auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
ShapeUtil::PermuteDimensions({1, 0}, dot->shape()),
- rhs->mutable_operand(0), lhs->mutable_operand(0),
- dot_dimension_numbers));
- new_dot->set_precision_config(dot->precision_config());
+ rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers,
+ dot->precision_config()));
return ReplaceWithNewInstruction(
dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
}
@@ -2477,8 +2476,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
dot_dimension_numbers.add_lhs_contracting_dimensions(1);
dot_dimension_numbers.add_rhs_contracting_dimensions(0);
auto dot = computation_->AddInstruction(HloInstruction::CreateDot(
- dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers));
- dot->set_precision_config(convolution->precision_config());
+ dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
+ convolution->precision_config()));
return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot));
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 43a891e4fa..019840b476 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -1013,6 +1013,13 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
1);
}
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
auto builder = HloComputation::Builder(TestName());
HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -1044,7 +1051,8 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
dim->set_window_reversal(false);
// Create add computation.
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums));
+ ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(builder.Build());
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
@@ -2260,9 +2268,11 @@ TEST_P(ConvInputPaddingTest, DoTest) {
.ValueOrDie();
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(),
- window, dnums)
+ /*feature_group_count=*/1, window,
+ dnums)
.ValueOrDie(),
- lhs_pad, filter, window, dnums));
+ lhs_pad, filter, /*feature_group_count=*/1, window, dnums,
+ DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -2368,9 +2378,11 @@ TEST_P(ConvFilterPaddingTest, DoIt) {
.ValueOrDie();
auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
- window, dnums)
+ /*feature_group_count=*/1, window,
+ dnums)
.ValueOrDie(),
- input, rhs_pad, window, dnums));
+ input, rhs_pad, /*feature_group_count=*/1, window, dnums,
+ DefaultPrecisionConfig(2)));
// Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
// after the transformation.
@@ -2522,8 +2534,9 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
HloInstruction* filter =
b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
- b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
- window, dnums));
+ b.AddInstruction(HloInstruction::CreateConvolve(
+ out_shape, input, filter,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
// TODO(b/80488902): verify this module.
auto module = HloTestBase::CreateNewModule();
@@ -2901,7 +2914,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums,
+ DefaultPrecisionConfig(2)));
std::unique_ptr<HloComputation> dot_computation(builder.Build());
HloComputation::Builder call_builder(TestName() + ".Call");
@@ -3253,8 +3267,8 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
@@ -3329,8 +3343,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
dot_dnums.add_rhs_contracting_dimensions(0);
Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -3393,8 +3407,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
dot_dnums.add_rhs_contracting_dimensions(0);
Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -3511,8 +3525,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
int64 dot_row_size = 1;
int64 dot_col_size = spec.n;
Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -3581,8 +3595,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
int64 dot_row_size = spec.m;
int64 dot_col_size = 1;
Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
- builder.AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums));
+ builder.AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2)));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
index a16b85a0a5..eda026ac56 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
@@ -63,8 +63,8 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size());
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
- MakeDotHlo(new_lhs, new_rhs, new_dim_numbers));
- new_dot->set_precision_config(batch_dot->precision_config());
+ MakeDotHlo(new_lhs, new_rhs, new_dim_numbers,
+ batch_dot->precision_config()));
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped,
MakeReshapeHlo(batch_dot->shape(), new_dot));
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index b08705d4c2..d480d72297 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -308,8 +308,11 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums));
+ HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 8bd1533972..7398f105a0 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -1490,10 +1490,13 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot_ab = builder.AddInstruction(
- HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums));
- auto dot_bc = builder.AddInstruction(
- HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums));
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
+ auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot(
+ shape_2x4, param_a, param_b, dot_dnums, precision_config));
+ auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot(
+ shape_3x4, param_b, param_c, dot_dnums, precision_config));
builder.AddInstruction(
HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0));
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
index 9c81a86bbb..0826380f65 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
@@ -223,8 +223,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
filter_mask, expanded_filter, zero_filter));
auto new_convolution = HloInstruction::CreateConvolve(
convolution->shape(), convolution->mutable_operand(0), new_filter,
- convolution->window(), dim_numbers, /*feature_group_count=*/1);
- new_convolution->set_precision_config(convolution->precision_config());
+ /*feature_group_count=*/1, convolution->window(), dim_numbers,
+ convolution->precision_config());
TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
convolution, std::move(new_convolution)));
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
index 098ce17a56..2d9978404c 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
@@ -130,9 +130,9 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
// change the dimension mapping but not the dimension sizes. For
// example, input height and width are the same as before the reshapes.
HloInstruction* new_conv = module->entry_computation()->AddInstruction(
- HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel,
- hlo->window(), new_dnums));
- new_conv->set_precision_config(hlo->precision_config());
+ HloInstruction::CreateConvolve(
+ new_conv_shape, new_input, new_kernel, hlo->feature_group_count(),
+ hlo->window(), new_dnums, hlo->precision_config()));
// Reshape the output back to the shape of the original convolution.
TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction(
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
index 547d4c696d..616c453750 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
@@ -56,6 +56,13 @@ class ConvCanonicalizationTest : public HloTestBase {
static constexpr int kOutputFeatureCount = 64;
};
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
auto builder = HloComputation::Builder(TestName());
// The input dimensions are in CNHW order.
@@ -84,7 +91,8 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(
F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}),
- input, kernel, conv_window_, dnums));
+ input, kernel, /*feature_group_count=*/1, conv_window_, dnums,
+ DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -146,7 +154,8 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(
F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}),
- input, kernel, conv_window_, dnums));
+ input, kernel, /*feature_group_count=*/1, conv_window_, dnums,
+ DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 284929ca07..6bd0a2dd90 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -38,7 +38,11 @@ std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs,
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
+ return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
+ precision_config);
}
TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc
index 09cb10d6ee..b2ba261790 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.cc
+++ b/tensorflow/compiler/xla/service/dot_decomposer.cc
@@ -134,9 +134,9 @@ Status DecomposeBatchDot(HloInstruction* dot) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot(
- dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums));
- dot_r2->set_precision_config(dot->precision_config());
+ auto dot_r2 = computation->AddInstruction(
+ HloInstruction::CreateDot(dot_shape_r2, lhs_slice_r2, rhs_slice_r2,
+ dot_dnums, dot->precision_config()));
// Reshape Dot to R3 so we can concat along batch dimension.
auto dot_r3 = computation->AddInstruction(
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
index 46c23db465..9b46bfc098 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
@@ -95,6 +95,13 @@ class CudnnConvolutionRewriterTest : public HloVerifiedTestBase {
ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_;
};
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) {
HloComputation::Builder builder(TestName());
HloInstruction* activations =
@@ -107,12 +114,12 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) {
conv_window.mutable_dimensions(1)->set_size(2);
conv_window.mutable_dimensions(1)->set_window_dilation(2);
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(activations->shape(),
- gradients->shape(), conv_window,
- tf_default_dnums_for_backward_filter_)
+ ShapeInference::InferConvolveShape(
+ activations->shape(), gradients->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_filter_)
.ConsumeValueOrDie(),
- activations, gradients, conv_window,
- tf_default_dnums_for_backward_filter_));
+ activations, gradients, /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -135,12 +142,12 @@ TEST_F(CudnnConvolutionRewriterTest,
Window conv_window = default_conv_window_;
conv_window.mutable_dimensions(1)->set_size(3);
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(activations->shape(),
- gradients->shape(), conv_window,
- tf_default_dnums_for_backward_filter_)
+ ShapeInference::InferConvolveShape(
+ activations->shape(), gradients->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_filter_)
.ConsumeValueOrDie(),
- activations, gradients, conv_window,
- tf_default_dnums_for_backward_filter_));
+ activations, gradients, /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -170,7 +177,8 @@ TEST_F(CudnnConvolutionRewriterTest,
}
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
- conv_window, tf_default_dnums_for_backward_filter_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -200,7 +208,8 @@ TEST_F(CudnnConvolutionRewriterTest,
}
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
- conv_window, tf_default_dnums_for_backward_filter_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -228,7 +237,8 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
}
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
- conv_window, tf_default_dnums_for_backward_filter_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -272,13 +282,14 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) {
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output,
- /*rhs=*/reverse_kernel, conv_window, conv_dnums));
+ /*rhs=*/reverse_kernel, /*feature_group_count=*/1, conv_window,
+ conv_dnums, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(),
- ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window, conv_dnums)
- .ValueOrDie()));
+ conv->shape(), ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(),
+ /*feature_group_count=*/1, conv_window, conv_dnums)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -319,11 +330,11 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) {
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
- conv_window,
+ /*feature_group_count=*/1, conv_window,
tf_default_dnums_for_backward_input_)
.ConsumeValueOrDie(),
- /*lhs=*/output, /*rhs=*/kernel, conv_window,
- tf_default_dnums_for_backward_input_));
+ /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -350,12 +361,13 @@ TEST_F(CudnnConvolutionRewriterTest,
1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
- default_conv_window_,
- tf_default_dnums_for_backward_input_)
+ ShapeInference::InferConvolveShape(
+ output->shape(), kernel->shape(), /*feature_group_count=*/1,
+ default_conv_window_, tf_default_dnums_for_backward_input_)
.ConsumeValueOrDie(),
- /*lhs=*/output, /*rhs=*/kernel, default_conv_window_,
- tf_default_dnums_for_backward_input_));
+ /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
+ default_conv_window_, tf_default_dnums_for_backward_input_,
+ DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -402,13 +414,15 @@ TEST_F(CudnnConvolutionRewriterTest,
}
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -449,13 +463,15 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
}
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -502,13 +518,15 @@ TEST_F(CudnnConvolutionRewriterTest,
forward_conv_col_dim->set_base_dilation(2);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
const HloComputation* entry_computation =
@@ -554,13 +572,15 @@ TEST_F(CudnnConvolutionRewriterTest,
forward_conv_col_dim->set_padding_high(2);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel,
- conv_window, tf_default_dnums_for_backward_input_));
+ /*feature_group_count=*/1, conv_window,
+ tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
- conv->shape(), ShapeInference::InferConvolveShape(
- output->shape(), reverse_kernel->shape(), conv_window,
- tf_default_dnums_for_backward_input_)
- .ValueOrDie()));
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+ conv_window, tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
auto module = CreateNewModule();
HloComputation* entry_computation =
diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc
index a2be89511b..0a49d85c6d 100644
--- a/tensorflow/compiler/xla/service/graphviz_example.cc
+++ b/tensorflow/compiler/xla/service/graphviz_example.cc
@@ -112,8 +112,11 @@ std::unique_ptr<HloModule> MakeBigGraph() {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums));
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ /*new_size=*/2, PrecisionConfigProto::DEFAULT);
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ vshape, clamp, param_v0, dot_dnums, precision_config));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({dot, param_s, clamp}));
auto scalar = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 5f85f14565..576c5ff7a4 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -353,6 +353,13 @@ TEST_F(HeapSimulatorTest, BufferReusedOnce) {
(neg_buffer == output_buffer_1));
}
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_F(HeapSimulatorTest, MultiplyDot) {
auto builder = HloComputation::Builder(TestName());
auto paramA = builder.AddInstruction(
@@ -366,8 +373,8 @@ TEST_F(HeapSimulatorTest, MultiplyDot) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
// The buffer for dot is the output, and it cannot be shared with the buffer
// for mul, since dot isn't elementwise.
@@ -402,8 +409,8 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA));
@@ -440,10 +447,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot0 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
- auto dot1 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums));
+ auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
// The buffer for dot1 is the output. No buffers can be shared. The buffer
// for mul is freed before the end, since it's no longer used after dot0
@@ -481,10 +488,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot0 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
- auto dot1 = builder.AddInstruction(
- HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums));
+ auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
+ f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1}));
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index f7ed1b0316..a2c1ce34c6 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -601,8 +601,11 @@ TEST_F(HloComputationTest, Stringification) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -633,8 +636,11 @@ TEST_F(HloComputationTest, StringificationIndent) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -666,8 +672,11 @@ TEST_F(HloComputationTest, StringificationCanonical) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index 19ffb465c0..a6ae0337a5 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -61,15 +61,18 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
}
StatusOr<HloInstruction*> MakeConvolveHlo(
- HloInstruction* lhs, HloInstruction* rhs, const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
+ const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config) {
HloComputation* computation = lhs->parent();
CHECK_EQ(computation, rhs->parent());
- TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape(
- lhs->shape(), rhs->shape(),
- window, dimension_numbers));
+ TF_ASSIGN_OR_RETURN(Shape convolve_shape,
+ ShapeInference::InferConvolveShape(
+ lhs->shape(), rhs->shape(), feature_group_count,
+ window, dimension_numbers));
return computation->AddInstruction(HloInstruction::CreateConvolve(
- convolve_shape, lhs, rhs, window, dimension_numbers));
+ convolve_shape, lhs, rhs, feature_group_count, window, dimension_numbers,
+ precision_config));
}
StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
@@ -164,15 +167,17 @@ StatusOr<HloInstruction*> MakeConcatHlo(
HloInstruction::CreateConcatenate(concat_shape, operands, dimension));
}
-StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dim_numbers) {
+StatusOr<HloInstruction*> MakeDotHlo(
+ HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config) {
HloComputation* computation = lhs->parent();
CHECK_EQ(computation, rhs->parent());
TF_ASSIGN_OR_RETURN(
Shape dot_shape,
ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers));
- return computation->AddInstruction(
- HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers));
+ return computation->AddInstruction(HloInstruction::CreateDot(
+ dot_shape, lhs, rhs, dim_numbers, precision_config));
}
StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index a1c4b374d1..1c82956907 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -48,8 +48,9 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
// Creates a convolution HLO instruction and adds it to the computation
// containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
StatusOr<HloInstruction*> MakeConvolveHlo(
- HloInstruction* lhs, HloInstruction* rhs, const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
+ const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config);
// Creates a transpose HLO instruction and adds it to the computation containing
// `operand`.
@@ -97,8 +98,10 @@ StatusOr<HloInstruction*> MakeConcatHlo(
// Creates a Dot HLO instruction and adds it to the computation containing `lhs`
// and `rhs` (both must be in the same computation).
-StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dim_numbers);
+StatusOr<HloInstruction*> MakeDotHlo(
+ HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config);
// Creates a Map HLO instruction and adds it to the computation containing the
// operands. All operands must be in the same computation.
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index d1a96c10f8..62eea2b06c 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -2334,8 +2334,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfigProto::DEFAULT);
auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
+ HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
index 8b2846e0c2..113fd18eae 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -51,6 +51,10 @@ int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const {
return FindOrDefault(instruction_to_domain_, instruction, -1);
}
+int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const {
+ return FindOrDie(domain_metadata_id_, instruction);
+}
+
Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
// We only check operands, so we are sure to not process the empty domain from
@@ -93,6 +97,43 @@ Status HloDomainMap::Populate(HloComputation* computation) {
CreateDomain(instruction, instructions_post_order));
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
}
+ TF_RETURN_IF_ERROR(PopulateDomainMetadataMap());
+ return Status::OK();
+}
+
+Status HloDomainMap::PopulateDomainMetadataMap() {
+ auto hash = [](const DomainMetadata* m) { return m->Hash(); };
+ auto equal = [](const DomainMetadata* a, const DomainMetadata* b) {
+ return a->Matches(*b);
+ };
+ tensorflow::gtl::FlatMap<const DomainMetadata*, int64, decltype(hash),
+ decltype(equal)>
+ domain_metadata(1024, hash, equal);
+
+ for (auto& domain : instruction_domains_) {
+ int64 domain_metadata_id = -1;
+ if (!domain->enter_domains.empty()) {
+ const HloInstruction* domain_instruction = *domain->enter_domains.begin();
+ domain_metadata_id =
+ domain_metadata
+ .insert({&domain_instruction->user_side_metadata(),
+ domain_metadata.size() + 1})
+ .first->second;
+ } else if (!domain->exit_domains.empty()) {
+ const HloInstruction* domain_instruction = *domain->exit_domains.begin();
+ domain_metadata_id =
+ domain_metadata
+ .insert({&domain_instruction->operand_side_metadata(),
+ domain_metadata.size() + 1})
+ .first->second;
+ } else {
+ domain_metadata_id = 0;
+ }
+ TF_RET_CHECK(domain_metadata_id >= 0);
+ for (HloInstruction* instruction : domain->instructions) {
+ domain_metadata_id_[instruction] = domain_metadata_id;
+ }
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h
index 633109249a..56b557d7ce 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.h
@@ -69,6 +69,11 @@ class HloDomainMap {
// instruction is not found within any domain.
int64 GetDomainId(HloInstruction* instruction) const;
+ // Returns the unique id of the domain metadata for the domain the given
+ // instruction belongs to. The given instruction must not be a kDomain
+ // instruction since each domain instruction is associated with 2 domains.
+ int64 GetDomainMetadataId(HloInstruction* instruction) const;
+
private:
// Map used for representing instruction ordering, i.e.
// order_map[a] < order_map[b] means a must be ordered before b.
@@ -109,9 +114,14 @@ class HloDomainMap {
const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
const InstructionOrderMap& instructions_order);
+ // Populates domain_metadata_id_ that maps each HloInstruction to the unique
+ // ID of its associated domain metatadata.
+ Status PopulateDomainMetadataMap();
+
string domain_kind_;
std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
tensorflow::gtl::FlatMap<HloInstruction*, int64> instruction_to_domain_;
+ tensorflow::gtl::FlatMap<HloInstruction*, int64> domain_metadata_id_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
index 6c142ee474..302807f816 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
@@ -72,6 +72,9 @@ class DomainMetadata {
// two matches.
virtual bool Matches(const DomainMetadata& other) const = 0;
+ // Returns the hash value of the metadata.
+ virtual size_t Hash() const = 0;
+
// Returns a string representation of the metadata.
virtual string ToString() const = 0;
};
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index 974ab94467..43e74d2f6f 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -99,6 +99,8 @@ class OpNameMetadata : public DomainMetadata {
static absl::string_view KindName() { return "opname"; }
+ size_t Hash() const override { return std::hash<string>()(opname_); }
+
private:
string opname_;
};
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 441dcad000..ffb3451164 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -53,7 +53,6 @@ namespace xla {
namespace {
-
template <typename OperandT>
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
LiteralSlice lhs_literal,
@@ -345,7 +344,8 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
- const DotDimensionNumbers& dim_numbers, const Literal& lhs,
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config, const Literal& lhs,
const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
HloInstruction::CreateConstant(lhs.CloneToUnique());
@@ -358,7 +358,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
- dim_numbers);
+ dim_numbers, precision_config);
return Evaluate(cloned_instruction.get());
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index c2d49e56ac..e13af8e999 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -115,7 +115,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
HloOpcode opcode, const Literal& operand);
StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
- const DotDimensionNumbers& dim_numbers, const Literal& lhs,
+ const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config, const Literal& lhs,
const Literal& rhs);
protected:
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 7e490d7f32..f586f253da 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -622,6 +622,13 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
HloComputation::Builder b(TestName());
@@ -649,7 +656,8 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
- rhs_instruction, dot_dnums));
+ rhs_instruction, dot_dnums,
+ DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -694,7 +702,8 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0);
b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
- rhs_instruction, dot_dnums));
+ rhs_instruction, dot_dnums,
+ DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -737,7 +746,8 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
- rhs_instruction, dot_dnums));
+ rhs_instruction, dot_dnums,
+ DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -788,9 +798,10 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
dnums.set_kernel_input_feature_dimension(1);
dnums.add_kernel_spatial_dimensions(2);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -842,9 +853,10 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
ConvolutionDimensionNumbers dnums =
XlaBuilder::CreateDefaultConvDimensionNumbers(2);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -925,9 +937,10 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
dnums.add_kernel_spatial_dimensions(3);
dnums.add_kernel_spatial_dimensions(1);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -1002,9 +1015,10 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
dnums.add_kernel_spatial_dimensions(3);
dnums.add_kernel_spatial_dimensions(1);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -1061,9 +1075,10 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
ConvolutionDimensionNumbers dnums =
XlaBuilder::CreateDefaultConvDimensionNumbers(2);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -1124,9 +1139,10 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
ConvolutionDimensionNumbers dnums =
XlaBuilder::CreateDefaultConvDimensionNumbers(2);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -1195,9 +1211,10 @@ TEST_P(HloEvaluatorTest,
ConvolutionDimensionNumbers dnums =
XlaBuilder::CreateDefaultConvDimensionNumbers(2);
- const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
b.AddInstruction(HloInstruction::CreateConvolve(
- shape, lhs_instruction, rhs_instruction, window, dnums));
+ shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+ window, dnums, DefaultPrecisionConfig(2)));
module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result = Evaluate();
@@ -1219,6 +1236,67 @@ TEST_P(HloEvaluatorTest,
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
+TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
+ HloComputation::Builder b(TestName());
+ std::vector<int64> input_dims = {1, 2, 2, 4};
+ std::vector<int64> filter_dims = {2, 2, 2, 8};
+ Shape input_shape = ShapeUtil::MakeShapeWithType<float>(input_dims);
+ Shape filter_shape = ShapeUtil::MakeShapeWithType<float>(filter_dims);
+ // Tensorflow dimension numbers for 2D convolution.
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_input_batch_dimension(0);
+ dnums.set_output_batch_dimension(0);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.set_input_feature_dimension(3);
+ dnums.set_output_feature_dimension(3);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+
+ Window window;
+ WindowDimension dim;
+ dim.set_size(2);
+ dim.set_stride(1);
+ dim.set_padding_low(0);
+ dim.set_padding_high(0);
+ dim.set_window_dilation(1);
+ dim.set_base_dilation(1);
+ *window.add_dimensions() = dim;
+ *window.add_dimensions() = dim;
+
+ std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
+ std::iota(input_elems.begin(), input_elems.end(), -7);
+ auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
+ auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ HloInstruction* lhs_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4)));
+
+ std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
+ std::iota(filter_elems.begin(), filter_elems.end(), -31);
+ auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
+ auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ HloInstruction* rhs_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4)));
+
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8});
+ b.AddInstruction(HloInstruction::CreateConvolve(
+ shape, lhs_instruction, rhs_instruction,
+ /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2)));
+ module().AddEntryComputation(b.Build());
+
+ std::unique_ptr<Literal> result = Evaluate();
+
+ Array4D<float> expected_array(1, 1, 1, 8);
+ expected_array.FillWithYX(
+ Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}}));
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+}
+
class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
// Tests that Reduce doesn't lose precision when adding many numbers (because
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index cb27e13e99..6a09bb08f4 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1021,9 +1021,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
CHECK_EQ(num_spatial_dims + 2, lhs_rank);
CHECK_EQ(num_spatial_dims + 2, rhs_rank);
- TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape,
- window, dnums));
+ TF_ASSIGN_OR_RETURN(
+ auto inferred_return_shape,
+ ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, conv->feature_group_count(), window, dnums));
CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape set to: " << ShapeUtil::HumanString(result_shape)
<< " but is inferred to be: "
@@ -1046,9 +1047,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto lhs_literal_data = lhs_literal.data<ReturnT>();
auto rhs_literal_data = rhs_literal.data<ReturnT>();
+ int64 feature_group_count = conv->feature_group_count();
+
auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
&lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
- rhs_literal_data](absl::Span<const int64> out_index) {
+ rhs_literal_data,
+ feature_group_count](absl::Span<const int64> out_index) {
// Dimension number applicable for input (lhs).
const int64 input_batch_dim = dnums.input_batch_dimension();
const int64 input_z_dim = dnums.input_feature_dimension();
@@ -1060,6 +1064,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const int64 output_z_dim = dnums.output_feature_dimension();
const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+ const int64 output_z_size =
+ ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim);
ElementwiseT result_val = static_cast<ElementwiseT>(0);
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
@@ -1068,6 +1074,33 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Convolve input feature with kernel.
do {
for (int64 iz = 0; iz < z_size; ++iz) {
+ int64 rhs_iz = iz;
+ // Handle grouped convolutions.
+ if (feature_group_count > 1) {
+ // The size of a feature group.
+ int64 feature_group_size = z_size / feature_group_count;
+ rhs_iz = iz % feature_group_size;
+
+ // The output feature dimension is a concatenation of convolution
+ // results from the different groups.
+ int64 output_feature_group_size =
+ output_z_size / feature_group_count;
+
+ // Calculate the group index to which the current input feature
+ // index belongs.
+ int64 input_group_index = iz / feature_group_size;
+
+ // Calculate the group index to which the current output index
+ // belongs.
+ int64 output_group_index =
+ out_index[output_z_dim] / output_feature_group_size;
+ if (input_group_index != output_group_index) {
+ // If the current output index does not belong to the current
+ // feature group, skip it.
+ continue;
+ }
+ }
+
int64 lhs_linear_index = 0;
lhs_linear_index += out_index[output_batch_dim] *
lhs_dim_multipliers[input_batch_dim];
@@ -1076,7 +1109,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
int64 rhs_linear_index = 0;
rhs_linear_index += out_index[output_z_dim] *
rhs_dim_multipliers[kernel_output_z_dim];
- rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim];
+ rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim];
// Find corresponding spatial dimension index for input (lhs).
for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 3041d94fa9..0345a2a5f8 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -120,12 +120,19 @@ class NodeFilter {
std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
};
+// We arbitrarily set this as the boundary between "large" and "small"
+// instructions.
+bool IsSmall(const HloInstruction* instr) {
+ return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
+}
+
// Node color schemes, used by NodeColorAttributes.
enum ColorScheme {
kBlue,
kBrown,
kDarkBlue,
kDarkGreen,
+ kDarkOrange,
kDarkRed,
kGray,
kGreen,
@@ -158,6 +165,10 @@ NodeColors NodeColorsForScheme(ColorScheme color) {
return NodeColors{"filled", "#1565c0", "#003c8f", "white"};
case kDarkGreen:
return NodeColors{"filled", "#2e7d32", "#005005", "white"};
+ case kDarkOrange:
+ // This is more of a "medium" orange, made to look close to kOrange;
+ // there's probably room for a darker weight if desired.
+ return NodeColors{"filled", "#ffb74d", "#c88719", "black"};
case kDarkRed:
return NodeColors{"filled", "#b71c1c", "#7f0000", "white"};
case kGray:
@@ -893,7 +904,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
sharding_colors_.emplace(instr->sharding(), color);
return color;
}
- const auto kParameterColor = kOrange;
+
+ // Choose different weights of orange for small vs large parameters. This
+ // distinction is often important, especially in fusion nodes.
+ auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange;
// Special case: If this instruction has a parameter merged into it, paint it
// the same color as a parameter. Unless the merged-in parameter is a
@@ -905,7 +919,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
ShouldMergeIntoUsers(operand) &&
TryGetFusionParameterConstant(operand) == nullptr;
})) {
- return kParameterColor;
+ return parameter_color;
}
// Pick different colors or shapes for instructions which are particularly
@@ -1015,7 +1029,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kReducePrecision:
return kRed;
case HloOpcode::kParameter:
- return kParameterColor;
+ return parameter_color;
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
@@ -1160,20 +1174,6 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
return StrJoin(lines, "<br/>");
}
-// Gets the total number of array elements in the given shape. For tuples, this
-// is the sum of all the sizes of all of the array elements recursively in the
-// tuple.
-static int64 TotalElementsInShape(const Shape& shape) {
- int64 elems = 0;
- ShapeUtil::ForEachSubshape(
- shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) {
- if (ShapeUtil::IsArray(subshape)) {
- elems += ShapeUtil::ElementsIn(subshape);
- }
- });
- return elems;
-}
-
void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
int64 operand_num, bool control_edge = false) {
@@ -1196,14 +1196,11 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
}
// We print "small" arrays using a hollow arrowhead and "large" arrays using
- // a filled arrowhead. For now, we use an arbitrary cutoff for what "big"
- // means.
- bool is_big_array = TotalElementsInShape(from->shape()) >= 4096;
-
+ // a filled arrowhead.
constexpr char kEdgeFmt[] =
R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
- (is_big_array ? "normal" : "empty"),
+ (IsSmall(from) ? "empty" : "normal"),
from->name(), to->name(), edge_label));
};
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 6d13f85cbb..f25761ac70 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -341,17 +341,21 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
source_target_pairs);
break;
}
- case HloOpcode::kConvolution:
+ case HloOpcode::kConvolution: {
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "Convolution instruction should have 2 operands but sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.has_window());
TF_RET_CHECK(proto.has_convolution_dimension_numbers());
+ PrecisionConfigProto precision_config = proto.precision_config();
+ precision_config.mutable_operand_precision()->Resize(
+ proto.operand_ids_size(), PrecisionConfigProto::DEFAULT);
instruction = CreateConvolve(
- proto.shape(), operands(0), operands(1), proto.window(),
- proto.convolution_dimension_numbers(),
- std::max(static_cast<int64>(proto.feature_group_count()), 1LL));
+ proto.shape(), operands(0), operands(1),
+ std::max<int64>(proto.feature_group_count(), 1), proto.window(),
+ proto.convolution_dimension_numbers(), precision_config);
break;
+ }
case HloOpcode::kReduceWindow:
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "ReduceWindow instruction should have 2 operands but sees "
@@ -468,6 +472,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
computation_map.at(computation_id));
}
}
+ if (instruction->opcode() == HloOpcode::kDot) {
+ instruction->precision_config_ = proto.precision_config();
+ instruction->precision_config_.mutable_operand_precision()->Resize(
+ instruction->operand_count(), PrecisionConfigProto::DEFAULT);
+ TF_RET_CHECK(proto.has_dot_dimension_numbers());
+ instruction->dot_dimension_numbers_ =
+ absl::make_unique<DotDimensionNumbers>(
+ proto.dot_dimension_numbers());
+ } else {
+ TF_RET_CHECK(!proto.has_precision_config())
+ << instruction->opcode() << proto.DebugString();
+ TF_RET_CHECK(!proto.has_dot_dimension_numbers())
+ << instruction->opcode();
+ }
break;
}
}
@@ -476,12 +494,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
- instruction->precision_config_ = proto.precision_config();
-
- if (proto.has_dot_dimension_numbers()) {
- instruction->dot_dimension_numbers_ =
- absl::make_unique<DotDimensionNumbers>(proto.dot_dimension_numbers());
- }
if (proto.has_sharding()) {
TF_ASSIGN_OR_RETURN(const auto& sharding,
@@ -643,10 +655,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count) {
+ int64 feature_group_count, const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config) {
return absl::make_unique<HloConvolutionInstruction>(
- shape, lhs, rhs, window, dimension_numbers, feature_group_count);
+ shape, lhs, rhs, feature_group_count, window, dimension_numbers,
+ precision_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
@@ -658,13 +672,15 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dimension_numbers) {
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config) {
auto instruction =
absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
instruction->AppendOperand(lhs);
instruction->AppendOperand(rhs);
instruction->dot_dimension_numbers_ =
absl::make_unique<DotDimensionNumbers>(dimension_numbers);
+ instruction->set_precision_config(precision_config);
return instruction;
}
@@ -1057,7 +1073,6 @@ void HloInstruction::SetupDerivedInstruction(
derived_instruction->clear_sharding();
}
derived_instruction->set_metadata(metadata_);
- derived_instruction->set_precision_config(precision_config_);
}
bool HloInstruction::HasSideEffectNoRecurse() const {
@@ -1278,7 +1293,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kDot:
CHECK_EQ(new_operands.size(), 2);
clone = CreateDot(shape, new_operands[0], new_operands[1],
- *dot_dimension_numbers_);
+ *dot_dimension_numbers_, precision_config());
break;
case HloOpcode::kReshape:
CHECK_EQ(new_operands.size(), 1);
@@ -2167,7 +2182,9 @@ HloInstructionProto HloInstruction::ToProto() const {
*proto.mutable_metadata() = metadata_;
proto.set_backend_config(backend_config_);
- *proto.mutable_precision_config() = precision_config_;
+ if (opcode() == HloOpcode::kConvolution || opcode() == HloOpcode::kDot) {
+ *proto.mutable_precision_config() = precision_config_;
+ }
if (opcode() != HloOpcode::kFusion) {
for (const HloComputation* computation : called_computations_) {
proto.add_called_computation_ids(computation->unique_id());
@@ -2948,7 +2965,11 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
}
string HloInstruction::PrecisionConfigToString() const {
- if (precision_config_.operand_precision().empty()) {
+ if (absl::c_all_of(
+ precision_config_.operand_precision(), [](int32 precision) {
+ return static_cast<PrecisionConfigProto::Precision>(precision) ==
+ PrecisionConfigProto::DEFAULT;
+ })) {
return "";
}
return StrCat(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index cca134e8b4..55d592ff94 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -405,9 +405,9 @@ class HloInstruction {
// and window describes how the filter is applied to lhs.
static std::unique_ptr<HloInstruction> CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window,
+ int64 feature_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ const PrecisionConfigProto& precision_config);
// Creates an FFT op, of the type indicated by fft_type.
static std::unique_ptr<HloInstruction> CreateFft(
@@ -418,7 +418,8 @@ class HloInstruction {
// dimensions specified in 'dimension_numbers'.
static std::unique_ptr<HloInstruction> CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const DotDimensionNumbers& dimension_numbers);
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config);
// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
// of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 76b0e940a6..b4e302e832 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1122,6 +1122,13 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
}
}
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
// Fused expression:
//
@@ -1147,8 +1154,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -1188,8 +1195,8 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(s, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ s, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -1239,8 +1246,8 @@ TEST_F(HloInstructionTest, NestedFusionEquality) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums));
+ auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+ data_shape, a, b_t, dot_dnums, DefaultPrecisionConfig(2)));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_operand = builder.AddInstruction(
@@ -1320,8 +1327,8 @@ TEST_F(HloInstructionTest, Stringification) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto options = HloPrintOptions().set_print_metadata(false);
@@ -1485,8 +1492,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto options = HloPrintOptions().Canonical();
@@ -1527,8 +1534,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
@@ -1583,8 +1590,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule();
auto* computation = module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e46afa764f..e3683aaec9 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1628,12 +1628,13 @@ std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
HloConvolutionInstruction::HloConvolutionInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count)
+ int64 feature_group_count, const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ const PrecisionConfigProto& precision_config)
: HloInstruction(HloOpcode::kConvolution, shape),
+ feature_group_count_(feature_group_count),
window_(window),
- convolution_dimension_numbers_(dimension_numbers),
- feature_group_count_(feature_group_count) {
+ convolution_dimension_numbers_(dimension_numbers) {
if (window_util::HasBaseDilation(window)) {
SetAndSanitizeName(StrCat(name(), "-base-dilated"));
}
@@ -1642,6 +1643,7 @@ HloConvolutionInstruction::HloConvolutionInstruction(
}
AppendOperand(lhs);
AppendOperand(rhs);
+ set_precision_config(precision_config);
}
string HloConvolutionInstruction::ToCategory() const {
@@ -1672,7 +1674,9 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
}
extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
convolution_dimension_numbers_)));
- extra.push_back(StrCat("feature_group_count=", feature_group_count_));
+ if (feature_group_count_ != 1) {
+ extra.push_back(StrCat("feature_group_count=", feature_group_count_));
+ }
return extra;
}
@@ -1697,8 +1701,8 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloConvolutionInstruction>(
- shape, new_operands[0], new_operands[1], window(),
- convolution_dimension_numbers_, feature_group_count_);
+ shape, new_operands[0], new_operands[1], feature_group_count_, window(),
+ convolution_dimension_numbers_, precision_config());
}
HloReduceWindowInstruction::HloReduceWindowInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 3230383579..1c85aa4681 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -942,9 +942,9 @@ class HloConvolutionInstruction : public HloInstruction {
public:
explicit HloConvolutionInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window,
+ int64 feature_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count);
+ const PrecisionConfigProto& precision_config);
const Window& window() const override { return window_; }
void set_window(const Window& window) override { window_ = window; }
const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
@@ -972,12 +972,13 @@ class HloConvolutionInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- Window window_;
- // Describes the dimension numbers used for a convolution.
- ConvolutionDimensionNumbers convolution_dimension_numbers_;
// The number of feature groups. Must be a divisor of the input feature
// dimension and output feature dimension.
int64 feature_group_count_;
+ // Describes the window used for a convolution.
+ Window window_;
+ // Describes the dimension numbers used for a convolution.
+ ConvolutionDimensionNumbers convolution_dimension_numbers_;
};
class HloReduceWindowInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index ea8e6a239a..62f01c4adb 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -530,10 +530,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
&backend_config};
- optional<std::vector<PrecisionConfigProto::Precision>> operand_precision;
- attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
- &operand_precision};
-
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
@@ -913,6 +909,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
AttrTy::kConvolutionDimensionNumbers, &dnums};
attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
&feature_group_count};
+ optional<std::vector<PrecisionConfigProto::Precision>> operand_precision;
+ attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
+ &operand_precision};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
@@ -923,9 +922,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!feature_group_count) {
feature_group_count = 1;
}
+ PrecisionConfigProto precision_config;
+ if (operand_precision) {
+ *precision_config.mutable_operand_precision() = {
+ operand_precision->begin(), operand_precision->end()};
+ } else {
+ precision_config.mutable_operand_precision()->Resize(
+ operands.size(), PrecisionConfigProto::DEFAULT);
+ }
instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
- shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums,
- feature_group_count.value()));
+ shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
+ feature_group_count.value(), *window, *dnums, precision_config));
break;
}
case HloOpcode::kFft: {
@@ -1272,6 +1279,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<std::vector<tensorflow::int64>> rhs_batch_dims;
attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
&rhs_batch_dims};
+ optional<std::vector<PrecisionConfigProto::Precision>> operand_precision;
+ attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
+ &operand_precision};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
@@ -1296,8 +1306,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
rhs_batch_dims->end()};
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateDot(shape, operands[0], operands[1], dnum));
+ PrecisionConfigProto precision_config;
+ if (operand_precision) {
+ *precision_config.mutable_operand_precision() = {
+ operand_precision->begin(), operand_precision->end()};
+ } else {
+ precision_config.mutable_operand_precision()->Resize(
+ operands.size(), PrecisionConfigProto::DEFAULT);
+ }
+
+ instruction = builder->AddInstruction(HloInstruction::CreateDot(
+ shape, operands[0], operands[1], dnum, precision_config));
break;
}
case HloOpcode::kGather: {
@@ -1414,12 +1433,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (backend_config) {
instruction->set_raw_backend_config_string(std::move(*backend_config));
}
- if (operand_precision) {
- PrecisionConfigProto precision_config;
- *precision_config.mutable_operand_precision() = {operand_precision->begin(),
- operand_precision->end()};
- instruction->set_precision_config(precision_config);
- }
return AddInstruction(name, instruction, name_loc);
} // NOLINT(readability/fn_size)
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 759789437c..0dfc0a4d1c 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -382,7 +384,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
- ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=1, operand_precision={high,default}
+ ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}
}
)"
@@ -395,7 +397,7 @@ R"(HloModule ConvolveR2_module
ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] {
%input = f32[1,2]{1,0} parameter(0)
%filter = f32[1,1]{1,0} parameter(1)
- ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf, feature_group_count=1
+ ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf
}
)"
@@ -408,7 +410,7 @@ R"(HloModule ConvolveBackward_module
ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] {
%input = f32[128,7,7,512]{0,3,2,1} parameter(0)
%filter = f32[3,3,512,512]{3,2,1,0} parameter(1)
- ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f, feature_group_count=1
+ ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f
}
)"
@@ -1775,5 +1777,18 @@ TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) {
::testing::HasSubstr("Operand broadcast had no shape in HLO text"));
}
+TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
+ const string text =
+ R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_THAT(computation->root_instruction(),
+ op::Convolution(op::Parameter(0), op::Parameter(1)));
+ auto* convolution =
+ Cast<HloConvolutionInstruction>(computation->root_instruction());
+ EXPECT_EQ(convolution->feature_group_count(), 1);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index 34cba6136f..e3f4a9852a 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -422,6 +422,13 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const {
: false;
}
+size_t ShardingMetadata::Hash() const {
+ if (sharding_ != nullptr) {
+ return sharding_->Hash();
+ }
+ return static_cast<size_t>(0x297814aaad196e6dULL);
+}
+
string ShardingMetadata::ToString() const {
return sharding_ != nullptr ? sharding_->ToString() : "{}";
}
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
index cba5db927a..e3ae82a070 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
@@ -36,6 +36,8 @@ class ShardingMetadata : public DomainMetadata {
bool Matches(const DomainMetadata& other) const override;
+ size_t Hash() const override;
+
string ToString() const override;
const HloSharding* sharding() const { return sharding_.get(); }
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 95516dec74..069586a738 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -86,8 +86,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
const Shape expected,
ShapeInference::InferConvolveShape(
convolution->operand(0)->shape(), convolution->operand(1)->shape(),
- convolution->window(), convolution->convolution_dimension_numbers(),
- convolution->feature_group_count()));
+ convolution->feature_group_count(), convolution->window(),
+ convolution->convolution_dimension_numbers()));
return CheckShape(convolution, expected);
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index a4de02a890..4a71ee909b 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -165,6 +165,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
TF_ASSIGN_OR_RETURN(
computed_array,
ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(),
+ instr->precision_config(),
FindOrDie(cache_, instr->operand(0)),
FindOrDie(cache_, instr->operand(1))));
} else {
@@ -1030,6 +1031,7 @@ bool CanFoldDotIntoIndexedArray(
StatusOr<Analysis::Array*>
IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config,
ScalarIndexedConstantArray* lhs, ConstantArray* rhs) {
VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " "
<< ToString(rhs);
@@ -1045,9 +1047,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
new_dim_numbers.set_lhs_contracting_dimensions(
0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1));
- TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
- TakeOwnership(HloEvaluator{}.EvaluateDotOp(
- new_dim_numbers, lhs->literal(), *rhs->literal())));
+ TF_ASSIGN_OR_RETURN(
+ Literal * literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateDotOp(
+ new_dim_numbers, precision_config, lhs->literal(), *rhs->literal())));
// The new source dimension is wherever the non-batch non-contracting LHS
// dimension "went".
@@ -1063,7 +1066,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
StatusOr<Analysis::Array*>
IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
- ConstantArray* lhs, ScalarIndexedConstantArray* rhs) {
+ const PrecisionConfigProto& precision_config, ConstantArray* lhs,
+ ScalarIndexedConstantArray* rhs) {
VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " "
<< ToString(rhs);
if (!CanFoldDotIntoIndexedArray(
@@ -1079,9 +1083,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
new_dim_numbers.set_rhs_contracting_dimensions(
0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1));
- TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
- TakeOwnership(HloEvaluator{}.EvaluateDotOp(
- new_dim_numbers, *lhs->literal(), rhs->literal())));
+ TF_ASSIGN_OR_RETURN(
+ Literal * literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateDotOp(
+ new_dim_numbers, precision_config, *lhs->literal(), rhs->literal())));
// The new source dimension is wherever the non-batch non-contracting RHS
// dimension "went".
@@ -1095,8 +1100,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
}
StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
- const Shape& shape, const DotDimensionNumbers& dim_numbers, Array* lhs,
- Array* rhs) {
+ const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs) {
// Intuitively, if
//
// - The LHS of a dot product is a gathered sequence of rows from a constant
@@ -1119,6 +1124,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
dynamic_cast<ScalarIndexedConstantArray*>(lhs)) {
if (auto* rhs_constant = dynamic_cast<ConstantArray*>(rhs)) {
return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers,
+ precision_config,
lhs_indexed_array, rhs_constant);
}
}
@@ -1126,7 +1132,8 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
if (auto* rhs_indexed_array =
dynamic_cast<ScalarIndexedConstantArray*>(rhs)) {
if (auto* lhs_constant = dynamic_cast<ConstantArray*>(lhs)) {
- return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, lhs_constant,
+ return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers,
+ precision_config, lhs_constant,
rhs_indexed_array);
}
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index dcfb725535..f21e784a4d 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -267,15 +267,17 @@ class IndexedArrayAnalysis {
StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config,
ScalarIndexedConstantArray* lhs, ConstantArray* rhs);
StatusOr<Array*> ComputeArrayForDotWithIndexedRhs(
const Shape& shape, const DotDimensionNumbers& dim_numbers,
- ConstantArray* lhs, ScalarIndexedConstantArray* rhs);
+ const PrecisionConfigProto& precision_config, ConstantArray* lhs,
+ ScalarIndexedConstantArray* rhs);
- StatusOr<Array*> ComputeArrayForDot(const Shape& shape,
- const DotDimensionNumbers& dim_numbers,
- Array* lhs, Array* rhs);
+ StatusOr<Array*> ComputeArrayForDot(
+ const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs);
// This tries to fold a ScalarIndexedArray which has another
// ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 021fe630ff..69c7e42601 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -874,18 +874,18 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
)";
auto module = ParseHloString(module_str).ValueOrDie();
- module =
+ auto compiled_module =
backend()
.compiler()
->RunHloPasses(std::move(module), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
-
- auto copy = FindInstruction(module.get(), "copy.1");
- auto slice = FindInstruction(module.get(), "slice0");
- EXPECT_EQ(slice->operand(0), copy);
- EXPECT_TRUE(
- LayoutUtil::Equal(slice->shape().layout(), copy->shape().layout()));
+ HloInstruction* root =
+ compiled_module->entry_computation()->root_instruction();
+ Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
+ EXPECT_THAT(root, op::Add(op::Parameter(),
+ op::Slice(AllOf(op::Copy(op::Parameter(1)),
+ op::ShapeWithLayout(shape_copy)))));
}
TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
@@ -902,18 +902,20 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
)";
auto module = ParseHloString(module_str).ValueOrDie();
- module =
+ auto compiled_module =
backend()
.compiler()
->RunHloPasses(std::move(module), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
-
- auto copy = FindInstruction(module.get(), "copy.1");
- auto dslice = FindInstruction(module.get(), "dslice0");
- EXPECT_EQ(dslice->operand(0), copy);
- EXPECT_TRUE(
- LayoutUtil::Equal(dslice->shape().layout(), copy->shape().layout()));
+ HloInstruction* root =
+ compiled_module->entry_computation()->root_instruction();
+ Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
+ EXPECT_THAT(root,
+ op::Add(op::Parameter(),
+ op::DynamicSlice(AllOf(op::Copy(op::Parameter(1)),
+ op::ShapeWithLayout(shape_copy)),
+ op::Parameter(2))));
}
TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
@@ -931,18 +933,20 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
)";
auto module = ParseHloString(module_str).ValueOrDie();
- module =
+ auto compiled_module =
backend()
.compiler()
->RunHloPasses(std::move(module), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
-
- auto copy = FindInstruction(module.get(), "copy.1");
- auto concat = FindInstruction(module.get(), "concat0");
- EXPECT_EQ(concat->operand(0), copy);
- EXPECT_TRUE(
- LayoutUtil::Equal(concat->shape().layout(), copy->shape().layout()));
+ HloInstruction* root =
+ compiled_module->entry_computation()->root_instruction();
+ Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0});
+ EXPECT_THAT(root,
+ op::Add(op::Parameter(),
+ op::Concatenate(AllOf(op::Copy(op::Parameter(1)),
+ op::ShapeWithLayout(shape_copy)),
+ op::Parameter(2))));
}
TEST_F(LayoutAssignmentTest,
@@ -960,15 +964,39 @@ TEST_F(LayoutAssignmentTest,
)";
auto module = ParseHloString(module_str).ValueOrDie();
- module =
+ auto compiled_module =
backend()
.compiler()
->RunHloPasses(std::move(module), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
+ HloInstruction* root =
+ compiled_module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Convolution(op::Parameter(0), op::Parameter(1)));
+}
+
+TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
+ const char* module_str = R"(
+ HloModule PropagatingLayoutFromResultToOperand
+
+ ENTRY PropagatingLayoutFromResultToOperand {
+ par0 = f32[4,5]{1,0} parameter(0)
+ ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]}
+ }
+ )";
- auto copy = FindInstruction(module.get(), "copy.1");
- EXPECT_EQ(copy, nullptr);
+ auto module = ParseHloString(module_str).ValueOrDie();
+ auto compiled_module =
+ backend()
+ .compiler()
+ ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+ /*device_allocator=*/nullptr)
+ .ConsumeValueOrDie();
+ HloInstruction* root =
+ compiled_module->entry_computation()->root_instruction();
+ Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1});
+ EXPECT_THAT(root, op::Slice(AllOf(op::Copy(op::Parameter(0)),
+ op::ShapeWithLayout(shape_copy))));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 2611749862..74bdf2a2e3 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1552,8 +1552,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
- const Shape& lhs, const Shape& rhs, const Window& window,
- const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) {
+ const Shape& lhs, const Shape& rhs, int64 feature_group_count,
+ const Window& window, const ConvolutionDimensionNumbers& dnums) {
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
@@ -1672,6 +1672,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
dnums.DebugString());
}
+ if (kernel_output_features % feature_group_count > 0) {
+ return InvalidArgument(
+ "Expected output feature dimension (value %d) to be divisible by "
+ "feature_group_count (value %d); "
+ "got <conv>(%s, %s)\n"
+ "Dimension numbers: {%s}.",
+ kernel_output_features, feature_group_count,
+ ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
+ dnums.DebugString());
+ }
std::vector<int64> window_dims(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
window_dims[i] = window.dimensions(i).size();
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index a28345acef..96a0ee165d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -108,9 +108,9 @@ class ShapeInference {
// Infers the shape produced by applying the given convolutional
// filter (rhs) to lhs in the way specified by the fields on window.
static StatusOr<Shape> InferConvolveShape(
- const Shape& lhs, const Shape& rhs, const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers,
- int64 feature_group_count = 1);
+ const Shape& lhs, const Shape& rhs, int64 feature_group_count,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers);
// Infers the shape produced by the given FFT type on the given operand.
static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index cc92e58ef8..864ed43118 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -419,8 +419,8 @@ TEST_F(ShapeInferenceTest, Convolve) {
dim1->set_padding_high(0);
dim1->set_window_dilation(1);
dim1->set_base_dilation(1);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
@@ -464,8 +464,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
dim1->set_padding_high(1);
dim1->set_window_dilation(2);
dim1->set_base_dilation(1);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
@@ -509,8 +509,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
dim1->set_padding_high(1);
dim1->set_window_dilation(1);
dim1->set_base_dilation(2);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
@@ -547,8 +547,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
dim1->set_stride(2);
dim1->set_padding_low(1);
dim1->set_padding_high(1);
- auto inferred_status =
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ auto inferred_status = ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
HasSubstr("each dimension exactly once"));
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index 530f40e4b2..7c1f4b5cc6 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -108,8 +108,7 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) {
}
std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
- dot->shape(), new_lhs, new_rhs, new_dim_numbers);
- new_dot->set_precision_config(dot->precision_config());
+ dot->shape(), new_lhs, new_rhs, new_dim_numbers, dot->precision_config());
return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
}
@@ -178,8 +177,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
}
auto new_conv = HloInstruction::CreateConvolve(
- convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums);
- new_conv->set_precision_config(convolution.precision_config());
+ convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(),
+ convolution.window(), new_dnums, convolution.precision_config());
TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
&convolution, std::move(new_conv)));
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index 58f767e913..e486a00e53 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -215,6 +215,13 @@ ENTRY entry_computation {
/*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
}
+PrecisionConfigProto DefaultPrecisionConfig(int operands) {
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ operands, PrecisionConfigProto::DEFAULT);
+ return precision_config;
+}
+
// Test that a two dimension swap of the kernel gets folded into convolution.
TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
auto builder = HloComputation::Builder("entry_computation");
@@ -240,10 +247,12 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- x->shape(), transpose_y->shape(), window, dnums);
+ x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
+ conv_shape.ValueOrDie(), x, transpose_y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
@@ -293,10 +302,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- x->shape(), transpose_y->shape(), window, dnums);
+ x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
+ conv_shape.ValueOrDie(), x, transpose_y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
@@ -351,10 +362,12 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- transpose_x->shape(), y->shape(), window, dnums);
+ transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
+ conv_shape.ValueOrDie(), transpose_x, y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
@@ -415,10 +428,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
- transpose_x->shape(), y->shape(), window, dnums);
+ transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window,
+ dnums);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
+ conv_shape.ValueOrDie(), transpose_x, y,
+ /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewModule("test_module");
HloComputation* entry_computation =
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index a32d1f9026..e3328203a6 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -1064,8 +1064,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
+ PrecisionConfigProto precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ /*new_size=*/2, PrecisionConfigProto::DEFAULT);
auto dot = builder.AddInstruction(
- HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
+ HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));