aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-06-11 16:04:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-11 16:07:37 -0700
commita1244d61b1bf9db586ad12fb12b65d2db3913e45 (patch)
tree6a0eba77e77cee6d77a8267e9ef1ccc9a3bd3559 /tensorflow/c/eager
parent49ed096fb3f89855dbdccf183d10d8068324f1c2 (diff)
Allow silent copies during remote execution.
This is required to do anything useful from python. PiperOrigin-RevId: 200129777
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r--tensorflow/c/eager/c_api_test.cc81
1 files changed, 80 insertions, 1 deletions
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 27ff5f7211..992d1afd5f 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -142,8 +142,10 @@ void TestRemoteExecute(bool async) {
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
status);
- TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
+ TFE_ContextOptionsSetDevicePlacementPolicy(opts,
+ TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
@@ -205,6 +207,83 @@ void TestRemoteExecute(bool async) {
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
+void TestRemoteExecuteSilentCopies(bool async) {
+ tensorflow::ServerDef server_def = GetServerDef(2);
+
+ // This server def has the task index set to 0.
+ string serialized = server_def.SerializeAsString();
+
+ server_def.set_task_index(1);
+
+ std::unique_ptr<tensorflow::eager::EagerGrpcServer> worker_server;
+ ASSERT_TRUE(
+ tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server)
+ .ok());
+ ASSERT_TRUE(worker_server->Start().ok());
+
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
+ status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
+ TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
+ TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
+ const char remote_device_name[] =
+ "/job:localhost/replica:0/task:1/device:CPU:0";
+
+ // Handles are on task0, but op is on remote (task1).
+ TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task0);
+ TFE_OpSetDevice(matmul, remote_device_name, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ auto* retval_task0 = TFE_TensorHandleCopyToDevice(
+ retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteTensorHandle(retval_task0);
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(7, product[0]);
+ EXPECT_EQ(10, product[1]);
+ EXPECT_EQ(15, product[2]);
+ EXPECT_EQ(22, product[3]);
+
+ TFE_DeleteTensorHandle(h0_task0);
+ TFE_DeleteTensorHandle(h1_task0);
+ TFE_DeleteTensorHandle(retvals[0]);
+
+ TFE_DeleteOp(matmul);
+
+ TFE_ContextAsyncWait(ctx, status);
+ TFE_DeleteContext(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_DeleteStatus(status);
+
+ // TODO(nareshmodi): Figure out how to correctly shut the server down.
+ worker_server.release();
+}
+
+TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
+TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
+ TestRemoteExecuteSilentCopies(true);
+}
+
TEST(CAPI, TensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));