diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-01-25 11:17:08 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-25 11:21:05 -0800 |
commit | a8c4e8d96de7c0978851a5f9718bbd6b8056d862 (patch) | |
tree | 2d1a77cbd07b8f578c5ae5e1c749ab30403e5cfa /tensorflow | |
parent | 949dd29d3a8bdc21328c9e94721b344310686eab (diff) |
[XLA] Make xla_hlo_profile_test less flaky
Instead of relying on some oeprations always taking longer than others (and this
appearing in a specific order in the rendered HLO profile), pick them out by
opcode.
PiperOrigin-RevId: 183268593
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/compiler/xla/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/map_util.h | 21 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc | 125 |
4 files changed, 112 insertions, 36 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 438f1443f1..c22fd37129 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -182,6 +182,7 @@ cc_library( deps = [ ":status", ":status_macros", + ":statusor", ":types", ":xla_data_proto", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h index 50659c1240..0ad0b91330 100644 --- a/tensorflow/compiler/xla/map_util.h +++ b/tensorflow/compiler/xla/map_util.h @@ -16,6 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_ +#include <functional> +#include <sstream> + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -44,6 +49,22 @@ typename Collection::value_type::second_type& FindOrDie( return it->second; } +// Like FindOrDie but returns an error instead of dying if `key` is not in +// `container`. +template <class Collection> +StatusOr< + std::reference_wrapper<const typename Collection::value_type::second_type>> +MaybeFind(const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + std::ostringstream os; + os << key; + return NotFound("key not found: %s", os.str().c_str()); + } + return {it->second}; +} + // Inserts the key-value pair into the collection. Dies if key was already // present. template <class Collection> diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 02ad9d982f..ac11081699 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -351,6 +351,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:platform_util", diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 1d2f436194..9ad2a19853 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -19,12 +19,14 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -32,6 +34,7 @@ limitations under the License. namespace xla { namespace { namespace se = ::perftools::gputools; +namespace gtl = ::tensorflow::gtl; class HloProfileTest : public ClientLibraryTestBase {}; @@ -43,39 +46,74 @@ struct ParsedProfileOutputLine { string trops; string bytes_per_sec; string bytes_per_cycle; - string name; + string opcode; }; -StatusOr<ParsedProfileOutputLine> ParseProfileOutputLine(const string& line, - bool expect_flops, - bool expect_trops) { +::testing::AssertionResult HasFlops( + const ParsedProfileOutputLine& parsed_line) { + if (RE2::FullMatch(parsed_line.flops, "[0-9.TGMk]+FLOP/s")) { + return ::testing::AssertionSuccess() + << "'flops' field present in " << parsed_line.opcode << ": '" + << parsed_line.flops << "'"; + } + + return ::testing::AssertionFailure() + << "'flops' field absent in " << parsed_line.opcode << ": '" + << parsed_line.flops << "'"; +} + +::testing::AssertionResult HasTrops( + const ParsedProfileOutputLine& parsed_line) { + if (RE2::FullMatch(parsed_line.trops, "[0-9.TGMk]+TROP/s")) { + return ::testing::AssertionSuccess() + << "'trops' field present in " << parsed_line.opcode << ": '" + << parsed_line.trops << "'"; + } + + return ::testing::AssertionFailure() + << "'trops' field absent in " << parsed_line.opcode << ": '" + << parsed_line.trops << "'"; +} + +Status ParseOneProfileOutputLine( + const string& line, bool expect_hlo, + gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results) { string separator = "[^:]*:: +"; string match_percentage = "\\d+\\.\\d\\d%"; string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)"; string match_usecs = "([0-9.]+) usec"; - string match_flops = expect_flops ? "([0-9.TGMk]+)FLOP/s" : "(<none>)"; - string match_trops = expect_trops ? "([0-9.TGMk]+)TROP/s" : "(<none>)"; + string match_flops = "([^ ]+)"; + string match_trops = "([^ ]+)"; string match_bytes_per_sec = "([0-9.TGMKi]+)B/s"; string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle"; + + // The underlined part is what we're trying to match with match_opcode: + // + // %dot33 = f32[256,256]{1,0} dot(...) + // ^^^ + + string match_opcode = + expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" : "(\\[total\\])"; string regexp_pattern = tensorflow::strings::StrCat( " +", match_cycles, separator, match_usecs, separator, match_flops, separator, match_trops, separator, match_bytes_per_sec, separator, - match_bytes_per_cycle, separator, "(.*)"); + match_bytes_per_cycle, separator, match_opcode); - RE2 pattern(regexp_pattern); ParsedProfileOutputLine parsed_line; bool matched = RE2::FullMatch( - line, pattern, &parsed_line.cycles, &parsed_line.cycles_percentage, + line, regexp_pattern, &parsed_line.cycles, &parsed_line.cycles_percentage, &parsed_line.usec, &parsed_line.flops, &parsed_line.trops, &parsed_line.bytes_per_sec, &parsed_line.bytes_per_cycle, - &parsed_line.name); + &parsed_line.opcode); if (!matched) { return tensorflow::errors::InvalidArgument( "Input did not match regexp. Input: ", line, ", Regexp: ", regexp_pattern); } - return parsed_line; + InsertOrDie(parsed_results, parsed_line.opcode, parsed_line); + + return Status::OK(); } // Returns void so that we can ASSERT. @@ -148,7 +186,7 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) { ClientLibrary::GetOrCreateLocalClient(platform)); ComputationBuilder builder(client, TestName()); - auto result = builder.Tanh(builder.Dot( + auto result = builder.Tanh(builder.Add( builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"), builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs"))); @@ -161,31 +199,43 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) { std::vector<string> profile_output_lines = tensorflow::str_util::Split(profile_output, '\n'); - TF_ASSERT_OK_AND_ASSIGN( - ParsedProfileOutputLine total_profile, - ParseProfileOutputLine(profile_output_lines[1], /*expect_flops=*/true, - /*expect_trops=*/true)); + gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines; - TF_ASSERT_OK_AND_ASSIGN( - ParsedProfileOutputLine dot_profile, - ParseProfileOutputLine(profile_output_lines[2], /*expect_flops=*/true, - /*expect_trops=*/false)); + TF_ASSERT_OK(ParseOneProfileOutputLine( + profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines)); - TF_ASSERT_OK_AND_ASSIGN( - ParsedProfileOutputLine tanh_profile, - ParseProfileOutputLine(profile_output_lines[3], /*expect_flops=*/false, - /*expect_trops=*/true)); + TF_ASSERT_OK(ParseOneProfileOutputLine( + profile_output_lines[2], /*expect_hlo=*/true, &parsed_profile_lines)); + + TF_ASSERT_OK(ParseOneProfileOutputLine( + profile_output_lines[3], /*expect_hlo=*/true, &parsed_profile_lines)); + + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_profile, + MaybeFind(parsed_profile_lines, "[total]")); + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile, + MaybeFind(parsed_profile_lines, "add")); + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine tanh_profile, + MaybeFind(parsed_profile_lines, "tanh")); EXPECT_GT(total_profile.cycles, 0); EXPECT_EQ(total_profile.cycles_percentage, "100.00%"); + EXPECT_TRUE(HasFlops(total_profile)); + EXPECT_TRUE(HasTrops(total_profile)); + EXPECT_GT(total_profile.cycles, dot_profile.cycles); EXPECT_NE(dot_profile.cycles_percentage, "0.00%"); EXPECT_NE(dot_profile.cycles_percentage, "100.00%"); + EXPECT_TRUE(HasFlops(dot_profile)); + EXPECT_FALSE(HasTrops(dot_profile)); + EXPECT_GT(total_profile.cycles, tanh_profile.cycles); EXPECT_NE(tanh_profile.cycles_percentage, "0.00%"); EXPECT_NE(tanh_profile.cycles_percentage, "100.00%"); + + EXPECT_FALSE(HasFlops(tanh_profile)); + EXPECT_TRUE(HasTrops(tanh_profile)); } // TODO(b/71364943): This test exposes a bug in the parallel CPU backend. @@ -220,7 +270,7 @@ XLA_TEST_F(HloProfileTest, auto matrix = builder.GetTupleElement(state, 1); auto next_iteration = builder.Add(builder.GetTupleElement(state, 0), builder.ConstantR0<int32>(1)); - builder.Tuple({next_iteration, builder.Dot(matrix, matrix)}); + builder.Tuple({next_iteration, builder.Add(matrix, matrix)}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } @@ -249,20 +299,23 @@ XLA_TEST_F(HloProfileTest, ASSERT_NE(while_body_profile_start, profile_output_lines.end()); - TF_ASSERT_OK_AND_ASSIGN( - ParsedProfileOutputLine total_while_body_profile, - ParseProfileOutputLine(*std::next(while_body_profile_start, 1), - /*expect_flops=*/false, - /*expect_trops=*/false)); + gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines; - TF_ASSERT_OK_AND_ASSIGN( - ParsedProfileOutputLine dot_profile, - ParseProfileOutputLine(*std::next(while_body_profile_start, 2), - /*expect_flops=*/false, - /*expect_trops=*/false)); + TF_ASSERT_OK( + ParseOneProfileOutputLine(*std::next(while_body_profile_start, 1), + /*expect_hlo=*/false, &parsed_profile_lines)); + + TF_ASSERT_OK( + ParseOneProfileOutputLine(*std::next(while_body_profile_start, 2), + /*expect_hlo=*/true, &parsed_profile_lines)); + + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile, + MaybeFind(parsed_profile_lines, "[total]")); + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile, + MaybeFind(parsed_profile_lines, "add")); EXPECT_GT(total_while_body_profile.cycles, 0); - EXPECT_EQ(total_while_body_profile.name, "[total]"); + EXPECT_EQ(total_while_body_profile.opcode, "[total]"); EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%"); EXPECT_GT(total_while_body_profile.cycles, dot_profile.cycles); |