aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Reed Wanderman-Milne <reedwm@google.com>2018-07-23 15:01:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-23 15:25:34 -0700
commit09c4c387913c86247121589caa7fb2e85351fa58 (patch)
treecb73c7ca6c0f91906d832161643526fb8e121b60
parentf8e8c0c6f7746d3f2b5820e76c9e382149090034 (diff)
Add check at GPU initialization to see if GPU kernels can be run.
PiperOrigin-RevId: 205730535
-rw-r--r--tensorflow/core/BUILD11
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc51
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.h11
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.cu.cc37
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.h32
5 files changed, 141 insertions, 1 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index b6a990ac7d..13e1b643d1 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -90,6 +90,7 @@ load(
"tf_genrule_cmd_append_to_srcs",
"tf_opts_nortti_if_android",
"tf_features_nomodules_if_android",
+ "tf_gpu_kernel_library",
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl")
load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
@@ -2948,6 +2949,15 @@ cc_library(
],
)
+tf_gpu_kernel_library(
+ name = "gpu_device_kernel_check",
+ srcs = ["common_runtime/gpu/gpu_device_kernel_check.cu.cc"],
+ hdrs = ["common_runtime/gpu/gpu_device_kernel_check.h"],
+ deps = [
+ "//tensorflow/core:stream_executor",
+ ],
+)
+
GPU_RUNTIME_HEADERS = [
"common_runtime/gpu/cuda_host_allocator.h",
"common_runtime/gpu/gpu_bfc_allocator.h",
@@ -2986,6 +2996,7 @@ tf_cuda_library(
":core_cpu_lib",
":framework",
":framework_internal",
+ ":gpu_device_kernel_check",
":gpu_id_impl",
":gpu_init_impl",
":gpu_lib",
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 3292ef2f62..fbe158c777 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
@@ -377,7 +378,7 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
}
}
- return Status::OK();
+ return CheckGPU();
}
bool BaseGPUDevice::RequiresRecordingAccessedTensors() const {
@@ -894,6 +895,54 @@ Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr,
return gpu_allocator_;
}
+Status BaseGPUDevice::CheckGPU() {
+ se::Stream* stream = tensorflow_gpu_device_info()->stream;
+ TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
+ Tensor device_tensor(gpu_allocator_, DT_FLOAT, {});
+ if (!device_tensor.IsInitialized()) {
+ return errors::ResourceExhausted("Failed to allocate ", sizeof(float),
+ " bytes on the GPU for initialization "
+ "checks");
+ }
+ float* val_dev = device_tensor.scalar<float>().data();
+ const cudaStream_t cu_stream = *reinterpret_cast<const cudaStream_t*>(
+ stream->implementation()->GpuStreamMemberHack());
+ {
+ se::cuda::ScopedActivateExecutorContext scoped_activation{stream->parent()};
+ run_test_kernel(val_dev, cu_stream);
+ // We have to use the CUDA runtime function cudaPeekAtLastError here,
+ // because 'stream' does not provide a way to check if a kernel launch
+ // succeeds. Calling 'stream->BlockHostUntilDone()', which internally calls
+ // 'cuCtxSynchronize()', does not catch all kernel launch errors.
+ cudaError_t cuda_error = cudaPeekAtLastError();
+ if (cuda_error == cudaSuccess) {
+ cuda_error = cudaDeviceSynchronize();
+ }
+ TF_RETURN_IF_ERROR(CudaErrorToStatus(cuda_error, *stream));
+ }
+
+ float val_host = 0.;
+ stream->ThenMemcpy(&val_host, se::DeviceMemoryBase(val_dev, sizeof(float)),
+ sizeof(float));
+ TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
+ if (val_host != 12345.) {
+ return errors::Internal(
+ "GPU kernel for initialization returned wrong value: ", val_host);
+ }
+ return Status::OK();
+}
+
+Status BaseGPUDevice::CudaErrorToStatus(cudaError_t cuda_error,
+ const se::Stream& stream) {
+ if (cuda_error != cudaSuccess) {
+ return errors::Internal(
+ "Failed to run GPU kernel for the initialization check. Received "
+ "error ",
+ cudaGetErrorName(cuda_error), " after running GPU kernel.");
+ }
+ return Status::OK();
+}
+
const int BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength = 1000;
const int BaseGPUDeviceFactory::InterconnectMap::kStreamExecutorStrength = 1;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index 56d03d7a8c..d02901a7ae 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -26,6 +26,7 @@ limitations under the License.
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "cuda/include/cuda_runtime_api.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
@@ -115,6 +116,12 @@ class BaseGPUDevice : public LocalDevice {
se::StreamExecutor* executor_; // not owned
std::unique_ptr<ScopedAllocatorMgr> scoped_allocator_mgr_;
+ // Returns a Status corresponding to a cudaError_t. The CUDA error must have
+ // been obtained from a CUDA kernel launch used to check if the GPU is
+ // initialized properly.
+ virtual Status CudaErrorToStatus(cudaError_t cuda_error,
+ const se::Stream& stream);
+
private:
struct StreamGroup {
se::Stream* compute = nullptr;
@@ -151,6 +158,10 @@ class BaseGPUDevice : public LocalDevice {
Status MaybeCopyTensorToGPU(const AllocatorAttributes& alloc_attrs,
const Tensor& from, Tensor* to,
StatusCallback done);
+
+ // Checks that the GPU is capable of doing work, by running a test kernel on
+ // it.
+ Status CheckGPU();
};
class BaseGPUDeviceFactory : public DeviceFactory {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.cu.cc b/tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.cu.cc
new file mode 100644
index 0000000000..017565195b
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.cu.cc
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.h"
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
+
+namespace {
+__global__ void test_kernel(float* val) {
+ if (blockIdx.x == 0 && threadIdx.x == 0) {
+ (*val) = 12345.;
+ }
+}
+} // namespace
+
+namespace tensorflow {
+
+void run_test_kernel(float* val, cudaStream_t cu_stream) {
+ test_kernel<<<1, 1, 0, cu_stream>>>(val);
+}
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.h b/tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.h
new file mode 100644
index 0000000000..064fb7a49f
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.h
@@ -0,0 +1,32 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEVICE_KERNEL_CHECK_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEVICE_KERNEL_CHECK_H_
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/platform/stream_executor.h"
+
+namespace tensorflow {
+
+// Runs a GPU kernel to test that it functions correctly. Sets 'val' to 12345.
+void run_test_kernel(float* val, cudaStream_t cu_stream);
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEVICE_KERNEL_CHECK_H_