aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/platform
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2017-10-02 13:38:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-02 13:43:29 -0700
commita8444b7c19d971e3f109adf4f1295f37d439af6c (patch)
treed4f3d3318ad6b759ccde4eb5511ffbdd5c967e28 /tensorflow/python/platform
parent7b098f62f983738bbf048873b6ecac3b26d40d68 (diff)
[Windows] Improve import self-check with tests for GPU-related DLLs.
This change incorporates the full logic of the [Windows self-check script](https://gist.github.com/mrry/ee5dbcfdd045fa48a27d56664411d41c) into core TensorFlow. Fixes #9170. PiperOrigin-RevId: 170746452
Diffstat (limited to 'tensorflow/python/platform')
-rw-r--r--tensorflow/python/platform/self_check.py68
1 files changed, 57 insertions, 11 deletions
diff --git a/tensorflow/python/platform/self_check.py b/tensorflow/python/platform/self_check.py
index 0a8fc07901..39d38d7bbc 100644
--- a/tensorflow/python/platform/self_check.py
+++ b/tensorflow/python/platform/self_check.py
@@ -21,6 +21,9 @@ from __future__ import print_function
import os
+from tensorflow.python.platform import build_info
+
+
def preload_check():
"""Raises an exception if the environment is not correctly configured.
@@ -33,17 +36,60 @@ def preload_check():
# we load the Python extension, so that we can raise an actionable error
# message if they are not found.
import ctypes # pylint: disable=g-import-not-at-top
- try:
- ctypes.WinDLL("msvcp140.dll")
- except OSError:
- raise ImportError(
- "Could not find 'msvcp140.dll'. TensorFlow requires that this DLL be "
- "installed in a directory that is named in your %PATH% environment "
- "variable. You may install this DLL by downloading Visual C++ 2015 "
- "Redistributable Update 3 from this URL: "
- "https://www.microsoft.com/en-us/download/details.aspx?id=53587")
- # TODO(mrry): Add specific checks for GPU DLLs if build_info indicates
- # that this is a GPU build.
+ if hasattr(build_info, "msvcp_dll_name"):
+ try:
+ ctypes.WinDLL(build_info.msvcp_dll_name)
+ except OSError:
+ raise ImportError(
+ "Could not find %r. TensorFlow requires that this DLL be "
+ "installed in a directory that is named in your %%PATH%% "
+ "environment variable. You may install this DLL by downloading "
+ "Visual C++ 2015 Redistributable Update 3 from this URL: "
+ "https://www.microsoft.com/en-us/download/details.aspx?id=53587"
+ % build_info.msvcp_dll_name)
+
+ if build_info.is_cuda_build:
+ # Attempt to check that the necessary CUDA DLLs are loadable.
+
+ if hasattr(build_info, "nvcuda_dll_name"):
+ try:
+ ctypes.WinDLL(build_info.nvcuda_dll_name)
+ except OSError:
+ raise ImportError(
+ "Could not find %r. TensorFlow requires that this DLL "
+ "be installed in a directory that is named in your %%PATH%% "
+ "environment variable. Typically it is installed in "
+ "'C:\\Windows\\System32'. If it is not present, ensure that you "
+ "have a CUDA-capable GPU with the correct driver installed."
+ % build_info.nvcuda_dll_name)
+
+ if hasattr(build_info, "cudart_dll_name") and hasattr(
+ build_info, "cuda_version_number"):
+ try:
+ ctypes.WinDLL(build_info.cudart_dll_name)
+ except OSError:
+ raise ImportError(
+ "Could not find %r. TensorFlow requires that this DLL be "
+ "installed in a directory that is named in your %%PATH%% "
+ "environment variable. Download and install CUDA %s from "
+ "this URL: https://developer.nvidia.com/cuda-toolkit"
+ % (build_info.cudart_dll_name, build_info.cuda_version_number))
+
+ if hasattr(build_info, "cudnn_dll_name") and hasattr(
+ build_info, "cudnn_version_number"):
+ try:
+ ctypes.WinDLL(build_info.cudnn_dll_name)
+ except OSError:
+ raise ImportError(
+ "Could not find %r. TensorFlow requires that this DLL be "
+ "installed in a directory that is named in your %%PATH%% "
+ "environment variable. Note that installing cuDNN is a separate "
+ "step from installing CUDA, and this DLL is often found in a "
+ "different directory from the CUDA DLLs. You may install the "
+ "necessary DLL by downloading cuDNN %s from this URL: "
+ "https://developer.nvidia.com/cudnn"
+ % (build_info.cudnn_dll_name, build_info.cudnn_version_number))
+
else:
# TODO(mrry): Consider adding checks for the Linux and Mac OS X builds.
pass