aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ayush Dubey <ayushd@google.com>2018-03-22 11:25:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-22 11:32:16 -0700
commit282750fee5e2df502436ca9ef6a95283f8adab34 (patch)
tree3c24d054b88ab0757c9b3e392faf0b1dcbddfaaa
parent7c4cdb8bae0e8760ebe4793d49ea5aee68768655 (diff)
Add new Ops for ScopedAllocator and the associated Concat and Split. The
ScopedAllocatorOp allocates a large backing tensor whose slices may be concatenated or splitted with ScopedAllocatorConcatOp and ScopedAllocatorSplitOp respectively. These ops should only be added via Grappler optimizations on the dataflow graph provided by the user. PiperOrigin-RevId: 190097586
-rw-r--r--tensorflow/core/BUILD3
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc11
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.h15
-rw-r--r--tensorflow/core/common_runtime/scoped_allocator.cc3
-rw-r--r--tensorflow/core/common_runtime/scoped_allocator_mgr.cc25
-rw-r--r--tensorflow/core/common_runtime/scoped_allocator_mgr.h6
-rw-r--r--tensorflow/core/common_runtime/scoped_allocator_mgr_test.cc25
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.cc16
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.h12
-rw-r--r--tensorflow/core/framework/allocator.h16
-rw-r--r--tensorflow/core/framework/device_base.h16
-rw-r--r--tensorflow/core/framework/op_kernel.cc9
-rw-r--r--tensorflow/core/kernels/BUILD37
-rw-r--r--tensorflow/core/kernels/scoped_allocator_ops.cc216
-rw-r--r--tensorflow/core/kernels/scoped_allocator_ops_test.cc296
-rw-r--r--tensorflow/core/ops/scoped_allocator_ops.cc81
16 files changed, 742 insertions, 45 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index a14eeed1a5..15cbba8285 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -634,6 +634,7 @@ tf_gen_op_libs(
"random_ops",
"remote_fused_graph_ops",
"resource_variable_ops",
+ "scoped_allocator_ops",
"sdca_ops",
"set_ops",
"script_ops",
@@ -717,6 +718,7 @@ cc_library(
":random_ops_op_lib",
":remote_fused_graph_ops_op_lib",
":resource_variable_ops_op_lib",
+ ":scoped_allocator_ops_op_lib",
":script_ops_op_lib",
":sdca_ops_op_lib",
":sendrecv_ops_op_lib",
@@ -861,6 +863,7 @@ cc_library(
"//tensorflow/core/kernels:remote_fused_graph_ops",
"//tensorflow/core/kernels:required",
"//tensorflow/core/kernels:resource_variable_ops",
+ "//tensorflow/core/kernels:scoped_allocator_ops",
"//tensorflow/core/kernels:sdca_ops",
"//tensorflow/core/kernels:set_kernels",
"//tensorflow/core/kernels:sparse",
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 8357cc5a72..52fd20e479 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -840,6 +840,17 @@ void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
}
}
+Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr,
+ int64 step_id) {
+ if (attr.scope_id > 0) {
+ return scoped_allocator_mgr_->GetContainer(step_id)->GetInstance(
+ attr.scope_id);
+ }
+ LOG(FATAL) << "Unexpected call to BaseGPUDevice::GetScopedAllocator "
+ << "attr.scope_id = " << attr.scope_id;
+ return gpu_allocator_;
+}
+
const int BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength = 1000;
const int BaseGPUDeviceFactory::InterconnectMap::kStreamExecutorStrength = 1;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index d817c7dd1f..cc5c3881dd 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -17,8 +17,8 @@ limitations under the License.
#error This file must only be included when building with Cuda support
#endif
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEVICE_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEVICE_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEVICE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEVICE_H_
#include <memory>
#include <string>
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -95,11 +96,19 @@ class BaseGPUDevice : public LocalDevice {
// corresponds to the cuda context.
gpu::StreamExecutor* executor() const { return executor_; }
+ Allocator* GetScopedAllocator(AllocatorAttributes attr,
+ int64 step_id) override;
+
+ ScopedAllocatorMgr* GetScopedAllocatorMgr() const override {
+ return scoped_allocator_mgr_.get();
+ }
+
protected:
Allocator* gpu_allocator_; // not owned
Allocator* cpu_allocator_; // not owned
gpu::StreamExecutor* executor_; // not owned
+ std::unique_ptr<ScopedAllocatorMgr> scoped_allocator_mgr_;
private:
struct StreamGroup {
@@ -205,4 +214,4 @@ class BaseGPUDeviceFactory : public DeviceFactory {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEVICE_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/scoped_allocator.cc b/tensorflow/core/common_runtime/scoped_allocator.cc
index 31e7a5e3e2..a26672b79d 100644
--- a/tensorflow/core/common_runtime/scoped_allocator.cc
+++ b/tensorflow/core/common_runtime/scoped_allocator.cc
@@ -75,7 +75,8 @@ void* ScopedAllocator::AllocateRaw(int32 field_index, size_t num_bytes) {
if (num_bytes != f.bytes) {
LOG(ERROR) << "ScopedAllocator " << name_ << " got request for "
<< num_bytes << " bytes from field " << field_index
- << " which has precalculated size " << f.bytes;
+ << " which has precalculated size " << f.bytes << " and offset "
+ << f.offset;
return nullptr;
}
diff --git a/tensorflow/core/common_runtime/scoped_allocator_mgr.cc b/tensorflow/core/common_runtime/scoped_allocator_mgr.cc
index d0d05c6d1b..e1f70404e3 100644
--- a/tensorflow/core/common_runtime/scoped_allocator_mgr.cc
+++ b/tensorflow/core/common_runtime/scoped_allocator_mgr.cc
@@ -22,7 +22,7 @@ namespace tensorflow {
Status ScopedAllocatorContainer::AddScopedAllocator(
const Tensor& backing_tensor, int32 scope_id, const string& scope_name,
const gtl::ArraySlice<ScopedAllocator::Field>& fields,
- int32 expected_call_count, ScopedAllocator** sa_ptr) {
+ int32 expected_call_count) {
VLOG(1) << "AddScopedAllocator " << mgr_->device_name()
<< " step_id_=" << step_id_ << " scope_id=" << scope_id;
mutex_lock l(mu_);
@@ -41,17 +41,17 @@ Status ScopedAllocatorContainer::AddScopedAllocator(
}
}
VLOG(2) << " container " << this << " step_id " << step_id_;
- *sa_ptr = new ScopedAllocator(backing_tensor, scope_id, scope_name, fields,
- expected_call_count, this);
- allocators_[scope_id] = ScopedAllocatorContainer::SAField(
- ScopedAllocator::kBackingIndex, *sa_ptr);
+ ScopedAllocator* sa = new ScopedAllocator(
+ backing_tensor, scope_id, scope_name, fields, expected_call_count, this);
+ allocators_[scope_id] =
+ ScopedAllocatorContainer::SAField(ScopedAllocator::kBackingIndex, sa);
VLOG(2) << "#fields " << fields.size();
for (int i = 0; i < fields.size(); ++i) {
const ScopedAllocator::Field& f = fields[i];
VLOG(2) << "Adding instance with for " << mgr_->device_name()
<< " scope_id=" << f.scope_id;
allocators_[f.scope_id] = ScopedAllocatorContainer::SAField(
- i, new ScopedAllocatorInstance(*sa_ptr, i));
+ i, new ScopedAllocatorInstance(sa, i));
}
return Status::OK();
}
@@ -154,23 +154,26 @@ Status ScopedAllocatorMgr::AddScopedAllocator(
const Tensor& backing_tensor, int64 step_id, int32 scope_id,
const string& scope_name,
const gtl::ArraySlice<ScopedAllocator::Field>& fields,
- int32 expected_call_count, ScopedAllocator** sa_ptr) {
+ int32 expected_call_count) {
ScopedAllocatorContainer* sac = GetContainer(step_id);
return sac->AddScopedAllocator(backing_tensor, scope_id, scope_name, fields,
- expected_call_count, sa_ptr);
+ expected_call_count);
}
void ScopedAllocatorMgr::PopulateFields(
- int32 scope_id, const gtl::ArraySlice<TensorShape>& shapes, DataType dtype,
- std::vector<ScopedAllocator::Field>* fields) {
+ int32 scope_id, const gtl::ArraySlice<TensorShape>& shapes,
+ const DataType dtype, std::vector<ScopedAllocator::Field>* fields) {
const int32 num_fields = static_cast<int32>(shapes.size());
fields->resize(num_fields);
size_t offset = 0;
for (int32 i = 0; i < num_fields; ++i) {
- size_t bytes = shapes[i].num_elements() * sizeof(dtype);
+ size_t bytes = shapes[i].num_elements() * DataTypeSize(dtype);
(*fields)[i].scope_id = scope_id + 1 + i;
(*fields)[i].bytes = bytes;
(*fields)[i].offset = offset;
+ VLOG(1) << "field=" << i << " scope_id=" << (*fields)[i].scope_id
+ << " bytes=" << (*fields)[i].bytes
+ << " offset=" << (*fields)[i].offset;
offset += bytes;
size_t overshoot = offset % Allocator::kAllocatorAlignment;
if (overshoot > 0) {
diff --git a/tensorflow/core/common_runtime/scoped_allocator_mgr.h b/tensorflow/core/common_runtime/scoped_allocator_mgr.h
index 4d5bc23dd9..effc5f2d77 100644
--- a/tensorflow/core/common_runtime/scoped_allocator_mgr.h
+++ b/tensorflow/core/common_runtime/scoped_allocator_mgr.h
@@ -34,7 +34,7 @@ class ScopedAllocatorContainer : public core::RefCounted {
Status AddScopedAllocator(
const Tensor& backing_tensor, int32 scope_id, const string& scope_name,
const gtl::ArraySlice<ScopedAllocator::Field>& fields,
- int32 expected_call_count, ScopedAllocator** sa_ptr);
+ int32 expected_call_count);
ScopedAllocatorInstance* GetInstance(int32 scope_id);
ScopedAllocator* GetAllocator(int32 scope_id);
@@ -83,7 +83,7 @@ class ScopedAllocatorMgr {
const Tensor& backing_tensor, int64 step_id, int32 scope_id,
const string& scope_name,
const gtl::ArraySlice<ScopedAllocator::Field>& fields,
- int32 expected_call_count, ScopedAllocator** sa_ptr);
+ int32 expected_call_count);
void Cleanup(int64 step_id);
@@ -91,7 +91,7 @@ class ScopedAllocatorMgr {
// consecutive scope_id values following that of the base ScopedAllocator.
static void PopulateFields(int32 scope_id,
const gtl::ArraySlice<TensorShape>& shapes,
- DataType dtype,
+ const DataType dtype,
std::vector<ScopedAllocator::Field>* fields);
const string& device_name() const { return device_name_; }
diff --git a/tensorflow/core/common_runtime/scoped_allocator_mgr_test.cc b/tensorflow/core/common_runtime/scoped_allocator_mgr_test.cc
index 81cb3e7979..38e07e47f2 100644
--- a/tensorflow/core/common_runtime/scoped_allocator_mgr_test.cc
+++ b/tensorflow/core/common_runtime/scoped_allocator_mgr_test.cc
@@ -25,7 +25,7 @@ namespace {
class ScopedAllocatorMgrTest : public ::testing::Test {
public:
- ScopedAllocatorMgrTest() : sam_("CPU0"), sa_(nullptr) {}
+ ScopedAllocatorMgrTest() : sam_("CPU0") {}
void InitTensor() {
backing_tensor_ = Tensor(cpu_allocator(), DT_FLOAT, backing_tensor_shape_);
@@ -42,7 +42,7 @@ class ScopedAllocatorMgrTest : public ::testing::Test {
<< " expected_use_count " << expected_use_count;
return sam_.AddScopedAllocator(backing_tensor_, step_id_, scope_id,
"tensor_shape_599", fields_,
- expected_use_count, &sa_);
+ expected_use_count);
}
Status PrepScopedAllocatorMgr(int expected_use_count) {
@@ -87,7 +87,6 @@ class ScopedAllocatorMgrTest : public ::testing::Test {
std::vector<TensorShape> fields_shapes_;
std::vector<ScopedAllocator::Field> fields_;
ScopedAllocatorMgr sam_;
- ScopedAllocator* sa_;
const int step_id_ = 101;
const int scope_id_ = 599;
std::vector<ScopedAllocatorInstance*> sa_instances_;
@@ -138,9 +137,9 @@ TEST_F(ScopedAllocatorMgrTest, ContainerAddAllocator) {
// Cleanup the instances by invoking allocate and deallocate.
void* ptr0 =
- sa_instances_[0]->AllocateRaw(0 /* alignment */, 512 * sizeof(DT_FLOAT));
+ sa_instances_[0]->AllocateRaw(0 /* alignment */, 512 * sizeof(float));
void* ptr1 =
- sa_instances_[1]->AllocateRaw(0 /* alignment */, 512 * sizeof(DT_FLOAT));
+ sa_instances_[1]->AllocateRaw(0 /* alignment */, 512 * sizeof(float));
sa_instances_[0]->DeallocateRaw(ptr0);
sa_instances_[1]->DeallocateRaw(ptr1);
}
@@ -153,7 +152,6 @@ TEST_F(ScopedAllocatorMgrTest, AllocatorSuccess) {
fields_shapes_ = std::vector<TensorShape>({{512}, {3, 3}, {2, 256}});
Status s = PrepScopedAllocatorMgr(3);
other = sac->GetAllocator(scope_id_);
- EXPECT_EQ(other, sa_);
ScopedAllocatorInstance* inst0 = sac->GetInstance(scope_id_ + 1);
char* ptr0 = static_cast<char*>(inst0->AllocateRaw(0, 512 * sizeof(float)));
@@ -187,8 +185,7 @@ TEST_F(ScopedAllocatorMgrTest, AllocatorInitFail) {
fields_.resize(1);
fields_[0].scope_id = scope_id_ + 1;
fields_[0].offset = 0;
- fields_[0].bytes =
- backing_tensor_shape_.num_elements() * 2 * sizeof(DT_FLOAT);
+ fields_[0].bytes = backing_tensor_shape_.num_elements() * 2 * sizeof(float);
// fields[0].offset + fields[0].bytes is larger than the size of the backing
// tensor, so this check should fail
EXPECT_DEATH(Status s = AddScopedAllocator(1, scope_id_), "");
@@ -208,20 +205,20 @@ TEST_F(ScopedAllocatorMgrTest, AllocatorFail) {
// so we need to explicitly delete the instances to avoid a memleak.
SaveInstances(fields_shapes_.size());
- char* ptr0 = static_cast<char*>(
- sa_instances_[0]->AllocateRaw(0, 512 * sizeof(DT_FLOAT)));
+ char* ptr0 =
+ static_cast<char*>(sa_instances_[0]->AllocateRaw(0, 512 * sizeof(float)));
VLOG(2) << "Should fail because we deallocate ptr="
<< static_cast<void*>(ptr0 + 8) << " which we never allocated.";
EXPECT_DEATH(sa_instances_[0]->DeallocateRaw(ptr0 + 8), "");
VLOG(2) << "Should fail because we allocate smaller than the size of the "
<< "field.";
- EXPECT_EQ(nullptr, sa_instances_[1]->AllocateRaw(0, 256 * sizeof(DT_FLOAT)));
+ EXPECT_EQ(nullptr, sa_instances_[1]->AllocateRaw(0, 256 * sizeof(float)));
VLOG(2) << "Should fail because we allocate larger than the size of the "
<< "field.";
- EXPECT_EQ(nullptr, sa_instances_[1]->AllocateRaw(0, 1024 * sizeof(DT_FLOAT)));
- void* ptr1 = sa_instances_[1]->AllocateRaw(0, 512 * sizeof(DT_FLOAT));
+ EXPECT_EQ(nullptr, sa_instances_[1]->AllocateRaw(0, 1024 * sizeof(float)));
+ void* ptr1 = sa_instances_[1]->AllocateRaw(0, 512 * sizeof(float));
VLOG(2) << "Should fail because we exceed expected_use_count.";
- EXPECT_EQ(nullptr, sa_instances_[0]->AllocateRaw(0, 512 * sizeof(DT_FLOAT)));
+ EXPECT_EQ(nullptr, sa_instances_[0]->AllocateRaw(0, 512 * sizeof(float)));
sa_instances_[0]->DeallocateRaw(ptr0);
sa_instances_[1]->DeallocateRaw(ptr1);
}
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc
index 5aa01376ab..6d8de6a3c0 100644
--- a/tensorflow/core/common_runtime/threadpool_device.cc
+++ b/tensorflow/core/common_runtime/threadpool_device.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/common_runtime/scoped_allocator.h"
+#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/framework/device_base.h"
@@ -40,7 +42,8 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
Allocator* allocator)
: LocalDevice(options, Device::BuildDeviceAttributes(
name, DEVICE_CPU, memory_limit, locality)),
- allocator_(allocator) {}
+ allocator_(allocator),
+ scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) {}
ThreadPoolDevice::~ThreadPoolDevice() {}
@@ -65,6 +68,17 @@ Allocator* ThreadPoolDevice::GetAllocator(AllocatorAttributes attr) {
return allocator_;
}
+Allocator* ThreadPoolDevice::GetScopedAllocator(AllocatorAttributes attr,
+ int64 step_id) {
+ if (attr.scope_id > 0) {
+ return scoped_allocator_mgr_->GetContainer(step_id)->GetInstance(
+ attr.scope_id);
+ }
+ LOG(FATAL) << "Unexpected call to ThreadPoolDevice::GetScopedAllocator "
+ << "attr.scope_id = " << attr.scope_id;
+ return allocator_;
+}
+
Status ThreadPoolDevice::MakeTensorFromProto(
const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs,
Tensor* tensor) {
diff --git a/tensorflow/core/common_runtime/threadpool_device.h b/tensorflow/core/common_runtime/threadpool_device.h
index 37cb745a0a..afc5d15ebc 100644
--- a/tensorflow/core/common_runtime/threadpool_device.h
+++ b/tensorflow/core/common_runtime/threadpool_device.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_
-#define TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_THREADPOOL_DEVICE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_THREADPOOL_DEVICE_H_
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
@@ -31,6 +31,11 @@ class ThreadPoolDevice : public LocalDevice {
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
Allocator* GetAllocator(AllocatorAttributes attr) override;
+ Allocator* GetScopedAllocator(AllocatorAttributes attr,
+ int64 step_id) override;
+ ScopedAllocatorMgr* GetScopedAllocatorMgr() const override {
+ return scoped_allocator_mgr_.get();
+ }
Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) override;
@@ -39,8 +44,9 @@ class ThreadPoolDevice : public LocalDevice {
private:
Allocator* allocator_; // Not owned
+ std::unique_ptr<ScopedAllocatorMgr> scoped_allocator_mgr_;
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_THREADPOOL_DEVICE_H_
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index 3ce1b61246..2c87156dca 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
-#define TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_H_
+#define TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_H_
#include <stdlib.h>
@@ -359,7 +359,12 @@ struct AllocatorAttributes {
bool nic_compatible() const { return value & (0x1 << 1); }
void set_gpu_compatible(bool v) { value |= (static_cast<int>(v) << 2); }
bool gpu_compatible() const { return value & (0x1 << 2); }
- void Merge(AllocatorAttributes other) { value |= other.value; }
+ void Merge(AllocatorAttributes other) {
+ value |= other.value;
+ scope_id = (scope_id > 0 && other.scope_id == 0)
+ ? scope_id
+ : ((scope_id == 0) ? other.scope_id : 0);
+ }
// Returns true if the fields set in *this is a subset of or equal to
// those set in other.
bool IsEqualOrLessRestrictiveThan(const AllocatorAttributes& other) const {
@@ -371,6 +376,9 @@ struct AllocatorAttributes {
// upper 8 bits in device-specific ways, and ops implemented for those
// devices are responsible for setting those 8 bits appropriately.
uint32 value = 0;
+ // EXPERIMENTAL: If this is greater than zero, then allocation is delegated to
+ // a named special-purpose allocator on the same device.
+ int32 scope_id = 0;
};
// Returns a trivial implementation of Allocator which uses the system
@@ -396,4 +404,4 @@ class SubAllocator {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_H_
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index fb6d5c69e1..52b9077d8c 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_
-#define TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_DEVICE_BASE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_DEVICE_BASE_H_
#include <memory>
#include <string>
@@ -48,6 +48,7 @@ class Env;
class EventMgr;
class OpKernelContext;
class ResourceMgr;
+class ScopedAllocatorMgr;
class TensorProto;
namespace thread {
@@ -179,6 +180,15 @@ class DeviceBase {
return GetAllocator(attr);
}
+ // Return an Allocator prepared for use in particular places by graph
+ // optimization
+ virtual Allocator* GetScopedAllocator(AllocatorAttributes attr,
+ int64 step_id) {
+ LOG(FATAL) << "Device does not implement GetScopedAllocator()";
+ }
+
+ virtual ScopedAllocatorMgr* GetScopedAllocatorMgr() const { return nullptr; }
+
virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() {
CHECK(eigen_cpu_device_ != nullptr);
return eigen_cpu_device_;
@@ -243,4 +253,4 @@ class DeviceBase {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_DEVICE_BASE_H_
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 8654437059..9ec1c213c3 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -282,8 +282,13 @@ OpKernelContext::~OpKernelContext() {
}
Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
- Allocator* allocator =
- params_->device->GetStepAllocator(attr, resource_manager());
+ Allocator* allocator = nullptr;
+ if (attr.scope_id > 0) {
+ allocator = params_->device->GetScopedAllocator(attr, step_id());
+ CHECK(allocator);
+ } else {
+ allocator = params_->device->GetStepAllocator(attr, resource_manager());
+ }
if (track_allocations()) {
mutex_lock lock(mu_);
for (const auto& wrapped : wrapped_allocators_) {
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 2e39f25fc1..f6137fb860 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1667,6 +1667,43 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "scoped_allocator_ops",
+ prefix = "scoped_allocator_ops",
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:scoped_allocator_ops_op_lib",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "scoped_allocator_ops_test",
+ srcs = ["scoped_allocator_ops_test.cc"],
+ linkstatic = tf_kernel_tests_linkstatic(), #Required for benchmarking
+ deps = [
+ ":cwise_op",
+ ":dense_update_ops",
+ ":ops_testutil",
+ ":ops_util",
+ ":scoped_allocator_ops",
+ ":variable_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:math_ops_op_lib",
+ "//tensorflow/core:proto_text",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+tf_kernel_library(
name = "session_ops",
prefix = "session_ops",
deps = DATA_FLOW_DEPS,
diff --git a/tensorflow/core/kernels/scoped_allocator_ops.cc b/tensorflow/core/kernels/scoped_allocator_ops.cc
new file mode 100644
index 0000000000..d7b25ffad0
--- /dev/null
+++ b/tensorflow/core/kernels/scoped_allocator_ops.cc
@@ -0,0 +1,216 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/scoped_allocator.h"
+#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class ScopedAllocatorOp : public OpKernel {
+ public:
+ explicit ScopedAllocatorOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("T", &dtype_));
+ OP_REQUIRES_OK(context, context->GetAttr("shapes", &shapes_));
+ OP_REQUIRES_OK(context, context->GetAttr("sa_name", &name_));
+ OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
+ OP_REQUIRES_OK(context, context->GetAttr("expected_call_count",
+ &expected_call_count_));
+ device_ = context->device();
+ // Precalculate the size of the backing tensor and the offsets of
+ // the subtensors to be allocated from it, taking into account
+ // alignment considerations.
+ ScopedAllocatorMgr::PopulateFields(id_, shapes_, dtype_, &fields_);
+ size_t num_bytes = fields_.back().offset + fields_.back().bytes;
+ num_elements_ = num_bytes / DataTypeSize(dtype_);
+ OP_REQUIRES(context, num_bytes % DataTypeSize(dtype_) == 0,
+ errors::InvalidArgument(
+ "Number of bytes ", num_bytes,
+ " must be divisible by size of datatype ", dtype_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ ScopedAllocatorMgr* sam = device_->GetScopedAllocatorMgr();
+ if (!sam) {
+ context->SetStatus(errors::Internal(
+ "ScopedAllocatorMgr not supported on device ", device_->name()));
+ return;
+ }
+ Tensor* backing_tensor = nullptr;
+ AllocatorAttributes attr = context->output_alloc_attr(0);
+ Status s =
+ context->allocate_output(0, {num_elements_}, &backing_tensor, attr);
+ VLOG(1) << "_ScopedAllocatorOp new backing tensor size "
+ << backing_tensor->TotalBytes() << " num_elements_ "
+ << num_elements_ << " buffer " << DMAHelper::buffer(backing_tensor)
+ << " base addr " << DMAHelper::base(backing_tensor);
+ if (s.ok()) {
+ s = sam->AddScopedAllocator(*backing_tensor, context->step_id(), id_,
+ name_, fields_, expected_call_count_);
+ }
+ if (!s.ok()) {
+ context->SetStatus(s);
+ }
+ }
+
+ private:
+ std::vector<TensorShape> shapes_;
+ DataType dtype_;
+ int64 num_elements_;
+ std::vector<ScopedAllocator::Field> fields_;
+ string name_;
+ int32 id_;
+ int32 expected_call_count_;
+ DeviceBase* device_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("_ScopedAllocator").Device(DEVICE_CPU),
+ ScopedAllocatorOp);
+
+REGISTER_KERNEL_BUILDER(Name("_ScopedAllocator").Device(DEVICE_GPU),
+ ScopedAllocatorOp);
+
+class ScopedAllocatorConcatOp : public OpKernel {
+ public:
+ explicit ScopedAllocatorConcatOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
+ OP_REQUIRES_OK(context, context->GetAttr("T", &dtype_));
+ // This stuff is just for debugging
+ OP_REQUIRES_OK(context, context->GetAttr("sa_name", &name_));
+ OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
+ device_ = context->device();
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& backing_tensor = context->input(0);
+ // Check that type matches.
+ OP_REQUIRES(
+ context, backing_tensor.dtype() == dtype_,
+ errors::InvalidArgument("Backing tensor type ", backing_tensor.dtype(),
+ " does not match expected type ", dtype_));
+ // Check that backing tensor is at least as large as the shape of the
+ // output.
+ OP_REQUIRES(context, backing_tensor.NumElements() >= shape_.num_elements(),
+ errors::InvalidArgument("Backing tensor num elements ",
+ backing_tensor.NumElements(),
+ " is not equal to expected ",
+ shape_.num_elements()));
+ VLOG(1) << "_ScopedAllocatorConcatOp outputting backing tensor at "
+ << DMAHelper::base(&backing_tensor);
+ Tensor backing_copy(backing_tensor);
+ context->set_output(0, backing_copy);
+ const TensorBuffer* backing_buf = DMAHelper::buffer(&backing_copy);
+ const void* backing_tensor_lb = backing_buf->data();
+ const void* backing_tensor_ub = static_cast<const void*>(
+ static_cast<const char*>(backing_tensor_lb) + backing_buf->size());
+ // Check that all inputs lie entirely within the backing tensor.
+ for (int i = 1; i < context->num_inputs(); ++i) {
+ const TensorBuffer* input_buf = DMAHelper::buffer(&context->input(i));
+ const void* input_lb = input_buf->data();
+ OP_REQUIRES(
+ context, input_lb >= backing_tensor_lb,
+ errors::InvalidArgument("Lower bound check fail for input ", i,
+ " to node ", context->op_kernel().name()));
+ const void* input_ub = static_cast<const void*>(
+ static_cast<const char*>(input_lb) + input_buf->size());
+ OP_REQUIRES(
+ context, input_ub <= backing_tensor_ub,
+ errors::InvalidArgument("Upper bound check fail for input ", i,
+ " to node ", context->op_kernel().name()));
+ }
+ }
+
+ private:
+ TensorShape shape_;
+ DataType dtype_;
+ string name_;
+ int32 id_;
+ DeviceBase* device_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("_ScopedAllocatorConcat").Device(DEVICE_CPU),
+ ScopedAllocatorConcatOp);
+
+REGISTER_KERNEL_BUILDER(Name("_ScopedAllocatorConcat").Device(DEVICE_GPU),
+ ScopedAllocatorConcatOp);
+
+class ScopedAllocatorSplitOp : public OpKernel {
+ public:
+ explicit ScopedAllocatorSplitOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("T", &dtype_));
+ // This stuff is just for debugging
+ OP_REQUIRES_OK(context, context->GetAttr("sa_name", &name_));
+ OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
+ device_ = context->device();
+ }
+
+ void Compute(OpKernelContext* context) override {
+ Tensor backing_copy(context->input(0));
+ // Check that type matches.
+ OP_REQUIRES(
+ context, backing_copy.dtype() == dtype_,
+ errors::InvalidArgument("Backing tensor type ", backing_copy.dtype(),
+ " does not match expected type ", dtype_));
+ const TensorBuffer* backing_buf = DMAHelper::buffer(&backing_copy);
+ const void* backing_tensor_lb = backing_buf->data();
+ const void* backing_tensor_ub = static_cast<const void*>(
+ static_cast<const char*>(backing_tensor_lb) + backing_buf->size());
+ for (int i = 1; i < context->num_inputs(); ++i) {
+ VLOG(1) << "_ScopedAllocatorSplitOp assigning input " << i
+ << " to output " << i - 1 << " buf addr "
+ << DMAHelper::base(&context->input(i));
+ Tensor copy(context->input(i));
+ OP_REQUIRES(
+ context, copy.dtype() == dtype_,
+ errors::InvalidArgument("Input ", i, " tensor type ", copy.dtype(),
+ " does not match expected type ", dtype_));
+ context->set_output(i - 1, copy);
+ const TensorBuffer* input_buf = DMAHelper::buffer(&copy);
+ const void* input_lb = input_buf->data();
+ OP_REQUIRES(
+ context, input_lb >= backing_tensor_lb,
+ errors::InvalidArgument("Lower bound check fail for input ", i,
+ " to node ", context->op_kernel().name()));
+ const void* input_ub = static_cast<const void*>(
+ static_cast<const char*>(input_lb) + input_buf->size());
+ OP_REQUIRES(
+ context, input_ub <= backing_tensor_ub,
+ errors::InvalidArgument("Upper bound check fail for input ", i,
+ " to node ", context->op_kernel().name()));
+ }
+ }
+
+ private:
+ DataType dtype_;
+ string name_;
+ int32 id_;
+ DeviceBase* device_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("_ScopedAllocatorSplit").Device(DEVICE_CPU),
+ ScopedAllocatorSplitOp);
+
+REGISTER_KERNEL_BUILDER(Name("_ScopedAllocatorSplit").Device(DEVICE_GPU),
+ ScopedAllocatorSplitOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/scoped_allocator_ops_test.cc b/tensorflow/core/kernels/scoped_allocator_ops_test.cc
new file mode 100644
index 0000000000..3d36c8b7d4
--- /dev/null
+++ b/tensorflow/core/kernels/scoped_allocator_ops_test.cc
@@ -0,0 +1,296 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/common_runtime/scoped_allocator.h"
+#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+class ScopedAllocatorOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(const gtl::ArraySlice<TensorShape>& shapes, DataType dtype,
+ const string& name, int32 id, int32 expected_call_count) {
+ TF_EXPECT_OK(NodeDefBuilder("scoped_allocator_op", "_ScopedAllocator")
+ .Attr("T", dtype)
+ .Attr("shapes", shapes)
+ .Attr("sa_name", name)
+ .Attr("id", id)
+ .Attr("expected_call_count", expected_call_count)
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Allocate and Deallocate the tensors so that memory is not leaked
+ AllocatorAttributes attr;
+ Allocator* allocator;
+ for (size_t i = 0; i < shapes.size(); i++) {
+ attr.scope_id = id + i + 1;
+ allocator = device_->GetScopedAllocator(attr, context_->step_id());
+ Tensor temp(allocator, dtype, shapes[i]);
+ }
+ }
+};
+
+TEST_F(ScopedAllocatorOpTest, Simple) {
+ MakeOp({TensorShape({8})}, DT_FLOAT, "test", 120, 1);
+ MakeOp({TensorShape({32, 32})}, DT_DOUBLE, "test1", 130, 1);
+ MakeOp({TensorShape({64}), TensorShape({3, 3}), TensorShape({5, 5, 5})},
+ DT_HALF, "test2", 140, 3);
+ MakeOp({TensorShape({512}), TensorShape({64, 8})}, DT_UINT32, "test3", 150,
+ 2);
+}
+
+// PrepOp is common to ConcatOp tests and SplitOpTests.
+// It allocates a backing tensor that is large enough to hold all slices defined
+// by fields, creates ScopedAllocatorInstances for each field, allocates the
+// tensors, and assigns them as inputs to the op.
+// We won't use the AddInput* suite of functions from ops_testutil.h because
+// they allocate new tensors for each input. We need to mimic what a
+// ScopedAllocator would do.
+void PrepOp(DataType dtype, int32 id,
+ const std::vector<TensorShape>& fields_shapes,
+ std::vector<ScopedAllocator::Field>* fields,
+ Tensor** backing_tensor, Allocator* allocator,
+ ScopedAllocatorMgr* sam, const string& op_name,
+ std::vector<Tensor>* tensors,
+ gtl::InlinedVector<TensorValue, 4>* inputs,
+ const DataTypeVector& input_types) {
+ ScopedAllocatorMgr::PopulateFields(id, fields_shapes, dtype, fields);
+ // We don't simply allocate a tensor with shape as backing_tensor_shape,
+ // because we need to account for padding in the fields. We actually need a
+ // tensor of size at least (fields[-1].offset + fields[-1].bytes).
+ size_t num_bytes = fields->back().offset + fields->back().bytes;
+ int32_t num_elements = num_bytes / DataTypeSize(dtype);
+ CHECK_EQ(num_bytes % DataTypeSize(dtype), 0);
+
+ *backing_tensor = new Tensor(allocator, dtype, {num_elements});
+ int64 step_id = 10;
+ Status s = sam->AddScopedAllocator(**backing_tensor, step_id, id,
+ "sa_" + op_name + "_test", *fields,
+ fields_shapes.size());
+ TF_ASSERT_OK(s);
+
+ ScopedAllocatorContainer* sac = sam->GetContainer(step_id);
+ std::vector<ScopedAllocatorInstance*> sa_instances(fields_shapes.size(),
+ nullptr);
+ for (size_t i = 0; i < fields_shapes.size(); i++) {
+ sa_instances[i] = sac->GetInstance(id + i + 1);
+ tensors->push_back(Tensor(sa_instances[i], dtype, fields_shapes[i]));
+ }
+ // Now add the tensor as an input to ScopedAllocator<op_name>Op.
+ // Order matters here, so first add the backing tensor, then the slices.
+ inputs->reserve(1 + tensors->size());
+ CHECK_GT(input_types.size(), inputs->size());
+ CHECK_EQ(input_types[inputs->size()], dtype);
+ inputs->push_back({nullptr, *backing_tensor});
+ for (size_t i = 0; i < tensors->size(); i++) {
+ CHECK_EQ(input_types[inputs->size()], dtype);
+ inputs->push_back({nullptr, &((*tensors)[i])});
+ }
+}
+
+class ScopedAllocatorConcatOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(const TensorShape& shape, DataType dtype, const string& name,
+ int32 id, int32 num_tensors) {
+ TF_EXPECT_OK(
+ NodeDefBuilder("scoped_allocator_concat_op", "_ScopedAllocatorConcat")
+ .Attr("shape", shape)
+ .Attr("T", dtype)
+ .Attr("N", num_tensors)
+ .Attr("sa_name", name)
+ .Attr("id", id)
+ .Input(FakeInput(dtype)) // backing tensor
+ .Input(FakeInput(num_tensors, dtype)) // list of tensors
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+ }
+
+ void ExecOp(DataType dtype, int32 id,
+ const std::vector<TensorShape>& fields_shapes) {
+ Tensor* backing_tensor = nullptr;
+ std::vector<Tensor> tensors;
+ std::vector<ScopedAllocator::Field> fields;
+ PrepOp(dtype, id, fields_shapes, &fields, &backing_tensor, allocator(),
+ device_->GetScopedAllocatorMgr(), "split", &tensors, &inputs_,
+ input_types_);
+
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check input and output are same tensor.
+ const Tensor& input = context_->input(0);
+ OpOutputList output_list;
+ Status s = context_->output_list("output", &output_list);
+ TF_ASSERT_OK(s);
+ const Tensor& output = *(output_list[0]);
+ CHECK_EQ(DMAHelper::base(&input), DMAHelper::base(&output));
+ CHECK_EQ(input.dtype(), output.dtype());
+ CHECK_EQ(input.NumElements(), output.NumElements());
+
+ // Free the backing tensor which was allocated in PrepOp.
+ delete backing_tensor;
+ }
+};
+
+TEST_F(ScopedAllocatorConcatOpTest, Success1) {
+ MakeOp({32}, DT_FLOAT, "test", 120, 2);
+ ExecOp(DT_FLOAT, 120, {{16}, {16}});
+}
+
+TEST_F(ScopedAllocatorConcatOpTest, Success2) {
+ MakeOp({2, 2, 2}, DT_DOUBLE, "test", 120, 2);
+ ExecOp(DT_DOUBLE, 120, {{2, 2}, {2, 2}});
+}
+
+TEST_F(ScopedAllocatorConcatOpTest, Success3) {
+ MakeOp({3, 3, 3}, DT_HALF, "test", 120, 3);
+ ExecOp(DT_HALF, 120, {{3, 3}, {3, 3}, {3, 3}});
+}
+
+TEST_F(ScopedAllocatorConcatOpTest, FailDtypeCheck) {
+ MakeOp({8}, DT_FLOAT, "test", 120, 2);
+ EXPECT_DEATH(ExecOp(DT_DOUBLE, 120, {{4}, {4}}), "");
+}
+
+TEST_F(ScopedAllocatorConcatOpTest, FailNumElementsCheck) {
+ MakeOp({32}, DT_FLOAT, "test", 120, 2);
+ AddInputFromArray<float>({8}, {0, 1, 2, 3, 4, 5, 6, 7});
+ AddInputFromArray<float>({4}, {0, 1, 2, 3});
+ AddInputFromArray<float>({4}, {4, 5, 6, 7});
+ Status s = RunOpKernel();
+ EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
+}
+
+// This test should fail because the backing tensor and the input tensors are
+// unrelated, i.e. the inputs are not slices of the backing tensor.
+TEST_F(ScopedAllocatorConcatOpTest, FailBounds) {
+ MakeOp({8}, DT_DOUBLE, "test", 120, 2);
+ AddInputFromArray<double>({8}, {0, 1, 2, 3, 4, 5, 6, 7});
+ AddInputFromArray<double>({4}, {0, 1, 2, 3});
+ AddInputFromArray<double>({4}, {4, 5, 6, 7});
+ Status s = RunOpKernel();
+ EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
+}
+
+class ScopedAllocatorSplitOpTest : public OpsTestBase {
+ protected:
+ void BuildNodeDef(const TensorShape& shape, DataType dtype,
+ const string& name, int32 id, int32 num_tensors) {
+ TF_EXPECT_OK(
+ NodeDefBuilder("scoped_allocator_split_op", "_ScopedAllocatorSplit")
+ .Attr("T", dtype)
+ .Attr("N", num_tensors)
+ .Attr("sa_name", name)
+ .Attr("id", id)
+ .Input(FakeInput(dtype)) // backing tensor and input
+ .Input(
+ FakeInput(num_tensors, dtype)) // list of subtensors to forward
+ .Finalize(node_def()));
+ }
+
+ void MakeOp(const TensorShape& shape, DataType dtype, const string& name,
+ int32 id, int32 num_tensors) {
+ BuildNodeDef(shape, dtype, name, id, num_tensors);
+ TF_EXPECT_OK(InitOp());
+ }
+
+ // Similar to ConcatOpTest, we add inputs that are allocated from
+ // ScopedAllocator so that the memory lines up nicely.
+ void ExecOp(DataType dtype, int32 id,
+ const std::vector<TensorShape>& fields_shapes) {
+ Tensor* backing_tensor = nullptr;
+ std::vector<Tensor> tensors;
+ std::vector<ScopedAllocator::Field> fields;
+ PrepOp(dtype, id, fields_shapes, &fields, &backing_tensor, allocator(),
+ device_->GetScopedAllocatorMgr(), "split", &tensors, &inputs_,
+ input_types_);
+
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check that outputs are slices of backing tensor.
+ const Tensor& input = context_->input(0);
+ const void* lower_limit = DMAHelper::base(&input);
+ const char* lower_limit_c =
+ static_cast<const char*>(lower_limit); // for pointer arithmetic
+ OpOutputList output_list;
+ Status s = context_->output_list("output", &output_list);
+ TF_ASSERT_OK(s);
+ for (int i = 0; i < output_list.size(); i++) {
+ const Tensor& output = *(output_list[i]);
+ const void* expected_base =
+ static_cast<const void*>(lower_limit_c + fields[i].offset);
+ CHECK_EQ(output.dtype(), input.dtype());
+ CHECK_EQ(expected_base, DMAHelper::base(&output));
+ CHECK_EQ(output.NumElements(), fields_shapes[i].num_elements());
+ }
+
+ // Free the backing tensor which was allocated in PrepOp.
+ delete backing_tensor;
+ }
+};
+
+TEST_F(ScopedAllocatorSplitOpTest, Success1) {
+ MakeOp({32}, DT_FLOAT, "test", 120, 2);
+ ExecOp(DT_FLOAT, 120, {{16}, {16}});
+}
+
+TEST_F(ScopedAllocatorSplitOpTest, Success2) {
+ MakeOp({2, 2, 2}, DT_DOUBLE, "test", 120, 2);
+ ExecOp(DT_DOUBLE, 120, {{2, 2}, {2, 2}});
+}
+
+TEST_F(ScopedAllocatorSplitOpTest, Success3) {
+ MakeOp({3, 3, 3}, DT_HALF, "test", 120, 3);
+ ExecOp(DT_HALF, 120, {{3, 3}, {3, 3}, {3, 3}});
+}
+
+TEST_F(ScopedAllocatorSplitOpTest, FailNLessThan2) {
+ BuildNodeDef({4, 4}, DT_FLOAT, "test", 120, 1);
+ Status s = InitOp();
+ EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
+}
+
+TEST_F(ScopedAllocatorSplitOpTest, FailDtypeCheck) {
+ MakeOp({8}, DT_FLOAT, "test", 120, 2);
+ EXPECT_DEATH(ExecOp(DT_HALF, 120, {{4}, {4}}), "");
+}
+
+TEST_F(ScopedAllocatorSplitOpTest, FailBounds) {
+ MakeOp({8}, DT_DOUBLE, "test", 120, 2);
+ AddInputFromArray<double>({8}, {0, 1, 2, 3, 4, 5, 6, 7});
+ AddInputFromArray<double>({4}, {0, 1, 2, 3});
+ AddInputFromArray<double>({4}, {4, 5, 6, 7});
+ Status s = RunOpKernel();
+ EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/ops/scoped_allocator_ops.cc b/tensorflow/core/ops/scoped_allocator_ops.cc
new file mode 100644
index 0000000000..f053a53f4c
--- /dev/null
+++ b/tensorflow/core/ops/scoped_allocator_ops.cc
@@ -0,0 +1,81 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("_ScopedAllocator")
+ .Output("output: T")
+ .Attr("shapes: list(shape)")
+ .Attr("T: type")
+ .Attr("sa_name: string")
+ .Attr("id: int")
+ .Attr("expected_call_count: int")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ExplicitShape)
+ .Doc(R"doc(
+Allocates a mutable tensor that becomes available to appropriately annotated
+downstream Ops as backing store for their output tensor allocations via the
+ScopedAllocatorMgr.
+Returns a reference to this value.
+
+This is an experimental op for internal use only. It is possible to use this
+op in unsafe ways.
+)doc");
+
+REGISTER_OP("_ScopedAllocatorConcat")
+ .Output("output: T")
+ .Input("backing: T")
+ .Input("inputs: N * T")
+ .Attr("shape: shape")
+ .Attr("T: type")
+ .Attr("sa_name: string")
+ .Attr("id: int")
+ .Attr("N: int >= 2")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ExplicitShape)
+ .Doc(R"doc(
+Acts like a Concat Op that merges multple tensors into one, however it must
+only be used in conjunction with a ScopedAllocator which is backing the memory
+of all of its input tensors so that actually it just outputs a read-only
+reference to that ScopedAllocator's backing tensor.
+
+This is an experimental op for internal use only. It is possible to use this
+op in unsafe ways.
+)doc");
+
+REGISTER_OP("_ScopedAllocatorSplit")
+ .Output("output: N * T")
+ .Input("concat: T")
+ .Input("split: N * T")
+ .Attr("T: type")
+ .Attr("sa_name: string")
+ .Attr("id: int")
+ .Attr("N: int >= 2")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ExplicitShape)
+ .Doc(R"doc(
+Acts like a Concat Op that merges multple tensors into one, however it must
+only be used in conjunction with a ScopedAllocator which is backing the memory
+of all of its input tensors so that actually it just outputs a read-only
+reference to that ScopedAllocator's backing tensor.
+
+This is an experimental op for internal use only. It is possible to use this
+op in unsafe ways.
+)doc");
+
+} // end namespace tensorflow