aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client
diff options
context:
space:
mode:
authorGravatar Russell Power <power@google.com>2018-07-26 08:55:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 08:57:55 -0700
commitdeac85da170542596ba4d1a72ef5e63c0a398aba (patch)
tree3181c75e0f3068934029d75a7b86936f190d1ce8 /tensorflow/python/client
parent8786b41d67241331ce0aa45c3df5d121039d5159 (diff)
Automated rollback of commit b8a9d163d9cbb4b581c044d9c4b1b256c801a9c4
PiperOrigin-RevId: 206166233
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r--tensorflow/python/client/session.py2
-rw-r--r--tensorflow/python/client/tf_session.i1
-rw-r--r--tensorflow/python/client/tf_session_helper.cc14
-rw-r--r--tensorflow/python/client/tf_session_helper.h3
4 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 180bb74d00..861230e5a0 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -630,7 +630,7 @@ class BaseSession(SessionInterface):
opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
try:
# pylint: disable=protected-access
- self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
+ self._session = tf_session.TF_NewSessionRef(self._graph._c_graph, opts)
# pylint: enable=protected-access
finally:
tf_session.TF_DeleteSessionOptions(opts)
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 1cdd8e0b6a..39a2922ac0 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -777,6 +777,7 @@ def TF_Reset(target, containers=None, config=None):
$1 = &types_local;
}
+%unignore TF_NewSessionRef;
%unignore SetRequireShapeInferenceFns;
%unignore TF_TryEvaluateConstant_wrapper;
%noexception TF_TryEvaluateConstant_wrapper;
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index b6481e7e29..bcd4af2912 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/core/common_runtime/session_ref.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
@@ -42,6 +43,19 @@ static const char* kFeedDictErrorMsg =
"feed_dict must be a dictionary mapping strings to NumPy arrays.";
} // end namespace
+TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts,
+ TF_Status* status) {
+ TF_Session* tf_session = TF_NewSession(graph, opts, status);
+ if (tf_session == nullptr) {
+ return nullptr;
+ }
+
+ Session* session = reinterpret_cast<Session*>(tf_session->session);
+ SessionRef* session_ref = new SessionRef(session);
+ tf_session->session = session_ref;
+ return tf_session;
+}
+
void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
const TF_Buffer* run_options, PyObject* feed_dict,
const NameVector& output_names,
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index cfd27c2bee..dab7e71aac 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -40,6 +40,9 @@ typedef tensorflow::gtl::InlinedVector<PyObject*, 8> PyObjectVector;
// A TF_TensorVector is a vector of borrowed pointers to TF_Tensors.
typedef gtl::InlinedVector<TF_Tensor*, 8> TF_TensorVector;
+TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts,
+ TF_Status* status);
+
// Run the graph associated with the session starting with the
// supplied inputs[]. Regardless of success or failure, inputs[] are
// stolen by the implementation (i.e. the implementation will