diff options
author | Russell Power <power@google.com> | 2018-07-26 08:55:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-26 08:57:55 -0700 |
commit | deac85da170542596ba4d1a72ef5e63c0a398aba (patch) | |
tree | 3181c75e0f3068934029d75a7b86936f190d1ce8 /tensorflow/python/client | |
parent | 8786b41d67241331ce0aa45c3df5d121039d5159 (diff) |
Automated rollback of commit b8a9d163d9cbb4b581c044d9c4b1b256c801a9c4
PiperOrigin-RevId: 206166233
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r-- | tensorflow/python/client/session.py | 2 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session.i | 1 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.cc | 14 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.h | 3 |
4 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 180bb74d00..861230e5a0 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -630,7 +630,7 @@ class BaseSession(SessionInterface): opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) try: # pylint: disable=protected-access - self._session = tf_session.TF_NewSession(self._graph._c_graph, opts) + self._session = tf_session.TF_NewSessionRef(self._graph._c_graph, opts) # pylint: enable=protected-access finally: tf_session.TF_DeleteSessionOptions(opts) diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 1cdd8e0b6a..39a2922ac0 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -777,6 +777,7 @@ def TF_Reset(target, containers=None, config=None): $1 = &types_local; } +%unignore TF_NewSessionRef; %unignore SetRequireShapeInferenceFns; %unignore TF_TryEvaluateConstant_wrapper; %noexception TF_TryEvaluateConstant_wrapper; diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index b6481e7e29..bcd4af2912 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/common_runtime/session_ref.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" @@ -42,6 +43,19 @@ static const char* kFeedDictErrorMsg = "feed_dict must be a dictionary mapping strings to NumPy arrays."; } // end namespace +TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts, + TF_Status* status) { + TF_Session* tf_session = TF_NewSession(graph, opts, status); + if (tf_session == nullptr) { + return nullptr; + } + + Session* session = reinterpret_cast<Session*>(tf_session->session); + SessionRef* session_ref = new SessionRef(session); + tf_session->session = session_ref; + return tf_session; +} + void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle, const TF_Buffer* run_options, PyObject* feed_dict, const NameVector& output_names, diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index cfd27c2bee..dab7e71aac 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -40,6 +40,9 @@ typedef tensorflow::gtl::InlinedVector<PyObject*, 8> PyObjectVector; // A TF_TensorVector is a vector of borrowed pointers to TF_Tensors. typedef gtl::InlinedVector<TF_Tensor*, 8> TF_TensorVector; +TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts, + TF_Status* status); + // Run the graph associated with the session starting with the // supplied inputs[]. Regardless of success or failure, inputs[] are // stolen by the implementation (i.e. the implementation will |