aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-09-27 16:16:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 16:28:11 -0700
commit5f67bf69d3f53d1cd3bb86ebeeb03ea2bba5911b (patch)
tree8ae6c4649b30874b3b2b5edd867b9f20813d3da9 /tensorflow
parentf41573b7956871b4142c97eb85ddf163ad641976 (diff)
Support nested variants in CopyHostToDevice and CopyDeviceToHost.
PiperOrigin-RevId: 214853860
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc82
-rw-r--r--tensorflow/python/kernel_tests/BUILD4
-rw-r--r--tensorflow/python/kernel_tests/list_ops_test.py26
3 files changed, 75 insertions, 37 deletions
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index d800a86199..6e2eb66b94 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -61,26 +61,33 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
status_cb->Unref();
};
auto copier = std::bind(
- [dst, recv_dev_context, out_allocator, status_cb](
- StatusCallback wrapped_done_,
- // Begin unbound arguments
- const Tensor& from, Tensor* to) {
- if (!DMAHelper::CanUseDMA(&from)) {
- Status err = errors::InvalidArgument(
- "During Variant Host->Device Copy: "
- "non-DMA-copy attempted of tensor type: ",
- DataTypeString(from.dtype()));
- status_cb->UpdateStatus(err);
- return err;
- }
- if (status_cb->ok()) {
+ [dst, recv_dev_context, out_allocator, status_cb, cpu_allocator,
+ edge_name](StatusCallback wrapped_done_,
+ // Begin unbound arguments
+ const Tensor& from, Tensor* to) {
+ if (from.dtype() == DT_VARIANT) {
status_cb->Ref();
- *to = Tensor(out_allocator, from.dtype(), from.shape());
- recv_dev_context->CopyCPUTensorToDevice(&from, dst, to,
- wrapped_done_);
+ CopyHostToDevice(&from, cpu_allocator, out_allocator, edge_name,
+ dst, to, recv_dev_context, wrapped_done_);
return Status::OK();
} else {
- return status_cb->status();
+ if (!DMAHelper::CanUseDMA(&from)) {
+ Status err = errors::InvalidArgument(
+ "During Variant Host->Device Copy: "
+ "non-DMA-copy attempted of tensor type: ",
+ DataTypeString(from.dtype()));
+ status_cb->UpdateStatus(err);
+ return err;
+ }
+ if (status_cb->ok()) {
+ status_cb->Ref();
+ *to = Tensor(out_allocator, from.dtype(), from.shape());
+ recv_dev_context->CopyCPUTensorToDevice(&from, dst, to,
+ wrapped_done_);
+ return Status::OK();
+ } else {
+ return status_cb->status();
+ }
}
},
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
@@ -119,26 +126,33 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
status_cb->Unref();
};
auto copier = std::bind(
- [edge_name, src, send_dev_context, out_allocator, status_cb](
- StatusCallback wrapped_done_,
- // Begin unbound arguments
- const Tensor& from, Tensor* to) {
- if (!DMAHelper::CanUseDMA(&from)) {
- Status err = errors::InvalidArgument(
- "During Variant Device->Host Copy: "
- "non-DMA-copy attempted of tensor type: ",
- DataTypeString(from.dtype()));
- status_cb->UpdateStatus(err);
- return err;
- }
- if (status_cb->ok()) {
+ [edge_name, src, send_dev_context, out_allocator, status_cb,
+ cpu_allocator](StatusCallback wrapped_done_,
+ // Begin unbound arguments
+ const Tensor& from, Tensor* to) {
+ if (from.dtype() == DT_VARIANT) {
status_cb->Ref();
- *to = Tensor(out_allocator, from.dtype(), from.shape());
- send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
- wrapped_done_);
+ CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name,
+ src, to, send_dev_context, wrapped_done_);
return Status::OK();
} else {
- return status_cb->status();
+ if (!DMAHelper::CanUseDMA(&from)) {
+ Status err = errors::InvalidArgument(
+ "During Variant Device->Host Copy: "
+ "non-DMA-copy attempted of tensor type: ",
+ DataTypeString(from.dtype()));
+ status_cb->UpdateStatus(err);
+ return err;
+ }
+ if (status_cb->ok()) {
+ status_cb->Ref();
+ *to = Tensor(out_allocator, from.dtype(), from.shape());
+ send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
+ wrapped_done_);
+ return Status::OK();
+ } else {
+ return status_cb->status();
+ }
}
},
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index c2e36e5e19..280c18ec00 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -3257,8 +3257,7 @@ tf_py_test(
tags = ["no_gpu"], # TODO(b/111656070)
)
-# TODO(b/116053459): Replace with cuda_py_test.
-tf_py_test(
+cuda_py_test(
name = "while_v2_test",
size = "medium",
srcs = ["while_v2_test.py"],
@@ -3278,5 +3277,4 @@ tf_py_test(
"//tensorflow/python:while_v2",
],
grpc_enabled = True,
- tags = ["no_gpu"], # TODO(b/116053459)
)
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index 0f5607712b..ae413edaec 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -170,6 +170,32 @@ class ListOpsTest(test_util.TensorFlowTestCase):
list_ops.tensor_list_pop_back(
l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
+ @test_util.run_in_graph_and_eager_modes
+ def testCPUGPUCopyNested(self):
+ if not context.num_gpus():
+ return
+ t = constant_op.constant([1.0, 2.0])
+ child_l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape())
+ l = list_ops.empty_tensor_list(
+ element_shape=constant_op.constant([], dtype=dtypes.int32),
+ element_dtype=dtypes.variant)
+ l = list_ops.tensor_list_push_back(l, child_l)
+ with context.device("gpu:0"):
+ l_gpu = array_ops.identity(l)
+ _, child_l_gpu = list_ops.tensor_list_pop_back(
+ l_gpu, element_dtype=dtypes.variant)
+ self.assertAllEqual(
+ self.evaluate(
+ list_ops.tensor_list_pop_back(
+ child_l_gpu, element_dtype=dtypes.float32)[1]), 2.0)
+ l_cpu = array_ops.identity(l_gpu)
+ _, child_l_cpu = list_ops.tensor_list_pop_back(
+ l_cpu, element_dtype=dtypes.variant)
+ self.assertAllEqual(
+ self.evaluate(
+ list_ops.tensor_list_pop_back(
+ child_l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
+
def testGraphStack(self):
with self.cached_session():
tl = list_ops.empty_tensor_list(