diff options
-rw-r--r-- | tensorflow/compiler/xla/tests/build_defs.bzl | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/test_macros.cc | 97 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/test_macros.h | 96 |
3 files changed, 189 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index dae0956f0a..7b707cd360 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -230,8 +230,12 @@ def generate_backend_test_macros(backends=[]): native.cc_library( name="test_macros_%s" % backend, testonly = True, + srcs = ["test_macros.cc"], hdrs = ["test_macros.h"], - copts = ["-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper()], + copts = [ + "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(), + "-DXLA_DISABLED_MANIFEST=\\\"\\\"" + ], deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc new file mode 100644 index 0000000000..173fb1b000 --- /dev/null +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/test_macros.h" + +#include <fstream> +#include <streambuf> +#include <string> +#include <unordered_map> + +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace { + +// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is +// disabled. +using ManifestT = std::unordered_map<string, std::vector<string>>; + +ManifestT ReadManifest() { + ManifestT manifest; + + string path = XLA_DISABLED_MANIFEST; + if (path.empty()) { + return manifest; + } + + std::ifstream file_stream(path); + // Note: parens are required to disambiguate vs function decl. + string contents((std::istreambuf_iterator<char>(file_stream)), + std::istreambuf_iterator<char>()); + + std::vector<string> lines = tensorflow::str_util::Split(contents, '\n'); + for (string& line : lines) { + auto comment = line.find("//"); + if (comment != string::npos) { + line = line.substr(0, comment); + } + if (line.empty()) { + continue; + } + tensorflow::str_util::StripTrailingWhitespace(&line); + std::vector<string> pieces = tensorflow::str_util::Split(line, ' '); + CHECK_GE(pieces.size(), 1); + auto& platforms = manifest[pieces[0]]; + for (int64 i = 1; i < pieces.size(); ++i) { + platforms.push_back(pieces[i]); + } + } + return manifest; +} + +} // namespace + +string PrependDisabledIfIndicated(const string& test_case_name, + const string& test_name) { + // TODO(leary): this code reads the manifest for every test case instantiated + // in every file. Consider switching to a singleton or using a compile-time + // genrule instead. + ManifestT manifest = ReadManifest(); + + // First try full match: test_case_name.test_name + // If that fails, try to find just the test_case_name; this would disable all + // tests in the test case. + auto it = manifest.find( + tensorflow::strings::StrCat(test_case_name, ".", test_name)); + if (it == manifest.end()) { + it = manifest.find(test_case_name); + if (it == manifest.end()) { + return test_name; + } + } + + const std::vector<string>& disabled_platforms = it->second; + string platform_string = XLA_PLATFORM; + if (std::find(disabled_platforms.begin(), disabled_platforms.end(), + platform_string) != disabled_platforms.end()) { + return "DISABLED_" + test_name; + } + + // We didn't hit in the disabled manifest entries, so don't disable it. + return test_name; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index 7f987a21ca..3878ac1013 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -33,13 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/test.h" -// Use this macro instead of directly using TEST_P for parameterized tests, -// otherwise DISABLED_ON_* macros nested in TEST_P will not get expanded since -// TEST_P stringifies its argument. That makes the test disabled for all targets -// when any one of the DISABLED_ON_* macro is used, and the test will just pass. -// TODO(b/29122096): Remove this once TEST_P fixes this problem. -#define XLA_TEST_P(test_case_name, test_name) TEST_P(test_case_name, test_name) - #define DISABLED_ON_CPU(X) X #define DISABLED_ON_CPU_PARALLEL(X) X #define DISABLED_ON_GPU(X) X @@ -71,6 +64,91 @@ limitations under the License. // clang-format on -#define XLA_TEST_F(test_fixture, test_name) TEST_F(test_fixture, test_name) - +namespace xla { + +// Reads a disabled manifest file (and retains it as a singleton) to resolve +// whether test cases should be disabled on a particular platform. +string PrependDisabledIfIndicated(const string& test_case_name, + const string& test_name); + +} // namespace xla + +// This is the internal "gtest" class instantiation -- it is identical to the +// GTEST_TEST_ macro, except that we intercept the test name for potential +// modification by PrependDisabledIfIndicated. That file can use an arbitrary +// heuristic to decide whether the test case should be disabled, and we +// determine whether the test case should be disabled by resolving the (test +// case name, test name) in a manifest file. +#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class, parent_id) \ + class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ + : public parent_class { \ + public: \ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ + \ + private: \ + virtual void TestBody(); \ + static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \ + GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \ + test_name)); \ + }; \ + \ + ::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \ + test_name)::test_info_ = \ + ::testing::internal::MakeAndRegisterTestInfo( \ + #test_case_name, \ + PrependDisabledIfIndicated(#test_case_name, #test_name).c_str(), \ + nullptr, nullptr, \ + ::testing::internal::CodeLocation(__FILE__, __LINE__), (parent_id), \ + parent_class::SetUpTestCase, parent_class::TearDownTestCase, \ + new ::testing::internal::TestFactoryImpl<GTEST_TEST_CLASS_NAME_( \ + test_case_name, test_name)>); \ + void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody() + +// This is identical to the TEST_F macro from "gtest", but it potentially +// disables the test based on an external manifest file, DISABLED_MANIFEST. +// +// Per usual, you can see what tests are available via --gunit_list_tests and +// choose to run tests that have been disabled via the manifest via +// --gunit_also_run_disabled_tests. +#define XLA_TEST_F(test_fixture, test_name) \ + XLA_GTEST_TEST_(test_fixture, test_name, test_fixture, \ + ::testing::internal::GetTypeId<test_fixture>()) + +// Likewise, this is identical to the TEST_P macro from "gtest", but +// potentially disables the test based on the DISABLED_MANIFEST file. +// +// We have to wrap this in an outer layer so that any DISABLED_ON_* macros will +// be properly expanded before the stringification occurs. +#define XLA_TEST_P_IMPL_(test_case_name, test_name) \ + class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ + : public test_case_name { \ + public: \ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ + virtual void TestBody(); \ + \ + private: \ + static int AddToRegistry() { \ + ::testing::UnitTest::GetInstance() \ + ->parameterized_test_registry() \ + .GetTestCasePatternHolder<test_case_name>( \ + #test_case_name, \ + ::testing::internal::CodeLocation(__FILE__, __LINE__)) \ + ->AddTestPattern( \ + #test_case_name, \ + PrependDisabledIfIndicated(#test_case_name, #test_name).c_str(), \ + new ::testing::internal::TestMetaFactory<GTEST_TEST_CLASS_NAME_( \ + test_case_name, test_name)>()); \ + return 0; \ + } \ + static int gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \ + GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \ + test_name)); \ + }; \ + int GTEST_TEST_CLASS_NAME_(test_case_name, \ + test_name)::gtest_registering_dummy_ = \ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::AddToRegistry(); \ + void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody() + +#define XLA_TEST_P(test_case_name, test_name) \ + XLA_TEST_P_IMPL_(test_case_name, test_name) #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ |