aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-08-29 10:59:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-29 11:06:12 -0700
commita2d2ac24a2ca2a50524f0acab1643b7d186e5074 (patch)
tree58eca6d8f9744217c5e8186d3743dad984061f06 /tensorflow/compiler/jit
parent3df3d078172a618f3453069928f76e43785a56a0 (diff)
Fix a race condition in XlaLocalLaunchBase.
XlaLocalLaunchBase was modifying platform_id_ without a lock which is racy because the same OpKernel can be execute concurrently. Fix this by inferring platform_id_ in the kernel constructor. While at it, make use_multiple_streams_ and xla_device_metadata_ member variables also. PiperOrigin-RevId: 210751494
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc37
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.h5
-rw-r--r--tensorflow/compiler/jit/xla_device.cc19
-rw-r--r--tensorflow/compiler/jit/xla_device.h7
4 files changed, 38 insertions, 30 deletions
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index fde4135bf7..b6f2f632f7 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@@ -57,18 +56,17 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
->stream->parent()
->platform()
->id();
- } else {
- platform_id_ = nullptr;
+ } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata_).ok()) {
+ use_multiple_streams_ = xla_device_metadata_->UseMultipleStreams();
+ platform_id_ = xla_device_metadata_->platform()->id();
}
}
Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
XlaCompilationCache** cache) {
- const XlaDevice::Metadata* metadata;
- Status s = XlaDevice::GetMetadata(ctx, &metadata);
- if (s.ok()) {
- *cache = new XlaCompilationCache(metadata->client(),
- metadata->jit_device_type());
+ if (xla_device_metadata_) {
+ *cache = new XlaCompilationCache(xla_device_metadata_->client(),
+ xla_device_metadata_->jit_device_type());
return Status::OK();
}
@@ -117,18 +115,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
- const XlaDevice::Metadata* metadata = nullptr;
- Status s = XlaDevice::GetMetadata(ctx, &metadata);
- bool allocate_xla_tensors = s.ok();
- bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams();
-
- // Get the platform_id_ for XLA_* devices.
- if (platform_id_ == nullptr) {
- if (s.ok()) {
- platform_id_ = metadata->platform()->id();
- }
- }
-
std::map<int, OptionalTensor> variables =
SnapshotResourceVariables(ctx, resources_);
@@ -146,7 +132,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
// (which local_xla_allocator above uses) as on an XlaDevice, this is a
// dummy allocator that returns XlaTensor objects. The XlaCompiler needs a
// real allocator to allocate real buffers.
- if (allocate_xla_tensors) {
+ if (xla_device_metadata_) {
xla_allocator = client->backend().memory_allocator();
} else {
xla_allocator = &local_xla_allocator;
@@ -163,8 +149,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
options.device_allocator = xla_allocator;
- if (metadata) {
- options.shape_representation_fn = metadata->shape_representation_fn();
+ if (xla_device_metadata_) {
+ options.shape_representation_fn =
+ xla_device_metadata_->shape_representation_fn();
}
const XlaCompiler::CompilationResult* kernel;
@@ -192,7 +179,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
VLOG(1) << "Executing XLA Computation...";
XlaComputationLaunchContext launch_context(
- client, xla_allocator, allocate_xla_tensors, use_multiple_streams);
+ client, xla_allocator,
+ /*allocate_xla_tensors=*/xla_device_metadata_ != nullptr,
+ use_multiple_streams_);
launch_context.PopulateInputs(ctx, kernel, variables);
// Execute the computation.
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h
index bf1e990668..e0f10e9817 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.h
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -58,7 +59,9 @@ class XlaLocalLaunchBase : public OpKernel {
DeviceType device_type_;
NameAttrList function_;
- se::Platform::Id platform_id_;
+ se::Platform::Id platform_id_ = nullptr;
+ bool use_multiple_streams_ = false;
+ const XlaDevice::Metadata* xla_device_metadata_ = nullptr;
};
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 50c902fdfc..f31879a2bc 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -185,14 +185,13 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
return device_type_;
}
-/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
- const Metadata** metadata) {
+/*static*/ Status XlaDevice::GetMetadataFromDevice(
+ DeviceBase* device, const XlaDevice::Metadata** metadata) {
*metadata = nullptr;
- XlaDevice* xla_device =
- dynamic_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
+ XlaDevice* xla_device = dynamic_cast<XlaDevice*>(device->UnderlyingDevice());
if (xla_device == nullptr) {
return errors::Internal(
- "Cannot get XLA metadata from non-XLA device \"", ctx->device()->name(),
+ "Cannot get XLA metadata from non-XLA device \"", device->name(),
"\". GetMetadata must only be called on an XLA device. Either an "
"internal bug has been triggered, or an XLA-specific op has been "
"placed on the wrong device.");
@@ -201,6 +200,16 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
return Status::OK();
}
+/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
+ const Metadata** metadata) {
+ return GetMetadataFromDevice(ctx->device(), metadata);
+}
+
+/* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx,
+ const Metadata** metadata) {
+ return GetMetadataFromDevice(ctx->device(), metadata);
+}
+
XlaDevice::XlaDevice(
const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index dbf35f349f..92891ffa8c 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -88,6 +88,10 @@ class XlaDevice : public LocalDevice {
// Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata);
+ // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
+ static Status GetMetadata(OpKernelConstruction* ctx,
+ const Metadata** metadata);
+
// Factory function. 'platform_name' is the name of the XLA platform.
// 'device_name' is the name of the Tensorflow device to create.
// 'jit_device_name' is the name of the corresponding JIT device.
@@ -158,6 +162,9 @@ class XlaDevice : public LocalDevice {
xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked()
EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ static Status GetMetadataFromDevice(DeviceBase* device,
+ const XlaDevice::Metadata** metadata);
+
mutex mu_;
// The metadata of this XlaDevice.
const Metadata xla_metadata_;