aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/client/client_session_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-03 14:19:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-03 14:24:12 -0700
commitb3b5e68b565e48ccd37baebbfa3459bc460156ca (patch)
treea6a5c1d2062f44c8ae8a489c23249dadca689823 /tensorflow/cc/client/client_session_test.cc
parentc9328e51b72f9f906364a523926abdc62095ffe0 (diff)
Add experimental Callable API to ClientSession.
PiperOrigin-RevId: 207323298
Diffstat (limited to 'tensorflow/cc/client/client_session_test.cc')
-rw-r--r--tensorflow/cc/client/client_session_test.cc21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc
index ea5cf5a1f1..559ffea7e8 100644
--- a/tensorflow/cc/client/client_session_test.cc
+++ b/tensorflow/cc/client/client_session_test.cc
@@ -95,5 +95,26 @@ TEST(ClientSessionTest, MultiThreaded) {
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2}));
}
+TEST(ClientSessionTest, Callable) {
+ Scope root = Scope::NewRootScope();
+ auto a = Placeholder(root, DT_INT32);
+ auto b = Placeholder(root, DT_INT32);
+ auto c = Add(root, a, b);
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+
+ CallableOptions options;
+ options.add_feed(a.node()->name());
+ options.add_feed(b.node()->name());
+ options.add_fetch(c.node()->name());
+ ClientSession::CallableHandle callable;
+ TF_CHECK_OK(session.MakeCallable(options, &callable));
+ TF_EXPECT_OK(session.RunCallable(
+ callable, {test::AsTensor<int>({1}, {}), test::AsTensor<int>({41}, {})},
+ &outputs, nullptr));
+ test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({42}, {}));
+ TF_EXPECT_OK(session.ReleaseCallable(callable));
+}
+
} // namespace
} // namespace tensorflow