aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/allocation_tracker.h
blob: 98d1a302a9f66f4a00e05d62837a79133e222687 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALLOCATION_TRACKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_ALLOCATION_TRACKER_H_

#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"

namespace xla {

// Tracks allocations for the XLA service; allocations can be registered
// with shape/device/tag and resolved from a handle for later use.
class AllocationTracker {
 public:
  // The allocator is used for deallocating memory when allocations are
  // deregistered. All registered allocations must have the same platform as the
  // allocator.
  AllocationTracker(Backend* backend) : backend_(backend), next_handle_(1) {}

  // Registers a shaped buffer of device memory, and returns a corresponding
  // handle that can be used for talking to XLA clients. The given shaped buffer
  // will be treated as the buffer corresponding to the only replica.
  StatusOr<GlobalDataHandle> Register(ScopedShapedBuffer shaped_buffer,
                                      const string& tag);

  // Registers a vector of shaped buffers of device memory, one per replica, and
  // returns a corresponding handle that can be used for talking to XLA clients.
  StatusOr<GlobalDataHandle> RegisterReplicatedBuffers(
      std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag);

  // Unregister the allocation for the given data handle.
  Status Unregister(const GlobalDataHandle& data);

  // Returns a vector of global data handles that point to the tuple elements.
  StatusOr<std::vector<GlobalDataHandle>> DeconstructTuple(
      const GlobalDataHandle& Data);

  // Resolve a handle from an XLA client to a vector of shaped buffers, one per
  // replica, or provide an error status to say whether any of those buffers
  // were not found (or found, but found deallocated).
  StatusOr<std::vector<const ShapedBuffer*>> Resolve(
      const GlobalDataHandle& data) const;

  // Resolves a handle from an XLA client and replica id to a shaped buffer, or
  // provide an error status to say whether it was not found (or found, but
  // found deallocated).
  StatusOr<const ShapedBuffer*> ResolveForReplica(const GlobalDataHandle& data,
                                                  int replica_id) const;

 private:
  // Data structure encapsulating single memory allocation on the device.
  struct Allocation {
    // The pointer to this allocation.
    OwningDeviceMemory device_memory;

    // This is the number of times this memory allocation is referred to by
    // registered data handles.
    int ref_count;
  };

  // Internal helper which resolves the given GlobalDataHandle to a
  // list of ScopedShapedBuffers.
  StatusOr<std::vector<const ShapedBuffer*>> ResolveInternal(
      const GlobalDataHandle& data) const EXCLUSIVE_LOCKS_REQUIRED(mutex_);

  // Internal helper which registers a vector of shaped buffers, one per
  // replica.  ShapedBufferTy is either ScopedShapedBuffer or ShapedBuffer.  If
  // it's ShapedBuffer, all of the given buffers must already be tracked by this
  // object -- presumably this is a call from DeconstructTuple.
  template <typename ShapedBufferTy>
  StatusOr<GlobalDataHandle> RegisterInternal(
      std::vector<ShapedBufferTy> replicated_buffers, const string& tag)
      EXCLUSIVE_LOCKS_REQUIRED(mutex_);

  // Adds the given device address to the allocation tracker, or if it already
  // exists, then increment its reference count.
  void AddAllocationOrIncrementRefCount(se::DeviceMemoryBase device_memory,
                                        int device_ordinal)
      EXCLUSIVE_LOCKS_REQUIRED(mutex_);

  // Decrements the reference count of the given device memory. Then, if it is
  // zero, deallocate the memory.
  Status DecrementRefCount(se::DeviceMemoryBase device_memory,
                           int device_ordinal) EXCLUSIVE_LOCKS_REQUIRED(mutex_);

  // A map from device memory opaque value to allocation. One such map is
  // maintained per device ordinal.
  using AllocationMap = absl::flat_hash_map<const void*, Allocation>;

  mutable tensorflow::mutex mutex_;

  // Backend to use with this tracker. The backend supplies the memory allocator
  // to use when deallocating memory.
  Backend* backend_;

  // The next handle to assign to an allocation, guarded by the same mutex as
  // the mapping as they'll be mutated at the same time.
  int64 next_handle_ GUARDED_BY(mutex_);

  // A map from device ordinal to AllocationMap.
  absl::flat_hash_map<int, AllocationMap> opaque_to_allocation_map_
      GUARDED_BY(mutex_);

  // A map from data handle to a vector of shaped buffers that represent the
  // buffers for different replicas.
  //
  // The ShapedBuffers in this map's vectors need to be unique_ptrs, because our
  // public API returns pointers to them.  We expect the concrete class to be
  // ShapedBuffer and never ScopedShapedBuffer; deallocation of buffers is
  // handled by opaque_to_allocation_map_.
  //
  // The elements of the vectors need to be unique_ptrs because we return
  // pointers to them.  (In theory we could use std::list or something instead,
  // but we also want to be able to null out these elements.)
  //
  // The reason that the elements can't be unique_ptr<ScopedShapedBuffer>s is
  // the existence of DeconstructTuple().  This function allows us to create a
  // non-owning "view" into a tuple's sub-buffers.  The sub-buffers are then
  // free'd when both the view *and* the original tuple are Unregistered.  This
  // refcounting is managed in opaque_to_allocation_map_.
  absl::flat_hash_map<int64, std::vector<std::unique_ptr<ShapedBuffer>>>
      handle_to_shaped_buffers_ GUARDED_BY(mutex_);

  TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker);
};

}  // namespace xla

#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_ALLOCATION_TRACKER_H_