diff options
author | 2018-08-29 10:59:57 -0700 | |
---|---|---|
committer | 2018-08-29 11:06:12 -0700 | |
commit | a2d2ac24a2ca2a50524f0acab1643b7d186e5074 (patch) | |
tree | 58eca6d8f9744217c5e8186d3743dad984061f06 /tensorflow/compiler/jit | |
parent | 3df3d078172a618f3453069928f76e43785a56a0 (diff) |
Fix a race condition in XlaLocalLaunchBase.
XlaLocalLaunchBase was modifying platform_id_ without a lock which is racy
because the same OpKernel can be execute concurrently. Fix this by inferring
platform_id_ in the kernel constructor.
While at it, make use_multiple_streams_ and xla_device_metadata_ member
variables also.
PiperOrigin-RevId: 210751494
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r-- | tensorflow/compiler/jit/kernels/xla_launch_op.cc | 37 | ||||
-rw-r--r-- | tensorflow/compiler/jit/kernels/xla_launch_op.h | 5 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_device.cc | 19 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_device.h | 7 |
4 files changed, 38 insertions, 30 deletions
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index fde4135bf7..b6f2f632f7 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" #include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -57,18 +56,17 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, ->stream->parent() ->platform() ->id(); - } else { - platform_id_ = nullptr; + } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata_).ok()) { + use_multiple_streams_ = xla_device_metadata_->UseMultipleStreams(); + platform_id_ = xla_device_metadata_->platform()->id(); } } Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx, XlaCompilationCache** cache) { - const XlaDevice::Metadata* metadata; - Status s = XlaDevice::GetMetadata(ctx, &metadata); - if (s.ok()) { - *cache = new XlaCompilationCache(metadata->client(), - metadata->jit_device_type()); + if (xla_device_metadata_) { + *cache = new XlaCompilationCache(xla_device_metadata_->client(), + xla_device_metadata_->jit_device_type()); return Status::OK(); } @@ -117,18 +115,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { // this is more obviously correct.) core::ScopedUnref cache_ref(cache); - const XlaDevice::Metadata* metadata = nullptr; - Status s = XlaDevice::GetMetadata(ctx, &metadata); - bool allocate_xla_tensors = s.ok(); - bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams(); - - // Get the platform_id_ for XLA_* devices. - if (platform_id_ == nullptr) { - if (s.ok()) { - platform_id_ = metadata->platform()->id(); - } - } - std::map<int, OptionalTensor> variables = SnapshotResourceVariables(ctx, resources_); @@ -146,7 +132,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { // (which local_xla_allocator above uses) as on an XlaDevice, this is a // dummy allocator that returns XlaTensor objects. The XlaCompiler needs a // real allocator to allocate real buffers. - if (allocate_xla_tensors) { + if (xla_device_metadata_) { xla_allocator = client->backend().memory_allocator(); } else { xla_allocator = &local_xla_allocator; @@ -163,8 +149,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); options.device_allocator = xla_allocator; - if (metadata) { - options.shape_representation_fn = metadata->shape_representation_fn(); + if (xla_device_metadata_) { + options.shape_representation_fn = + xla_device_metadata_->shape_representation_fn(); } const XlaCompiler::CompilationResult* kernel; @@ -192,7 +179,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { VLOG(1) << "Executing XLA Computation..."; XlaComputationLaunchContext launch_context( - client, xla_allocator, allocate_xla_tensors, use_multiple_streams); + client, xla_allocator, + /*allocate_xla_tensors=*/xla_device_metadata_ != nullptr, + use_multiple_streams_); launch_context.PopulateInputs(ctx, kernel, variables); // Execute the computation. diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h index bf1e990668..e0f10e9817 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.h +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ #include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -58,7 +59,9 @@ class XlaLocalLaunchBase : public OpKernel { DeviceType device_type_; NameAttrList function_; - se::Platform::Id platform_id_; + se::Platform::Id platform_id_ = nullptr; + bool use_multiple_streams_ = false; + const XlaDevice::Metadata* xla_device_metadata_ = nullptr; }; // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 50c902fdfc..f31879a2bc 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -185,14 +185,13 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return device_type_; } -/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, - const Metadata** metadata) { +/*static*/ Status XlaDevice::GetMetadataFromDevice( + DeviceBase* device, const XlaDevice::Metadata** metadata) { *metadata = nullptr; - XlaDevice* xla_device = - dynamic_cast<XlaDevice*>(ctx->device()->UnderlyingDevice()); + XlaDevice* xla_device = dynamic_cast<XlaDevice*>(device->UnderlyingDevice()); if (xla_device == nullptr) { return errors::Internal( - "Cannot get XLA metadata from non-XLA device \"", ctx->device()->name(), + "Cannot get XLA metadata from non-XLA device \"", device->name(), "\". GetMetadata must only be called on an XLA device. Either an " "internal bug has been triggered, or an XLA-specific op has been " "placed on the wrong device."); @@ -201,6 +200,16 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return Status::OK(); } +/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, + const Metadata** metadata) { + return GetMetadataFromDevice(ctx->device(), metadata); +} + +/* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx, + const Metadata** metadata) { + return GetMetadataFromDevice(ctx->device(), metadata); +} + XlaDevice::XlaDevice( const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index dbf35f349f..92891ffa8c 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -88,6 +88,10 @@ class XlaDevice : public LocalDevice { // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata); + // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. + static Status GetMetadata(OpKernelConstruction* ctx, + const Metadata** metadata); + // Factory function. 'platform_name' is the name of the XLA platform. // 'device_name' is the name of the Tensorflow device to create. // 'jit_device_name' is the name of the corresponding JIT device. @@ -158,6 +162,9 @@ class XlaDevice : public LocalDevice { xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_); + static Status GetMetadataFromDevice(DeviceBase* device, + const XlaDevice::Metadata** metadata); + mutex mu_; // The metadata of this XlaDevice. const Metadata xla_metadata_; |