diff options
author | 2018-08-03 14:19:25 -0700 | |
---|---|---|
committer | 2018-08-03 14:24:12 -0700 | |
commit | b3b5e68b565e48ccd37baebbfa3459bc460156ca (patch) | |
tree | a6a5c1d2062f44c8ae8a489c23249dadca689823 /tensorflow/cc/client/client_session_test.cc | |
parent | c9328e51b72f9f906364a523926abdc62095ffe0 (diff) |
Add experimental Callable API to ClientSession.
PiperOrigin-RevId: 207323298
Diffstat (limited to 'tensorflow/cc/client/client_session_test.cc')
-rw-r--r-- | tensorflow/cc/client/client_session_test.cc | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc index ea5cf5a1f1..559ffea7e8 100644 --- a/tensorflow/cc/client/client_session_test.cc +++ b/tensorflow/cc/client/client_session_test.cc @@ -95,5 +95,26 @@ TEST(ClientSessionTest, MultiThreaded) { test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2})); } +TEST(ClientSessionTest, Callable) { + Scope root = Scope::NewRootScope(); + auto a = Placeholder(root, DT_INT32); + auto b = Placeholder(root, DT_INT32); + auto c = Add(root, a, b); + ClientSession session(root); + std::vector<Tensor> outputs; + + CallableOptions options; + options.add_feed(a.node()->name()); + options.add_feed(b.node()->name()); + options.add_fetch(c.node()->name()); + ClientSession::CallableHandle callable; + TF_CHECK_OK(session.MakeCallable(options, &callable)); + TF_EXPECT_OK(session.RunCallable( + callable, {test::AsTensor<int>({1}, {}), test::AsTensor<int>({41}, {})}, + &outputs, nullptr)); + test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({42}, {})); + TF_EXPECT_OK(session.ReleaseCallable(callable)); +} + } // namespace } // namespace tensorflow |