diff options
Diffstat (limited to 'tensorflow/cc/client/client_session_test.cc')
-rw-r--r-- | tensorflow/cc/client/client_session_test.cc | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc index ea5cf5a1f1..559ffea7e8 100644 --- a/tensorflow/cc/client/client_session_test.cc +++ b/tensorflow/cc/client/client_session_test.cc @@ -95,5 +95,26 @@ TEST(ClientSessionTest, MultiThreaded) { test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2})); } +TEST(ClientSessionTest, Callable) { + Scope root = Scope::NewRootScope(); + auto a = Placeholder(root, DT_INT32); + auto b = Placeholder(root, DT_INT32); + auto c = Add(root, a, b); + ClientSession session(root); + std::vector<Tensor> outputs; + + CallableOptions options; + options.add_feed(a.node()->name()); + options.add_feed(b.node()->name()); + options.add_fetch(c.node()->name()); + ClientSession::CallableHandle callable; + TF_CHECK_OK(session.MakeCallable(options, &callable)); + TF_EXPECT_OK(session.RunCallable( + callable, {test::AsTensor<int>({1}, {}), test::AsTensor<int>({41}, {})}, + &outputs, nullptr)); + test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({42}, {})); + TF_EXPECT_OK(session.ReleaseCallable(callable)); +} + } // namespace } // namespace tensorflow |