aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-01-25 11:17:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-25 11:21:05 -0800
commita8c4e8d96de7c0978851a5f9718bbd6b8056d862 (patch)
tree2d1a77cbd07b8f578c5ae5e1c749ab30403e5cfa /tensorflow
parent949dd29d3a8bdc21328c9e94721b344310686eab (diff)
[XLA] Make xla_hlo_profile_test less flaky
Instead of relying on some oeprations always taking longer than others (and this appearing in a specific order in the rendered HLO profile), pick them out by opcode. PiperOrigin-RevId: 183268593
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/map_util.h21
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc125
4 files changed, 112 insertions, 36 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 438f1443f1..c22fd37129 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -182,6 +182,7 @@ cc_library(
deps = [
":status",
":status_macros",
+ ":statusor",
":types",
":xla_data_proto",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h
index 50659c1240..0ad0b91330 100644
--- a/tensorflow/compiler/xla/map_util.h
+++ b/tensorflow/compiler/xla/map_util.h
@@ -16,6 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_
+#include <functional>
+#include <sstream>
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -44,6 +49,22 @@ typename Collection::value_type::second_type& FindOrDie(
return it->second;
}
+// Like FindOrDie but returns an error instead of dying if `key` is not in
+// `container`.
+template <class Collection>
+StatusOr<
+ std::reference_wrapper<const typename Collection::value_type::second_type>>
+MaybeFind(const Collection& collection,
+ const typename Collection::value_type::first_type& key) {
+ typename Collection::const_iterator it = collection.find(key);
+ if (it == collection.end()) {
+ std::ostringstream os;
+ os << key;
+ return NotFound("key not found: %s", os.str().c_str());
+ }
+ return {it->second};
+}
+
// Inserts the key-value pair into the collection. Dies if key was already
// present.
template <class Collection>
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 02ad9d982f..ac11081699 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -351,6 +351,7 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:platform_util",
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 1d2f436194..9ad2a19853 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -19,12 +19,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -32,6 +34,7 @@ limitations under the License.
namespace xla {
namespace {
namespace se = ::perftools::gputools;
+namespace gtl = ::tensorflow::gtl;
class HloProfileTest : public ClientLibraryTestBase {};
@@ -43,39 +46,74 @@ struct ParsedProfileOutputLine {
string trops;
string bytes_per_sec;
string bytes_per_cycle;
- string name;
+ string opcode;
};
-StatusOr<ParsedProfileOutputLine> ParseProfileOutputLine(const string& line,
- bool expect_flops,
- bool expect_trops) {
+::testing::AssertionResult HasFlops(
+ const ParsedProfileOutputLine& parsed_line) {
+ if (RE2::FullMatch(parsed_line.flops, "[0-9.TGMk]+FLOP/s")) {
+ return ::testing::AssertionSuccess()
+ << "'flops' field present in " << parsed_line.opcode << ": '"
+ << parsed_line.flops << "'";
+ }
+
+ return ::testing::AssertionFailure()
+ << "'flops' field absent in " << parsed_line.opcode << ": '"
+ << parsed_line.flops << "'";
+}
+
+::testing::AssertionResult HasTrops(
+ const ParsedProfileOutputLine& parsed_line) {
+ if (RE2::FullMatch(parsed_line.trops, "[0-9.TGMk]+TROP/s")) {
+ return ::testing::AssertionSuccess()
+ << "'trops' field present in " << parsed_line.opcode << ": '"
+ << parsed_line.trops << "'";
+ }
+
+ return ::testing::AssertionFailure()
+ << "'trops' field absent in " << parsed_line.opcode << ": '"
+ << parsed_line.trops << "'";
+}
+
+Status ParseOneProfileOutputLine(
+ const string& line, bool expect_hlo,
+ gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results) {
string separator = "[^:]*:: +";
string match_percentage = "\\d+\\.\\d\\d%";
string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)";
string match_usecs = "([0-9.]+) usec";
- string match_flops = expect_flops ? "([0-9.TGMk]+)FLOP/s" : "(<none>)";
- string match_trops = expect_trops ? "([0-9.TGMk]+)TROP/s" : "(<none>)";
+ string match_flops = "([^ ]+)";
+ string match_trops = "([^ ]+)";
string match_bytes_per_sec = "([0-9.TGMKi]+)B/s";
string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle";
+
+ // The underlined part is what we're trying to match with match_opcode:
+ //
+ // %dot33 = f32[256,256]{1,0} dot(...)
+ // ^^^
+
+ string match_opcode =
+ expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" : "(\\[total\\])";
string regexp_pattern = tensorflow::strings::StrCat(
" +", match_cycles, separator, match_usecs, separator, match_flops,
separator, match_trops, separator, match_bytes_per_sec, separator,
- match_bytes_per_cycle, separator, "(.*)");
+ match_bytes_per_cycle, separator, match_opcode);
- RE2 pattern(regexp_pattern);
ParsedProfileOutputLine parsed_line;
bool matched = RE2::FullMatch(
- line, pattern, &parsed_line.cycles, &parsed_line.cycles_percentage,
+ line, regexp_pattern, &parsed_line.cycles, &parsed_line.cycles_percentage,
&parsed_line.usec, &parsed_line.flops, &parsed_line.trops,
&parsed_line.bytes_per_sec, &parsed_line.bytes_per_cycle,
- &parsed_line.name);
+ &parsed_line.opcode);
if (!matched) {
return tensorflow::errors::InvalidArgument(
"Input did not match regexp. Input: ", line,
", Regexp: ", regexp_pattern);
}
- return parsed_line;
+ InsertOrDie(parsed_results, parsed_line.opcode, parsed_line);
+
+ return Status::OK();
}
// Returns void so that we can ASSERT.
@@ -148,7 +186,7 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) {
ClientLibrary::GetOrCreateLocalClient(platform));
ComputationBuilder builder(client, TestName());
- auto result = builder.Tanh(builder.Dot(
+ auto result = builder.Tanh(builder.Add(
builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"),
builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs")));
@@ -161,31 +199,43 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) {
std::vector<string> profile_output_lines =
tensorflow::str_util::Split(profile_output, '\n');
- TF_ASSERT_OK_AND_ASSIGN(
- ParsedProfileOutputLine total_profile,
- ParseProfileOutputLine(profile_output_lines[1], /*expect_flops=*/true,
- /*expect_trops=*/true));
+ gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
- TF_ASSERT_OK_AND_ASSIGN(
- ParsedProfileOutputLine dot_profile,
- ParseProfileOutputLine(profile_output_lines[2], /*expect_flops=*/true,
- /*expect_trops=*/false));
+ TF_ASSERT_OK(ParseOneProfileOutputLine(
+ profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines));
- TF_ASSERT_OK_AND_ASSIGN(
- ParsedProfileOutputLine tanh_profile,
- ParseProfileOutputLine(profile_output_lines[3], /*expect_flops=*/false,
- /*expect_trops=*/true));
+ TF_ASSERT_OK(ParseOneProfileOutputLine(
+ profile_output_lines[2], /*expect_hlo=*/true, &parsed_profile_lines));
+
+ TF_ASSERT_OK(ParseOneProfileOutputLine(
+ profile_output_lines[3], /*expect_hlo=*/true, &parsed_profile_lines));
+
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_profile,
+ MaybeFind(parsed_profile_lines, "[total]"));
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile,
+ MaybeFind(parsed_profile_lines, "add"));
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine tanh_profile,
+ MaybeFind(parsed_profile_lines, "tanh"));
EXPECT_GT(total_profile.cycles, 0);
EXPECT_EQ(total_profile.cycles_percentage, "100.00%");
+ EXPECT_TRUE(HasFlops(total_profile));
+ EXPECT_TRUE(HasTrops(total_profile));
+
EXPECT_GT(total_profile.cycles, dot_profile.cycles);
EXPECT_NE(dot_profile.cycles_percentage, "0.00%");
EXPECT_NE(dot_profile.cycles_percentage, "100.00%");
+ EXPECT_TRUE(HasFlops(dot_profile));
+ EXPECT_FALSE(HasTrops(dot_profile));
+
EXPECT_GT(total_profile.cycles, tanh_profile.cycles);
EXPECT_NE(tanh_profile.cycles_percentage, "0.00%");
EXPECT_NE(tanh_profile.cycles_percentage, "100.00%");
+
+ EXPECT_FALSE(HasFlops(tanh_profile));
+ EXPECT_TRUE(HasTrops(tanh_profile));
}
// TODO(b/71364943): This test exposes a bug in the parallel CPU backend.
@@ -220,7 +270,7 @@ XLA_TEST_F(HloProfileTest,
auto matrix = builder.GetTupleElement(state, 1);
auto next_iteration = builder.Add(builder.GetTupleElement(state, 0),
builder.ConstantR0<int32>(1));
- builder.Tuple({next_iteration, builder.Dot(matrix, matrix)});
+ builder.Tuple({next_iteration, builder.Add(matrix, matrix)});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
@@ -249,20 +299,23 @@ XLA_TEST_F(HloProfileTest,
ASSERT_NE(while_body_profile_start, profile_output_lines.end());
- TF_ASSERT_OK_AND_ASSIGN(
- ParsedProfileOutputLine total_while_body_profile,
- ParseProfileOutputLine(*std::next(while_body_profile_start, 1),
- /*expect_flops=*/false,
- /*expect_trops=*/false));
+ gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
- TF_ASSERT_OK_AND_ASSIGN(
- ParsedProfileOutputLine dot_profile,
- ParseProfileOutputLine(*std::next(while_body_profile_start, 2),
- /*expect_flops=*/false,
- /*expect_trops=*/false));
+ TF_ASSERT_OK(
+ ParseOneProfileOutputLine(*std::next(while_body_profile_start, 1),
+ /*expect_hlo=*/false, &parsed_profile_lines));
+
+ TF_ASSERT_OK(
+ ParseOneProfileOutputLine(*std::next(while_body_profile_start, 2),
+ /*expect_hlo=*/true, &parsed_profile_lines));
+
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile,
+ MaybeFind(parsed_profile_lines, "[total]"));
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile,
+ MaybeFind(parsed_profile_lines, "add"));
EXPECT_GT(total_while_body_profile.cycles, 0);
- EXPECT_EQ(total_while_body_profile.name, "[total]");
+ EXPECT_EQ(total_while_body_profile.opcode, "[total]");
EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%");
EXPECT_GT(total_while_body_profile.cycles, dot_profile.cycles);