From b3b5e68b565e48ccd37baebbfa3459bc460156ca Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 3 Aug 2018 14:19:25 -0700 Subject: Add experimental Callable API to ClientSession. PiperOrigin-RevId: 207323298 --- tensorflow/cc/client/client_session.cc | 18 ++++++++++++++++++ tensorflow/cc/client/client_session.h | 28 +++++++++++++++++++++++++++- tensorflow/cc/client/client_session_test.cc | 21 +++++++++++++++++++++ 3 files changed, 66 insertions(+), 1 deletion(-) (limited to 'tensorflow/cc') diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc index ba056a8f3a..0e61089a59 100644 --- a/tensorflow/cc/client/client_session.cc +++ b/tensorflow/cc/client/client_session.cc @@ -127,4 +127,22 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs, target_node_names, outputs, run_metadata); } +Status ClientSession::MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) { + TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph()); + return impl()->session_->MakeCallable(callable_options, out_handle); +} + +Status ClientSession::RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) { + return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors, + run_metadata); +} + +Status ClientSession::ReleaseCallable(CallableHandle handle) { + return impl()->session_->ReleaseCallable(handle); +} + } // end namespace tensorflow diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h index 5fb4109f7d..7dd653eec4 100644 --- a/tensorflow/cc/client/client_session.h +++ b/tensorflow/cc/client/client_session.h @@ -87,7 +87,33 @@ class ClientSession { const std::vector& run_outputs, std::vector* outputs, RunMetadata* run_metadata) const; - // TODO(keveman): Add support for partial run. + /// \brief A handle to a subgraph, created with + /// `ClientSession::MakeCallable()`. + typedef int64 CallableHandle; + + /// \brief Creates a `handle` for invoking the subgraph defined by + /// `callable_options`. + /// NOTE: This API is still experimental and may change. + Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle); + + /// \brief Invokes the subgraph named by `handle` with the given options and + /// input tensors. + /// + /// The order of tensors in `feed_tensors` must match the order of names in + /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will + /// match the order of names in `CallableOptions::fetch()` when this subgraph + /// was created. + /// NOTE: This API is still experimental and may change. + Status RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata); + + /// \brief Releases resources associated with the given `handle` in this + /// session. + /// NOTE: This API is still experimental and may change. + Status ReleaseCallable(CallableHandle handle); private: class Impl; diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc index ea5cf5a1f1..559ffea7e8 100644 --- a/tensorflow/cc/client/client_session_test.cc +++ b/tensorflow/cc/client/client_session_test.cc @@ -95,5 +95,26 @@ TEST(ClientSessionTest, MultiThreaded) { test::ExpectTensorEqual(outputs[0], test::AsTensor({-1, 2}, {2})); } +TEST(ClientSessionTest, Callable) { + Scope root = Scope::NewRootScope(); + auto a = Placeholder(root, DT_INT32); + auto b = Placeholder(root, DT_INT32); + auto c = Add(root, a, b); + ClientSession session(root); + std::vector outputs; + + CallableOptions options; + options.add_feed(a.node()->name()); + options.add_feed(b.node()->name()); + options.add_fetch(c.node()->name()); + ClientSession::CallableHandle callable; + TF_CHECK_OK(session.MakeCallable(options, &callable)); + TF_EXPECT_OK(session.RunCallable( + callable, {test::AsTensor({1}, {}), test::AsTensor({41}, {})}, + &outputs, nullptr)); + test::ExpectTensorEqual(outputs[0], test::AsTensor({42}, {})); + TF_EXPECT_OK(session.ReleaseCallable(callable)); +} + } // namespace } // namespace tensorflow -- cgit v1.2.3