aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-06-26 18:51:41 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commita1d6179adb1ca6208281ed955860c319525edf75 (patch)
treed59762033c0784b638c89304f3b3216a2bb7ce20 /tensorflow/core/distributed_runtime
parent3336574287a16a0ead083a33b5e80a1c7204fa62 (diff)
[C++]: Ability to feed and fetch tensors while keeping them in device memory
when using Session::RunCallable(). PiperOrigin-RevId: 202234757
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc33
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
index 45b15a54a2..fc601991a2 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
@@ -163,6 +163,39 @@ TEST(GrpcSessionTest, BasicCallable) {
}
}
+TEST(GrpcSessionTest, CallableWithOnDeviceFeedsAndFetches) {
+ // Specifying feeds/fetch devices for remote sessions is not yet defined.
+ // Ensure that the error is graceful.
+ GraphDef graph;
+ string node_names[3];
+ // c = a * b
+ CreateGraphDef(&graph, node_names);
+
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+
+ std::unique_ptr<Session> session(
+ NewRemote(Options(cluster->targets()[0], 1)));
+ ASSERT_TRUE(session != nullptr);
+
+ TF_CHECK_OK(session->Create(graph));
+
+ std::vector<DeviceAttributes> devices;
+ TF_CHECK_OK(session->ListDevices(&devices));
+ ASSERT_GT(devices.size(), 0);
+ const string device_name = devices.back().name();
+
+ CallableOptions opts;
+ const string fetch = node_names[2] + ":0";
+ opts.add_fetch(fetch);
+ opts.mutable_fetch_devices()->insert({fetch, device_name});
+
+ Session::CallableHandle handle;
+ Status status = session->MakeCallable(opts, &handle);
+ EXPECT_EQ(error::UNIMPLEMENTED, status.code());
+ TF_CHECK_OK(session->Close());
+}
+
TEST(GrpcSessionTest, BasicNonProtoAPIConsistentOrder) {
GraphDef graph;
string node_names[3];