aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_device.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/xla_device.h')
-rw-r--r--tensorflow/compiler/jit/xla_device.h14
1 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index d8906419b0..dbf35f349f 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/allocator.h"
@@ -124,7 +123,7 @@ class XlaDevice : public LocalDevice {
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override;
- Status Sync() override { return Status::OK(); }
+ Status Sync() override;
Status FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) override
@@ -153,7 +152,7 @@ class XlaDevice : public LocalDevice {
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
- xla::StreamPool::Ptr* stream,
+ std::shared_ptr<se::Stream>* stream,
bool* stream_was_changed)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked()
@@ -174,17 +173,17 @@ class XlaDevice : public LocalDevice {
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
- xla::StreamPool::Ptr stream_ GUARDED_BY(mu_);
+ std::shared_ptr<se::Stream> stream_ GUARDED_BY(mu_);
// If false, only stream_ is valid and all computation and transfers use
// stream_. If true, computation is performed by stream_ and transfers are
// performed by host_to_device/device_to_host_stream.
const bool use_multiple_streams_;
// If use_multiple_streams_, host to device transfers are performed using this
// stream.
- xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_);
+ std::shared_ptr<se::Stream> host_to_device_stream_ GUARDED_BY(mu_);
// If use_multiple_streams_, device to host transfers are performed using this
// stream.
- xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_);
+ std::shared_ptr<se::Stream> device_to_host_stream_ GUARDED_BY(mu_);
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
const bool transfer_as_literal_;
@@ -198,6 +197,9 @@ class XlaDevice : public LocalDevice {
// Holds extra information for GPU and TPU devices, e.g. the device context.
bool use_gpu_device_info_ GUARDED_BY(mu_) = false;
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_);
+
+ // Thread pool used for running closures
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
};
// Builds OpKernel registrations on 'device' for the JIT operators