diff options
author | 2018-09-18 17:34:53 -0700 | |
---|---|---|
committer | 2018-09-18 17:38:08 -0700 | |
commit | 38d8f893e0ab8376cf97c40fde78002f31776c92 (patch) | |
tree | 0cb850776e2f7564ad7fb53dd2e5816136ee0ba3 /tensorflow/python/framework | |
parent | 867449616aa43f9306247cebdd1edac85b70852a (diff) |
Add a new function to load kernel libraries and library folders.
PiperOrigin-RevId: 213549838
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r-- | tensorflow/python/framework/load_library.py | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py index 535c6017f5..908a5f521e 100644 --- a/tensorflow/python/framework/load_library.py +++ b/tensorflow/python/framework/load_library.py @@ -18,14 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import errno import hashlib import imp +import os +import platform import sys import threading # pylint: disable=unused-import from tensorflow.core.framework import op_def_pb2 from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import from tensorflow.python import pywrap_tensorflow as py_tf +from tensorflow.python.lib.io import file_io from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export @@ -98,3 +102,64 @@ def load_file_system_library(library_filename): RuntimeError: when unable to load the library. """ py_tf.TF_LoadLibrary(library_filename) + + +def _is_shared_object(filename): + """Check the file to see if it is a shared object, only using extension.""" + if platform.system() == 'Linux': + if filename.endswith('.so'): + return True + else: + index = filename.rfind('.so.') + if index == -1: + return False + else: + # A shared object with the API version in filename + return filename[index + 4].isdecimal() + elif platform.system() == 'Darwin': + return filename.endswith('.dylib') + elif platform.system() == 'Windows': + return filename.endswith('.dll') + else: + return False + + +@tf_export('load_library') +def load_library(library_location): + """Loads a TensorFlow plugin. + + "library_location" can be a path to a specific shared object, or a folder. + If it is a folder, all sahred objects that are named "libtfkernel*" will be + loaded. When the library is loaded, kernels registered in the library via the + `REGISTER_*` macros are made available in the TensorFlow process. + + Args: + library_location: Path to the plugin or the folder of plugins. + Relative or absolute filesystem path to a dynamic library file or folder. + + Returns: + None + + Raises: + OSError: When the file to be loaded is not found. + RuntimeError: when unable to load the library. + """ + if file_io.file_exists(library_location): + if file_io.is_directory(library_location): + directory_contents = file_io.list_directory(library_location) + + kernel_libraries = [ + os.path.join(library_location, f) for f in directory_contents + if _is_shared_object(f)] + else: + kernel_libraries = [library_location] + + for lib in kernel_libraries: + py_tf.TF_LoadLibrary(lib) + + else: + raise OSError( + errno.ENOENT, + 'The file or folder to load kernel libraries from does not exist.', + library_location) + |