diff options
Diffstat (limited to 'tensorflow/compiler/xla/test_helpers.h')
-rw-r--r-- | tensorflow/compiler/xla/test_helpers.h | 355 |
1 files changed, 355 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h new file mode 100644 index 0000000000..f923d9f36c --- /dev/null +++ b/tensorflow/compiler/xla/test_helpers.h @@ -0,0 +1,355 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_ +#define TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_ + +#include <list> +#include <vector> + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/test.h" + +// This module contains a minimal subset of gmock functionality just +// sufficient to execute the currently existing tests. +namespace util { +class Status; +} // namespace util + +namespace xla { +template <typename T> +class Array2D; +class Literal; + +namespace testing { + +class AssertionResult { + public: + explicit AssertionResult(bool success) : success_(success) {} + + // Returns true iff the assertion succeeded. + operator bool() const { return success_; } // NOLINT + + // Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE. + AssertionResult operator!() const; + + // Returns the text streamed into this AssertionResult. Test assertions + // use it when they fail (i.e., the predicate's outcome doesn't match the + // assertion's expectation). When nothing has been streamed into the + // object, returns an empty string. + const char* message() const { + return message_ != nullptr ? message_->c_str() : ""; + } + + // Streams a custom failure message into this object. + template <typename T> + AssertionResult& operator<<(const T& value) { + AppendMessage(::testing::Message() << value); + return *this; + } + + // Allows streaming basic output manipulators such as endl or flush into + // this object. + AssertionResult& operator<<( + std::ostream& (*basic_manipulator)(std::ostream& stream)) { + AppendMessage(::testing::Message() << basic_manipulator); + return *this; + } + + // Copy operator. + AssertionResult(const AssertionResult& ar); + + // Assignment operator. + AssertionResult& operator=(const AssertionResult&); + + private: + // Appends the contents of message to message_. + void AppendMessage(const ::testing::Message& a_message) { + if (message_ == nullptr) message_.reset(new std::string); + message_->append(a_message.GetString().c_str()); + } + + bool success_ = false; + + // Stores the message describing the condition in case the + // expectation construct is not satisfied with the predicate's + // outcome. Referenced via a pointer to avoid taking too much stack + // frame space with test assertions. + std::unique_ptr<std::string> message_; +}; + +AssertionResult AssertionFailure(); + +AssertionResult AssertionSuccess(); + +std::function<bool(tensorflow::StringPiece)> ContainsRegex( + const tensorflow::StringPiece regex); + +std::function<bool(tensorflow::StringPiece)> HasSubstr( + const tensorflow::StringPiece part); + +// Matcher for a vector of same-type values for which operator= is +// defined. +template <typename T> +std::function<AssertionResult(const std::vector<T>& actual)> VectorMatcher( + const std::vector<T>& expected) { + return [expected](const std::vector<T>& actual) -> AssertionResult { + int len = expected.size(); + if (actual.size() != len) { + return AssertionFailure() << "Actual values len of " << actual.size() + << " != expected.size " << len; + } + for (int i = 0; i < len; ++i) { + if (actual[i] != expected[i]) { + return AssertionFailure() << "Element " << i << " actual " << actual[i] + << " != " << expected[i]; + } + } + return AssertionSuccess(); + }; +} + +// Approximate matcher for a vector of floats or similar. +template <typename T> +std::function<AssertionResult(const std::vector<T>& actual)> +ApproxVectorMatcher(const std::vector<T>& expected, float abs_diff, + float rel_diff) { + return [abs_diff, rel_diff, + expected](const std::vector<T>& actual) -> AssertionResult { + int len = expected.size(); + if (actual.size() != len) { + AssertionResult ar = AssertionFailure() << "Actual values len of " + << actual.size() + << " != expected.size " << len; + LOG(ERROR) << ar.message(); + return ar; + } + for (int i = 0; i < len; ++i) { + T diff = actual[i] - expected[i]; + if (diff < 0) { + diff *= -1; + } + if (diff > abs_diff) { + T rdiff = (expected[i] != 0 ? diff / expected[i] : 0.0 * expected[i]); + if (rdiff > rel_diff) { + AssertionResult ar = AssertionFailure() + << "Element " << i << " actual " << actual[i] + << " != " << expected[i] + << "( abs_diff = " << diff + << ", rel_diff = " << rdiff << ")"; + LOG(ERROR) << ar.message(); + return ar; + } + } + } + return AssertionSuccess(); + }; +} + +// Matches a vector of same-type values against another, succeeding so +// long as they have the same length and every value in 'actual' +// matches one in 'expected.' Does not verify an exhaustive +// one-to-one mapping between the two. +template <typename T> +std::function<AssertionResult(const std::vector<T>& actual)> +UnorderedElementsAre(const std::vector<T>& expected) { + return [expected](const std::vector<T>& actual) -> AssertionResult { + if (actual.size() != expected.size()) { + return AssertionFailure() << "sizes don't match"; + } + for (auto a : actual) { + bool found = false; + for (auto e : expected) { + if (a == e) { + found = true; + break; + } + } + if (!found) { + return AssertionFailure() << "actual element " << a + << " not in expected"; + } + } + return AssertionSuccess(); + }; +} + +// Overloaded cover functions for UnorderedElementsAre, for the numbers +// of values used in practice. +template <typename T> +std::function<AssertionResult(const std::vector<T>& actual)> UnorderedMatcher( + T a) { + std::vector<T> expected; + expected.push_back(a); + return testing::UnorderedElementsAre<T>(expected); +} + +template <typename T> +std::function<AssertionResult(const std::vector<T>& actual)> UnorderedMatcher( + T a, T b) { + std::vector<T> expected; + expected.push_back(a); + expected.push_back(b); + return testing::UnorderedElementsAre<T>(expected); +} + +template <typename T> +std::function<AssertionResult(const std::vector<T>& actual)> UnorderedMatcher( + T a, T b, T c) { + std::vector<T> expected; + expected.push_back(a); + expected.push_back(b); + expected.push_back(c); + return testing::UnorderedElementsAre<T>(expected); +} + +template <typename T> +std::function<AssertionResult(const std::vector<T>& actual)> UnorderedMatcher( + T a, T b, T c, T d) { + std::vector<T> expected; + expected.push_back(a); + expected.push_back(b); + expected.push_back(c); + expected.push_back(d); + return testing::UnorderedElementsAre<T>(expected); +} + +template <typename T> +std::function<AssertionResult(const std::vector<T>& actual)> UnorderedMatcher( + T a, T b, T c, T d, T e) { + std::vector<T> expected; + expected.push_back(a); + expected.push_back(b); + expected.push_back(c); + expected.push_back(d); + expected.push_back(e); + return testing::UnorderedElementsAre<T>(expected); +} + +template <typename T> +std::function<AssertionResult(const std::vector<T>& actual)> UnorderedMatcher( + T a, T b, T c, T d, T e, T f) { + std::vector<T> expected; + expected.push_back(a); + expected.push_back(b); + expected.push_back(c); + expected.push_back(d); + expected.push_back(e); + expected.push_back(f); + return testing::UnorderedElementsAre<T>(expected); +} + +// Overloaded cover functions for VectorMatcher for the numbers of +// elements used in practice. +template <typename T> +std::function<AssertionResult(const std::vector<T>& actual)> OrderedMatcher( + T a) { + std::vector<T> expected; + expected.push_back(a); + return testing::VectorMatcher<T>(expected); +} + +template <typename T> +std::function<AssertionResult(const std::vector<T>& actual)> OrderedMatcher( + T a, T b) { + std::vector<T> expected; + expected.push_back(a); + expected.push_back(b); + return testing::VectorMatcher<T>(expected); +} + +template <typename T> +std::function<AssertionResult(const std::vector<T>& actual)> OrderedMatcher( + T a, T b, T c) { + std::vector<T> expected; + expected.push_back(a); + expected.push_back(b); + expected.push_back(c); + return testing::VectorMatcher<T>(expected); +} + +template <typename T> +std::function<AssertionResult(const std::vector<T>& actual)> OrderedMatcher( + T a, T b, T c, T d) { + std::vector<T> expected; + expected.push_back(a); + expected.push_back(b); + expected.push_back(c); + expected.push_back(d); + return testing::VectorMatcher<T>(expected); +} + +// Convert a RepeatedField to a flat vector. +template <typename T> +std::vector<T> PBToVec(const tensorflow::protobuf::RepeatedField<T> rf) { + return std::vector<T>(rf.begin(), rf.end()); +} + +// Convert a List to a flat vector. +template <typename T> +std::vector<T> ListToVec(const std::list<T>& l) { + return std::vector<T>(l.begin(), l.end()); +} + +// Convert a Set to a flat vector. +template <typename T> +std::vector<T> SetToVec(const std::set<T>& c) { + return std::vector<T>(c.begin(), c.end()); +} + +// Convert an Array to a flat vector. +template <typename T> +std::vector<T> Array2DToVec(const Array2D<T>& a) { + return std::vector<T>(a.data(), a.data() + a.num_elements()); +} + +namespace internal_status { +inline const ::tensorflow::Status& GetStatus( + const ::tensorflow::Status& status) { + return status; +} + +template <typename T> +inline const ::tensorflow::Status& GetStatus(const StatusOr<T>& status) { + return status.status(); +} +} // namespace internal_status + +} // namespace testing +} // namespace xla + +// The following macros are similar to macros in gmock, but deliberately named +// differently in order to avoid conflicts in files which include both. + +// Macros for testing the results of functions that return tensorflow::Status or +// StatusOr<T> (for any type T). +#define EXPECT_IS_OK(expression) \ + EXPECT_EQ(tensorflow::Status::OK(), \ + xla::testing::internal_status::GetStatus(expression)) +#undef ASSERT_IS_OK +#define ASSERT_IS_OK(expression) \ + ASSERT_EQ(tensorflow::Status::OK(), \ + xla::testing::internal_status::GetStatus(expression)) + +// Macros that apply a Matcher to a Value, returning an +// AssertionResult which gets digested by a standard gunit macro. +#define EXPECT_MATCH(V, M) EXPECT_TRUE((M)((V))) +#define ASSERT_MATCH(V, M) ASSERT_TRUE(M(V)) + +#endif // TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_ |