aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-08-01 16:00:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-01 16:06:11 -0700
commitbed00207976a70370cb6e3615c7ad56a5547cf45 (patch)
treeb154beccd7cbabc839cc72ac6291ad082055879b /tensorflow/stream_executor
parent626317cb35524e4deb0851a65ca5dd5ca61d431f (diff)
[SE] Allow context reuse in CreatedContexts::Add.
It's possible for an already-existing context to be returned by cuDevicePrimaryCtxRetain. Previously, this would be handled incorrectly by CreatedContexts::Add, which was assuming that inserts into the map always succeeded. This makes XLA work with TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE=blocking_sync, although exactly how that flag is related to this bug is unclear to me. It seems like some sort of race condition, maybe? PiperOrigin-RevId: 207010059
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.cc16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc
index dbece3adf9..f982f34b98 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.cc
+++ b/tensorflow/stream_executor/cuda/cuda_driver.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/human_readable.h"
#include "tensorflow/stream_executor/lib/inlined_vector.h"
#include "tensorflow/stream_executor/lib/notification.h"
+#include "tensorflow/stream_executor/lib/ptr_util.h"
#include "tensorflow/stream_executor/lib/stacktrace.h"
#include "tensorflow/stream_executor/lib/static_threadlocal.h"
#include "tensorflow/stream_executor/lib/strcat.h"
@@ -66,14 +67,17 @@ class CreatedContexts {
return Live()->find(context) != Live()->end();
}
- // Adds context to the live set.
+ // Adds context to the live set, or returns it if it's already present.
static CudaContext* Add(CUcontext context) {
CHECK(context != nullptr);
mutex_lock lock(mu_);
- auto cuda_context = new CudaContext(context, next_id_++);
- Live()->insert(
- std::make_pair(context, std::unique_ptr<CudaContext>(cuda_context)));
- return cuda_context;
+ auto insert_result = Live()->insert(std::make_pair(context, nullptr));
+ auto it = insert_result.first;
+ if (insert_result.second) {
+ // context was not present in the map. Add it.
+ it->second = MakeUnique<CudaContext>(context, next_id_++);
+ }
+ return it->second.get();
}
// Removes context from the live set.
@@ -427,7 +431,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options,
*context = CreatedContexts::Add(new_context);
CHECK(*context != nullptr)
<< "success in this call must entail non-null result";
- VLOG(2) << "created context " << context << " for this thread";
+ VLOG(2) << "created or reused context " << context << " for this thread";
return port::Status::OK();
}