diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/local_client_test_base.h')
-rw-r--r-- | tensorflow/compiler/xla/tests/local_client_test_base.h | 146 |
1 files changed, 146 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h new file mode 100644 index 0000000000..62916d50e3 --- /dev/null +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -0,0 +1,146 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_ + +#include <map> +#include <memory> +#include <vector> + +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/local_service.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class TestAllocator : public StreamExecutorMemoryAllocator { + public: + explicit TestAllocator(perftools::gputools::Platform* platform) + : StreamExecutorMemoryAllocator( + platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) { + } + + StatusOr<perftools::gputools::DeviceMemoryBase> Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) override; + tensorflow::Status Deallocate( + int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) override; + + // Return the number of allocations that have been performed. + int64 allocation_count() const; + int64 allocation_count(int device_ordinal) const; + + // Return the number of deallocations that have been performed. + int64 deallocation_count() const; + int64 deallocation_count(int device_ordinal) const; + + private: + mutable tensorflow::mutex count_mutex_; + + // Global counts of allocations and deallocations. + int64 allocation_count_ GUARDED_BY(count_mutex_) = 0; + int64 deallocation_count_ GUARDED_BY(count_mutex_) = 0; + + // Per-device counts of allocations and deallocations. + std::map<int, int64> device_allocation_count_ GUARDED_BY(count_mutex_); + std::map<int, int64> device_deallocation_count_ GUARDED_BY(count_mutex_); +}; + +// A base class for tests which exercise the LocalClient interface. +class LocalClientTestBase : public ::testing::Test { + protected: + explicit LocalClientTestBase( + perftools::gputools::Platform* platform = nullptr); + + static TestAllocator* GetOrCreateAllocator( + perftools::gputools::Platform* platform); + + // Copy the given literal onto the default device and return a + // ScopedShapedBuffer. + std::unique_ptr<ScopedShapedBuffer> LiteralToScopedShapedBuffer( + const Literal& literal); + // As above, but copy to a specific device. + std::unique_ptr<ScopedShapedBuffer> LiteralToScopedShapedBuffer( + const Literal& literal, int device_ordinal); + + // Construct and return a literal containing the array represented by + // shaped_buffer. + std::unique_ptr<Literal> ShapedBufferToLiteral( + const ShapedBuffer& shaped_buffer); + + // Helper for converting a ShapedBuffer into a literal. + void CopyShapedBufferToLiteral(const ShapedBuffer& shaped_buffer, + ShapeIndex* index, Literal* literal); + + // Execute the given computation on the local client. With and without + // options. + std::unique_ptr<ScopedShapedBuffer> ExecuteLocally( + const Computation& computation, + tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments); + std::unique_ptr<ScopedShapedBuffer> ExecuteLocally( + const Computation& computation, + tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, + const LocalExecuteOptions& options); + + // Returns a default set of execute options, configured to use allocator_ + // as the allocator. + LocalExecuteOptions DefaultLocalExecuteOptions() const; + + // Overloads which write result into the given buffer. + void ExecuteLocally( + const Computation& computation, + tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, + ShapedBuffer* result); + void ExecuteLocally( + const Computation& computation, + tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, + const LocalExecuteOptions& options, ShapedBuffer* result); + + // Convert a ShapedBuffer into a ScopedShaped buffer so that all buffers are + // deallocated when the object is destructed. + std::unique_ptr<ScopedShapedBuffer> ShapedBufferToScopedShapedBuffer( + std::unique_ptr<ShapedBuffer> shaped_buffer, + DeviceMemoryAllocator* allocator); + + string TestName() const { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } + + // The allocator must live as long as the service which lives until the end of + // the process, so make the allocator static. + static TestAllocator* allocator_; + + perftools::gputools::StreamExecutor* stream_executor_; + TransferManager* transfer_manager_; + + LocalClient* local_client_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_ |