aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-09-27 16:16:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 16:28:11 -0700
commit5f67bf69d3f53d1cd3bb86ebeeb03ea2bba5911b (patch)
tree8ae6c4649b30874b3b2b5edd867b9f20813d3da9 /tensorflow/core/common_runtime
parentf41573b7956871b4142c97eb85ddf163ad641976 (diff)
Support nested variants in CopyHostToDevice and CopyDeviceToHost.
PiperOrigin-RevId: 214853860
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc82
1 files changed, 48 insertions, 34 deletions
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index d800a86199..6e2eb66b94 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -61,26 +61,33 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
status_cb->Unref();
};
auto copier = std::bind(
- [dst, recv_dev_context, out_allocator, status_cb](
- StatusCallback wrapped_done_,
- // Begin unbound arguments
- const Tensor& from, Tensor* to) {
- if (!DMAHelper::CanUseDMA(&from)) {
- Status err = errors::InvalidArgument(
- "During Variant Host->Device Copy: "
- "non-DMA-copy attempted of tensor type: ",
- DataTypeString(from.dtype()));
- status_cb->UpdateStatus(err);
- return err;
- }
- if (status_cb->ok()) {
+ [dst, recv_dev_context, out_allocator, status_cb, cpu_allocator,
+ edge_name](StatusCallback wrapped_done_,
+ // Begin unbound arguments
+ const Tensor& from, Tensor* to) {
+ if (from.dtype() == DT_VARIANT) {
status_cb->Ref();
- *to = Tensor(out_allocator, from.dtype(), from.shape());
- recv_dev_context->CopyCPUTensorToDevice(&from, dst, to,
- wrapped_done_);
+ CopyHostToDevice(&from, cpu_allocator, out_allocator, edge_name,
+ dst, to, recv_dev_context, wrapped_done_);
return Status::OK();
} else {
- return status_cb->status();
+ if (!DMAHelper::CanUseDMA(&from)) {
+ Status err = errors::InvalidArgument(
+ "During Variant Host->Device Copy: "
+ "non-DMA-copy attempted of tensor type: ",
+ DataTypeString(from.dtype()));
+ status_cb->UpdateStatus(err);
+ return err;
+ }
+ if (status_cb->ok()) {
+ status_cb->Ref();
+ *to = Tensor(out_allocator, from.dtype(), from.shape());
+ recv_dev_context->CopyCPUTensorToDevice(&from, dst, to,
+ wrapped_done_);
+ return Status::OK();
+ } else {
+ return status_cb->status();
+ }
}
},
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
@@ -119,26 +126,33 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
status_cb->Unref();
};
auto copier = std::bind(
- [edge_name, src, send_dev_context, out_allocator, status_cb](
- StatusCallback wrapped_done_,
- // Begin unbound arguments
- const Tensor& from, Tensor* to) {
- if (!DMAHelper::CanUseDMA(&from)) {
- Status err = errors::InvalidArgument(
- "During Variant Device->Host Copy: "
- "non-DMA-copy attempted of tensor type: ",
- DataTypeString(from.dtype()));
- status_cb->UpdateStatus(err);
- return err;
- }
- if (status_cb->ok()) {
+ [edge_name, src, send_dev_context, out_allocator, status_cb,
+ cpu_allocator](StatusCallback wrapped_done_,
+ // Begin unbound arguments
+ const Tensor& from, Tensor* to) {
+ if (from.dtype() == DT_VARIANT) {
status_cb->Ref();
- *to = Tensor(out_allocator, from.dtype(), from.shape());
- send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
- wrapped_done_);
+ CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name,
+ src, to, send_dev_context, wrapped_done_);
return Status::OK();
} else {
- return status_cb->status();
+ if (!DMAHelper::CanUseDMA(&from)) {
+ Status err = errors::InvalidArgument(
+ "During Variant Device->Host Copy: "
+ "non-DMA-copy attempted of tensor type: ",
+ DataTypeString(from.dtype()));
+ status_cb->UpdateStatus(err);
+ return err;
+ }
+ if (status_cb->ok()) {
+ status_cb->Ref();
+ *to = Tensor(out_allocator, from.dtype(), from.shape());
+ send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
+ wrapped_done_);
+ return Status::OK();
+ } else {
+ return status_cb->status();
+ }
}
},
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);