diff options
author | 2018-09-27 16:16:26 -0700 | |
---|---|---|
committer | 2018-09-27 16:28:11 -0700 | |
commit | 5f67bf69d3f53d1cd3bb86ebeeb03ea2bba5911b (patch) | |
tree | 8ae6c4649b30874b3b2b5edd867b9f20813d3da9 | |
parent | f41573b7956871b4142c97eb85ddf163ad641976 (diff) |
Support nested variants in CopyHostToDevice and CopyDeviceToHost.
PiperOrigin-RevId: 214853860
-rw-r--r-- | tensorflow/core/common_runtime/copy_tensor.cc | 82 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/list_ops_test.py | 26 |
3 files changed, 75 insertions, 37 deletions
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index d800a86199..6e2eb66b94 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -61,26 +61,33 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator, status_cb->Unref(); }; auto copier = std::bind( - [dst, recv_dev_context, out_allocator, status_cb]( - StatusCallback wrapped_done_, - // Begin unbound arguments - const Tensor& from, Tensor* to) { - if (!DMAHelper::CanUseDMA(&from)) { - Status err = errors::InvalidArgument( - "During Variant Host->Device Copy: " - "non-DMA-copy attempted of tensor type: ", - DataTypeString(from.dtype())); - status_cb->UpdateStatus(err); - return err; - } - if (status_cb->ok()) { + [dst, recv_dev_context, out_allocator, status_cb, cpu_allocator, + edge_name](StatusCallback wrapped_done_, + // Begin unbound arguments + const Tensor& from, Tensor* to) { + if (from.dtype() == DT_VARIANT) { status_cb->Ref(); - *to = Tensor(out_allocator, from.dtype(), from.shape()); - recv_dev_context->CopyCPUTensorToDevice(&from, dst, to, - wrapped_done_); + CopyHostToDevice(&from, cpu_allocator, out_allocator, edge_name, + dst, to, recv_dev_context, wrapped_done_); return Status::OK(); } else { - return status_cb->status(); + if (!DMAHelper::CanUseDMA(&from)) { + Status err = errors::InvalidArgument( + "During Variant Host->Device Copy: " + "non-DMA-copy attempted of tensor type: ", + DataTypeString(from.dtype())); + status_cb->UpdateStatus(err); + return err; + } + if (status_cb->ok()) { + status_cb->Ref(); + *to = Tensor(out_allocator, from.dtype(), from.shape()); + recv_dev_context->CopyCPUTensorToDevice(&from, dst, to, + wrapped_done_); + return Status::OK(); + } else { + return status_cb->status(); + } } }, std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2); @@ -119,26 +126,33 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, status_cb->Unref(); }; auto copier = std::bind( - [edge_name, src, send_dev_context, out_allocator, status_cb]( - StatusCallback wrapped_done_, - // Begin unbound arguments - const Tensor& from, Tensor* to) { - if (!DMAHelper::CanUseDMA(&from)) { - Status err = errors::InvalidArgument( - "During Variant Device->Host Copy: " - "non-DMA-copy attempted of tensor type: ", - DataTypeString(from.dtype())); - status_cb->UpdateStatus(err); - return err; - } - if (status_cb->ok()) { + [edge_name, src, send_dev_context, out_allocator, status_cb, + cpu_allocator](StatusCallback wrapped_done_, + // Begin unbound arguments + const Tensor& from, Tensor* to) { + if (from.dtype() == DT_VARIANT) { status_cb->Ref(); - *to = Tensor(out_allocator, from.dtype(), from.shape()); - send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to, - wrapped_done_); + CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name, + src, to, send_dev_context, wrapped_done_); return Status::OK(); } else { - return status_cb->status(); + if (!DMAHelper::CanUseDMA(&from)) { + Status err = errors::InvalidArgument( + "During Variant Device->Host Copy: " + "non-DMA-copy attempted of tensor type: ", + DataTypeString(from.dtype())); + status_cb->UpdateStatus(err); + return err; + } + if (status_cb->ok()) { + status_cb->Ref(); + *to = Tensor(out_allocator, from.dtype(), from.shape()); + send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to, + wrapped_done_); + return Status::OK(); + } else { + return status_cb->status(); + } } }, std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2); diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index c2e36e5e19..280c18ec00 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -3257,8 +3257,7 @@ tf_py_test( tags = ["no_gpu"], # TODO(b/111656070) ) -# TODO(b/116053459): Replace with cuda_py_test. -tf_py_test( +cuda_py_test( name = "while_v2_test", size = "medium", srcs = ["while_v2_test.py"], @@ -3278,5 +3277,4 @@ tf_py_test( "//tensorflow/python:while_v2", ], grpc_enabled = True, - tags = ["no_gpu"], # TODO(b/116053459) ) diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index 0f5607712b..ae413edaec 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -170,6 +170,32 @@ class ListOpsTest(test_util.TensorFlowTestCase): list_ops.tensor_list_pop_back( l_cpu, element_dtype=dtypes.float32)[1]), 2.0) + @test_util.run_in_graph_and_eager_modes + def testCPUGPUCopyNested(self): + if not context.num_gpus(): + return + t = constant_op.constant([1.0, 2.0]) + child_l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape()) + l = list_ops.empty_tensor_list( + element_shape=constant_op.constant([], dtype=dtypes.int32), + element_dtype=dtypes.variant) + l = list_ops.tensor_list_push_back(l, child_l) + with context.device("gpu:0"): + l_gpu = array_ops.identity(l) + _, child_l_gpu = list_ops.tensor_list_pop_back( + l_gpu, element_dtype=dtypes.variant) + self.assertAllEqual( + self.evaluate( + list_ops.tensor_list_pop_back( + child_l_gpu, element_dtype=dtypes.float32)[1]), 2.0) + l_cpu = array_ops.identity(l_gpu) + _, child_l_cpu = list_ops.tensor_list_pop_back( + l_cpu, element_dtype=dtypes.variant) + self.assertAllEqual( + self.evaluate( + list_ops.tensor_list_pop_back( + child_l_cpu, element_dtype=dtypes.float32)[1]), 2.0) + def testGraphStack(self): with self.cached_session(): tl = list_ops.empty_tensor_list( |