aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yuan Yu <yuanbyu@google.com>2016-02-23 10:57:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-23 11:31:21 -0800
commit803f9a4badfe0203550fdaebd651f23c9fc1b216 (patch)
treef775c28c0ac780ec0b796087ca7b2fa37a9e5459
parent2b9f60917a255137d9e97c8e4e767718b07d9152 (diff)
Memory swapping between GPU and CPU for stack. Changed stack push and pop ops to async.
For gradient computation for loops, stacks are used to store the tensors that are computed in the forward but needed in backprop. This CL enables very long sequence training by swapping the stack tensors from GPU to CPU. Change: 115359847
-rw-r--r--tensorflow/core/kernels/stack_ops.cc212
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc4
-rw-r--r--tensorflow/core/ops/ops.pbtxt8
-rw-r--r--tensorflow/python/kernel_tests/stack_ops_test.py45
4 files changed, 209 insertions, 60 deletions
diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc
index 069960299a..39cef2c454 100644
--- a/tensorflow/core/kernels/stack_ops.cc
+++ b/tensorflow/core/kernels/stack_ops.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <limits.h>
#include <vector>
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -34,31 +36,52 @@ limitations under the License.
namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
class Stack : public ResourceBase {
public:
+ struct TensorAndAllocation {
+ PersistentTensor tensor;
+ AllocatorAttributes alloc_attrs;
+ bool swapped_to_cpu;
+ };
+
Stack(const DataType& elem_type, const Tensor& handle)
- : elem_type_(elem_type), handle_(handle) {}
+ : elem_type_(elem_type), handle_(handle), closed_(false) {}
- void Push(const PersistentTensor& value) {
+ Status Push(const TensorAndAllocation& value) {
mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(CheckNotClosed());
stack_.push_back(value);
+ return Status::OK();
}
- bool Pop(PersistentTensor* value) {
+ Status Pop(TensorAndAllocation* value) {
mutex_lock l(mu_);
- if (!stack_.empty()) {
- *value = stack_.back();
- stack_.pop_back();
- return true;
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ if (stack_.empty()) {
+ const string& stack_name = handle_.vec<string>()(1);
+ return errors::InvalidArgument("Stack[", stack_name,
+ "] is empty when calling Pop().");
}
- return false;
+ *value = stack_.back();
+ stack_.pop_back();
+ return Status::OK();
+ }
+
+ void Close() {
+ mutex_lock l(mu_);
+ stack_.clear();
+ closed_ = true;
}
DataType ElemType() { return elem_type_; }
string DebugString() override {
mutex_lock l(mu_);
- return strings::StrCat("#elem:", stack_.size());
+ const string& stack_name = handle_.vec<string>()(1);
+ return strings::StrCat("Stack[", stack_name, "]");
}
private:
@@ -69,9 +92,36 @@ class Stack : public ResourceBase {
mutex mu_;
DataType elem_type_;
Tensor handle_;
- std::vector<PersistentTensor> stack_ GUARDED_BY(mu_);
+ bool closed_ GUARDED_BY(mu_);
+ std::vector<TensorAndAllocation> stack_ GUARDED_BY(mu_);
+
+ Status CheckNotClosed() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (closed_) {
+ const string& stack_name = handle_.vec<string>()(1);
+ return errors::InvalidArgument("Stack[", stack_name,
+ "] has already been closed.");
+ }
+ return Status::OK();
+ }
};
+Status GetStack(OpKernelContext* ctx, Stack** stack) {
+ Tensor Tstack_handle = ctx->mutable_input(0, false);
+ if (Tstack_handle.NumElements() != 2) {
+ return errors::InvalidArgument(
+ "Stack handle must have two elements, but had shape: ",
+ Tstack_handle.shape().DebugString());
+ }
+ const string& container = Tstack_handle.flat<string>()(0);
+ const string& stack_name = Tstack_handle.flat<string>()(1);
+ ResourceMgr* rm = ctx->step_resource_manager();
+ if (rm == nullptr) {
+ return errors::Internal("No per-step resource manager.");
+ }
+ TF_RETURN_IF_ERROR(rm->Lookup(container, stack_name, stack));
+ return Status::OK();
+}
+
// A per-run local stack. The stack uses a "per-step" resource manager which
// ensures that correct garbage collection on error or successful completion.
class StackOp : public OpKernel {
@@ -113,41 +163,77 @@ REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_CPU), StackOp);
REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_GPU).HostMemory("handle"),
StackOp);
-class StackPushOp : public OpKernel {
+template <typename Device>
+class StackPushOp : public AsyncOpKernel {
public:
- explicit StackPushOp(OpKernelConstruction* context) : OpKernel(context) {}
+ explicit StackPushOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("swap_memory", &swap_memory_));
+ }
- void Compute(OpKernelContext* ctx) override {
- Tensor Tstack_handle = ctx->mutable_input(0, false);
- OP_REQUIRES(ctx, Tstack_handle.NumElements() == 2,
- errors::InvalidArgument(
- "Stack handle must have two elements, but had shape: ",
- Tstack_handle.shape().DebugString()));
- const string& container = Tstack_handle.flat<string>()(0);
- const string& stack_name = Tstack_handle.flat<string>()(1);
- ResourceMgr* rm = ctx->step_resource_manager();
- OP_REQUIRES(ctx, rm != nullptr,
- errors::Internal("No per-step resource manager."));
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
+ // Get the stack from the handle.
Stack* stack = nullptr;
- OP_REQUIRES_OK(ctx, rm->Lookup(container, stack_name, &stack));
+ OP_REQUIRES_OK(ctx, GetStack(ctx, &stack));
OP_REQUIRES(ctx, ctx->input_dtype(1) == stack->ElemType(),
errors::InvalidArgument("Must have type ", stack->ElemType(),
" but got ", ctx->input_dtype(1)));
- stack->Push(PersistentTensor(ctx->input(1)));
- ctx->set_output(0, ctx->input(1));
+
+ // Push the tensor onto the stack. Swap the tensor to CPU if instructed.
+ const Tensor& tensor = ctx->input(1);
+ AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1);
+ static constexpr int copy_threshold = 2048;
+ if (swap_memory_ && !alloc_attrs.on_host() &&
+ std::is_same<Device, GPUDevice>::value &&
+ tensor.TotalBytes() > copy_threshold) {
+ // Asynchronously copy the tensor from GPU to CPU memory.
+ // TODO(yuanbyu): Swap only when there is mmeory pressure.
+ DeviceContext* device_ctxt = ctx->op_device_context();
+ auto device = static_cast<tensorflow::Device*>(ctx->device());
+ AllocatorAttributes host_alloc_attrs;
+ host_alloc_attrs.set_gpu_compatible(true);
+ host_alloc_attrs.set_on_host(true);
+ Allocator* cpu_allocator = device->GetAllocator(host_alloc_attrs);
+ Tensor* cpu_tensor =
+ new Tensor(cpu_allocator, tensor.dtype(), tensor.shape());
+ device_ctxt->CopyDeviceTensorToCPU(
+ &tensor, "StackPush", device, cpu_tensor,
+ [cpu_tensor, stack, ctx, done](const Status& s) {
+ ctx->SetStatus(s);
+ if (s.ok()) {
+ AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1);
+ ctx->SetStatus(stack->Push(
+ {PersistentTensor(*cpu_tensor), alloc_attrs, true}));
+ }
+ if (ctx->status().ok()) {
+ ctx->set_output(0, *cpu_tensor);
+ }
+ done();
+ delete cpu_tensor;
+ });
+ } else {
+ // Execute synchronously if not swapped.
+ OP_REQUIRES_OK(
+ ctx, stack->Push({PersistentTensor(tensor), alloc_attrs, false}));
+ ctx->set_output(0, tensor);
+ done();
+ }
}
bool IsExpensive() override { return false; }
+
+ private:
+ bool swap_memory_;
};
-REGISTER_KERNEL_BUILDER(Name("StackPush").Device(DEVICE_CPU), StackPushOp);
+REGISTER_KERNEL_BUILDER(Name("StackPush").Device(DEVICE_CPU),
+ StackPushOp<CPUDevice>);
#define REGISTER_GPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("StackPush") \
.Device(DEVICE_GPU) \
.HostMemory("handle") \
.TypeConstraint<type>("T"), \
- StackPushOp);
+ StackPushOp<GPUDevice>);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
@@ -162,35 +248,49 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
.HostMemory("elem") \
.HostMemory("output") \
.TypeConstraint<type>("T"), \
- StackPushOp)
+ StackPushOp<GPUDevice>)
REGISTER_GPU_HOST_KERNEL(int32);
REGISTER_GPU_HOST_KERNEL(bool);
#undef REGISTER_GPU_HOST_KERNEL
-class StackPopOp : public OpKernel {
+class StackPopOp : public AsyncOpKernel {
public:
- explicit StackPopOp(OpKernelConstruction* context) : OpKernel(context) {}
+ explicit StackPopOp(OpKernelConstruction* context) : AsyncOpKernel(context) {}
- void Compute(OpKernelContext* ctx) override {
- Tensor Tstack_handle = ctx->mutable_input(0, false);
- OP_REQUIRES(ctx, Tstack_handle.NumElements() == 2,
- errors::InvalidArgument(
- "Stack handle must have two elements, but had shape: ",
- Tstack_handle.shape().DebugString()));
- const string& container = Tstack_handle.flat<string>()(0);
- const string& stack_name = Tstack_handle.flat<string>()(1);
- ResourceMgr* rm = ctx->step_resource_manager();
- OP_REQUIRES(ctx, rm != nullptr,
- errors::Internal("No per-step resource manager."));
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
+ // Get the stack from the handle.
Stack* stack = nullptr;
- OP_REQUIRES_OK(ctx, rm->Lookup(container, stack_name, &stack));
- PersistentTensor value;
- bool has_value = stack->Pop(&value);
- OP_REQUIRES(ctx, has_value, errors::InvalidArgument(
- "Calling Pop() when the stack is empty."));
- ctx->set_output(0, *value.AccessTensor(ctx));
+ OP_REQUIRES_OK(ctx, GetStack(ctx, &stack));
+
+ // Pop the tensor. Transfer the tensor back to device if it was
+ // swapped out to CPU.
+ Stack::TensorAndAllocation value;
+ OP_REQUIRES_OK(ctx, stack->Pop(&value));
+ if (value.swapped_to_cpu) {
+ // Asynchronously copy the tensor back from CPU to GPU memory.
+ DeviceContext* device_ctxt = ctx->op_device_context();
+ Device* device = static_cast<Device*>(ctx->device());
+ Tensor* cpu_tensor = value.tensor.AccessTensor(ctx);
+ Allocator* gpu_allocator = device->GetAllocator(value.alloc_attrs);
+ Tensor* device_tensor =
+ new Tensor(gpu_allocator, cpu_tensor->dtype(), cpu_tensor->shape());
+ device_ctxt->CopyCPUTensorToDevice(
+ cpu_tensor, device, device_tensor,
+ [device_tensor, ctx, done](const Status& s) {
+ ctx->SetStatus(s);
+ if (s.ok()) {
+ ctx->set_output(0, *device_tensor);
+ }
+ done();
+ delete device_tensor;
+ });
+ } else {
+ // Execute synchronously if not swapped.
+ ctx->set_output(0, *value.tensor.AccessTensor(ctx));
+ done();
+ }
}
bool IsExpensive() override { return false; }
@@ -229,18 +329,12 @@ class StackCloseOp : public OpKernel {
explicit StackCloseOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* ctx) override {
- Tensor Tstack_handle = ctx->mutable_input(0, false);
- OP_REQUIRES(ctx, Tstack_handle.NumElements() == 2,
- errors::InvalidArgument(
- "Stack handle must have two elements, but had shape: ",
- Tstack_handle.shape().DebugString()));
- const string& container = Tstack_handle.flat<string>()(0);
- const string& stack_name = Tstack_handle.flat<string>()(1);
- ResourceMgr* rm = ctx->step_resource_manager();
- OP_REQUIRES(ctx, rm != nullptr,
- errors::Internal("No per-step resource manager."));
- OP_REQUIRES_OK(ctx, rm->Delete<Stack>(container, stack_name));
+ Stack* stack = nullptr;
+ OP_REQUIRES_OK(ctx, GetStack(ctx, &stack));
+ stack->Close();
}
+
+ bool IsExpensive() override { return false; }
};
REGISTER_KERNEL_BUILDER(Name("StackClose").Device(DEVICE_CPU), StackCloseOp);
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index 816d246060..078753f053 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -353,12 +353,14 @@ REGISTER_OP("StackPush")
.Input("elem: T")
.Output("output: T")
.Attr("T: type")
+ .Attr("swap_memory: bool = false")
.Doc(R"doc(
Push an element onto the stack.
handle: The handle to a stack.
elem: The tensor to be pushed onto the stack.
output: The same tensor as the input 'elem'.
+swap_memory: Swap `elem` to CPU. Default to false.
)doc");
REGISTER_OP("StackPop")
@@ -369,8 +371,8 @@ REGISTER_OP("StackPop")
Pop the element at the top of the stack.
handle: The handle to a stack.
-elem_type: The type of the elem that is popped.
elem: The tensor that is popped from the top of the stack.
+elem_type: The type of the elem that is popped.
)doc");
REGISTER_OP("StackClose")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 7133ec0d9e..79fb9e4f92 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -9452,6 +9452,14 @@ op {
name: "T"
type: "type"
}
+ attr {
+ name: "swap_memory"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ description: "Swap `elem` to CPU. Default to false."
+ }
summary: "Push an element onto the stack."
}
op {
diff --git a/tensorflow/python/kernel_tests/stack_ops_test.py b/tensorflow/python/kernel_tests/stack_ops_test.py
index d3ffee89b5..5270a13de2 100644
--- a/tensorflow/python/kernel_tests/stack_ops_test.py
+++ b/tensorflow/python/kernel_tests/stack_ops_test.py
@@ -18,9 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
import tensorflow as tf
from tensorflow.python.framework import errors
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
@@ -38,6 +40,49 @@ class StackOpTest(tf.test.TestCase):
self._testStackPushPop(use_gpu=False)
self._testStackPushPop(use_gpu=True)
+ def _testStackPushPopSwap(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ a = np.arange(2000)
+ x = tf.constant(a, dtype=tf.float32)
+ h = gen_data_flow_ops._stack(tf.float32, stack_name="foo")
+ c = gen_data_flow_ops._stack_push(h, x, swap_memory=True)
+ with tf.control_dependencies([c]):
+ c1 = gen_data_flow_ops._stack_pop(h, tf.float32)
+ self.assertAllClose(a, c1.eval())
+
+ def testStackPushPopSwap(self):
+ self._testStackPushPopSwap(use_gpu=False)
+ self._testStackPushPopSwap(use_gpu=True)
+
+ def _testStackWhileSwap(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ n = tf.constant(0)
+ h = gen_data_flow_ops._stack(tf.float32, stack_name="foo")
+
+ def c(x):
+ return tf.less(x, 10)
+ def b(x):
+ with tf.control_dependencies([x]):
+ a = tf.constant(np.ones(2000), dtype=tf.float32)
+ v = gen_data_flow_ops._stack_push(h, a, swap_memory=True)
+ with tf.control_dependencies([v]):
+ return tf.add(x, 1)
+ r = control_flow_ops.While(c, b, [n])
+
+ v = tf.constant(np.zeros(2000), dtype=tf.float32)
+ def c1(x, y):
+ return tf.greater(x, 0)
+ def b1(x, y):
+ nx = tf.sub(x, 1)
+ ny = y + gen_data_flow_ops._stack_pop(h, tf.float32)
+ return [nx, ny]
+ rx, ry = control_flow_ops.While(c1, b1, [r, v])
+ self.assertAllClose(np.ones(2000) * 10.0, ry.eval())
+
+ def testStackWhileSwap(self):
+ self._testStackWhileSwap(use_gpu=False)
+ self._testStackWhileSwap(use_gpu=True)
+
def _testMultiStack(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
h1 = gen_data_flow_ops._stack(tf.float32, stack_name="foo")