aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Russell Power <power@google.com>2018-07-26 08:55:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 08:57:55 -0700
commitdeac85da170542596ba4d1a72ef5e63c0a398aba (patch)
tree3181c75e0f3068934029d75a7b86936f190d1ce8
parent8786b41d67241331ce0aa45c3df5d121039d5159 (diff)
Automated rollback of commit b8a9d163d9cbb4b581c044d9c4b1b256c801a9c4
PiperOrigin-RevId: 206166233
-rw-r--r--tensorflow/core/BUILD8
-rw-r--r--tensorflow/core/common_runtime/session_ref.cc170
-rw-r--r--tensorflow/core/common_runtime/session_ref.h86
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/client/session.py2
-rw-r--r--tensorflow/python/client/tf_session.i1
-rw-r--r--tensorflow/python/client/tf_session_helper.cc14
-rw-r--r--tensorflow/python/client/tf_session_helper.h3
8 files changed, 284 insertions, 1 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 84555b60da..35a112e834 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2925,6 +2925,14 @@ tf_cuda_library(
)
cc_library(
+ name = "session_ref",
+ srcs = ["common_runtime/session_ref.cc"],
+ hdrs = ["common_runtime/session_ref.h"],
+ copts = tf_copts(),
+ deps = [":core_cpu_base"],
+)
+
+cc_library(
name = "gpu_id",
hdrs = [
"common_runtime/gpu/gpu_id.h",
diff --git a/tensorflow/core/common_runtime/session_ref.cc b/tensorflow/core/common_runtime/session_ref.cc
new file mode 100644
index 0000000000..b931ef4229
--- /dev/null
+++ b/tensorflow/core/common_runtime/session_ref.cc
@@ -0,0 +1,170 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/session_ref.h"
+
+#include <utility>
+
+namespace tensorflow {
+
+namespace {
+
+// Scope helper to track active calls and manage session lifetime.
+struct RunCounter {
+ std::shared_ptr<Session> session;
+ uint64* value;
+ mutex* m;
+ condition_variable* cv;
+
+ explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
+ condition_variable* cv)
+ : session(std::move(s)), value(v), m(m), cv(cv) {
+ mutex_lock l(*m);
+ ++*value;
+ }
+
+ ~RunCounter() {
+ mutex_lock l(*m);
+ if (--*value == 0) {
+ cv->notify_all();
+ }
+ }
+};
+
+} // namespace
+
+Status SessionRef::CheckNotClosed() {
+ mutex_lock l(run_lock_);
+ if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
+ return ::tensorflow::Status::OK();
+}
+
+Status SessionRef::Run(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs,
+ RunMetadata* run_metadata) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->Run(run_options, inputs, output_tensor_names,
+ target_node_names, outputs, run_metadata);
+}
+
+Status SessionRef::Create(const GraphDef& graph) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->Create(graph);
+}
+
+Status SessionRef::Create(const RunOptions& run_options,
+ const GraphDef& graph) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->Create(run_options, graph);
+}
+
+Status SessionRef::Extend(const RunOptions& run_options,
+ const GraphDef& graph) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->Extend(run_options, graph);
+}
+
+Status SessionRef::Extend(const GraphDef& graph) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->Extend(graph);
+}
+
+Status SessionRef::Close(const RunOptions& run_options) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ mutex_lock l(run_lock_);
+ Status status = session_->Close(run_options);
+ session_.reset();
+ while (run_count_ > 0) {
+ run_finished_.wait(l);
+ }
+ return status;
+}
+
+Status SessionRef::Close() {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ mutex_lock l(run_lock_);
+ Status status = session_->Close();
+ session_.reset();
+ while (run_count_ > 0) {
+ run_finished_.wait(l);
+ }
+ return status;
+}
+
+Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->Run(inputs, output_tensor_names, target_node_names,
+ outputs);
+}
+
+Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->ListDevices(response);
+}
+
+Status SessionRef::PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->PRunSetup(input_names, output_names, target_nodes, handle);
+}
+
+Status SessionRef::PRun(const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->PRun(handle, inputs, output_names, outputs);
+}
+
+Status SessionRef::MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->MakeCallable(callable_options, out_handle);
+}
+
+Status SessionRef::RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->RunCallable(handle, feed_tensors, fetch_tensors,
+ run_metadata);
+}
+
+Status SessionRef::ReleaseCallable(CallableHandle handle) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
+ return rc.session->ReleaseCallable(handle);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_ref.h b/tensorflow/core/common_runtime/session_ref.h
new file mode 100644
index 0000000000..6146933326
--- /dev/null
+++ b/tensorflow/core/common_runtime/session_ref.h
@@ -0,0 +1,86 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
+
+#include <memory>
+
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+
+// A `SessionRef` manages the lifetime of a wrapped `Session` pointer.
+//
+// SessionRef blocks the return of Close() until all pending operations have
+// been completed or cancelled and underlying session has been freed. Any
+// subsequent operations on the SessionRef object will return errors::Cancelled.
+class SessionRef : public Session {
+ public:
+ SessionRef(Session* session) : session_(session) {}
+ virtual ~SessionRef() {}
+
+ Status Create(const GraphDef& graph) override;
+ Status Extend(const GraphDef& graph) override;
+ Status Create(const RunOptions& run_options, const GraphDef& graph) override;
+ Status Extend(const RunOptions& run_options, const GraphDef& graph) override;
+ Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) override;
+
+ Status ListDevices(std::vector<DeviceAttributes>* response) override;
+
+ Status Close() override;
+ Status Close(const RunOptions& run_options) override;
+
+ Status Run(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs, RunMetadata* run_metadata) override;
+
+ Status PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) override;
+
+ Status PRun(const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) override;
+
+ Status MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle);
+
+ Status RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata);
+
+ Status ReleaseCallable(CallableHandle handle);
+
+ private:
+ mutex run_lock_;
+ condition_variable run_finished_;
+ uint64 run_count_ GUARDED_BY(run_lock_) = {0};
+ std::shared_ptr<Session> session_;
+
+ Status CheckNotClosed();
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index b5876c3457..d35731d3cd 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3658,6 +3658,7 @@ tf_cuda_library(
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_ref",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
],
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 180bb74d00..861230e5a0 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -630,7 +630,7 @@ class BaseSession(SessionInterface):
opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
try:
# pylint: disable=protected-access
- self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
+ self._session = tf_session.TF_NewSessionRef(self._graph._c_graph, opts)
# pylint: enable=protected-access
finally:
tf_session.TF_DeleteSessionOptions(opts)
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 1cdd8e0b6a..39a2922ac0 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -777,6 +777,7 @@ def TF_Reset(target, containers=None, config=None):
$1 = &types_local;
}
+%unignore TF_NewSessionRef;
%unignore SetRequireShapeInferenceFns;
%unignore TF_TryEvaluateConstant_wrapper;
%noexception TF_TryEvaluateConstant_wrapper;
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index b6481e7e29..bcd4af2912 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/core/common_runtime/session_ref.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
@@ -42,6 +43,19 @@ static const char* kFeedDictErrorMsg =
"feed_dict must be a dictionary mapping strings to NumPy arrays.";
} // end namespace
+TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts,
+ TF_Status* status) {
+ TF_Session* tf_session = TF_NewSession(graph, opts, status);
+ if (tf_session == nullptr) {
+ return nullptr;
+ }
+
+ Session* session = reinterpret_cast<Session*>(tf_session->session);
+ SessionRef* session_ref = new SessionRef(session);
+ tf_session->session = session_ref;
+ return tf_session;
+}
+
void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
const TF_Buffer* run_options, PyObject* feed_dict,
const NameVector& output_names,
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index cfd27c2bee..dab7e71aac 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -40,6 +40,9 @@ typedef tensorflow::gtl::InlinedVector<PyObject*, 8> PyObjectVector;
// A TF_TensorVector is a vector of borrowed pointers to TF_Tensors.
typedef gtl::InlinedVector<TF_Tensor*, 8> TF_TensorVector;
+TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts,
+ TF_Status* status);
+
// Run the graph associated with the session starting with the
// supplied inputs[]. Regardless of success or failure, inputs[] are
// stolen by the implementation (i.e. the implementation will