diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/client_library_test_base.h')
-rw-r--r-- | tensorflow/compiler/xla/tests/client_library_test_base.h | 23 |
1 files changed, 4 insertions, 19 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index b578667735..7cfc276ec1 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -196,16 +196,6 @@ class ClientLibraryTestBase : public ::testing::Test { ComputationBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec abs_error); - // Convenience method for running a built computation and comparing the result - // with the HloEvaluator. - void ComputeAndCompare(ComputationBuilder* builder, - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice<Literal> arguments); - void ComputeAndCompare(ComputationBuilder* builder, - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice<Literal> arguments, - ErrorSpec error); - // Create scalar operations for use in reductions. Computation CreateScalarRelu(); Computation CreateScalarMax(); @@ -308,13 +298,6 @@ class ClientLibraryTestBase : public ::testing::Test { const std::function<void(const Literal& actual, const string& error_message)>& verify_output, const Shape* output_with_layout = nullptr); - - // Executes the computation and calculates the expected reference value using - // the HloEvaluator. Returns two literal in the order of (expected, actual). - StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>> - ComputeValueAndReference(ComputationBuilder* builder, - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice<Literal> arguments); }; template <typename NativeT> @@ -486,7 +469,8 @@ template <typename NativeT> std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1( const int width, NativeT min_value, NativeT max_value, uint32 seed) { std::vector<NativeT> result(width); - PseudorandomGenerator<NativeT> generator(min_value, max_value, seed); + test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value, + seed); for (int i = 0; i < width; ++i) { result[i] = generator.get(); } @@ -498,7 +482,8 @@ std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2( const int rows, const int cols, NativeT min_value, NativeT max_value, uint32 seed) { auto result = MakeUnique<Array2D<NativeT>>(rows, cols); - PseudorandomGenerator<NativeT> generator(min_value, max_value, seed); + test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value, + seed); for (int y = 0; y < rows; ++y) { for (int x = 0; x < cols; ++x) { (*result)(y, x) = generator.get(); |