diff options
Diffstat (limited to 'tensorflow/compiler/jit/xla_device_context.h')
-rw-r--r-- | tensorflow/compiler/jit/xla_device_context.h | 31 |
1 files changed, 19 insertions, 12 deletions
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 912f8d779e..2e7445340c 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -47,10 +47,12 @@ class XlaDeviceAllocator : public Allocator { class XlaTransferManager { public: explicit XlaTransferManager( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, - bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + std::shared_ptr<se::Stream> compute_stream, + std::shared_ptr<se::Stream> host_to_device_stream, + std::shared_ptr<se::Stream> device_to_host_stream, + xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; @@ -61,7 +63,7 @@ class XlaTransferManager { void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); - se::Stream* stream() const { return stream_; } + se::Stream* stream() const { return stream_.get(); } private: Status TransferLiteralToDevice(const Tensor& host_tensor, @@ -73,13 +75,13 @@ class XlaTransferManager { // The main compute stream of the device, used to synchronize the transfer // streams if they are set. - se::Stream* stream_; + std::shared_ptr<se::Stream> stream_; // The stream to use for transferring data from host to device. Can be // idential to stream_, but must not be nullptr. - se::Stream* host_to_device_stream_; + std::shared_ptr<se::Stream> host_to_device_stream_; // The stream to use for transferring data from device to host. Can be // idential to stream_, but must not be nullptr. - se::Stream* device_to_host_stream_; + std::shared_ptr<se::Stream> device_to_host_stream_; // For the underlying memory allocator and XLA's TransferManager. xla::LocalClient* client_; // Transfer manager, for marshalling data to and from the device. @@ -87,6 +89,9 @@ class XlaTransferManager { // True if we must use XLA's TransferManager for correct device transfers. const bool transfer_as_literal_; XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + + // Thread pool used for running closures + thread::ThreadPool* thread_pool_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -95,10 +100,12 @@ class XlaTransferManager { class XlaDeviceContext : public DeviceContext { public: explicit XlaDeviceContext( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, - bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + std::shared_ptr<se::Stream> compute_stream, + std::shared_ptr<se::Stream> host_to_device_stream, + std::shared_ptr<se::Stream> device_to_host_stream, + xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, |