diff options
author | 2018-09-27 16:16:26 -0700 | |
---|---|---|
committer | 2018-09-27 16:28:11 -0700 | |
commit | 5f67bf69d3f53d1cd3bb86ebeeb03ea2bba5911b (patch) | |
tree | 8ae6c4649b30874b3b2b5edd867b9f20813d3da9 /tensorflow/core/common_runtime | |
parent | f41573b7956871b4142c97eb85ddf163ad641976 (diff) |
Support nested variants in CopyHostToDevice and CopyDeviceToHost.
PiperOrigin-RevId: 214853860
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r-- | tensorflow/core/common_runtime/copy_tensor.cc | 82 |
1 files changed, 48 insertions, 34 deletions
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index d800a86199..6e2eb66b94 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -61,26 +61,33 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator, status_cb->Unref(); }; auto copier = std::bind( - [dst, recv_dev_context, out_allocator, status_cb]( - StatusCallback wrapped_done_, - // Begin unbound arguments - const Tensor& from, Tensor* to) { - if (!DMAHelper::CanUseDMA(&from)) { - Status err = errors::InvalidArgument( - "During Variant Host->Device Copy: " - "non-DMA-copy attempted of tensor type: ", - DataTypeString(from.dtype())); - status_cb->UpdateStatus(err); - return err; - } - if (status_cb->ok()) { + [dst, recv_dev_context, out_allocator, status_cb, cpu_allocator, + edge_name](StatusCallback wrapped_done_, + // Begin unbound arguments + const Tensor& from, Tensor* to) { + if (from.dtype() == DT_VARIANT) { status_cb->Ref(); - *to = Tensor(out_allocator, from.dtype(), from.shape()); - recv_dev_context->CopyCPUTensorToDevice(&from, dst, to, - wrapped_done_); + CopyHostToDevice(&from, cpu_allocator, out_allocator, edge_name, + dst, to, recv_dev_context, wrapped_done_); return Status::OK(); } else { - return status_cb->status(); + if (!DMAHelper::CanUseDMA(&from)) { + Status err = errors::InvalidArgument( + "During Variant Host->Device Copy: " + "non-DMA-copy attempted of tensor type: ", + DataTypeString(from.dtype())); + status_cb->UpdateStatus(err); + return err; + } + if (status_cb->ok()) { + status_cb->Ref(); + *to = Tensor(out_allocator, from.dtype(), from.shape()); + recv_dev_context->CopyCPUTensorToDevice(&from, dst, to, + wrapped_done_); + return Status::OK(); + } else { + return status_cb->status(); + } } }, std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2); @@ -119,26 +126,33 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, status_cb->Unref(); }; auto copier = std::bind( - [edge_name, src, send_dev_context, out_allocator, status_cb]( - StatusCallback wrapped_done_, - // Begin unbound arguments - const Tensor& from, Tensor* to) { - if (!DMAHelper::CanUseDMA(&from)) { - Status err = errors::InvalidArgument( - "During Variant Device->Host Copy: " - "non-DMA-copy attempted of tensor type: ", - DataTypeString(from.dtype())); - status_cb->UpdateStatus(err); - return err; - } - if (status_cb->ok()) { + [edge_name, src, send_dev_context, out_allocator, status_cb, + cpu_allocator](StatusCallback wrapped_done_, + // Begin unbound arguments + const Tensor& from, Tensor* to) { + if (from.dtype() == DT_VARIANT) { status_cb->Ref(); - *to = Tensor(out_allocator, from.dtype(), from.shape()); - send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to, - wrapped_done_); + CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name, + src, to, send_dev_context, wrapped_done_); return Status::OK(); } else { - return status_cb->status(); + if (!DMAHelper::CanUseDMA(&from)) { + Status err = errors::InvalidArgument( + "During Variant Device->Host Copy: " + "non-DMA-copy attempted of tensor type: ", + DataTypeString(from.dtype())); + status_cb->UpdateStatus(err); + return err; + } + if (status_cb->ok()) { + status_cb->Ref(); + *to = Tensor(out_allocator, from.dtype(), from.shape()); + send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to, + wrapped_done_); + return Status::OK(); + } else { + return status_cb->status(); + } } }, std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2); |