aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager/c_api_test.cc
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-01-24 09:10:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-24 09:15:15 -0800
commit37bfad3c33c005077630b021ca927608dd70bb3e (patch)
tree832cc092ba6d8ebea0f16077541e854d9dc6a5c1 /tensorflow/c/eager/c_api_test.cc
parenta346b737046574a86268bf920e4c88f19a455830 (diff)
Allow setting per-thread device copying policies in TFE.
PiperOrigin-RevId: 183093407
Diffstat (limited to 'tensorflow/c/eager/c_api_test.cc')
-rw-r--r--tensorflow/c/eager/c_api_test.cc49
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 423a7e1ff7..18e7a64435 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -321,6 +321,55 @@ TEST(CAPI, TensorHandleSilentCopy) {
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
+TEST(CAPI, TensorHandleSilentCopyLocal) {
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetDevicePlacementPolicy(opts,
+ TFE_DEVICE_PLACEMENT_EXPLICIT);
+ TFE_Context* ctx = TFE_NewContext(opts, status.get());
+ TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx,
+ TFE_DEVICE_PLACEMENT_SILENT);
+ TFE_DeleteContextOptions(opts);
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
+ TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ const int num_devices = TF_DeviceListCount(devices);
+
+ // Disable the test if no GPU is present.
+ if (num_devices > 1) {
+ const int device_to_use = 1;
+ const string name(TF_DeviceListName(devices, device_to_use, status.get()));
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ TFE_TensorHandle* hgpu =
+ TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
+ TFE_OpSetDevice(matmul, name.c_str(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ TFE_DeleteOp(matmul);
+ TFE_DeleteTensorHandle(retvals[0]);
+ TFE_DeleteTensorHandle(hgpu);
+ }
+
+ TF_DeleteDeviceList(devices);
+ TF_DeleteTensor(t);
+ TFE_DeleteTensorHandle(hcpu);
+ TFE_DeleteContext(ctx, status.get());
+ EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+}
+
TEST(CAPI, Execute) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();