diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2016-08-31 15:43:07 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-31 16:49:05 -0700 |
commit | 62c159ffe847eeb788550a32b8be572e41055022 (patch) | |
tree | b4e005f3b5664472f3984c008b2b27f3cb7b0439 /tensorflow/core/common_runtime/direct_session.h | |
parent | e11b99749d2898df0bce0269c77df316b059e8ea (diff) |
Add Reset implementation for DirectSession.
- Reset clears and closes the specified containers for ALL DirectSession objects.
- Add closed bit to DirectSession to ensure that operations that occur after Close is called fail.
Change: 131889161
Diffstat (limited to 'tensorflow/core/common_runtime/direct_session.h')
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.h | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index dcb2c584c8..8681d8fb7c 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/common_runtime/session_factory.h" #include "tensorflow/core/common_runtime/simple_graph_execution_state.h" #include "tensorflow/core/debug/debug_graph_utils.h" #include "tensorflow/core/framework/cancellation.h" @@ -47,11 +48,18 @@ namespace tensorflow { class CostModel; class DebugGateway; class Device; +class DirectSessionFactory; class DirectSession : public Session { public: + typedef std::function<void(Session*)> CloseCallback; + // Takes ownership of 'device_mgr'. - DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr); + // 'factory' is used to unregister the DirectSession with 'factory' when its + // closed. This ensures that Reset requests from the 'factory' don't get sent + // to sessions that are already closed. + DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr, + DirectSessionFactory* factory); ~DirectSession() override; typedef std::vector<std::pair<string, Tensor>> NamedTensorList; @@ -83,6 +91,10 @@ class DirectSession : public Session { const std::vector<string>& output_names, std::vector<Tensor>* outputs) override; + // Reset clears 'containers' from the device_mgr of the DirectSession. + // If 'containers' is empty, then Reset clears the default container. + ::tensorflow::Status Reset(const std::vector<string>& containers); + ::tensorflow::Status Close() override; void ExportCostModels(CostModelManager::CostModelMap* cost_models) { @@ -198,6 +210,12 @@ class DirectSession : public Session { // operation_timeout_in_ms is greater than 0. void WaitForNotification(RunState* run_state, int64 timeout_in_ms); + ::tensorflow::Status CheckNotClosed() { + mutex_lock l(mu_); + if (closed_) return errors::Cancelled("Session has been closed."); + return ::tensorflow::Status::OK(); + } + const SessionOptions options_; // Device structures. @@ -232,10 +250,12 @@ class DirectSession : public Session { // This holds all the tensors that are currently alive in the session. SessionState session_state_; + DirectSessionFactory* const factory_; // not owned CancellationManager* cancellation_manager_; // Saves and restores device placements for stateful nodes. mutex mu_; + // Map of placed stateful nodes, i.e. nodes for which is_stateful() // is true, such as "params" and "queue" nodes. Once placed these // nodes can not be moved to a different device. Maps node names to @@ -251,6 +271,9 @@ class DirectSession : public Session { // library; it copies and modifies the function library. std::unique_ptr<FunctionLibraryDefinition> flib_def_; + // true if the Session has been Closed. + bool closed_ GUARDED_BY(mu_); + // For generating unique names. int64 name_counter_ GUARDED_BY(mu_) = 0; |