diff options
Diffstat (limited to 'tensorflow/compiler/jit/xla_device.h')
-rw-r--r-- | tensorflow/compiler/jit/xla_device.h | 73 |
1 files changed, 49 insertions, 24 deletions
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 4a5942fbd7..d8906419b0 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -25,6 +25,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ +#include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -40,6 +41,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace tensorflow { @@ -117,62 +119,85 @@ class XlaDevice : public LocalDevice { const PaddedShapeFn& padded_shape_fn); ~XlaDevice() override; - Allocator* GetAllocator(AllocatorAttributes attr) override; + Allocator* GetAllocator(AllocatorAttributes attr) override + LOCKS_EXCLUDED(mu_); void Compute(OpKernel* op_kernel, OpKernelContext* context) override; void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; Status Sync() override { return Status::OK(); } Status FillContextMap(const Graph* graph, - DeviceContextMap* device_context_map) override; + DeviceContextMap* device_context_map) override + LOCKS_EXCLUDED(mu_); Status MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, - Tensor* tensor) override; + Tensor* tensor) override LOCKS_EXCLUDED(mu_); - xla::LocalClient* client() const; const Metadata& metadata() { return xla_metadata_; } - xla::StatusOr<se::Stream*> GetStream(); - xla::StatusOr<se::Stream*> GetHostToDeviceStream(); - xla::StatusOr<se::Stream*> GetDeviceToHostStream(); - // If not already set, create and set GpuDeviceInfo. - // Not thread-safe - Status CreateAndSetGpuDeviceInfo(); + // Ensures the DeviceContext associated with this XlaDevice is created and + // valid (i.e. all streams are ok). If any state is not valid, a new + // DeviceContext will be created. + // + // TODO(b/111859745): The Eager context needs to call this method to recover + // from failures. + Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_); + + // Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra + // information for GPU and TPU devices. + Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_); private: + xla::LocalClient* client() const; + Allocator* GetAllocatorLocked(AllocatorAttributes attr) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, + xla::StreamPool::Ptr* stream, + bool* stream_was_changed) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked() + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + mutex mu_; // The metadata of this XlaDevice. const Metadata xla_metadata_; // Which hardware device in the client's platform this XlaDevice controls. const int device_ordinal_; // The name of the device that is used to compile Ops for this XlaDevice. - DeviceType jit_device_name_; + const DeviceType jit_device_name_; + // The platform for this device. + se::Platform* const platform_; // Not owned. // Memory allocator associated with this device. - Allocator* xla_allocator_; // Not owned. - se::Platform* platform_; // Not owned. + Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned. // Stream associated with this device. Operations enqueued on this // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and // computations enqueued by XLA. - xla::StreamPool::Ptr stream_; - // If true, only stream_ is valid and all computation and transfers use - // stream_. If false, computation is performed by stream_ and transfers are + xla::StreamPool::Ptr stream_ GUARDED_BY(mu_); + // If false, only stream_ is valid and all computation and transfers use + // stream_. If true, computation is performed by stream_ and transfers are // performed by host_to_device/device_to_host_stream. - bool use_multiple_streams_; + const bool use_multiple_streams_; // If use_multiple_streams_, host to device transfers are performed using this // stream. - xla::StreamPool::Ptr host_to_device_stream_; + xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_); // If use_multiple_streams_, device to host transfers are performed using this // stream. - xla::StreamPool::Ptr device_to_host_stream_; + xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_); // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. - bool transfer_as_literal_; - XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + const bool transfer_as_literal_; + const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + + // The device context accessed by all users of the XlaDevice, set by calls to + // EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is + // also filled in to that struct. XlaDeviceContext is a ref-counted object. + XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr; - // If set, holds default device context (that we must Unref) - // and its stream. - std::unique_ptr<GpuDeviceInfo> gpu_device_info_; + // Holds extra information for GPU and TPU devices, e.g. the device context. + bool use_gpu_device_info_ GUARDED_BY(mu_) = false; + std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_); }; // Builds OpKernel registrations on 'device' for the JIT operators |