aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_managed_allocator.cc42
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h36
-rw-r--r--tensorflow/core/kernels/ops_testutil.h126
4 files changed, 151 insertions, 55 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index b18b3cb123..2577184329 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2109,6 +2109,7 @@ GPU_RUNTIME_HEADERS = [
"common_runtime/gpu/gpu_debug_allocator.h",
"common_runtime/gpu/gpu_device.h",
"common_runtime/gpu/gpu_init.h",
+ "common_runtime/gpu/gpu_managed_allocator.h",
"common_runtime/gpu/gpu_stream_util.h",
"common_runtime/gpu/gpu_util.h",
"common_runtime/gpu/pool_allocator.h",
@@ -2123,6 +2124,7 @@ tf_cuda_library(
"common_runtime/gpu/gpu_debug_allocator.cc",
"common_runtime/gpu/gpu_device.cc",
"common_runtime/gpu/gpu_device_factory.cc",
+ "common_runtime/gpu/gpu_managed_allocator.cc",
"common_runtime/gpu/gpu_stream_util.cc",
"common_runtime/gpu/gpu_util.cc",
"common_runtime/gpu/gpu_util_platform_specific.cc",
diff --git a/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.cc
new file mode 100644
index 0000000000..41f7249d89
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.cc
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h"
+
+#ifdef GOOGLE_CUDA
+#include "cuda/include/cuda.h"
+#endif // GOOGLE_CUDA
+
+namespace tensorflow {
+
+void* GpuManagedAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
+#ifdef GOOGLE_CUDA
+ CUdeviceptr ptr = 0;
+ CHECK_EQ(cuMemAllocManaged(&ptr, num_bytes, CU_MEM_ATTACH_GLOBAL),
+ CUDA_SUCCESS);
+ CHECK(!(ptr & (alignment - 1)));
+ return reinterpret_cast<void*>(ptr);
+#else
+ return nullptr;
+#endif
+}
+
+void GpuManagedAllocator::DeallocateRaw(void* ptr) {
+#ifdef GOOGLE_CUDA
+ CHECK_EQ(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr)), CUDA_SUCCESS);
+#endif
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h
new file mode 100644
index 0000000000..006b2ca448
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h
@@ -0,0 +1,36 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_
+
+#include "tensorflow/core/framework/allocator.h"
+
+namespace tensorflow {
+
+// An allocator for CUDA unified memory. Memory allocated with this allocator
+// can be accessed from both host and device. CUDA transparently migrates dirty
+// pages, which can be slow. Therefore, this allocator is intended for
+// convenience in functional tests only.
+class GpuManagedAllocator : public Allocator {
+ public:
+ string Name() override { return "GpuManagedAllocator"; }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+ void DeallocateRaw(void* ptr) override;
+};
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_
diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h
index 2a6b9e00bf..cfd8b11f25 100644
--- a/tensorflow/core/kernels/ops_testutil.h
+++ b/tensorflow/core/kernels/ops_testutil.h
@@ -46,8 +46,12 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
-namespace tensorflow {
+#ifdef GOOGLE_CUDA
+#include "cuda/include/cuda.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h"
+#endif
+namespace tensorflow {
namespace test {
inline void SetOutputAttrs(OpKernelContext::Params* params,
@@ -71,14 +75,16 @@ inline void SetOutputAttrs(OpKernelContext::Params* params,
// to use the BrainClient interface.
class OpsTestBase : public ::testing::Test {
public:
- OpsTestBase() : device_type_(DEVICE_CPU) {
- device_.reset(
- DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
+ OpsTestBase()
+ : device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
+ device_type_(DEVICE_CPU) {
CHECK(device_.get()) << "Could not create CPU device";
+ allocator_ = device_->GetAllocator(AllocatorAttributes());
}
~OpsTestBase() override {
gtl::STLDeleteElements(&tensors_);
+ gtl::STLDeleteElements(&managed_outputs_);
context_.reset(nullptr);
params_.reset(nullptr);
}
@@ -89,6 +95,18 @@ class OpsTestBase : public ::testing::Test {
CHECK(device_.get()) << "No device provided";
device_type_ = device_type;
device_ = std::move(device);
+#ifdef GOOGLE_CUDA
+ if (device_type == DEVICE_GPU) {
+ managed_allocator_.reset(new GpuManagedAllocator());
+ allocator_ = managed_allocator_.get();
+ } else {
+ managed_allocator_.reset();
+ allocator_ = device_->GetAllocator(AllocatorAttributes());
+ }
+#else
+ CHECK_NE(device_type, DEVICE_GPU)
+ << "Requesting GPU on binary compiled without GOOGLE_CUDA.";
+#endif
}
void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); }
@@ -118,42 +136,14 @@ class OpsTestBase : public ::testing::Test {
// TODO(vrv): Replace with something like a BrainClient Feed.
template <typename T>
void AddInput(const TensorShape& shape, std::function<T(int)> input_mapping) {
- CHECK_GT(input_types_.size(), inputs_.size())
- << "Adding more inputs than types; perhaps you need to call MakeOp";
- bool is_ref = IsRefType(input_types_[inputs_.size()]);
- Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()),
- DataTypeToEnum<T>::v(), shape);
- test::FillFn(input, input_mapping);
- tensors_.push_back(input);
- if (is_ref) {
- CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]),
- DataTypeToEnum<T>::v());
- inputs_.push_back({&lock_for_refs_, input});
- } else {
- CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum<T>::v());
- inputs_.push_back({nullptr, input});
- }
+ test::FillFn(AddInput(DataTypeToEnum<T>::v(), shape), input_mapping);
}
// Like AddInput but takes in an explicit arrayslice of data.
template <typename T>
void AddInputFromArray(const TensorShape& shape,
const gtl::ArraySlice<T>& data) {
- CHECK_GT(input_types_.size(), inputs_.size())
- << "Adding more inputs than types; perhaps you need to call MakeOp";
- bool is_ref = IsRefType(input_types_[inputs_.size()]);
- Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()),
- DataTypeToEnum<T>::v(), shape);
- test::FillValues<T>(input, data);
- tensors_.push_back(input);
- if (is_ref) {
- CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]),
- DataTypeToEnum<T>::v());
- inputs_.push_back({&lock_for_refs_, input});
- } else {
- CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum<T>::v());
- inputs_.push_back({nullptr, input});
- }
+ test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data);
}
// Convenience function to add an input and populate it with the elements from
@@ -161,21 +151,7 @@ class OpsTestBase : public ::testing::Test {
template <typename T, typename SrcType>
void AddInputFromList(const TensorShape& shape,
std::initializer_list<SrcType> data) {
- CHECK_GT(input_types_.size(), inputs_.size())
- << "Adding more inputs than types; perhaps you need to call MakeOp";
- bool is_ref = IsRefType(input_types_[inputs_.size()]);
- Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()),
- DataTypeToEnum<T>::v(), shape);
- test::FillValues<T>(input, data);
- tensors_.push_back(input);
- if (is_ref) {
- CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]),
- DataTypeToEnum<T>::v());
- inputs_.push_back({&lock_for_refs_, input});
- } else {
- CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum<T>::v());
- inputs_.push_back({nullptr, input});
- }
+ test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data);
}
// Adds a Resource type as input. If <container> is empty, uses the default
@@ -197,8 +173,7 @@ class OpsTestBase : public ::testing::Test {
handle.set_name(name);
handle.set_hash_code(type_index.hash_code());
handle.set_maybe_type_name(type_index.name());
- Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()),
- DT_RESOURCE, TensorShape({}));
+ Tensor* input = new Tensor(allocator(), DT_RESOURCE, TensorShape({}));
input->scalar<ResourceHandle>()() = handle;
tensors_.push_back(input);
inputs_.push_back({nullptr, input});
@@ -248,17 +223,54 @@ class OpsTestBase : public ::testing::Test {
// REQUIRES: 0 <= output_index < context_->num_outputs()
Tensor* GetOutput(int output_index) {
CHECK_LT(output_index, context_->num_outputs());
- return context_->mutable_output(output_index);
+ Tensor* output = context_->mutable_output(output_index);
+#ifdef GOOGLE_CUDA
+ if (device_type_ == DEVICE_GPU) {
+ managed_outputs_.resize(context_->num_outputs());
+ // Copy the output tensor to managed memory if we haven't done so.
+ if (!managed_outputs_[output_index]) {
+ Tensor* managed_output =
+ new Tensor(allocator(), output->dtype(), output->shape());
+ auto src = output->tensor_data();
+ auto dst = managed_output->tensor_data();
+ CHECK_EQ(
+ cuMemcpyDtoD(reinterpret_cast<CUdeviceptr>(dst.data()),
+ reinterpret_cast<CUdeviceptr>(src.data()), src.size()),
+ CUDA_SUCCESS);
+ CHECK_EQ(cuCtxSynchronize(), CUDA_SUCCESS);
+ managed_outputs_[output_index] = managed_output;
+ }
+ output = managed_outputs_[output_index];
+ }
+#endif
+ return output;
}
- Allocator* allocator() {
- return device_->GetAllocator(AllocatorAttributes());
- }
+ Allocator* allocator() { return allocator_; }
const DataTypeVector& output_types() const { return kernel_->output_types(); }
+ private:
+ Tensor* AddInput(DataType dtype, const TensorShape& shape) {
+ CHECK_GT(input_types_.size(), inputs_.size())
+ << "Adding more inputs than types; perhaps you need to call MakeOp";
+ bool is_ref = IsRefType(input_types_[inputs_.size()]);
+ Tensor* input = new Tensor(allocator(), dtype, shape);
+ tensors_.push_back(input);
+ if (is_ref) {
+ CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), dtype);
+ inputs_.push_back({&lock_for_refs_, input});
+ } else {
+ CHECK_EQ(input_types_[inputs_.size()], dtype);
+ inputs_.push_back({nullptr, input});
+ }
+ return input;
+ }
+
protected:
std::unique_ptr<Device> device_;
+ // The device allocator, or the managed_allocator_ below if running on GPU.
+ Allocator* allocator_;
std::unique_ptr<OpKernel> kernel_;
std::unique_ptr<ScopedStepContainer> step_container_;
@@ -271,9 +283,13 @@ class OpsTestBase : public ::testing::Test {
gtl::InlinedVector<TensorValue, 4> inputs_;
// Owns Tensors.
std::vector<Tensor*> tensors_;
+ // Copies of the outputs in unified memory (host and device accessible).
+ std::vector<Tensor*> managed_outputs_;
std::unique_ptr<OpKernelContext::Params> params_;
std::unique_ptr<OpKernelContext> context_;
+ // Unified memory allocator, only used when running on GPU.
+ std::unique_ptr<Allocator> managed_allocator_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase);