diff options
Diffstat (limited to 'tensorflow/core/common_runtime/direct_session.h')
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.h | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index dcb2c584c8..8681d8fb7c 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/common_runtime/session_factory.h" #include "tensorflow/core/common_runtime/simple_graph_execution_state.h" #include "tensorflow/core/debug/debug_graph_utils.h" #include "tensorflow/core/framework/cancellation.h" @@ -47,11 +48,18 @@ namespace tensorflow { class CostModel; class DebugGateway; class Device; +class DirectSessionFactory; class DirectSession : public Session { public: + typedef std::function<void(Session*)> CloseCallback; + // Takes ownership of 'device_mgr'. - DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr); + // 'factory' is used to unregister the DirectSession with 'factory' when its + // closed. This ensures that Reset requests from the 'factory' don't get sent + // to sessions that are already closed. + DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr, + DirectSessionFactory* factory); ~DirectSession() override; typedef std::vector<std::pair<string, Tensor>> NamedTensorList; @@ -83,6 +91,10 @@ class DirectSession : public Session { const std::vector<string>& output_names, std::vector<Tensor>* outputs) override; + // Reset clears 'containers' from the device_mgr of the DirectSession. + // If 'containers' is empty, then Reset clears the default container. + ::tensorflow::Status Reset(const std::vector<string>& containers); + ::tensorflow::Status Close() override; void ExportCostModels(CostModelManager::CostModelMap* cost_models) { @@ -198,6 +210,12 @@ class DirectSession : public Session { // operation_timeout_in_ms is greater than 0. void WaitForNotification(RunState* run_state, int64 timeout_in_ms); + ::tensorflow::Status CheckNotClosed() { + mutex_lock l(mu_); + if (closed_) return errors::Cancelled("Session has been closed."); + return ::tensorflow::Status::OK(); + } + const SessionOptions options_; // Device structures. @@ -232,10 +250,12 @@ class DirectSession : public Session { // This holds all the tensors that are currently alive in the session. SessionState session_state_; + DirectSessionFactory* const factory_; // not owned CancellationManager* cancellation_manager_; // Saves and restores device placements for stateful nodes. mutex mu_; + // Map of placed stateful nodes, i.e. nodes for which is_stateful() // is true, such as "params" and "queue" nodes. Once placed these // nodes can not be moved to a different device. Maps node names to @@ -251,6 +271,9 @@ class DirectSession : public Session { // library; it copies and modifies the function library. std::unique_ptr<FunctionLibraryDefinition> flib_def_; + // true if the Session has been Closed. + bool closed_ GUARDED_BY(mu_); + // For generating unique names. int64 name_counter_ GUARDED_BY(mu_) = 0; |