aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/client
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@google.com>2016-08-19 10:57:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-19 12:03:49 -0700
commit767377749cb9f7f4df1f95db342c88ca3c6ee1f2 (patch)
tree288ef21efa24f7b048bcb5df58c0a61cebbdea95 /tensorflow/cc/client
parent0b9f0f53ddbf693bb30afb211a6d514a1fce1c22 (diff)
ClientSession wraps a Session object and provides Run methods that operate on
the C++ API's Output/Operation object instead of strings. Change: 130776638
Diffstat (limited to 'tensorflow/cc/client')
-rw-r--r--tensorflow/cc/client/client_session.cc102
-rw-r--r--tensorflow/cc/client/client_session.h109
-rw-r--r--tensorflow/cc/client/client_session_test.cc92
3 files changed, 303 insertions, 0 deletions
diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc
new file mode 100644
index 0000000000..5a98deb259
--- /dev/null
+++ b/tensorflow/cc/client/client_session.cc
@@ -0,0 +1,102 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/client/client_session.h"
+
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+ClientSession::ClientSession(const Scope& scope, const string& target)
+ : ClientSession(scope, MakeDefaultSessionOptions(target)) {}
+
+ClientSession::ClientSession(const Scope& scope) : ClientSession(scope, "") {}
+
+ClientSession::ClientSession(const Scope& scope,
+ const SessionOptions& session_options)
+ : session_(NewSession(session_options)),
+ graph_(scope.graph_as_shared_ptr()) {
+ CHECK_NOTNULL(session_.get());
+}
+
+SessionOptions ClientSession::MakeDefaultSessionOptions(
+ const string& target) const {
+ SessionOptions options;
+ options.env = Env::Default();
+ options.target = target;
+ return options;
+}
+
+Status ClientSession::Run(const std::vector<ops::Output>& fetch_outputs,
+ std::vector<Tensor>* outputs) const {
+ return Run(FeedType{}, fetch_outputs, {}, outputs);
+}
+
+Status ClientSession::Run(const FeedType& inputs,
+ const std::vector<ops::Output>& fetch_outputs,
+ std::vector<Tensor>* outputs) const {
+ return Run(inputs, fetch_outputs, {}, outputs);
+}
+
+Status ClientSession::Run(const FeedType& inputs,
+ const std::vector<ops::Output>& fetch_outputs,
+ const std::vector<ops::Operation>& run_outputs,
+ std::vector<Tensor>* outputs) const {
+ return Run(RunOptions(), inputs, fetch_outputs, run_outputs, outputs,
+ nullptr);
+}
+
+Status ClientSession::MaybeExtendGraph() const {
+ mutex_lock l(mu_);
+ int num_nodes = graph_->num_node_ids();
+ if (num_nodes > last_num_graph_nodes_) {
+ GraphDef graph_def;
+ graph_->ToGraphDefSubRange(&graph_def, last_num_graph_nodes_);
+ last_num_graph_nodes_ = num_nodes;
+ return session_->Extend(graph_def);
+ }
+ return Status::OK();
+}
+
+Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
+ const std::vector<ops::Output>& fetch_outputs,
+ const std::vector<ops::Operation>& run_outputs,
+ std::vector<Tensor>* outputs,
+ RunMetadata* run_metadata) const {
+ std::vector<std::pair<string, Tensor>> feeds;
+ for (auto const& feed : inputs) {
+ TF_RETURN_IF_ERROR(feed.second.status);
+ feeds.emplace_back(feed.first.name(), feed.second.tensor);
+ }
+ std::vector<string> output_tensor_names;
+ for (auto const& output : fetch_outputs) {
+ output_tensor_names.push_back(output.name());
+ }
+ std::vector<string> target_node_names;
+ for (auto const& output : run_outputs) {
+ target_node_names.push_back(output.node()->name());
+ }
+ TF_RETURN_IF_ERROR(MaybeExtendGraph());
+ return session_->Run(run_options, feeds, output_tensor_names,
+ target_node_names, outputs, run_metadata);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h
new file mode 100644
index 0000000000..9d480477f6
--- /dev/null
+++ b/tensorflow/cc/client/client_session.h
@@ -0,0 +1,109 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
+#define TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+// A `ClientSession` object lets the caller drive the evaluation of the
+// TensorFlow graph constructed with the C++ API.
+//
+// Example:
+//
+// Scope root = Scope::NewRootScope();
+// auto a = Placeholder(root, DT_INT32);
+// auto c = Add(root, a, {41});
+//
+// ClientSession session(root);
+// std::vector<Tensor> outputs;
+//
+// Status s = session.Run({{a, {1}}}, {c}, &outputs);
+// if (!s.ok()) { /* Handle error */ }
+class ClientSession {
+ public:
+ // A data type to represent feeds to a Run call.
+ // This is a map of `Output` objects returned by op-constructors to the value
+ // to feed them with. See `ops::Input::Initializer` for details on what can be
+ // used as feed values.
+ typedef std::unordered_map<ops::Output, ops::Input::Initializer,
+ ops::OutputHash>
+ FeedType;
+
+ // Create a new session to evaluate the graph contained in `scope` by
+ // connecting to the TensorFlow runtime specified by `target`.
+ ClientSession(const Scope& scope, const string& target);
+
+ // Same as above, but use the empty string ("") as the target specification.
+ ClientSession(const Scope& scope);
+
+ // Create a new session, configuring it with `session_options`.
+ ClientSession(const Scope& scope, const SessionOptions& session_options);
+
+ // Evaluate the tensors in `fetch_outputs`. The values are returned as
+ // `Tensor` objects in `outputs`. The number and order of `outputs` will match
+ // `fetch_outputs`.
+ Status Run(const std::vector<ops::Output>& fetch_outputs,
+ std::vector<Tensor>* outputs) const;
+
+ // Same as above, but use the mapping in `inputs` as feeds.
+ Status Run(const FeedType& inputs,
+ const std::vector<ops::Output>& fetch_outputs,
+ std::vector<Tensor>* outputs) const;
+
+ // Same as above. Additionally runs the operations ins `run_outputs`.
+ Status Run(const FeedType& inputs,
+ const std::vector<ops::Output>& fetch_outputs,
+ const std::vector<ops::Operation>& run_outputs,
+ std::vector<Tensor>* outputs) const;
+
+ // Use `run_options` to turn on performance profiling. `run_metadata`, if not
+ // null, is filled in with the profiling results.
+ Status Run(const RunOptions& run_options, const FeedType& inputs,
+ const std::vector<ops::Output>& fetch_outputs,
+ const std::vector<ops::Operation>& run_outputs,
+ std::vector<Tensor>* outputs, RunMetadata* run_metadata) const;
+
+ // TODO(keveman): Add support for partial run.
+
+ private:
+ SessionOptions MakeDefaultSessionOptions(const string& target) const;
+ Status MaybeExtendGraph() const;
+
+ std::unique_ptr<Session> session_;
+ std::shared_ptr<Graph> graph_;
+
+ mutable mutex mu_;
+ mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ClientSession);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc
new file mode 100644
index 0000000000..9c0f00f2b1
--- /dev/null
+++ b/tensorflow/cc/client/client_session_test.cc
@@ -0,0 +1,92 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+using namespace ops; // NOLINT(build/namespaces)
+
+TEST(ClientSessionTest, Basic) {
+ Scope root = Scope::NewRootScope();
+ auto c = Const(root, {{1, 1}});
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+
+ TF_EXPECT_OK(session.Run({c}, &outputs));
+ test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({1, 1}, {1, 2}));
+}
+
+TEST(ClientSessionTest, Feed) {
+ Scope root = Scope::NewRootScope();
+ auto a = Placeholder(root, DT_INT32);
+ auto b = Placeholder(root, DT_INT32);
+ auto c = Add(root, a, b);
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+
+ TF_EXPECT_OK(session.Run({{a, 1}, {b, 41}}, {c}, &outputs));
+ test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({42}, {}));
+}
+
+TEST(ClientSessionTest, Extend) {
+ Scope root = Scope::NewRootScope();
+ auto a = Placeholder(root, DT_INT32);
+ auto c = Add(root, a, {2, 2});
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+
+ TF_EXPECT_OK(session.Run({{a, {1, 1}}}, {c}, &outputs));
+ test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({3, 3}, {2}));
+
+ auto d = Add(root, c, {39, 39});
+ outputs.clear();
+ TF_EXPECT_OK(session.Run({{a, {-10, 1}}}, {d}, &outputs));
+ test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({31, 42}, {2}));
+}
+
+TEST(ClientSessionTest, MultiThreaded) {
+ Scope root = Scope::NewRootScope();
+ auto a = Add(root, {1, 2}, {3, 4});
+ auto b = Mul(root, {1, 2}, {3, 4});
+ ClientSession session(root);
+ {
+ thread::ThreadPool thread_pool(Env::Default(), "pool", 2);
+ thread_pool.Schedule([&session, a]() {
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({a}, &outputs));
+ test::ExpectTensorEqual<int>(outputs[0],
+ test::AsTensor<int>({4, 6}, {2}));
+ });
+ thread_pool.Schedule([&session, b]() {
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({b}, &outputs));
+ test::ExpectTensorEqual<int>(outputs[0],
+ test::AsTensor<int>({3, 8}, {2}));
+ });
+ }
+ auto c = Sub(root, b, a);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({c}, &outputs));
+ test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2}));
+}
+
+} // end namespace tensorflow