aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_device_context.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/xla_device_context.h')
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h31
1 files changed, 19 insertions, 12 deletions
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index 912f8d779e..2e7445340c 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -47,10 +47,12 @@ class XlaDeviceAllocator : public Allocator {
class XlaTransferManager {
public:
explicit XlaTransferManager(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
- bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream,
+ xla::LocalClient* client, bool transfer_as_literal,
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, StatusCallback done) const;
@@ -61,7 +63,7 @@ class XlaTransferManager {
void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
const StatusCallback& done);
- se::Stream* stream() const { return stream_; }
+ se::Stream* stream() const { return stream_.get(); }
private:
Status TransferLiteralToDevice(const Tensor& host_tensor,
@@ -73,13 +75,13 @@ class XlaTransferManager {
// The main compute stream of the device, used to synchronize the transfer
// streams if they are set.
- se::Stream* stream_;
+ std::shared_ptr<se::Stream> stream_;
// The stream to use for transferring data from host to device. Can be
// idential to stream_, but must not be nullptr.
- se::Stream* host_to_device_stream_;
+ std::shared_ptr<se::Stream> host_to_device_stream_;
// The stream to use for transferring data from device to host. Can be
// idential to stream_, but must not be nullptr.
- se::Stream* device_to_host_stream_;
+ std::shared_ptr<se::Stream> device_to_host_stream_;
// For the underlying memory allocator and XLA's TransferManager.
xla::LocalClient* client_;
// Transfer manager, for marshalling data to and from the device.
@@ -87,6 +89,9 @@ class XlaTransferManager {
// True if we must use XLA's TransferManager for correct device transfers.
const bool transfer_as_literal_;
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
+
+ // Thread pool used for running closures
+ thread::ThreadPool* thread_pool_;
};
// DeviceContext for operators assigned to XlaDevice devices. The
@@ -95,10 +100,12 @@ class XlaTransferManager {
class XlaDeviceContext : public DeviceContext {
public:
explicit XlaDeviceContext(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
- bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream,
+ xla::LocalClient* client, bool transfer_as_literal,
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor,