aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
authorGravatar Gunhan Gulsoy <gunan@google.com>2018-09-18 17:34:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 17:38:08 -0700
commit38d8f893e0ab8376cf97c40fde78002f31776c92 (patch)
tree0cb850776e2f7564ad7fb53dd2e5816136ee0ba3 /tensorflow/python/framework
parent867449616aa43f9306247cebdd1edac85b70852a (diff)
Add a new function to load kernel libraries and library folders.
PiperOrigin-RevId: 213549838
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/load_library.py65
1 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py
index 535c6017f5..908a5f521e 100644
--- a/tensorflow/python/framework/load_library.py
+++ b/tensorflow/python/framework/load_library.py
@@ -18,14 +18,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import errno
import hashlib
import imp
+import os
+import platform
import sys
import threading # pylint: disable=unused-import
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import
from tensorflow.python import pywrap_tensorflow as py_tf
+from tensorflow.python.lib.io import file_io
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
@@ -98,3 +102,64 @@ def load_file_system_library(library_filename):
RuntimeError: when unable to load the library.
"""
py_tf.TF_LoadLibrary(library_filename)
+
+
+def _is_shared_object(filename):
+ """Check the file to see if it is a shared object, only using extension."""
+ if platform.system() == 'Linux':
+ if filename.endswith('.so'):
+ return True
+ else:
+ index = filename.rfind('.so.')
+ if index == -1:
+ return False
+ else:
+ # A shared object with the API version in filename
+ return filename[index + 4].isdecimal()
+ elif platform.system() == 'Darwin':
+ return filename.endswith('.dylib')
+ elif platform.system() == 'Windows':
+ return filename.endswith('.dll')
+ else:
+ return False
+
+
+@tf_export('load_library')
+def load_library(library_location):
+ """Loads a TensorFlow plugin.
+
+ "library_location" can be a path to a specific shared object, or a folder.
+ If it is a folder, all sahred objects that are named "libtfkernel*" will be
+ loaded. When the library is loaded, kernels registered in the library via the
+ `REGISTER_*` macros are made available in the TensorFlow process.
+
+ Args:
+ library_location: Path to the plugin or the folder of plugins.
+ Relative or absolute filesystem path to a dynamic library file or folder.
+
+ Returns:
+ None
+
+ Raises:
+ OSError: When the file to be loaded is not found.
+ RuntimeError: when unable to load the library.
+ """
+ if file_io.file_exists(library_location):
+ if file_io.is_directory(library_location):
+ directory_contents = file_io.list_directory(library_location)
+
+ kernel_libraries = [
+ os.path.join(library_location, f) for f in directory_contents
+ if _is_shared_object(f)]
+ else:
+ kernel_libraries = [library_location]
+
+ for lib in kernel_libraries:
+ py_tf.TF_LoadLibrary(lib)
+
+ else:
+ raise OSError(
+ errno.ENOENT,
+ 'The file or folder to load kernel libraries from does not exist.',
+ library_location)
+