aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/local_client_test_base.h
blob: 90095c5d410f1561a1303a0f62f44d22ed5340f9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
#define TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_

#include <map>
#include <memory>
#include <vector>

#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"

namespace xla {

class TestAllocator : public StreamExecutorMemoryAllocator {
 public:
  explicit TestAllocator(se::Platform* platform)
      : StreamExecutorMemoryAllocator(
            platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) {
  }

  StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
                                        bool retry_on_failure) override;
  Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;

  // Return the number of allocations that have been performed.
  int64 allocation_count() const;
  int64 allocation_count(int device_ordinal) const;

  // Return the number of deallocations that have been performed.
  int64 deallocation_count() const;
  int64 deallocation_count(int device_ordinal) const;

 private:
  mutable tensorflow::mutex count_mutex_;

  // Global counts of allocations and deallocations.
  int64 allocation_count_ GUARDED_BY(count_mutex_) = 0;
  int64 deallocation_count_ GUARDED_BY(count_mutex_) = 0;

  // Per-device counts of allocations and deallocations.
  std::map<int, int64> device_allocation_count_ GUARDED_BY(count_mutex_);
  std::map<int, int64> device_deallocation_count_ GUARDED_BY(count_mutex_);
};

// A base class for tests which exercise the LocalClient interface.
class LocalClientTestBase : public ::testing::Test {
 protected:
  struct EigenThreadPoolWrapper;
  explicit LocalClientTestBase(se::Platform* platform = nullptr);
  virtual ~LocalClientTestBase();

  static TestAllocator* GetOrCreateAllocator(se::Platform* platform);

  // Copy the given literal onto the default device and return a
  // ScopedShapedBuffer. Convenience wrapper around
  // LocalClient::LiteralToShapedBuffer.
  ScopedShapedBuffer LiteralToShapedBuffer(const Literal& literal);

  // Construct and return a literal containing the array represented by
  // shaped_buffer.
  std::unique_ptr<Literal> ShapedBufferToLiteral(
      const ShapedBuffer& shaped_buffer);

  // Execute the given computation on the local client. With and without
  // options.
  StatusOr<ScopedShapedBuffer> ExecuteLocally(
      const XlaComputation& computation,
      absl::Span<const ShapedBuffer* const> arguments);
  StatusOr<ScopedShapedBuffer> ExecuteLocally(
      const XlaComputation& computation,
      absl::Span<const ShapedBuffer* const> arguments,
      const ExecutableBuildOptions& build_options,
      const ExecutableRunOptions& run_options);

  ScopedShapedBuffer ExecuteLocallyOrDie(
      const XlaComputation& computation,
      absl::Span<const ShapedBuffer* const> arguments);
  ScopedShapedBuffer ExecuteLocallyOrDie(
      const XlaComputation& computation,
      absl::Span<const ShapedBuffer* const> arguments,
      const ExecutableBuildOptions& build_options,
      const ExecutableRunOptions& run_options);

  // Returns a default set of execute options.
  ExecutableBuildOptions DefaultExecutableBuildOptions() const;

  // Returns a default set of execute options, configured to use allocator_
  // as the allocator.
  ExecutableRunOptions DefaultExecutableRunOptions() const;

  string TestName() const {
    return ::testing::UnitTest::GetInstance()->current_test_info()->name();
  }

  // The allocator must live as long as the service, which lives until the end
  // of the process. So make the allocator static.
  static TestAllocator* allocator_;

  se::StreamExecutor* stream_executor_;
  TransferManager* transfer_manager_;

  LocalClient* local_client_;

  std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
};

}  // namespace xla

#endif  // TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_