aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/temporary_device_memory.h
blob: 4e7c63056b02acd3dec8d142468e513857f5c4d8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// Temporary memories are used to allocate scratch space required by an
// operation about to be enqueued onto a stream.
//
//    std::unique_ptr<TemporaryDeviceMemory<float>> temporary_memory =
//        stream.AllocateTemporaryArray<float>(1024).ConsumeValueOrDie();
//    // ... enqueue stuff onto the stream using the temporary memory ...
//    // Note that the memory is accessible via
//    // temporary_memory->device_memory() and similar.
//
//    // Finalize the temporary memory. The underlying device memory may
//    // be released any time after this program point, as another thread may
//    // call Stream::BlockHostUntilDone, causing synchronization. This
//    // finalization also happens automatically for the user if the unique_ptr
//    // goes out of scope.
//    temporary_memory.Finalize();
//
// WARNING: do NOT hold onto the device memory associated with temporary_memory
// after finalization. If temporary_memory->device_memory() is used after the
// temporary memory is finalized, it will cause a DCHECK failure.
//
// Note that standard usage takes advantage of the type-safe wrapper,
// TemporaryDeviceMemory<T>, defined below.
//
// Also see tests for executable sample usage.

#ifndef TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_
#define TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_

#include "tensorflow/stream_executor/device_memory.h"

namespace perftools {
namespace gputools {

class Stream;
namespace internal {
class TemporaryMemoryManager;
}

// Untyped base class (analogous to a void*) for temporary device memory
// allocations associated with a stream.
class TemporaryDeviceMemoryBase {
 public:
  // Marks the temporary memory as finalized if it is not already marked as
  // such.
  ~TemporaryDeviceMemoryBase();

  // Precondition: !IsFinalized()
  DeviceMemoryBase* mutable_device_memory();

  // Precondition: !IsFinalized()
  const DeviceMemoryBase& device_memory() const;

  // "Finalizes" this temporary memory, making it acceptable to release at the
  // next stream synchronization point -- the device memory can be reclaimed at
  // any time after the temporary memory is marked as finalized (e.g. if a
  // separate thread is calls Stream::BlockHostUntilDone). This may only be
  // called once -- see the precondition below.
  //
  // Precondition: !IsFinalized()
  void Finalize();

  // Returns true iff the temporary memory is finalized (that is, the user is
  // done referring to the temporary device memory, and thus it can be released
  // at the next stream synchronization point).
  bool IsFinalized() const;

  // Returns true iff the temporary memory is still allocated.
  //
  // Note: this is a polling call, no guarantee is made that the temporary
  // memory is still allocated after the call has completed.
  bool IsAllocated() const;

 private:
  friend class internal::TemporaryMemoryManager;
  friend class TemporaryDeviceMemoryTest;

  // Note: construction DCHECKs that the memory is known-allocated in the
  // stream's temporary-allocation-manager.
  TemporaryDeviceMemoryBase(Stream* parent, DeviceMemoryBase device_memory,
                            uint64 allocation_generation);

  // The device memory region that has allocated.
  DeviceMemoryBase device_memory_;

  // The generation counter value for the temporary memory record in the
  // temporary memory manager.
  uint64 allocation_generation_;

  // The stream that this temporary memory was allocated for.
  Stream* parent_;
};

// Type-safe wrapper around the base type (which is analogous to a void*).
template <typename T>
class TemporaryDeviceMemory : public TemporaryDeviceMemoryBase {
 public:
  // Type-safe wrapper around TemporaryDeviceMemoryBase::mutable_device_memory.
  DeviceMemory<T>* mutable_device_memory() {
    StaticSlicingAssertionDummy();
    return reinterpret_cast<DeviceMemory<T>*>(
        TemporaryDeviceMemoryBase::mutable_device_memory());
  }

  // Type-safe wrapper around TemporaryDeviceMemoryBase::device_memory.
  const DeviceMemory<T>& device_memory() const {
    StaticSlicingAssertionDummy();
    return reinterpret_cast<const DeviceMemory<T>&>(
        TemporaryDeviceMemoryBase::device_memory());
  }

 private:
  static void StaticSlicingAssertionDummy() {
    static_assert(
        sizeof(TemporaryDeviceMemory) == sizeof(TemporaryDeviceMemoryBase),
        "derived class is simply a wrapper, no members may be added due to "
        "slicing");
  }
};

}  // namespace gputools
}  // namespace perftools

#endif  // TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_