diff options
author | 2017-02-23 16:00:12 -0800 | |
---|---|---|
committer | 2017-02-23 16:09:44 -0800 | |
commit | 7c46b4600a084b64da48ee26b2f974d5f56e8cf1 (patch) | |
tree | e7c3b3f2ff27423481986357e8d11a38c9333b3a /tensorflow/cc/client | |
parent | 609618f209c16b7910432da6e216ca1cebe8dcb0 (diff) |
C++ API: create ClientSession::Impl class to hide private members/methods.
Eventually all public API classes should follow this pattern.
Change: 148403490
Diffstat (limited to 'tensorflow/cc/client')
-rw-r--r-- | tensorflow/cc/client/client_session.cc | 44 | ||||
-rw-r--r-- | tensorflow/cc/client/client_session.h | 20 |
2 files changed, 39 insertions, 25 deletions
diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc index 644409203c..2732f3f501 100644 --- a/tensorflow/cc/client/client_session.cc +++ b/tensorflow/cc/client/client_session.cc @@ -19,29 +19,51 @@ limitations under the License. #include <vector> #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { +class ClientSession::Impl { + private: + friend class ClientSession; + + Impl(Session* session, std::shared_ptr<Graph> graph) + : session_(session), graph_(graph) {} + + static SessionOptions MakeDefaultSessionOptions(const string& target); + Status MaybeExtendGraph() const; + + std::unique_ptr<Session> session_; + std::shared_ptr<Graph> graph_; + + mutable mutex mu_; + mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0; +}; + ClientSession::ClientSession(const Scope& scope, const string& target) - : ClientSession(scope, MakeDefaultSessionOptions(target)) {} + : ClientSession(scope, Impl::MakeDefaultSessionOptions(target)) {} ClientSession::ClientSession(const Scope& scope) : ClientSession(scope, "") {} ClientSession::ClientSession(const Scope& scope, - const SessionOptions& session_options) - : graph_(scope.graph_as_shared_ptr()) { + const SessionOptions& session_options) { Session* new_session; Status status = NewSession(session_options, &new_session); TF_CHECK_OK(status) << status; - session_.reset(new_session); - CHECK_NOTNULL(session_.get()); + impl_.reset(new Impl(new_session, scope.graph_as_shared_ptr())); + CHECK_NOTNULL(impl()->session_.get()); } -SessionOptions ClientSession::MakeDefaultSessionOptions( - const string& target) const { +// Define destructor here so we can forward declare `Impl` in client_session.h. +// If we define a dtor in the header file or use the default dtor, +// unique_ptr<Impl> needs the complete type. +ClientSession::~ClientSession() {} + +SessionOptions ClientSession::Impl::MakeDefaultSessionOptions( + const string& target) { SessionOptions options; options.env = Env::Default(); options.target = target; @@ -67,7 +89,7 @@ Status ClientSession::Run(const FeedType& inputs, nullptr); } -Status ClientSession::MaybeExtendGraph() const { +Status ClientSession::Impl::MaybeExtendGraph() const { mutex_lock l(mu_); int num_nodes = graph_->num_node_ids(); if (num_nodes > last_num_graph_nodes_) { @@ -97,9 +119,9 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs, for (auto const& output : run_outputs) { target_node_names.push_back(output.node()->name()); } - TF_RETURN_IF_ERROR(MaybeExtendGraph()); - return session_->Run(run_options, feeds, output_tensor_names, - target_node_names, outputs, run_metadata); + TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph()); + return impl()->session_->Run(run_options, feeds, output_tensor_names, + target_node_names, outputs, run_metadata); } } // end namespace tensorflow diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h index a6fe0205a0..5fb4109f7d 100644 --- a/tensorflow/cc/client/client_session.h +++ b/tensorflow/cc/client/client_session.h @@ -23,10 +23,6 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/protobuf/config.pb.h" -#include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { @@ -67,6 +63,8 @@ class ClientSession { /// Create a new session, configuring it with `session_options`. ClientSession(const Scope& scope, const SessionOptions& session_options); + ~ClientSession(); + /// Evaluate the tensors in `fetch_outputs`. The values are returned as /// `Tensor` objects in `outputs`. The number and order of `outputs` will /// match `fetch_outputs`. @@ -92,16 +90,10 @@ class ClientSession { // TODO(keveman): Add support for partial run. private: - SessionOptions MakeDefaultSessionOptions(const string& target) const; - Status MaybeExtendGraph() const; - - std::unique_ptr<Session> session_; - std::shared_ptr<Graph> graph_; - - mutable mutex mu_; - mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0; - - TF_DISALLOW_COPY_AND_ASSIGN(ClientSession); + class Impl; + std::unique_ptr<Impl> impl_; + Impl* impl() { return impl_.get(); } + const Impl* impl() const { return impl_.get(); } }; /// @} |