diff options
author | Manjunath Kudlur <keveman@google.com> | 2016-08-19 10:57:15 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-19 12:03:49 -0700 |
commit | 767377749cb9f7f4df1f95db342c88ca3c6ee1f2 (patch) | |
tree | 288ef21efa24f7b048bcb5df58c0a61cebbdea95 /tensorflow/cc/client | |
parent | 0b9f0f53ddbf693bb30afb211a6d514a1fce1c22 (diff) |
ClientSession wraps a Session object and provides Run methods that operate on
the C++ API's Output/Operation object instead of strings.
Change: 130776638
Diffstat (limited to 'tensorflow/cc/client')
-rw-r--r-- | tensorflow/cc/client/client_session.cc | 102 | ||||
-rw-r--r-- | tensorflow/cc/client/client_session.h | 109 | ||||
-rw-r--r-- | tensorflow/cc/client/client_session_test.cc | 92 |
3 files changed, 303 insertions, 0 deletions
diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc new file mode 100644 index 0000000000..5a98deb259 --- /dev/null +++ b/tensorflow/cc/client/client_session.cc @@ -0,0 +1,102 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/client/client_session.h" + +#include <unordered_map> +#include <vector> + +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +ClientSession::ClientSession(const Scope& scope, const string& target) + : ClientSession(scope, MakeDefaultSessionOptions(target)) {} + +ClientSession::ClientSession(const Scope& scope) : ClientSession(scope, "") {} + +ClientSession::ClientSession(const Scope& scope, + const SessionOptions& session_options) + : session_(NewSession(session_options)), + graph_(scope.graph_as_shared_ptr()) { + CHECK_NOTNULL(session_.get()); +} + +SessionOptions ClientSession::MakeDefaultSessionOptions( + const string& target) const { + SessionOptions options; + options.env = Env::Default(); + options.target = target; + return options; +} + +Status ClientSession::Run(const std::vector<ops::Output>& fetch_outputs, + std::vector<Tensor>* outputs) const { + return Run(FeedType{}, fetch_outputs, {}, outputs); +} + +Status ClientSession::Run(const FeedType& inputs, + const std::vector<ops::Output>& fetch_outputs, + std::vector<Tensor>* outputs) const { + return Run(inputs, fetch_outputs, {}, outputs); +} + +Status ClientSession::Run(const FeedType& inputs, + const std::vector<ops::Output>& fetch_outputs, + const std::vector<ops::Operation>& run_outputs, + std::vector<Tensor>* outputs) const { + return Run(RunOptions(), inputs, fetch_outputs, run_outputs, outputs, + nullptr); +} + +Status ClientSession::MaybeExtendGraph() const { + mutex_lock l(mu_); + int num_nodes = graph_->num_node_ids(); + if (num_nodes > last_num_graph_nodes_) { + GraphDef graph_def; + graph_->ToGraphDefSubRange(&graph_def, last_num_graph_nodes_); + last_num_graph_nodes_ = num_nodes; + return session_->Extend(graph_def); + } + return Status::OK(); +} + +Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs, + const std::vector<ops::Output>& fetch_outputs, + const std::vector<ops::Operation>& run_outputs, + std::vector<Tensor>* outputs, + RunMetadata* run_metadata) const { + std::vector<std::pair<string, Tensor>> feeds; + for (auto const& feed : inputs) { + TF_RETURN_IF_ERROR(feed.second.status); + feeds.emplace_back(feed.first.name(), feed.second.tensor); + } + std::vector<string> output_tensor_names; + for (auto const& output : fetch_outputs) { + output_tensor_names.push_back(output.name()); + } + std::vector<string> target_node_names; + for (auto const& output : run_outputs) { + target_node_names.push_back(output.node()->name()); + } + TF_RETURN_IF_ERROR(MaybeExtendGraph()); + return session_->Run(run_options, feeds, output_tensor_names, + target_node_names, outputs, run_metadata); +} + +} // end namespace tensorflow diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h new file mode 100644 index 0000000000..9d480477f6 --- /dev/null +++ b/tensorflow/cc/client/client_session.h @@ -0,0 +1,109 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_ +#define TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_ + +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +// A `ClientSession` object lets the caller drive the evaluation of the +// TensorFlow graph constructed with the C++ API. +// +// Example: +// +// Scope root = Scope::NewRootScope(); +// auto a = Placeholder(root, DT_INT32); +// auto c = Add(root, a, {41}); +// +// ClientSession session(root); +// std::vector<Tensor> outputs; +// +// Status s = session.Run({{a, {1}}}, {c}, &outputs); +// if (!s.ok()) { /* Handle error */ } +class ClientSession { + public: + // A data type to represent feeds to a Run call. + // This is a map of `Output` objects returned by op-constructors to the value + // to feed them with. See `ops::Input::Initializer` for details on what can be + // used as feed values. + typedef std::unordered_map<ops::Output, ops::Input::Initializer, + ops::OutputHash> + FeedType; + + // Create a new session to evaluate the graph contained in `scope` by + // connecting to the TensorFlow runtime specified by `target`. + ClientSession(const Scope& scope, const string& target); + + // Same as above, but use the empty string ("") as the target specification. + ClientSession(const Scope& scope); + + // Create a new session, configuring it with `session_options`. + ClientSession(const Scope& scope, const SessionOptions& session_options); + + // Evaluate the tensors in `fetch_outputs`. The values are returned as + // `Tensor` objects in `outputs`. The number and order of `outputs` will match + // `fetch_outputs`. + Status Run(const std::vector<ops::Output>& fetch_outputs, + std::vector<Tensor>* outputs) const; + + // Same as above, but use the mapping in `inputs` as feeds. + Status Run(const FeedType& inputs, + const std::vector<ops::Output>& fetch_outputs, + std::vector<Tensor>* outputs) const; + + // Same as above. Additionally runs the operations ins `run_outputs`. + Status Run(const FeedType& inputs, + const std::vector<ops::Output>& fetch_outputs, + const std::vector<ops::Operation>& run_outputs, + std::vector<Tensor>* outputs) const; + + // Use `run_options` to turn on performance profiling. `run_metadata`, if not + // null, is filled in with the profiling results. + Status Run(const RunOptions& run_options, const FeedType& inputs, + const std::vector<ops::Output>& fetch_outputs, + const std::vector<ops::Operation>& run_outputs, + std::vector<Tensor>* outputs, RunMetadata* run_metadata) const; + + // TODO(keveman): Add support for partial run. + + private: + SessionOptions MakeDefaultSessionOptions(const string& target) const; + Status MaybeExtendGraph() const; + + std::unique_ptr<Session> session_; + std::shared_ptr<Graph> graph_; + + mutable mutex mu_; + mutable int last_num_graph_nodes_ GUARDED_BY(mu_) = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(ClientSession); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_ diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc new file mode 100644 index 0000000000..9c0f00f2b1 --- /dev/null +++ b/tensorflow/cc/client/client_session_test.cc @@ -0,0 +1,92 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <vector> + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +using namespace ops; // NOLINT(build/namespaces) + +TEST(ClientSessionTest, Basic) { + Scope root = Scope::NewRootScope(); + auto c = Const(root, {{1, 1}}); + ClientSession session(root); + std::vector<Tensor> outputs; + + TF_EXPECT_OK(session.Run({c}, &outputs)); + test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({1, 1}, {1, 2})); +} + +TEST(ClientSessionTest, Feed) { + Scope root = Scope::NewRootScope(); + auto a = Placeholder(root, DT_INT32); + auto b = Placeholder(root, DT_INT32); + auto c = Add(root, a, b); + ClientSession session(root); + std::vector<Tensor> outputs; + + TF_EXPECT_OK(session.Run({{a, 1}, {b, 41}}, {c}, &outputs)); + test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({42}, {})); +} + +TEST(ClientSessionTest, Extend) { + Scope root = Scope::NewRootScope(); + auto a = Placeholder(root, DT_INT32); + auto c = Add(root, a, {2, 2}); + ClientSession session(root); + std::vector<Tensor> outputs; + + TF_EXPECT_OK(session.Run({{a, {1, 1}}}, {c}, &outputs)); + test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({3, 3}, {2})); + + auto d = Add(root, c, {39, 39}); + outputs.clear(); + TF_EXPECT_OK(session.Run({{a, {-10, 1}}}, {d}, &outputs)); + test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({31, 42}, {2})); +} + +TEST(ClientSessionTest, MultiThreaded) { + Scope root = Scope::NewRootScope(); + auto a = Add(root, {1, 2}, {3, 4}); + auto b = Mul(root, {1, 2}, {3, 4}); + ClientSession session(root); + { + thread::ThreadPool thread_pool(Env::Default(), "pool", 2); + thread_pool.Schedule([&session, a]() { + std::vector<Tensor> outputs; + TF_EXPECT_OK(session.Run({a}, &outputs)); + test::ExpectTensorEqual<int>(outputs[0], + test::AsTensor<int>({4, 6}, {2})); + }); + thread_pool.Schedule([&session, b]() { + std::vector<Tensor> outputs; + TF_EXPECT_OK(session.Run({b}, &outputs)); + test::ExpectTensorEqual<int>(outputs[0], + test::AsTensor<int>({3, 8}, {2})); + }); + } + auto c = Sub(root, b, a); + std::vector<Tensor> outputs; + TF_EXPECT_OK(session.Run({c}, &outputs)); + test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2})); +} + +} // end namespace tensorflow |