diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc | 55 |
1 files changed, 54 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 2c854eea18..d76ce9ecbc 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -203,6 +203,35 @@ TEST_F(HloCostAnalysisTest, Convolution) { sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18)); } +TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) { + XlaBuilder builder("convolution"); + auto input = Parameter( + &builder, 0, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/120, /*y_dim=*/10, + /*x_dim=*/20}), + "input"); + auto kernel = Parameter( + &builder, 1, + ShapeUtil::MakeShape(F32, {/*p_dim=*/120, /*z_dim=*/1, /*y_dim=*/3, + /*x_dim=*/3}), + "kernel"); + Conv(input, kernel, {1, 1}, Padding::kValid, /*feature_group_count=*/120); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Output shape is [1x120x8x18] and each output element requires (3x3) + // FMAs and one FMA is 2 flops. + EXPECT_EQ(analysis.flop_count(), 120 * 8 * 18 * 2 * 3 * 3); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18)); +} + TEST_F(HloCostAnalysisTest, Reduce) { XlaBuilder builder("reduce"); auto input = @@ -415,7 +444,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { - XlaBuilder builder("matmul"); + XlaBuilder builder("tuple"); auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x"); auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y"); Tuple(&builder, {x, y}); @@ -430,6 +459,30 @@ TEST_F(HloCostAnalysisTest, TupleCost) { EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2); } +using DomainCostAnalysis = HloTestBase; +TEST_F(DomainCostAnalysis, DomainCost) { + HloCostAnalysis analysis(ShapeSize); + + HloComputation::Builder builder("domain"); + auto x = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {123}), "x")); + auto y = builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y")); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y})); + auto domain = builder.AddInstruction( + HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain); + ASSERT_IS_OK(domain->Accept(&analysis)); + + EXPECT_EQ(analysis.flop_count(*domain), 0); + EXPECT_EQ(analysis.transcendental_count(*domain), 0); + EXPECT_EQ(analysis.bytes_accessed(*domain), 0); +} + TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { XlaBuilder builder("BaseDilatedConvolution"); auto input = Parameter( |