aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc55
1 files changed, 54 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 2c854eea18..d76ce9ecbc 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -203,6 +203,35 @@ TEST_F(HloCostAnalysisTest, Convolution) {
sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18));
}
+TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) {
+ XlaBuilder builder("convolution");
+ auto input = Parameter(
+ &builder, 0,
+ ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/120, /*y_dim=*/10,
+ /*x_dim=*/20}),
+ "input");
+ auto kernel = Parameter(
+ &builder, 1,
+ ShapeUtil::MakeShape(F32, {/*p_dim=*/120, /*z_dim=*/1, /*y_dim=*/3,
+ /*x_dim=*/3}),
+ "kernel");
+ Conv(input, kernel, {1, 1}, Padding::kValid, /*feature_group_count=*/120);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis(ShapeSize);
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // Output shape is [1x120x8x18] and each output element requires (3x3)
+ // FMAs and one FMA is 2 flops.
+ EXPECT_EQ(analysis.flop_count(), 120 * 8 * 18 * 2 * 3 * 3);
+
+ // Bytes accessed is sum of inputs and output.
+ EXPECT_EQ(analysis.bytes_accessed(),
+ sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18));
+}
+
TEST_F(HloCostAnalysisTest, Reduce) {
XlaBuilder builder("reduce");
auto input =
@@ -415,7 +444,7 @@ TEST_F(FusionCostAnalysis, NoLayout) {
TEST_F(HloCostAnalysisTest, TupleCost) {
HloCostAnalysis analysis(ShapeSize);
{
- XlaBuilder builder("matmul");
+ XlaBuilder builder("tuple");
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x");
auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y");
Tuple(&builder, {x, y});
@@ -430,6 +459,30 @@ TEST_F(HloCostAnalysisTest, TupleCost) {
EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2);
}
+using DomainCostAnalysis = HloTestBase;
+TEST_F(DomainCostAnalysis, DomainCost) {
+ HloCostAnalysis analysis(ShapeSize);
+
+ HloComputation::Builder builder("domain");
+ auto x = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {123}), "x"));
+ auto y = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y"));
+ auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y}));
+ auto domain = builder.AddInstruction(
+ HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr));
+
+ auto hlo_module = CreateNewModule();
+ hlo_module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain);
+ ASSERT_IS_OK(domain->Accept(&analysis));
+
+ EXPECT_EQ(analysis.flop_count(*domain), 0);
+ EXPECT_EQ(analysis.transcendental_count(*domain), 0);
+ EXPECT_EQ(analysis.bytes_accessed(*domain), 0);
+}
+
TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
XlaBuilder builder("BaseDilatedConvolution");
auto input = Parameter(