aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/client_library_test_base.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/client_library_test_base.h')
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h23
1 files changed, 4 insertions, 19 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index b578667735..7cfc276ec1 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -196,16 +196,6 @@ class ClientLibraryTestBase : public ::testing::Test {
ComputationBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec abs_error);
- // Convenience method for running a built computation and comparing the result
- // with the HloEvaluator.
- void ComputeAndCompare(ComputationBuilder* builder,
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<Literal> arguments);
- void ComputeAndCompare(ComputationBuilder* builder,
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<Literal> arguments,
- ErrorSpec error);
-
// Create scalar operations for use in reductions.
Computation CreateScalarRelu();
Computation CreateScalarMax();
@@ -308,13 +298,6 @@ class ClientLibraryTestBase : public ::testing::Test {
const std::function<void(const Literal& actual,
const string& error_message)>& verify_output,
const Shape* output_with_layout = nullptr);
-
- // Executes the computation and calculates the expected reference value using
- // the HloEvaluator. Returns two literal in the order of (expected, actual).
- StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
- ComputeValueAndReference(ComputationBuilder* builder,
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<Literal> arguments);
};
template <typename NativeT>
@@ -486,7 +469,8 @@ template <typename NativeT>
std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
const int width, NativeT min_value, NativeT max_value, uint32 seed) {
std::vector<NativeT> result(width);
- PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
+ test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
+ seed);
for (int i = 0; i < width; ++i) {
result[i] = generator.get();
}
@@ -498,7 +482,8 @@ std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
const int rows, const int cols, NativeT min_value, NativeT max_value,
uint32 seed) {
auto result = MakeUnique<Array2D<NativeT>>(rows, cols);
- PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
+ test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
+ seed);
for (int y = 0; y < rows; ++y) {
for (int x = 0; x < cols; ++x) {
(*result)(y, x) = generator.get();