aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/allocation_tracker.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-05 11:20:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-05 11:24:40 -0800
commitd93b843330593375907a554985c1f8ed77dae204 (patch)
tree5e209cb2d6ee0e0ebd5ce3909a77b5d4da364056 /tensorflow/compiler/xla/service/allocation_tracker.h
parent8382cbabf2a15f22d22a291fc47776113e6ec77c (diff)
[XLA] Allocate and track memory in replicas separately.
PiperOrigin-RevId: 187894473
Diffstat (limited to 'tensorflow/compiler/xla/service/allocation_tracker.h')
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.h44
1 files changed, 32 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h
index 807af86949..038aee8541 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.h
+++ b/tensorflow/compiler/xla/service/allocation_tracker.h
@@ -43,10 +43,17 @@ class AllocationTracker {
AllocationTracker(Backend* backend) : backend_(backend), next_handle_(1) {}
// Registers a shaped buffer of device memory, and returns a corresponding
- // handle that can be used for talking to XLA clients.
+ // handle that can be used for talking to XLA clients. The given shaped buffer
+ // will be treated as the buffer corresponding to the only replica.
StatusOr<GlobalDataHandle> Register(
std::unique_ptr<ShapedBuffer> shaped_buffer, const string& tag);
+ // Registers a vector of shaped buffers of device memory, one per replica, and
+ // returns a corresponding handle that can be used for talking to XLA clients.
+ StatusOr<GlobalDataHandle> RegisterReplicatedBuffers(
+ std::vector<std::unique_ptr<ShapedBuffer>> replicated_buffers,
+ const string& tag);
+
// Unregister the allocation for the given data handle.
Status Unregister(const GlobalDataHandle& data);
@@ -54,9 +61,17 @@ class AllocationTracker {
StatusOr<std::vector<GlobalDataHandle>> DeconstructTuple(
const GlobalDataHandle& Data);
- // Resolve a handle from an XLA client to a shaped buffer, or provide an error
- // status to say whether it was not found (or found, but found deallocated).
- StatusOr<const ShapedBuffer*> Resolve(const GlobalDataHandle& data);
+ // Resolve a handle from an XLA client to a vector of shaped buffers, one per
+ // replica, or provide an error status to say whether any of those buffers
+ // were not found (or found, but found deallocated).
+ StatusOr<std::vector<const ShapedBuffer*>> Resolve(
+ const GlobalDataHandle& data);
+
+ // Resolves a handle from an XLA client and replica id to a shaped buffer, or
+ // provide an error status to say whether it was not found (or found, but
+ // found deallocated).
+ StatusOr<const ShapedBuffer*> ResolveForReplica(const GlobalDataHandle& data,
+ int replica_id);
private:
// Data structure encapsulating single memory allocation on the device.
@@ -74,13 +89,17 @@ class AllocationTracker {
// Internal helper which resolves the given GlobalDataHandle to a
// ShapedBuffer.
- StatusOr<ShapedBuffer*> ResolveInternal(const GlobalDataHandle& data)
- EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+ StatusOr<std::vector<const ShapedBuffer*>> ResolveInternal(
+ const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_);
- // Internal helper which registers a shaped buffer.
+ // Internal helper which registers a vector of shaped buffers, one per
+ // replica.
StatusOr<GlobalDataHandle> RegisterInternal(
- std::unique_ptr<ShapedBuffer> shaped_buffer, const string& tag)
- EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+ std::vector<std::unique_ptr<ShapedBuffer>> replicated_buffers,
+ const string& tag) EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Resets the shaped buffers corresponding to the given handle.
+ Status Reset(const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_);
// Adds the given device address to the allocation tracker, or if it already
// exists, then increment it's reference count.
@@ -111,9 +130,10 @@ class AllocationTracker {
tensorflow::gtl::FlatMap<int, AllocationMap> opaque_to_allocation_map_
GUARDED_BY(mutex_);
- // A map from data handle to ShapedBuffer.
- tensorflow::gtl::FlatMap<int64, std::unique_ptr<ShapedBuffer>>
- handle_to_shaped_buffer_ GUARDED_BY(mutex_);
+ // A map from data handle to a vector of shaped buffers that represent the
+ // buffers for different replicas.
+ tensorflow::gtl::FlatMap<int64, std::vector<std::unique_ptr<ShapedBuffer>>>
+ handle_to_shaped_buffers_ GUARDED_BY(mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker);
};