aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/client
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-02-23 16:00:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-23 16:09:44 -0800
commit7c46b4600a084b64da48ee26b2f974d5f56e8cf1 (patch)
treee7c3b3f2ff27423481986357e8d11a38c9333b3a /tensorflow/cc/client
parent609618f209c16b7910432da6e216ca1cebe8dcb0 (diff)
C++ API: create ClientSession::Impl class to hide private members/methods.
Eventually all public API classes should follow this pattern. Change: 148403490
Diffstat (limited to 'tensorflow/cc/client')
-rw-r--r--tensorflow/cc/client/client_session.cc44
-rw-r--r--tensorflow/cc/client/client_session.h20
2 files changed, 39 insertions, 25 deletions
diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc
index 644409203c..2732f3f501 100644
--- a/tensorflow/cc/client/client_session.cc
+++ b/tensorflow/cc/client/client_session.cc
@@ -19,29 +19,51 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
+class ClientSession::Impl {
+ private:
+ friend class ClientSession;
+
+ Impl(Session* session, std::shared_ptr<Graph> graph)
+ : session_(session), graph_(graph) {}
+
+ static SessionOptions MakeDefaultSessionOptions(const string& target);
+ Status MaybeExtendGraph() const;
+
+ std::unique_ptr<Session> session_;
+ std::shared_ptr<Graph> graph_;
+
+ mutable mutex mu_;
+ mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0;
+};
+
ClientSession::ClientSession(const Scope& scope, const string& target)
- : ClientSession(scope, MakeDefaultSessionOptions(target)) {}
+ : ClientSession(scope, Impl::MakeDefaultSessionOptions(target)) {}
ClientSession::ClientSession(const Scope& scope) : ClientSession(scope, "") {}
ClientSession::ClientSession(const Scope& scope,
- const SessionOptions& session_options)
- : graph_(scope.graph_as_shared_ptr()) {
+ const SessionOptions& session_options) {
Session* new_session;
Status status = NewSession(session_options, &new_session);
TF_CHECK_OK(status) << status;
- session_.reset(new_session);
- CHECK_NOTNULL(session_.get());
+ impl_.reset(new Impl(new_session, scope.graph_as_shared_ptr()));
+ CHECK_NOTNULL(impl()->session_.get());
}
-SessionOptions ClientSession::MakeDefaultSessionOptions(
- const string& target) const {
+// Define destructor here so we can forward declare `Impl` in client_session.h.
+// If we define a dtor in the header file or use the default dtor,
+// unique_ptr<Impl> needs the complete type.
+ClientSession::~ClientSession() {}
+
+SessionOptions ClientSession::Impl::MakeDefaultSessionOptions(
+ const string& target) {
SessionOptions options;
options.env = Env::Default();
options.target = target;
@@ -67,7 +89,7 @@ Status ClientSession::Run(const FeedType& inputs,
nullptr);
}
-Status ClientSession::MaybeExtendGraph() const {
+Status ClientSession::Impl::MaybeExtendGraph() const {
mutex_lock l(mu_);
int num_nodes = graph_->num_node_ids();
if (num_nodes > last_num_graph_nodes_) {
@@ -97,9 +119,9 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
for (auto const& output : run_outputs) {
target_node_names.push_back(output.node()->name());
}
- TF_RETURN_IF_ERROR(MaybeExtendGraph());
- return session_->Run(run_options, feeds, output_tensor_names,
- target_node_names, outputs, run_metadata);
+ TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
+ return impl()->session_->Run(run_options, feeds, output_tensor_names,
+ target_node_names, outputs, run_metadata);
}
} // end namespace tensorflow
diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h
index a6fe0205a0..5fb4109f7d 100644
--- a/tensorflow/cc/client/client_session.h
+++ b/tensorflow/cc/client/client_session.h
@@ -23,10 +23,6 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/protobuf/config.pb.h"
-#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
@@ -67,6 +63,8 @@ class ClientSession {
/// Create a new session, configuring it with `session_options`.
ClientSession(const Scope& scope, const SessionOptions& session_options);
+ ~ClientSession();
+
/// Evaluate the tensors in `fetch_outputs`. The values are returned as
/// `Tensor` objects in `outputs`. The number and order of `outputs` will
/// match `fetch_outputs`.
@@ -92,16 +90,10 @@ class ClientSession {
// TODO(keveman): Add support for partial run.
private:
- SessionOptions MakeDefaultSessionOptions(const string& target) const;
- Status MaybeExtendGraph() const;
-
- std::unique_ptr<Session> session_;
- std::shared_ptr<Graph> graph_;
-
- mutable mutex mu_;
- mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0;
-
- TF_DISALLOW_COPY_AND_ASSIGN(ClientSession);
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+ Impl* impl() { return impl_.get(); }
+ const Impl* impl() const { return impl_.get(); }
};
/// @}