aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc67
1 files changed, 43 insertions, 24 deletions
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index c0616809f9..7a75e5102c 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -79,7 +79,9 @@ struct ParsedProfileOutputLine {
Status ParseOneProfileOutputLine(
const string& line, bool expect_hlo,
- gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results) {
+ gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results,
+ tensorflow::gtl::ArraySlice<tensorflow::StringPiece> opcodes_to_ignore =
+ {}) {
string separator = "[^:]*:: +";
string match_percentage = "\\d+\\.\\d\\d%";
string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)";
@@ -113,7 +115,9 @@ Status ParseOneProfileOutputLine(
", Regexp: ", regexp_pattern);
}
- InsertOrDie(parsed_results, parsed_line.opcode, parsed_line);
+ if (!c_linear_search(opcodes_to_ignore, parsed_line.opcode)) {
+ InsertOrDie(parsed_results, parsed_line.opcode, parsed_line);
+ }
return Status::OK();
}
@@ -240,9 +244,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) {
EXPECT_TRUE(HasTrops(tanh_profile));
}
-// TODO(b/71544591): The GPU backend does not record cycles spent in on Hlo
-// instructions "interior" to while nodes.
-XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) {
+XLA_TEST_F(HloProfileTest, ProfileWhileComputation) {
const int64 size = 256;
Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size});
Shape while_result_shape =
@@ -269,7 +271,7 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) {
auto matrix = GetTupleElement(state, 1);
auto next_iteration =
Add(GetTupleElement(state, 0), ConstantR0<int32>(&builder, 1));
- Tuple(&builder, {next_iteration, Add(matrix, matrix)});
+ Tuple(&builder, {next_iteration, Mul(matrix, matrix)});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
@@ -291,36 +293,50 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) {
tensorflow::str_util::Split(profile_output, '\n');
auto while_body_profile_start =
- std::find_if(profile_output_lines.begin(), profile_output_lines.end(),
+ c_find_if(profile_output_lines, [](tensorflow::StringPiece s) {
+ return tensorflow::str_util::StartsWith(s,
+ "Execution profile for body");
+ });
+
+ ASSERT_NE(while_body_profile_start, profile_output_lines.cend());
+
+ auto while_body_profile_end =
+ std::find_if(while_body_profile_start, profile_output_lines.end(),
[](tensorflow::StringPiece s) {
return tensorflow::str_util::StartsWith(
- s, "Execution profile for body");
+ s, "********** microseconds report **********");
});
- ASSERT_NE(while_body_profile_start, profile_output_lines.end());
+ // We emit a blank line before the "********** microseconds report **********"
+ // line.
+ while_body_profile_end--;
- gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
+ ASSERT_NE(while_body_profile_end, profile_output_lines.end());
- TF_ASSERT_OK(
- ParseOneProfileOutputLine(*std::next(while_body_profile_start, 1),
- /*expect_hlo=*/false, &parsed_profile_lines));
+ gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
- TF_ASSERT_OK(
- ParseOneProfileOutputLine(*std::next(while_body_profile_start, 2),
- /*expect_hlo=*/true, &parsed_profile_lines));
+ for (auto while_body_profile_i = while_body_profile_start + 1;
+ while_body_profile_i != while_body_profile_end; while_body_profile_i++) {
+ // There are multiple "get-tuple-element" instructions in the while body so
+ // we ignore them -- we don't want parsed_profile_lines to be a multi-map.
+ TF_ASSERT_OK(ParseOneProfileOutputLine(
+ *while_body_profile_i,
+ /*expect_hlo=*/while_body_profile_i != (while_body_profile_start + 1),
+ &parsed_profile_lines, {"get-tuple-element"}));
+ }
TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile,
MaybeFind(parsed_profile_lines, "[total]"));
- TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile,
- MaybeFind(parsed_profile_lines, "add"));
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine multiply_profile,
+ MaybeFind(parsed_profile_lines, "multiply"));
EXPECT_GT(total_while_body_profile.cycles, 0);
EXPECT_EQ(total_while_body_profile.opcode, "[total]");
EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%");
- EXPECT_GT(total_while_body_profile.cycles, dot_profile.cycles);
- EXPECT_NE(dot_profile.cycles_percentage, "0.00%");
- EXPECT_NE(dot_profile.cycles_percentage, "100.00%");
+ EXPECT_GT(total_while_body_profile.cycles, multiply_profile.cycles);
+ EXPECT_NE(multiply_profile.cycles_percentage, "0.00%");
+ EXPECT_NE(multiply_profile.cycles_percentage, "100.00%");
}
} // namespace
} // namespace xla
@@ -337,8 +353,11 @@ static std::pair<int, char**> AddXlaHloProfileFlag(int argc, char** argv) {
new_argv[argc] = strdup("--xla_hlo_profile");
// Fusion can change the Hlo instructions that show up in the final Hlo
- // executable, so block it here.
- new_argv[argc + 1] = strdup("--xla_disable_hlo_passes=fusion");
+ // executable, so block it here. Also block the WhileLoopInvariantCodeMotion
+ // pass, otherwise a while loop is transformed and we could not match the
+ // original name in the ProfileWhileComputation test.
+ new_argv[argc + 1] = strdup(
+ "--xla_disable_hlo_passes=fusion,while-loop-invariant-code-motion");
return {argc + 2, new_argv};
}