diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-05 11:20:28 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-05 11:24:40 -0800 |
commit | d93b843330593375907a554985c1f8ed77dae204 (patch) | |
tree | 5e209cb2d6ee0e0ebd5ce3909a77b5d4da364056 /tensorflow/compiler/xla/service/allocation_tracker.h | |
parent | 8382cbabf2a15f22d22a291fc47776113e6ec77c (diff) |
[XLA] Allocate and track memory in replicas separately.
PiperOrigin-RevId: 187894473
Diffstat (limited to 'tensorflow/compiler/xla/service/allocation_tracker.h')
-rw-r--r-- | tensorflow/compiler/xla/service/allocation_tracker.h | 44 |
1 files changed, 32 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index 807af86949..038aee8541 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -43,10 +43,17 @@ class AllocationTracker { AllocationTracker(Backend* backend) : backend_(backend), next_handle_(1) {} // Registers a shaped buffer of device memory, and returns a corresponding - // handle that can be used for talking to XLA clients. + // handle that can be used for talking to XLA clients. The given shaped buffer + // will be treated as the buffer corresponding to the only replica. StatusOr<GlobalDataHandle> Register( std::unique_ptr<ShapedBuffer> shaped_buffer, const string& tag); + // Registers a vector of shaped buffers of device memory, one per replica, and + // returns a corresponding handle that can be used for talking to XLA clients. + StatusOr<GlobalDataHandle> RegisterReplicatedBuffers( + std::vector<std::unique_ptr<ShapedBuffer>> replicated_buffers, + const string& tag); + // Unregister the allocation for the given data handle. Status Unregister(const GlobalDataHandle& data); @@ -54,9 +61,17 @@ class AllocationTracker { StatusOr<std::vector<GlobalDataHandle>> DeconstructTuple( const GlobalDataHandle& Data); - // Resolve a handle from an XLA client to a shaped buffer, or provide an error - // status to say whether it was not found (or found, but found deallocated). - StatusOr<const ShapedBuffer*> Resolve(const GlobalDataHandle& data); + // Resolve a handle from an XLA client to a vector of shaped buffers, one per + // replica, or provide an error status to say whether any of those buffers + // were not found (or found, but found deallocated). + StatusOr<std::vector<const ShapedBuffer*>> Resolve( + const GlobalDataHandle& data); + + // Resolves a handle from an XLA client and replica id to a shaped buffer, or + // provide an error status to say whether it was not found (or found, but + // found deallocated). + StatusOr<const ShapedBuffer*> ResolveForReplica(const GlobalDataHandle& data, + int replica_id); private: // Data structure encapsulating single memory allocation on the device. @@ -74,13 +89,17 @@ class AllocationTracker { // Internal helper which resolves the given GlobalDataHandle to a // ShapedBuffer. - StatusOr<ShapedBuffer*> ResolveInternal(const GlobalDataHandle& data) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); + StatusOr<std::vector<const ShapedBuffer*>> ResolveInternal( + const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_); - // Internal helper which registers a shaped buffer. + // Internal helper which registers a vector of shaped buffers, one per + // replica. StatusOr<GlobalDataHandle> RegisterInternal( - std::unique_ptr<ShapedBuffer> shaped_buffer, const string& tag) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); + std::vector<std::unique_ptr<ShapedBuffer>> replicated_buffers, + const string& tag) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Resets the shaped buffers corresponding to the given handle. + Status Reset(const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Adds the given device address to the allocation tracker, or if it already // exists, then increment it's reference count. @@ -111,9 +130,10 @@ class AllocationTracker { tensorflow::gtl::FlatMap<int, AllocationMap> opaque_to_allocation_map_ GUARDED_BY(mutex_); - // A map from data handle to ShapedBuffer. - tensorflow::gtl::FlatMap<int64, std::unique_ptr<ShapedBuffer>> - handle_to_shaped_buffer_ GUARDED_BY(mutex_); + // A map from data handle to a vector of shaped buffers that represent the + // buffers for different replicas. + tensorflow::gtl::FlatMap<int64, std::vector<std::unique_ptr<ShapedBuffer>>> + handle_to_shaped_buffers_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker); }; |