aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-06-26 18:51:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-26 18:55:08 -0700
commit373921627bfe9cf4b21bbe0fbd879888797217d9 (patch)
treeb3be78c0b17e9226932d5b9136fe689237bed672 /tensorflow/core/distributed_runtime
parent41731b13598c50a31432e769f4cb9d9fc355cf7a (diff)
[C++]: Ability to feed and fetch tensors while keeping them in device memory
when using Session::RunCallable(). PiperOrigin-RevId: 202234757
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc33
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
index 45b15a54a2..fc601991a2 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
@@ -163,6 +163,39 @@ TEST(GrpcSessionTest, BasicCallable) {
}
}
+TEST(GrpcSessionTest, CallableWithOnDeviceFeedsAndFetches) {
+ // Specifying feeds/fetch devices for remote sessions is not yet defined.
+ // Ensure that the error is graceful.
+ GraphDef graph;
+ string node_names[3];
+ // c = a * b
+ CreateGraphDef(&graph, node_names);
+
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+
+ std::unique_ptr<Session> session(
+ NewRemote(Options(cluster->targets()[0], 1)));
+ ASSERT_TRUE(session != nullptr);
+
+ TF_CHECK_OK(session->Create(graph));
+
+ std::vector<DeviceAttributes> devices;
+ TF_CHECK_OK(session->ListDevices(&devices));
+ ASSERT_GT(devices.size(), 0);
+ const string device_name = devices.back().name();
+
+ CallableOptions opts;
+ const string fetch = node_names[2] + ":0";
+ opts.add_fetch(fetch);
+ opts.mutable_fetch_devices()->insert({fetch, device_name});
+
+ Session::CallableHandle handle;
+ Status status = session->MakeCallable(opts, &handle);
+ EXPECT_EQ(error::UNIMPLEMENTED, status.code());
+ TF_CHECK_OK(session->Close());
+}
+
TEST(GrpcSessionTest, BasicNonProtoAPIConsistentOrder) {
GraphDef graph;
string node_names[3];