aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/BUILD11
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc51
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.h11
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.cu.cc37
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.h32
5 files changed, 1 insertions, 141 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index a960736295..84555b60da 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -90,7 +90,6 @@ load(
"tf_genrule_cmd_append_to_srcs",
"tf_opts_nortti_if_android",
"tf_features_nomodules_if_android",
- "tf_gpu_kernel_library",
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl")
load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
@@ -2950,15 +2949,6 @@ cc_library(
],
)
-tf_gpu_kernel_library(
- name = "gpu_device_kernel_check",
- srcs = ["common_runtime/gpu/gpu_device_kernel_check.cu.cc"],
- hdrs = ["common_runtime/gpu/gpu_device_kernel_check.h"],
- deps = [
- "//tensorflow/core:stream_executor",
- ],
-)
-
GPU_RUNTIME_HEADERS = [
"common_runtime/gpu/cuda_host_allocator.h",
"common_runtime/gpu/gpu_bfc_allocator.h",
@@ -2997,7 +2987,6 @@ tf_cuda_library(
":core_cpu_lib",
":framework",
":framework_internal",
- ":gpu_device_kernel_check",
":gpu_id_impl",
":gpu_init_impl",
":gpu_lib",
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index fbe158c777..3292ef2f62 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -31,7 +31,6 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/device_factory.h"
-#include "tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
@@ -378,7 +377,7 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
}
}
- return CheckGPU();
+ return Status::OK();
}
bool BaseGPUDevice::RequiresRecordingAccessedTensors() const {
@@ -895,54 +894,6 @@ Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr,
return gpu_allocator_;
}
-Status BaseGPUDevice::CheckGPU() {
- se::Stream* stream = tensorflow_gpu_device_info()->stream;
- TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
- Tensor device_tensor(gpu_allocator_, DT_FLOAT, {});
- if (!device_tensor.IsInitialized()) {
- return errors::ResourceExhausted("Failed to allocate ", sizeof(float),
- " bytes on the GPU for initialization "
- "checks");
- }
- float* val_dev = device_tensor.scalar<float>().data();
- const cudaStream_t cu_stream = *reinterpret_cast<const cudaStream_t*>(
- stream->implementation()->GpuStreamMemberHack());
- {
- se::cuda::ScopedActivateExecutorContext scoped_activation{stream->parent()};
- run_test_kernel(val_dev, cu_stream);
- // We have to use the CUDA runtime function cudaPeekAtLastError here,
- // because 'stream' does not provide a way to check if a kernel launch
- // succeeds. Calling 'stream->BlockHostUntilDone()', which internally calls
- // 'cuCtxSynchronize()', does not catch all kernel launch errors.
- cudaError_t cuda_error = cudaPeekAtLastError();
- if (cuda_error == cudaSuccess) {
- cuda_error = cudaDeviceSynchronize();
- }
- TF_RETURN_IF_ERROR(CudaErrorToStatus(cuda_error, *stream));
- }
-
- float val_host = 0.;
- stream->ThenMemcpy(&val_host, se::DeviceMemoryBase(val_dev, sizeof(float)),
- sizeof(float));
- TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
- if (val_host != 12345.) {
- return errors::Internal(
- "GPU kernel for initialization returned wrong value: ", val_host);
- }
- return Status::OK();
-}
-
-Status BaseGPUDevice::CudaErrorToStatus(cudaError_t cuda_error,
- const se::Stream& stream) {
- if (cuda_error != cudaSuccess) {
- return errors::Internal(
- "Failed to run GPU kernel for the initialization check. Received "
- "error ",
- cudaGetErrorName(cuda_error), " after running GPU kernel.");
- }
- return Status::OK();
-}
-
const int BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength = 1000;
const int BaseGPUDeviceFactory::InterconnectMap::kStreamExecutorStrength = 1;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index d02901a7ae..56d03d7a8c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -26,7 +26,6 @@ limitations under the License.
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "cuda/include/cuda_runtime_api.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
@@ -116,12 +115,6 @@ class BaseGPUDevice : public LocalDevice {
se::StreamExecutor* executor_; // not owned
std::unique_ptr<ScopedAllocatorMgr> scoped_allocator_mgr_;
- // Returns a Status corresponding to a cudaError_t. The CUDA error must have
- // been obtained from a CUDA kernel launch used to check if the GPU is
- // initialized properly.
- virtual Status CudaErrorToStatus(cudaError_t cuda_error,
- const se::Stream& stream);
-
private:
struct StreamGroup {
se::Stream* compute = nullptr;
@@ -158,10 +151,6 @@ class BaseGPUDevice : public LocalDevice {
Status MaybeCopyTensorToGPU(const AllocatorAttributes& alloc_attrs,
const Tensor& from, Tensor* to,
StatusCallback done);
-
- // Checks that the GPU is capable of doing work, by running a test kernel on
- // it.
- Status CheckGPU();
};
class BaseGPUDeviceFactory : public DeviceFactory {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.cu.cc b/tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.cu.cc
deleted file mode 100644
index 017565195b..0000000000
--- a/tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.cu.cc
+++ /dev/null
@@ -1,37 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#if GOOGLE_CUDA
-
-#include "tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.h"
-#include "tensorflow/stream_executor/cuda/cuda_activation.h"
-
-namespace {
-__global__ void test_kernel(float* val) {
- if (blockIdx.x == 0 && threadIdx.x == 0) {
- (*val) = 12345.;
- }
-}
-} // namespace
-
-namespace tensorflow {
-
-void run_test_kernel(float* val, cudaStream_t cu_stream) {
- test_kernel<<<1, 1, 0, cu_stream>>>(val);
-}
-
-} // namespace tensorflow
-
-#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.h b/tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.h
deleted file mode 100644
index 064fb7a49f..0000000000
--- a/tensorflow/core/common_runtime/gpu/gpu_device_kernel_check.h
+++ /dev/null
@@ -1,32 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEVICE_KERNEL_CHECK_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEVICE_KERNEL_CHECK_H_
-
-#if GOOGLE_CUDA
-
-#include "tensorflow/core/platform/stream_executor.h"
-
-namespace tensorflow {
-
-// Runs a GPU kernel to test that it functions correctly. Sets 'val' to 12345.
-void run_test_kernel(float* val, cudaStream_t cu_stream);
-
-} // namespace tensorflow
-
-#endif // GOOGLE_CUDA
-
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEVICE_KERNEL_CHECK_H_