diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc | 67 |
1 files changed, 43 insertions, 24 deletions
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index c0616809f9..7a75e5102c 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -79,7 +79,9 @@ struct ParsedProfileOutputLine { Status ParseOneProfileOutputLine( const string& line, bool expect_hlo, - gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results) { + gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results, + tensorflow::gtl::ArraySlice<tensorflow::StringPiece> opcodes_to_ignore = + {}) { string separator = "[^:]*:: +"; string match_percentage = "\\d+\\.\\d\\d%"; string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)"; @@ -113,7 +115,9 @@ Status ParseOneProfileOutputLine( ", Regexp: ", regexp_pattern); } - InsertOrDie(parsed_results, parsed_line.opcode, parsed_line); + if (!c_linear_search(opcodes_to_ignore, parsed_line.opcode)) { + InsertOrDie(parsed_results, parsed_line.opcode, parsed_line); + } return Status::OK(); } @@ -240,9 +244,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { EXPECT_TRUE(HasTrops(tanh_profile)); } -// TODO(b/71544591): The GPU backend does not record cycles spent in on Hlo -// instructions "interior" to while nodes. -XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { +XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { const int64 size = 256; Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size}); Shape while_result_shape = @@ -269,7 +271,7 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { auto matrix = GetTupleElement(state, 1); auto next_iteration = Add(GetTupleElement(state, 0), ConstantR0<int32>(&builder, 1)); - Tuple(&builder, {next_iteration, Add(matrix, matrix)}); + Tuple(&builder, {next_iteration, Mul(matrix, matrix)}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } @@ -291,36 +293,50 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { tensorflow::str_util::Split(profile_output, '\n'); auto while_body_profile_start = - std::find_if(profile_output_lines.begin(), profile_output_lines.end(), + c_find_if(profile_output_lines, [](tensorflow::StringPiece s) { + return tensorflow::str_util::StartsWith(s, + "Execution profile for body"); + }); + + ASSERT_NE(while_body_profile_start, profile_output_lines.cend()); + + auto while_body_profile_end = + std::find_if(while_body_profile_start, profile_output_lines.end(), [](tensorflow::StringPiece s) { return tensorflow::str_util::StartsWith( - s, "Execution profile for body"); + s, "********** microseconds report **********"); }); - ASSERT_NE(while_body_profile_start, profile_output_lines.end()); + // We emit a blank line before the "********** microseconds report **********" + // line. + while_body_profile_end--; - gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines; + ASSERT_NE(while_body_profile_end, profile_output_lines.end()); - TF_ASSERT_OK( - ParseOneProfileOutputLine(*std::next(while_body_profile_start, 1), - /*expect_hlo=*/false, &parsed_profile_lines)); + gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines; - TF_ASSERT_OK( - ParseOneProfileOutputLine(*std::next(while_body_profile_start, 2), - /*expect_hlo=*/true, &parsed_profile_lines)); + for (auto while_body_profile_i = while_body_profile_start + 1; + while_body_profile_i != while_body_profile_end; while_body_profile_i++) { + // There are multiple "get-tuple-element" instructions in the while body so + // we ignore them -- we don't want parsed_profile_lines to be a multi-map. + TF_ASSERT_OK(ParseOneProfileOutputLine( + *while_body_profile_i, + /*expect_hlo=*/while_body_profile_i != (while_body_profile_start + 1), + &parsed_profile_lines, {"get-tuple-element"})); + } TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile, MaybeFind(parsed_profile_lines, "[total]")); - TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile, - MaybeFind(parsed_profile_lines, "add")); + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine multiply_profile, + MaybeFind(parsed_profile_lines, "multiply")); EXPECT_GT(total_while_body_profile.cycles, 0); EXPECT_EQ(total_while_body_profile.opcode, "[total]"); EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%"); - EXPECT_GT(total_while_body_profile.cycles, dot_profile.cycles); - EXPECT_NE(dot_profile.cycles_percentage, "0.00%"); - EXPECT_NE(dot_profile.cycles_percentage, "100.00%"); + EXPECT_GT(total_while_body_profile.cycles, multiply_profile.cycles); + EXPECT_NE(multiply_profile.cycles_percentage, "0.00%"); + EXPECT_NE(multiply_profile.cycles_percentage, "100.00%"); } } // namespace } // namespace xla @@ -337,8 +353,11 @@ static std::pair<int, char**> AddXlaHloProfileFlag(int argc, char** argv) { new_argv[argc] = strdup("--xla_hlo_profile"); // Fusion can change the Hlo instructions that show up in the final Hlo - // executable, so block it here. - new_argv[argc + 1] = strdup("--xla_disable_hlo_passes=fusion"); + // executable, so block it here. Also block the WhileLoopInvariantCodeMotion + // pass, otherwise a while loop is transformed and we could not match the + // original name in the ProfileWhileComputation test. + new_argv[argc + 1] = strdup( + "--xla_disable_hlo_passes=fusion,while-loop-invariant-code-motion"); return {argc + 2, new_argv}; } |