diff options
author | 2018-07-26 08:55:08 -0700 | |
---|---|---|
committer | 2018-07-26 08:57:55 -0700 | |
commit | deac85da170542596ba4d1a72ef5e63c0a398aba (patch) | |
tree | 3181c75e0f3068934029d75a7b86936f190d1ce8 | |
parent | 8786b41d67241331ce0aa45c3df5d121039d5159 (diff) |
Automated rollback of commit b8a9d163d9cbb4b581c044d9c4b1b256c801a9c4
PiperOrigin-RevId: 206166233
-rw-r--r-- | tensorflow/core/BUILD | 8 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/session_ref.cc | 170 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/session_ref.h | 86 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/client/session.py | 2 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session.i | 1 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.cc | 14 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.h | 3 |
8 files changed, 284 insertions, 1 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 84555b60da..35a112e834 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2925,6 +2925,14 @@ tf_cuda_library( ) cc_library( + name = "session_ref", + srcs = ["common_runtime/session_ref.cc"], + hdrs = ["common_runtime/session_ref.h"], + copts = tf_copts(), + deps = [":core_cpu_base"], +) + +cc_library( name = "gpu_id", hdrs = [ "common_runtime/gpu/gpu_id.h", diff --git a/tensorflow/core/common_runtime/session_ref.cc b/tensorflow/core/common_runtime/session_ref.cc new file mode 100644 index 0000000000..b931ef4229 --- /dev/null +++ b/tensorflow/core/common_runtime/session_ref.cc @@ -0,0 +1,170 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/session_ref.h" + +#include <utility> + +namespace tensorflow { + +namespace { + +// Scope helper to track active calls and manage session lifetime. +struct RunCounter { + std::shared_ptr<Session> session; + uint64* value; + mutex* m; + condition_variable* cv; + + explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m, + condition_variable* cv) + : session(std::move(s)), value(v), m(m), cv(cv) { + mutex_lock l(*m); + ++*value; + } + + ~RunCounter() { + mutex_lock l(*m); + if (--*value == 0) { + cv->notify_all(); + } + } +}; + +} // namespace + +Status SessionRef::CheckNotClosed() { + mutex_lock l(run_lock_); + if (session_ == nullptr) return errors::Cancelled("Session has been closed."); + return ::tensorflow::Status::OK(); +} + +Status SessionRef::Run(const RunOptions& run_options, + const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, + std::vector<Tensor>* outputs, + RunMetadata* run_metadata) { + TF_RETURN_IF_ERROR(CheckNotClosed()); + RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); + return rc.session->Run(run_options, inputs, output_tensor_names, + target_node_names, outputs, run_metadata); +} + +Status SessionRef::Create(const GraphDef& graph) { + TF_RETURN_IF_ERROR(CheckNotClosed()); + RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); + return rc.session->Create(graph); +} + +Status SessionRef::Create(const RunOptions& run_options, + const GraphDef& graph) { + TF_RETURN_IF_ERROR(CheckNotClosed()); + RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); + return rc.session->Create(run_options, graph); +} + +Status SessionRef::Extend(const RunOptions& run_options, + const GraphDef& graph) { + TF_RETURN_IF_ERROR(CheckNotClosed()); + RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); + return rc.session->Extend(run_options, graph); +} + +Status SessionRef::Extend(const GraphDef& graph) { + TF_RETURN_IF_ERROR(CheckNotClosed()); + RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); + return rc.session->Extend(graph); +} + +Status SessionRef::Close(const RunOptions& run_options) { + TF_RETURN_IF_ERROR(CheckNotClosed()); + mutex_lock l(run_lock_); + Status status = session_->Close(run_options); + session_.reset(); + while (run_count_ > 0) { + run_finished_.wait(l); + } + return status; +} + +Status SessionRef::Close() { + TF_RETURN_IF_ERROR(CheckNotClosed()); + mutex_lock l(run_lock_); + Status status = session_->Close(); + session_.reset(); + while (run_count_ > 0) { + run_finished_.wait(l); + } + return status; +} + +Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, + std::vector<Tensor>* outputs) { + TF_RETURN_IF_ERROR(CheckNotClosed()); + RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); + return rc.session->Run(inputs, output_tensor_names, target_node_names, + outputs); +} + +Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) { + TF_RETURN_IF_ERROR(CheckNotClosed()); + RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); + return rc.session->ListDevices(response); +} + +Status SessionRef::PRunSetup(const std::vector<string>& input_names, + const std::vector<string>& output_names, + const std::vector<string>& target_nodes, + string* handle) { + TF_RETURN_IF_ERROR(CheckNotClosed()); + RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); + return rc.session->PRunSetup(input_names, output_names, target_nodes, handle); +} + +Status SessionRef::PRun(const string& handle, + const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_names, + std::vector<Tensor>* outputs) { + TF_RETURN_IF_ERROR(CheckNotClosed()); + RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); + return rc.session->PRun(handle, inputs, output_names, outputs); +} + +Status SessionRef::MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) { + TF_RETURN_IF_ERROR(CheckNotClosed()); + RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); + return rc.session->MakeCallable(callable_options, out_handle); +} + +Status SessionRef::RunCallable(CallableHandle handle, + const std::vector<Tensor>& feed_tensors, + std::vector<Tensor>* fetch_tensors, + RunMetadata* run_metadata) { + TF_RETURN_IF_ERROR(CheckNotClosed()); + RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); + return rc.session->RunCallable(handle, feed_tensors, fetch_tensors, + run_metadata); +} + +Status SessionRef::ReleaseCallable(CallableHandle handle) { + TF_RETURN_IF_ERROR(CheckNotClosed()); + RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); + return rc.session->ReleaseCallable(handle); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/session_ref.h b/tensorflow/core/common_runtime/session_ref.h new file mode 100644 index 0000000000..6146933326 --- /dev/null +++ b/tensorflow/core/common_runtime/session_ref.h @@ -0,0 +1,86 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_ + +#include <memory> + +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { + +// A `SessionRef` manages the lifetime of a wrapped `Session` pointer. +// +// SessionRef blocks the return of Close() until all pending operations have +// been completed or cancelled and underlying session has been freed. Any +// subsequent operations on the SessionRef object will return errors::Cancelled. +class SessionRef : public Session { + public: + SessionRef(Session* session) : session_(session) {} + virtual ~SessionRef() {} + + Status Create(const GraphDef& graph) override; + Status Extend(const GraphDef& graph) override; + Status Create(const RunOptions& run_options, const GraphDef& graph) override; + Status Extend(const RunOptions& run_options, const GraphDef& graph) override; + Status Run(const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, + std::vector<Tensor>* outputs) override; + + Status ListDevices(std::vector<DeviceAttributes>* response) override; + + Status Close() override; + Status Close(const RunOptions& run_options) override; + + Status Run(const RunOptions& run_options, + const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_tensor_names, + const std::vector<string>& target_node_names, + std::vector<Tensor>* outputs, RunMetadata* run_metadata) override; + + Status PRunSetup(const std::vector<string>& input_names, + const std::vector<string>& output_names, + const std::vector<string>& target_nodes, + string* handle) override; + + Status PRun(const string& handle, + const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_names, + std::vector<Tensor>* outputs) override; + + Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle); + + Status RunCallable(CallableHandle handle, + const std::vector<Tensor>& feed_tensors, + std::vector<Tensor>* fetch_tensors, + RunMetadata* run_metadata); + + Status ReleaseCallable(CallableHandle handle); + + private: + mutex run_lock_; + condition_variable run_finished_; + uint64 run_count_ GUARDED_BY(run_lock_) = {0}; + std::shared_ptr<Session> session_; + + Status CheckNotClosed(); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_ diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index b5876c3457..d35731d3cd 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3658,6 +3658,7 @@ tf_cuda_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_ref", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", ], diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 180bb74d00..861230e5a0 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -630,7 +630,7 @@ class BaseSession(SessionInterface): opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) try: # pylint: disable=protected-access - self._session = tf_session.TF_NewSession(self._graph._c_graph, opts) + self._session = tf_session.TF_NewSessionRef(self._graph._c_graph, opts) # pylint: enable=protected-access finally: tf_session.TF_DeleteSessionOptions(opts) diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 1cdd8e0b6a..39a2922ac0 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -777,6 +777,7 @@ def TF_Reset(target, containers=None, config=None): $1 = &types_local; } +%unignore TF_NewSessionRef; %unignore SetRequireShapeInferenceFns; %unignore TF_TryEvaluateConstant_wrapper; %noexception TF_TryEvaluateConstant_wrapper; diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index b6481e7e29..bcd4af2912 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/common_runtime/session_ref.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" @@ -42,6 +43,19 @@ static const char* kFeedDictErrorMsg = "feed_dict must be a dictionary mapping strings to NumPy arrays."; } // end namespace +TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts, + TF_Status* status) { + TF_Session* tf_session = TF_NewSession(graph, opts, status); + if (tf_session == nullptr) { + return nullptr; + } + + Session* session = reinterpret_cast<Session*>(tf_session->session); + SessionRef* session_ref = new SessionRef(session); + tf_session->session = session_ref; + return tf_session; +} + void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle, const TF_Buffer* run_options, PyObject* feed_dict, const NameVector& output_names, diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index cfd27c2bee..dab7e71aac 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -40,6 +40,9 @@ typedef tensorflow::gtl::InlinedVector<PyObject*, 8> PyObjectVector; // A TF_TensorVector is a vector of borrowed pointers to TF_Tensors. typedef gtl::InlinedVector<TF_Tensor*, 8> TF_TensorVector; +TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts, + TF_Status* status); + // Run the graph associated with the session starting with the // supplied inputs[]. Regardless of success or failure, inputs[] are // stolen by the implementation (i.e. the implementation will |