aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/direct_session.h
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2016-08-31 15:43:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-31 16:49:05 -0700
commit62c159ffe847eeb788550a32b8be572e41055022 (patch)
treeb4e005f3b5664472f3984c008b2b27f3cb7b0439 /tensorflow/core/common_runtime/direct_session.h
parente11b99749d2898df0bce0269c77df316b059e8ea (diff)
Add Reset implementation for DirectSession.
- Reset clears and closes the specified containers for ALL DirectSession objects. - Add closed bit to DirectSession to ensure that operations that occur after Close is called fail. Change: 131889161
Diffstat (limited to 'tensorflow/core/common_runtime/direct_session.h')
-rw-r--r--tensorflow/core/common_runtime/direct_session.h25
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index dcb2c584c8..8681d8fb7c 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/common_runtime/session_factory.h"
#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
#include "tensorflow/core/debug/debug_graph_utils.h"
#include "tensorflow/core/framework/cancellation.h"
@@ -47,11 +48,18 @@ namespace tensorflow {
class CostModel;
class DebugGateway;
class Device;
+class DirectSessionFactory;
class DirectSession : public Session {
public:
+ typedef std::function<void(Session*)> CloseCallback;
+
// Takes ownership of 'device_mgr'.
- DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr);
+ // 'factory' is used to unregister the DirectSession with 'factory' when its
+ // closed. This ensures that Reset requests from the 'factory' don't get sent
+ // to sessions that are already closed.
+ DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr,
+ DirectSessionFactory* factory);
~DirectSession() override;
typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
@@ -83,6 +91,10 @@ class DirectSession : public Session {
const std::vector<string>& output_names,
std::vector<Tensor>* outputs) override;
+ // Reset clears 'containers' from the device_mgr of the DirectSession.
+ // If 'containers' is empty, then Reset clears the default container.
+ ::tensorflow::Status Reset(const std::vector<string>& containers);
+
::tensorflow::Status Close() override;
void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
@@ -198,6 +210,12 @@ class DirectSession : public Session {
// operation_timeout_in_ms is greater than 0.
void WaitForNotification(RunState* run_state, int64 timeout_in_ms);
+ ::tensorflow::Status CheckNotClosed() {
+ mutex_lock l(mu_);
+ if (closed_) return errors::Cancelled("Session has been closed.");
+ return ::tensorflow::Status::OK();
+ }
+
const SessionOptions options_;
// Device structures.
@@ -232,10 +250,12 @@ class DirectSession : public Session {
// This holds all the tensors that are currently alive in the session.
SessionState session_state_;
+ DirectSessionFactory* const factory_; // not owned
CancellationManager* cancellation_manager_;
// Saves and restores device placements for stateful nodes.
mutex mu_;
+
// Map of placed stateful nodes, i.e. nodes for which is_stateful()
// is true, such as "params" and "queue" nodes. Once placed these
// nodes can not be moved to a different device. Maps node names to
@@ -251,6 +271,9 @@ class DirectSession : public Session {
// library; it copies and modifies the function library.
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
+ // true if the Session has been Closed.
+ bool closed_ GUARDED_BY(mu_);
+
// For generating unique names.
int64 name_counter_ GUARDED_BY(mu_) = 0;