aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/local_client_test_base.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/local_client_test_base.h')
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.h146
1 files changed, 146 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h
new file mode 100644
index 0000000000..62916d50e3
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.h
@@ -0,0 +1,146 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
+
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/local_service.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class TestAllocator : public StreamExecutorMemoryAllocator {
+ public:
+ explicit TestAllocator(perftools::gputools::Platform* platform)
+ : StreamExecutorMemoryAllocator(
+ platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) {
+ }
+
+ StatusOr<perftools::gputools::DeviceMemoryBase> Allocate(
+ int device_ordinal, uint64 size, bool retry_on_failure) override;
+ tensorflow::Status Deallocate(
+ int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) override;
+
+ // Return the number of allocations that have been performed.
+ int64 allocation_count() const;
+ int64 allocation_count(int device_ordinal) const;
+
+ // Return the number of deallocations that have been performed.
+ int64 deallocation_count() const;
+ int64 deallocation_count(int device_ordinal) const;
+
+ private:
+ mutable tensorflow::mutex count_mutex_;
+
+ // Global counts of allocations and deallocations.
+ int64 allocation_count_ GUARDED_BY(count_mutex_) = 0;
+ int64 deallocation_count_ GUARDED_BY(count_mutex_) = 0;
+
+ // Per-device counts of allocations and deallocations.
+ std::map<int, int64> device_allocation_count_ GUARDED_BY(count_mutex_);
+ std::map<int, int64> device_deallocation_count_ GUARDED_BY(count_mutex_);
+};
+
+// A base class for tests which exercise the LocalClient interface.
+class LocalClientTestBase : public ::testing::Test {
+ protected:
+ explicit LocalClientTestBase(
+ perftools::gputools::Platform* platform = nullptr);
+
+ static TestAllocator* GetOrCreateAllocator(
+ perftools::gputools::Platform* platform);
+
+ // Copy the given literal onto the default device and return a
+ // ScopedShapedBuffer.
+ std::unique_ptr<ScopedShapedBuffer> LiteralToScopedShapedBuffer(
+ const Literal& literal);
+ // As above, but copy to a specific device.
+ std::unique_ptr<ScopedShapedBuffer> LiteralToScopedShapedBuffer(
+ const Literal& literal, int device_ordinal);
+
+ // Construct and return a literal containing the array represented by
+ // shaped_buffer.
+ std::unique_ptr<Literal> ShapedBufferToLiteral(
+ const ShapedBuffer& shaped_buffer);
+
+ // Helper for converting a ShapedBuffer into a literal.
+ void CopyShapedBufferToLiteral(const ShapedBuffer& shaped_buffer,
+ ShapeIndex* index, Literal* literal);
+
+ // Execute the given computation on the local client. With and without
+ // options.
+ std::unique_ptr<ScopedShapedBuffer> ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ std::unique_ptr<ScopedShapedBuffer> ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options);
+
+ // Returns a default set of execute options, configured to use allocator_
+ // as the allocator.
+ LocalExecuteOptions DefaultLocalExecuteOptions() const;
+
+ // Overloads which write result into the given buffer.
+ void ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ ShapedBuffer* result);
+ void ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options, ShapedBuffer* result);
+
+ // Convert a ShapedBuffer into a ScopedShaped buffer so that all buffers are
+ // deallocated when the object is destructed.
+ std::unique_ptr<ScopedShapedBuffer> ShapedBufferToScopedShapedBuffer(
+ std::unique_ptr<ShapedBuffer> shaped_buffer,
+ DeviceMemoryAllocator* allocator);
+
+ string TestName() const {
+ return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+ }
+
+ // The allocator must live as long as the service which lives until the end of
+ // the process, so make the allocator static.
+ static TestAllocator* allocator_;
+
+ perftools::gputools::StreamExecutor* stream_executor_;
+ TransferManager* transfer_manager_;
+
+ LocalClient* local_client_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_