aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/stream_executor/cuda/cuda_activation.cc15
-rw-r--r--tensorflow/stream_executor/cuda/cuda_activation.h5
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.cc295
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.h132
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.cc6
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.h4
-rw-r--r--tensorflow/stream_executor/cuda/cuda_timer.cc4
7 files changed, 256 insertions, 205 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_activation.cc b/tensorflow/stream_executor/cuda/cuda_activation.cc
index ccf9fd754c..7c92daa76e 100644
--- a/tensorflow/stream_executor/cuda/cuda_activation.cc
+++ b/tensorflow/stream_executor/cuda/cuda_activation.cc
@@ -23,18 +23,23 @@ namespace perftools {
namespace gputools {
namespace cuda {
-CUcontext ExtractCudaContext(CUDAExecutor *cuda_exec);
+CudaContext* ExtractCudaContext(CUDAExecutor *cuda_exec);
CUDAExecutor *ExtractCudaExecutor(StreamExecutor *stream_exec);
ScopedActivateExecutorContext::ScopedActivateExecutorContext(
- CUDAExecutor *cuda_exec, MultiOpActivation moa)
+ CUDAExecutor *cuda_exec)
: cuda_exec_(cuda_exec),
driver_scoped_activate_context_(
- new ScopedActivateContext{ExtractCudaContext(cuda_exec), moa}) {}
+ new ScopedActivateContext{ExtractCudaContext(cuda_exec)}) { }
ScopedActivateExecutorContext::ScopedActivateExecutorContext(
- StreamExecutor *stream_exec, MultiOpActivation moa)
- : ScopedActivateExecutorContext(ExtractCudaExecutor(stream_exec), moa) {}
+ StreamExecutor *stream_exec, MultiOpActivation unused)
+ : ScopedActivateExecutorContext(ExtractCudaExecutor(stream_exec)) {
+ // Note that the second argument is unused. We are migrating to code that
+ // always allows the multi-op activation case; the signature is kept
+ // the same until all of the code is in.
+ // TODO(cwhipkey): remove the extra parameter.
+}
ScopedActivateExecutorContext::~ScopedActivateExecutorContext() {
delete static_cast<ScopedActivateContext *>(driver_scoped_activate_context_);
diff --git a/tensorflow/stream_executor/cuda/cuda_activation.h b/tensorflow/stream_executor/cuda/cuda_activation.h
index c78d03396e..139519498b 100644
--- a/tensorflow/stream_executor/cuda/cuda_activation.h
+++ b/tensorflow/stream_executor/cuda/cuda_activation.h
@@ -40,14 +40,13 @@ class ScopedActivateContext;
class ScopedActivateExecutorContext {
public:
// Form that takes a CUDA executor implementation.
- explicit ScopedActivateExecutorContext(
- CUDAExecutor* cuda_exec, MultiOpActivation moa = MultiOpActivation::kNo);
+ explicit ScopedActivateExecutorContext(CUDAExecutor* cuda_exec);
// Form that takes a pImpl executor and extracts a CUDA implementation --
// fatal failure if it is not CUDA inside.
explicit ScopedActivateExecutorContext(
StreamExecutor* stream_exec,
- MultiOpActivation moa = MultiOpActivation::kNo);
+ MultiOpActivation unused = MultiOpActivation::kNo);
~ScopedActivateExecutorContext();
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc
index 7f992c6073..ddc0e04ab4 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.cc
+++ b/tensorflow/stream_executor/cuda/cuda_driver.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_driver.h"
#include <dlfcn.h>
+#include <map>
#include <stdint.h>
#include <stdlib.h>
#include <set>
@@ -42,6 +43,10 @@ bool FLAGS_gpuexec_cuda_driver_inject_init_error = false;
bool FLAGS_gpuexec_cuda_sync_around_driver_calls = false;
bool FLAGS_gpuexec_cuda_device_0_only = false;
+// Debugging: on each push and pop of a cuda context, verify the current context
+// matches the expected one.
+constexpr bool kVerifyCudaContext = false;
+
namespace perftools {
namespace gputools {
namespace cuda {
@@ -137,9 +142,12 @@ PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuStreamWaitEvent);
namespace {
-// Manages the singleton set of contexts that we've created. This is used for
-// checking that no CUDA-runtime-created contexts have been generated
-// accidentally. CUDA-runtime-created contexts are avoided, if triple angle
+// Manages the singleton map of contexts that we've created, mapping
+// from the CUcontext to the CudaContext* that we pass around internally.
+// This also manages assignment of unique ids to CudaContexts, to allow
+// for fast comparison of a context against the current context.
+//
+// CUDA-runtime-created contexts are avoided, if triple angle
// brace launches are required, by using the scoped activations in
// cuda_activation.h.
class CreatedContexts {
@@ -151,31 +159,39 @@ class CreatedContexts {
}
// Adds context to the live set.
- static void Add(CUcontext context) {
+ static CudaContext* Add(CUcontext context) {
CHECK(context != nullptr);
mutex_lock lock{mu_};
- Live()->emplace(context);
+ auto cuda_context = new CudaContext(context, next_id_++);
+ Live()->insert(
+ make_pair(context, std::unique_ptr<CudaContext>(cuda_context)));
+ return cuda_context;
}
// Removes context from the live set.
static void Remove(CUcontext context) {
CHECK(context != nullptr);
mutex_lock lock{mu_};
- Live()->erase(context);
+ auto it = Live()->find(context);
+ CHECK(it != Live()->end()) << context;
+ Live()->erase(it);
}
private:
- // Returns the live set singleton.
- static std::set<CUcontext> *Live() {
- static auto singleton = new std::set<CUcontext>;
+ // Returns the live map singleton.
+ static std::map<CUcontext, std::unique_ptr<CudaContext>> *Live() {
+ static auto singleton =
+ new std::map<CUcontext, std::unique_ptr<CudaContext>>;
return singleton;
}
// Lock that guards access-to/mutation-of the live set.
static mutex mu_;
+ static int64 next_id_;
};
/* static */ mutex CreatedContexts::mu_{LINKER_INITIALIZED};
+/* static */ int64 CreatedContexts::next_id_ = 1; // 0 means "no context"
// Formats CUresult to output prettified values into a log stream.
// Error summaries taken from:
@@ -295,11 +311,7 @@ string ToString(CUresult result) {
// created by StreamExecutor (to ensure that the CUDA runtime didn't create a
// context behind our backs).
CUcontext CurrentContext() {
- CUcontext current = nullptr;
- CUresult result = dynload::cuCtxGetCurrent(&current);
- if (result != CUDA_SUCCESS) {
- LOG(FATAL) << "failed to query current context: " << ToString(result);
- }
+ CUcontext current = CUDADriver::CurrentContextOrDie();
if (current != nullptr && !CreatedContexts::Has(current)) {
LOG(FATAL) << "current context was not created by the StreamExecutor "
"cuda_driver API: "
@@ -310,22 +322,6 @@ CUcontext CurrentContext() {
return current;
}
-// "Pops" the current context, checks that it matches expected, and checks the
-// postcondition that the current context is nullptr.
-//
-// This is not done when we're nested within a MultiOpActivation, as we want to
-// persist the active context until the MultiOpActivation is popped.
-void PopContextAndCheckNowNull(CUcontext expected) {
- CUcontext actual = CurrentContext();
- CHECK_EQ(expected, actual) << "would pop unexpected context";
- CUcontext popped;
- CHECK_EQ(CUDA_SUCCESS, dynload::cuCtxPopCurrent_v2(&popped));
- CHECK_EQ(expected, popped);
- DCHECK(nullptr == CurrentContext());
- VLOG(3) << "popped context " << expected
- << " and current context is now null";
-}
-
// CUDA driver routines may require a large amount of stack (particularly
// cuModuleLoadDataEx, in our experience). To avoid stack overflow when using
// stack-limited threads (such as those spawned by a default-argument
@@ -345,12 +341,6 @@ port::ThreadPool *GetDriverExecutor() {
} // namespace
-
-// Thread-local storage that indicates whether a CUDA context activation is
-// being nested within an outer, MultiOpActivation. In that case, we should not
-// pop the context to nullptr when we are done with the current activation.
-SE_STATIC_THREAD_LOCAL_POD(bool, tls_in_multi_op_activation);
-
string MemorySpaceString(MemorySpace memory_space) {
switch (memory_space) {
case MemorySpace::kHost:
@@ -362,56 +352,74 @@ string MemorySpaceString(MemorySpace memory_space) {
}
}
-// Implementation note: the CUDA context is held, per-thread, in TLS. We avoid
-// setting all the time because it's not clear what side effects might occur for
-// a "set" operation, whereas a "get" operation we can reasonably assume is a
-// TLS read.
-//
-// We cannot race here because CUcontext is associated with a particular thread
-// and stored in TLS; and these interfaces should not be used from signal
-// handlers.
-ScopedActivateContext::ScopedActivateContext(CUcontext context,
- MultiOpActivation moa)
- : context_(CHECK_NOTNULL(context)),
- previously_in_multi_op_activation_(tls_in_multi_op_activation.get()) {
- if (static_cast<bool>(moa)) {
- tls_in_multi_op_activation.get() = true;
- }
-
- CUcontext current = prior_context_ = CurrentContext();
- if (current != context) {
- VLOG(3) << "ScopedActivateContext switching context from " << current
- << " to " << context;
- CHECK_EQ(CUDA_SUCCESS, dynload::cuCtxSetCurrent(context));
- if (FLAGS_gpuexec_cuda_sync_around_driver_calls) {
- auto res = dynload::cuCtxSynchronize();
- if (res != CUDA_SUCCESS) {
- LOG(FATAL) << "gpuexec_cuda_sync_around_driver_calls found "
- << ToString(res)
- << " immediately after establishing the device context "
- << context << " :: " << port::CurrentStackTrace();
- }
+namespace {
+
+// Call cuCtxtSynchronize and crash if it doesn't succeed.
+void SynchronizeOrDie() {
+ auto res = dynload::cuCtxSynchronize();
+ if (res != CUDA_SUCCESS) {
+ LOG(FATAL) << "Synchronize found "
+ << ToString(res) << " :: " << port::CurrentStackTrace();
+ }
+}
+
+struct ThreadLocalData {
+ int64 id;
+ CudaContext* context; // Only valid if id == a known good context.
+ int depth;
+};
+
+SE_STATIC_THREAD_LOCAL_POD(ThreadLocalData, tls_data);
+
+} // namespace
+
+ScopedActivateContext::ScopedActivateContext(CudaContext* cuda_context) {
+ if (FLAGS_gpuexec_cuda_sync_around_driver_calls) SynchronizeOrDie();
+
+ auto* tls = &tls_data.get();
+ tls->depth++;
+ if (tls->id == cuda_context->id()) {
+ if (kVerifyCudaContext) {
+ CHECK_EQ(CurrentContext(), cuda_context->context());
}
+ DCHECK_EQ(CurrentContext(), cuda_context->context());
+ return;
}
+
+ VLOG(3) << "ScopedActivateContext switching context from " << tls->id
+ << " to " << cuda_context->id();
+
+ to_restore_ = (tls->depth == 1 ? nullptr : tls->context);
+
+ // Set the context and update thread local.
+ CHECK_EQ(CUDA_SUCCESS, dynload::cuCtxSetCurrent(cuda_context->context()));
+ tls->id = cuda_context->id();
+ tls->context = cuda_context;
}
ScopedActivateContext::~ScopedActivateContext() {
- if (tls_in_multi_op_activation.get()) {
- DCHECK_EQ(context_, CurrentContext());
- if (FLAGS_gpuexec_cuda_sync_around_driver_calls) {
- auto res = dynload::cuCtxSynchronize();
- if (res != CUDA_SUCCESS) {
- LOG(FATAL) << "gpuexec_cuda_sync_around_driver_calls found "
- << ToString(res)
- << " immediately after de-establishing the device context "
- << context_ << " :: " << port::CurrentStackTrace();
- }
- }
- CHECK_EQ(CUDA_SUCCESS, dynload::cuCtxSetCurrent(prior_context_));
- } else {
- PopContextAndCheckNowNull(context_);
+ if (FLAGS_gpuexec_cuda_sync_around_driver_calls) SynchronizeOrDie();
+
+ auto* tls = &tls_data.get();
+
+ if (kVerifyCudaContext) {
+ // Note that if kVerifyCudaContext is used, and contexts are deleted, it's
+ // possible this could fail in the CurrentContext() call.
+ CHECK_EQ(CurrentContext(),
+ tls->context == nullptr ? nullptr : tls->context->context());
+ }
+
+ tls->depth--;
+ DCHECK_GE(tls->depth, 0);
+ if (to_restore_ == nullptr) {
+ // Leave context, tls->id, and tls->context set.
+ return;
}
- tls_in_multi_op_activation.get() = previously_in_multi_op_activation_;
+
+ // Set context and update thread local.
+ CHECK_EQ(CUDA_SUCCESS, dynload::cuCtxSetCurrent(to_restore_->context()));
+ tls->id = to_restore_->id();
+ tls->context = to_restore_;
}
namespace {
@@ -556,7 +564,9 @@ bool DeviceOptionsToContextFlags(DeviceOptions device_options, int *flags) {
}
/* static */ port::Status CUDADriver::CreateContext(
- CUdevice device, DeviceOptions device_options, CUcontext *context) {
+ CUdevice device, DeviceOptions device_options, CudaContext** context) {
+ *context = nullptr;
+
CUcontext former_context = CurrentContext();
if (former_context != nullptr) {
LOG(WARNING) << "creating context when one is currently active; existing: "
@@ -569,15 +579,17 @@ bool DeviceOptionsToContextFlags(DeviceOptions device_options, int *flags) {
}
CUresult res;
+ CUcontext new_context;
{
// TODO(leary) Need to see if NVIDIA can expunge the leakiness in their
// context creation: see http://b/13248943
- res = dynload::cuCtxCreate_v2(context, flags, device);
+ res = dynload::cuCtxCreate_v2(&new_context, flags, device);
}
+ CHECK_EQ(CUDA_SUCCESS, dynload::cuCtxSetCurrent(former_context));
+
if (res == CUDA_SUCCESS) {
- CreatedContexts::Add(*context);
- PopContextAndCheckNowNull(*context);
+ *context = CreatedContexts::Add(new_context);
CHECK(*context != nullptr)
<< "success in this call must entail non-null result";
VLOG(2) << "created context " << context << " for this thread";
@@ -597,17 +609,17 @@ bool DeviceOptionsToContextFlags(DeviceOptions device_options, int *flags) {
return port::Status{port::error::INTERNAL, message};
}
-/* static */ void CUDADriver::DestroyContext(CUcontext context) {
+/* static */ void CUDADriver::DestroyContext(CudaContext* context) {
if (context == nullptr) {
return;
}
- CUresult res = dynload::cuCtxDestroy_v2(context);
+ CUresult res = dynload::cuCtxDestroy_v2(context->context());
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to destroy CUDA context; leaking: " << ToString(res);
}
- CreatedContexts::Remove(context);
+ CreatedContexts::Remove(context->context());
}
/* static */ bool CUDADriver::FuncGetAttribute(CUfunction_attribute attribute,
@@ -635,7 +647,7 @@ bool DeviceOptionsToContextFlags(DeviceOptions device_options, int *flags) {
}
/* static */ port::StatusOr<CUsharedconfig>
-CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
+CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUsharedconfig shared_mem_config;
ScopedActivateContext activation{context};
CUresult result = dynload::cuCtxGetSharedMemConfig(&shared_mem_config);
@@ -653,7 +665,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
}
/* static */ port::Status CUDADriver::ContextSetSharedMemConfig(
- CUcontext context, CUsharedconfig shared_mem_config) {
+ CudaContext* context, CUsharedconfig shared_mem_config) {
ScopedActivateContext activation{context};
CUresult result = dynload::cuCtxSetSharedMemConfig(shared_mem_config);
if (result != CUDA_SUCCESS) {
@@ -671,7 +683,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
}
/* static */ bool CUDADriver::LaunchKernel(
- CUcontext context, CUfunction function, unsigned int grid_dim_x,
+ CudaContext* context, CUfunction function, unsigned int grid_dim_x,
unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x,
unsigned int block_dim_y, unsigned int block_dim_z,
unsigned int shared_mem_bytes, CUstream stream, void **kernel_params,
@@ -693,7 +705,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ port::Status CUDADriver::LoadCubin(CUcontext context,
+/* static */ port::Status CUDADriver::LoadCubin(CudaContext* context,
const char *cubin_bytes,
CUmodule *module) {
ScopedActivateContext activation{context};
@@ -706,7 +718,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return port::Status::OK();
}
-/* static */ bool CUDADriver::LoadPtx(CUcontext context,
+/* static */ bool CUDADriver::LoadPtx(CudaContext* context,
const char *ptx_contents,
CUmodule *module) {
port::Notification notification;
@@ -776,7 +788,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return ret;
}
-/* static */ bool CUDADriver::SynchronousMemsetUint8(CUcontext context,
+/* static */ bool CUDADriver::SynchronousMemsetUint8(CudaContext* context,
CUdeviceptr location,
uint8 value, size_t size) {
ScopedActivateContext activation{context};
@@ -788,7 +800,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::SynchronousMemsetUint32(CUcontext context,
+/* static */ bool CUDADriver::SynchronousMemsetUint32(CudaContext* context,
CUdeviceptr location,
uint32 value,
size_t uint32_count) {
@@ -801,7 +813,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::AsynchronousMemsetUint8(CUcontext context,
+/* static */ bool CUDADriver::AsynchronousMemsetUint8(CudaContext* context,
CUdeviceptr location,
uint8 value,
size_t uint32_count,
@@ -817,7 +829,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::AsynchronousMemsetUint32(CUcontext context,
+/* static */ bool CUDADriver::AsynchronousMemsetUint32(CudaContext* context,
CUdeviceptr location,
uint32 value,
size_t uint32_count,
@@ -833,7 +845,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::AddStreamCallback(CUcontext context,
+/* static */ bool CUDADriver::AddStreamCallback(CudaContext* context,
CUstream stream,
StreamCallback callback,
void *data) {
@@ -847,7 +859,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::GetModuleFunction(CUcontext context,
+/* static */ bool CUDADriver::GetModuleFunction(CudaContext *context,
CUmodule module,
const char *kernel_name,
CUfunction *function) {
@@ -863,7 +875,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::GetModuleSymbol(CUcontext context,
+/* static */ bool CUDADriver::GetModuleSymbol(CudaContext* context,
CUmodule module,
const char *symbol_name,
CUdeviceptr *dptr,
@@ -884,7 +896,8 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ void CUDADriver::UnloadModule(CUcontext context, CUmodule module) {
+/* static */ void CUDADriver::UnloadModule(CudaContext *context,
+ CUmodule module) {
ScopedActivateContext activated{context};
CUresult res = dynload::cuModuleUnload(module);
if (res != CUDA_SUCCESS) {
@@ -894,7 +907,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
}
/* static */ port::StatusOr<CUdevice> CUDADriver::DeviceFromContext(
- CUcontext context) {
+ CudaContext* context) {
ScopedActivateContext activated{context};
CUdevice device = -1;
CUresult result = dynload::cuCtxGetDevice(&device);
@@ -907,7 +920,8 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
port::StrCat("failed to get device for context: ", ToString(result))};
}
-/* static */ bool CUDADriver::CreateStream(CUcontext context, CUstream *out) {
+/* static */ bool CUDADriver::CreateStream(CudaContext *context,
+ CUstream *out) {
// TODO(leary) can we switch this to CU_STREAM_NON_BLOCKING or will that mess
// up synchronization with respect to memsets and any other things that have
// to occur on the default stream?
@@ -924,7 +938,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ void CUDADriver::DestroyStream(CUcontext context,
+/* static */ void CUDADriver::DestroyStream(CudaContext* context,
CUstream *stream) {
if (*stream == nullptr) {
return;
@@ -942,7 +956,8 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
}
}
-/* static */ void *CUDADriver::DeviceAllocate(CUcontext context, uint64 bytes) {
+/* static */ void *CUDADriver::DeviceAllocate(CudaContext *context,
+ uint64 bytes) {
ScopedActivateContext activated{context};
CUdeviceptr result = 0;
CUresult res = dynload::cuMemAlloc_v2(&result, bytes);
@@ -958,7 +973,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return ptr;
}
-/* static */ void CUDADriver::DeviceDeallocate(CUcontext context,
+/* static */ void CUDADriver::DeviceDeallocate(CudaContext* context,
void *location) {
ScopedActivateContext activation{context};
CUdeviceptr pointer = port::bit_cast<CUdeviceptr>(location);
@@ -971,7 +986,8 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
}
}
-/* static */ void *CUDADriver::HostAllocate(CUcontext context, uint64 bytes) {
+/* static */ void *CUDADriver::HostAllocate(CudaContext *context,
+ uint64 bytes) {
ScopedActivateContext activation{context};
void *host_mem = nullptr;
// "Portable" memory is visible to all CUDA contexts. Safe for our use model.
@@ -984,7 +1000,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return host_mem;
}
-/* static */ void CUDADriver::HostDeallocate(CUcontext context,
+/* static */ void CUDADriver::HostDeallocate(CudaContext* context,
void *location) {
ScopedActivateContext activation{context};
CUresult res = dynload::cuMemFreeHost(location);
@@ -994,7 +1010,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
}
}
-/* static */ bool CUDADriver::HostRegister(CUcontext context, void *location,
+/* static */ bool CUDADriver::HostRegister(CudaContext* context, void *location,
uint64 bytes) {
ScopedActivateContext activation{context};
// "Portable" memory is visible to all CUDA contexts. Safe for our use model.
@@ -1008,7 +1024,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::HostUnregister(CUcontext context,
+/* static */ bool CUDADriver::HostUnregister(CudaContext* context,
void *location) {
ScopedActivateContext activation{context};
CUresult res = dynload::cuMemHostUnregister(location);
@@ -1020,7 +1036,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ port::Status CUDADriver::DestroyEvent(CUcontext context,
+/* static */ port::Status CUDADriver::DestroyEvent(CudaContext* context,
CUevent *event) {
if (*event == nullptr) {
return port::Status{port::error::INVALID_ARGUMENT,
@@ -1048,7 +1064,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
}
}
-/* static */ port::Status CUDADriver::RecordEvent(CUcontext context,
+/* static */ port::Status CUDADriver::RecordEvent(CudaContext* context,
CUevent event,
CUstream stream) {
ScopedActivateContext activated{context};
@@ -1070,8 +1086,8 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
}
}
-/* static */ port::StatusOr<CUresult> CUDADriver::QueryEvent(CUcontext context,
- CUevent event) {
+/* static */ port::StatusOr<CUresult> CUDADriver::QueryEvent(
+ CudaContext *context, CUevent event) {
ScopedActivateContext activated{context};
CUresult res = dynload::cuEventQuery(event);
if (res != CUDA_SUCCESS && res != CUDA_ERROR_NOT_READY) {
@@ -1083,7 +1099,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return res;
}
-/* static */ bool CUDADriver::GetEventElapsedTime(CUcontext context,
+/* static */ bool CUDADriver::GetEventElapsedTime(CudaContext* context,
float *elapsed_milliseconds,
CUevent start, CUevent stop) {
ScopedActivateContext activated{context};
@@ -1104,7 +1120,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::WaitStreamOnEvent(CUcontext context,
+/* static */ bool CUDADriver::WaitStreamOnEvent(CudaContext* context,
CUstream stream,
CUevent event) {
ScopedActivateContext activation{context};
@@ -1117,7 +1133,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::SynchronizeContext(CUcontext context) {
+/* static */ bool CUDADriver::SynchronizeContext(CudaContext* context) {
ScopedActivateContext activation{context};
CUresult res = dynload::cuCtxSynchronize();
if (res != CUDA_SUCCESS) {
@@ -1129,7 +1145,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::SynchronizeStream(CUcontext context,
+/* static */ bool CUDADriver::SynchronizeStream(CudaContext* context,
CUstream stream) {
ScopedActivateContext activated{context};
CHECK(stream != nullptr);
@@ -1144,7 +1160,8 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::IsStreamIdle(CUcontext context, CUstream stream) {
+/* static */ bool CUDADriver::IsStreamIdle(CudaContext *context,
+ CUstream stream) {
ScopedActivateContext activated{context};
CHECK(stream != nullptr);
CUresult res = dynload::cuStreamQuery(stream);
@@ -1158,7 +1175,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return false;
}
-/* static */ bool CUDADriver::SynchronousMemcpyD2H(CUcontext context,
+/* static */ bool CUDADriver::SynchronousMemcpyD2H(CudaContext* context,
void *host_dst,
CUdeviceptr gpu_src,
uint64 size) {
@@ -1176,7 +1193,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::SynchronousMemcpyH2D(CUcontext context,
+/* static */ bool CUDADriver::SynchronousMemcpyH2D(CudaContext* context,
CUdeviceptr gpu_dst,
const void *host_src,
uint64 size) {
@@ -1193,7 +1210,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::SynchronousMemcpyD2D(CUcontext context,
+/* static */ bool CUDADriver::SynchronousMemcpyD2D(CudaContext* context,
CUdeviceptr gpu_dst,
CUdeviceptr gpu_src,
uint64 size) {
@@ -1211,7 +1228,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::AsynchronousMemcpyD2H(CUcontext context,
+/* static */ bool CUDADriver::AsynchronousMemcpyD2H(CudaContext* context,
void *host_dst,
CUdeviceptr gpu_src,
uint64 size,
@@ -1231,7 +1248,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::AsynchronousMemcpyH2D(CUcontext context,
+/* static */ bool CUDADriver::AsynchronousMemcpyH2D(CudaContext* context,
CUdeviceptr gpu_dst,
const void *host_src,
uint64 size,
@@ -1250,7 +1267,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ bool CUDADriver::AsynchronousMemcpyD2D(CUcontext context,
+/* static */ bool CUDADriver::AsynchronousMemcpyD2D(CudaContext* context,
CUdeviceptr gpu_dst,
CUdeviceptr gpu_src,
uint64 size,
@@ -1277,7 +1294,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
-/* static */ port::Status CUDADriver::CreateEvent(CUcontext context,
+/* static */ port::Status CUDADriver::CreateEvent(CudaContext* context,
CUevent *result,
EventFlags flags) {
int cuflags;
@@ -1321,9 +1338,9 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return device_count;
}
-/* static */ port::StatusOr<CUcontext> CUDADriver::GetPointerContext(
+/* static */ port::StatusOr<CudaContext*> CUDADriver::GetPointerContext(
CUdeviceptr pointer) {
- CUcontext context = nullptr;
+ CudaContext* context = nullptr;
CUresult result = dynload::cuPointerGetAttribute(
&context, CU_POINTER_ATTRIBUTE_CONTEXT, pointer);
if (result == CUDA_SUCCESS) {
@@ -1532,7 +1549,7 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
return true;
}
-/* static */ bool CUDADriver::GetDeviceMemoryInfo(CUcontext context,
+/* static */ bool CUDADriver::GetDeviceMemoryInfo(CudaContext* context,
int64 *free_out,
int64 *total_out) {
ScopedActivateContext activation{context};
@@ -1577,8 +1594,8 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
return pci_bus_id;
}
-/* static */ bool CUDADriver::CanEnablePeerAccess(CUcontext from,
- CUcontext to) {
+/* static */ bool CUDADriver::CanEnablePeerAccess(CudaContext* from,
+ CudaContext* to) {
if (from == to) {
return true; // A context can always access its own memory.
}
@@ -1606,14 +1623,15 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
return can_access_peer;
}
-/* static */ port::Status CUDADriver::EnablePeerAccess(CUcontext from,
- CUcontext to) {
+/* static */ port::Status CUDADriver::EnablePeerAccess(CudaContext* from,
+ CudaContext* to) {
if (from == to) {
return port::Status::OK(); // A context can always access its own memory.
}
ScopedActivateContext activated{from};
- CUresult result = dynload::cuCtxEnablePeerAccess(to, 0 /* = flags */);
+ CUresult result =
+ dynload::cuCtxEnablePeerAccess(to->context(), 0 /* = flags */);
if (result != CUDA_SUCCESS &&
result != CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED) {
return port::Status{
@@ -1626,7 +1644,7 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
}
/* static */ port::StatusOr<int> CUDADriver::GetMaxOccupiedBlocksPerCore(
- CUcontext context, CUfunction kernel, int threads_per_block,
+ CudaContext* context, CUfunction kernel, int threads_per_block,
size_t dynamic_shared_memory_bytes) {
ScopedActivateContext activation{context};
@@ -1643,6 +1661,15 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
return max_blocks;
}
+/* static */ CUcontext CUDADriver::CurrentContextOrDie() {
+ CUcontext current = nullptr;
+ CUresult result = dynload::cuCtxGetCurrent(&current);
+ if (result != CUDA_SUCCESS) {
+ LOG(FATAL) << "failed to query current context: " << ToString(result);
+ }
+ return current;
+}
+
} // namespace cuda
} // namespace gputools
} // namespace perftools
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.h b/tensorflow/stream_executor/cuda/cuda_driver.h
index b887d048a4..fa227b953d 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.h
+++ b/tensorflow/stream_executor/cuda/cuda_driver.h
@@ -21,7 +21,6 @@ limitations under the License.
#include <stddef.h>
#include "tensorflow/stream_executor/platform/port.h"
-#include "tensorflow/stream_executor/cuda/multi_op_activation.h"
#include "tensorflow/stream_executor/device_options.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
@@ -39,6 +38,8 @@ enum class MemorySpace { kHost, kDevice };
// Returns a casual string, such as "host" for the provided memory space.
string MemorySpaceString(MemorySpace memory_space);
+class CudaContext;
+
// CUDADriver contains wrappers for calls to the userspace library driver. It's
// useful to isolate these calls and put basic wrappers around them to separate
// userspace library driver behaviors from the rest of the program.
@@ -66,19 +67,19 @@ class CUDADriver {
// Returns the device associated with the given context.
// device is an outparam owned by the caller, must not be null.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g4e84b109eba36cdaaade167f34ae881e
- static port::StatusOr<CUdevice> DeviceFromContext(CUcontext context);
+ static port::StatusOr<CUdevice> DeviceFromContext(CudaContext* context);
// Creates a new CUDA stream associated with the given context via
// cuStreamCreate.
// stream is an outparam owned by the caller, must not be null.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1ga581f0c5833e21ded8b5a56594e243f4
- static bool CreateStream(CUcontext context, CUstream *stream);
+ static bool CreateStream(CudaContext* context, CUstream *stream);
// Destroys a CUDA stream associated with the given context.
// stream is owned by the caller, must not be null, and *stream is set to null
// if the stream is successfuly destroyed.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g244c8833de4596bcd31a06cdf21ee758
- static void DestroyStream(CUcontext context, CUstream *stream);
+ static void DestroyStream(CudaContext* context, CUstream *stream);
// CUDA events can explicitly disable event TSC retrieval for some presumed
// performance improvement if timing is unnecessary.
@@ -88,36 +89,36 @@ class CUDADriver {
// Creates a new event associated with the given context.
// result is an outparam owned by the caller and must not be null.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g450687e75f3ff992fe01662a43d9d3db
- static port::Status CreateEvent(CUcontext context, CUevent *result,
+ static port::Status CreateEvent(CudaContext* context, CUevent *result,
EventFlags flags);
// Destroys *event and turns it into a nullptr. event may not be null, but
// *event may be, via cuEventDestroy
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g593ec73a8ec5a5fc031311d3e4dca1ef
- static port::Status DestroyEvent(CUcontext context, CUevent *event);
+ static port::Status DestroyEvent(CudaContext* context, CUevent *event);
// Allocates a GPU memory space of size bytes associated with the given
// context via cuMemAlloc.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gb82d2a09844a58dd9e744dc31e8aa467
- static void *DeviceAllocate(CUcontext context, uint64 bytes);
+ static void *DeviceAllocate(CudaContext* context, uint64 bytes);
// Deallocates a GPU memory space of size bytes associated with the given
// context via cuMemFree.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g89b3f154e17cc89b6eea277dbdf5c93a
- static void DeviceDeallocate(CUcontext context, void *location);
+ static void DeviceDeallocate(CudaContext* context, void *location);
// Allocates page-locked and CUDA-registered memory on the host via
// cuMemAllocHost.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gdd8311286d2c2691605362c689bc64e0
- static void *HostAllocate(CUcontext context, uint64 bytes);
+ static void *HostAllocate(CudaContext* context, uint64 bytes);
// Deallocates a location created by HostAllocate, via cuMemFreeHost.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g62e0fdbe181dab6b1c90fa1a51c7b92c
- static void HostDeallocate(CUcontext context, void *location);
+ static void HostDeallocate(CudaContext* context, void *location);
// Registers a memory region at location of size bytes via cuMemHostRegister.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gf0a9fe11544326dabd743b7aa6b54223
- static bool HostRegister(CUcontext context, void *location, uint64 bytes);
+ static bool HostRegister(CudaContext* context, void *location, uint64 bytes);
// Unregisters a memory region that was previously registered at location via
// cuMemHostUnregister.
@@ -126,7 +127,7 @@ class CUDADriver {
//
// TODO(leary) verify an error will be returned if the location wasn't
// previously registered.
- static bool HostUnregister(CUcontext context, void *location);
+ static bool HostUnregister(CudaContext* context, void *location);
// Given a device ordinal, returns a device handle into the device outparam,
// which must not be null.
@@ -148,13 +149,13 @@ class CUDADriver {
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g65dc0012348bc84810e2103a40d8e2cf
static port::Status CreateContext(CUdevice device,
DeviceOptions device_options,
- CUcontext *context);
+ CudaContext** context);
// Destroys the provided context via cuCtxDestroy.
// Don't do this while clients could still be using the context, per the docs
// bad things will happen.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g27a365aebb0eb548166309f58a1e8b8e
- static void DestroyContext(CUcontext context);
+ static void DestroyContext(CudaContext* context);
// Queries the runtime for the specified attribute of the specified function.
// cuFuncGetAttribute (the underlying CUDA driver API routine) only operates
@@ -173,19 +174,19 @@ class CUDADriver {
// CONTEXT (not function!), either default or four- or eight-byte bank size.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g17153a1b8b8c756f7ab8505686a4ad74
static port::StatusOr<CUsharedconfig> ContextGetSharedMemConfig(
- CUcontext context);
+ CudaContext* context);
// Sets the preferred shared memory bank configuration for the specified
// CONTEXT (not function!), either default or four- or eight-byte bank size.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g2574235fa643f8f251bf7bc28fac3692
static port::Status ContextSetSharedMemConfig(
- CUcontext context, CUsharedconfig shared_mem_config);
+ CudaContext* context, CUsharedconfig shared_mem_config);
// Launches a CUDA kernel via cuLaunchKernel.
// TODO(leary) describe the structure of kernel_params and extra in a readable
// way.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15
- static bool LaunchKernel(CUcontext context, CUfunction function,
+ static bool LaunchKernel(CudaContext* context, CUfunction function,
unsigned int grid_dim_x, unsigned int grid_dim_y,
unsigned int grid_dim_z, unsigned int block_dim_x,
unsigned int block_dim_y, unsigned int block_dim_z,
@@ -194,25 +195,25 @@ class CUDADriver {
// Loads ptx_contents with the CUDA driver's PTX JIT and stores the resulting
// handle in "module". Any error logs that are produced are logged internally.
- static bool LoadPtx(CUcontext context, const char *ptx_contents,
+ static bool LoadPtx(CudaContext* context, const char *ptx_contents,
CUmodule *module);
// Loads cubin_bytes with the CUDA driver's blob loading interface and stores
// the resulting handle in "module".
- static port::Status LoadCubin(CUcontext context, const char *cubin_bytes,
+ static port::Status LoadCubin(CudaContext* context, const char *cubin_bytes,
CUmodule *module);
// Retrieves a named kernel from a loaded module, and places the resulting
// handle into function (outparam) on success. Neither kernel_name nor
// function may be null. No ownership is taken of kernel_name.
- static bool GetModuleFunction(CUcontext context, CUmodule module,
+ static bool GetModuleFunction(CudaContext* context, CUmodule module,
const char *kernel_name, CUfunction *function);
// Retrieves a named global/constant symbol from a loaded module, and returns
// a device pointer and size of the symbol on success. symbol_name may not be
// null. At least one of dptr or bytes should not be null. No ownership is
// taken of symbol_name.
- static bool GetModuleSymbol(CUcontext context, CUmodule module,
+ static bool GetModuleSymbol(CudaContext* context, CUmodule module,
const char *symbol_name, CUdeviceptr *dptr,
size_t *bytes);
@@ -220,52 +221,53 @@ class CUDADriver {
// TODO(leary) the documentation doesn't say what kind of disasters happen
// if you try to unload a module while its CUfunctions are in use.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MODULE.html#group__CUDA__MODULE_1g8ea3d716524369de3763104ced4ea57b
- static void UnloadModule(CUcontext context, CUmodule module);
+ static void UnloadModule(CudaContext* context, CUmodule module);
// Performs a synchronous memset of the device memory segment via cuMemsetD8.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6e582bf866e9e2fb014297bfaf354d7b
- static bool SynchronousMemsetUint8(CUcontext context, CUdeviceptr location,
+ static bool SynchronousMemsetUint8(CudaContext* context, CUdeviceptr location,
uint8 value, size_t size);
// Performs a synchronous memset of the device memory segment via cuMemsetD32.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g983e8d8759acd1b64326317481fbf132
- static bool SynchronousMemsetUint32(CUcontext context, CUdeviceptr location,
- uint32 value, size_t uint32_count);
+ static bool SynchronousMemsetUint32(CudaContext* context,
+ CUdeviceptr location, uint32 value,
+ size_t uint32_count);
// Performs an asynchronous memset of the device memory segment via
// cuMemsetD8Async.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gaef08a7ccd61112f94e82f2b30d43627
- static bool AsynchronousMemsetUint8(CUcontext context, CUdeviceptr location,
+ static bool AsynchronousMemsetUint8(CudaContext* context, CUdeviceptr location,
uint8 value, size_t uint32_count,
CUstream stream);
// Performs an asynchronous memset of the device memory segment via
// cuMemsetD32Async.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g58229da5d30f1c0cdf667b320ec2c0f5
- static bool AsynchronousMemsetUint32(CUcontext context, CUdeviceptr location,
- uint32 value, size_t uint32_count,
- CUstream stream);
+ static bool AsynchronousMemsetUint32(CudaContext* context,
+ CUdeviceptr location, uint32 value,
+ size_t uint32_count, CUstream stream);
// -- Synchronous memcopies.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g4d32266788c440b0220b1a9ba5795169
- static bool SynchronousMemcpyD2H(CUcontext context, void *host_dst,
+ static bool SynchronousMemcpyD2H(CudaContext* context, void *host_dst,
CUdeviceptr gpu_src, uint64 size);
- static bool SynchronousMemcpyH2D(CUcontext context, CUdeviceptr gpu_dst,
+ static bool SynchronousMemcpyH2D(CudaContext* context, CUdeviceptr gpu_dst,
const void *host_src, uint64 size);
- static bool SynchronousMemcpyD2D(CUcontext context, CUdeviceptr gpu_dst,
+ static bool SynchronousMemcpyD2D(CudaContext* context, CUdeviceptr gpu_dst,
CUdeviceptr gpu_src, uint64 size);
// -- Asynchronous memcopies.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g56f30236c7c5247f8e061b59d3268362
- static bool AsynchronousMemcpyD2H(CUcontext context, void *host_dst,
+ static bool AsynchronousMemcpyD2H(CudaContext* context, void *host_dst,
CUdeviceptr gpu_src, uint64 size,
CUstream stream);
- static bool AsynchronousMemcpyH2D(CUcontext context, CUdeviceptr gpu_dst,
+ static bool AsynchronousMemcpyH2D(CudaContext* context, CUdeviceptr gpu_dst,
const void *host_src, uint64 size,
CUstream stream);
- static bool AsynchronousMemcpyD2D(CUcontext context, CUdeviceptr gpu_dst,
+ static bool AsynchronousMemcpyD2D(CudaContext* context, CUdeviceptr gpu_dst,
CUdeviceptr gpu_src, uint64 size,
CUstream stream);
@@ -283,13 +285,13 @@ class CUDADriver {
// Enqueues a callback operation into stream.
// See StreamCallback above and the NVIDIA documentation for additional
// details.
- static bool AddStreamCallback(CUcontext context, CUstream stream,
+ static bool AddStreamCallback(CudaContext* context, CUstream stream,
StreamCallback callback, void *data);
// Causes stream to wait for event to trigger before proceeding via
// cuStreamWaitEvent.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#axzz334nAXAhM
- static bool WaitStreamOnEvent(CUcontext context, CUstream stream,
+ static bool WaitStreamOnEvent(CudaContext* context, CUstream stream,
CUevent event);
// Blocks the calling thread until the operations enqueued onto stream have
@@ -300,50 +302,51 @@ class CUDADriver {
// amount of time?
//
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g15e49dd91ec15991eb7c0a741beb7dad
- static bool SynchronizeStream(CUcontext context, CUstream stream);
+ static bool SynchronizeStream(CudaContext* context, CUstream stream);
// Blocks the calling thread until the operations associated with the context
// have been completed, via cuCtxSynchronize.
//
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g7a54725f28d34b8c6299f0c6ca579616
- static bool SynchronizeContext(CUcontext context);
+ static bool SynchronizeContext(CudaContext* context);
// Returns true if all stream tasks have completed at time of the call. Note
// the potential for races around this call (if another thread adds work to
// the stream immediately after this returns).
- static bool IsStreamIdle(CUcontext context, CUstream stream);
+ static bool IsStreamIdle(CudaContext* context, CUstream stream);
// Returns whether code in the from context can access memory in the to
// context via cuDeviceCanAccessPeer.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PEER__ACCESS.html#group__CUDA__PEER__ACCESS_1g496bdaae1f632ebfb695b99d2c40f19e
- static bool CanEnablePeerAccess(CUcontext from, CUcontext to);
+ static bool CanEnablePeerAccess(CudaContext* from, CudaContext* to);
// Enables peer access per CanEnablePeerAccess, via cuCtxEnablePeerAccess.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PEER__ACCESS.html#group__CUDA__PEER__ACCESS_1g0889ec6728e61c05ed359551d67b3f5a
- static port::Status EnablePeerAccess(CUcontext from, CUcontext to);
+ static port::Status EnablePeerAccess(CudaContext* from, CudaContext* to);
// Returns the elapsed milliseconds between start and stop via
// cuEventElapsedTime.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1gdfb1178807353bbcaa9e245da497cf97
- static bool GetEventElapsedTime(CUcontext context,
+ static bool GetEventElapsedTime(CudaContext* context,
float *elapsed_milliseconds, CUevent start,
CUevent stop);
// Records that an event occurred when execution reaches the current point in
// thestream via cuEventRecord.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g95424d3be52c4eb95d83861b70fb89d1
- static port::Status RecordEvent(CUcontext context, CUevent event,
+ static port::Status RecordEvent(CudaContext* context, CUevent event,
CUstream stream);
// Polls (without blocking) to determine the status of an event - pending or
// complete (or an error status).
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g6f0704d755066b0ee705749ae911deef
- static port::StatusOr<CUresult> QueryEvent(CUcontext context, CUevent event);
+ static port::StatusOr<CUresult> QueryEvent(CudaContext* context,
+ CUevent event);
// -- Pointer-specific calls.
// Returns the context in which pointer was allocated or registered.
- static port::StatusOr<CUcontext> GetPointerContext(CUdeviceptr pointer);
+ static port::StatusOr<CudaContext*> GetPointerContext(CUdeviceptr pointer);
// Returns the device associated with the context from GetPointerContext().
static port::StatusOr<CUdevice> GetPointerDevice(CUdeviceptr pointer);
@@ -413,7 +416,8 @@ class CUDADriver {
// Returns the free amount of memory and total amount of memory, as reported
// by cuMemGetInfo.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g808f555540d0143a331cc42aa98835c0
- static bool GetDeviceMemoryInfo(CUcontext context, int64 *free, int64 *total);
+ static bool GetDeviceMemoryInfo(CudaContext* context, int64* free,
+ int64* total);
// Returns a PCI bus id string for the device.
// [domain]:[bus]:[device].[function]
@@ -442,9 +446,13 @@ class CUDADriver {
// specified kernel/CUfunction when launched with the specified parameters.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__OCCUPANCY.html#group__CUDA__OCCUPANCY_1gcc6e1094d05cba2cee17fe33ddd04a98
static port::StatusOr<int> GetMaxOccupiedBlocksPerCore(
- CUcontext context, CUfunction kernel, int threads_per_block,
+ CudaContext* context, CUfunction kernel, int threads_per_block,
size_t dynamic_shared_memory_bytes);
+ // Returns the current context set in CUDA. This is done by calling the cuda
+ // driver (e.g., this value is not our cached view of the current context).
+ static CUcontext CurrentContextOrDie();
+
// Seam for injecting an error at CUDA initialization time for testing
// purposes.
static bool driver_inject_init_error_;
@@ -457,22 +465,34 @@ class ScopedActivateContext {
// active context (a la cuCtxGetCurrent). Note the alternative push/pop
// mechanism is said by NVIDIA to be relatively slow and deprecated.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1gbe562ee6258b4fcc272ca6478ca2a2f7
- explicit ScopedActivateContext(
- CUcontext context, MultiOpActivation moa = MultiOpActivation::kNo);
+ explicit ScopedActivateContext(CudaContext* context);
// Checks that the context has remained activated for the duration of the
// scope.
~ScopedActivateContext();
private:
- CUcontext context_; // context being activated.
+ CudaContext* to_restore_ = nullptr;
+};
- CUcontext prior_context_; // context that was active when we were activated.
+// CudaContext wraps a cuda CUcontext handle, and includes a unique id. The
+// unique id is positive, and ids are not repeated within the process.
+class CudaContext {
+ public:
+ CudaContext(CUcontext context, int64 id) : context_(context), id_(id) { }
- // Stores whether this was instantiated during a MultiOpActivation, in which
- // case we will not pop the context when we're destroyed (we will leave it to
- // the parent MultiOpActivation that we were nested within).
- bool previously_in_multi_op_activation_;
+ CUcontext context() const { return context_; }
+ int64 id() const { return id_; }
+
+ // Disallow copying and moving.
+ CudaContext(CudaContext&&) = delete;
+ CudaContext(const CudaContext&) = delete;
+ CudaContext& operator=(CudaContext&&) = delete;
+ CudaContext& operator=(const CudaContext&) = delete;
+
+ private:
+ CUcontext const context_;
+ const int64 id_;
};
} // namespace cuda
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
index 9757ef640d..4e3906da97 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
@@ -114,12 +114,12 @@ static CUdeviceptr AsCudaDevicePtr(DeviceMemoryBase *gpu_mem) {
return AsCudaDevicePtr(*gpu_mem);
}
-static CUcontext GetCudaContext(Stream *stream) {
+static CudaContext* GetCudaContext(Stream *stream) {
return static_cast<CUDAExecutor *>(stream->parent()->implementation())
->cuda_context();
}
-CUcontext ExtractCudaContext(CUDAExecutor *cuda_exec) {
+CudaContext* ExtractCudaContext(CUDAExecutor *cuda_exec) {
CHECK(cuda_exec != nullptr);
return cuda_exec->cuda_context();
}
@@ -878,7 +878,7 @@ CUDAExecutor::GetTimerImplementation() {
void *CUDAExecutor::CudaContextHack() { return context_; }
-CUcontext CUDAExecutor::cuda_context() { return context_; }
+CudaContext* CUDAExecutor::cuda_context() { return context_; }
// Attemps to read the NUMA node corresponding to the GPU device's PCI bus out
// of SysFS. Returns -1 if it cannot.
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
index ccbe6f26fd..6997e51a6b 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
@@ -217,7 +217,7 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
void *CudaContextHack() override;
- CUcontext cuda_context();
+ CudaContext* cuda_context();
private:
// Attempts to find a more specific version of the file indicated by
@@ -272,7 +272,7 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
CUdevice device_;
// Handle for session with the library/driver. Immutable post-initialization.
- CUcontext context_;
+ CudaContext* context_;
// The device ordinal value that this executor was initialized with; recorded
// for use in getting device metadata. Immutable post-initialization.
diff --git a/tensorflow/stream_executor/cuda/cuda_timer.cc b/tensorflow/stream_executor/cuda/cuda_timer.cc
index eb5be4300a..8d9d1bd01a 100644
--- a/tensorflow/stream_executor/cuda/cuda_timer.cc
+++ b/tensorflow/stream_executor/cuda/cuda_timer.cc
@@ -26,7 +26,7 @@ namespace cuda {
bool CUDATimer::Init() {
CHECK(start_event_ == nullptr && stop_event_ == nullptr);
- CUcontext context = parent_->cuda_context();
+ CudaContext* context = parent_->cuda_context();
if (!CUDADriver::CreateEvent(context, &start_event_,
CUDADriver::EventFlags::kDefault)
.ok()) {
@@ -48,7 +48,7 @@ bool CUDATimer::Init() {
}
void CUDATimer::Destroy() {
- CUcontext context = parent_->cuda_context();
+ CudaContext* context = parent_->cuda_context();
port::Status status = CUDADriver::DestroyEvent(context, &start_event_);
if (!status.ok()) {
LOG(ERROR) << status;