aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-08-03 08:06:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-03 08:09:45 -0700
commitb9b45d21aff32a1252ec1929330ae5ebd1e5571f (patch)
tree09187708bc3eef0ede7995f4b1680efd8e8d71be
parentba7516eed03d839f8cc1800948e884a117a88f2d (diff)
[XLA] Add test blacklist mechanism for XLA C++ unit tests.
PiperOrigin-RevId: 164124423
-rw-r--r--tensorflow/compiler/xla/tests/build_defs.bzl6
-rw-r--r--tensorflow/compiler/xla/tests/test_macros.cc97
-rw-r--r--tensorflow/compiler/xla/tests/test_macros.h96
3 files changed, 189 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl
index dae0956f0a..7b707cd360 100644
--- a/tensorflow/compiler/xla/tests/build_defs.bzl
+++ b/tensorflow/compiler/xla/tests/build_defs.bzl
@@ -230,8 +230,12 @@ def generate_backend_test_macros(backends=[]):
native.cc_library(
name="test_macros_%s" % backend,
testonly = True,
+ srcs = ["test_macros.cc"],
hdrs = ["test_macros.h"],
- copts = ["-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper()],
+ copts = [
+ "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
+ "-DXLA_DISABLED_MANIFEST=\\\"\\\""
+ ],
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc
new file mode 100644
index 0000000000..173fb1b000
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/test_macros.cc
@@ -0,0 +1,97 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+
+#include <fstream>
+#include <streambuf>
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace xla {
+namespace {
+
+// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
+// disabled.
+using ManifestT = std::unordered_map<string, std::vector<string>>;
+
+ManifestT ReadManifest() {
+ ManifestT manifest;
+
+ string path = XLA_DISABLED_MANIFEST;
+ if (path.empty()) {
+ return manifest;
+ }
+
+ std::ifstream file_stream(path);
+ // Note: parens are required to disambiguate vs function decl.
+ string contents((std::istreambuf_iterator<char>(file_stream)),
+ std::istreambuf_iterator<char>());
+
+ std::vector<string> lines = tensorflow::str_util::Split(contents, '\n');
+ for (string& line : lines) {
+ auto comment = line.find("//");
+ if (comment != string::npos) {
+ line = line.substr(0, comment);
+ }
+ if (line.empty()) {
+ continue;
+ }
+ tensorflow::str_util::StripTrailingWhitespace(&line);
+ std::vector<string> pieces = tensorflow::str_util::Split(line, ' ');
+ CHECK_GE(pieces.size(), 1);
+ auto& platforms = manifest[pieces[0]];
+ for (int64 i = 1; i < pieces.size(); ++i) {
+ platforms.push_back(pieces[i]);
+ }
+ }
+ return manifest;
+}
+
+} // namespace
+
+string PrependDisabledIfIndicated(const string& test_case_name,
+ const string& test_name) {
+ // TODO(leary): this code reads the manifest for every test case instantiated
+ // in every file. Consider switching to a singleton or using a compile-time
+ // genrule instead.
+ ManifestT manifest = ReadManifest();
+
+ // First try full match: test_case_name.test_name
+ // If that fails, try to find just the test_case_name; this would disable all
+ // tests in the test case.
+ auto it = manifest.find(
+ tensorflow::strings::StrCat(test_case_name, ".", test_name));
+ if (it == manifest.end()) {
+ it = manifest.find(test_case_name);
+ if (it == manifest.end()) {
+ return test_name;
+ }
+ }
+
+ const std::vector<string>& disabled_platforms = it->second;
+ string platform_string = XLA_PLATFORM;
+ if (std::find(disabled_platforms.begin(), disabled_platforms.end(),
+ platform_string) != disabled_platforms.end()) {
+ return "DISABLED_" + test_name;
+ }
+
+ // We didn't hit in the disabled manifest entries, so don't disable it.
+ return test_name;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h
index 7f987a21ca..3878ac1013 100644
--- a/tensorflow/compiler/xla/tests/test_macros.h
+++ b/tensorflow/compiler/xla/tests/test_macros.h
@@ -33,13 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/test.h"
-// Use this macro instead of directly using TEST_P for parameterized tests,
-// otherwise DISABLED_ON_* macros nested in TEST_P will not get expanded since
-// TEST_P stringifies its argument. That makes the test disabled for all targets
-// when any one of the DISABLED_ON_* macro is used, and the test will just pass.
-// TODO(b/29122096): Remove this once TEST_P fixes this problem.
-#define XLA_TEST_P(test_case_name, test_name) TEST_P(test_case_name, test_name)
-
#define DISABLED_ON_CPU(X) X
#define DISABLED_ON_CPU_PARALLEL(X) X
#define DISABLED_ON_GPU(X) X
@@ -71,6 +64,91 @@ limitations under the License.
// clang-format on
-#define XLA_TEST_F(test_fixture, test_name) TEST_F(test_fixture, test_name)
-
+namespace xla {
+
+// Reads a disabled manifest file (and retains it as a singleton) to resolve
+// whether test cases should be disabled on a particular platform.
+string PrependDisabledIfIndicated(const string& test_case_name,
+ const string& test_name);
+
+} // namespace xla
+
+// This is the internal "gtest" class instantiation -- it is identical to the
+// GTEST_TEST_ macro, except that we intercept the test name for potential
+// modification by PrependDisabledIfIndicated. That file can use an arbitrary
+// heuristic to decide whether the test case should be disabled, and we
+// determine whether the test case should be disabled by resolving the (test
+// case name, test name) in a manifest file.
+#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class, parent_id) \
+ class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
+ : public parent_class { \
+ public: \
+ GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
+ \
+ private: \
+ virtual void TestBody(); \
+ static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \
+ GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
+ test_name)); \
+ }; \
+ \
+ ::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \
+ test_name)::test_info_ = \
+ ::testing::internal::MakeAndRegisterTestInfo( \
+ #test_case_name, \
+ PrependDisabledIfIndicated(#test_case_name, #test_name).c_str(), \
+ nullptr, nullptr, \
+ ::testing::internal::CodeLocation(__FILE__, __LINE__), (parent_id), \
+ parent_class::SetUpTestCase, parent_class::TearDownTestCase, \
+ new ::testing::internal::TestFactoryImpl<GTEST_TEST_CLASS_NAME_( \
+ test_case_name, test_name)>); \
+ void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
+
+// This is identical to the TEST_F macro from "gtest", but it potentially
+// disables the test based on an external manifest file, DISABLED_MANIFEST.
+//
+// Per usual, you can see what tests are available via --gunit_list_tests and
+// choose to run tests that have been disabled via the manifest via
+// --gunit_also_run_disabled_tests.
+#define XLA_TEST_F(test_fixture, test_name) \
+ XLA_GTEST_TEST_(test_fixture, test_name, test_fixture, \
+ ::testing::internal::GetTypeId<test_fixture>())
+
+// Likewise, this is identical to the TEST_P macro from "gtest", but
+// potentially disables the test based on the DISABLED_MANIFEST file.
+//
+// We have to wrap this in an outer layer so that any DISABLED_ON_* macros will
+// be properly expanded before the stringification occurs.
+#define XLA_TEST_P_IMPL_(test_case_name, test_name) \
+ class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
+ : public test_case_name { \
+ public: \
+ GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
+ virtual void TestBody(); \
+ \
+ private: \
+ static int AddToRegistry() { \
+ ::testing::UnitTest::GetInstance() \
+ ->parameterized_test_registry() \
+ .GetTestCasePatternHolder<test_case_name>( \
+ #test_case_name, \
+ ::testing::internal::CodeLocation(__FILE__, __LINE__)) \
+ ->AddTestPattern( \
+ #test_case_name, \
+ PrependDisabledIfIndicated(#test_case_name, #test_name).c_str(), \
+ new ::testing::internal::TestMetaFactory<GTEST_TEST_CLASS_NAME_( \
+ test_case_name, test_name)>()); \
+ return 0; \
+ } \
+ static int gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \
+ GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
+ test_name)); \
+ }; \
+ int GTEST_TEST_CLASS_NAME_(test_case_name, \
+ test_name)::gtest_registering_dummy_ = \
+ GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::AddToRegistry(); \
+ void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
+
+#define XLA_TEST_P(test_case_name, test_name) \
+ XLA_TEST_P_IMPL_(test_case_name, test_name)
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_