diff options
author | Justin Lebar <jlebar@google.com> | 2018-04-22 14:48:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-22 14:50:48 -0700 |
commit | 56fd856425f1322d22796decb1f0580c8fab5d5a (patch) | |
tree | b2a40e2e9180a4549c451d970585a2836ecaa3a4 /tensorflow/compiler/xla/service/allocation_tracker.h | |
parent | ea0c8a7ed84eb5cdf8ca6a856f9bd05a95597739 (diff) |
[XLA] Make Executable return a ScopedShapedBuffer.
Previously, we returned a plain ShapedBuffer. But this doesn't capture
our semantics: It's up to the callee to free this ShapedBuffer.
PiperOrigin-RevId: 193854051
Diffstat (limited to 'tensorflow/compiler/xla/service/allocation_tracker.h')
-rw-r--r-- | tensorflow/compiler/xla/service/allocation_tracker.h | 32 |
1 files changed, 22 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index 2bfcd53712..1174fa641c 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -45,13 +45,13 @@ class AllocationTracker { // Registers a shaped buffer of device memory, and returns a corresponding // handle that can be used for talking to XLA clients. The given shaped buffer // will be treated as the buffer corresponding to the only replica. - StatusOr<GlobalDataHandle> Register(ShapedBuffer shaped_buffer, + StatusOr<GlobalDataHandle> Register(ScopedShapedBuffer shaped_buffer, const string& tag); // Registers a vector of shaped buffers of device memory, one per replica, and // returns a corresponding handle that can be used for talking to XLA clients. StatusOr<GlobalDataHandle> RegisterReplicatedBuffers( - std::vector<ShapedBuffer> replicated_buffers, const string& tag); + std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag); // Unregister the allocation for the given data handle. Status Unregister(const GlobalDataHandle& data); @@ -87,21 +87,21 @@ class AllocationTracker { }; // Internal helper which resolves the given GlobalDataHandle to a - // ShapedBuffer. + // list of ScopedShapedBuffers. StatusOr<std::vector<const ShapedBuffer*>> ResolveInternal( const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Internal helper which registers a vector of shaped buffers, one per - // replica. + // replica. ShapedBufferTy is either ScopedShapedBuffer or ShapedBuffer. If + // it's ShapedBuffer, all of the given buffers must already be tracked by this + // object -- presumably this is a call from DeconstructTuple. + template <typename ShapedBufferTy> StatusOr<GlobalDataHandle> RegisterInternal( - std::vector<ShapedBuffer> replicated_buffers, const string& tag) + std::vector<ShapedBufferTy> replicated_buffers, const string& tag) EXCLUSIVE_LOCKS_REQUIRED(mutex_); - // Resets the shaped buffers corresponding to the given handle. - Status Reset(const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_); - // Adds the given device address to the allocation tracker, or if it already - // exists, then increment it's reference count. + // exists, then increment its reference count. void AddAllocationOrIncrementRefCount(se::DeviceMemoryBase device_memory, int device_ordinal) EXCLUSIVE_LOCKS_REQUIRED(mutex_); @@ -133,7 +133,19 @@ class AllocationTracker { // buffers for different replicas. // // The ShapedBuffers in this map's vectors need to be unique_ptrs, because our - // public API returns pointers to them. + // public API returns pointers to them. We expect the concrete class to be + // ShapedBuffer and never ScopedShapedBuffer; deallocation of buffers is + // handled by opaque_to_allocation_map_. + // + // The elements of the vectors need to be unique_ptrs because we return + // pointers to them. (In theory we could use std::list or something instead, + // but we also want to be able to null out these elements.) + // + // The reason that the elements can't be unique_ptr<ScopedShapedBuffer>s is + // the existence of DeconstructTuple(). This function allows us to create a + // non-owning "view" into a tuple's sub-buffers. The sub-buffers are then + // free'd when both the view *and* the original tuple are Unregistered. This + // refcounting is managed in opaque_to_allocation_map_. tensorflow::gtl::FlatMap<int64, std::vector<std::unique_ptr<ShapedBuffer>>> handle_to_shaped_buffers_ GUARDED_BY(mutex_); |