diff options
author | Asim Shankar <ashankar@google.com> | 2018-06-26 18:51:41 -0700 |
---|---|---|
committer | Gunhan Gulsoy <gunan@google.com> | 2018-06-28 21:37:43 -0700 |
commit | a1d6179adb1ca6208281ed955860c319525edf75 (patch) | |
tree | d59762033c0784b638c89304f3b3216a2bb7ce20 /tensorflow/core/distributed_runtime | |
parent | 3336574287a16a0ead083a33b5e80a1c7204fa62 (diff) |
[C++]: Ability to feed and fetch tensors while keeping them in device memory
when using Session::RunCallable().
PiperOrigin-RevId: 202234757
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc index 45b15a54a2..fc601991a2 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc @@ -163,6 +163,39 @@ TEST(GrpcSessionTest, BasicCallable) { } } +TEST(GrpcSessionTest, CallableWithOnDeviceFeedsAndFetches) { + // Specifying feeds/fetch devices for remote sessions is not yet defined. + // Ensure that the error is graceful. + GraphDef graph; + string node_names[3]; + // c = a * b + CreateGraphDef(&graph, node_names); + + std::unique_ptr<test::TestCluster> cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster)); + + std::unique_ptr<Session> session( + NewRemote(Options(cluster->targets()[0], 1))); + ASSERT_TRUE(session != nullptr); + + TF_CHECK_OK(session->Create(graph)); + + std::vector<DeviceAttributes> devices; + TF_CHECK_OK(session->ListDevices(&devices)); + ASSERT_GT(devices.size(), 0); + const string device_name = devices.back().name(); + + CallableOptions opts; + const string fetch = node_names[2] + ":0"; + opts.add_fetch(fetch); + opts.mutable_fetch_devices()->insert({fetch, device_name}); + + Session::CallableHandle handle; + Status status = session->MakeCallable(opts, &handle); + EXPECT_EQ(error::UNIMPLEMENTED, status.code()); + TF_CHECK_OK(session->Close()); +} + TEST(GrpcSessionTest, BasicNonProtoAPIConsistentOrder) { GraphDef graph; string node_names[3]; |