aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_device.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/xla_device.h')
-rw-r--r--tensorflow/compiler/jit/xla_device.h73
1 files changed, 49 insertions, 24 deletions
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index 4a5942fbd7..d8906419b0 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -25,6 +25,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
+#include "tensorflow/compiler/jit/xla_device_context.h"
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -40,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace tensorflow {
@@ -117,62 +119,85 @@ class XlaDevice : public LocalDevice {
const PaddedShapeFn& padded_shape_fn);
~XlaDevice() override;
- Allocator* GetAllocator(AllocatorAttributes attr) override;
+ Allocator* GetAllocator(AllocatorAttributes attr) override
+ LOCKS_EXCLUDED(mu_);
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override;
Status Sync() override { return Status::OK(); }
Status FillContextMap(const Graph* graph,
- DeviceContextMap* device_context_map) override;
+ DeviceContextMap* device_context_map) override
+ LOCKS_EXCLUDED(mu_);
Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
- Tensor* tensor) override;
+ Tensor* tensor) override LOCKS_EXCLUDED(mu_);
- xla::LocalClient* client() const;
const Metadata& metadata() { return xla_metadata_; }
- xla::StatusOr<se::Stream*> GetStream();
- xla::StatusOr<se::Stream*> GetHostToDeviceStream();
- xla::StatusOr<se::Stream*> GetDeviceToHostStream();
- // If not already set, create and set GpuDeviceInfo.
- // Not thread-safe
- Status CreateAndSetGpuDeviceInfo();
+ // Ensures the DeviceContext associated with this XlaDevice is created and
+ // valid (i.e. all streams are ok). If any state is not valid, a new
+ // DeviceContext will be created.
+ //
+ // TODO(b/111859745): The Eager context needs to call this method to recover
+ // from failures.
+ Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_);
+
+ // Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra
+ // information for GPU and TPU devices.
+ Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
private:
+ xla::LocalClient* client() const;
+ Allocator* GetAllocatorLocked(AllocatorAttributes attr)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
+ xla::StreamPool::Ptr* stream,
+ bool* stream_was_changed)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked()
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ mutex mu_;
// The metadata of this XlaDevice.
const Metadata xla_metadata_;
// Which hardware device in the client's platform this XlaDevice controls.
const int device_ordinal_;
// The name of the device that is used to compile Ops for this XlaDevice.
- DeviceType jit_device_name_;
+ const DeviceType jit_device_name_;
+ // The platform for this device.
+ se::Platform* const platform_; // Not owned.
// Memory allocator associated with this device.
- Allocator* xla_allocator_; // Not owned.
- se::Platform* platform_; // Not owned.
+ Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned.
// Stream associated with this device. Operations enqueued on this
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
- xla::StreamPool::Ptr stream_;
- // If true, only stream_ is valid and all computation and transfers use
- // stream_. If false, computation is performed by stream_ and transfers are
+ xla::StreamPool::Ptr stream_ GUARDED_BY(mu_);
+ // If false, only stream_ is valid and all computation and transfers use
+ // stream_. If true, computation is performed by stream_ and transfers are
// performed by host_to_device/device_to_host_stream.
- bool use_multiple_streams_;
+ const bool use_multiple_streams_;
// If use_multiple_streams_, host to device transfers are performed using this
// stream.
- xla::StreamPool::Ptr host_to_device_stream_;
+ xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_);
// If use_multiple_streams_, device to host transfers are performed using this
// stream.
- xla::StreamPool::Ptr device_to_host_stream_;
+ xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_);
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
- bool transfer_as_literal_;
- XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
+ const bool transfer_as_literal_;
+ const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
+
+ // The device context accessed by all users of the XlaDevice, set by calls to
+ // EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is
+ // also filled in to that struct. XlaDeviceContext is a ref-counted object.
+ XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr;
- // If set, holds default device context (that we must Unref)
- // and its stream.
- std::unique_ptr<GpuDeviceInfo> gpu_device_info_;
+ // Holds extra information for GPU and TPU devices, e.g. the device context.
+ bool use_gpu_device_info_ GUARDED_BY(mu_) = false;
+ std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_);
};
// Builds OpKernel registrations on 'device' for the JIT operators