diff options
12 files changed, 32 insertions, 30 deletions
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc index b1cb89391c..99fecf9651 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc @@ -445,7 +445,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { se::Stream* comm_stream = nccl_stream->stream.get(); ScopedActivateExecutorContext scoped_context(nccl_stream->executor); const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>( - comm_stream->implementation()->CudaStreamMemberHack()); + comm_stream->implementation()->GpuStreamMemberHack()); while (true) { // Find collective to run. diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index 988b35f74f..2de7973750 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -65,7 +65,7 @@ class IncPluginTRT : public OpKernel { reinterpret_cast<const cudaStream_t*>(context->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); IncrementKernel(input_tensor.flat<float>().data(), inc_, output_tensor->flat<float>().data(), input_shape.num_elements(), *stream); diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 8a17eb02f1..3daf810a4b 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -230,7 +230,7 @@ void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx, reinterpret_cast<const cudaStream_t*>(ctx->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); calib_res->calibrator_->setBatch(input_data, *stream); VLOG(2) << "Passed calibration data"; ExecuteNativeSegment(ctx, helper); @@ -380,7 +380,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, reinterpret_cast<const cudaStream_t*>(ctx->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); // TODO(jie): trt enqueue does not return error auto& trt_execution_context_ptr = engine_ctx_pair.second; diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 3cb51b0dbc..f38ccd0d5b 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -856,7 +856,7 @@ void BaseGPUDevice::ReinitializeDevice(OpKernelContext* context, static_cast<ConcretePerOpGpuDevice*>(device); DCHECK(concrete_device); const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>( - streams_[stream_id]->compute->implementation()->CudaStreamMemberHack()); + streams_[stream_id]->compute->implementation()->GpuStreamMemberHack()); concrete_device->Reinitialize(context, cuda_stream, tf_gpu_id_, allocator, scratch_[stream_id]); } diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc index a857bd3ce4..a59baaa96f 100644 --- a/tensorflow/core/kernels/cuda_solvers.cc +++ b/tensorflow/core/kernels/cuda_solvers.cc @@ -151,7 +151,7 @@ CudaSolver::CudaSolver(OpKernelContext* context) : context_(context) { reinterpret_cast<const cudaStream_t*>(context->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); cuda_stream_ = *cu_stream_ptr; HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton()); auto it = handle_map->find(cuda_stream_); diff --git a/tensorflow/core/util/cuda_launch_config.h b/tensorflow/core/util/cuda_launch_config.h index 81df7a51d7..d0d95736d3 100644 --- a/tensorflow/core/util/cuda_launch_config.h +++ b/tensorflow/core/util/cuda_launch_config.h @@ -295,7 +295,7 @@ inline const cudaStream_t& GetCudaStream(OpKernelContext* context) { reinterpret_cast<const cudaStream_t*>(context->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); return *ptr; } diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index f11022ef1d..259c813c57 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -844,7 +844,7 @@ CUDAExecutor::GetTimerImplementation() { return std::unique_ptr<internal::TimerInterface>(new CUDATimer(this)); } -void *CUDAExecutor::CudaContextHack() { return context_; } +void *CUDAExecutor::GpuContextHack() { return context_; } CudaContext* CUDAExecutor::cuda_context() { return context_; } diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h index 773cbfb8a1..f7c341c857 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h @@ -210,7 +210,7 @@ class CUDAExecutor : public internal::StreamExecutorInterface { std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override; - void *CudaContextHack() override; + void *GpuContextHack() override; CudaContext* cuda_context(); diff --git a/tensorflow/stream_executor/cuda/cuda_stream.h b/tensorflow/stream_executor/cuda/cuda_stream.h index 02edff6431..bb8bda4755 100644 --- a/tensorflow/stream_executor/cuda/cuda_stream.h +++ b/tensorflow/stream_executor/cuda/cuda_stream.h @@ -40,8 +40,8 @@ class CUDAStream : public internal::StreamInterface { // Note: teardown is handled by a parent's call to DeallocateStream. ~CUDAStream() override {} - void *CudaStreamHack() override { return cuda_stream_; } - void **CudaStreamMemberHack() override { + void *GpuStreamHack() override { return cuda_stream_; } + void **GpuStreamMemberHack() override { return reinterpret_cast<void **>(&cuda_stream_); } diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h index e82f57569f..858396ef96 100644 --- a/tensorflow/stream_executor/host/host_gpu_executor.h +++ b/tensorflow/stream_executor/host/host_gpu_executor.h @@ -202,7 +202,7 @@ class HostExecutor : public internal::StreamExecutorInterface { return std::unique_ptr<internal::TimerInterface>(new HostTimer()); } - void *CudaContextHack() override { return nullptr; } + void *GpuContextHack() override { return nullptr; } private: const PluginConfig plugin_config_; diff --git a/tensorflow/stream_executor/host/host_stream.h b/tensorflow/stream_executor/host/host_stream.h index 5d7b8a3782..be88f074cf 100644 --- a/tensorflow/stream_executor/host/host_stream.h +++ b/tensorflow/stream_executor/host/host_stream.h @@ -34,8 +34,8 @@ class HostStream : public internal::StreamInterface { bool EnqueueTask(std::function<void()> task); - void *CudaStreamHack() override { return nullptr; } - void **CudaStreamMemberHack() override { return nullptr; } + void *GpuStreamHack() override { return nullptr; } + void **GpuStreamMemberHack() override { return nullptr; } void BlockUntilDone(); diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h index 9c989b971d..fb1b92cb84 100644 --- a/tensorflow/stream_executor/stream_executor_internal.h +++ b/tensorflow/stream_executor/stream_executor_internal.h @@ -100,19 +100,20 @@ class StreamInterface { // Default destructor for the abstract interface. virtual ~StreamInterface() {} - // Returns the CUDA stream associated with this platform's stream + // Returns the GPU stream associated with this platform's stream // implementation. // - // WARNING: checks that the underlying platform is, in fact, CUDA, causing a - // fatal error if it is not. This hack is made available solely for use from - // distbelief code, which temporarily has strong ties to CUDA as a platform. - virtual void *CudaStreamHack() { return nullptr; } - - // See the above comment on CudaStreamHack -- this further breaks abstraction - // for Eigen within distbelief, which has strong ties to CUDA as a platform, - // and a historical attachment to a programming model which takes a + // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm, + // causing a fatal error if it is not. This hack is made available solely for + // use from distbelief code, which temporarily has strong ties to CUDA or + // ROCm as a platform. + virtual void *GpuStreamHack() { return nullptr; } + + // See the above comment on GpuStreamHack -- this further breaks abstraction + // for Eigen within distbelief, which has strong ties to CUDA or ROCm as a + // platform, and a historical attachment to a programming model which takes a // stream-slot rather than a stream-value. - virtual void **CudaStreamMemberHack() { return nullptr; } + virtual void **GpuStreamMemberHack() { return nullptr; } private: SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface); @@ -324,13 +325,14 @@ class StreamExecutorInterface { virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0; virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0; - // Returns the CUDA context associated with this StreamExecutor platform - // implementation. + // Returns the CUDA or ROCm context associated with this StreamExecutor + // platform implementation. // - // WARNING: checks that the underlying platform is, in fact, CUDA, causing a - // fatal error if it is not. This hack is made available solely for use from - // distbelief code, which temporarily has strong ties to CUDA as a platform. - virtual void *CudaContextHack() { return nullptr; } + // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm, + // causing a fatal error if it is not. This hack is made available solely for + // use from distbelief code, which temporarily has strong ties to CUDA or ROCm + // as a platform. + virtual void *GpuContextHack() { return nullptr; } private: SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface); |