aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/pywrap_tensorflow.py
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2017-02-28 18:36:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-28 18:51:06 -0800
commit718812c9e4df55b8b3275aa4db7bb6833ed03111 (patch)
treee8bc57fe5bfaea125d4d4c4535e29a7ed2306057 /tensorflow/python/pywrap_tensorflow.py
parent4e63540076921d2c08d03aa9efb76fd483920593 (diff)
Fix the dlopen contrib test hack by making a pywrap_tensorflow module that imports
pywrap_tensorflow_internal with RTLD_GLOBAL. Fixes #6568 Change: 148843302
Diffstat (limited to 'tensorflow/python/pywrap_tensorflow.py')
-rw-r--r--tensorflow/python/pywrap_tensorflow.py54
1 files changed, 54 insertions, 0 deletions
diff --git a/tensorflow/python/pywrap_tensorflow.py b/tensorflow/python/pywrap_tensorflow.py
new file mode 100644
index 0000000000..f116081752
--- /dev/null
+++ b/tensorflow/python/pywrap_tensorflow.py
@@ -0,0 +1,54 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""pywrap_tensorflow wrapper that exports all symbols with RTLD_GLOBAL."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ctypes
+import sys
+import traceback
+
+# pylint: disable=wildcard-import,g-import-not-at-top,unused-import,line-too-long
+
+# On UNIX-based platforms, pywrap_tensorflow is a SWIG-generated
+# python library that dynamically loads _pywrap_tensorflow.so. The
+# default mode for loading keeps all the symbol private and not
+# visible to other libraries that may be loaded. Setting the mode to
+# RTLD_GLOBAL to make the symbols visible, so that custom op libraries
+# imported using `tf.load_op_library()` can access symbols defined in
+# _pywrap_tensorflow.so.
+try:
+ # TODO(keveman,mrry): Support dynamic op loading on platforms that do not
+ # use `dlopen()` for dynamic loading.
+ _use_rtld_global = hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags')
+ if _use_rtld_global:
+ _default_dlopen_flags = sys.getdlopenflags()
+ sys.setdlopenflags(_default_dlopen_flags | ctypes.RTLD_GLOBAL)
+ from tensorflow.python.pywrap_tensorflow_internal import *
+ from tensorflow.python.pywrap_tensorflow_internal import __version__
+ from tensorflow.python.pywrap_tensorflow_internal import __git_version__
+ from tensorflow.python.pywrap_tensorflow_internal import __compiler_version__
+ if _use_rtld_global:
+ sys.setdlopenflags(_default_dlopen_flags)
+except ImportError:
+ msg = """%s\n\nFailed to load the native TensorFlow runtime.\n
+See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/get_started/os_setup.md#import_error\n
+for some common reasons and solutions. Include the entire stack trace
+above this error message when asking for help.""" % traceback.format_exc()
+ raise ImportError(msg)
+
+# pylint: enable=wildcard-import,g-import-not-at-top,unused-import,line-too-long