diff options
Diffstat (limited to 'tensorflow/contrib/util/loader.py')
-rw-r--r-- | tensorflow/contrib/util/loader.py | 27 |
1 files changed, 17 insertions, 10 deletions
diff --git a/tensorflow/contrib/util/loader.py b/tensorflow/contrib/util/loader.py index 95657217a0..c2ae425b56 100644 --- a/tensorflow/contrib/util/loader.py +++ b/tensorflow/contrib/util/loader.py @@ -21,6 +21,7 @@ from __future__ import division from __future__ import print_function import os +import re from tensorflow.python.framework import load_library from tensorflow.python.platform import resource_loader @@ -29,9 +30,9 @@ from tensorflow.python.platform import resource_loader def load_op_library(path): """Loads a contrib op library from the given path. - NOTE(mrry): On Windows, we currently assume that contrib op + NOTE(mrry): On Windows, we currently assume that some contrib op libraries are statically linked into the main TensorFlow Python - extension DLL. + extension DLL - use dynamically linked ops if the .so is present. Args: path: An absolute path to a shared object file. @@ -40,11 +41,17 @@ def load_op_library(path): A Python module containing the Python wrappers for Ops defined in the plugin. """ - if os.name != 'nt': - path = resource_loader.get_path_to_datafile(path) - ret = load_library.load_op_library(path) - assert ret, 'Could not load %s' % path - return ret - else: - # NOTE(mrry): - return None + if os.name == 'nt': + # To avoid makeing every user_ops aware of windows, re-write + # the file extension from .so to .dll. + path = re.sub('\.so$', '.dll', path) + + # TODO: currently we have only some user_ops as .dll's on windows - don't try + # to load them if the dll is not found. Once we have all of them + # this check should be removed. + if not os.path.exists(path): + return None + path = resource_loader.get_path_to_datafile(path) + ret = load_library.load_op_library(path) + assert ret, 'Could not load %s' % path + return ret |