aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shaped_buffer.h
blob: e1d26da4a20c0105be304b1a34c81515fcdc6b7f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_

#include <memory>
#include <ostream>
#include <string>

#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"

namespace xla {

class ScopedShapedBuffer;

// Class which encapsulates a buffer or set of buffers containing data of a
// particular XLA shape.
class ShapedBuffer {
 public:
  // Construct a ShapedBuffer with null DeviceMemoryBases at each index. The
  // shape of the data on the host and the device may differ because the device
  // may have a different representation for different data types. Therefore,
  // both the on-host and on-device shape are required. The on-device shape
  // determines the number of device allocations (DeviceMemoryBase) held by the
  // ShapedBuffer.
  ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape,
               const se::Platform* platform, int device_ordinal);

  // Movable, but not copyable.
  ShapedBuffer(ShapedBuffer&& s);
  ShapedBuffer& operator=(ShapedBuffer&&);
  ShapedBuffer(const ShapedBuffer&) = delete;
  ShapedBuffer& operator=(const ShapedBuffer&) = delete;

  // Prevent (some forms of) accidental object slicing.
  ShapedBuffer(const ScopedShapedBuffer&) = delete;
  ShapedBuffer& operator=(const ScopedShapedBuffer&) = delete;

  virtual ~ShapedBuffer();

  // Returns the shape of the on-host representation of the data held by this
  // ShapedBuffer.
  const Shape& on_host_shape() const { return on_host_shape_; }

  // Returns the shape of the on-device representation of the data held by this
  // ShapedBuffer.
  const Shape& on_device_shape() const { return on_device_shape_; }

  const se::Platform* platform() const { return platform_; }
  int device_ordinal() const { return device_ordinal_; }

  // Return the root buffer of the shape (shape index {}).
  const se::DeviceMemoryBase& root_buffer() const {
    return buffer(/*index=*/{});
  }

  // Returns the buffer at the given shape index where index is defined as in
  // ShapeUtil::GetSubshape.
  const se::DeviceMemoryBase& buffer(const ShapeIndex& index) const {
    return buffers_.element(index);
  }

  // Sets the device memory buffer at the given index.
  void set_buffer(const se::DeviceMemoryBase& buffer, const ShapeIndex& index) {
    *buffers_.mutable_element(index) = buffer;
  }

  // Sets all buffers.
  //
  // Precondition: buffers.shape == on_device_shape_
  void set_buffers(ShapeTree<se::DeviceMemoryBase> buffers) {
    CHECK(ShapeUtil::Equal(buffers.shape(), on_device_shape_));
    buffers_ = std::move(buffers);
  }

  // Returns the underlying ShapeTree containing all the device addresses in the
  // ShapedBuffer.
  const ShapeTree<se::DeviceMemoryBase>& buffers() const { return buffers_; }
  ShapeTree<se::DeviceMemoryBase>& buffers() { return buffers_; }

  // Set all device memory pointers in the object to null.
  void clear();

  string ToString() const;

 protected:
  // The shape of the data when represented on the host.
  Shape on_host_shape_;

  // The shape of the data on the device.
  Shape on_device_shape_;

  // The platform the memory is allocated on.
  const se::Platform* platform_;

  // The device the memory is allocated on.
  int device_ordinal_;

  // The tree of device buffers. Its shape is on_device_shape().
  ShapeTree<se::DeviceMemoryBase> buffers_;
};

std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer);

// ShapedBuffer derived class which allocates all internal buffers on
// construction and deallocates the memory when the object is
// destructed.
//
// TODO(timshen): Remove inheritance between ScopedShapedBuffer and
// ShapedBuffer.  There should never be a need to consider a ScopedShapedBuffer
// as a ShapedBuffer, because in that case we should just be able to pass around
// our ShapeTree<DeviceMemoryBase>.  Inheritance only adds complexity.  See
// discussion in cl/192849370.
class ScopedShapedBuffer : public ShapedBuffer {
 public:
  // Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index.
  explicit ScopedShapedBuffer(const Shape& on_host_shape,
                              const Shape& on_device_shape,
                              DeviceMemoryAllocator* allocator,
                              int device_ordinal);

  // Create a ScopedShapedBuffer by taking over the memory from the incoming
  // ShapedBuffer.
  explicit ScopedShapedBuffer(ShapedBuffer shaped_buffer,
                              DeviceMemoryAllocator* allocator);

  // Movable, but not copyable.
  ScopedShapedBuffer(ScopedShapedBuffer&& s);
  ScopedShapedBuffer& operator=(ScopedShapedBuffer&&);
  ScopedShapedBuffer(const ScopedShapedBuffer&) = delete;
  ScopedShapedBuffer& operator=(const ScopedShapedBuffer&) = delete;

  // All buffers in the shape are deallocated on destruction.
  ~ScopedShapedBuffer() override;

  // Return the allocator used to allocate the device memory held in this
  // ScopedShapedBuffer.
  DeviceMemoryAllocator* memory_allocator() const { return allocator_; }

  // Sets the device memory buffer at the given index.
  //
  // If the given buffer's device memory is non-null, its device_ordinal and
  // allocator must match those in `this`.
  void set_buffer(OwningDeviceMemory buffer, const ShapeIndex& index) {
    if (!buffer.is_null()) {
      CHECK_EQ(buffer.device_ordinal(), device_ordinal());
      CHECK_EQ(buffer.allocator(), allocator_);
      *buffers_.mutable_element(index) = buffer.Forget();
    } else {
      *buffers_.mutable_element(index) = se::DeviceMemoryBase();
    }
  }

  // Like unique_ptr::release(), creates and returns a regular ShapedBuffer from
  // this ScopedShapedBuffer, without freeing any of the associated memory.
  //
  // It's the caller's job to ensure that the memory contained therein is freed.
  TF_MUST_USE_RESULT ShapedBuffer release();

 protected:
  void Deallocate();

  DeviceMemoryAllocator* allocator_;
};

}  // namespace xla

#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_