aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/allocation_tracker.h
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-04-22 14:48:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-22 14:50:48 -0700
commit56fd856425f1322d22796decb1f0580c8fab5d5a (patch)
treeb2a40e2e9180a4549c451d970585a2836ecaa3a4 /tensorflow/compiler/xla/service/allocation_tracker.h
parentea0c8a7ed84eb5cdf8ca6a856f9bd05a95597739 (diff)
[XLA] Make Executable return a ScopedShapedBuffer.
Previously, we returned a plain ShapedBuffer. But this doesn't capture our semantics: It's up to the callee to free this ShapedBuffer. PiperOrigin-RevId: 193854051
Diffstat (limited to 'tensorflow/compiler/xla/service/allocation_tracker.h')
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.h32
1 files changed, 22 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h
index 2bfcd53712..1174fa641c 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.h
+++ b/tensorflow/compiler/xla/service/allocation_tracker.h
@@ -45,13 +45,13 @@ class AllocationTracker {
// Registers a shaped buffer of device memory, and returns a corresponding
// handle that can be used for talking to XLA clients. The given shaped buffer
// will be treated as the buffer corresponding to the only replica.
- StatusOr<GlobalDataHandle> Register(ShapedBuffer shaped_buffer,
+ StatusOr<GlobalDataHandle> Register(ScopedShapedBuffer shaped_buffer,
const string& tag);
// Registers a vector of shaped buffers of device memory, one per replica, and
// returns a corresponding handle that can be used for talking to XLA clients.
StatusOr<GlobalDataHandle> RegisterReplicatedBuffers(
- std::vector<ShapedBuffer> replicated_buffers, const string& tag);
+ std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag);
// Unregister the allocation for the given data handle.
Status Unregister(const GlobalDataHandle& data);
@@ -87,21 +87,21 @@ class AllocationTracker {
};
// Internal helper which resolves the given GlobalDataHandle to a
- // ShapedBuffer.
+ // list of ScopedShapedBuffers.
StatusOr<std::vector<const ShapedBuffer*>> ResolveInternal(
const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_);
// Internal helper which registers a vector of shaped buffers, one per
- // replica.
+ // replica. ShapedBufferTy is either ScopedShapedBuffer or ShapedBuffer. If
+ // it's ShapedBuffer, all of the given buffers must already be tracked by this
+ // object -- presumably this is a call from DeconstructTuple.
+ template <typename ShapedBufferTy>
StatusOr<GlobalDataHandle> RegisterInternal(
- std::vector<ShapedBuffer> replicated_buffers, const string& tag)
+ std::vector<ShapedBufferTy> replicated_buffers, const string& tag)
EXCLUSIVE_LOCKS_REQUIRED(mutex_);
- // Resets the shaped buffers corresponding to the given handle.
- Status Reset(const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_);
-
// Adds the given device address to the allocation tracker, or if it already
- // exists, then increment it's reference count.
+ // exists, then increment its reference count.
void AddAllocationOrIncrementRefCount(se::DeviceMemoryBase device_memory,
int device_ordinal)
EXCLUSIVE_LOCKS_REQUIRED(mutex_);
@@ -133,7 +133,19 @@ class AllocationTracker {
// buffers for different replicas.
//
// The ShapedBuffers in this map's vectors need to be unique_ptrs, because our
- // public API returns pointers to them.
+ // public API returns pointers to them. We expect the concrete class to be
+ // ShapedBuffer and never ScopedShapedBuffer; deallocation of buffers is
+ // handled by opaque_to_allocation_map_.
+ //
+ // The elements of the vectors need to be unique_ptrs because we return
+ // pointers to them. (In theory we could use std::list or something instead,
+ // but we also want to be able to null out these elements.)
+ //
+ // The reason that the elements can't be unique_ptr<ScopedShapedBuffer>s is
+ // the existence of DeconstructTuple(). This function allows us to create a
+ // non-owning "view" into a tuple's sub-buffers. The sub-buffers are then
+ // free'd when both the view *and* the original tuple are Unregistered. This
+ // refcounting is managed in opaque_to_allocation_map_.
tensorflow::gtl::FlatMap<int64, std::vector<std::unique_ptr<ShapedBuffer>>>
handle_to_shaped_buffers_ GUARDED_BY(mutex_);