diff options
author | 2018-01-24 09:10:53 -0800 | |
---|---|---|
committer | 2018-01-24 09:15:15 -0800 | |
commit | 37bfad3c33c005077630b021ca927608dd70bb3e (patch) | |
tree | 832cc092ba6d8ebea0f16077541e854d9dc6a5c1 /tensorflow/c/eager/c_api_test.cc | |
parent | a346b737046574a86268bf920e4c88f19a455830 (diff) |
Allow setting per-thread device copying policies in TFE.
PiperOrigin-RevId: 183093407
Diffstat (limited to 'tensorflow/c/eager/c_api_test.cc')
-rw-r--r-- | tensorflow/c/eager/c_api_test.cc | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 423a7e1ff7..18e7a64435 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -321,6 +321,55 @@ TEST(CAPI, TensorHandleSilentCopy) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } +TEST(CAPI, TensorHandleSilentCopyLocal) { + std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, + TFE_DEVICE_PLACEMENT_EXPLICIT); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx, + TFE_DEVICE_PLACEMENT_SILENT); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + const int num_devices = TF_DeviceListCount(devices); + + // Disable the test if no GPU is present. + if (num_devices > 1) { + const int device_to_use = 1; + const string name(TF_DeviceListName(devices, device_to_use, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_TensorHandle* hgpu = + TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); + TFE_OpSetDevice(matmul, name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(hgpu); + } + + TF_DeleteDeviceList(devices); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(hcpu); + TFE_DeleteContext(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); +} + TEST(CAPI, Execute) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); |