aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-02-26 11:57:30 -0800
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-02-27 14:33:33 -0800
commit0898ee302cb20d9fce50dae4f484816a2dc2d0e2 (patch)
tree4ef39daed6c3b01f120a5e98d116a0f479b6a70d /tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc
parent5caeb37e5d4314b702cf660db35b93a3bfc29819 (diff)
Use optimized ops to handle GPU memory swapping: this avoids the need for 2
pairs of extra _send/_recv nodes which speeds things up a bit. This also ensures that performance doesn't depend on the recv scheduling built in TF, which isn't always optimal. PiperOrigin-RevId: 187057831
Diffstat (limited to 'tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc88
1 files changed, 88 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc b/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc
new file mode 100644
index 0000000000..1820af6844
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc
@@ -0,0 +1,88 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Op kernels used to swap data in and out of GPU memory.
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace {
+
+class CopyFromGpuToHostKernel : public AsyncOpKernel {
+ public:
+ explicit CopyFromGpuToHostKernel(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ const Tensor& input = ctx->input(0);
+ OP_REQUIRES_ASYNC(
+ ctx, !ctx->input_alloc_attr(0).on_host(),
+ errors::Internal("The input tensor to the _CopyFromGpuToHost kernel "
+ "must reside on the device."),
+ done);
+
+ AllocatorAttributes alloc_attrs;
+ alloc_attrs.set_gpu_compatible(true);
+ alloc_attrs.set_on_host(true);
+ Tensor* output;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->allocate_output(0, input.shape(), &output, alloc_attrs),
+ done);
+
+ ctx->op_device_context()->CopyDeviceTensorToCPU(
+ &input, "CopyFromGpuToHost", static_cast<Device*>(ctx->device()),
+ output, [ctx, done](const Status& s) {
+ ctx->SetStatus(s);
+ done();
+ });
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("_CopyFromGpuToHost").Device(DEVICE_GPU).HostMemory("output"),
+ CopyFromGpuToHostKernel);
+
+class CopyFromHostToGpuKernel : public AsyncOpKernel {
+ public:
+ explicit CopyFromHostToGpuKernel(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ const Tensor& input = ctx->input(0);
+ OP_REQUIRES_ASYNC(
+ ctx, ctx->input_alloc_attr(0).on_host(),
+ errors::Internal("The input tensor to the _CopyFromHostToGpu kernel "
+ "must reside on the host."),
+ done);
+
+ Tensor* output;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, input.shape(), &output),
+ done);
+
+ ctx->op_device_context()->CopyCPUTensorToDevice(
+ &input, static_cast<Device*>(ctx->device()), output,
+ [ctx, done](const Status& s) {
+ ctx->SetStatus(s);
+ done();
+ });
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("_CopyFromHostToGpu").Device(DEVICE_GPU).HostMemory("input"),
+ CopyFromHostToGpuKernel);
+
+} // namespace
+} // namespace tensorflow