aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r--tensorflow/core/common_runtime/device.cc37
-rw-r--r--tensorflow/core/common_runtime/device.h128
-rw-r--r--tensorflow/core/common_runtime/device_factory.cc106
-rw-r--r--tensorflow/core/common_runtime/device_factory.h69
-rw-r--r--tensorflow/core/common_runtime/device_mgr.cc90
-rw-r--r--tensorflow/core/common_runtime/device_mgr.h55
-rw-r--r--tensorflow/core/common_runtime/device_set.cc68
-rw-r--r--tensorflow/core/common_runtime/device_set.h64
-rw-r--r--tensorflow/core/common_runtime/device_set_test.cc65
-rw-r--r--tensorflow/core/common_runtime/eigen_thread_pool.h22
-rw-r--r--tensorflow/core/common_runtime/executor.cc2118
-rw-r--r--tensorflow/core/common_runtime/executor.h209
-rw-r--r--tensorflow/core/common_runtime/function.cc1335
-rw-r--r--tensorflow/core/common_runtime/function.h100
-rw-r--r--tensorflow/core/common_runtime/gpu/dma_helper.h18
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_allocator_retry.cc49
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h36
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc175
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc397
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h156
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc166
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc186
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h68
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc207
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc651
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.h94
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_factory.cc52
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc132
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.h118
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc152
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_init.cc147
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_init.h19
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_region_allocator.cc371
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_region_allocator.h146
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_region_allocator_test.cc71
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_stream_util.cc97
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_stream_util.h30
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc137
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util.cc345
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util.h89
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc24
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator.cc269
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator.h202
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator_test.cc203
-rw-r--r--tensorflow/core/common_runtime/gpu/process_state.cc220
-rw-r--r--tensorflow/core/common_runtime/gpu/process_state.h140
-rw-r--r--tensorflow/core/common_runtime/gpu/visitable_allocator.h30
-rw-r--r--tensorflow/core/common_runtime/gpu_device_context.h45
-rw-r--r--tensorflow/core/common_runtime/kernel_benchmark_testlib.cc160
-rw-r--r--tensorflow/core/common_runtime/kernel_benchmark_testlib.h52
-rw-r--r--tensorflow/core/common_runtime/local_device.cc51
-rw-r--r--tensorflow/core/common_runtime/local_device.h27
-rw-r--r--tensorflow/core/common_runtime/local_session.cc500
-rw-r--r--tensorflow/core/common_runtime/local_session.h109
-rw-r--r--tensorflow/core/common_runtime/local_session_test.cc314
-rw-r--r--tensorflow/core/common_runtime/rendezvous_mgr.cc170
-rw-r--r--tensorflow/core/common_runtime/rendezvous_mgr.h73
-rw-r--r--tensorflow/core/common_runtime/session.cc51
-rw-r--r--tensorflow/core/common_runtime/session_factory.cc41
-rw-r--r--tensorflow/core/common_runtime/session_factory.h25
-rw-r--r--tensorflow/core/common_runtime/session_options.cc9
-rw-r--r--tensorflow/core/common_runtime/session_test.cc17
-rw-r--r--tensorflow/core/common_runtime/simple_placer.cc559
-rw-r--r--tensorflow/core/common_runtime/simple_placer.h81
-rw-r--r--tensorflow/core/common_runtime/simple_placer_test.cc863
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.cc55
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.h31
-rw-r--r--tensorflow/core/common_runtime/threadpool_device_factory.cc31
68 files changed, 12927 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/device.cc b/tensorflow/core/common_runtime/device.cc
new file mode 100644
index 0000000000..2e3e7b6597
--- /dev/null
+++ b/tensorflow/core/common_runtime/device.cc
@@ -0,0 +1,37 @@
+#include "tensorflow/core/common_runtime/device.h"
+
+#include "tensorflow/core/framework/op_segment.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+Device::Device(Env* env, const DeviceAttributes& device_attributes,
+ Allocator* device_allocator)
+ : DeviceBase(env), device_attributes_(device_attributes) {
+ CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_))
+ << "Invalid device name: " << name();
+ rmgr_ = new ResourceMgr(parsed_name_.job);
+}
+
+Device::~Device() { delete rmgr_; }
+
+// static
+DeviceAttributes Device::BuildDeviceAttributes(
+ const string& name, DeviceType device, Bytes memory_limit,
+ BusAdjacency bus_adjacency, const string& physical_device_desc) {
+ DeviceAttributes da;
+ da.set_name(name);
+ do {
+ da.set_incarnation(random::New64());
+ } while (da.incarnation() == 0); // This proto field must not be zero
+ da.set_device_type(device.type());
+ da.set_memory_limit(memory_limit.value());
+ da.set_bus_adjacency(bus_adjacency);
+ da.set_physical_device_desc(physical_device_desc);
+ return da;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
new file mode 100644
index 0000000000..ff3404fea4
--- /dev/null
+++ b/tensorflow/core/common_runtime/device.h
@@ -0,0 +1,128 @@
+// A Device is a something that can perform computations as part of a
+// model. Devices can be local (runs computation on this machine), or
+// remote (contacts a device local to another machine using an RPC to
+// do the work). Devices are registered in a DeviceSet, which is also
+// responsible for the Device <-> id mapping.
+//
+// Device names
+// * Every Device should have a unique name with the format:
+// /job:___/replica:___/task:___/(gpu|cpu):___
+// An example name would be "/job:train/replica:0/task:3/gpu:2".
+// * Task numbers are within the specified replica, so there are as
+// many "task zeros" as replicas.
+
+#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_H_
+#define TENSORFLOW_COMMON_RUNTIME_DEVICE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/control_flow.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/op_segment.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/types.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+class Device : public DeviceBase {
+ public:
+ Device(Env* env, const DeviceAttributes& device_attributes,
+ Allocator* device_allocator);
+ ~Device() override;
+
+ // Full name of this device (see top comment).
+ const string& name() const { return device_attributes_.name(); }
+
+ // Parsed name of this device
+ const DeviceNameUtils::ParsedName parsed_name() const { return parsed_name_; }
+
+ // Describes what kind of device this is. This is intended to be
+ // human-readable and not computer-parsed, except that two devices
+ // with the same device_type() are expected to perform similarly
+ // (both from a computation and communication perspective).
+ const string& device_type() const { return device_attributes_.device_type(); }
+
+ // Returns an aggregation of device attributes.
+ const DeviceAttributes& attributes() const override {
+ return device_attributes_;
+ }
+
+ // Performs the actual compute function.
+ //
+ // Subclasses may override this function if they wish to perform
+ // some initialization before each compute.
+ virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) {
+ op_kernel->Compute(context);
+ }
+
+ // Asynchronous kernel's compute.
+ virtual void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
+ AsyncOpKernel::DoneCallback done) {
+ op_kernel->ComputeAsync(context, done);
+ }
+
+ // Blocks until all operations queued on the device at the time of
+ // the call have completed. Returns any error pending on the device
+ // at completion.
+ virtual Status Sync() = 0;
+
+ // Fill in the context map for the graph. Default behavior is to do
+ // nothing.
+ //
+ // The caller takes ownership over the DeviceContext objects given
+ // by the device.
+ virtual Status FillContextMap(const Graph* graph,
+ DeviceContextMap* device_context_map) {
+ return Status::OK();
+ }
+
+ // Returns the op segment of this device. The caller can reuse op
+ // kernels registered for the same session running on this device.
+ OpSegment* op_segment() { return &op_seg_; }
+
+ // Returns the resource manager associated w/ this device.
+ ResourceMgr* resource_manager() { return rmgr_; }
+
+ // Summarizes the status of this Device, for debugging.
+ string DebugString() const { return device_attributes_.DebugString(); }
+
+ // Assembles the parameter components into a complete DeviceAttributes value.
+ static DeviceAttributes BuildDeviceAttributes(
+ const string& name, DeviceType device, Bytes memory_limit,
+ BusAdjacency bus_adjacency, const string& physical_device_desc);
+
+ static DeviceAttributes BuildDeviceAttributes(const string& name,
+ DeviceType device,
+ Bytes memory_limit,
+ BusAdjacency bus_adjacency) {
+ // Pass in an empty string as physical device name.
+ return BuildDeviceAttributes(name, device, memory_limit, bus_adjacency, "");
+ }
+
+ private:
+ const DeviceAttributes device_attributes_;
+ DeviceNameUtils::ParsedName parsed_name_;
+
+ // op_seg_ maps session handle and op name to OpKernel objects.
+ OpSegment op_seg_;
+
+ // Resources associated w/ this device. E.g., shared variables, etc.
+ ResourceMgr* rmgr_ = nullptr;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Device);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/device_factory.cc b/tensorflow/core/common_runtime/device_factory.cc
new file mode 100644
index 0000000000..7d391bde1d
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_factory.cc
@@ -0,0 +1,106 @@
+#include "tensorflow/core/common_runtime/device_factory.h"
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+namespace {
+
+static mutex* get_device_factory_lock() {
+ static mutex device_factory_lock;
+ return &device_factory_lock;
+}
+
+struct FactoryItem {
+ std::unique_ptr<DeviceFactory> factory;
+ int priority;
+};
+
+std::unordered_map<string, FactoryItem>& device_factories() {
+ static std::unordered_map<string, FactoryItem>* factories =
+ new std::unordered_map<string, FactoryItem>;
+ return *factories;
+}
+} // namespace
+
+void DeviceFactory::Register(const string& device_type, DeviceFactory* factory,
+ int priority) {
+ mutex_lock l(*get_device_factory_lock());
+ std::unique_ptr<DeviceFactory> factory_ptr(factory);
+ std::unordered_map<string, FactoryItem>& factories = device_factories();
+ auto iter = factories.find(device_type);
+ if (iter == factories.end()) {
+ factories[device_type] = {std::move(factory_ptr), priority};
+ } else {
+ if (iter->second.priority < priority) {
+ iter->second = {std::move(factory_ptr), priority};
+ } else if (iter->second.priority == priority) {
+ LOG(FATAL) << "Duplicate registration of device factory for type "
+ << device_type << " with the same priority " << priority;
+ }
+ }
+}
+
+DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
+ mutex_lock l(*get_device_factory_lock()); // could use reader lock
+ auto it = device_factories().find(device_type);
+ if (it == device_factories().end()) {
+ return nullptr;
+ }
+ return it->second.factory.get();
+}
+
+void DeviceFactory::AddDevices(const SessionOptions& options,
+ const string& name_prefix,
+ std::vector<Device*>* devices) {
+ // CPU first.
+ auto cpu_factory = GetFactory("CPU");
+ if (!cpu_factory) {
+ LOG(FATAL)
+ << "CPU Factory not registered. Did you link in threadpool_device?";
+ }
+ size_t init_size = devices->size();
+ cpu_factory->CreateDevices(options, name_prefix, devices);
+ if (devices->size() == init_size) {
+ LOG(FATAL) << "No CPU devices are available in this process";
+ }
+
+ // Then GPU.
+ auto gpu_factory = GetFactory("GPU");
+ if (gpu_factory) {
+ gpu_factory->CreateDevices(options, name_prefix, devices);
+ }
+
+ // Then the rest.
+ mutex_lock l(*get_device_factory_lock());
+ for (auto& p : device_factories()) {
+ auto factory = p.second.factory.get();
+ if (factory != cpu_factory && factory != gpu_factory) {
+ factory->CreateDevices(options, name_prefix, devices);
+ }
+ }
+}
+
+Device* DeviceFactory::NewDevice(const string& type,
+ const SessionOptions& options,
+ const string& name_prefix) {
+ auto device_factory = GetFactory(type);
+ if (!device_factory) {
+ return nullptr;
+ }
+ SessionOptions opt = options;
+ (*opt.config.mutable_device_count())[type] = 1;
+ std::vector<Device*> devices;
+ device_factory->CreateDevices(opt, name_prefix, &devices);
+ CHECK_EQ(devices.size(), 1);
+ return devices[0];
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_factory.h b/tensorflow/core/common_runtime/device_factory.h
new file mode 100644
index 0000000000..57b625b3e5
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_factory.h
@@ -0,0 +1,69 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
+#define TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
+
+#include <string>
+#include <vector>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+class Device;
+struct SessionOptions;
+
+class DeviceFactory {
+ public:
+ virtual ~DeviceFactory() {}
+ static void Register(const string& device_type, DeviceFactory* factory,
+ int priority);
+ static DeviceFactory* GetFactory(const string& device_type);
+
+ // Append to "*devices" all suitable devices, respecting
+ // any device type specific properties/counts listed in "options".
+ //
+ // CPU devices are added first.
+ static void AddDevices(const SessionOptions& options,
+ const string& name_prefix,
+ std::vector<Device*>* devices);
+
+ // Helper for tests. Create a single device of type "type". The
+ // returned device is always numbered zero, so if creating multiple
+ // devices of the same type, supply distinct name_prefix arguments.
+ static Device* NewDevice(const string& type, const SessionOptions& options,
+ const string& name_prefix);
+
+ // Most clients should call AddDevices() instead.
+ virtual void CreateDevices(const SessionOptions& options,
+ const string& name_prefix,
+ std::vector<Device*>* devices) = 0;
+};
+
+namespace dfactory {
+
+template <class Factory>
+class Registrar {
+ public:
+ // Multiple registrations for the same device type with different priorities
+ // are allowed. The registration with the highest priority will be used.
+ explicit Registrar(const string& device_type, int priority = 0) {
+ DeviceFactory::Register(device_type, new Factory(), priority);
+ }
+};
+
+} // namespace dfactory
+
+#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \
+ INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
+ __COUNTER__, ##__VA_ARGS__)
+
+#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
+ ctr, ...) \
+ static ::tensorflow::dfactory::Registrar<device_factory> \
+ INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, \
+ ##__VA_ARGS__)
+
+// __COUNTER__ must go through another macro to be properly expanded
+#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc
new file mode 100644
index 0000000000..4fa13f6b4b
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_mgr.cc
@@ -0,0 +1,90 @@
+#include "tensorflow/core/common_runtime/device_mgr.h"
+
+#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+DeviceMgr::DeviceMgr(const std::vector<Device*>& devices) {
+ for (Device* d : devices) {
+ devices_.push_back(d);
+
+ // Register under both the full name and the local name.
+ device_map_[d->name()] = d;
+ device_map_[DeviceNameUtils::LocalName(d->name())] = d;
+ device_type_counts_[d->device_type()]++;
+ }
+}
+
+DeviceMgr::~DeviceMgr() {
+ for (auto p : devices_) delete p;
+}
+
+void DeviceMgr::ListDeviceAttributes(
+ std::vector<DeviceAttributes>* devices) const {
+ devices->reserve(devices_.size());
+ for (Device* dev : devices_) {
+ devices->emplace_back(dev->attributes());
+ }
+}
+
+std::vector<Device*> DeviceMgr::ListDevices() const {
+ return std::vector<Device*>(devices_.begin(), devices_.end());
+}
+
+string DeviceMgr::DebugString() const {
+ string out;
+ for (Device* dev : devices_) {
+ strings::StrAppend(&out, dev->name(), "\n");
+ }
+ return out;
+}
+
+string DeviceMgr::DeviceMappingString() const {
+ string out;
+ for (Device* dev : devices_) {
+ if (!dev->attributes().physical_device_desc().empty()) {
+ strings::StrAppend(&out, dev->name(), " -> ",
+ dev->attributes().physical_device_desc(), "\n");
+ }
+ }
+ return out;
+}
+
+Status DeviceMgr::LookupDevice(const string& name, Device** device) const {
+ Status s;
+ auto iter = device_map_.find(name);
+ if (iter == device_map_.end()) {
+ return errors::InvalidArgument(name, " unknown device.");
+ }
+ *device = iter->second;
+ return Status::OK();
+}
+
+void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const {
+ Status s;
+ for (Device* dev : devices_) {
+ if (containers.empty()) {
+ s.Update(dev->resource_manager()->Cleanup(
+ dev->resource_manager()->default_container()));
+ } else {
+ for (const string& c : containers) {
+ s.Update(dev->resource_manager()->Cleanup(c));
+ }
+ }
+ if (!s.ok()) {
+ LOG(WARNING) << s;
+ }
+ }
+}
+
+int DeviceMgr::NumDeviceType(const string& type) const {
+ auto iter = device_type_counts_.find(type);
+ if (iter != device_type_counts_.end()) return iter->second;
+ return 0;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h
new file mode 100644
index 0000000000..c57d0222aa
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_mgr.h
@@ -0,0 +1,55 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_
+#define TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class DeviceAttributes;
+
+class DeviceMgr {
+ public:
+ // TODO(zhifengc): Other initialization information.
+ explicit DeviceMgr(const std::vector<Device*>& devices);
+ ~DeviceMgr();
+
+ // Returns attributes of all devices.
+ void ListDeviceAttributes(std::vector<DeviceAttributes>* devices) const;
+
+ std::vector<Device*> ListDevices() const;
+
+ // Returns a string listing all devices.
+ string DebugString() const;
+
+ // Returns a string of all the device mapping.
+ string DeviceMappingString() const;
+
+ // Assigns *device with pointer to Device of the given name.
+ // Accepts either a full device name, or just the replica-local suffix.
+ Status LookupDevice(const string& name, Device** device) const;
+
+ // Clears given containers of all devices if 'container' is
+ // non-empty. Otherwise, clears default containers of all devices.
+ void ClearContainers(gtl::ArraySlice<string> containers) const;
+
+ int NumDeviceType(const string& type) const;
+
+ private:
+ typedef gtl::InlinedVector<Device*, 8> DeviceVec;
+ DeviceVec devices_;
+ std::unordered_map<string, Device*> device_map_;
+ std::unordered_map<string, int> device_type_counts_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_
diff --git a/tensorflow/core/common_runtime/device_set.cc b/tensorflow/core/common_runtime/device_set.cc
new file mode 100644
index 0000000000..3b0465d9a6
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_set.cc
@@ -0,0 +1,68 @@
+#include "tensorflow/core/common_runtime/device_set.h"
+
+#include <set>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+
+DeviceSet::DeviceSet() {}
+
+DeviceSet::~DeviceSet() {}
+
+void DeviceSet::AddDevice(Device* device) {
+ devices_.push_back(device);
+ device_by_name_.insert({device->name(), device});
+}
+
+void DeviceSet::FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
+ std::vector<Device*>* devices) const {
+ // TODO(jeff): If we are going to repeatedly lookup the set of devices
+ // for the same spec, maybe we should have a cache of some sort
+ devices->clear();
+ for (Device* d : devices_) {
+ if (DeviceNameUtils::IsCompleteSpecification(spec, d->parsed_name())) {
+ devices->push_back(d);
+ }
+ }
+}
+
+Device* DeviceSet::FindDeviceByName(const string& name) const {
+ return gtl::FindPtrOrNull(device_by_name_, name);
+}
+
+// Higher result implies lower priority.
+static int Order(const DeviceType& d) {
+ if (StringPiece(d.type()) == DEVICE_CPU) {
+ return 3;
+ } else if (StringPiece(d.type()) == DEVICE_GPU) {
+ return 2;
+ } else {
+ return 1;
+ }
+}
+
+static bool ByPriority(const DeviceType& a, const DeviceType& b) {
+ // Order by "order number"; break ties lexicographically.
+ return std::make_pair(Order(a), StringPiece(a.type())) <
+ std::make_pair(Order(b), StringPiece(b.type()));
+}
+
+std::vector<DeviceType> DeviceSet::PrioritizedDeviceTypeList() const {
+ std::vector<DeviceType> result;
+ std::set<string> seen;
+ for (Device* d : devices_) {
+ auto t = d->device_type();
+ if (seen.insert(t).second) {
+ result.emplace_back(DeviceType(t));
+ }
+ }
+ std::sort(result.begin(), result.end(), ByPriority);
+ return result;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_set.h b/tensorflow/core/common_runtime/device_set.h
new file mode 100644
index 0000000000..130d965891
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_set.h
@@ -0,0 +1,64 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
+#define TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
+
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+// DeviceSet is a container class for managing the various types of
+// devices used by a model.
+class DeviceSet {
+ public:
+ DeviceSet();
+ ~DeviceSet();
+
+ // Does not take ownership of 'device'.
+ void AddDevice(Device* device);
+
+ // Set the device designated as the "client". This device
+ // must also be registered via AddDevice().
+ void set_client_device(Device* device) { client_device_ = device; }
+
+ // Returns a pointer to the device designated as the "client".
+ Device* client_device() const { return client_device_; }
+
+ // Return the list of devices in this set.
+ const std::vector<Device*>& devices() const { return devices_; }
+
+ // Given a DeviceNameUtils::ParsedName (which may have some
+ // wildcards for different components), fills "*devices" with all
+ // devices in "*this" that match "spec".
+ void FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
+ std::vector<Device*>* devices) const;
+
+ // Finds the device with the given "fullname". Returns nullptr if
+ // not found.
+ Device* FindDeviceByName(const string& fullname) const;
+
+ // Return the list of unique device types in this set, ordered
+ // with more preferable devices earlier.
+ std::vector<DeviceType> PrioritizedDeviceTypeList() const;
+
+ private:
+ // Not owned.
+ std::vector<Device*> devices_;
+
+ // Fullname -> device* for device in devices_.
+ std::unordered_map<string, Device*> device_by_name_;
+
+ // client_device_ points to an element of devices_ that we consider
+ // to be the client device (in this local process).
+ Device* client_device_ = nullptr;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DeviceSet);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
diff --git a/tensorflow/core/common_runtime/device_set_test.cc b/tensorflow/core/common_runtime/device_set_test.cc
new file mode 100644
index 0000000000..1b80a5b697
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_set_test.cc
@@ -0,0 +1,65 @@
+#include "tensorflow/core/common_runtime/device_set.h"
+
+#include "tensorflow/core/public/status.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+// Return a fake device with the specified type and name.
+static Device* Dev(const char* type, const char* name) {
+ class FakeDevice : public Device {
+ public:
+ explicit FakeDevice(const DeviceAttributes& attr)
+ : Device(nullptr, attr, nullptr) {}
+ Status Sync() override { return Status::OK(); }
+ Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
+ };
+ DeviceAttributes attr;
+ attr.set_name(name);
+ attr.set_device_type(type);
+ return new FakeDevice(attr);
+}
+
+class DeviceSetTest : public testing::Test {
+ public:
+ void AddDevice(const char* type, const char* name) {
+ Device* d = Dev(type, name);
+ owned_.emplace_back(d);
+ devices_.AddDevice(d);
+ }
+
+ std::vector<DeviceType> types() const {
+ return devices_.PrioritizedDeviceTypeList();
+ }
+
+ private:
+ DeviceSet devices_;
+ std::vector<std::unique_ptr<Device>> owned_;
+};
+
+TEST_F(DeviceSetTest, PrioritizedDeviceTypeList) {
+ EXPECT_EQ(std::vector<DeviceType>{}, types());
+
+ AddDevice("CPU", "/job:a/replica:0/task:0/cpu:0");
+ EXPECT_EQ(std::vector<DeviceType>{DeviceType(DEVICE_CPU)}, types());
+
+ AddDevice("CPU", "/job:a/replica:0/task:0/cpu:1");
+ EXPECT_EQ(std::vector<DeviceType>{DeviceType(DEVICE_CPU)}, types());
+
+ AddDevice("GPU", "/job:a/replica:0/task:0/gpu:0");
+ EXPECT_EQ(
+ (std::vector<DeviceType>{DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}),
+ types());
+
+ AddDevice("T1", "/job:a/replica:0/task:0/device:T1:0");
+ AddDevice("T1", "/job:a/replica:0/task:0/device:T1:1");
+ AddDevice("T2", "/job:a/replica:0/task:0/device:T2:0");
+ EXPECT_EQ(
+ (std::vector<DeviceType>{DeviceType("T1"), DeviceType("T2"),
+ DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}),
+ types());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eigen_thread_pool.h b/tensorflow/core/common_runtime/eigen_thread_pool.h
new file mode 100644
index 0000000000..2554f3521b
--- /dev/null
+++ b/tensorflow/core/common_runtime/eigen_thread_pool.h
@@ -0,0 +1,22 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_EIGEN_THREAD_POOL_H_
+#define TENSORFLOW_COMMON_RUNTIME_EIGEN_THREAD_POOL_H_
+
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
+ public:
+ explicit EigenThreadPoolWrapper(thread::ThreadPool* pool) : pool_(pool) {}
+ ~EigenThreadPoolWrapper() override {}
+
+ void Schedule(std::function<void()> fn) override { pool_->Schedule(fn); }
+
+ private:
+ thread::ThreadPool* pool_ = nullptr;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_EIGEN_THREAD_POOL_H_
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
new file mode 100644
index 0000000000..7f2473f93b
--- /dev/null
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -0,0 +1,2118 @@
+#include "tensorflow/core/common_runtime/executor.h"
+
+#include <atomic>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+#include <deque>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/control_flow.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/op_segment.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/edgeset.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/tensor_slice_reader_cache.h"
+
+namespace tensorflow {
+
+namespace {
+
+// 1-D, 0 element tensor.
+static const Tensor* const kEmptyTensor = new Tensor;
+
+bool IsInitializationOp(const Node* node) {
+ return node->op_def().allows_uninitialized_input();
+}
+
+// Sets the timeline_label field of *node_stats, using data from *node.
+// Returns true iff the node is a transfer node.
+// TODO(tucker): merge with the DetailText function in session.cc
+// in a common location.
+bool SetTimelineLabel(const Node* node, NodeExecStats* node_stats) {
+ bool is_transfer_node = false;
+ string memory;
+ for (auto& all : node_stats->memory()) {
+ int64 tot = all.total_bytes();
+ if (tot >= 0.1 * 1048576.0) {
+ int64 peak = all.peak_bytes();
+ if (peak > 0) {
+ memory =
+ strings::StrCat(memory, "[", all.allocator_name(),
+ strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0,
+ peak / 1048576.0));
+ } else {
+ memory = strings::StrCat(memory, "[", all.allocator_name(),
+ strings::Printf(" %.1fMB] ", tot / 1048576.0));
+ }
+ }
+ }
+ const NodeDef& def = node->def();
+ string text = "";
+ if (IsSend(node)) {
+ string tensor_name;
+ TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name));
+ string recv_device;
+ TF_CHECK_OK(GetNodeAttr(def, "recv_device", &recv_device));
+ text = strings::StrCat(memory, def.name(), " = ", def.op(), "(",
+ tensor_name, " @", recv_device);
+ is_transfer_node = true;
+ } else if (IsRecv(node)) {
+ string tensor_name;
+ TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name));
+ string send_device;
+ TF_CHECK_OK(GetNodeAttr(def, "send_device", &send_device));
+ text = strings::StrCat(memory, def.name(), " = ", def.op(), "(",
+ tensor_name, " @", send_device);
+ is_transfer_node = true;
+ } else {
+ text = strings::StrCat(
+ memory, def.name(), " = ", def.op(), "(",
+ str_util::Join(
+ std::vector<StringPiece>(def.input().begin(), def.input().end()),
+ ", "),
+ ")");
+ }
+ node_stats->set_timeline_label(text);
+ return is_transfer_node;
+}
+
+// Helper routines for collecting step stats.
+namespace nodestats {
+inline int64 NowInUsec() { return Env::Default()->NowMicros(); }
+
+void SetScheduled(NodeExecStats* nt, int64 t) { nt->set_scheduled_micros(t); }
+
+void SetAllStart(NodeExecStats* nt) { nt->set_all_start_micros(NowInUsec()); }
+
+void SetOpStart(NodeExecStats* nt) {
+ DCHECK_NE(nt->all_start_micros(), 0);
+ nt->set_op_start_rel_micros(NowInUsec() - nt->all_start_micros());
+}
+
+void SetOpEnd(NodeExecStats* nt) {
+ DCHECK_NE(nt->all_start_micros(), 0);
+ nt->set_op_end_rel_micros(NowInUsec() - nt->all_start_micros());
+}
+
+void SetAllEnd(NodeExecStats* nt) {
+ DCHECK_NE(nt->all_start_micros(), 0);
+ nt->set_all_end_rel_micros(NowInUsec() - nt->all_start_micros());
+}
+
+void SetOutput(NodeExecStats* nt, int slot, AllocationType allocation_type,
+ const Tensor* v) {
+ DCHECK(v);
+ NodeOutput* no = nt->add_output();
+ no->set_slot(slot);
+ no->set_allocation_type(allocation_type);
+ v->FillDescription(no->mutable_tensor_description());
+}
+
+void SetMemory(NodeExecStats* nt, OpKernelContext* ctx) {
+ for (const auto& allocator_pair : ctx->wrapped_allocators()) {
+ AllocatorMemoryUsed* memory = nt->add_memory();
+ // retrieving the sizes from the wrapped allocator removes the
+ // executor's reference to it, so allocator_pair.second must not
+ // be dereferenced again after this statement
+ auto sizes = allocator_pair.second->GetSizesAndUnRef();
+ memory->set_allocator_name(allocator_pair.first->Name());
+ int tb = sizes.first;
+ memory->set_total_bytes(tb);
+ if (allocator_pair.first->TracksAllocationSizes()) {
+ memory->set_peak_bytes(sizes.second);
+ }
+ }
+}
+} // namespace nodestats
+
+struct NodeItem {
+ // A graph node.
+ const Node* node = nullptr;
+
+ // The kernel for this node.
+ OpKernel* kernel = nullptr;
+
+ // ExecutorImpl::tensors_[input_start] is the 1st positional input
+ // for this node.
+ int input_start = 0;
+};
+
+// Map from std::pair<node_id, output_index> to attributes.
+struct pairhash {
+ public:
+ template <typename T, typename U>
+ std::size_t operator()(const std::pair<T, U>& x) const {
+ return std::hash<T>()(x.first) ^ std::hash<U>()(x.second);
+ }
+};
+typedef std::unordered_map<std::pair<int, int>, AllocatorAttributes, pairhash>
+ DevAttrMap;
+
+typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
+typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec;
+typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
+
+class ExecutorImpl : public Executor {
+ public:
+ ExecutorImpl(const LocalExecutorParams& p, const Graph* g)
+ : params_(p), graph_(g) {
+ CHECK(p.create_kernel != nullptr);
+ CHECK(p.delete_kernel != nullptr);
+ }
+
+ ~ExecutorImpl() override {
+ for (NodeItem& item : nodes_) {
+ params_.delete_kernel(item.kernel);
+ }
+ delete graph_;
+ }
+
+ Status Initialize();
+
+ // Infer memory allocation attributes of a node n's output,
+ // based on its use node dst. Note that dst might not be directly
+ // connected to n by a single edge, but might be a downstream
+ // consumer of n's output by reference. *attr is updated with any
+ // necessary attributes.
+ Status InferAllocAttr(const Node* n, const Node* dst,
+ const DeviceNameUtils::ParsedName& local_dev_name,
+ AllocatorAttributes* attr);
+
+ // Process all Nodes in the current graph, attempting to infer the
+ // memory allocation attributes to be used wherever they may allocate
+ // a tensor buffer.
+ Status SetAllocAttrs();
+
+ void RunAsync(const Args& args, DoneCallback done) override;
+
+ private:
+ friend class ExecutorState;
+ friend class SimpleExecutorState;
+
+ // Owned.
+ LocalExecutorParams params_;
+ const Graph* graph_;
+ std::vector<NodeItem> nodes_; // nodes_.size == graph_.num_node_ids().
+ int total_tensors_ = 0; // total_tensors_ = sum(nodes_[*].num_inputs())
+
+ // The number of inputs for each frame in this graph. This is static
+ // information of the graph.
+ std::unordered_map<string, int> frame_input_count_;
+
+ DevAttrMap alloc_attr_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
+};
+
+Status ExecutorImpl::Initialize() {
+ const int num_nodes = graph_->num_node_ids();
+ nodes_.resize(num_nodes);
+
+ Status s;
+ total_tensors_ = 0;
+
+ // Preprocess every node in the graph to create an instance of op
+ // kernel for each node;
+ for (const Node* n : graph_->nodes()) {
+ const int id = n->id();
+ NodeItem* item = &nodes_[id];
+ item->node = n;
+ item->input_start = total_tensors_;
+ total_tensors_ += n->num_inputs();
+ s = params_.create_kernel(n->def(), &item->kernel);
+ if (!s.ok()) {
+ s = AttachDef(s, n->def());
+ LOG(ERROR) << "Executor failed to create kernel. " << s;
+ break;
+ }
+ CHECK(item->kernel);
+
+ // Initialize static information about the frames in the graph.
+ if (IsEnter(n)) {
+ string frame_name;
+ s = GetNodeAttr(n->def(), "frame_name", &frame_name);
+ if (!s.ok()) return s;
+ ++frame_input_count_[frame_name];
+ }
+ }
+ if (params_.has_control_flow) {
+ VLOG(2) << "Graph has control flow.";
+ }
+ if (!s.ok()) return s;
+ return SetAllocAttrs();
+}
+
+Status ExecutorImpl::SetAllocAttrs() {
+ Status s;
+ Device* device = params_.device;
+ DeviceNameUtils::ParsedName local_dev_name = device->parsed_name();
+
+ for (const Node* n : graph_->nodes()) {
+ // Examine the out edges of each node looking for special use
+ // cases that may affect memory allocation attributes.
+ for (auto e : n->out_edges()) {
+ AllocatorAttributes attr;
+ s = InferAllocAttr(n, e->dst(), local_dev_name, &attr);
+ if (!s.ok()) return s;
+ if (attr.value != 0) {
+ VLOG(2) << "node " << n->name() << " gets attr " << attr.value
+ << " for output " << e->src_output();
+ alloc_attr_[std::make_pair(n->id(), e->src_output())].Merge(attr);
+ } else {
+ VLOG(2) << "default output attr for node " << n->name() << " output "
+ << e->src_output();
+ }
+ }
+ }
+ return s;
+}
+
+Status ExecutorImpl::InferAllocAttr(
+ const Node* n, const Node* dst,
+ const DeviceNameUtils::ParsedName& local_dev_name,
+ AllocatorAttributes* attr) {
+ Status s;
+ if (IsSend(dst)) {
+ string dst_name;
+ s = GetNodeAttr(dst->def(), "recv_device", &dst_name);
+ if (!s.ok()) return s;
+ DeviceNameUtils::ParsedName parsed_dst_name;
+ if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) {
+ s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ",
+ n->name());
+ return s;
+ }
+ if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) {
+ // Value is going to be the source of an RPC.
+ attr->set_nic_compatible(true);
+ VLOG(2) << "node " << n->name() << " is the source of an RPC out";
+ } else if (local_dev_name.type == "CPU" && parsed_dst_name.type == "GPU") {
+ // Value is going to be the source of a local DMA from CPU to GPU.
+ attr->set_gpu_compatible(true);
+ VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy";
+ } else {
+ VLOG(2) << "default alloc case local type " << local_dev_name.type
+ << " remote type " << parsed_dst_name.type;
+ }
+ } else if (dst->type_string() == "ToFloat") {
+ for (auto e : dst->out_edges()) {
+ s = InferAllocAttr(n, e->dst(), local_dev_name, attr);
+ if (!s.ok()) return s;
+ }
+ }
+ return s;
+}
+
+// The state associated with one invokation of ExecutorImpl::Run.
+// ExecutorState dispatches nodes when they become ready and keeps
+// track of how many predecessors of a node have not done (pending_).
+class ExecutorState {
+ public:
+ ExecutorState(const Executor::Args& args, ExecutorImpl* impl);
+ ~ExecutorState();
+
+ void RunAsync(Executor::DoneCallback done);
+
+ private:
+ typedef ExecutorState ME;
+
+ // Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
+ // TODO(yuanbyu): A better way to do "has_value"?
+ struct Entry {
+ Tensor val = *kEmptyTensor; // A tensor value.
+ Tensor* ref = nullptr; // A tensor reference.
+ mutex* ref_mu = nullptr; // mutex for *ref if ref is not nullptr.
+ bool has_value = false; // Whether the value exists
+
+ // Every entry carries an optional DeviceContext containing
+ // Device-specific information about how the Tensor was produced.
+ DeviceContext* device_context = nullptr;
+
+ // The attributes of the allocator that creates the tensor.
+ AllocatorAttributes alloc_attr;
+ };
+
+ // Contains a map from node id to the DeviceContext object that was
+ // assigned by the device at the beginning of a step.
+ DeviceContextMap device_context_map_;
+
+ struct IterationState {
+ // The state of an iteration.
+
+ // The pending count for each graph node. One copy per iteration.
+ // Iteration i can be garbage collected when it is done.
+ // TODO(yuanbyu): This vector currently has size of the number of nodes
+ // in this partition. This is not efficient if the subgraph for the frame
+ // is only a small subset of the partition. We should make the vector
+ // size to be only the size of the frame subgraph.
+ std::vector<int>* pending_count;
+
+ // The dead input count for each graph node. One copy per iteration.
+ std::vector<int>* dead_count;
+
+ // One copy per iteration. For iteration k, i-th node's j-th input is in
+ // input_tensors[k][impl_->nodes[i].input_start + j]. An entry is either
+ // a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
+ //
+ // NOTE: No need to protect input_tensors[i] by any locks because it
+ // is resized once. Each element of tensors_ is written once by the
+ // source node of an edge and is cleared by the destination of the same
+ // edge. The latter node is never run concurrently with the former node.
+ std::vector<Entry>* input_tensors;
+
+ // The number of outstanding ops for each iteration.
+ int outstanding_ops;
+
+ // The number of outstanding frames for each iteration.
+ int outstanding_frame_count;
+
+ ~IterationState() {
+ delete pending_count;
+ delete dead_count;
+ delete input_tensors;
+ }
+ };
+
+ struct FrameState {
+ // A new frame is created for each loop. Execution starts at iteration 0.
+ // When a value at iteration 0 passes through a NextIteration node,
+ // iteration 1 is created and starts running. Note that iteration 0 may
+ // still be running so multiple iterations may run in parallel. The
+ // frame maintains the state of iterations in several data structures
+ // such as pending_count and input_tensors. When iteration 0 completes,
+ // we garbage collect the state of iteration 0.
+ //
+ // A frame instance is considered "done" and can be garbage collected
+ // if all its inputs have entered and all its iterations are "done".
+ //
+ // A frame manages the live iterations of an iterative computation.
+ // Iteration i is considered "done" when there are no outstanding ops,
+ // frames at iteration i are done, all recvs for this iteration are
+ // completed, and iteration i-1 is done. For iteration 0, we instead
+ // wait for there to be no more pending inputs of the frame.
+ //
+ // Frames and iterations are garbage collected once they are done.
+ // The state we need to keep around is highly dependent on the
+ // parallelism enabled by the scheduler. We may want to have the
+ // scheduler dynamically control the outstanding number of live
+ // parallel frames and iterations. To reduce the state space, the
+ // scheduler might want to schedule ops in inner frames first and
+ // lower iterations first.
+ //
+ // This frame state is mostly initialized lazily on demand so we
+ // don't introduce unnecessary overhead.
+
+ // The name of this frame, which is the concatenation of its parent
+ // frame name, the iteration of the parent frame when this frame was
+ // created, and the value of the attr 'frame_name'.
+ string frame_name;
+
+ // The unique id for this frame. Generated by fingerprinting
+ // frame_name.
+ uint64 frame_id;
+
+ // The iteration id of its parent frame when this frame is created.
+ // -1 if there is no parent frame. The frame_name/parent_iter pair
+ // uniquely identifies this FrameState.
+ int64 parent_iter = -1;
+
+ // The FrameState of its parent frame.
+ FrameState* parent_frame = nullptr;
+
+ // The highest iteration number we have reached so far in this frame.
+ int64 iteration_count = 0;
+
+ // The number of inputs this frame is still waiting.
+ int num_pending_inputs = 0;
+
+ // The number of outstanding iterations.
+ int num_outstanding_iterations = 0;
+
+ // The maximum allowed number of parallel iterations.
+ int max_parallel_iterations = 1;
+
+ // The iteration states of this frame.
+ std::vector<IterationState*> iterations;
+
+ // The NextIteration nodes to enter a new iteration. If the number of
+ // outstanding iterations reaches the limit, we will defer the start of
+ // the next iteration until the number of outstanding iterations falls
+ // below the limit.
+ std::vector<std::pair<const Node*, Entry>> next_iter_roots;
+
+ // The values of the loop invariants for this loop. They are added into
+ // this list as they "enter" the frame. When a loop invariant enters,
+ // we make it available to all active iterations. When the frame starts
+ // a new iteration, we make all the current loop invariants available
+ // to the new iteration.
+ std::vector<std::pair<const Node*, Entry>> inv_values;
+
+ // The list of dead exit nodes for the current highest iteration. We
+ // will only "execute" the dead exits of the final iteration.
+ std::vector<const Node*> dead_exits;
+
+ IterationState* GetIteration(int64 iter) {
+ int index = iter % iterations.size();
+ return iterations[index];
+ }
+
+ void SetIteration(int64 iter, IterationState* state) {
+ int index = iter % iterations.size();
+ iterations[index] = state;
+ }
+
+ ~FrameState() {
+ for (size_t i = 0; i < iterations.size(); ++i) {
+ delete iterations[i];
+ iterations[i] = nullptr;
+ }
+ }
+ };
+
+ // A tagged node: <frame*, iter, node*>.
+ struct TaggedNode {
+ const Node* node = nullptr;
+ FrameState* input_frame = nullptr;
+ int64 input_iter = -1;
+ bool is_dead = false;
+
+ TaggedNode(const Node* t_node, FrameState* in_frame, int64 in_iter,
+ bool dead) {
+ node = t_node;
+ input_frame = in_frame;
+ input_iter = in_iter;
+ is_dead = dead;
+ }
+ };
+
+ typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
+ typedef gtl::InlinedVector<Entry, 4> EntryVector;
+
+ // Not owned.
+ Rendezvous* rendezvous_;
+ StepStatsCollector* stats_collector_;
+ // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper instead of a
+ // pointer? (avoids having to delete).
+ checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
+ FunctionCallFrame* call_frame_;
+ const ExecutorImpl* impl_;
+ CancellationManager* cancellation_manager_;
+ Executor::Args::Runner runner_;
+
+ // Owned.
+
+ // Step-local resource manager.
+ ResourceMgr step_resource_manager_;
+
+ // The root frame in which the execution of this step is started.
+ FrameState* root_frame_;
+
+ // Invoked when the execution finishes.
+ Executor::DoneCallback done_cb_;
+
+ std::atomic_int_fast32_t num_outstanding_ops_;
+
+ mutex mu_;
+ Status status_ GUARDED_BY(mu_);
+
+ // Mapping from frame name to outstanding frames. A new frame is created
+ // at some iteration of an active frame. So the unique key for the new
+ // child frame is composed of the name of the parent frame, the iteration
+ // number at which the parent frame is creating the new frame, and the
+ // name of the new frame from nodedef.
+ std::unordered_map<string, FrameState*> outstanding_frames_ GUARDED_BY(mu_);
+
+ // The unique name of a frame.
+ inline string MakeFrameName(FrameState* frame, int64 iter_id, string name) {
+ return strings::StrCat(frame->frame_name, ";", iter_id, ";", name);
+ }
+
+ // Initialize the pending count for a graph.
+ static void InitializePending(const Graph* graph, std::vector<int>* pending);
+
+ // Find an existing or create a new child frame in the frame 'frame' at
+ // iteration 'iter'.
+ void FindOrCreateChildFrame(FrameState* frame, int64 iter, const Node* node,
+ FrameState** child) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Increments the iteration id. If this is a new iteration, initialize it.
+ void IncrementIteration(FrameState* frame, TaggedNodeSeq* ready)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Returns true if the computation in the frame is completed.
+ bool IsFrameDone(FrameState* frame) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Returns true if the iteration of the frame is completed.
+ bool IsIterationDone(FrameState* frame, int64 iter)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Get the output frame/iter of a node. Create new frame/iteration if
+ // needed. If there are dead roots for the new iteration, we need to
+ // "execute" them so ad them to the ready queue. Returns true if
+ // we need to check for the completion of output frame/iter.
+ bool SetOutputFrameIter(const TaggedNode& tagged_node,
+ const EntryVector& outputs, FrameState** frame,
+ int64* iter, TaggedNodeSeq* ready)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Cleanup frames and iterations
+ void CleanupFramesIterations(FrameState* frame, int64 iter,
+ TaggedNodeSeq* ready)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Activate all the deferred NextIteration nodes in a new iteration.
+ void ActivateNexts(FrameState* frame, int64 iter, TaggedNodeSeq* ready)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Activate all the current loop invariants in a new iteration.
+ void ActivateLoopInvs(FrameState* frame, int64 iter, TaggedNodeSeq* ready)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Add a new loop invariant and make it available to all active iterations.
+ void AddLoopInv(FrameState* frame, const Node* node, const Entry& value,
+ TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Activate the successors of a node.
+ void ActivateNode(const Node* node, const bool is_dead, FrameState* frame,
+ int64 iter, const EntryVector& outputs,
+ TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Process a ready node in current thread.
+ void Process(TaggedNode node, int64 scheduled_usec);
+
+ // Before invoking item->kernel, fills in its "inputs".
+ Status PrepareInputs(const NodeItem& item, Entry* first_input,
+ TensorValueVec* inputs,
+ DeviceContextVec* input_device_contexts,
+ AllocatorAttributeVec* input_alloc_attrs,
+ bool* is_input_dead);
+
+ // After item->kernel computation is done, processes its outputs.
+ Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
+ EntryVector* outputs, NodeExecStats* stats);
+
+ // After processing the outputs, propagates the outputs to their dsts.
+ void PropagateOutputs(const TaggedNode& tagged_node,
+ const EntryVector& outputs, TaggedNodeSeq* ready);
+
+ // "node" just finishes. Takes ownership of "stats". Returns true if
+ // execution has completed.
+ bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
+ NodeExecStats* stats, std::deque<TaggedNode>* inline_ready);
+
+ // Call Process() on all nodes in 'inline_ready'.
+ void ProcessInline(const std::deque<TaggedNode>& inline_ready);
+
+ // Schedule all the expensive nodes in 'ready', and put all the inexpensive
+ // nodes in 'ready' into 'inline_ready'.
+ void ScheduleReady(const TaggedNodeSeq& ready,
+ std::deque<TaggedNode>* inline_ready);
+
+ // One thread of control finishes.
+ void Finish();
+};
+
+ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
+ : rendezvous_(args.rendezvous),
+ stats_collector_(args.stats_collector),
+ slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
+ call_frame_(args.call_frame),
+ impl_(impl),
+ cancellation_manager_(args.cancellation_manager),
+ runner_(args.runner),
+ num_outstanding_ops_(0) {
+ // We start the entire execution in iteration 0 of the root frame
+ // so let us create the root frame and the state for iteration 0.
+ // Initialize the frame.
+ root_frame_ = new FrameState;
+ root_frame_->frame_name = "_root"; // assume to be unique
+ root_frame_->frame_id = 0; // must be 0
+ root_frame_->num_pending_inputs = 0;
+ root_frame_->num_outstanding_iterations = 1;
+ root_frame_->max_parallel_iterations = 1; // enough for root frame
+ root_frame_->iterations.resize(root_frame_->max_parallel_iterations);
+
+ VLOG(2) << "Create frame: " << root_frame_->frame_name;
+
+ // Initialize the iteration.
+ IterationState* iter_state = new IterationState;
+ root_frame_->iterations[0] = iter_state;
+ iter_state->outstanding_ops = 0;
+ iter_state->outstanding_frame_count = 0;
+ iter_state->pending_count = new std::vector<int>;
+ iter_state->dead_count = new std::vector<int>(impl->graph_->num_node_ids());
+ iter_state->input_tensors = new std::vector<Entry>(impl_->total_tensors_);
+
+ // Initialize the executor state.
+ outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
+}
+
+ExecutorState::~ExecutorState() {
+ for (auto name_frame : outstanding_frames_) {
+ delete name_frame.second;
+ }
+
+ for (auto it : device_context_map_) {
+ it.second->Unref();
+ }
+
+ delete slice_reader_cache_;
+}
+
+void ExecutorState::InitializePending(const Graph* graph,
+ std::vector<int>* pending) {
+ pending->resize(graph->num_node_ids());
+ for (const Node* n : graph->nodes()) {
+ const int id = n->id();
+ const int num_in_edges = n->in_edges().size();
+ if (IsMerge(n)) {
+ // merge waits all control inputs so we initialize the pending
+ // count to be the number of control edges.
+ int32 num_control_edges = 0;
+ for (const Edge* edge : n->in_edges()) {
+ if (edge->IsControlEdge()) {
+ num_control_edges++;
+ }
+ }
+ // Use bit 0 to indicate if there is a ready live data input.
+ (*pending)[id] = num_control_edges << 1;
+ } else {
+ (*pending)[id] = num_in_edges;
+ }
+ }
+}
+
+void ExecutorState::RunAsync(Executor::DoneCallback done) {
+ const Graph* graph = impl_->graph_;
+ TaggedNodeSeq ready;
+
+ {
+ // Initialize the executor state. We grab the mutex here just to
+ // keep the thread safety analysis happy.
+ mutex_lock l(mu_);
+ std::vector<int>* pending = root_frame_->iterations[0]->pending_count;
+ InitializePending(graph, pending);
+ }
+
+ // Ask the device to fill in the device context map.
+ Device* device = impl_->params_.device;
+ device->FillContextMap(graph, &device_context_map_);
+
+ // Initialize the ready queue.
+ for (const Node* n : graph->nodes()) {
+ const int num_in_edges = n->in_edges().size();
+ if (num_in_edges == 0) {
+ ready.push_back(TaggedNode{n, root_frame_, 0, false});
+ }
+ }
+ if (ready.empty()) {
+ done(Status::OK());
+ } else {
+ num_outstanding_ops_ = ready.size();
+ root_frame_->iterations[0]->outstanding_ops = ready.size();
+ done_cb_ = done;
+ // Schedule to run all the ready ops in thread pool.
+ ScheduleReady(ready, nullptr);
+ }
+}
+
+namespace {
+
+// This function is provided for use by OpKernelContext when allocating
+// the index'th output of node. It provides access to the
+// AllocatorAttributes computed during initialization to determine in
+// which memory region the tensor should be allocated.
+AllocatorAttributes OutputAttributes(const DevAttrMap* attr_map,
+ const Node* node,
+ const OpKernel* op_kernel, int index) {
+ DCHECK_GE(index, 0);
+
+ AllocatorAttributes attr;
+ int nid = node->id();
+ const auto& iter = attr_map->find(std::make_pair(nid, index));
+ if (iter != attr_map->end()) {
+ attr = iter->second;
+ VLOG(2) << "nondefault attr " << attr.value << " for node " << node->name()
+ << " output " << index;
+ } else {
+ VLOG(2) << "default attr for node " << node->name() << " output " << index;
+ }
+
+ DCHECK_LT(index, op_kernel->output_memory_types().size());
+ bool on_host = op_kernel->output_memory_types()[index] == HOST_MEMORY;
+ attr.set_on_host(on_host);
+ return attr;
+}
+
+// Helpers to make a copy of 'p' and makes a copy of the input type
+// vector and the device context vector.
+//
+// NOTE: We need to make a copy of p.input for asynchronous kernel
+// because OpKernelContext methods like input_type(i) needs the param
+// points to valid input type vector. It's not an issue for sync
+// kernels because the type vector is kept on the stack.
+OpKernelContext::Params* CopyParams(const OpKernelContext::Params& p) {
+ OpKernelContext::Params* ret = new OpKernelContext::Params;
+ *ret = p;
+ ret->inputs = new TensorValueVec(*p.inputs);
+ ret->input_device_contexts = new DeviceContextVec(*p.input_device_contexts);
+ ret->input_alloc_attrs = new AllocatorAttributeVec(*p.input_alloc_attrs);
+ return ret;
+}
+
+// Helpers to delete 'p' and copies made by CopyParams.
+void DeleteParams(OpKernelContext::Params* p) {
+ delete p->inputs;
+ delete p->input_device_contexts;
+ delete p->input_alloc_attrs;
+ delete p;
+}
+
+} // namespace
+
+void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ TaggedNodeSeq ready;
+ std::deque<TaggedNode> inline_ready;
+
+ // Parameters passed to OpKernel::Compute.
+ TensorValueVec inputs;
+ DeviceContextVec input_device_contexts;
+ AllocatorAttributeVec input_alloc_attrs;
+
+ OpKernelContext::Params params;
+ Device* device = impl_->params_.device;
+ params.device = device;
+ // track allocations if and only if we are collecting statistics
+ params.track_allocations = (stats_collector_ != nullptr);
+ params.rendezvous = rendezvous_;
+ params.cancellation_manager = cancellation_manager_;
+ params.call_frame = call_frame_;
+ params.function_library = impl_->params_.function_library;
+ params.resource_manager = device->resource_manager();
+ params.step_resource_manager = &step_resource_manager_;
+ params.slice_reader_cache = slice_reader_cache_;
+ params.inputs = &inputs;
+ params.input_device_contexts = &input_device_contexts;
+ params.input_alloc_attrs = &input_alloc_attrs;
+
+ Status s;
+ NodeExecStats* stats = nullptr;
+ EntryVector outputs;
+ bool completed = false;
+ inline_ready.push_back(tagged_node);
+ while (!inline_ready.empty()) {
+ tagged_node = inline_ready.front();
+ inline_ready.pop_front();
+ const Node* node = tagged_node.node;
+ FrameState* input_frame = tagged_node.input_frame;
+ int64 input_iter = tagged_node.input_iter;
+ const int id = node->id();
+ const NodeItem& item = nodes[id];
+
+ // Set the device_context for this node id, if it exists.
+ auto dc_it = device_context_map_.find(id);
+ if (dc_it != device_context_map_.end()) {
+ params.op_device_context = dc_it->second;
+ }
+
+ if (stats_collector_) {
+ stats = new NodeExecStats;
+ stats->set_node_name(node->name());
+ nodestats::SetScheduled(stats, scheduled_usec);
+ nodestats::SetAllStart(stats);
+ }
+
+ VLOG(1) << "Process node: " << id << " " << SummarizeNodeDef(node->def());
+
+ std::vector<Entry>* input_tensors;
+ {
+ // Need the lock because the iterations vector could be resized by
+ // another thread.
+ mutex_lock l(mu_);
+ input_tensors = input_frame->GetIteration(input_iter)->input_tensors;
+ }
+ Entry* first_input = input_tensors->data() + item.input_start;
+ outputs.clear();
+ outputs.resize(node->num_outputs());
+
+ // Only execute this node if it is not dead or it is a send/recv
+ // transfer node. For transfer nodes, we need to propagate the "dead"
+ // bit even when the node is dead.
+ AsyncOpKernel* async = nullptr;
+ if (!tagged_node.is_dead || IsTransferNode(node)) {
+ // Prepares inputs.
+ bool is_input_dead = false;
+ s = PrepareInputs(item, first_input, &inputs, &input_device_contexts,
+ &input_alloc_attrs, &is_input_dead);
+ if (!s.ok()) {
+ // Continue to process the nodes in 'inline_ready'.
+ completed = NodeDone(s, item.node, ready, stats, &inline_ready);
+ continue;
+ }
+
+ // Set up compute params.
+ OpKernel* op_kernel = item.kernel;
+ params.op_kernel = op_kernel;
+ params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
+ params.is_input_dead = is_input_dead;
+ params.output_alloc_attr = [this, node, op_kernel](int index) {
+ return OutputAttributes(&impl_->alloc_attr_, node, op_kernel, index);
+ };
+
+ async = op_kernel->AsAsync();
+ if (async) {
+ // Asynchronous computes.
+ auto pcopy = CopyParams(params);
+ auto ctx = new OpKernelContext(*pcopy);
+ auto done = [this, tagged_node, item, first_input, ctx, stats,
+ pcopy]() {
+ VLOG(2) << this << " Async kernel done: "
+ << SummarizeNodeDef(item.node->def());
+ if (stats_collector_) nodestats::SetOpEnd(stats);
+ EntryVector outputs;
+ Status s = ProcessOutputs(item, ctx, &outputs, stats);
+ if (stats_collector_) nodestats::SetMemory(stats, ctx);
+ // Clears inputs.
+ int num_inputs = tagged_node.node->num_inputs();
+ for (int i = 0; i < num_inputs; ++i) {
+ (first_input + i)->val = *kEmptyTensor;
+ }
+ TaggedNodeSeq ready;
+ if (s.ok()) {
+ PropagateOutputs(tagged_node, outputs, &ready);
+ }
+ // Schedule to run all the ready ops in thread pool.
+ bool completed = NodeDone(s, item.node, ready, stats, nullptr);
+ delete ctx;
+ DeleteParams(pcopy);
+ if (completed) Finish();
+ };
+ if (stats_collector_) nodestats::SetOpStart(stats);
+ device->ComputeAsync(async, ctx, done);
+ } else {
+ // Synchronous computes.
+ OpKernelContext ctx(params);
+ if (stats_collector_) nodestats::SetOpStart(stats);
+ device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
+ if (stats_collector_) nodestats::SetOpEnd(stats);
+
+ // Processes outputs.
+ s = ProcessOutputs(item, &ctx, &outputs, stats);
+ if (stats_collector_) nodestats::SetMemory(stats, &ctx);
+ }
+ }
+
+ if (!async) {
+ // Clears inputs.
+ int num_inputs = node->num_inputs();
+ for (int i = 0; i < num_inputs; ++i) {
+ (first_input + i)->val = *kEmptyTensor;
+ }
+ // Propagates outputs.
+ if (s.ok()) {
+ PropagateOutputs(tagged_node, outputs, &ready);
+ }
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ // Postprocess.
+ completed = NodeDone(s, item.node, ready, stats, &inline_ready);
+ }
+ } // while !inline_ready.empty()
+
+ // This thread of computation is done if completed = true.
+ if (completed) Finish();
+}
+
+Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
+ TensorValueVec* inputs,
+ DeviceContextVec* input_device_contexts,
+ AllocatorAttributeVec* input_alloc_attrs,
+ bool* is_input_dead) {
+ const Node* node = item.node;
+
+ inputs->clear();
+ inputs->resize(node->num_inputs());
+ input_device_contexts->clear();
+ input_device_contexts->resize(node->num_inputs());
+ input_alloc_attrs->clear();
+ input_alloc_attrs->resize(node->num_inputs());
+
+ *is_input_dead = false;
+
+ bool is_merge = IsMerge(node);
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ const bool expect_ref = IsRefType(node->input_type(i));
+ Entry* entry = first_input + i;
+ (*input_device_contexts)[i] = entry->device_context;
+ (*input_alloc_attrs)[i] = entry->alloc_attr;
+
+ // i-th input.
+ TensorValue* inp = &(*inputs)[i];
+
+ // Only merge and transfer nodes can have no-value inputs.
+ if (!entry->has_value) {
+ if (!is_merge) {
+ DCHECK(IsTransferNode(node));
+ inp->tensor = &entry->val;
+ *is_input_dead = true;
+ }
+ continue;
+ }
+ if (entry->ref == nullptr) {
+ if (expect_ref) {
+ return AttachDef(
+ errors::InvalidArgument(i, "-th input expects a ref type"),
+ item.kernel->def());
+ }
+ inp->tensor = &entry->val;
+ } else {
+ if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) {
+ return AttachDef(
+ errors::FailedPrecondition("Attempting to use uninitialized value ",
+ item.kernel->def().input(i)),
+ item.kernel->def());
+ }
+ if (expect_ref) {
+ inp->mutex_if_ref = entry->ref_mu;
+ inp->tensor = entry->ref;
+ } else {
+ // Automatically deref the tensor ref when the op expects a
+ // tensor but is given a ref to a tensor. Need to deref it
+ // under the mutex.
+ {
+ mutex_lock l(*(entry->ref_mu));
+ entry->val = *entry->ref;
+ }
+ inp->tensor = &entry->val;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
+ EntryVector* outputs,
+ NodeExecStats* stats) {
+ const Node* node = item.node;
+ outputs->clear();
+ outputs->resize(node->num_outputs());
+
+ Status s = ctx->status();
+ if (!s.ok()) {
+ s = AttachDef(s, item.kernel->def());
+ LOG(WARNING) << this << " Compute status: " << s;
+ return s;
+ }
+
+ // Get the device_context for this node id, if it exists.
+ DeviceContext* device_context = nullptr;
+ auto dc_it = device_context_map_.find(node->id());
+ if (dc_it != device_context_map_.end()) {
+ device_context = dc_it->second;
+ }
+
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ TensorValue val = ctx->release_output(i);
+ // Only Switch and Recv nodes can generate new dead outputs
+ if (*ctx->is_output_dead() || val.tensor == nullptr) {
+ DCHECK(IsSwitch(node) || IsRecv(node));
+ } else {
+ Entry* out = &((*outputs)[i]);
+ out->has_value = true;
+
+ // Set the device context of the output entry.
+ out->device_context = device_context;
+
+ // Set the allocator attributes of the output entry.
+ out->alloc_attr = ctx->output_alloc_attr(i);
+
+ // Sanity check of output tensor types.
+ DataType dtype = val->dtype();
+ if (val.is_ref()) dtype = MakeRefType(dtype);
+ if (dtype == node->output_type(i)) {
+ if (val.is_ref()) {
+ out->ref = val.tensor;
+ out->ref_mu = val.mutex_if_ref;
+ } else {
+ out->val = *val.tensor;
+ }
+ if (stats_collector_ && val.tensor->IsInitialized()) {
+ nodestats::SetOutput(stats, i, ctx->output_allocation_type(i),
+ val.tensor);
+ }
+ } else {
+ s.Update(errors::Internal("Output ", i, " of type ",
+ DataTypeString(dtype),
+ " does not match declared output type ",
+ DataTypeString(node->output_type(i)),
+ " for node ", SummarizeNodeDef(node->def())));
+ }
+ }
+ if (!val.is_ref()) {
+ // If OpKernelContext returns outputs via pass-by-value, we
+ // don't need this trouble.
+ delete val.tensor;
+ }
+ }
+ return s;
+}
+
+void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
+ const EntryVector& outputs,
+ TaggedNodeSeq* ready) {
+ FrameState* input_frame = tagged_node.input_frame;
+ int64 input_iter = tagged_node.input_iter;
+
+ // Propagates outputs along out edges, and puts newly ready nodes
+ // into the ready queue.
+ ready->clear();
+
+ {
+ FrameState* output_frame = input_frame;
+ int64 output_iter = input_iter;
+
+ mutex_lock l(mu_);
+ // Sets the output_frame and output_iter of node.
+ bool maybe_completed = SetOutputFrameIter(
+ tagged_node, outputs, &output_frame, &output_iter, ready);
+ if (output_frame != nullptr) {
+ // Continue to process the out nodes:
+ ActivateNode(tagged_node.node, tagged_node.is_dead, output_frame,
+ output_iter, outputs, ready);
+ }
+
+ // At this point, this node is completely done.
+ input_frame->GetIteration(input_iter)->outstanding_ops--;
+ CleanupFramesIterations(input_frame, input_iter, ready);
+
+ // The execution of a node such as Enter may cause the completion of
+ // output_frame:output_iter, so perform cleanup if output_frame:output_iter
+ // is indeed completed.
+ if (maybe_completed) {
+ CleanupFramesIterations(output_frame, output_iter, ready);
+ }
+ }
+}
+
+void ExecutorState::ActivateNode(const Node* node, const bool is_dead,
+ FrameState* output_frame, int64 output_iter,
+ const EntryVector& outputs,
+ TaggedNodeSeq* ready) {
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ IterationState* output_iter_state = output_frame->GetIteration(output_iter);
+ std::vector<int>* pending = output_iter_state->pending_count;
+ std::vector<int>* dead_count = output_iter_state->dead_count;
+ for (const Edge* e : node->out_edges()) {
+ const Node* dst_node = e->dst();
+ const int dst_id = dst_node->id();
+ const int src_slot = e->src_output();
+
+ bool dst_dead = false;
+ bool dst_ready = false;
+ bool dst_need_input = !e->IsControlEdge();
+ if (IsMerge(dst_node)) {
+ // A merge node is ready if a) all control edges are enabled and a
+ // live data input becomes available, or b) all control edges are
+ // enabled and all data inputs are dead.
+ if (e->IsControlEdge()) {
+ (*pending)[dst_id] -= 2;
+ int count = (*pending)[dst_id];
+ dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs());
+ dst_ready = (count == 1) || ((count == 0) && dst_dead);
+ } else {
+ if (outputs[src_slot].has_value) {
+ // This is a live data input.
+ int count = (*pending)[dst_id];
+ (*pending)[dst_id] |= 0x1;
+ dst_ready = (count == 0);
+ } else {
+ // This is a dead data input.
+ ++(*dead_count)[dst_id];
+ dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs());
+ dst_ready = ((*pending)[dst_id] == 0) && dst_dead;
+ }
+ // This input for dst is not needed if !dst_ready. We suppress the
+ // propagation to make the thread safety analysis happy.
+ dst_need_input = dst_ready;
+ }
+ } else {
+ // A non-merge node is ready if all its inputs are ready. We wait
+ // for all inputs to come in even if we know the node is dead. This
+ // ensures that all input tensors get cleaned up.
+ if (is_dead || (!e->IsControlEdge() && !outputs[src_slot].has_value)) {
+ ++(*dead_count)[dst_id];
+ }
+ dst_dead = (*dead_count)[dst_id] > 0;
+ dst_ready = (--(*pending)[dst_id] == 0);
+ }
+
+ if (dst_need_input) {
+ const NodeItem& dst_item = nodes[dst_id];
+ const int dst_slot = e->dst_input();
+ std::vector<Entry>* input_tensors = output_iter_state->input_tensors;
+ int dst_loc = dst_item.input_start + dst_slot;
+ (*input_tensors)[dst_loc] = outputs[src_slot];
+ }
+
+ // Add dst to the ready queue if it's ready
+ if (dst_ready) {
+ dst_dead = dst_dead && !IsControlTrigger(dst_node);
+ ready->push_back(
+ TaggedNode(dst_node, output_frame, output_iter, dst_dead));
+ output_iter_state->outstanding_ops++;
+ }
+ }
+}
+
+void ExecutorState::ActivateNexts(FrameState* frame, int64 iter,
+ TaggedNodeSeq* ready) {
+ // Propagate the deferred NextIteration nodes to the new iteration.
+ for (auto& node_entry : frame->next_iter_roots) {
+ const Node* node = node_entry.first;
+ const Entry& entry = node_entry.second;
+ const bool is_dead = !entry.has_value;
+ ActivateNode(node, is_dead, frame, iter, {entry}, ready);
+ }
+ frame->next_iter_roots.clear();
+}
+
+void ExecutorState::ActivateLoopInvs(FrameState* frame, int64 iter,
+ TaggedNodeSeq* ready) {
+ // Propagate loop invariants to the new iteration.
+ for (auto& node_entry : frame->inv_values) {
+ const Node* node = node_entry.first;
+ const Entry& entry = node_entry.second;
+ const bool is_dead = !entry.has_value;
+ ActivateNode(node, is_dead, frame, iter, {entry}, ready);
+ }
+}
+
+void ExecutorState::AddLoopInv(FrameState* frame, const Node* node,
+ const Entry& entry, TaggedNodeSeq* ready) {
+ // Store this value.
+ frame->inv_values.push_back({node, entry});
+
+ // Make this value available to all iterations.
+ bool is_dead = !entry.has_value;
+ for (int i = 1; i <= frame->iteration_count; ++i) {
+ ActivateNode(node, is_dead, frame, i, {entry}, ready);
+ }
+}
+
+bool ExecutorState::NodeDone(const Status& s, const Node* node,
+ const TaggedNodeSeq& ready, NodeExecStats* stats,
+ std::deque<TaggedNode>* inline_ready) {
+ if (stats_collector_) {
+ nodestats::SetAllEnd(stats);
+ if (!SetTimelineLabel(node, stats)) {
+ // Only record non-transfer nodes.
+ stats_collector_->Save(impl_->params_.device->name(), stats);
+ } else {
+ delete stats;
+ }
+ }
+
+ Rendezvous* captured_rendezvous = nullptr; // Will be set on error.
+ if (!s.ok()) {
+ // Some error happened. This thread of computation is done.
+ mutex_lock l(mu_);
+ if (status_.ok()) {
+ captured_rendezvous = rendezvous_;
+ if (captured_rendezvous) captured_rendezvous->Ref();
+ status_ = s;
+ }
+ }
+ if (captured_rendezvous) {
+ // If we captured the rendezvous_ pointer, we are in an error condition.
+ // Use captured_rendezvous, in case "this" is deleted by another thread.
+ TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
+ captured_rendezvous->StartAbort(s);
+ captured_rendezvous->Unref();
+ }
+
+ bool completed = false;
+ int ready_size = ready.size();
+ if (ready_size == 0 || !s.ok()) {
+ completed = (num_outstanding_ops_.fetch_sub(1) == 1);
+ } else if (ready_size > 1) {
+ num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed);
+ }
+
+ // Schedule the ready nodes in 'ready'.
+ if (s.ok()) {
+ ScheduleReady(ready, inline_ready);
+ }
+ return completed;
+}
+
+void ExecutorState::ProcessInline(const std::deque<TaggedNode>& inline_ready) {
+ if (inline_ready.empty()) return;
+ int64 scheduled_usec = 0;
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ for (auto& tagged_node : inline_ready) {
+ Process(tagged_node, scheduled_usec);
+ }
+}
+
+void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
+ std::deque<TaggedNode>* inline_ready) {
+ if (ready.empty()) return;
+
+ int64 scheduled_usec = 0;
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ if (inline_ready == nullptr) {
+ // Schedule to run all the ready ops in thread pool.
+ for (auto& tagged_node : ready) {
+ runner_(std::bind(&ME::Process, this, tagged_node, scheduled_usec));
+ }
+ return;
+ }
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ const TaggedNode* curr_expensive_node = nullptr;
+ for (auto& tagged_node : ready) {
+ const NodeItem& item = nodes[tagged_node.node->id()];
+ if (tagged_node.is_dead || !item.kernel->IsExpensive()) {
+ // Inline this inexpensive node.
+ inline_ready->push_back(tagged_node);
+ } else {
+ if (curr_expensive_node) {
+ // Dispatch to another thread since there is plenty of work to
+ // do for this thread.
+ runner_(std::bind(&ME::Process, this, *curr_expensive_node,
+ scheduled_usec));
+ }
+ curr_expensive_node = &tagged_node;
+ }
+ }
+ if (curr_expensive_node) {
+ if (inline_ready->empty()) {
+ // Tail recursion optimization
+ inline_ready->push_back(*curr_expensive_node);
+ } else {
+ // There are inline nodes to run already. We dispatch this expensive
+ // node to other thread.
+ runner_(
+ std::bind(&ME::Process, this, *curr_expensive_node, scheduled_usec));
+ }
+ }
+}
+
+void ExecutorState::Finish() {
+ mu_.lock();
+ auto status = status_;
+ auto done_cb = done_cb_;
+ auto runner = runner_;
+ mu_.unlock();
+ delete this;
+ CHECK(done_cb != nullptr);
+ runner([done_cb, status]() { done_cb(status); });
+}
+
+bool ExecutorState::IsFrameDone(FrameState* frame) {
+ return (frame->num_pending_inputs == 0 &&
+ frame->num_outstanding_iterations == 0);
+}
+
+bool ExecutorState::IsIterationDone(FrameState* frame, int64 iter) {
+ IterationState* iter_state = frame->GetIteration(iter);
+ if (iter_state->outstanding_ops == 0 &&
+ iter_state->outstanding_frame_count == 0) {
+ if (iter == 0) {
+ // The enclosing frame has no pending input.
+ return frame->num_pending_inputs == 0;
+ } else {
+ // The preceding iteration is deleted (and therefore done).
+ return (frame->GetIteration(iter - 1) == nullptr);
+ }
+ }
+ return false;
+}
+
+void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
+ const Node* node,
+ FrameState** child) {
+ // Get the child frame name.
+ string enter_name;
+ Status s = GetNodeAttr(node->def(), "frame_name", &enter_name);
+ CHECK(s.ok()) << s;
+ const string child_name = MakeFrameName(frame, iter, enter_name);
+
+ auto it = outstanding_frames_.find(child_name);
+ if (it != outstanding_frames_.end()) {
+ *child = it->second;
+ } else {
+ // Need to create a new frame instance.
+ VLOG(2) << "Create frame: " << child_name;
+
+ FrameState* temp = new FrameState;
+ temp->frame_name = child_name;
+ temp->frame_id = Hash64(child_name);
+ temp->parent_frame = frame;
+ temp->parent_iter = iter;
+ s = GetNodeAttr(node->def(), "parallel_iterations",
+ &temp->max_parallel_iterations);
+ CHECK(s.ok()) << s;
+ // 'iterations' is a fixed-length circular buffer.
+ temp->iterations.resize(temp->max_parallel_iterations + 1);
+ IterationState* iter_state = new IterationState;
+ temp->iterations[0] = iter_state;
+
+ iter_state->outstanding_ops = 0;
+ iter_state->outstanding_frame_count = 0;
+ iter_state->pending_count = new std::vector<int>;
+ InitializePending(impl_->graph_, iter_state->pending_count);
+ iter_state->dead_count =
+ new std::vector<int>(impl_->graph_->num_node_ids());
+ iter_state->input_tensors = new std::vector<Entry>(impl_->total_tensors_);
+
+ auto frame_pending = impl_->frame_input_count_.find(enter_name);
+ DCHECK(frame_pending != impl_->frame_input_count_.end());
+ temp->num_pending_inputs = frame_pending->second;
+ temp->num_outstanding_iterations = 1;
+ *child = temp;
+
+ frame->GetIteration(iter)->outstanding_frame_count++;
+ outstanding_frames_[child_name] = temp;
+ }
+}
+
+void ExecutorState::IncrementIteration(FrameState* frame,
+ TaggedNodeSeq* ready) {
+ frame->iteration_count++;
+ int64 next_iter = frame->iteration_count;
+
+ VLOG(2) << "Create iteration: [" << frame->frame_name << ", " << next_iter
+ << "]";
+
+ IterationState* iter_state = new IterationState;
+ frame->SetIteration(next_iter, iter_state);
+ frame->num_outstanding_iterations++;
+ frame->dead_exits.clear();
+
+ iter_state->outstanding_ops = 0;
+ iter_state->outstanding_frame_count = 0;
+ iter_state->pending_count = new std::vector<int>;
+ InitializePending(impl_->graph_, iter_state->pending_count);
+ iter_state->dead_count = new std::vector<int>(impl_->graph_->num_node_ids());
+ iter_state->input_tensors = new std::vector<Entry>(impl_->total_tensors_);
+
+ // Activate the successors of the deferred roots in the new iteration.
+ ActivateNexts(frame, next_iter, ready);
+
+ // Activate the loop invariants in the new iteration.
+ ActivateLoopInvs(frame, next_iter, ready);
+}
+
+bool ExecutorState::SetOutputFrameIter(const TaggedNode& tagged_node,
+ const EntryVector& outputs,
+ FrameState** output_frame,
+ int64* output_iter,
+ TaggedNodeSeq* ready) {
+ const Node* node = tagged_node.node;
+ FrameState* input_frame = tagged_node.input_frame;
+ int64 input_iter = tagged_node.input_iter;
+ bool is_dead = tagged_node.is_dead;
+ bool is_enter = IsEnter(node);
+
+ if (is_enter) {
+ FindOrCreateChildFrame(input_frame, input_iter, node, output_frame);
+ // Propagate if this is a loop invariant.
+ bool is_constant;
+ Status s = GetNodeAttr(node->def(), "is_constant", &is_constant);
+ CHECK(s.ok()) << s;
+ if (is_constant) {
+ AddLoopInv(*output_frame, node, outputs[0], ready);
+ }
+ --(*output_frame)->num_pending_inputs;
+ *output_iter = 0;
+ } else if (IsExit(node)) {
+ if (is_dead) {
+ // Stop and remember this node if it is a dead exit.
+ if (input_iter == input_frame->iteration_count) {
+ input_frame->dead_exits.push_back(node);
+ }
+ *output_frame = nullptr;
+ } else {
+ *output_frame = input_frame->parent_frame;
+ *output_iter = input_frame->parent_iter;
+ }
+ } else if (IsNextIteration(node)) {
+ if (is_dead) {
+ // Stop the deadness propagation
+ *output_frame = nullptr;
+ } else {
+ if (input_iter == input_frame->iteration_count &&
+ input_frame->num_outstanding_iterations ==
+ input_frame->max_parallel_iterations) {
+ // Reached the maximum for parallel iterations.
+ input_frame->next_iter_roots.push_back({node, outputs[0]});
+ *output_frame = nullptr;
+ } else {
+ // If this is a new iteration, start it.
+ if (input_iter == input_frame->iteration_count) {
+ IncrementIteration(input_frame, ready);
+ }
+ *output_iter = input_iter + 1;
+ }
+ }
+ }
+ return is_enter;
+}
+
+void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter,
+ TaggedNodeSeq* ready) {
+ int64 curr_iter = iter;
+ while (curr_iter <= frame->iteration_count &&
+ IsIterationDone(frame, curr_iter)) {
+ // Delete the iteration curr_iter
+ VLOG(2) << "Delete iteration [" << frame->frame_name << ", " << curr_iter
+ << "].";
+
+ delete frame->GetIteration(curr_iter);
+ frame->SetIteration(curr_iter, nullptr);
+ --frame->num_outstanding_iterations;
+ ++curr_iter;
+
+ // If there is a deferred iteration, start it.
+ if (frame->next_iter_roots.size() > 0) {
+ IncrementIteration(frame, ready);
+ }
+ }
+
+ if (IsFrameDone(frame)) {
+ FrameState* parent_frame = frame->parent_frame;
+ int64 parent_iter = frame->parent_iter;
+
+ // Propagate all the dead exits to the parent frame.
+ for (const Node* node : frame->dead_exits) {
+ auto parent_iter_state = parent_frame->GetIteration(parent_iter);
+ std::vector<int>* pending = parent_iter_state->pending_count;
+ std::vector<int>* dead_count = parent_iter_state->dead_count;
+ for (const Edge* e : node->out_edges()) {
+ const Node* dst_node = e->dst();
+ const int dst_id = dst_node->id();
+
+ bool dst_dead = true;
+ bool dst_ready = false;
+ // We know this is a dead input to dst
+ if (IsMerge(dst_node)) {
+ if (e->IsControlEdge()) {
+ (*pending)[dst_id] -= 2;
+ int count = (*pending)[dst_id];
+ dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs());
+ dst_ready = (count == 1) || ((count == 0) && dst_dead);
+ } else {
+ ++(*dead_count)[dst_id];
+ dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs());
+ dst_ready = ((*pending)[dst_id] == 0) && dst_dead;
+ }
+ } else {
+ ++(*dead_count)[dst_id];
+ dst_ready = (--(*pending)[dst_id] == 0);
+ }
+ if (dst_ready) {
+ ready->push_back(
+ TaggedNode(dst_node, parent_frame, parent_iter, dst_dead));
+ parent_iter_state->outstanding_ops++;
+ }
+ }
+ }
+
+ // Delete the frame
+ const string& frame_name = frame->frame_name;
+ VLOG(2) << "Delete frame " << frame_name;
+ outstanding_frames_.erase(frame_name);
+ delete frame;
+
+ // Cleanup recursively
+ if (parent_frame != nullptr) {
+ parent_frame->GetIteration(parent_iter)->outstanding_frame_count--;
+ CleanupFramesIterations(parent_frame, parent_iter, ready);
+ }
+ }
+}
+
+// When ExecutorImpl graph has no control flow nodes,
+// SimpleExecutorState is used instead of ExecutorState. It maintains
+// fewer internal state and is convenient for experimenting with async
+// op kernels.
+class SimpleExecutorState {
+ public:
+ SimpleExecutorState(const Executor::Args& args, ExecutorImpl* impl);
+ ~SimpleExecutorState() {
+ for (auto it : device_context_map_) {
+ it.second->Unref();
+ }
+ delete slice_reader_cache_;
+ }
+ void RunAsync(Executor::DoneCallback done);
+
+ private:
+ typedef SimpleExecutorState ME;
+
+ // Not owned.
+ Rendezvous* rendezvous_;
+ StepStatsCollector* stats_collector_;
+ checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
+ FunctionCallFrame* call_frame_;
+ const ExecutorImpl* impl_;
+ CancellationManager* cancellation_manager_;
+ Executor::Args::Runner runner_;
+
+ // Owned.
+
+ // i-th node's j-th input is in tensors_[impl_->nodes[i].input_start
+ // + j]. The output is either a tensor pointer (pass-by-reference)
+ // or a tensor (pass-by-value).
+ //
+ // NOTE: Not protected by mu_ because tensors_ is resized once. Each
+ // element of tensors_ is written once by the source node of an edge
+ // and is cleared by the destination of the same edge. The latter
+ // node is never run concurrently with the former node.
+ struct Entry {
+ Tensor val = *kEmptyTensor; // A tensor value.
+ Tensor* ref = nullptr; // A tensor reference.
+ mutex* ref_mu = nullptr; // mutex for *ref if ref is not nullptr.
+
+ // Every entry carries an optional DeviceContext containing
+ // Device-specific information about how the Tensor was produced.
+ DeviceContext* device_context = nullptr;
+
+ // The attributes of the allocator that creates the tensor.
+ AllocatorAttributes alloc_attr;
+ };
+
+ // Contains a map from node id to the DeviceContext object that was
+ // assigned by the device at the beginning of a step.
+ DeviceContextMap device_context_map_;
+
+ std::vector<Entry> input_tensors_;
+
+ // Step-local resource manager.
+ ResourceMgr step_resource_manager_;
+
+ // Invoked when the execution finishes.
+ Executor::DoneCallback done_cb_;
+
+ // How many active threads of computation are being used. Same as
+ // the number of pending Process() functions.
+ std::atomic_int_fast32_t num_active_;
+
+ mutex mu_;
+ Status status_ GUARDED_BY(mu_);
+
+ // i-th kernel is still waiting for pending[i] inputs.
+ class CountDown {
+ public:
+ CountDown() : v_(0) {}
+ void Set(int32 v) { v_.store(v); }
+ bool Dec() {
+ return v_.load(std::memory_order_acquire) == 1 || v_.fetch_sub(1) == 1;
+ }
+
+ private:
+ std::atomic_int_fast32_t v_;
+ };
+ std::vector<CountDown> pending_;
+
+ // Process Node identified by "id" in current thread. "scheduled_usec"
+ // indicates when the node becomes ready and gets scheduled.
+ void Process(int id, int64 scheduled_usec);
+
+ // Before invoking item->kernel, fills in its "inputs".
+ Status PrepareInputs(const NodeItem& item, TensorValueVec* inputs,
+ DeviceContextVec* input_device_contexts);
+
+ // After item->kernel computation is done, processes its outputs
+ // and returns nodes that become "ready".
+ typedef gtl::InlinedVector<int, 8> ReadyNodeIds;
+ Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
+ ReadyNodeIds* ready, NodeExecStats* stats);
+
+ // "node" just finishes. Takes ownership of "stats". Returns true if
+ // execution has completed.
+ bool NodeDone(const Status& s, const Node* node, const ReadyNodeIds& ready,
+ NodeExecStats* stats, std::deque<int>* inline_ready);
+
+ // Call Process() on all nodes in 'inline_ready'.
+ void ProcessInline(const std::deque<int>& inline_ready);
+
+ // Schedule all the expensive nodes in 'ready', and put all the inexpensive
+ // nodes in 'ready' into 'inline_ready'.
+ void ScheduleReady(const ReadyNodeIds& ready, std::deque<int>* inline_ready);
+
+ // One thread of control finishes.
+ void Finish();
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SimpleExecutorState);
+};
+
+SimpleExecutorState::SimpleExecutorState(const Executor::Args& args,
+ ExecutorImpl* impl)
+ : rendezvous_(args.rendezvous),
+ stats_collector_(args.stats_collector),
+ slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
+ call_frame_(args.call_frame),
+ impl_(impl),
+ cancellation_manager_(args.cancellation_manager),
+ runner_(args.runner),
+ num_active_(0),
+ pending_(impl_->nodes_.size()) {}
+
+void SimpleExecutorState::ProcessInline(const std::deque<int>& inline_ready) {
+ if (inline_ready.empty()) return;
+ int64 scheduled_usec = 0;
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ for (int id : inline_ready) {
+ Process(id, scheduled_usec);
+ }
+}
+
+void SimpleExecutorState::ScheduleReady(const ReadyNodeIds& ready,
+ std::deque<int>* inline_ready) {
+ if (ready.empty()) return;
+
+ int64 scheduled_usec = 0;
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ if (inline_ready == nullptr) {
+ // Schedule to run all the ready ops in thread pool.
+ for (auto id : ready) {
+ runner_(std::bind(&ME::Process, this, id, scheduled_usec));
+ }
+ return;
+ }
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ int curr_expensive_node = -1;
+ for (auto id : ready) {
+ if (!nodes[id].kernel->IsExpensive()) {
+ // Inline this inexpensive node.
+ inline_ready->push_back(id);
+ } else {
+ if (curr_expensive_node != -1) {
+ // Dispatch to another thread since there is plenty of work to
+ // do for this thread.
+ runner_(
+ std::bind(&ME::Process, this, curr_expensive_node, scheduled_usec));
+ }
+ curr_expensive_node = id;
+ }
+ }
+ if (curr_expensive_node != -1) {
+ if (inline_ready->empty()) {
+ // Tail recursion optimization
+ inline_ready->push_back(curr_expensive_node);
+ } else {
+ // There are inline nodes to run already. We dispatch this expensive
+ // node to other thread.
+ runner_(
+ std::bind(&ME::Process, this, curr_expensive_node, scheduled_usec));
+ }
+ }
+}
+
+void SimpleExecutorState::RunAsync(Executor::DoneCallback done) {
+ const Graph* graph = impl_->graph_;
+ ReadyNodeIds ready;
+
+ // Ask the device to fill in the device context map.
+ Device* device = impl_->params_.device;
+ device->FillContextMap(graph, &device_context_map_);
+
+ for (const Node* n : graph->nodes()) {
+ const int id = n->id();
+ const int num_in_edges = n->in_edges().size();
+ pending_[id].Set(num_in_edges);
+ if (num_in_edges == 0) {
+ ready.push_back(id);
+ }
+ }
+ if (ready.empty()) {
+ done(Status::OK());
+ } else {
+ num_active_ = ready.size();
+ done_cb_ = done;
+ input_tensors_.resize(impl_->total_tensors_);
+ // Schedule to run all the ready ops in thread pool.
+ ScheduleReady(ready, nullptr);
+ }
+}
+
+Status SimpleExecutorState::PrepareInputs(
+ const NodeItem& item, TensorValueVec* inputs,
+ DeviceContextVec* input_device_contexts) {
+ const Node* node = item.node;
+
+ inputs->clear();
+ inputs->resize(node->num_inputs());
+ input_device_contexts->clear();
+ input_device_contexts->resize(node->num_inputs());
+
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ const bool expect_ref = IsRefType(node->input_type(i));
+ Entry* entry = input_tensors_.data() + item.input_start + i;
+ (*input_device_contexts)[i] = entry->device_context;
+
+ // i-th input.
+ TensorValue* inp = &(*inputs)[i];
+
+ if (entry->ref == nullptr) {
+ if (expect_ref) {
+ return AttachDef(
+ errors::InvalidArgument(i, "-th input expects a ref type"),
+ item.kernel->def());
+ }
+ inp->tensor = &entry->val;
+ } else {
+ if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) {
+ return AttachDef(
+ errors::FailedPrecondition("Attempting to use uninitialized value ",
+ item.kernel->def().input(i)),
+ item.kernel->def());
+ }
+ if (expect_ref) {
+ inp->mutex_if_ref = entry->ref_mu;
+ inp->tensor = entry->ref;
+ } else {
+ // Automatically deref the tensor ref when the op expects a
+ // tensor but is given a ref to a tensor. Need to deref it
+ // under the mutex.
+ {
+ mutex_lock l(*(entry->ref_mu));
+ entry->val = *entry->ref;
+ }
+ inp->tensor = &entry->val;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+void SimpleExecutorState::Process(int id, int64 scheduled_usec) {
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ ReadyNodeIds ready;
+ std::deque<int> inline_ready;
+
+ // Parameters passed to OpKernel::Compute.
+ TensorValueVec inputs;
+ DeviceContextVec input_device_contexts;
+
+ OpKernelContext::Params params;
+ Device* device = impl_->params_.device;
+ params.device = device;
+ // track allocations if and only if we are collecting statistics
+ params.track_allocations = (stats_collector_ != nullptr);
+ params.rendezvous = rendezvous_;
+ params.cancellation_manager = cancellation_manager_;
+ params.call_frame = call_frame_;
+ params.function_library = impl_->params_.function_library;
+ params.resource_manager = device->resource_manager();
+ params.step_resource_manager = &step_resource_manager_;
+ params.slice_reader_cache = slice_reader_cache_;
+ params.inputs = &inputs;
+ params.input_device_contexts = &input_device_contexts;
+ params.frame_iter = FrameAndIter(0, 0);
+
+ Status s;
+ NodeExecStats* stats = nullptr;
+ bool completed = false;
+ inline_ready.push_back(id);
+ while (!inline_ready.empty()) {
+ id = inline_ready.front();
+ inline_ready.pop_front();
+ const NodeItem& item = nodes[id];
+ const Node* node = item.node;
+
+ // Set the device_context for this node id, if it exists.
+ auto dc_it = device_context_map_.find(id);
+ if (dc_it != device_context_map_.end()) {
+ params.op_device_context = dc_it->second;
+ }
+
+ if (stats_collector_) {
+ stats = new NodeExecStats;
+ stats->set_node_name(node->name());
+ nodestats::SetScheduled(stats, scheduled_usec);
+ nodestats::SetAllStart(stats);
+ }
+
+ VLOG(1) << "Process node: " << id << " " << SummarizeNodeDef(node->def());
+
+ // Prepares inputs.
+ s = PrepareInputs(item, &inputs, &input_device_contexts);
+ if (!s.ok()) {
+ // Continue to process the nodes in 'inline_ready'.
+ completed = NodeDone(s, item.node, ready, stats, &inline_ready);
+ continue;
+ }
+
+ OpKernel* op_kernel = item.kernel;
+ params.op_kernel = op_kernel;
+ params.output_alloc_attr = [this, node, op_kernel](int index) {
+ return OutputAttributes(&impl_->alloc_attr_, node, op_kernel, index);
+ };
+
+ // Asynchronous computes.
+ AsyncOpKernel* async = op_kernel->AsAsync();
+ if (async) {
+ auto pcopy = CopyParams(params);
+ auto ctx = new OpKernelContext(*pcopy);
+ auto done = [this, item, ctx, stats, pcopy]() {
+ VLOG(2) << this
+ << " Async kernel done: " << SummarizeNodeDef(item.node->def());
+ if (stats_collector_) nodestats::SetOpEnd(stats);
+ ReadyNodeIds ready;
+ Status s = ProcessOutputs(item, ctx, &ready, stats);
+ if (stats_collector_) nodestats::SetMemory(stats, ctx);
+ // Schedule to run all the ready ops in thread pool.
+ bool completed = NodeDone(s, item.node, ready, stats, nullptr);
+ delete ctx;
+ DeleteParams(pcopy);
+ if (completed) Finish();
+ };
+ if (stats_collector_) nodestats::SetOpStart(stats);
+ device->ComputeAsync(async, ctx, done);
+ } else {
+ // Synchronous computes.
+ OpKernelContext ctx(params);
+ if (stats_collector_) nodestats::SetOpStart(stats);
+ device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
+ if (stats_collector_) nodestats::SetOpEnd(stats);
+
+ s = ProcessOutputs(item, &ctx, &ready, stats);
+ if (stats_collector_) nodestats::SetMemory(stats, &ctx);
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ completed = NodeDone(s, node, ready, stats, &inline_ready);
+ }
+ } // while !inline_ready.empty()
+
+ // This thread of computation is done if completed = true.
+ if (completed) Finish();
+}
+
+bool SimpleExecutorState::NodeDone(const Status& s, const Node* node,
+ const ReadyNodeIds& ready,
+ NodeExecStats* stats,
+ std::deque<int>* inline_ready) {
+ if (stats_collector_) {
+ nodestats::SetAllEnd(stats);
+ if (!SetTimelineLabel(node, stats)) {
+ // Only record non-transfer nodes.
+ stats_collector_->Save(impl_->params_.device->name(), stats);
+ } else {
+ delete stats;
+ }
+ }
+
+ Rendezvous* captured_rendezvous = nullptr; // Will be set on error.
+ if (!s.ok()) {
+ // Some error happened. This thread of computation is done.
+ mutex_lock l(mu_);
+ if (status_.ok()) {
+ captured_rendezvous = rendezvous_;
+ if (captured_rendezvous) captured_rendezvous->Ref();
+ status_ = s;
+ }
+ }
+ if (captured_rendezvous) {
+ // If we captured the rendezvous_ pointer, we are in an error condition.
+ // Use captured_rendezvous, in case "this" is deleted by another thread.
+ TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
+ captured_rendezvous->StartAbort(s);
+ captured_rendezvous->Unref();
+ }
+
+ bool completed = false;
+ int ready_size = ready.size();
+ if (ready_size == 0 || !s.ok()) {
+ completed = (num_active_.fetch_sub(1) == 1);
+ } else if (ready_size > 1) {
+ num_active_.fetch_add(ready_size - 1, std::memory_order_relaxed);
+ }
+
+ // Schedule the ready nodes in 'ready'.
+ if (s.ok()) {
+ ScheduleReady(ready, inline_ready);
+ }
+ return completed;
+}
+
+void SimpleExecutorState::Finish() {
+ mu_.lock();
+ auto ret = status_;
+ auto done_cb = done_cb_;
+ auto runner = runner_;
+ mu_.unlock();
+ delete this;
+ CHECK(done_cb != nullptr);
+ runner([done_cb, ret]() { done_cb(ret); });
+}
+
+Status SimpleExecutorState::ProcessOutputs(const NodeItem& item,
+ OpKernelContext* ctx,
+ ReadyNodeIds* ready,
+ NodeExecStats* stats) {
+ Status s = ctx->status();
+ if (!s.ok()) {
+ s = AttachDef(s, item.kernel->def());
+ LOG(WARNING) << this << " Compute status: " << s;
+ return s;
+ }
+
+ // Processes outputs.
+ gtl::InlinedVector<Entry, 4> outputs;
+ const Node* node = item.node;
+ outputs.resize(node->num_outputs());
+
+ // Get the device_context for this node id, if it exists.
+ DeviceContext* device_context = nullptr;
+ auto dc_it = device_context_map_.find(node->id());
+ if (dc_it != device_context_map_.end()) {
+ device_context = dc_it->second;
+ }
+
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ TensorValue val = ctx->release_output(i);
+ // Sanity check of output tensor types.
+ DataType dtype = val->dtype();
+ if (val.is_ref()) dtype = MakeRefType(dtype);
+ if (dtype == node->output_type(i)) {
+ Entry* out = &(outputs[i]);
+ if (val.is_ref()) {
+ out->ref = val.tensor;
+ out->ref_mu = val.mutex_if_ref;
+ } else {
+ out->val = *val.tensor;
+ }
+
+ // Set the device context of the output entry.
+ out->device_context = device_context;
+
+ // Set the allocator attributes of the output entry.
+ out->alloc_attr = ctx->output_alloc_attr(i);
+
+ if (stats_collector_ && val.tensor->IsInitialized()) {
+ nodestats::SetOutput(stats, i, ctx->output_allocation_type(i),
+ val.tensor);
+ }
+ } else {
+ s.Update(
+ errors::Internal("Output ", i, " of type ", DataTypeString(dtype),
+ " does not match declared output type ",
+ DataTypeString(node->output_type(i)),
+ " for operation ", SummarizeNodeDef(node->def())));
+ }
+ if (!val.is_ref()) {
+ // If OpKernelContext returns outputs via pass-by-value, we
+ // don't need this trouble.
+ delete val.tensor;
+ }
+ }
+ if (!s.ok()) return s;
+
+ // Clears inputs.
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ input_tensors_[item.input_start + i].val = *kEmptyTensor;
+ }
+
+ // Propagates outputs along out edges.
+ ready->clear();
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ for (const Edge* e : node->out_edges()) {
+ const int src_slot = e->src_output();
+ const int dst_id = e->dst()->id();
+ const NodeItem& dst_item = nodes[dst_id];
+ if (!e->IsControlEdge()) {
+ const int dst_slot = e->dst_input();
+ input_tensors_[dst_item.input_start + dst_slot] = outputs[src_slot];
+ }
+ if (pending_[dst_id].Dec()) {
+ ready->push_back(dst_id);
+ }
+ }
+ return Status::OK();
+}
+
+// NOTE(yuanbyu): Use the executor that supports control flow by default.
+const bool use_control_flow_executor = true;
+void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
+ if (params_.has_control_flow || use_control_flow_executor) {
+ (new ExecutorState(args, this))->RunAsync(done);
+ } else {
+ (new SimpleExecutorState(args, this))->RunAsync(done);
+ }
+}
+
+} // end namespace
+
+Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph,
+ Executor** executor) {
+ ExecutorImpl* impl = new ExecutorImpl(params, graph);
+ Status s = impl->Initialize();
+ if (s.ok()) {
+ *executor = impl;
+ } else {
+ delete impl;
+ }
+ return s;
+}
+
+Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
+ const NodeDef& ndef, OpKernel** kernel) {
+ auto device_type = DeviceType(device->attributes().device_type());
+ auto allocator = device->GetAllocator(AllocatorAttributes());
+ return CreateOpKernel(device_type, device, allocator, flib, ndef, kernel);
+}
+
+void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
+
+Status CreateCachedKernel(Device* device, const string& session,
+ FunctionLibraryRuntime* flib, const NodeDef& ndef,
+ OpKernel** kernel) {
+ auto op_seg = device->op_segment();
+ auto create_fn = [device, flib, &ndef](OpKernel** kernel) {
+ return CreateNonCachedKernel(device, flib, ndef, kernel);
+ };
+ return op_seg->FindOrCreate(session, ndef.name(), kernel, create_fn);
+}
+
+// Deletes "kernel".
+void DeleteCachedKernel(Device* device, const string& session,
+ OpKernel* kernel) {
+ // Do nothing.
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
new file mode 100644
index 0000000000..82bcbab836
--- /dev/null
+++ b/tensorflow/core/common_runtime/executor.h
@@ -0,0 +1,209 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
+#define TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+class StepStatsCollector;
+
+// Executor runs a graph computation.
+// Example:
+// Graph* graph = ...;
+// ... construct graph ...
+// Executor* executor;
+// TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor));
+// Rendezvous* rendezvous = NewNaiveRendezvous();
+// TF_CHECK_OK(rendezvous->Send("input", some_input_tensor));
+// TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr}));
+// TF_CHECK_OK(rendezvous->Recv("input", &output_tensor));
+// ... ...
+//
+// Multiple threads can call Executor::Run concurrently.
+class Executor {
+ public:
+ virtual ~Executor() {}
+
+ // RunAsync() executes the graph computation. "done" is run when the
+ // graph computation completes. If any error happens during the
+ // computation, "done" is run and the error is passed to "done".
+ //
+ // RunAsync() is given a few arguments in Args. The caller must
+ // ensure objects passed in Args (rendezvous, stats_collector, etc.)
+ // are alive at least until done is invoked. All pointers to the
+ // argument objects can be nullptr.
+ //
+ // RunAsync() uses the given "rendezvous", if not null, as the
+ // mechanism to communicate inputs and outputs of the underlying
+ // graph computation.
+ //
+ // RunAsync() calls "stats_collector", if not null, to keep track of
+ // stats. This allows us to collect statistics and traces on demand.
+ //
+ // RunAsync() is provided a "call_frame", if the executor is used
+ // for executing a function, is used to pass arguments and return
+ // values between the caller and the callee.
+ //
+ // RunAsync() uses "cancellation_manager", if not nullptr, to
+ // register callbacks that should be called if the graph computation
+ // is cancelled. Note that the callbacks merely unblock any
+ // long-running computation, and a cancelled step will terminate by
+ // returning/calling the DoneCallback as usual.
+ //
+ // RunAsync() dispatches closures to "runner". Typically, "runner"
+ // is backed up by a bounded threadpool.
+ struct Args {
+ Rendezvous* rendezvous = nullptr;
+ StepStatsCollector* stats_collector = nullptr;
+ FunctionCallFrame* call_frame = nullptr;
+ CancellationManager* cancellation_manager = nullptr;
+
+ typedef std::function<void()> Closure;
+ typedef std::function<void(Closure)> Runner;
+ Runner runner = nullptr;
+ };
+ typedef std::function<void(const Status&)> DoneCallback;
+ virtual void RunAsync(const Args& args, DoneCallback done) = 0;
+
+ // Synchronous wrapper for RunAsync().
+ Status Run(const Args& args) {
+ Status ret;
+ Notification n;
+ RunAsync(args, [&ret, &n](const Status& s) {
+ ret = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return ret;
+ }
+};
+
+// Creates an Executor that computes the given "graph".
+//
+// If successful, returns the constructed executor in "*executor". The
+// caller keeps the ownership of "device". The returned executor takes
+// the ownership of "graph". Otherwise, returns an error status.
+//
+// "params" provides a set of context for the executor. We expect that
+// different context would provide different implementations.
+struct LocalExecutorParams {
+ Device* device;
+
+ // The library runtime support.
+ FunctionLibraryRuntime* function_library;
+
+ // True iff the computation contains control flow nodes.
+ bool has_control_flow;
+
+ // create_kernel returns an instance of op kernel based on NodeDef.
+ // delete_kernel is called for every kernel used by the executor
+ // when the executor is deleted.
+ std::function<Status(const NodeDef&, OpKernel**)> create_kernel;
+ std::function<void(OpKernel*)> delete_kernel;
+};
+::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
+ const Graph* graph, Executor** executor);
+
+// A class to help run multiple executors in parallel and wait until
+// all of them are complete.
+//
+// ExecutorBarrier deletes itself after the function returned by Get()
+// is called.
+class ExecutorBarrier {
+ public:
+ typedef std::function<void(const Status&)> StatusCallback;
+
+ // Create an ExecutorBarrier for 'num' different executors.
+ //
+ // 'r' is the shared Rendezvous object that is used to communicate
+ // state. If any of the executors experiences an error, the
+ // rendezvous object will be aborted exactly once.
+ //
+ // 'done' is called after the last executor completes, and
+ // ExecutorBarrier is deleted.
+ ExecutorBarrier(int num, Rendezvous* r, StatusCallback done)
+ : rendez_(r), done_cb_(done), pending_(num) {}
+
+ ~ExecutorBarrier() {}
+
+ // Returns a closure that Executors must call when they are done
+ // computing, passing the status of their execution as an argument.
+ StatusCallback Get() {
+ return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1);
+ }
+
+ private:
+ Rendezvous* rendez_ = nullptr;
+ StatusCallback done_cb_ = nullptr;
+
+ mutable mutex mu_;
+ int pending_ GUARDED_BY(mu_) = 0;
+ Status status_ GUARDED_BY(mu_);
+
+ void WhenDone(const Status& s) {
+ bool error = false;
+ StatusCallback done = nullptr;
+ Status status;
+ {
+ mutex_lock l(mu_);
+ // If we are the first error encountered, mark the status
+ // appropriately and later trigger an abort of the Rendezvous
+ // object by this thread only.
+ if (status_.ok() && !s.ok()) {
+ error = true;
+ status_ = s;
+ }
+
+ // If this is the last call to WhenDone, call the final callback
+ // below.
+ if (--pending_ == 0) {
+ CHECK(done_cb_ != nullptr);
+ done = done_cb_;
+ done_cb_ = nullptr;
+ }
+ status = status_;
+ }
+ if (error) {
+ rendez_->StartAbort(status);
+ }
+ if (done != nullptr) {
+ delete this;
+ done(status);
+ }
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier);
+};
+
+// A few helpers to facilitate create/delete kernels.
+
+// Creates a kernel based on "ndef" on device "device". The kernel can
+// access the functions in the "flib". The caller takes ownership of
+// returned "*kernel".
+Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
+ const NodeDef& ndef, OpKernel** kernel);
+
+// Deletes "kernel" returned by CreateKernel.
+void DeleteNonCachedKernel(OpKernel* kernel);
+
+// Creates a kernel based on "ndef" on device "device". The kernel can
+// access the functions in the "flib". The caller does not take
+// ownership of returned "*kernel". If a kernel has been created for
+// ndef.name(), returns the same kernel instance.
+Status CreateCachedKernel(Device* device, const string& session,
+ FunctionLibraryRuntime* flib, const NodeDef& ndef,
+ OpKernel** kernel);
+
+// Deletes "kernel" returned by CreateCachedKernel.
+void DeleteCachedKernel(Device* device, const string& session,
+ OpKernel* kernel);
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
new file mode 100644
index 0000000000..2b1a041235
--- /dev/null
+++ b/tensorflow/core/common_runtime/function.cc
@@ -0,0 +1,1335 @@
+#include "tensorflow/core/common_runtime/function.h"
+
+#include <deque>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/optimizer_cse.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+
+// A few string constant used throughout this module.
+static const char* const kArgOp = "_Arg";
+static const char* const kRetOp = "_Retval";
+static const char* const kGradientOp = "SymbolicGradient";
+static const char* const kNodeLabel = "Func";
+
+// Represents the index-th output of a node.
+struct Endpoint {
+ Node* node;
+ int index;
+
+ // Returns the string name represents this endpoint.
+ string name() const {
+ if (index == 0) {
+ return node->name();
+ } else {
+ return strings::StrCat(node->name(), ":", index);
+ }
+ }
+
+ DataType dtype() const { return node->output_type(index); }
+};
+
+struct EndpointHash {
+ uint64 operator()(const Endpoint& x) const {
+ return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
+ x.index);
+ }
+};
+
+struct EndpointEq {
+ bool operator()(const Endpoint& x, const Endpoint& y) const {
+ return (x.node == y.node) && (x.index == y.index);
+ }
+};
+
+// The following Add* routines are used to add a few graph nodes while
+// functions are transformed.
+static Node* AddNoOp(Graph* g) {
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("NoOp");
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ return ret;
+}
+
+static Node* AddIdentity(Graph* g, Endpoint input) {
+ DCHECK_LT(0, input.dtype());
+ DCHECK_LT(input.dtype(), DT_FLOAT_REF);
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("Identity");
+ ndef.add_input(input.name());
+ AddNodeAttr("T", input.dtype(), &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(input.node, input.index, ret, 0);
+ return ret;
+}
+
+static Node* AddArg(Graph* g, DataType dtype, int index) {
+ DCHECK_LT(0, dtype);
+ DCHECK_LT(dtype, DT_FLOAT_REF);
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op(kArgOp);
+ AddNodeAttr("T", dtype, &ndef);
+ AddNodeAttr("index", index, &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ return ret;
+}
+
+static Node* AddRet(Graph* g, Endpoint input, int index) {
+ DCHECK_LT(0, input.dtype());
+ DCHECK_LT(input.dtype(), DT_FLOAT_REF);
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op(kRetOp);
+ ndef.add_input(input.name());
+ AddNodeAttr("T", input.dtype(), &ndef);
+ AddNodeAttr("index", index, &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(input.node, input.index, ret, 0);
+ return ret;
+}
+
+static Node* AddZerosLike(Graph* g, Endpoint input) {
+ DCHECK_LT(0, input.dtype());
+ DCHECK_LT(input.dtype(), DT_FLOAT_REF);
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("ZerosLike");
+ ndef.add_input(input.name());
+ AddNodeAttr("T", input.dtype(), &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(input.node, input.index, ret, 0);
+ return ret;
+}
+
+static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<Endpoint> grads) {
+ const int num_x = n->num_inputs();
+ const int num_y = n->num_outputs();
+ CHECK_EQ(num_y, grads.size());
+
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op(kGradientOp);
+
+ // The gradient node should have num_x + num_y inputs.
+ std::vector<Endpoint> n_inputs(num_x);
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ n_inputs[e->dst_input()] = {e->src(), e->src_output()};
+ }
+ DataTypeVector in_types;
+ for (const Endpoint& ep : n_inputs) {
+ ndef.add_input(ep.name());
+ in_types.push_back(ep.dtype());
+ }
+ for (const Endpoint& ep : grads) {
+ ndef.add_input(ep.name());
+ in_types.push_back(ep.dtype());
+ }
+ CHECK_EQ(ndef.input_size(), num_x + num_y);
+
+ AddNodeAttr("Tin", in_types, &ndef);
+
+ // The gradient node's outputs have the same types as the node 'n's
+ // inputs.
+ AddNodeAttr("Tout", n->input_types(), &ndef);
+ NameAttrList func;
+ func.set_name(n->type_string());
+ *(func.mutable_attr()) = n->def().attr();
+ AddNodeAttr("f", func, &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ return ret;
+}
+
+class ArgOp : public OpKernel {
+ public:
+ explicit ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ auto frame = ctx->call_frame();
+ OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
+ Tensor val;
+ OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
+ OP_REQUIRES(ctx, val.dtype() == dtype_,
+ errors::InvalidArgument(
+ "Type mismatch: actual ", DataTypeString(val.dtype()),
+ " vs. expect ", DataTypeString(dtype_)));
+ ctx->set_output(0, val);
+ }
+
+ private:
+ int index_;
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ArgOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp);
+REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_GPU), ArgOp);
+
+class RetvalOp : public OpKernel {
+ public:
+ explicit RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& val = ctx->input(0);
+ OP_REQUIRES(ctx, val.dtype() == dtype_,
+ errors::InvalidArgument(
+ "Type mismatch: actual ", DataTypeString(val.dtype()),
+ " vs. expect ", DataTypeString(dtype_)));
+ auto frame = ctx->call_frame();
+ OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
+ OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val));
+ }
+
+ private:
+ int index_;
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
+REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_GPU), RetvalOp);
+
+static const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
+
+class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
+ public:
+ FunctionLibraryRuntimeImpl(Device* device, Runner runner,
+ const FunctionLibraryDefinition* lib_def);
+
+ ~FunctionLibraryRuntimeImpl() override;
+
+ Status Instantiate(const string& function_name,
+ const InstantiateAttrValueMap& attrs,
+ Handle* handle) override;
+
+ const FunctionBody* GetFunctionBody(Handle handle) override;
+
+ Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override;
+
+ void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
+ std::vector<Tensor>* rets, DoneCallback done) override;
+
+ bool IsDefined(const string& function_name) override;
+
+ private:
+ typedef FunctionLibraryRuntimeImpl ME;
+
+ Device* const device_;
+ Runner runner_ = nullptr;
+ const FunctionLibraryDefinition* const lib_def_;
+ std::function<Status(const string&, const OpDef**)> get_func_sig_;
+ std::function<Status(const NodeDef&, OpKernel**)> create_kernel_;
+
+ mutable mutex mu_;
+
+ // Maps function instantiation to a handle. The key is a
+ // canonicalized representation of the function name and
+ // instantiation attrs. The handle is an index into the items_.
+ std::unordered_map<string, Handle> table_ GUARDED_BY(mu_);
+
+ // func_graphs_ never shrinks or reorders its members.
+ std::vector<FunctionBody*> func_graphs_ GUARDED_BY(mu_);
+
+ // The instantiated and transformed function is encoded as a Graph
+ // object, and an executor is created for the graph.
+ struct Item : public core::RefCounted {
+ Executor* exec = nullptr;
+
+ ~Item() override { delete this->exec; }
+ };
+ std::vector<Item*> items_;
+
+ Status FunctionDefToBody(const FunctionDef& fdef,
+ const InstantiateAttrValueMap& attrs,
+ FunctionBody** fbody);
+ Status CreateItem(Handle handle, Item** item);
+ Status GetOrCreateItem(Handle handle, Item** item);
+ Status InstantiateSymbolicGradient(const InstantiateAttrValueMap& attrs,
+ FunctionBody** g_body);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
+};
+
+FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
+ Device* device, Runner runner, const FunctionLibraryDefinition* lib_def)
+ : device_(device), runner_(runner), lib_def_(lib_def) {
+ get_func_sig_ = [this](const string& op, const OpDef** sig) {
+ Status s;
+ *sig = lib_def_->LookUp(op, &s);
+ return s;
+ };
+ create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) {
+ return CreateKernel(ndef, kernel);
+ };
+}
+
+FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
+ for (FunctionBody* p : func_graphs_) delete p;
+ for (Item* item : items_)
+ if (item) item->Unref();
+}
+
+// An asynchronous op kernel which executes an instantiated function
+// defined in a library.
+class CallOp : public AsyncOpKernel {
+ public:
+ CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx), handle_(handle) {}
+
+ ~CallOp() override {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ FunctionLibraryRuntime* lib = ctx->function_library();
+ OP_REQUIRES_ASYNC(ctx, lib != nullptr,
+ errors::Internal("No function library is provided."),
+ done);
+ FunctionLibraryRuntime::Options opts;
+ std::vector<Tensor> args;
+ args.reserve(ctx->num_inputs());
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ args.push_back(ctx->input(i));
+ }
+ std::vector<Tensor>* rets = new std::vector<Tensor>;
+ lib->Run(opts, handle_, args, rets,
+ [ctx, done, rets](const Status& status) {
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ } else {
+ CHECK_EQ(rets->size(), ctx->num_outputs());
+ for (size_t i = 0; i < rets->size(); ++i) {
+ ctx->set_output(i, (*rets)[i]);
+ }
+ }
+ delete rets;
+ done();
+ });
+ }
+
+ private:
+ FunctionLibraryRuntime::Handle handle_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CallOp);
+};
+
+const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
+ mutex_lock l(mu_);
+ CHECK_LE(0, h);
+ CHECK_LT(h, func_graphs_.size());
+ return func_graphs_[h];
+}
+
+Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
+ OpKernel** kernel) {
+ if (ndef.op() != kGradientOp && (lib_def_->Find(ndef.op()) == nullptr)) {
+ return CreateNonCachedKernel(device_, this, ndef, kernel);
+ }
+
+ // Try to instantiate this function for the func/attr. Maybe its
+ // cached already.
+ Handle handle;
+ TF_RETURN_IF_ERROR(Instantiate(ndef.op(), ndef.attr(), &handle));
+
+ const FunctionBody* fbody = GetFunctionBody(handle);
+ CHECK_NOTNULL(fbody);
+
+ // Constructs a CallOp kernel for running the instantiated function.
+ Status s;
+ auto device_type = DeviceType(device_->attributes().device_type());
+ OpKernelConstruction construction(
+ device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
+ &fbody->fdef.signature(), this, fbody->arg_types, fbody->ret_types, &s);
+ *kernel = new CallOp(handle, &construction);
+ if (!s.ok()) {
+ delete kernel;
+ }
+ return s;
+}
+
+Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
+ const FunctionDef& fdef, const InstantiateAttrValueMap& attrs,
+ FunctionBody** fbody) {
+ // Instantiates the function template into a graph def.
+ InstantiationResult result;
+ TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig_, &result));
+
+ Graph* graph = new Graph(lib_def_);
+ GraphConstructorOptions opts;
+ opts.allow_internal_ops = true;
+ opts.expect_device_spec = false;
+ Status s = ConvertGraphDefToGraph(opts, result.gdef, graph);
+ if (!s.ok()) {
+ delete graph;
+ } else {
+ *fbody = new FunctionBody(fdef, result.arg_types, result.ret_types, graph);
+ }
+ return s;
+}
+
+Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
+ const InstantiateAttrValueMap& attrs, FunctionBody** g_body) {
+ const AttrValue* f = gtl::FindOrNull(attrs, "f");
+ if (f == nullptr) {
+ return errors::InvalidArgument("SymbolicGradient is missing attr: f");
+ }
+ const auto& func = f->func();
+ const FunctionDef* fdef = lib_def_->Find(func.name());
+ if (fdef == nullptr) {
+ // f is a primitve op.
+ gradient::Creator creator;
+ TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
+ if (creator == nullptr) {
+ return errors::InvalidArgument("No gradient is defined for ",
+ func.name());
+ }
+ FunctionDef grad_fdef;
+ TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
+ TF_RETURN_IF_ERROR(FunctionDefToBody(grad_fdef, func.attr(), g_body));
+ } else {
+ // f is a user-defined function.
+ Handle f_handle;
+ TF_RETURN_IF_ERROR(Instantiate(func.name(), func.attr(), &f_handle));
+ const FunctionBody* f_body = GetFunctionBody(f_handle);
+ CHECK_NOTNULL(f_body);
+ *g_body = SymbolicGradient(*f_body);
+ }
+ return Status::OK();
+}
+
+Status FunctionLibraryRuntimeImpl::Instantiate(
+ const string& function_name, const InstantiateAttrValueMap& attrs,
+ Handle* handle) {
+ const string key = Canonicalize(function_name, attrs);
+ {
+ mutex_lock l(mu_);
+ *handle = gtl::FindWithDefault(table_, key, kInvalidHandle);
+ if (*handle != kInvalidHandle) {
+ return Status::OK();
+ }
+ }
+
+ Status s;
+ FunctionBody* fbody = nullptr;
+ if (function_name == kGradientOp) {
+ TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(attrs, &fbody));
+ } else {
+ const FunctionDef* fdef = lib_def_->Find(function_name);
+ if (fdef == nullptr) {
+ return errors::NotFound("Function ", function_name, " is not defined.");
+ }
+ TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, &fbody));
+ }
+
+ {
+ mutex_lock l(mu_);
+ *handle = gtl::FindWithDefault(table_, key, kInvalidHandle);
+ if (*handle != kInvalidHandle) {
+ delete fbody;
+ } else {
+ *handle = func_graphs_.size();
+ table_.insert({key, *handle});
+ func_graphs_.push_back(fbody);
+ items_.resize(func_graphs_.size());
+ }
+ }
+ return Status::OK();
+}
+
+static void DumpGraph(const char* label, const Graph* g) {
+ if (VLOG_IS_ON(1)) {
+ LOG(INFO) << label << ": " << std::endl << DebugString(g);
+ }
+}
+
+static void SimplifyGraph(Graph* g) {
+ if (RemoveListArrayConverter(g)) {
+ DumpGraph("RemoveListArrayConverter", g);
+ }
+ bool changed;
+ do {
+ changed = false;
+ if (RemoveDeadNodes(g)) {
+ changed = true;
+ DumpGraph("RemoveDeadNodes", g);
+ }
+ if (RemoveIdentityNodes(g)) {
+ changed = true;
+ DumpGraph("RemoveIdentityNodes", g);
+ }
+ FixupSourceAndSinkEdges(g);
+ OptimizeCSE(g, nullptr);
+ DumpGraph("OptimizeCSE", g);
+ } while (changed);
+}
+
+void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g) {
+ DumpGraph("Initial", *g);
+ const int kNumInlineRounds = 10;
+ for (int i = 0; i < kNumInlineRounds; ++i) {
+ if (!ExpandInlineFunctions(lib, *g)) break;
+ DumpGraph("ExpandInlineFunctions", *g);
+ SimplifyGraph(*g);
+ }
+
+ // Makes a copy so that we densify node ids.
+ Graph* copy = new Graph((*g)->op_registry());
+ CopyGraph(**g, copy);
+ delete *g;
+ *g = copy;
+ DumpGraph("ReCopy", *g);
+}
+
+Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
+ const FunctionBody* fbody = GetFunctionBody(handle);
+ CHECK_NOTNULL(fbody);
+ Graph* g = new Graph(lib_def_);
+ CopyGraph(*fbody->graph, g);
+ OptimizeGraph(this, &g);
+
+ // Creates an executor based on the g. This must be done without
+ // holding mu_ because create_kernel_ calls back into the library.
+ LocalExecutorParams params;
+ params.device = device_;
+ params.function_library = this;
+ params.has_control_flow = false;
+ params.create_kernel = create_kernel_;
+ params.delete_kernel = [](OpKernel* kernel) {
+ DeleteNonCachedKernel(kernel);
+ };
+ Executor* exec;
+ TF_RETURN_IF_ERROR(NewLocalExecutor(params, g, &exec));
+
+ *item = new Item;
+ (*item)->exec = exec;
+ return Status::OK();
+}
+
+Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
+ {
+ mutex_lock l(mu_);
+ if (handle >= items_.size()) {
+ return errors::NotFound("Function handle ", handle,
+ " is not valid. Likely an internal error.");
+ }
+ *item = items_[handle];
+ if (*item != nullptr) {
+ (*item)->Ref();
+ return Status::OK();
+ }
+ }
+ // NOTE: We need to call CreateItem out of mu_ because creating an
+ // executor needs to call CreateKernel.
+ TF_RETURN_IF_ERROR(CreateItem(handle, item));
+
+ {
+ mutex_lock l(mu_);
+ if (items_[handle] == nullptr) {
+ // Install *item in items_.
+ items_[handle] = *item;
+ (*item)->Ref();
+ }
+ }
+ return Status::OK();
+}
+
+void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
+ gtl::ArraySlice<Tensor> args,
+ std::vector<Tensor>* rets,
+ DoneCallback done) {
+ if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
+ return done(errors::Cancelled(""));
+ }
+ const FunctionBody* fbody = GetFunctionBody(handle);
+ FunctionCallFrame* frame =
+ new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
+ Status s = frame->SetArgs(args);
+ if (!s.ok()) {
+ delete frame;
+ return done(s);
+ }
+ Item* item = nullptr;
+ s = GetOrCreateItem(handle, &item);
+ if (!s.ok()) {
+ delete frame;
+ return done(s);
+ }
+ Executor::Args exec_args;
+ exec_args.call_frame = frame;
+ exec_args.cancellation_manager = opts.cancellation_manager;
+ exec_args.runner = runner_;
+ item->exec->RunAsync(
+ // Executor args
+ exec_args,
+ // Done callback.
+ [item, frame, rets, done](const Status& status) {
+ item->Unref();
+ Status s = status;
+ if (s.ok()) {
+ s = frame->GetRetvals(rets);
+ }
+ delete frame;
+ done(s);
+ });
+}
+
+bool FunctionLibraryRuntimeImpl::IsDefined(const string& function_name) {
+ return lib_def_->Find(function_name) != nullptr;
+}
+
+FunctionLibraryRuntime* NewFunctionLibraryRuntime(
+ Device* device, Runner runner, const FunctionLibraryDefinition* lib_def) {
+ return new FunctionLibraryRuntimeImpl(device, runner, lib_def);
+}
+
+bool RemoveDeadNodes(Graph* g) {
+ std::vector<bool> visited(g->num_node_ids(), false);
+ visited[Graph::kSourceId] = true;
+ visited[Graph::kSinkId] = true;
+ std::deque<Node*> q;
+ for (auto n : g->nodes()) {
+ if (n->op_def().is_stateful()) {
+ visited[n->id()] = true;
+ } else if (n->type_string() == kArgOp) {
+ visited[n->id()] = true;
+ } else if (n->type_string() == kRetOp) {
+ visited[n->id()] = true;
+ q.push_back(n);
+ }
+ }
+ while (!q.empty()) {
+ const Node* n = q.front();
+ q.pop_front();
+ visited[n->id()] = true;
+ for (auto e : n->in_edges()) {
+ q.push_back(e->src());
+ }
+ }
+ bool removed_any = false;
+ for (Node* n : g->nodes()) {
+ if (!visited[n->id()]) {
+ g->RemoveNode(n);
+ removed_any = true;
+ }
+ }
+ return removed_any;
+}
+
+namespace {
+// If 'edges' contains only 1 non-control edge, returns it. Otherwise,
+// returns a nullptr.
+const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) {
+ const Edge* ret = nullptr;
+ for (const Edge* e : edges) {
+ if (e->IsControlEdge() || ret) return nullptr;
+ ret = e;
+ }
+ return ret;
+}
+} // end namespace
+
+bool RemoveIdentityNodes(Graph* g) {
+ bool removed_any = false;
+ gtl::InlinedVector<Node*, 8> matches;
+ for (Node* n : g->nodes()) {
+ if ((n->type_string() == "Identity") && GetTheOnlyDataEdge(n->in_edges())) {
+ matches.push_back(n);
+ }
+ }
+ if (!matches.empty()) {
+ for (Node* n : matches) {
+ const Edge* in = GetTheOnlyDataEdge(n->in_edges());
+ for (const Edge* out : n->out_edges()) {
+ if (out->IsControlEdge()) {
+ g->AddControlEdge(in->src(), out->dst());
+ } else {
+ g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input());
+ }
+ }
+ g->RemoveNode(n);
+ removed_any = true;
+ }
+ }
+ return removed_any;
+}
+
+bool RemoveListArrayConverter(Graph* g) {
+ gtl::InlinedVector<Node*, 8> matches;
+ for (Node* n : g->nodes()) {
+ if ((n->type_string() == "_ListToArray") ||
+ (n->type_string() == "_ArrayToList")) {
+ matches.push_back(n);
+ }
+ }
+ bool removed_any = false;
+ if (!matches.empty()) {
+ for (Node* n : matches) {
+ if (n->num_inputs() != n->num_outputs()) {
+ continue; // Not expected. Skip.
+ }
+ gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr);
+
+ // Process input edges first.
+ Node* input_control_node = nullptr;
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) {
+ if (input_control_node == nullptr) {
+ // If node "n" has any control dependencies, adds a no-op
+ // node (input_control_node) which the additional Identity
+ // nodes depends on and the input_control_node depends on
+ // the node "n"s control dependencies.
+ input_control_node = AddNoOp(g);
+ }
+ g->AddControlEdge(e->src(), input_control_node);
+ } else {
+ const int index = e->dst_input();
+ Node** id_node = &identity_nodes[index];
+ if (*id_node != nullptr) {
+ LOG(ERROR)
+ << "RemoveListArrayConverter unexpected duplicated input: "
+ << e->dst_input();
+ return removed_any;
+ }
+ *id_node = AddIdentity(g, {e->src(), e->src_output()});
+ }
+ }
+
+ // If node "n" has any control dependencies, the added identity
+ // nodes should have control dependencies on input_control_node.
+ if (input_control_node != nullptr) {
+ for (Node* id : identity_nodes) {
+ g->AddControlEdge(input_control_node, id);
+ }
+ }
+
+ Node* output_control_node = nullptr;
+ for (const Edge* e : n->out_edges()) {
+ if (e->IsControlEdge()) {
+ if (output_control_node == nullptr) {
+ // If node "n" is control-depended upon by other nodes,
+ // adds a no-op node (output_control_node) which those
+ // nodes will depend on and output_control_node depends on
+ // all Identity nodes.
+ output_control_node = AddNoOp(g);
+ }
+ g->AddControlEdge(output_control_node, e->dst());
+ } else {
+ Node* id_node = identity_nodes[e->src_output()];
+ if (id_node == nullptr) {
+ LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: "
+ << e->src_output();
+ return removed_any;
+ }
+ CHECK(id_node);
+ g->AddEdge(id_node, 0, e->dst(), e->dst_input());
+ }
+ }
+
+ // If any nodes have control dependencies on node "n", those
+ // nodes should have control dependencies on
+ // output_control_node.
+ if (output_control_node != nullptr) {
+ for (Node* id : identity_nodes) {
+ g->AddControlEdge(id, output_control_node);
+ }
+ }
+
+ g->RemoveNode(n);
+ removed_any = true;
+ }
+ }
+ return removed_any;
+}
+
+// Returns true iff the function '*fbody' can be inlined at 'node'
+// based on the type signature of 'node' and 'fbody'.
+static bool ValidateInlining(const Node* node, const FunctionBody* fbody) {
+ if (static_cast<size_t>(node->num_inputs()) != fbody->arg_types.size()) {
+ return false;
+ }
+ if (static_cast<size_t>(node->num_inputs()) != fbody->arg_nodes.size()) {
+ return false;
+ }
+ if (static_cast<size_t>(node->num_outputs()) != fbody->ret_types.size()) {
+ return false;
+ }
+ if (static_cast<size_t>(node->num_outputs()) != fbody->ret_nodes.size()) {
+ return false;
+ }
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ if (node->input_type(i) != fbody->arg_types[i]) return false;
+ }
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ if (node->output_type(i) != fbody->ret_types[i]) return false;
+ }
+ return true;
+}
+
+// Given a "caller" in "graph", which is a function call of a function
+// to "fbody". Replaces the "caller" with fbody->graph and connects
+// edges properly.
+static void InlineFunctionBody(Graph* g, Node* caller,
+ const FunctionBody* fbody) {
+ if (!ValidateInlining(caller, fbody)) {
+ LOG(WARNING) << "Inlining mismatch: " << caller->DebugString() << " vs. "
+ << DebugString(fbody->graph);
+ return;
+ }
+
+ // Duplicate fbody->graph into 'g'. First, we copy the nodes of
+ // fbody->graph into 'g' except the source and sink nodes. We copy
+ // edges among nodes in 'fbody->graph'.
+ //
+ // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
+ // remember 'y' in node_map[x->id()].
+ std::vector<Node*> node_map(fbody->graph->num_node_ids());
+ for (Node* n : fbody->graph->nodes()) {
+ if (n->IsSource() || n->IsSink()) continue;
+ CHECK(n->IsOp());
+ node_map[n->id()] = g->CopyNode(n);
+ }
+ for (const Edge* e : fbody->graph->edges()) {
+ if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
+ e->dst()->IsSink()) {
+ continue;
+ }
+ Node* src_copy = node_map[e->src()->id()];
+ Node* dst_copy = node_map[e->dst()->id()];
+ g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
+ }
+
+ // Connect input edges.
+ //
+ // For data edges coming into "caller", we first compute the
+ // <src>:<src_output> for the i-th input in "inputs". We create one
+ // Identity node for each input. Then, we connect inputs[i] to to
+ // the i-th identity node added. The nodes that previously connects
+ // to the j-th output of i-th arg node are reconnected to th i-th
+ // identity node.
+ //
+ // If "caller" has any input control dependencies, we add a NoOp
+ // node "input_control_node". This "input_control_node" depends on
+ // what "caller" depends on, and the added identity nodes depend on
+ // "input_control_node".
+ std::vector<Endpoint> inputs(caller->num_inputs());
+ Node* input_control_node = nullptr;
+ for (const Edge* e : caller->in_edges()) {
+ if (e->IsControlEdge()) {
+ if (input_control_node == nullptr) {
+ input_control_node = AddNoOp(g);
+ }
+ g->AddControlEdge(e->src(), input_control_node);
+ } else {
+ inputs[e->dst_input()] = {e->src(), e->src_output()};
+ }
+ }
+ for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
+ Node* arg = node_map[fbody->arg_nodes[i]->id()];
+ Node* n = AddIdentity(g, inputs[i]);
+ if (input_control_node) {
+ g->AddControlEdge(input_control_node, n);
+ }
+ for (const Edge* e : arg->out_edges()) {
+ if (e->IsControlEdge()) {
+ g->AddControlEdge(n, e->dst());
+ } else {
+ g->AddEdge(n, 0, e->dst(), e->dst_input());
+ }
+ }
+ node_map[fbody->arg_nodes[i]->id()] = n;
+ g->RemoveNode(arg); // 'arg' is disconnected.
+ }
+
+ // Connect output edges.
+ //
+ // For i-th return node in fbody->graph, we add in "g" an identity
+ // node (outputs[i-th]). We then reconnect every incoming edge into
+ // the i-th return node to the added identity node.
+ //
+ // For every data edge coming out of "callee"s i-th output, we
+ // reconnect it to the i-th identity added above.
+ //
+ // If "callee" is control-depended upon by any other nodes, we add a
+ // NoOp node "output_control_node". "output_control_node" depends on
+ // all identity nodes added above. And nodes previously depend on
+ // "callee" is changed to depend on "output_control_node".
+ std::vector<Node*> outputs(caller->num_inputs());
+ for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
+ Node* ret = node_map[fbody->ret_nodes[i]->id()];
+ Endpoint data; // Data input for the ret node.
+ for (const Edge* e : ret->in_edges()) {
+ if (!e->IsControlEdge()) {
+ data = {e->src(), e->src_output()};
+ break;
+ }
+ }
+ CHECK(data.node != nullptr);
+ Node* n = AddIdentity(g, data);
+ outputs[i] = n;
+ for (const Edge* e : ret->in_edges()) {
+ if (e->IsControlEdge()) {
+ g->AddControlEdge(e->src(), n);
+ }
+ }
+ g->RemoveNode(ret); // 'ret' is disconnected.
+ }
+ Node* output_control_node = nullptr;
+ for (const Edge* e : caller->out_edges()) {
+ if (e->IsControlEdge()) {
+ if (output_control_node == nullptr) {
+ output_control_node = AddNoOp(g);
+ for (Node* n : outputs) {
+ g->AddControlEdge(n, output_control_node);
+ }
+ }
+ g->AddControlEdge(output_control_node, e->dst());
+ } else {
+ g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
+ }
+ }
+ g->RemoveNode(caller); // 'caller' is replaced with inlined nodes.
+}
+
+bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
+ std::vector<std::pair<Node*, const FunctionBody*>> candidates;
+ for (Node* node : graph->nodes()) {
+ VLOG(3) << "Expanding " << node->DebugString();
+ FunctionLibraryRuntime::Handle handle;
+ Status s =
+ lib->Instantiate(node->type_string(), node->def().attr(), &handle);
+ if (!s.ok()) {
+ // Either "node" is a primitive op, or the instantiation failed.
+ if (errors::IsNotFound(s)) {
+ VLOG(2) << "ExpandInlineFunctions " << s;
+ } else {
+ LOG(ERROR) << "ExpandInlineFunctions " << s;
+ }
+ continue;
+ }
+ const FunctionBody* fbody = lib->GetFunctionBody(handle);
+ CHECK_NOTNULL(fbody);
+ candidates.push_back({node, fbody});
+ }
+ for (const auto& p : candidates) {
+ InlineFunctionBody(graph, p.first, p.second);
+ }
+ return !candidates.empty();
+}
+
+// TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef.
+// and stash the original NodeDef name as an attr for documentation
+// purpose.
+static void ToGraphDef(const Graph* g, GraphDef* gdef) {
+ // We visit nodes in forward topological sort order, which is a
+ // possible execution order of the graph.
+ std::vector<int> pending(g->num_node_ids());
+ std::deque<const Node*> ready;
+ for (const Node* n : g->nodes()) {
+ pending[n->id()] = n->in_edges().size();
+ if (pending[n->id()] == 0) ready.push_back(n);
+ }
+ gtl::InlinedVector<const Edge*, 4> inputs;
+ gdef->Clear();
+ while (!ready.empty()) {
+ const Node* n = ready.front();
+ ready.pop_front();
+ for (const Edge* e : n->out_edges()) {
+ const Node* next = e->dst();
+ if (--pending[next->id()] == 0) {
+ ready.push_back(next);
+ }
+ }
+ if (!n->IsOp()) continue;
+ NodeDef* ndef = gdef->add_node();
+ ndef->set_name(strings::StrCat("n", n->id()));
+ ndef->set_op(n->type_string());
+ *(ndef->mutable_attr()) = n->def().attr();
+ inputs.clear();
+ inputs.resize(n->num_inputs());
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) {
+ inputs.push_back(e);
+ } else {
+ if (inputs[e->dst_input()] == nullptr) {
+ inputs[e->dst_input()] = e;
+ } else {
+ LOG(WARNING) << "Malformed graph node. multiple input edges: "
+ << n->DebugString();
+ }
+ }
+ }
+ // node->name() is merely NodeDef::name, which are not guaranteed
+ // to be unique and stable after optimization rewrites. Therefore,
+ // we use "n<node id>" instead.
+ for (const Edge* e : inputs) {
+ if (e == nullptr) {
+ ndef->add_input("unknown");
+ } else if (!e->src()->IsOp()) {
+ } else if (e->IsControlEdge()) {
+ ndef->add_input(strings::StrCat("^n", e->src()->id()));
+ } else if (e->src_output() == 0) {
+ ndef->add_input(strings::StrCat("n", e->src()->id()));
+ } else {
+ ndef->add_input(
+ strings::StrCat("n", e->src()->id(), ":", e->src_output()));
+ }
+ }
+ }
+}
+
+string DebugString(const Graph* g) {
+ GraphDef gdef;
+ ToGraphDef(g, &gdef);
+ return DebugString(gdef);
+}
+
+FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
+ DataTypeSlice ret_t, Graph* g)
+ : fdef(f),
+ graph(g),
+ arg_types(arg_t.begin(), arg_t.end()),
+ ret_types(ret_t.begin(), ret_t.end()) {
+ this->arg_nodes.resize(arg_types.size());
+ this->ret_nodes.resize(ret_types.size());
+ for (Node* n : this->graph->nodes()) {
+ gtl::InlinedVector<Node*, 4>* node_vec;
+ if (n->type_string() == kRetOp) {
+ node_vec = &this->ret_nodes;
+ } else if (n->type_string() == kArgOp) {
+ node_vec = &this->arg_nodes;
+ } else {
+ continue;
+ }
+ int index;
+ TF_CHECK_OK(GetNodeAttr(n->def(), "index", &index));
+ CHECK_LE(0, index);
+ CHECK_LT(index, node_vec->size());
+ (*node_vec)[index] = n;
+ }
+}
+
+FunctionBody::~FunctionBody() { delete this->graph; }
+
+class SymbolicGradientHelper {
+ public:
+ explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {}
+
+ ~SymbolicGradientHelper() { delete gbody_; }
+
+ FunctionBody* Compute();
+
+ private:
+ const FunctionBody* fbody_;
+ FunctionBody* gbody_ = nullptr;
+
+ // A vector of output endpoints which represents backpropagated
+ // gradients
+ typedef std::vector<Endpoint> BackpropedGradients;
+
+ // backprops_ is a map from an output endpoint to its accumulated
+ // gradients. When an output endpoint has accumulated all its
+ // gradients, we add a node which sums them up.
+ std::unordered_map<Endpoint, BackpropedGradients, EndpointHash, EndpointEq>
+ backprops_;
+
+ // pending[i] is count-down counter for i-th node's expected
+ // backprops. When pending[i] becomes zero, we collected all
+ // backprop gradients for all output endpoint of the ith-node.
+ std::vector<int> pending_;
+
+ // 'ready' keeps track of nodes that have been completely
+ // backpropped. Initially, for every output y of the function f, we
+ // add dy as an input of the the gradient function.
+ std::deque<Node*> ready_;
+
+ // Makes a copy of fbody_ in gbody_.
+ void Copy();
+
+ // Initialize pending_ and ready_.
+ void InitBackprop();
+
+ // In the original function body, there is a forward edge from 'src'
+ // to 'dst', when the backprop algorithm constructs the node
+ // 'dst_grad' which computes the gradient, we need to propagate it
+ // to 'src'.
+ void BackpropAlongEdge(const Endpoint& dst_grad, const Endpoint& src);
+ void BackpropZerosAlongEdge(const Endpoint& src);
+
+ Endpoint SumGradients(const Endpoint& src);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper);
+};
+
+void SymbolicGradientHelper::Copy() {
+ const Graph& src = *(fbody_->graph);
+ gbody_->graph = new Graph(src.op_registry());
+ Graph* dst = gbody_->graph;
+
+ std::vector<Node*> node_map(src.num_node_ids());
+
+ // Copy the nodes.
+ node_map[src.source_node()->id()] = dst->source_node();
+ node_map[src.sink_node()->id()] = dst->sink_node();
+ for (Node* n : src.nodes()) {
+ if (n->IsSource() || n->IsSink()) continue;
+ CHECK(n->IsOp());
+ node_map[n->id()] = dst->CopyNode(n);
+ }
+
+ // Copy the edges.
+ for (const Edge* e : src.edges()) {
+ Node* src_copy = node_map[e->src()->id()];
+ Node* dst_copy = node_map[e->dst()->id()];
+ dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
+ }
+
+ // Save inputs in copied graph.
+ CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size());
+ gbody_->arg_types = fbody_->arg_types;
+ for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
+ gbody_->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]);
+ }
+
+ // Save outputs in copied graph.
+ CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size());
+ gbody_->ret_types = fbody_->ret_types;
+ for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) {
+ gbody_->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]);
+ }
+}
+
+void SymbolicGradientHelper::BackpropAlongEdge(const Endpoint& dst_grad,
+ const Endpoint& src) {
+ CHECK_NOTNULL(src.node);
+ auto iter = backprops_.find(src);
+ if (iter != backprops_.end()) {
+ auto* grads = &iter->second;
+ grads->push_back(dst_grad);
+ if (--pending_[src.node->id()] == 0) {
+ ready_.push_back(src.node);
+ }
+ }
+}
+
+void SymbolicGradientHelper::BackpropZerosAlongEdge(const Endpoint& src) {
+ CHECK_NOTNULL(src.node);
+ auto iter = backprops_.find(src);
+ if (iter != backprops_.end()) {
+ if (--pending_[src.node->id()] == 0) {
+ ready_.push_back(src.node);
+ }
+ }
+}
+
+void SymbolicGradientHelper::InitBackprop() {
+ Graph* g = gbody_->graph;
+ pending_.resize(g->num_node_ids(), 0);
+ {
+ backprops_.clear();
+ std::unordered_set<Node*> visited;
+ std::deque<Node*> queue;
+ for (Node* n : gbody_->arg_nodes) {
+ queue.push_back(n);
+ }
+
+ // Going forward to figure out which endpoints need backprop-ed.
+ // A node's endpoints need to be backprop-ed only if one of the
+ // arg node can reach the node via data edges.
+ while (!queue.empty()) {
+ Node* n = queue.front();
+ queue.pop_front();
+ visited.insert(n);
+ for (int i = 0; i < n->num_outputs(); ++i) {
+ backprops_[{n, i}].clear();
+ }
+ int num_expected_backprops = 0;
+ for (const Edge* e : n->out_edges()) {
+ if (e->IsControlEdge()) continue;
+ ++num_expected_backprops;
+ if (visited.find(e->dst()) == visited.end()) {
+ queue.push_back(e->dst());
+ }
+ }
+ pending_[n->id()] = num_expected_backprops;
+ }
+ }
+
+ {
+ const int num_y = gbody_->ret_nodes.size();
+ for (int i = 0; i < num_y; ++i) {
+ Node* y = gbody_->ret_nodes[i];
+ DCHECK_EQ(y->type_string(), kRetOp);
+ const DataType dtype = y->input_type(0);
+ const int index = gbody_->arg_nodes.size();
+ Node* dy = AddArg(g, dtype, index);
+ gbody_->arg_types.push_back(dtype);
+ gbody_->arg_nodes.push_back(dy);
+
+ // What's the input to y?
+ Endpoint y_in{nullptr, 0};
+ for (const Edge* e : y->in_edges()) {
+ if (!e->IsControlEdge()) {
+ y_in = {e->src(), e->src_output()};
+ break;
+ }
+ }
+ CHECK_NOTNULL(y_in.node);
+ BackpropAlongEdge({dy, 0}, y_in);
+ }
+ }
+}
+
+Endpoint SymbolicGradientHelper::SumGradients(const Endpoint& src) {
+ Graph* g = gbody_->graph;
+ const DataType dtype = src.dtype();
+ auto iter = backprops_.find(src);
+ CHECK(iter != backprops_.end());
+ const auto& grads = iter->second;
+ if (grads.empty()) {
+ // Nothing propagated back. The best we can come up is zeros.
+ Node* zero_like = AddZerosLike(g, src);
+ return {zero_like, 0};
+ }
+ if (grads.size() == 1) {
+ // Just one backprop edge.
+ return grads[0];
+ }
+ // Otherwise, adds backprop-ed gradients.
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("AddN"); // N-way Add
+ for (const Endpoint& ep : grads) {
+ ndef.add_input(ep.name());
+ }
+ AddNodeAttr("N", static_cast<int64>(grads.size()), &ndef);
+ AddNodeAttr("T", dtype, &ndef);
+ Status s;
+ Node* add = gbody_->graph->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ for (size_t i = 0; i < grads.size(); ++i) {
+ const Endpoint& ep = grads[i];
+ g->AddEdge(ep.node, ep.index, add, i);
+ }
+ return {add, 0};
+}
+
+static bool IsPrimitiveOpWithNoGrad(const string& func) {
+ gradient::Creator creator;
+ Status s = gradient::GetOpGradientCreator(func, &creator);
+ return s.ok() && (creator == nullptr);
+}
+
+FunctionBody* SymbolicGradientHelper::Compute() {
+ CHECK(gbody_ == nullptr);
+ gbody_ = new FunctionBody;
+
+ // Copy fbody_ into gbody_.
+ Copy();
+
+ // Initialize backprops.
+ InitBackprop();
+
+ // Backward propagation.
+ gtl::InlinedVector<Endpoint, 8> dy;
+ Graph* g = gbody_->graph;
+ while (!ready_.empty()) {
+ // n has collected all gradients.
+ Node* n = ready_.front();
+ ready_.pop_front();
+
+ if (n->type_string() == kArgOp) {
+ // We'll handle the _Arg node after backprop is done.
+ continue;
+ }
+
+ // "n" has num_x inputs and num_y outputs.
+ const int num_x = n->num_inputs();
+ const int num_y = n->num_outputs();
+
+ // dy[i] is the sum of i-th output's backpropped gradients.
+ dy.clear();
+ dy.resize(num_y, {nullptr, 0});
+ for (int i = 0; i < num_y; ++i) {
+ dy[i] = SumGradients({n, i});
+ }
+
+ if (IsPrimitiveOpWithNoGrad(n->type_string())) {
+ // No grad defined for this op. Backprops zeros along the in
+ // edges.
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ BackpropZerosAlongEdge({e->src(), e->src_output()});
+ }
+ continue;
+ }
+
+ // Adds a gradient node with num_x + num_y inputs and num_x
+ // outputs.
+ Node* grad = AddSymGrad(g, n, dy);
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ g->AddEdge(e->src(), e->src_output(), grad, e->dst_input());
+ }
+ for (int i = 0; i < num_y; ++i) {
+ g->AddEdge(dy[i].node, dy[i].index, grad, num_x + i);
+ }
+
+ // Backprops along the in edges.
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ BackpropAlongEdge({grad, e->dst_input()}, {e->src(), e->src_output()});
+ }
+ }
+
+ // The gradient's retval nodes.
+ for (Node* n : gbody_->ret_nodes) {
+ g->RemoveNode(n);
+ }
+ gbody_->ret_types = fbody_->arg_types;
+ gbody_->ret_nodes.clear();
+ for (size_t i = 0; i < fbody_->arg_types.size(); ++i) {
+ Endpoint grad = SumGradients({gbody_->arg_nodes[i], 0});
+ Node* ret = AddRet(g, grad, i);
+ gbody_->ret_nodes.push_back(ret);
+ }
+
+ auto ret = gbody_;
+ gbody_ = nullptr;
+ return ret;
+}
+
+FunctionBody* SymbolicGradient(const FunctionBody& f) {
+ return SymbolicGradientHelper(f).Compute();
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h
new file mode 100644
index 0000000000..634b31232a
--- /dev/null
+++ b/tensorflow/core/common_runtime/function.h
@@ -0,0 +1,100 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
+#define TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
+
+#include <functional>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Creates a FunctionLibraryRuntime, which instantiates functions
+// defined in "lib_def" and executes functions on the "device".
+//
+// The returned object does not take ownerships of "device" or
+// "lib_def". The caller must ensure "device" and "lib_def" outlives
+// the returned object.
+typedef std::function<void()> Closure;
+typedef std::function<void(Closure)> Runner;
+FunctionLibraryRuntime* NewFunctionLibraryRuntime(
+ Device* device, Runner runner, const FunctionLibraryDefinition* lib_def);
+
+// FunctionLibraryRuntime::GetFunctionBody returns a description of an
+// instantiated function that is represented as a Graph with arg/ret
+// nodes annotated.
+struct FunctionBody {
+ FunctionDef fdef;
+ Graph* graph = nullptr; // owned.
+ DataTypeVector arg_types;
+ DataTypeVector ret_types;
+ gtl::InlinedVector<Node*, 4> arg_nodes;
+ gtl::InlinedVector<Node*, 4> ret_nodes;
+
+ FunctionBody() {}
+ FunctionBody(const FunctionDef& f, DataTypeSlice arg_types,
+ DataTypeSlice ret_types, Graph* g);
+ ~FunctionBody();
+};
+
+// Debugging facility. Returns a debug string for a graph
+// representing an instantiated function.
+string DebugString(const Graph* instantiated_func_graph);
+
+// A few hand-crafted optimization on the instantiated function body
+// (a Graph*).
+
+// Removes nodes that are
+// 1. not stateful; and
+// 2. not _Arg; and
+// 3. not reachable from _Retval.
+// Returns true iff any node is removed from "g".
+bool RemoveDeadNodes(Graph* g);
+
+// Find a pattern:
+// src -(in)-> node -(out)-> dst, where
+// 1) node is an identity node;
+// 2) in is the only incoming data edge;
+// 3) out is the only outgoing data edge;
+//
+// Rewrites the above pattern with src->dst and relevant data
+// dependencies updated. Repeat the process until no such pattern
+// left.
+bool RemoveIdentityNodes(Graph* g);
+
+// Rewrites _ListToArray and _ArrayToList to a set of Identity nodes.
+bool RemoveListArrayConverter(Graph* g);
+
+// For each node in "graph", if "lib" indicates that the node is a
+// function call, inline the function body. Returns true if at least
+// one node is inlined.
+//
+// This routine goes through "graph" nodes once and applies the
+// inlining. The caller may decide to apply the inlining on "graph"
+// multiple times by calling ExpandInlineFunctions a few times.
+bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph);
+
+// Applies graph rewrite optimzation such as inlining, dead code
+// removal, etc.
+//
+// **g is a graph constructed based on the runtime library 'lib'.
+// OptimizeGraph mutates **g extensively and replaces '*g' with a
+// complete copy. Therefore, the caller should not keep any references
+// to nodes *g.
+void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g);
+
+// Given a numerical function "f", returns another numerical function
+// "g", such that if "f" takes N inputs and produces M outputs, "g"
+// takes N + M inputs and produces N outputs. I.e., if
+// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
+// g is a function which is
+// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
+// dL/dy1, dL/dy2, ..., dL/dy_M),
+// where L is a scalar-value function of (...x_i...).
+//
+// TODO(zhifengc): Asks math expert to say the comment again.
+FunctionBody* SymbolicGradient(const FunctionBody& f);
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
diff --git a/tensorflow/core/common_runtime/gpu/dma_helper.h b/tensorflow/core/common_runtime/gpu/dma_helper.h
new file mode 100644
index 0000000000..7b0750f405
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/dma_helper.h
@@ -0,0 +1,18 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_DMA_HELPER_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_DMA_HELPER_H_
+
+#include "tensorflow/core/public/tensor.h"
+
+// For internal use only. Visibility should be limited to brain/framework.
+
+namespace tensorflow {
+class DMAHelper {
+ public:
+ static bool CanUseDMA(const Tensor* t) { return t->CanUseDMA(); }
+ static const void* base(const Tensor* t) { return t->base<const void>(); }
+ static void* base(Tensor* t) { return t->base<void>(); }
+ static TensorBuffer* buffer(Tensor* t) { return t->buf_; }
+ static const TensorBuffer* buffer(const Tensor* t) { return t->buf_; }
+};
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_DMA_HELPER_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.cc b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.cc
new file mode 100644
index 0000000000..742459c63b
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.cc
@@ -0,0 +1,49 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+GPUAllocatorRetry::GPUAllocatorRetry() : env_(Env::Default()) {}
+
+void* GPUAllocatorRetry::AllocateRaw(
+ std::function<void*(size_t alignment, size_t num_bytes,
+ bool verbose_failure)> alloc_func,
+ int max_millis_to_wait, size_t alignment, size_t num_bytes) {
+ if (num_bytes == 0) {
+ LOG(WARNING) << "Request to allocate 0 bytes";
+ return nullptr;
+ }
+ uint64 deadline_micros = env_->NowMicros() + max_millis_to_wait * 1000;
+ void* ptr = nullptr;
+ while (ptr == nullptr) {
+ ptr = alloc_func(alignment, num_bytes, false);
+ if (ptr == nullptr) {
+ uint64 now = env_->NowMicros();
+ if (now < deadline_micros) {
+ mutex_lock l(mu_);
+ WaitForMilliseconds(&l, &memory_returned_,
+ (deadline_micros - now) / 1000);
+ } else {
+ return alloc_func(alignment, num_bytes, true);
+ }
+ }
+ }
+ return ptr;
+}
+
+void GPUAllocatorRetry::DeallocateRaw(std::function<void(void*)> dealloc_func,
+ void* ptr) {
+ if (ptr == nullptr) {
+ LOG(ERROR) << "Request to free nullptr";
+ return;
+ }
+ dealloc_func(ptr);
+ {
+ mutex_lock l(mu_);
+ memory_returned_.notify_all();
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h
new file mode 100644
index 0000000000..a3298ab222
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h
@@ -0,0 +1,36 @@
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ALLOCATOR_RETRY_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ALLOCATOR_RETRY_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+
+// A retrying wrapper for a memory allocator.
+class GPUAllocatorRetry {
+ public:
+ GPUAllocatorRetry();
+
+ // Call 'alloc_func' to obtain memory. On first call,
+ // 'verbose_failure' will be false. If return value is nullptr,
+ // then wait up to 'max_millis_to_wait' milliseconds, retrying each
+ // time a call to DeallocateRaw() is detected, until either a good
+ // pointer is returned or the deadline is exhausted. If the
+ // deadline is exahusted, try one more time with 'verbose_failure'
+ // set to true. The value returned is either the first good pointer
+ // obtained from 'alloc_func' or nullptr.
+ void* AllocateRaw(std::function<void*(size_t alignment, size_t num_bytes,
+ bool verbose_failure)> alloc_func,
+ int max_millis_to_wait, size_t alignment, size_t bytes);
+
+ // Calls dealloc_func(ptr) and then notifies any threads blocked in
+ // AllocateRaw() that would like to retry.
+ void DeallocateRaw(std::function<void(void* ptr)> dealloc_func, void* ptr);
+
+ private:
+ Env* env_;
+ mutex mu_;
+ condition_variable memory_returned_;
+};
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ALLOCATOR_RETRY_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc
new file mode 100644
index 0000000000..db1c58cc65
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc
@@ -0,0 +1,175 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
+
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/env.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+class FakeAllocator {
+ public:
+ FakeAllocator(size_t cap, int millis_to_wait)
+ : memory_capacity_(cap), millis_to_wait_(millis_to_wait) {}
+
+ // Allocate just keeps track of the number of outstanding allocations,
+ // not their sizes. Assume a constant size for each.
+ void* AllocateRaw(size_t alignment, size_t num_bytes) {
+ return retry_.AllocateRaw(
+ [this](size_t a, size_t nb, bool v) {
+ mutex_lock l(mu_);
+ if (memory_capacity_ > 0) {
+ --memory_capacity_;
+ return good_ptr_;
+ } else {
+ return static_cast<void*>(nullptr);
+ }
+ },
+ millis_to_wait_, alignment, num_bytes);
+ }
+
+ void DeallocateRaw(void* ptr) {
+ retry_.DeallocateRaw(
+ [this](void* p) {
+ mutex_lock l(mu_);
+ ++memory_capacity_;
+ },
+ ptr);
+ }
+
+ private:
+ GPUAllocatorRetry retry_;
+ void* good_ptr_ = reinterpret_cast<void*>(0xdeadbeef);
+ mutex mu_;
+ size_t memory_capacity_ GUARDED_BY(mu_);
+ int millis_to_wait_;
+};
+
+class GPUAllocatorRetryTest : public ::testing::Test {
+ protected:
+ GPUAllocatorRetryTest() {}
+
+ void LaunchConsumerThreads(int num_consumers, int cap_needed) {
+ consumer_count_.resize(num_consumers, 0);
+ for (int i = 0; i < num_consumers; ++i) {
+ consumers_.push_back(Env::Default()->StartThread(
+ ThreadOptions(), "anon_thread", [this, i, cap_needed]() {
+ do {
+ void* ptr = nullptr;
+ for (int j = 0; j < cap_needed; ++j) {
+ ptr = alloc_->AllocateRaw(16, 1);
+ if (ptr == nullptr) {
+ mutex_lock l(mu_);
+ has_failed_ = true;
+ return;
+ }
+ }
+ ++consumer_count_[i];
+ for (int j = 0; j < cap_needed; ++j) {
+ alloc_->DeallocateRaw(ptr);
+ }
+ } while (!notifier_.HasBeenNotified());
+ }));
+ }
+ }
+
+ // Wait up to wait_micros microseconds for has_failed_ to equal expected,
+ // then terminate all threads.
+ void JoinConsumerThreads(bool expected, int wait_micros) {
+ while (wait_micros > 0) {
+ {
+ mutex_lock l(mu_);
+ if (has_failed_ == expected) break;
+ }
+ int interval_micros = std::min(1000, wait_micros);
+ Env::Default()->SleepForMicroseconds(interval_micros);
+ wait_micros -= interval_micros;
+ }
+ notifier_.Notify();
+ for (auto c : consumers_) {
+ // Blocks until thread terminates.
+ delete c;
+ }
+ }
+
+ std::unique_ptr<FakeAllocator> alloc_;
+ std::vector<Thread*> consumers_;
+ std::vector<int> consumer_count_;
+ Notification notifier_;
+ mutex mu_;
+ bool has_failed_ GUARDED_BY(mu_) = false;
+ int count_ GUARDED_BY(mu_) = 0;
+};
+
+// Verifies correct retrying when memory is slightly overcommitted but
+// we allow retry.
+TEST_F(GPUAllocatorRetryTest, RetrySuccess) {
+ // Support up to 2 allocations simultaneously, waits up to 10 msec for
+ // a chance to alloc.
+ alloc_.reset(new FakeAllocator(2, 10000));
+ // Launch 3 consumers, each of whom needs 1 unit at a time.
+ LaunchConsumerThreads(3, 1);
+ // This should be enough time for each consumer to be satisfied many times.
+ Env::Default()->SleepForMicroseconds(50000);
+ JoinConsumerThreads(false, 0);
+ for (int i = 0; i < 3; ++i) {
+ LOG(INFO) << "Consumer " << i << " is " << consumer_count_[i];
+ }
+ {
+ mutex_lock l(mu_);
+ EXPECT_FALSE(has_failed_);
+ }
+ EXPECT_GT(consumer_count_[0], 0);
+ EXPECT_GT(consumer_count_[1], 0);
+ EXPECT_GT(consumer_count_[2], 0);
+}
+
+// Verifies OutOfMemory failure when memory is slightly overcommitted
+// and retry is not allowed.
+TEST_F(GPUAllocatorRetryTest, NoRetryFail) {
+ // Support up to 2 allocations simultaneously, waits up to 0 msec for
+ // a chance to alloc.
+ alloc_.reset(new FakeAllocator(2, 0));
+ // Launch 3 consumers, each of whom needs 1 unit at a time.
+ LaunchConsumerThreads(3, 1);
+ Env::Default()->SleepForMicroseconds(50000);
+ // Will wait up to 10 seconds for proper race condition to occur, resulting
+ // in failure.
+ JoinConsumerThreads(true, 10000000);
+ for (int i = 0; i < 3; ++i) {
+ LOG(INFO) << "Consumer " << i << " is " << consumer_count_[i];
+ }
+ {
+ mutex_lock l(mu_);
+ EXPECT_TRUE(has_failed_);
+ }
+}
+
+// Verifies OutOfMemory failure when retry is allowed but memory capacity
+// is too low even for retry.
+TEST_F(GPUAllocatorRetryTest, RetryInsufficientFail) {
+ // Support up to 2 allocations simultaneously, waits up to 10 msec for
+ // a chance to alloc.
+ alloc_.reset(new FakeAllocator(2, 10000));
+ // Launch 3 consumers, each of whom needs 2 units at a time. We expect
+ // deadlock where 2 consumers each hold 1 unit, and timeout trying to
+ // get the second.
+ LaunchConsumerThreads(3, 2);
+ Env::Default()->SleepForMicroseconds(50000);
+ // Will wait up to 10 seconds for proper race condition to occur, resulting
+ // in failure.
+ JoinConsumerThreads(true, 10000000);
+ for (int i = 0; i < 3; ++i) {
+ LOG(INFO) << "Consumer " << i << " is " << consumer_count_[i];
+ }
+ {
+ mutex_lock l(mu_);
+ EXPECT_TRUE(has_failed_);
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
new file mode 100644
index 0000000000..3df833594f
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
@@ -0,0 +1,397 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
+
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/lib/core/bits.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+GPUBFCAllocator::GPUBFCAllocator(int device_id, size_t total_memory)
+ : device_id_(device_id) {
+ // Get a pointer to the stream_executor for this device
+ stream_exec_ = GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+
+ // Allocate the requested amount of memory.
+ gpu_memory_size_ = total_memory;
+
+ LOG(INFO) << "Allocating " << strings::HumanReadableNumBytes(gpu_memory_size_)
+ << " bytes.";
+ gpu::DeviceMemory<char> gpu_mem =
+ stream_exec_->AllocateArray<char>(gpu_memory_size_);
+
+ QCHECK(gpu_mem != nullptr)
+ << " Could not allocate GPU device memory for device " << device_id
+ << ". Tried to allocate "
+ << strings::HumanReadableNumBytes(gpu_memory_size_);
+ base_ptr_ = gpu_mem.opaque();
+ LOG(INFO) << "GPU " << device_id << " memory begins at " << base_ptr_
+ << " extends to "
+ << static_cast<void*>(
+ (static_cast<char*>(base_ptr_) + gpu_memory_size_));
+
+ // Create a bunch of bins of various good sizes.
+
+ // Covers allocations of exactly 256 bytes (the minimum size).
+ bins_.insert(std::make_pair(256, new Bin(256)));
+
+ // We create bins to fit all possible ranges that cover the
+ // gpu_memory_size_ starting from allocations up to 1024 bytes to
+ // allocations up to (and including) the memory limit.
+ for (size_t bin_size = 1024; bin_size < gpu_memory_size_ * 2; bin_size *= 2) {
+ LOG(INFO) << "Creating bin of max chunk size "
+ << strings::HumanReadableNumBytes(bin_size);
+ bins_.insert(std::make_pair(bin_size, new Bin(bin_size)));
+ }
+
+ // Create one large chunk for the whole memory space that will
+ // be chunked later.
+ GPUBFCAllocator::Chunk* c = new GPUBFCAllocator::Chunk();
+ c->ptr = gpu_mem.opaque();
+ c->size = gpu_memory_size_;
+ c->in_use = false;
+ c->prev = nullptr;
+ c->next = nullptr;
+
+ ptr_to_chunk_map_.insert(std::make_pair(c->ptr, c));
+
+ // Insert the chunk into the right bin.
+ ReassignChunkToBin(c);
+}
+
+GPUBFCAllocator::~GPUBFCAllocator() {
+ // Return memory back.
+ if (base_ptr_) {
+ gpu::DeviceMemoryBase gpu_ptr{base_ptr_};
+ stream_exec_->Deallocate(&gpu_ptr);
+ }
+
+ gtl::STLDeleteValues(&bins_);
+}
+
+void* GPUBFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes) {
+ static const int64 kMaxMillisToWait = 10000; // 10 seconds
+ return retry_helper_.AllocateRaw(
+ [this](size_t a, size_t nb, bool v) {
+ return AllocateRawInternal(a, nb, v);
+ },
+ kMaxMillisToWait, unused_alignment, num_bytes);
+}
+
+void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment,
+ size_t num_bytes,
+ bool dump_log_on_failure) {
+ if (num_bytes == 0) {
+ LOG(ERROR) << "tried to allocate 0 bytes";
+ return nullptr;
+ }
+ // First, always allocate memory of at least 256 bytes, and always
+ // allocate multiples of 256 bytes so all memory addresses are
+ // nicely byte aligned.
+ size_t rounded_bytes = (256 * ((num_bytes + 255) / 256));
+ DCHECK_EQ(0, rounded_bytes % 256);
+
+ // The BFC allocator tries to find the best fit first.
+ //
+ // First identify the first bin that could satisfy rounded_bytes.
+ auto it = bins_.lower_bound(rounded_bytes);
+ if (it == bins_.end()) {
+ LOG(ERROR) << " Asked for " << rounded_bytes << " but largest bin was "
+ << bins_.rbegin()->first;
+ return nullptr;
+ }
+
+ mutex_lock l(lock_);
+ for (; it != bins_.end(); ++it) {
+ // Start searching from the first bin for the smallest chunk that fits
+ // rounded_bytes.
+ Bin* b = it->second;
+ for (GPUBFCAllocator::Chunk* chunk : b->chunks) {
+ if (!chunk->in_use && chunk->size > rounded_bytes) {
+ // We found an existing chunk that fits us that wasn't in use.
+ chunk->in_use = true;
+
+ // If we can break the size of the chunk into two reasonably
+ // large pieces, do so.
+ //
+ // TODO(vrv): What should be the criteria when deciding when
+ // to split?
+ if (chunk->size >= rounded_bytes * 2) {
+ SplitChunk(chunk, rounded_bytes);
+ }
+
+ // The requested size of the returned chunk is what the user
+ // has allocated.
+ chunk->requested_size = num_bytes;
+
+ VLOG(4) << "Returning: " << chunk->ptr;
+ return chunk->ptr;
+ }
+ }
+ }
+
+ // We searched all bins for an existing free chunk to use and
+ // couldn't find one. This means we must have run out of memory,
+ // Dump the memory log for analysis.
+ if (dump_log_on_failure) {
+ DumpMemoryLog(rounded_bytes);
+ LOG(WARNING) << "Ran out of memory trying to allocate "
+ << strings::HumanReadableNumBytes(num_bytes)
+ << ". See logs for memory state";
+ }
+ return nullptr;
+}
+
+void GPUBFCAllocator::SplitChunk(GPUBFCAllocator::Chunk* c, size_t num_bytes) {
+ // Create a new chunk starting num_bytes after c
+ GPUBFCAllocator::Chunk* new_chunk = new GPUBFCAllocator::Chunk();
+ new_chunk->ptr = static_cast<void*>(static_cast<char*>(c->ptr) + num_bytes);
+ VLOG(6) << "Adding to chunk map: " << new_chunk->ptr;
+ ptr_to_chunk_map_.insert(std::make_pair(new_chunk->ptr, new_chunk));
+
+ // Set the new sizes of the chunks.
+ new_chunk->size = c->size - num_bytes;
+ c->size = num_bytes;
+
+ // The new chunk is not in use.
+ new_chunk->in_use = false;
+
+ // Maintain the pointers.
+ // c <-> c_neighbor becomes
+ // c <-> new_chunk <-> c_neighbor
+ GPUBFCAllocator::Chunk* c_neighbor = c->next;
+ new_chunk->prev = c;
+ new_chunk->next = c_neighbor;
+ c->next = new_chunk;
+ if (c_neighbor) {
+ c_neighbor->prev = new_chunk;
+ }
+
+ // Maintain the bins
+ ReassignChunkToBin(new_chunk);
+ ReassignChunkToBin(c);
+}
+
+void GPUBFCAllocator::DeallocateRaw(void* ptr) {
+ retry_helper_.DeallocateRaw([this](void* p) { DeallocateRawInternal(p); },
+ ptr);
+}
+
+void GPUBFCAllocator::DeallocateRawInternal(void* ptr) {
+ if (ptr == nullptr) {
+ LOG(ERROR) << "tried to deallocate nullptr";
+ return;
+ }
+ mutex_lock l(lock_);
+
+ // Find the chunk from the ptr.
+ auto it = ptr_to_chunk_map_.find(ptr);
+ CHECK(it != ptr_to_chunk_map_.end())
+ << "Asked to deallocate a pointer we never allocated: " << ptr;
+
+ GPUBFCAllocator::Chunk* c = it->second;
+ VLOG(6) << "Chunk at " << c->ptr << " no longer in use";
+ // Mark the chunk as no longer in use
+ c->in_use = false;
+
+ // Consider coalescing it.
+ MaybeCoalesce(c);
+}
+
+// Merges c1 and c2 when c1->next is c2 and c2->prev is c1.
+// We merge c2 into c1.
+void GPUBFCAllocator::Merge(GPUBFCAllocator::Chunk* c1,
+ GPUBFCAllocator::Chunk* c2) {
+ // We can only merge chunks that are not in use.
+ DCHECK(!c1->in_use && !c2->in_use);
+
+ // c1's prev doesn't change, still points to the same ptr, and is
+ // still not in use.
+
+ // Fix up neighbor pointers
+ //
+ // c1 <-> c2 <-> c3 should become
+ // c1 <-> c3
+ GPUBFCAllocator::Chunk* c3 = c2->next;
+ c1->next = c3;
+ CHECK(c2->prev == c1);
+ if (c3 != nullptr) {
+ c3->prev = c1;
+ }
+
+ // Set the new size
+ c1->size += c2->size;
+
+ // Delete c2 and cleanup all state
+ RemoveChunkFromBin(c2);
+}
+
+void GPUBFCAllocator::ReassignChunkToBin(GPUBFCAllocator::Chunk* c) {
+ auto it = bins_.lower_bound(c->size);
+ CHECK(it != bins_.end()) << " Tried to reassign to non-existent bin for size "
+ << c->size;
+
+ Bin* new_bin = it->second;
+
+ // If the bin has not changed, do nothing.
+ Bin* old_bin = c->bin;
+ if (old_bin != nullptr && new_bin == old_bin) {
+ return;
+ }
+
+ // The bin has changed. Add the chunk to the new bin and remove
+ // the chunk from the old bin.
+ new_bin->chunks.insert(c);
+ c->bin = new_bin;
+
+ if (old_bin == nullptr) {
+ return;
+ }
+
+ // Remove chunk from old bin
+ for (auto it = old_bin->chunks.begin(); it != old_bin->chunks.end(); ++it) {
+ if (*it == c) {
+ old_bin->chunks.erase(it);
+ return;
+ }
+ }
+ CHECK(false) << "Could not find chunk in old bin";
+}
+
+void GPUBFCAllocator::RemoveChunkFromBin(GPUBFCAllocator::Chunk* c) {
+ Bin* b = c->bin;
+ for (auto it = b->chunks.begin(); it != b->chunks.end(); ++it) {
+ Chunk* other_c = *it;
+ if (other_c->ptr == c->ptr) {
+ b->chunks.erase(it);
+ VLOG(4) << "Removing: " << c->ptr;
+ ptr_to_chunk_map_.erase(c->ptr);
+ delete c;
+ return;
+ }
+ }
+
+ CHECK(false) << "Could not find chunk in bin";
+}
+
+void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) {
+ // This chunk is no longer in-use, consider coalescing the chunk
+ // with adjacent chunks.
+ Chunk* chunk_to_reassign = nullptr;
+
+ // If the next chunk is free, coalesce the two, if the result would
+ // fit in an existing bin.
+ if (c->next && !c->next->in_use) {
+ VLOG(8) << "Chunk at " << c->next->ptr << " merging with c " << c->ptr;
+
+ chunk_to_reassign = c;
+
+ // Deletes c->next
+ Merge(c, c->next);
+ }
+
+ // If the previous chunk is free, coalesce the two
+ if (c->prev && !c->prev->in_use) {
+ VLOG(8) << "Chunk at " << c->ptr << " merging into c->prev "
+ << c->prev->ptr;
+
+ chunk_to_reassign = c->prev;
+
+ // Deletes c
+ Merge(c->prev, c);
+ }
+
+ // Reassign the final merged chunk into the right bin.
+ if (chunk_to_reassign) {
+ ReassignChunkToBin(chunk_to_reassign);
+ }
+}
+
+void GPUBFCAllocator::AddAllocVisitor(Visitor visitor) {
+ VLOG(1) << "AddVisitor";
+ mutex_lock l(lock_);
+ region_visitors_.push_back(visitor);
+ visitor(base_ptr_, gpu_memory_size_);
+}
+
+bool GPUBFCAllocator::TracksAllocationSizes() { return true; }
+
+size_t GPUBFCAllocator::RequestedSize(void* ptr) {
+ mutex_lock l(lock_);
+ auto it = ptr_to_chunk_map_.find(ptr);
+ CHECK(it != ptr_to_chunk_map_.end())
+ << "Asked for requested size of pointer we never allocated: " << ptr;
+ GPUBFCAllocator::Chunk* c = it->second;
+ return c->requested_size;
+}
+
+size_t GPUBFCAllocator::AllocatedSize(void* ptr) {
+ mutex_lock l(lock_);
+ auto it = ptr_to_chunk_map_.find(ptr);
+ CHECK(it != ptr_to_chunk_map_.end())
+ << "Asked for allocated size of pointer we never allocated: " << ptr;
+ GPUBFCAllocator::Chunk* c = it->second;
+ return c->size;
+}
+
+void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) {
+ // For each bin: tally up the total number of chunks and bytes.
+ for (auto bit : bins_) {
+ Bin* b = bit.second;
+
+ size_t total_bytes_in_use = 0;
+ size_t total_bytes_in_bin = 0;
+ size_t total_requested_bytes_in_use = 0;
+ size_t total_requested_bytes_in_bin = 0;
+ size_t total_chunks_in_use = 0;
+ size_t total_chunks_in_bin = 0;
+ for (Chunk* c : b->chunks) {
+ total_bytes_in_bin += c->size;
+ total_requested_bytes_in_bin += c->requested_size;
+ ++total_chunks_in_bin;
+ if (c->in_use) {
+ total_bytes_in_use += c->size;
+ total_requested_bytes_in_use += c->requested_size;
+ ++total_chunks_in_use;
+ }
+ }
+
+ LOG(INFO) << "Bin (" << b->bin_size
+ << "): \tTotal Chunks: " << total_chunks_in_bin
+ << ", Chunks in use: " << total_chunks_in_use << " "
+ << strings::HumanReadableNumBytes(total_bytes_in_bin)
+ << " allocated for chunks. "
+ << strings::HumanReadableNumBytes(total_requested_bytes_in_bin)
+ << " client-requested for chunks. "
+ << strings::HumanReadableNumBytes(total_bytes_in_use)
+ << " in use in bin. "
+ << strings::HumanReadableNumBytes(total_requested_bytes_in_use)
+ << " client-requested in use in bin.";
+ }
+
+ // Find the bin that we would have liked to allocate in, so we
+ // can get some further analysis about fragmentation.
+ auto it = bins_.lower_bound(num_bytes);
+ if (it != bins_.end()) {
+ Bin* b = it->second;
+
+ LOG(INFO) << "Bin for " << strings::HumanReadableNumBytes(num_bytes)
+ << " was " << strings::HumanReadableNumBytes(b->bin_size)
+ << ", Chunk State: ";
+
+ for (Chunk* c : b->chunks) {
+ LOG(INFO) << c->DebugString(true);
+ }
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
new file mode 100644
index 0000000000..3d1601e132
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
@@ -0,0 +1,156 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
+#include "tensorflow/core/common_runtime/gpu/visitable_allocator.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace tensorflow {
+
+// A GPU memory allocator that implements a 'best-fit with coalescing'
+// algorithm. This is essentially a very simple version of Doug Lea's
+// malloc (dlmalloc).
+//
+// The goal of this allocator is to support defragmentation via
+// coalescing. One assumption we make is that the process using this
+// allocator owns pretty much all of the GPU memory, and that nearly
+// all requests to allocate GPU memory go through this interface.
+class GPUBFCAllocator : public VisitableAllocator {
+ public:
+ // 'device_id' refers to the StreamExecutor ID of the device within
+ // the process and must reference a valid ID in the process.
+ explicit GPUBFCAllocator(int device_id, size_t total_memory);
+ ~GPUBFCAllocator() override;
+
+ string Name() override { return "gpu_bfc"; }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+ void DeallocateRaw(void* ptr) override;
+
+ void AddAllocVisitor(Visitor visitor) override;
+
+ // Does nothing, because gpu memory is never freed.
+ void AddFreeVisitor(Visitor visitor) override {}
+
+ bool TracksAllocationSizes() override;
+
+ size_t RequestedSize(void* ptr) override;
+
+ size_t AllocatedSize(void* ptr) override;
+
+ private:
+ struct Bin;
+
+ void* AllocateRawInternal(size_t alignment, size_t num_bytes,
+ bool dump_log_on_failure);
+ void DeallocateRawInternal(void* ptr);
+
+ // Chunks point to GPU memory. Their prev/next pointers form a
+ // doubly-linked list of addresses sorted by GPU base address that
+ // must be contiguous. Chunks contain information about whether
+ // they are in use or whether they are free, and contain a pointer
+ // to the bin they are in.
+ struct Chunk {
+ size_t size = 0; // Full size of GPU buffer.
+
+ // We sometimes give chunks that are larger than needed to reduce
+ // fragmentation. requested_size keeps track of what the client
+ // actually wanted so we can understand whether our splitting
+ // strategy is efficient.
+ size_t requested_size = 0;
+
+ bool in_use = false;
+ void* ptr = nullptr; // pointer to granted GPU subbuffer.
+
+ // If not null, the memory referred to by 'prev' is directly
+ // preceding the memory used by this chunk. E.g., It should start
+ // at 'ptr - prev->size'
+ Chunk* prev = nullptr;
+
+ // If not null, the memory referred to by 'next' is directly
+ // following the memory used by this chunk. E.g., It should be at
+ // 'ptr + size'
+ Chunk* next = nullptr;
+
+ // What bin are we in?
+ Bin* bin = nullptr;
+
+ string DebugString(bool recurse) {
+ string dbg;
+ strings::StrAppend(&dbg, " Size: ", strings::HumanReadableNumBytes(size),
+ " | Requested Size: ",
+ strings::HumanReadableNumBytes(requested_size),
+ " | in_use: ", in_use);
+ if (recurse && prev) {
+ strings::StrAppend(&dbg, ", prev: ", prev->DebugString(false));
+ }
+ if (recurse && next) {
+ strings::StrAppend(&dbg, ", next: ", next->DebugString(false));
+ }
+ return dbg;
+ }
+ };
+
+ Chunk* AllocateNewChunk(size_t num_bytes);
+ void SplitChunk(Chunk* c, size_t num_bytes);
+ void Merge(Chunk* c1, Chunk* c2);
+ void MaybeCoalesce(Chunk* c);
+
+ void ReassignChunkToBin(Chunk* c);
+ void RemoveChunkFromBin(Chunk* c);
+
+ void DumpMemoryLog(size_t num_bytes);
+
+ // A Bin is a collection of similar-sized Chunks.
+ struct Bin {
+ // All chunks in this bin have >= bin_size memory.
+ size_t bin_size = 0;
+
+ struct ChunkComparator {
+ bool operator()(Chunk* a, Chunk* b) { return a->size < b->size; }
+ };
+
+ // List of chunks within the bin, sorted by chunk size.
+ std::multiset<Chunk*, ChunkComparator> chunks;
+
+ explicit Bin(size_t bs) : bin_size(bs) {}
+
+ ~Bin() { gtl::STLDeleteElements(&chunks); }
+ };
+
+ GPUAllocatorRetry retry_helper_;
+
+ // Structures immutable after construction
+ const int device_id_;
+ // The base pointer where all the GPU memory begins.
+ void* base_ptr_ = nullptr;
+ size_t gpu_memory_size_ = 0;
+
+ // Map from bin size to Bin
+ // After construction, the bin map is never resized.
+ std::map<size_t, Bin*> bins_;
+
+ perftools::gputools::StreamExecutor* stream_exec_; // Not owned.
+
+ // Structures mutable after construction
+ mutable mutex lock_;
+ // Not owned.
+ std::unordered_map<void*, Chunk*> ptr_to_chunk_map_;
+
+ // Called once on each region, ASAP.
+ std::vector<Visitor> region_visitors_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
new file mode 100644
index 0000000000..7b5e8aec1d
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
@@ -0,0 +1,166 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
+
+#include <algorithm>
+#include <vector>
+
+#include "tensorflow/stream_executor/stream_executor.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+namespace {
+
+TEST(GPUBFCAllocatorTest, NoDups) {
+ GPUBFCAllocator a(0, 1 << 30);
+ // Allocate a lot of raw pointers
+ std::vector<void*> ptrs;
+ for (int s = 1; s < 1024; s++) {
+ void* raw = a.AllocateRaw(1, s);
+ ptrs.push_back(raw);
+ }
+
+ std::sort(ptrs.begin(), ptrs.end());
+
+ // Make sure none of them are equal, and that none of them overlap.
+ for (int i = 0; i < ptrs.size(); i++) {
+ if (i > 0) {
+ ASSERT_NE(ptrs[i], ptrs[i - 1]); // No dups
+ size_t req_size = a.RequestedSize(ptrs[i - 1]);
+ ASSERT_GT(req_size, 0);
+ ASSERT_GE(static_cast<char*>(ptrs[i]) - static_cast<char*>(ptrs[i - 1]),
+ req_size);
+ }
+ }
+
+ for (int i = 0; i < ptrs.size(); i++) {
+ a.DeallocateRaw(ptrs[i]);
+ }
+}
+
+TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) {
+ GPUBFCAllocator a(0, 1 << 30);
+ // Allocate 256 raw pointers of sizes between 100 bytes and about
+ // a meg
+ random::PhiloxRandom philox(123, 17);
+ random::SimplePhilox rand(&philox);
+
+ std::vector<void*> initial_ptrs;
+ for (int s = 1; s < 256; s++) {
+ size_t size = std::min<size_t>(
+ std::max<size_t>(rand.Rand32() % 1048576, 100), 1048576);
+ void* raw = a.AllocateRaw(1, size);
+
+ initial_ptrs.push_back(raw);
+ }
+
+ // Deallocate half of the memory, and keep track of the others.
+ std::vector<void*> existing_ptrs;
+ for (int i = 0; i < initial_ptrs.size(); i++) {
+ if (i % 2 == 1) {
+ a.DeallocateRaw(initial_ptrs[i]);
+ } else {
+ existing_ptrs.push_back(initial_ptrs[i]);
+ }
+ }
+
+ // Allocate a lot of raw pointers
+ for (int s = 1; s < 256; s++) {
+ size_t size = std::min<size_t>(
+ std::max<size_t>(rand.Rand32() % 1048576, 100), 1048576);
+ void* raw = a.AllocateRaw(1, size);
+ existing_ptrs.push_back(raw);
+ }
+
+ std::sort(existing_ptrs.begin(), existing_ptrs.end());
+ // Make sure none of them are equal
+ for (int i = 0; i < existing_ptrs.size(); i++) {
+ if (i > 0) {
+ CHECK_NE(existing_ptrs[i], existing_ptrs[i - 1]); // No dups
+
+ size_t req_size = a.RequestedSize(existing_ptrs[i - 1]);
+ ASSERT_GT(req_size, 0);
+
+ // Check that they don't overlap.
+ ASSERT_GE(static_cast<char*>(existing_ptrs[i]) -
+ static_cast<char*>(existing_ptrs[i - 1]),
+ req_size);
+ }
+ }
+
+ for (int i = 0; i < existing_ptrs.size(); i++) {
+ a.DeallocateRaw(existing_ptrs[i]);
+ }
+}
+
+TEST(GPUBFCAllocatorTest, ExerciseCoalescing) {
+ GPUBFCAllocator a(0, 1 << 30);
+
+ float* first_ptr = a.Allocate<float>(1024);
+ a.Deallocate(first_ptr);
+ for (int i = 0; i < 1024; ++i) {
+ // Allocate several buffers of different sizes, and then clean them
+ // all up. We should be able to repeat this endlessly without
+ // causing fragmentation and growth.
+ float* t1 = a.Allocate<float>(1024);
+
+ int64* t2 = a.Allocate<int64>(1048576);
+ double* t3 = a.Allocate<double>(2048);
+ float* t4 = a.Allocate<float>(10485760);
+
+ a.Deallocate(t1);
+ a.Deallocate(t2);
+ a.Deallocate(t3);
+ a.Deallocate(t4);
+ }
+
+ // At the end, we should have coalesced all memory into one region
+ // starting at the beginning, so validate that allocating a pointer
+ // starts from this region.
+ float* first_ptr_after = a.Allocate<float>(1024);
+ EXPECT_EQ(first_ptr, first_ptr_after);
+ a.Deallocate(first_ptr_after);
+}
+
+TEST(GPUBFCAllocatorTest, AllocateZeroBufSize) {
+ GPUBFCAllocator a(0, 1 << 30);
+ float* ptr = a.Allocate<float>(0);
+ EXPECT_EQ(nullptr, ptr);
+}
+
+TEST(GPUBFCAllocatorTest, TracksSizes) {
+ GPUBFCAllocator a(0, 1 << 30);
+ EXPECT_EQ(true, a.TracksAllocationSizes());
+}
+
+TEST(GPUBFCAllocatorTest, AllocatedVsRequested) {
+ GPUBFCAllocator a(0, 1 << 30);
+ float* t1 = a.Allocate<float>(1);
+ EXPECT_EQ(4, a.RequestedSize(t1));
+ EXPECT_EQ(256, a.AllocatedSize(t1));
+ a.Deallocate(t1);
+}
+
+TEST(GPUBFCAllocatorTest, TestCustomMemoryLimit) {
+ // Configure a 1MiB byte limit
+ GPUBFCAllocator a(0, 1 << 20);
+
+ float* first_ptr = a.Allocate<float>(1 << 6);
+ float* second_ptr = a.Allocate<float>(1 << 20);
+
+ EXPECT_NE(nullptr, first_ptr);
+ EXPECT_EQ(nullptr, second_ptr);
+ a.Deallocate(first_ptr);
+}
+
+} // namespace
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
new file mode 100644
index 0000000000..5ec405cd80
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
@@ -0,0 +1,186 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h"
+
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+#define MASK_WORDS 2
+#define MASK_BYTES (MASK_WORDS * sizeof(int64))
+
+namespace {
+
+static int64* NewMask(int64 word) {
+ int64* m = new int64[MASK_WORDS];
+ for (int i = 0; i < MASK_WORDS; ++i) {
+ m[i] = word;
+ }
+ return m;
+}
+
+static int64* before_mask = NewMask(0xabababababababab);
+static int64* after_mask = NewMask(0xcdcdcdcdcdcdcdcd);
+
+bool CheckMask(perftools::gputools::StreamExecutor* exec, void* ptr,
+ int64* mask) {
+ gpu::DeviceMemory<int64> gpu_ptr{gpu::DeviceMemoryBase{ptr, MASK_BYTES}};
+ int64 tmp[MASK_WORDS];
+
+ if (!exec->SynchronousMemcpy(&tmp, gpu_ptr, MASK_BYTES)) {
+ LOG(FATAL) << "Could not copy debug mask";
+ }
+
+ bool ok = true;
+ for (int i = 0; i < MASK_WORDS; ++i) {
+ ok &= (mask[i] == tmp[i]);
+ if (!ok) {
+ LOG(ERROR) << "i=" << i
+ << " mask=" << reinterpret_cast<const void*>(mask[i])
+ << " field=" << reinterpret_cast<const void*>(tmp[i]);
+ }
+ }
+
+ return ok;
+}
+
+void InitMask(perftools::gputools::StreamExecutor* exec, void* ptr,
+ int64* mask) {
+ gpu::DeviceMemory<int64> gpu_ptr{gpu::DeviceMemoryBase{ptr, MASK_BYTES}};
+ if (!exec->SynchronousMemcpy(&gpu_ptr, mask, MASK_BYTES)) {
+ LOG(FATAL) << "Could not copy debug mask";
+ }
+}
+
+} // namespace
+
+// -----------------------------------------------------------------------------
+// GPUDebugAllocator
+// -----------------------------------------------------------------------------
+GPUDebugAllocator::GPUDebugAllocator(VisitableAllocator* allocator,
+ int device_id)
+ : base_allocator_(allocator) {
+ stream_exec_ = GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+}
+
+GPUDebugAllocator::~GPUDebugAllocator() { delete base_allocator_; }
+
+void* GPUDebugAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
+ num_bytes += (2 * MASK_BYTES);
+
+ void* allocated_ptr = base_allocator_->AllocateRaw(alignment, num_bytes);
+
+ // Return the pointer after the header
+ void* rv = static_cast<char*>(allocated_ptr) + MASK_BYTES;
+
+ // Write the header at allocated_ptr
+ InitMask(stream_exec_, allocated_ptr, before_mask);
+
+ // Write the footer at the end.
+ size_t req_size = base_allocator_->RequestedSize(allocated_ptr);
+ InitMask(stream_exec_,
+ static_cast<char*>(allocated_ptr) + req_size - MASK_BYTES,
+ after_mask);
+ return rv;
+}
+void GPUDebugAllocator::DeallocateRaw(void* ptr) {
+ CHECK(CheckHeader(ptr)) << "before_mask has been overwritten";
+ CHECK(CheckFooter(ptr)) << "after_mask has been overwritten";
+
+ // Backtrack to the beginning of the header.
+ ptr = static_cast<void*>(static_cast<char*>(ptr) - MASK_BYTES);
+ // Deallocate the memory
+ base_allocator_->DeallocateRaw(ptr);
+}
+
+void GPUDebugAllocator::AddAllocVisitor(Visitor visitor) {
+ return base_allocator_->AddAllocVisitor(visitor);
+}
+
+void GPUDebugAllocator::AddFreeVisitor(Visitor visitor) {
+ return base_allocator_->AddFreeVisitor(visitor);
+}
+
+bool GPUDebugAllocator::TracksAllocationSizes() { return true; }
+
+size_t GPUDebugAllocator::RequestedSize(void* ptr) {
+ auto req_size =
+ base_allocator_->RequestedSize(static_cast<char*>(ptr) - MASK_BYTES);
+ return req_size - 2 * MASK_BYTES;
+}
+
+size_t GPUDebugAllocator::AllocatedSize(void* ptr) {
+ return base_allocator_->AllocatedSize(static_cast<char*>(ptr) - MASK_BYTES);
+}
+
+bool GPUDebugAllocator::CheckHeader(void* ptr) {
+ return CheckMask(stream_exec_, static_cast<char*>(ptr) - MASK_BYTES,
+ before_mask);
+}
+
+bool GPUDebugAllocator::CheckFooter(void* ptr) {
+ char* original_ptr = static_cast<char*>(ptr) - MASK_BYTES;
+ size_t req_size = base_allocator_->RequestedSize(original_ptr);
+ return CheckMask(stream_exec_, original_ptr + req_size - MASK_BYTES,
+ after_mask);
+}
+
+// -----------------------------------------------------------------------------
+// GPUNanResetAllocator
+// -----------------------------------------------------------------------------
+GPUNanResetAllocator::GPUNanResetAllocator(VisitableAllocator* allocator,
+ int device_id)
+ : base_allocator_(allocator) {
+ stream_exec_ = GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+}
+
+GPUNanResetAllocator::~GPUNanResetAllocator() { delete base_allocator_; }
+
+void* GPUNanResetAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
+ void* allocated_ptr = base_allocator_->AllocateRaw(alignment, num_bytes);
+
+ // Initialize the buffer to Nans
+ size_t req_size = base_allocator_->RequestedSize(allocated_ptr);
+ std::vector<float> nans(req_size / sizeof(float), std::nanf(""));
+ gpu::DeviceMemory<float> nan_ptr{
+ gpu::DeviceMemoryBase{static_cast<float*>(allocated_ptr), req_size}};
+
+ if (!stream_exec_->SynchronousMemcpy(&nan_ptr, &nans[0], req_size)) {
+ LOG(ERROR) << "Could not initialize to NaNs";
+ }
+
+ return allocated_ptr;
+}
+void GPUNanResetAllocator::DeallocateRaw(void* ptr) {
+ // Reset the buffer to Nans
+ size_t req_size = base_allocator_->RequestedSize(ptr);
+ std::vector<float> nans(req_size / sizeof(float), std::nanf(""));
+ gpu::DeviceMemory<float> nan_ptr{
+ gpu::DeviceMemoryBase{static_cast<float*>(ptr), req_size}};
+ if (!stream_exec_->SynchronousMemcpy(&nan_ptr, &nans[0], req_size)) {
+ LOG(ERROR) << "Could not initialize to NaNs";
+ }
+
+ // Deallocate the memory
+ base_allocator_->DeallocateRaw(ptr);
+}
+
+void GPUNanResetAllocator::AddAllocVisitor(Visitor visitor) {
+ return base_allocator_->AddAllocVisitor(visitor);
+}
+
+void GPUNanResetAllocator::AddFreeVisitor(Visitor visitor) {
+ return base_allocator_->AddFreeVisitor(visitor);
+}
+
+size_t GPUNanResetAllocator::RequestedSize(void* ptr) {
+ return base_allocator_->RequestedSize(ptr);
+}
+
+size_t GPUNanResetAllocator::AllocatedSize(void* ptr) {
+ return base_allocator_->AllocatedSize(ptr);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
new file mode 100644
index 0000000000..c9b564ffc4
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
@@ -0,0 +1,68 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/common_runtime/gpu/visitable_allocator.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+namespace tensorflow {
+
+// An allocator that wraps a GPU allocator and adds debugging
+// functionality that verifies that users do not write outside their
+// allocated memory.
+class GPUDebugAllocator : public VisitableAllocator {
+ public:
+ explicit GPUDebugAllocator(VisitableAllocator* allocator, int device_id);
+ ~GPUDebugAllocator() override;
+ string Name() override { return "gpu_debug"; }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+ void DeallocateRaw(void* ptr) override;
+ void AddAllocVisitor(Visitor visitor) override;
+ void AddFreeVisitor(Visitor visitor) override;
+ bool TracksAllocationSizes() override;
+ size_t RequestedSize(void* ptr) override;
+ size_t AllocatedSize(void* ptr) override;
+
+ // For testing.
+ bool CheckHeader(void* ptr);
+ bool CheckFooter(void* ptr);
+
+ private:
+ VisitableAllocator* base_allocator_ = nullptr; // owned
+
+ perftools::gputools::StreamExecutor* stream_exec_; // Not owned.
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GPUDebugAllocator);
+};
+
+// An allocator that wraps a GPU allocator and resets the memory on
+// allocation and free to 'NaN', helping to identify cases where the
+// user forgets to initialize the memory.
+class GPUNanResetAllocator : public VisitableAllocator {
+ public:
+ explicit GPUNanResetAllocator(VisitableAllocator* allocator, int device_id);
+ ~GPUNanResetAllocator() override;
+ string Name() override { return "gpu_nan_reset"; }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+ void DeallocateRaw(void* ptr) override;
+ void AddAllocVisitor(Visitor visitor) override;
+ void AddFreeVisitor(Visitor visitor) override;
+ size_t RequestedSize(void* ptr) override;
+ size_t AllocatedSize(void* ptr) override;
+
+ private:
+ VisitableAllocator* base_allocator_ = nullptr; // owned
+
+ perftools::gputools::StreamExecutor* stream_exec_; // Not owned.
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GPUNanResetAllocator);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
new file mode 100644
index 0000000000..5f63906576
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
@@ -0,0 +1,207 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h"
+
+#include <algorithm>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include <gtest/gtest.h>
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+TEST(GPUDebugAllocatorTest, OverwriteDetection_None) {
+ const int device_id = 0;
+ GPUDebugAllocator a(new GPUBFCAllocator(device_id, 1 << 30), device_id);
+ auto stream_exec =
+ GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+
+ for (int s : {8}) {
+ std::vector<int64> cpu_array(s);
+ memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
+ int64* gpu_array = a.Allocate<int64>(cpu_array.size());
+ gpu::DeviceMemory<int64> gpu_array_ptr{gpu::DeviceMemoryBase{gpu_array}};
+ ASSERT_TRUE(stream_exec->SynchronousMemcpy(&gpu_array_ptr, &cpu_array[0],
+ s * sizeof(int64)));
+ EXPECT_TRUE(a.CheckHeader(gpu_array));
+ EXPECT_TRUE(a.CheckFooter(gpu_array));
+
+ // Confirm no error on free.
+ a.DeallocateRaw(gpu_array);
+ }
+}
+
+TEST(GPUDebugAllocatorTest, OverwriteDetection_Header) {
+ for (int s : {8, 211}) {
+ EXPECT_DEATH(
+ {
+ const int device_id = 0;
+ GPUDebugAllocator a(new GPUBFCAllocator(device_id, 1 << 30),
+ device_id);
+ auto stream_exec =
+ GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+
+ std::vector<int64> cpu_array(s);
+ memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
+ int64* gpu_array = a.Allocate<int64>(cpu_array.size());
+
+ gpu::DeviceMemory<int64> gpu_array_ptr{
+ gpu::DeviceMemoryBase{gpu_array}};
+ ASSERT_TRUE(stream_exec->SynchronousMemcpy(
+ &gpu_array_ptr, &cpu_array[0], cpu_array.size() * sizeof(int64)));
+
+ gpu::DeviceMemory<int64> gpu_hdr_ptr{
+ gpu::DeviceMemoryBase{gpu_array - 1}};
+ // Clobber first word of the header.
+ float pi = 3.1417;
+ ASSERT_TRUE(
+ stream_exec->SynchronousMemcpy(&gpu_hdr_ptr, &pi, sizeof(float)));
+
+ // Expect error on free.
+ a.Deallocate(gpu_array);
+ },
+ "");
+ }
+}
+
+TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) {
+ for (int s : {8, 22}) {
+ EXPECT_DEATH(
+ {
+ const int device_id = 0;
+ GPUDebugAllocator a(new GPUBFCAllocator(device_id, 1 << 30),
+ device_id);
+ auto stream_exec =
+ GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+
+ std::vector<int64> cpu_array(s);
+ memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
+ int64* gpu_array = a.Allocate<int64>(cpu_array.size());
+
+ gpu::DeviceMemory<int64> gpu_array_ptr{
+ gpu::DeviceMemoryBase{gpu_array}};
+ ASSERT_TRUE(stream_exec->SynchronousMemcpy(
+ &gpu_array_ptr, &cpu_array[0], cpu_array.size() * sizeof(int64)));
+
+ // Clobber word of the footer.
+ gpu::DeviceMemory<int64> gpu_ftr_ptr{
+ gpu::DeviceMemoryBase{gpu_array + s}};
+ float pi = 3.1417;
+ ASSERT_TRUE(
+ stream_exec->SynchronousMemcpy(&gpu_ftr_ptr, &pi, sizeof(float)));
+
+ // Expect error on free.
+ a.Deallocate(gpu_array);
+ },
+ "");
+ }
+}
+
+TEST(GPUDebugAllocatorTest, ResetToNan) {
+ const int device_id = 0;
+ GPUNanResetAllocator a(new GPUBFCAllocator(device_id, 1 << 30), device_id);
+ auto stream_exec =
+ GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+
+ std::vector<float> cpu_array(1024);
+ std::vector<float> cpu_array_result(1024);
+
+ // Allocate 1024 floats
+ float* gpu_array = a.Allocate<float>(cpu_array.size());
+ gpu::DeviceMemory<float> gpu_array_ptr{gpu::DeviceMemoryBase{gpu_array}};
+ ASSERT_TRUE(stream_exec->SynchronousMemcpy(&cpu_array[0], gpu_array_ptr,
+ cpu_array.size() * sizeof(float)));
+ for (float f : cpu_array) {
+ ASSERT_FALSE(std::isfinite(f));
+ }
+
+ // Set one of the fields to 1.0.
+ cpu_array[0] = 1.0;
+ ASSERT_TRUE(stream_exec->SynchronousMemcpy(&gpu_array_ptr, &cpu_array[0],
+ cpu_array.size() * sizeof(float)));
+ // Copy the data back and verify.
+ ASSERT_TRUE(
+ stream_exec->SynchronousMemcpy(&cpu_array_result[0], gpu_array_ptr,
+ cpu_array_result.size() * sizeof(float)));
+ ASSERT_EQ(1.0, cpu_array_result[0]);
+
+ // Free the array
+ a.Deallocate(gpu_array);
+
+ // All values should be reset to nan.
+ ASSERT_TRUE(
+ stream_exec->SynchronousMemcpy(&cpu_array_result[0], gpu_array_ptr,
+ cpu_array_result.size() * sizeof(float)));
+ for (float f : cpu_array_result) {
+ ASSERT_FALSE(std::isfinite(f));
+ }
+}
+
+TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) {
+ const int device_id = 0;
+ // NaN reset must be the outer-most allocator.
+ GPUNanResetAllocator a(
+ new GPUDebugAllocator(new GPUBFCAllocator(device_id, 1 << 30), device_id),
+ device_id);
+ auto stream_exec =
+ GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+
+ std::vector<float> cpu_array(1024);
+ std::vector<float> cpu_array_result(1024);
+
+ // Allocate 1024 floats
+ float* gpu_array = a.Allocate<float>(cpu_array.size());
+ gpu::DeviceMemory<float> gpu_array_ptr{gpu::DeviceMemoryBase{gpu_array}};
+ ASSERT_TRUE(stream_exec->SynchronousMemcpy(&cpu_array[0], gpu_array_ptr,
+ cpu_array.size() * sizeof(float)));
+ for (float f : cpu_array) {
+ ASSERT_FALSE(std::isfinite(f));
+ }
+
+ // Set one of the fields to 1.0.
+ cpu_array[0] = 1.0;
+ ASSERT_TRUE(stream_exec->SynchronousMemcpy(&gpu_array_ptr, &cpu_array[0],
+ cpu_array.size() * sizeof(float)));
+ // Copy the data back and verify.
+ ASSERT_TRUE(
+ stream_exec->SynchronousMemcpy(&cpu_array_result[0], gpu_array_ptr,
+ cpu_array_result.size() * sizeof(float)));
+ ASSERT_EQ(1.0, cpu_array_result[0]);
+
+ // Free the array
+ a.Deallocate(gpu_array);
+
+ // All values should be reset to nan.
+ ASSERT_TRUE(
+ stream_exec->SynchronousMemcpy(&cpu_array_result[0], gpu_array_ptr,
+ cpu_array_result.size() * sizeof(float)));
+ for (float f : cpu_array_result) {
+ ASSERT_FALSE(std::isfinite(f));
+ }
+}
+
+TEST(GPUDebugAllocatorTest, TracksSizes) {
+ GPUDebugAllocator a(new GPUBFCAllocator(0, 1 << 30), 0);
+ EXPECT_EQ(true, a.TracksAllocationSizes());
+}
+
+TEST(GPUDebugAllocatorTest, AllocatedVsRequested) {
+ GPUNanResetAllocator a(
+ new GPUDebugAllocator(new GPUBFCAllocator(0, 1 << 30), 0), 0);
+ float* t1 = a.Allocate<float>(1);
+ EXPECT_EQ(4, a.RequestedSize(t1));
+ EXPECT_EQ(256, a.AllocatedSize(t1));
+ a.Deallocate(t1);
+}
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
new file mode 100644
index 0000000000..26d34645f1
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -0,0 +1,651 @@
+// TODO(opensource): Use a more generic sounding preprocessor name than
+// GOOGLE_CUDA
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
+
+#include <stdlib.h>
+#include <string.h>
+
+//#include "base/commandlineflags.h"
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
+#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/types.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+#if defined(PLATFORM_GOOGLE)
+DEFINE_bool(brain_gpu_sync_every_op, false,
+ "If true, call GPUUtil::Sync() between every dispatched opkernel.");
+
+DEFINE_int32(brain_gpu_max_streams, 1,
+ "Max number of GPU streams to use for computation.");
+#else
+// TODO(opensource): These should be made options in some options struct,
+// rather than flags.
+bool FLAGS_brain_gpu_sync_every_op = false;
+tensorflow::int32 FLAGS_brain_gpu_max_streams = 1;
+#endif
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+// Eigen Ops directly allocate memory only for temporary buffers used
+// during OpKernel::Compute(). The recommended way of allocating such
+// memory is via OpKernelContext::allocate_temp(). However, Eigen Ops
+// don't have access to OpKernelContext, instead they get access to
+// memory directly through the device allocator. As an Open Source
+// project, Eigen assumes allocator semantics similar to those of the
+// CUDA memory allocator, and may not work correctly due to race
+// conditions if used with some other allocator. For safety, we need
+// to delay deallocation calls out of Eigen until all events on the
+// corresponding stream have completed. The following two classes
+// serve this purpose in two different compilation environments.
+
+#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
+class EigenAllocator : public ::Eigen::Allocator {
+ public:
+ explicit EigenAllocator(gpu::Stream* stream, ::tensorflow::Allocator* alloc,
+ EventMgr* em)
+ : stream_(stream), allocator_(alloc), em_(em) {}
+
+ void* allocate(size_t num_bytes) const override {
+ void* ret = allocator_->AllocateRaw(32 /* alignment */, num_bytes);
+ // Eigen doesn't typically check the return pointer from allocate,
+ // so we do it here and die with a more helpful error message.
+ if (ret == nullptr) {
+ LOG(FATAL) << "EigenAllocator for GPU ran out of memory when allocating "
+ << num_bytes << ". See error logs for more detailed info.";
+ }
+ return ret;
+ }
+
+ void deallocate(void* buffer) const override {
+ em_->ThenDeleteBuffer(stream_, {allocator_, buffer});
+ }
+
+ private:
+ gpu::Stream* stream_; // Not owned.
+ ::tensorflow::Allocator* allocator_; // Not owned.
+ ::tensorflow::EventMgr* em_; // Not owned.
+
+ TF_DISALLOW_COPY_AND_ASSIGN(EigenAllocator);
+};
+
+#else
+class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
+ public:
+ EigenCudaStreamDevice(const cudaStream_t* cuda_stream, int gpu_id,
+ ::tensorflow::Allocator* alloc)
+ : stream_(cuda_stream), allocator_(alloc) {
+ Eigen::initializeDeviceProp();
+ device_prop_ = &Eigen::m_deviceProperties[gpu_id];
+ }
+
+ const cudaStream_t& stream() const override { return *stream_; }
+ const cudaDeviceProp& deviceProperties() const override {
+ return *device_prop_;
+ }
+
+ void* allocate(size_t num_bytes) const override {
+ void* ret = allocator_->AllocateRaw(32 /* alignment */, num_bytes);
+ if (ret == nullptr) {
+ LOG(FATAL) << "EigenAllocator for GPU ran out of memory when allocating "
+ << num_bytes << ". See error logs for more detailed info.";
+ }
+
+ return ret;
+ }
+ void deallocate(void* buffer) const override {
+ AsyncFreeData* afData = new AsyncFreeData(allocator_, buffer);
+ cudaError_t err = cudaStreamAddCallback(*stream_, asyncFree, afData, 0);
+ CHECK_EQ(err, cudaSuccess);
+ }
+
+ private:
+ struct AsyncFreeData {
+ AsyncFreeData(::tensorflow::Allocator* a, void* p)
+ : allocator_(a), address_(p) {}
+ ::tensorflow::Allocator* allocator_;
+ void* address_;
+ };
+
+ static void CUDART_CB asyncFree(cudaStream_t stream, cudaError_t status,
+ void* userData) {
+ AsyncFreeData* data = static_cast<AsyncFreeData*>(userData);
+ data->allocator_->DeallocateRaw(data->address_);
+ delete data;
+ }
+
+ const cudaStream_t* stream_; // Not owned.
+ const cudaDeviceProp* device_prop_; // Not owned.
+ ::tensorflow::Allocator* allocator_; // Not owned.
+
+ TF_DISALLOW_COPY_AND_ASSIGN(EigenCudaStreamDevice);
+};
+
+#endif
+
+BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
+ Bytes memory_limit, BusAdjacency bus_adjacency,
+ int gpu_id, const string& physical_device_desc,
+ Allocator* gpu_allocator, Allocator* cpu_allocator)
+ : LocalDevice(options, Device::BuildDeviceAttributes(
+ name, DEVICE_GPU, memory_limit, bus_adjacency,
+ physical_device_desc),
+ gpu_allocator),
+ gpu_allocator_(gpu_allocator),
+ cpu_allocator_(cpu_allocator),
+ gpu_id_(gpu_id) {
+ gpu::StreamExecutor* executor =
+ GPUMachineManager()->ExecutorForDevice(gpu_id_).ValueOrDie();
+ if (!executor) {
+ LOG(ERROR) << "Failed to get StreamExecutor for device " << gpu_id_;
+ return;
+ }
+ em_.reset(new EventMgr(executor));
+
+ if (FLAGS_brain_gpu_max_streams < 1) {
+ LOG(FATAL) << "Invalid value for brain_gpu_max_streams.";
+ }
+
+ // Create the specified number of GPU streams
+ for (int i = 0; i < FLAGS_brain_gpu_max_streams; i++) {
+ auto stream = new gpu::Stream(executor);
+ stream->Init();
+ VLOG(2) << "Created stream[" << i << "] = " << stream;
+ streams_.push_back(stream);
+ device_contexts_.push_back(new GPUDeviceContext(i, stream));
+ }
+ gpu_device_info_ = new GpuDeviceInfo;
+ gpu_device_info_->stream = streams_[0];
+ gpu_device_info_->default_context = device_contexts_[0];
+ gpu_device_info_->event_mgr = em_.get();
+ set_tensorflow_gpu_device_info(gpu_device_info_);
+}
+
+BaseGPUDevice::~BaseGPUDevice() {
+ delete gpu_device_info_;
+ for (auto ctx : device_contexts_) ctx->Unref();
+ gtl::STLDeleteElements(&streams_);
+}
+
+Status BaseGPUDevice::FillContextMap(const Graph* graph,
+ DeviceContextMap* device_context_map) {
+ VLOG(2) << "FillContextMap";
+
+ const auto num_streams = streams_.size();
+ // Special case for single stream.
+ if (num_streams == 1) {
+ return Status::OK();
+ }
+ const int64 before = Env::Default()->NowMicros();
+ gpu_stream_util::AssignStreamsOpts opts;
+ opts.max_streams = num_streams;
+ std::unordered_map<int, int> node_to_stream_id;
+ TF_RETURN_IF_ERROR(
+ gpu_stream_util::AssignStreams(graph, opts, &node_to_stream_id));
+ int64 elapsed = Env::Default()->NowMicros() - before;
+ VLOG(3) << "AssignStreams took " << elapsed << "us";
+
+ // Fill in the context map. It is OK for this map to contain
+ // duplicate DeviceContexts so long as we increment the refcount.
+ for (Node* n : graph->nodes()) {
+ auto mapped_stream = node_to_stream_id[n->id()];
+ CHECK_LE(mapped_stream, num_streams);
+ auto ctx = device_contexts_[mapped_stream];
+ VLOG(3) << "Assigned stream " << node_to_stream_id[n->id()]
+ << " ==> stream[" << ctx->stream_id() << "] for node id " << n->id()
+ << " " << n->type_string() << " " << n->name();
+ ctx->Ref();
+ device_context_map->insert(std::make_pair(n->id(), ctx));
+ }
+
+ return Status::OK();
+}
+
+void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
+ // ScopedActivity is cheap when tracing is not active, but we
+ // can avoid computing the Hash64.
+ // TODO(pbar) This would no longer be needed if Ops have a unique id.
+ const uint64 id = port::Tracing::IsActive() ? Hash64(op_kernel->name()) : 0;
+ port::Tracing::ScopedActivity region(port::Tracing::EventCategory::kCompute,
+ id);
+
+ GPUDeviceContext* gpu_device_context = device_contexts_[0];
+ if (context->op_device_context() != nullptr) {
+ gpu_device_context =
+ static_cast<GPUDeviceContext*>(context->op_device_context());
+ }
+ gpu::Stream* stream = gpu_device_context->stream();
+ const auto stream_id = gpu_device_context->stream_id();
+
+ VLOG(1) << "GpuDevice::Compute " << op_kernel->name() << " op "
+ << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream["
+ << stream_id << "]";
+
+ // NOTE(tucker): We need to discriminate between Eigen GPU
+ // operations and all others. If an operation is Eigen
+ // implemented (or otherwise tries to launch a cuda kernel
+ // directly), we need to establish a stacked-scoped environment
+ // that directs it to execute on the proper device. Otherwise we
+ // expect the Op to use StreamExecutor directly and correctly. The
+ // way we make this discrimination is quite hacky: At the moment
+ // the only non-Eigen GPU Op is the recv-op, which is known to be
+ // asynchronous.
+ if (op_kernel->type_string() == "_Recv") {
+ context->SetStatus(errors::Internal(
+ "Invalid synchronous 'Compute' on GPU for '_Recv' op"));
+ } else {
+ const string label =
+ strings::StrCat(op_kernel->name(), ":", op_kernel->type_string());
+ port::Tracing::ScopedAnnotation annotation(label);
+
+ const auto num_streams = streams_.size();
+ if (num_streams > 1) {
+ // If this op's device context is different from the other contexts,
+ // we must wait on the stream.
+ for (int i = 0; i < context->num_inputs(); ++i) {
+ const GPUDeviceContext* idc =
+ static_cast<GPUDeviceContext*>(context->input_device_context(i));
+ OP_REQUIRES(context, idc != nullptr,
+ errors::Internal("Input device context ", i,
+ " was not set properly."));
+ if (VLOG_IS_ON(2)) {
+ const void* base;
+ size_t len;
+ if (context->has_input(i)) {
+ if (IsRefType(context->input_dtype(i))) {
+ Tensor tensor = context->mutable_input(i, false);
+ base = DMAHelper::base(&tensor);
+ len = tensor.TotalBytes();
+ } else {
+ const Tensor& tensor = context->input(i);
+ base = DMAHelper::base(&tensor);
+ len = tensor.TotalBytes();
+ }
+ VLOG(2) << "Input " << i << " " << base << " " << len;
+ VLOG(2) << " stream[" << stream_id << "].ThenWaitFor(stream["
+ << idc->stream_id() << "])"
+ << ((idc->stream() == stream) ? " not needed" : "");
+ }
+ }
+ if (idc->stream() != stream) stream->ThenWaitFor(idc->stream());
+ }
+ }
+ gpu::cuda::ScopedActivateExecutorContext scoped_activation{
+ stream->parent(), gpu::cuda::MultiOpActivation::kYes};
+ // Keep a copy of the inputs before Compute runs, in case they get
+ // deleted. TODO(misard) this will be fixed when the tracking is
+ // done right.
+ std::vector<Tensor>* tensor_refs = nullptr;
+ if (!FLAGS_brain_gpu_sync_every_op) {
+ tensor_refs = new std::vector<Tensor>;
+ tensor_refs->reserve(context->num_inputs() + context->num_outputs());
+ for (int ii = 0; ii < context->num_inputs(); ++ii) {
+ if (context->has_input(ii)) {
+ if (IsRefType(context->input_dtype(ii))) {
+ Tensor in = context->mutable_input(ii, false);
+ tensor_refs->push_back(in);
+ } else {
+ const Tensor& in = context->input(ii);
+ tensor_refs->push_back(in);
+ }
+ }
+ }
+ }
+ op_kernel->Compute(context);
+ if (context->status().ok()) {
+ if (FLAGS_brain_gpu_sync_every_op) {
+ // Note: GPUUtil::Sync() only syncs the default stream.
+ // We need to either sync the stream used by this op, or
+ // all streams. Given that this flag is typically used for
+ // debugging it makes more sense to sync all GPU activity.
+ context->SetStatus(GPUUtil::SyncAll(this));
+ } else {
+ // The GPU kernel has been queued, but may not complete for some
+ // time. As soon as this function completes, the caller will
+ // discard its refs on the inputs, outputs and any scratch
+ // tensors it created. Create additional refs here that will be
+ // held until the kernel completes.
+ for (int ii = 0; ii < context->num_temps(); ++ii) {
+ Tensor* temp = context->temp(ii);
+ VLOG(2) << "Saving ref to temp Tensor @ " << DMAHelper::base(temp);
+ tensor_refs->push_back(*temp);
+ }
+ for (int ii = 0; ii < context->num_outputs(); ++ii) {
+ Tensor* temp = context->mutable_output(ii);
+ if (nullptr != temp) {
+ tensor_refs->push_back(*temp);
+ }
+ }
+ em_->ThenDeleteTensors(stream, tensor_refs);
+ }
+ } else {
+ if (!FLAGS_brain_gpu_sync_every_op) {
+ delete tensor_refs;
+ }
+ }
+ }
+}
+
+Status BaseGPUDevice::Sync() { return GPUUtil::Sync(this); }
+
+void BaseGPUDevice::ComputeAsync(AsyncOpKernel* op_kernel,
+ OpKernelContext* context,
+ AsyncOpKernel::DoneCallback done) {
+ GPUDeviceContext* gpu_device_context = device_contexts_[0];
+ if (context->op_device_context() != nullptr) {
+ gpu_device_context =
+ static_cast<GPUDeviceContext*>(context->op_device_context());
+ }
+ const auto stream_id = gpu_device_context->stream_id();
+
+ VLOG(1) << "GpuDevice::ComputeAsync " << op_kernel->name() << " op "
+ << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream["
+ << stream_id << "]";
+
+ port::Tracing::TraceMe activity(
+ strings::StrCat(op_kernel->name(), ":", op_kernel->type_string()));
+ op_kernel->ComputeAsync(context, done);
+}
+
+Status BaseGPUDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
+ const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) {
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ attr.set_gpu_compatible(true);
+ Allocator* host_alloc = GetAllocator(attr);
+ Tensor parsed(tensor_proto.dtype());
+ if (!parsed.FromProto(host_alloc, tensor_proto)) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ tensor_proto.DebugString());
+ }
+ Status status;
+ if (alloc_attrs.on_host()) {
+ *tensor = parsed;
+ } else {
+ if (!DMAHelper::CanUseDMA(&parsed)) {
+ return errors::Internal("GPU copy from non-DMA ",
+ DataTypeString(parsed.dtype()), " tensor");
+ }
+ Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
+ port::Tracing::ScopedAnnotation annotation("MakeTensorFromProto");
+ Notification n;
+ device_contexts_[0]->CopyCPUTensorToDevice(&parsed, this, &copy,
+ [&n, &status](const Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ *tensor = copy;
+ }
+ return status;
+}
+
+namespace {
+#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
+class ConcretePerOpGpuDevice : public PerOpGpuDevice {
+ public:
+ explicit ConcretePerOpGpuDevice(gpu::Stream* stream,
+ EigenAllocator* allocator)
+ : device_(stream, allocator), allocator_(allocator) {}
+ ~ConcretePerOpGpuDevice() { delete allocator_; }
+
+ const Eigen::GpuDevice& device() const override { return device_; }
+
+ private:
+ Eigen::GpuDevice device_;
+ EigenAllocator* allocator_;
+};
+#else
+class ConcretePerOpGpuDevice : public PerOpGpuDevice {
+ public:
+ explicit ConcretePerOpGpuDevice(EigenCudaStreamDevice* stream_device)
+ : device_(stream_device), stream_device_(stream_device) {}
+ ~ConcretePerOpGpuDevice() { delete stream_device_; }
+
+ const Eigen::GpuDevice& device() const override { return device_; }
+
+ private:
+ Eigen::GpuDevice device_;
+ EigenCudaStreamDevice* stream_device_;
+};
+#endif
+} // namespace
+
+const PerOpGpuDevice* BaseGPUDevice::NewDevice(int stream_id,
+ Allocator* allocator) {
+#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
+ auto ea = new EigenAllocator(streams_[stream_id], allocator, em_.get());
+ return new ConcretePerOpGpuDevice(streams_[stream_id], ea);
+#else
+ const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>(
+ streams_[stream_id]->implementation()->CudaStreamMemberHack());
+ auto es = new EigenCudaStreamDevice(cuda_stream, gpu_id_, allocator);
+ return new ConcretePerOpGpuDevice(es);
+#endif
+}
+
+const PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice(DeviceContext* dc,
+ Allocator* allocator) {
+ if (dc) {
+ const GPUDeviceContext* gpu_dc = static_cast<GPUDeviceContext*>(dc);
+ const int stream_id = gpu_dc->stream_id();
+ VLOG(1) << " eigen_gpu_device(" << dc << ") => stream[" << stream_id
+ << "]";
+ CHECK_LT(stream_id, streams_.size());
+ return NewDevice(stream_id, allocator);
+ } else {
+ return NewDevice(0, allocator);
+ }
+}
+
+void BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
+ const string& name_prefix,
+ std::vector<Device*>* devices) {
+ int n = INT_MAX;
+ auto iter = options.config.device_count().find("GPU");
+ if (iter != options.config.device_count().end()) {
+ n = iter->second;
+ }
+ std::vector<int> valid_gpu_ids;
+ GetValidDeviceIds(&valid_gpu_ids);
+ if (static_cast<size_t>(n) > valid_gpu_ids.size()) {
+ n = valid_gpu_ids.size();
+ }
+ for (int i = 0; i < n; i++) {
+ devices->push_back(CreateGPUDevice(
+ options, strings::StrCat(name_prefix, "/gpu:", i), valid_gpu_ids[i]));
+ }
+}
+
+namespace {
+int64 MinSystemMemory(int64 available_memory) {
+ // We use the following heuristic for now:
+ //
+ // If the available_memory is < 2GiB, we allocate 200MiB to system memory.
+ // Otherwise, allocate 300MiB to system memory.
+ //
+ // In the future we could be more sophisticated by using a table of
+ // devices.
+ if (available_memory < (1LL << 31)) {
+ // 200MiB
+ return 209715200LL;
+ } else {
+ // max(300 MiB, 0.95 * available_memory)
+ return std::max(314572800LL, static_cast<int64>(available_memory * 0.05));
+ }
+}
+} // namespace
+
+static string GetShortDeviceDescription(int device_id,
+ const gpu::DeviceDescription& desc) {
+ return strings::StrCat("device: ", device_id, ", name: ", desc.name(),
+ ", pci bus id: ", desc.pci_bus_id());
+}
+
+LocalDevice* BaseGPUDeviceFactory::CreateGPUDevice(
+ const SessionOptions& options, const string& name, int gpu_id) {
+ CHECK_GE(gpu_id, 0);
+
+ // Look up the device, to see its attributes.
+ gpu::Platform* gpu_platform = GPUMachineManager();
+ CHECK_LT(gpu_id, gpu_platform->VisibleDeviceCount());
+ gpu::StreamExecutor* se =
+ gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie();
+ const gpu::DeviceDescription& desc = se->GetDeviceDescription();
+
+ int64 total_memory, available_memory;
+ CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory));
+
+ int64 allocated_memory = available_memory;
+ double config_memory_fraction =
+ options.config.gpu_options().per_process_gpu_memory_fraction();
+ if (config_memory_fraction == 0) {
+ const int64 min_system_memory = MinSystemMemory(available_memory);
+ if (min_system_memory < allocated_memory) {
+ allocated_memory -= min_system_memory;
+ }
+ } else {
+ allocated_memory *= config_memory_fraction;
+ }
+
+ Bytes allocated_bytes = static_cast<Bytes>(allocated_memory);
+
+ // Get GPU BusAdjacency from its reported NUMA affinity.
+ // Because GPUs are virtualized in some environments, we can't just
+ // use the GPU id.
+ BusAdjacency bus_adjacency = BUS_ANY;
+ switch (desc.numa_node()) {
+ case 0:
+ bus_adjacency = BUS_0;
+ break;
+ case 1:
+ bus_adjacency = BUS_1;
+ break;
+ default:
+ bus_adjacency = BUS_ANY;
+ }
+ VLOG(1) << "GPUDevice id " << gpu_id << " on bus " << bus_adjacency
+ << " numa: " << desc.numa_node() << " pci: " << desc.pci_bus_id();
+
+ ProcessState* process_state = ProcessState::singleton();
+ return CreateGPUDevice(
+ options, name, allocated_bytes, bus_adjacency, gpu_id,
+ GetShortDeviceDescription(gpu_id, desc),
+ process_state->GetGPUAllocator(gpu_id, allocated_memory),
+ process_state->GetCPUAllocator(desc.numa_node()));
+}
+
+static int GetMinGPUMultiprocessorCount() {
+ static const int kDefaultMinGPUMultiprocessorCount = 8;
+
+ const char* tf_min_gpu_core_count = getenv("TF_MIN_GPU_MULTIPROCESSOR_COUNT");
+
+ if (tf_min_gpu_core_count == nullptr ||
+ strcmp(tf_min_gpu_core_count, "") == 0) {
+ return kDefaultMinGPUMultiprocessorCount;
+ }
+
+ int min_gpu_core_count = -1;
+ if (strings::safe_strto32(tf_min_gpu_core_count, &min_gpu_core_count)) {
+ if (min_gpu_core_count >= 0) {
+ return min_gpu_core_count;
+ }
+ }
+
+ LOG(ERROR) << "Invalid minimum GPU multiprocessor count: ["
+ << tf_min_gpu_core_count << "]. "
+ << "Using the default value: "
+ << kDefaultMinGPUMultiprocessorCount;
+ return kDefaultMinGPUMultiprocessorCount;
+}
+
+void BaseGPUDeviceFactory::GetValidDeviceIds(std::vector<int>* ids) {
+ auto gpu_manager = GPUMachineManager();
+ int min_gpu_core_count = GetMinGPUMultiprocessorCount();
+ if (gpu_manager) {
+ auto visible_device_count = gpu_manager->VisibleDeviceCount();
+ for (int i = 0; i < gpu_manager->VisibleDeviceCount(); ++i) {
+ auto exec_status = gpu_manager->ExecutorForDevice(i);
+ if (!exec_status.ok()) {
+ continue;
+ }
+ gpu::StreamExecutor* se = exec_status.ValueOrDie();
+ const gpu::DeviceDescription& desc = se->GetDeviceDescription();
+ int major, minor;
+ if (!desc.cuda_compute_capability(&major, &minor)) {
+ continue;
+ }
+ // Only consider GPUs with compute capability >= 3.5 (Kepler or
+ // higher)
+ if (major < 3 || (major == 3 && minor < 5)) {
+ LOG(INFO) << "Ignoring gpu device "
+ << "(" << GetShortDeviceDescription(i, desc) << ") "
+ << "with Cuda compute capability " << major << "." << minor
+ << ". The minimum required Cuda capability is 3.5.";
+ continue;
+ }
+
+ // TensorFlow currently places computation on devices assuming
+ // they have similar capability.
+ //
+ // If there are multiple GPUs available on the machine, only
+ // consider GPUs with 8 or more multiprocessors.
+ //
+ // TODO(vrv): In the medium term: we should only filter out GPUs
+ // that are slow relative to the fastest GPU. In the long term,
+ // TensorFlow should support automatic placement based on
+ // capability.
+ if (visible_device_count > 1) {
+ if (desc.core_count() < min_gpu_core_count) {
+ LOG(INFO) << "Ignoring gpu device "
+ << "(" << GetShortDeviceDescription(i, desc) << ") "
+ << "with Cuda multiprocessor count: " << desc.core_count()
+ << ". The minimum required count is " << min_gpu_core_count
+ << ". You can adjust this requirement with the env var "
+ "TF_MIN_GPU_MULTIPROCESSOR_COUNT.";
+ continue;
+ }
+ }
+
+ int new_id = ids->size();
+ ids->push_back(i);
+
+ LOG(INFO) << "Creating TensorFlow device (/gpu:" << new_id << ") -> "
+ << "(" << GetShortDeviceDescription(i, desc) << ")";
+ }
+ }
+}
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
new file mode 100644
index 0000000000..a415224d95
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -0,0 +1,94 @@
+#if !GOOGLE_CUDA
+#error This file must only be included when building with Cuda support
+#endif
+
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEVICE_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEVICE_H_
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
+#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+class EigenAllocator;
+
+class BaseGPUDevice : public LocalDevice {
+ public:
+ BaseGPUDevice(const SessionOptions& options, const string& name,
+ Bytes memory_limit, BusAdjacency bus_adjacency, int gpu_id,
+ const string& physical_device_desc, Allocator* gpu_allocator,
+ Allocator* cpu_allocator);
+
+ ~BaseGPUDevice() override;
+
+ // GPU devices require the Op Compute method to save a reference to
+ // any temporary tensors that are allocated until the Op execution
+ // completes.
+ bool SaveTemporaryTensors() const override { return true; }
+
+ Status FillContextMap(const Graph* graph,
+ DeviceContextMap* device_context_map);
+
+ void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
+
+ Status Sync() override;
+
+ void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
+ AsyncOpKernel::DoneCallback done) override;
+
+ Status MakeTensorFromProto(const TensorProto& tensor_proto,
+ const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) override;
+
+ // The caller owns the returned device.
+ const PerOpGpuDevice* MakeGpuDevice(DeviceContext* dc,
+ Allocator* allocator) override;
+
+ protected:
+ Allocator* gpu_allocator_; // not owned
+ Allocator* cpu_allocator_; // not owned
+
+ private:
+ std::vector<gpu::Stream*> streams_;
+ std::vector<GPUDeviceContext*> device_contexts_;
+ GpuDeviceInfo* gpu_device_info_ = nullptr;
+ mutex trace_mu_;
+ int gpu_id_ = -1;
+ std::unique_ptr<EventMgr> em_;
+
+ const PerOpGpuDevice* NewDevice(int stream_id, Allocator* allocator);
+};
+
+class BaseGPUDeviceFactory : public DeviceFactory {
+ public:
+ void CreateDevices(const SessionOptions& options, const string& name_prefix,
+ std::vector<Device*>* devices) override;
+
+ private:
+ LocalDevice* CreateGPUDevice(const SessionOptions& options,
+ const string& name, int gpu_id);
+
+ virtual LocalDevice* CreateGPUDevice(const SessionOptions& options,
+ const string& name, Bytes memory_limit,
+ BusAdjacency bus_adjacency, int gpu_id,
+ const string& physical_device_desc,
+ Allocator* gpu_allocator,
+ Allocator* cpu_allocator) = 0;
+
+ void GetValidDeviceIds(std::vector<int>* ids);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
new file mode 100644
index 0000000000..240ac47499
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
@@ -0,0 +1,52 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
+
+namespace tensorflow {
+
+void RequireGPUDevice() {}
+
+class GPUDevice : public BaseGPUDevice {
+ public:
+ GPUDevice(const SessionOptions& options, const string& name,
+ Bytes memory_limit, BusAdjacency bus_adjacency, int gpu_id,
+ const string& physical_device_desc, Allocator* gpu_allocator,
+ Allocator* cpu_allocator)
+ : BaseGPUDevice(options, name, memory_limit, bus_adjacency, gpu_id,
+ physical_device_desc, gpu_allocator, cpu_allocator) {}
+
+ Allocator* GetAllocator(AllocatorAttributes attr) override {
+ if (attr.on_host()) {
+ ProcessState* ps = ProcessState::singleton();
+ if (attr.gpu_compatible()) {
+ return ps->GetCUDAHostAllocator(0);
+ } else {
+ return cpu_allocator_;
+ }
+ } else {
+ return gpu_allocator_;
+ }
+ }
+};
+
+class GPUDeviceFactory : public BaseGPUDeviceFactory {
+ private:
+ LocalDevice* CreateGPUDevice(const SessionOptions& options,
+ const string& name, Bytes memory_limit,
+ BusAdjacency bus_adjacency, int gpu_id,
+ const string& physical_device_desc,
+ Allocator* gpu_allocator,
+ Allocator* cpu_allocator) override {
+ return new GPUDevice(options, name, memory_limit, bus_adjacency, gpu_id,
+ physical_device_desc, gpu_allocator, cpu_allocator);
+ }
+};
+
+REGISTER_LOCAL_DEVICE_FACTORY("GPU", GPUDeviceFactory);
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
new file mode 100644
index 0000000000..29d6281733
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
@@ -0,0 +1,132 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+
+#include "tensorflow/stream_executor/event.h"
+#include "tensorflow/stream_executor/stream.h"
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+EventMgr::EventMgr(gpu::StreamExecutor* se)
+ : exec_(se),
+ // threadpool_ has 1 thread for the polling loop, and one to execute
+ // event callback functions. Maybe we should have more?
+ threadpool_(Env::Default(), "GPU_Event_Manager", 2) {
+ threadpool_.Schedule([this]() { PollLoop(); });
+}
+
+EventMgr::~EventMgr() {
+ stop_polling_.Notify();
+ // Shut down the backup polling loop.
+ polling_stopped_.WaitForNotification();
+
+ // Events are owned by this object.
+ for (auto& e : free_events_) {
+ delete e;
+ }
+ while (!used_events_.empty()) {
+ delete used_events_[0].event;
+ delete used_events_[0].mem;
+ if (used_events_[0].bufrec.buf) {
+ used_events_[0].bufrec.alloc->DeallocateRaw(used_events_[0].bufrec.buf);
+ }
+ if (used_events_[0].func != nullptr)
+ threadpool_.Schedule(used_events_[0].func);
+ used_events_.pop_front();
+ }
+}
+
+// This polling loop runs at a relatively low frequency. Most calls to
+// PollEvents() should come directly from Compute() via
+// ThenDeleteTensors(). This function's purpose is to ensure that
+// even if no more GPU operations are being requested, we still
+// eventually clear the queue. It seems to prevent some tensorflow
+// programs from stalling for reasons not yet understood.
+void EventMgr::PollLoop() {
+ while (!stop_polling_.HasBeenNotified()) {
+ Env::Default()->SleepForMicroseconds(1 * 1000);
+ {
+ mutex_lock l(mu_);
+ PollEvents(true);
+ }
+ }
+ polling_stopped_.Notify();
+}
+
+void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu) {
+ VLOG(2) << "QueueInUse free_events_ " << free_events_.size()
+ << " used_events_ " << used_events_.size();
+ // Events are created on demand, and repeatedly reused. There is no
+ // limit placed here on the number of allocated Events.
+ if (free_events_.empty()) {
+ free_events_.push_back(new gpu::Event(exec_));
+ free_events_.back()->Init();
+ }
+ gpu::Event* e = free_events_.back();
+ free_events_.pop_back();
+ stream->ThenRecordEvent(e);
+ iu.event = e;
+ used_events_.push_back(iu);
+}
+
+// This function must be called periodically to check whether pending
+// events have recorded, and then retire them. Initial observations
+// suggest that typical behavior in a TensorFlow program is to have
+// 0-3 events pending most of the time, but there are occasionally
+// spikes of up to several hundred outstanding.
+//
+// NOTE: If all events are on the same stream, no later event will
+// complete before an earlier event, except possibly if the earlier
+// event transitions to an error state, so there's no advantage in
+// looking past the first kPending event. However, if we're using
+// multiple streams there may be some gain in looking deeper.
+// As a compromise, PollEvent() calls that are triggered by the queueing
+// of a single event never look past the first kPending event. Calls
+// coming from the dedicated polling thread always sweep the full queue.
+//
+// Note that allowing the queue to grow very long could cause overall
+// GPU memory use to spike needlessly. An alternative strategy would
+// be to throttle new Op execution until the pending event queue
+// clears.
+void EventMgr::PollEvents(bool is_dedicated_poller) {
+ VLOG(2) << "PollEvents free_events_ " << free_events_.size()
+ << " used_events_ " << used_events_.size();
+ // Sweep the remaining events in order. If this is the dedicated
+ // polling thread, check the entire set. Otherwise, just sweep up to
+ // the first non-complete record that is still pending.
+ for (auto& iu : used_events_) {
+ if (iu.event == nullptr) continue;
+ gpu::Event::Status s = iu.event->PollForStatus();
+ switch (s) {
+ case gpu::Event::Status::kUnknown:
+ case gpu::Event::Status::kError:
+ // We don't expect to see these. Someday maybe propagate
+ // a Status error, but for now fail hard.
+ LOG(FATAL) << "Unexpected Event status: " << static_cast<int>(s);
+ break;
+ case gpu::Event::Status::kPending:
+ if (!is_dedicated_poller) return; // quit processing queue
+ break;
+ case gpu::Event::Status::kComplete:
+ delete iu.mem;
+ if (iu.bufrec.buf) iu.bufrec.alloc->DeallocateRaw(iu.bufrec.buf);
+ // The function must be called in another thread, outside of
+ // the mutex held here.
+ if (iu.func != nullptr) threadpool_.Schedule(iu.func);
+ free_events_.push_back(iu.event);
+ // Mark this InUse record as completed.
+ iu.event = nullptr;
+ }
+ }
+ // Then clear any completed InUse records from the front of the queue.
+ while (!used_events_.empty()) {
+ InUse& iu = used_events_.front();
+ if (iu.event == nullptr) {
+ used_events_.pop_front();
+ } else {
+ break;
+ }
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
new file mode 100644
index 0000000000..f9436566d4
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
@@ -0,0 +1,118 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
+
+#include <deque>
+#include <vector>
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace perftools {
+namespace gputools {
+class Event;
+class Stream;
+class StreamExecutor;
+} // namespace gputools
+} // namespace perftools
+
+namespace tensorflow {
+
+// An object to keep track of pending Events in the StreamExecutor streams
+// and associated Tensors that cannot safely be deleted until the associated
+// Events are recorded.
+class EventMgr {
+ public:
+ explicit EventMgr(perftools::gputools::StreamExecutor* se);
+
+ ~EventMgr();
+
+ // Takes ownership of *tensors and deletes it as soon as all events
+ // currently enqueued on *stream have completed.
+ inline void ThenDeleteTensors(perftools::gputools::Stream* stream,
+ std::vector<Tensor>* tensors) {
+ mutex_lock l(mu_);
+ QueueTensors(stream, tensors);
+ PollEvents(false);
+ }
+
+ struct BufRec {
+ Allocator* alloc;
+ void* buf;
+ };
+
+ // Takes ownership of *bufrec.buf and calls bufrec.alloc->DeallocateRaw()
+ // on it as soon as all events currently enqueued on *stream have completed.
+ inline void ThenDeleteBuffer(perftools::gputools::Stream* stream,
+ BufRec bufrec) {
+ mutex_lock l(mu_);
+ QueueBuffer(stream, bufrec);
+ PollEvents(false);
+ }
+
+ inline void ThenExecute(perftools::gputools::Stream* stream,
+ std::function<void()> func) {
+ mutex_lock l(mu_);
+ QueueFunc(stream, func);
+ PollEvents(false);
+ }
+
+ private:
+ friend class TEST_EventMgrHelper;
+ mutex mu_;
+ perftools::gputools::StreamExecutor* exec_;
+
+ struct InUse {
+ perftools::gputools::Event* event;
+ std::vector<Tensor>* mem;
+ BufRec bufrec;
+ std::function<void()> func;
+ };
+
+ // Stream-enqueue an unused Event and save with it a collection of
+ // Tensors and/or a BufRec to be deleted only after the Event
+ // records.
+ void QueueInUse(perftools::gputools::Stream* stream, InUse in_use)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ void QueueTensors(perftools::gputools::Stream* stream,
+ std::vector<Tensor>* tensors)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ QueueInUse(stream, {nullptr, tensors, BufRec(), nullptr});
+ }
+
+ void QueueBuffer(perftools::gputools::Stream* stream, BufRec bufrec)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ QueueInUse(stream, {nullptr, nullptr, bufrec, nullptr});
+ }
+
+ void QueueFunc(perftools::gputools::Stream* stream,
+ std::function<void()> func) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ QueueInUse(stream, {nullptr, nullptr, BufRec(), func});
+ }
+
+ // This function should be called at roughly the same tempo as
+ // QueueTensors() to check whether pending events have recorded,
+ // and then retire them.
+ void PollEvents(bool is_dedicated_poller) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // An internal polling loop that runs at a low frequency to clear
+ // straggler Events.
+ void PollLoop();
+
+ // A stack of unused events
+ std::vector<perftools::gputools::Event*> free_events_ GUARDED_BY(mu_);
+
+ // A FIFO queue of InUse events and associated tensors.
+ std::deque<InUse> used_events_ GUARDED_BY(mu_);
+
+ Notification stop_polling_;
+ Notification polling_stopped_;
+
+ // The main PollLoop for the event manager runs in this threadpool.
+ thread::ThreadPool threadpool_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
new file mode 100644
index 0000000000..30ca1ff187
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
@@ -0,0 +1,152 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include <gtest/gtest.h>
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+class TEST_EventMgrHelper {
+ public:
+ explicit TEST_EventMgrHelper(EventMgr* em) : em_(em) {}
+
+ int queue_size() {
+ mutex_lock l(em_->mu_);
+ return em_->used_events_.size();
+ }
+
+ int free_size() {
+ mutex_lock l(em_->mu_);
+ return em_->free_events_.size();
+ }
+
+ void QueueTensors(perftools::gputools::Stream* stream,
+ std::vector<Tensor>* tensors) {
+ mutex_lock l(em_->mu_);
+ em_->QueueTensors(stream, tensors);
+ }
+
+ void PollEvents(bool is_dedicated_poller) {
+ mutex_lock l(em_->mu_);
+ em_->PollEvents(is_dedicated_poller);
+ }
+
+ private:
+ EventMgr* em_;
+};
+
+namespace {
+
+TEST(EventMgr, Empty) {
+ auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
+ EventMgr em(stream_exec);
+ TEST_EventMgrHelper th(&em);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(0, th.free_size());
+}
+
+// Delaying polling until after several enqueings should grow the
+// total number of allocated events. Once we have enough events for
+// the max simultaneously pending, we should not allocate any more.
+TEST(EventMgr, DelayedPolling) {
+ auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
+ EventMgr em(stream_exec);
+ TEST_EventMgrHelper th(&em);
+ EXPECT_EQ(0, th.queue_size());
+ std::vector<Tensor>* v = nullptr;
+ std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
+ CHECK(stream.get());
+ stream->Init();
+ for (int i = 0; i < 5; ++i) {
+ v = new std::vector<Tensor>;
+ th.QueueTensors(stream.get(), v);
+ EXPECT_EQ(i + 1, th.queue_size());
+ EXPECT_EQ(0, th.free_size());
+ }
+ th.PollEvents(false);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(5, th.free_size());
+ for (int j = 0; j < 2; ++j) {
+ for (int i = 0; i < 5; ++i) {
+ v = new std::vector<Tensor>;
+ th.QueueTensors(stream.get(), v);
+ EXPECT_EQ(i + 1, th.queue_size());
+ EXPECT_EQ(4 - i, th.free_size());
+ }
+ th.PollEvents(false);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(5, th.free_size());
+ }
+}
+
+// Immediate polling should require only one event to be allocated.
+TEST(EventMgr, ImmediatePolling) {
+ auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
+ EventMgr em(stream_exec);
+ TEST_EventMgrHelper th(&em);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(0, th.free_size());
+ std::vector<Tensor>* v = nullptr;
+ std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
+ CHECK(stream.get());
+ stream->Init();
+ for (int i = 0; i < 5; ++i) {
+ v = new std::vector<Tensor>;
+ em.ThenDeleteTensors(stream.get(), v);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(1, th.free_size());
+ }
+}
+
+// If we delay polling by more than 1 second, the backup polling loop
+// should clear the queue.
+TEST(EventMgr, LongDelayedPolling) {
+ auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
+ EventMgr em(stream_exec);
+ TEST_EventMgrHelper th(&em);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(0, th.free_size());
+ std::vector<Tensor>* v = nullptr;
+ std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
+ CHECK(stream.get());
+ stream->Init();
+ for (int i = 0; i < 5; ++i) {
+ v = new std::vector<Tensor>;
+ th.QueueTensors(stream.get(), v);
+ EXPECT_EQ(1 + i, th.queue_size());
+ EXPECT_EQ(0, th.free_size());
+ }
+ sleep(1);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(5, th.free_size());
+}
+
+// Deleting the EventMgr when events are still pending should shut
+// down gracefully.
+TEST(EventMgr, NonEmptyShutdown) {
+ auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
+ EventMgr em(stream_exec);
+ TEST_EventMgrHelper th(&em);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(0, th.free_size());
+ std::vector<Tensor>* v = nullptr;
+ std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
+ CHECK(stream.get());
+ stream->Init();
+ for (int i = 0; i < 5; ++i) {
+ v = new std::vector<Tensor>;
+ th.QueueTensors(stream.get(), v);
+ EXPECT_EQ(1 + i, th.queue_size());
+ EXPECT_EQ(0, th.free_size());
+ }
+}
+
+} // namespace
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_init.cc b/tensorflow/core/common_runtime/gpu/gpu_init.cc
new file mode 100644
index 0000000000..631a47eb91
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_init.cc
@@ -0,0 +1,147 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+
+#include <string>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+namespace {
+
+std::unique_ptr<std::map<std::pair<int, int>, bool>> GetPeerAccessMap(
+ gpu::Platform* platform, int device_count) {
+ auto* map = new std::map<std::pair<int, int>, bool>;
+ for (int i = 0; i < device_count; ++i) {
+ for (int j = 0; j < device_count; ++j) {
+ gpu::StreamExecutor* from = platform->ExecutorForDevice(i).ValueOrDie();
+ gpu::StreamExecutor* to = platform->ExecutorForDevice(j).ValueOrDie();
+ (*map)[{i, j}] = from->CanEnablePeerAccessTo(to);
+ }
+ }
+
+ return std::unique_ptr<std::map<std::pair<int, int>, bool>>{map};
+}
+
+Status EnablePeerAccess(gpu::Platform* platform, int device_count) {
+ for (int i = 0; i < device_count; ++i) {
+ for (int j = 0; j < device_count; ++j) {
+ gpu::StreamExecutor* from = platform->ExecutorForDevice(i).ValueOrDie();
+ gpu::StreamExecutor* to = platform->ExecutorForDevice(j).ValueOrDie();
+
+ if (from->CanEnablePeerAccessTo(to)) {
+ auto status = from->EnablePeerAccessTo(to);
+ if (!status.ok()) {
+ return errors::Internal(status.ToString());
+ }
+ } else {
+ LOG(INFO) << "cannot enable peer access from device ordinal " << i
+ << " to device ordinal " << j;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+static void InitGPU() {
+ auto result = gpu::MultiPlatformManager::PlatformWithName("CUDA");
+ if (!result.ok()) {
+ LOG(WARNING)
+ << "Not initializing the GPU, could not create GPU MachineManager. "
+ << "Error: " << result.status();
+ return;
+ }
+
+ gpu::Platform* platform = result.ValueOrDie();
+
+ int dev_count = platform->VisibleDeviceCount();
+
+ if (dev_count == 0) {
+ LOG(INFO) << "No GPU devices available on machine.";
+ return;
+ }
+
+ for (int i = 0; i < dev_count; ++i) {
+ auto stream_exec = platform->ExecutorForDevice(i).ValueOrDie();
+ int64 free_bytes;
+ int64 total_bytes;
+ if (!stream_exec->DeviceMemoryUsage(&free_bytes, &total_bytes)) {
+ // Logs internally on failure.
+ free_bytes = 0;
+ total_bytes = 0;
+ }
+ const auto& description = stream_exec->GetDeviceDescription();
+ int cc_major;
+ int cc_minor;
+ if (!description.cuda_compute_capability(&cc_major, &cc_minor)) {
+ // Logs internally on failure.
+ cc_major = 0;
+ cc_minor = 0;
+ }
+ LOG(INFO) << "Found device " << i << " with properties: "
+ << "\nname: " << description.name() << "\nmajor: " << cc_major
+ << " minor: " << cc_minor << " memoryClockRate (GHz) "
+ << description.clock_rate_ghz() << "\npciBusID "
+ << description.pci_bus_id() << "\nTotal memory: "
+ << strings::HumanReadableNumBytes(total_bytes)
+ << "\nFree memory: "
+ << strings::HumanReadableNumBytes(free_bytes);
+ }
+
+ // Enable peer access
+
+ auto status = EnablePeerAccess(platform, dev_count);
+ if (!status.ok()) {
+ LOG(FATAL) << "could not enable peer access for GPU devices: " << status;
+ }
+
+ // Print out a matrix showing which devices can DMA to one
+ // another.
+ auto access_map = GetPeerAccessMap(platform, dev_count);
+ string line_buf = "DMA: ";
+ for (int i = 0; i < dev_count; ++i) {
+ strings::StrAppend(&line_buf, i, " ");
+ }
+ LOG(INFO) << line_buf;
+ for (int i = 0; i < dev_count; ++i) {
+ line_buf = strings::StrCat(i, ": ");
+ for (int j = 0; j < dev_count; ++j) {
+ if ((*access_map)[{i, j}]) {
+ line_buf.append("Y ");
+ } else {
+ line_buf.append("N ");
+ }
+ }
+ LOG(INFO) << line_buf;
+ }
+}
+
+static bool InitModule() {
+ InitGPU();
+ return true;
+}
+
+} // namespace
+
+gpu::Platform* GPUMachineManager() {
+ // Create the machine manager singleton and initialize the GPUs only
+ // once.
+ static bool init = InitModule();
+ CHECK(init); // Avoids compiler warning that init is unused.
+
+ auto result = gpu::MultiPlatformManager::PlatformWithName("CUDA");
+ if (!result.ok()) {
+ return nullptr;
+ }
+
+ return result.ValueOrDie();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_init.h b/tensorflow/core/common_runtime/gpu/gpu_init.h
new file mode 100644
index 0000000000..d126a8b1ca
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_init.h
@@ -0,0 +1,19 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_INIT_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_INIT_H_
+
+namespace perftools {
+namespace gputools {
+class Platform;
+} // namespace gputools
+} // namespace perftools
+
+namespace tensorflow {
+
+// Returns the GPU machine manager singleton, creating it and
+// initializing the GPUs on the machine if needed the first time it is
+// called.
+perftools::gputools::Platform* GPUMachineManager();
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_INIT_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_region_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_region_allocator.cc
new file mode 100644
index 0000000000..08ff55e221
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_region_allocator.cc
@@ -0,0 +1,371 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_region_allocator.h"
+
+//#include "base/commandlineflags.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/lib/core/bits.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+#if defined(PLATFORM_GOOGLE)
+DEFINE_bool(brain_gpu_region_allocator_heap_check_on_destruction, true,
+ "If true, the CUDA gpu manager checks that all allocated "
+ "memory through the GPU memory pool implementation has been "
+ "freed.");
+
+DEFINE_int64(brain_gpu_region_allocator_region_size, 0,
+ "If > 0, sets the default chunk-size allocatable from GPU memory. "
+ "Else defaults to entire GPU memory.");
+
+#else
+bool FLAGS_brain_gpu_region_allocator_heap_check_on_destruction = true;
+tensorflow::int64 FLAGS_brain_gpu_region_allocator_region_size = 0;
+#endif
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+GPURegionAllocator::GPURegionAllocator(int device_id, size_t total_bytes)
+ : device_id_(device_id), total_bytes_(total_bytes) {
+ // Get a pointer to the stream_executor for this device
+ stream_exec_ = GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+
+ // Set the region size based on explicit user request, or based on
+ // total GPU capacity.
+ if (FLAGS_brain_gpu_region_allocator_region_size > 0) {
+ region_size_ = FLAGS_brain_gpu_region_allocator_region_size;
+ } else {
+ region_size_ = static_cast<size_t>(total_bytes_);
+ }
+
+ LOG(INFO) << "Setting region size to " << region_size_;
+}
+
+GPURegionAllocator::~GPURegionAllocator() {
+ if (FLAGS_brain_gpu_region_allocator_heap_check_on_destruction) {
+ CheckForMemoryLeaks();
+ }
+
+ gtl::STLDeleteValues(&chunk_map_);
+
+ for (auto r : regions_) {
+ gpu::DeviceMemoryBase gpu_ptr{r->ptr};
+ stream_exec_->Deallocate(&gpu_ptr);
+ delete r;
+ }
+}
+
+void* GPURegionAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
+ static const int64 kMaxMillisToWait = 10000; // 10 seconds
+ return retry_helper_.AllocateRaw(
+ [this](size_t a, size_t nb, bool v) {
+ return AllocateRawInternal(a, nb, v);
+ },
+ kMaxMillisToWait, alignment, num_bytes);
+}
+
+void* GPURegionAllocator::AllocateRawInternal(size_t alignment,
+ size_t num_bytes,
+ bool dump_log_on_failure) {
+ if (num_bytes == 0) {
+ LOG(ERROR) << "tried to allocate 0 bytes";
+ return nullptr;
+ }
+ size_t chunk_size = ChunkSize(num_bytes);
+
+ VLOG(2) << "chunk_size " << chunk_size << " from num_bytes "
+ << strings::HumanReadableNumBytes(num_bytes);
+ mutex_lock l(lock_);
+ Pool* pool = &pools_[chunk_size];
+ if (pool->num_free == 0) {
+ if (!ExpandPool(pool, chunk_size, num_bytes, dump_log_on_failure)) {
+ if (dump_log_on_failure) {
+ LOG(WARNING) << "Out of GPU memory, see memory state dump above";
+ }
+ return nullptr;
+ }
+ }
+ CHECK_LT(0, pool->num_free);
+ CHECK(pool->first);
+ CHECK(pool->last);
+ Chunk* c = pool->first;
+ CHECK(c);
+ CHECK(!c->in_use);
+
+ c->in_use = true;
+ // Move c to the back of the queue.
+ if (c->next != nullptr) {
+ pool->first = c->next;
+ pool->first->prev = nullptr;
+ c->next = nullptr;
+ }
+
+ if (pool->last != c) {
+ pool->last->next = c;
+ c->prev = pool->last;
+ pool->last = c;
+ }
+ pool->num_free--;
+ pool->cumulative_malloced++;
+
+ void* rv = c->ptr;
+ c->bytes_allocated = num_bytes;
+
+ VLOG(2) << "new ptr " << rv;
+ return rv;
+}
+
+void GPURegionAllocator::DeallocateRaw(void* ptr) {
+ retry_helper_.DeallocateRaw([this](void* p) { DeallocateRawInternal(p); },
+ ptr);
+}
+
+void GPURegionAllocator::DeallocateRawInternal(void* ptr) {
+ VLOG(2) << "DeallocateRaw: " << ptr;
+ if (ptr == nullptr) {
+ LOG(ERROR) << "tried to deallocate nullptr";
+ return;
+ }
+
+ mutex_lock l(lock_);
+ ChunkMap::const_iterator iter = chunk_map_.find(ptr);
+ CHECK(iter != chunk_map_.end());
+
+ Chunk* c = iter->second;
+ VLOG(2) << "chunk of size " << c->size << " at " << c;
+
+ Pool* pool = &(pools_[c->size]);
+ // Move chunk to head of queue, and mark free.
+ DCHECK(c->in_use);
+ c->in_use = false;
+ if (c->prev) c->prev->next = c->next;
+ if (c->next) c->next->prev = c->prev;
+ if (pool->first == c) pool->first = c->next;
+ if (pool->last == c) pool->last = c->prev;
+ c->next = pool->first;
+ c->prev = nullptr;
+ if (c->next) c->next->prev = c;
+ pool->first = c;
+ if (pool->last == nullptr) pool->last = c;
+ pool->num_free++;
+ pool->cumulative_freed++;
+}
+
+bool GPURegionAllocator::ExpandPool(Pool* pool, size_t chunk_size,
+ size_t requested_size,
+ bool dump_log_on_failure) {
+ VLOG(1) << "ExpandPool of " << chunk_size << " from " << pool->num_chunks
+ << " current members";
+ DCHECK_NE(0, chunk_size);
+ // If chunk_size is < 4096, double the pool size. Otherwise
+ // just increase by one.
+ int num_chunks = pool->num_chunks;
+ if (num_chunks == 0) {
+ if (chunk_size > 4096) {
+ num_chunks = 1;
+ } else {
+ num_chunks = 4096 / chunk_size;
+ }
+ }
+ // For larger chunks, limit the amount of expansion.
+ size_t aggregate_size = num_chunks * chunk_size;
+ if (aggregate_size > (1 << 20)) {
+ num_chunks = static_cast<int>(
+ std::max(static_cast<size_t>(1), (1 << 20) / chunk_size));
+ }
+ while (num_chunks > 0) {
+ Region* r = (regions_.empty() ? nullptr : regions_.back());
+ if (r == nullptr ||
+ (((r->ptr + r->size) - r->next) < static_cast<int64>(chunk_size))) {
+ // Current region is not large enough to accommodate another chunk.
+ while (r == nullptr || (((r->ptr + r->size) - r->next) <
+ static_cast<int64>(chunk_size))) {
+ // Get another region.
+ size_t this_region_size = std::max(region_size_, chunk_size);
+
+ // Check if we would exceed our limit.
+ if (allocated_memory_ + this_region_size > total_bytes_) {
+ if (dump_log_on_failure) DumpMemoryLog();
+ return false;
+ }
+
+ // Perform the allocation, still checking that the allocator
+ // has not run out of memory.
+ gpu::DeviceMemory<char> gpu_mem =
+ stream_exec_->AllocateArray<char>(this_region_size);
+ if (gpu_mem == nullptr) {
+ if (dump_log_on_failure) DumpMemoryLog();
+ return false;
+ }
+
+ // We never release memory once expanded.
+ allocated_memory_ += this_region_size;
+
+ Region* nr = new Region;
+ nr->ptr = static_cast<char*>(gpu_mem.opaque());
+
+ if (VLOG_IS_ON(2)) {
+ int64 free_bytes;
+ int64 total_bytes;
+ if (stream_exec_->DeviceMemoryUsage(&free_bytes, &total_bytes)) {
+ VLOG(2) << "free " << free_bytes << " total " << total_bytes;
+ } else {
+ // Note: stream_exec call also logs internally on failure.
+ VLOG(2) << "could not retrieve memory usage";
+ }
+ }
+ VLOG(1) << "new Region of size " << this_region_size << " at "
+ << static_cast<void*>(nr->ptr) << " on device " << device_id_;
+ r = nr;
+ r->size = this_region_size;
+ r->next = r->ptr;
+ regions_.push_back(r);
+
+ for (auto visitor : region_visitors_) {
+ visitor(r->ptr, r->size);
+ }
+ }
+ } else {
+ // Allocate a new chunk and push on front of Pool.
+ Chunk* c = new Chunk;
+ c->ptr = r->next;
+ chunk_map_[c->ptr] = c;
+ c->size = chunk_size;
+ r->next += chunk_size;
+ c->next = pool->first;
+ if (c->next != nullptr) c->next->prev = c;
+ pool->first = c;
+ if (pool->last == nullptr) pool->last = c;
+ pool->num_chunks++;
+ pool->num_free++;
+ --num_chunks;
+ }
+ }
+
+ return true;
+}
+
+void GPURegionAllocator::CheckForMemoryLeaks() {
+ std::vector<string> errors;
+ mutex_lock l(lock_); // could use reader lock
+ for (auto pool_map : pools_) {
+ const Pool& p = pool_map.second;
+ Chunk* curr_chunk = p.first;
+ while (curr_chunk != nullptr) {
+ if (curr_chunk->in_use) {
+ errors.push_back(
+ strings::StrCat("Unfreed chunk of size ", curr_chunk->size));
+ }
+ curr_chunk = curr_chunk->next;
+ }
+ }
+ if (!errors.empty()) {
+ LOG(FATAL) << "GPU Memory leaks:\n" << str_util::Join(errors, "\n");
+ }
+}
+
+// Since there's no merging of chunks once allocated, we want to
+// maximize their reusablity (which argues for fewer, larger sizes),
+// while minimizing waste (which argues for tight-fitting sizes).
+//
+// The smallest unit of allocation is 256 bytes.
+// NOTE(tucker): akrizhevsky says that nvidia's memory manager always
+// aligns to 256 bytes, and doing so results in significant speedup.
+//
+// Up to 2^16 bytes we only allocate in powers of 2.
+//
+// Above that, we pick a max-waste which is the largest power
+// of 2 <= 1/16 of the requested size, then round up to the nearest
+// multiple of max_waste.
+//
+// static
+size_t GPURegionAllocator::ChunkSize(size_t bytes) {
+ if (bytes <= 256) {
+ return 256;
+ } else if (bytes <= (1 << 16)) {
+ return 1uLL << Log2Ceiling64(bytes);
+ } else {
+ // 1/16th of requested size
+ size_t max_waste = 1uLL << (Log2Ceiling64(bytes) - 4);
+ return (bytes + max_waste) & (~(max_waste - 1));
+ }
+}
+
+void GPURegionAllocator::AddAllocVisitor(Visitor visitor) {
+ VLOG(1) << "AddVisitor";
+ mutex_lock l(lock_);
+ region_visitors_.push_back(visitor);
+ for (auto region : regions_) {
+ visitor(region->ptr, region->size);
+ }
+}
+
+void GPURegionAllocator::DumpMemoryLog() {
+ size_t region_bytes = 0;
+ for (auto r : regions_) {
+ region_bytes += r->size;
+ }
+ size_t chunk_bytes = 0;
+ std::vector<size_t> chunk_sizes;
+ for (auto i : pools_) {
+ chunk_sizes.push_back(i.first);
+ }
+ std::sort(chunk_sizes.begin(), chunk_sizes.end());
+ for (auto i : chunk_sizes) {
+ int32 chunks_in_use = 0;
+ const Pool& p = pools_[i];
+ chunk_bytes += i * p.num_chunks;
+
+ if (p.num_chunks > 0) {
+ // Iterate backwards (allocated chunks are last).
+ Chunk* curr_chunk = p.last;
+ while (curr_chunk != nullptr) {
+ if (curr_chunk->in_use) {
+ ++chunks_in_use;
+ }
+ curr_chunk = curr_chunk->prev;
+ if (curr_chunk == p.first) {
+ break;
+ }
+ }
+ }
+
+ LOG(INFO) << "Chunk size: " << i << " ("
+ << strings::HumanReadableNumBytes(i) << ") Pool: " << p.ToString()
+ << "\nNumber of chunks: " << p.num_chunks
+ << ", in_use chunks: " << chunks_in_use;
+ }
+
+ LOG(INFO) << "Aggregate Region Memory: " << region_bytes << " ("
+ << strings::HumanReadableNumBytes(region_bytes) << ")";
+ LOG(INFO) << "Aggregate Chunk Memory: " << chunk_bytes << " ("
+ << strings::HumanReadableNumBytes(chunk_bytes) << ")";
+}
+
+bool GPURegionAllocator::TracksAllocationSizes() { return true; }
+
+size_t GPURegionAllocator::RequestedSize(void* ptr) {
+ mutex_lock l(lock_);
+ auto it = chunk_map_.find(ptr);
+ CHECK(it != chunk_map_.end())
+ << "Asked for requested size of pointer we never allocated: " << ptr;
+ auto c = it->second;
+ return c->bytes_allocated;
+}
+
+size_t GPURegionAllocator::AllocatedSize(void* ptr) {
+ mutex_lock l(lock_);
+ auto it = chunk_map_.find(ptr);
+ CHECK(it != chunk_map_.end())
+ << "Asked for allocated size of pointer we never allocated: " << ptr;
+ auto c = it->second;
+ return c->size;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_region_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_region_allocator.h
new file mode 100644
index 0000000000..1a250b6ede
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_region_allocator.h
@@ -0,0 +1,146 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_REGION_ALLOCATOR_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_REGION_ALLOCATOR_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
+#include "tensorflow/core/common_runtime/gpu/visitable_allocator.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace tensorflow {
+
+class GPURegionAllocator : public VisitableAllocator {
+ public:
+ // 'device_id' must be a valid device on the machine.
+ //
+ // total_bytes is how many bytes this allocator should allocate up
+ // to. This may be less than the total available.
+ explicit GPURegionAllocator(int device_id, size_t total_bytes);
+ ~GPURegionAllocator() override;
+
+ string Name() override { return "gpu_region"; }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+ void DeallocateRaw(void* ptr) override;
+ void AddAllocVisitor(Visitor visitor) override;
+ // Does nothing, because regions are never freed.
+ void AddFreeVisitor(Visitor visitor) override {}
+
+ bool TracksAllocationSizes() override;
+ size_t RequestedSize(void* ptr) override;
+ size_t AllocatedSize(void* ptr) override;
+
+ private:
+ // A Chunk is the header on a single piece of memory given back
+ // in response to an AllocateRaw() call.
+ struct Chunk {
+ char* ptr; // pointer to granted GPU buffer.
+ size_t size; // Full size of GPU buffer.
+ size_t bytes_allocated; // Bytes asked for by client.
+ bool in_use;
+ Chunk* prev; // Used for chaining in pool.
+ Chunk* next;
+ Chunk()
+ : ptr(nullptr),
+ size(0),
+ bytes_allocated(0),
+ in_use(false),
+ prev(nullptr),
+ next(nullptr) {}
+ };
+
+ // A Pool is a collection of same-sized Chunks.
+ struct Pool {
+ int num_chunks; // total chunks in this pool
+ int num_free; // total free chunks in this pool
+ int64 cumulative_malloced; // number of chunks malloced so far
+ int64 cumulative_freed; // number of chunks freed so far
+
+ // double-linked ring of chunks; all free chunks precede all
+ // granted chunks
+ Chunk* first;
+ Chunk* last;
+ Pool()
+ : num_chunks(0),
+ num_free(0),
+ cumulative_malloced(0),
+ cumulative_freed(0),
+ first(nullptr),
+ last(nullptr) {}
+
+ string ToString() const {
+ return strings::StrCat("chunks: ", num_chunks, " free: ", num_free,
+ " cumulative malloc: ", cumulative_malloced,
+ " cumulative freed: ", cumulative_freed);
+ }
+ };
+
+ // A Region is a single area of GPU memory that has been
+ // reserved by this class and carved up into Chunks.
+ struct Region {
+ char* ptr; // base GPU ptr
+ char* next; // frontier of unused part of region
+ size_t size;
+ Region() : ptr(nullptr), size(0) {}
+ };
+
+ // Calculate size of chunk for an allocation of this size.
+ // Min chunk size is 16, for alignment.
+ // For larger sizes, we round up somewhat so there are fewer
+ // size-specific pools.
+ static size_t ChunkSize(size_t bytes);
+
+ void* AllocateRawInternal(size_t alignment, size_t num_bytes,
+ bool dump_log_on_failure);
+ void DeallocateRawInternal(void* ptr);
+
+ bool ExpandPool(Pool* p, size_t chunk_size, size_t requested_size,
+ bool dump_log_on_failure) EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ // Inspects region maps and crashes with debug information if there
+ // are any memory leaks as detected by the region allocator.
+ void CheckForMemoryLeaks() LOCKS_EXCLUDED(lock_);
+
+ void DumpMemoryLog() EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ perftools::gputools::StreamExecutor* stream_exec_; // Not owned.
+
+ typedef std::unordered_map<size_t, Pool> PoolMap;
+ typedef std::unordered_map<void*, Chunk*> ChunkMap;
+
+ GPUAllocatorRetry retry_helper_;
+ mutable mutex lock_;
+ PoolMap pools_ GUARDED_BY(lock_);
+
+ // Owns regions.
+ std::vector<Region*> regions_ GUARDED_BY(lock_);
+
+ // Maps from GPU ptr to Chunk owning it.
+ //
+ // Owns chunks.
+ ChunkMap chunk_map_ GUARDED_BY(lock_);
+
+ // Called once on each region, ASAP.
+ std::vector<Visitor> region_visitors_ GUARDED_BY(lock_);
+
+ const int device_id_;
+
+ // Total amount of memory (in bytes) available to this Allocator
+ const size_t total_bytes_;
+
+ // Total amount of memory allocated to regions.
+ size_t allocated_memory_ = 0;
+
+ size_t region_size_ = 0;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GPURegionAllocator);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_REGION_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_region_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_region_allocator_test.cc
new file mode 100644
index 0000000000..07b0dd57f6
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_region_allocator_test.cc
@@ -0,0 +1,71 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/common_runtime/gpu/gpu_region_allocator.h"
+
+#include <algorithm>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include <gtest/gtest.h>
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+namespace {
+
+TEST(GPURegionAllocatorTest, Simple) {
+ GPURegionAllocator a(0, 1 << 26);
+ std::vector<void*> ptrs;
+ for (int s = 1; s < 1024; s++) {
+ void* raw = a.AllocateRaw(1, s);
+ ptrs.push_back(raw);
+ }
+ std::sort(ptrs.begin(), ptrs.end());
+ for (int i = 0; i < ptrs.size(); i++) {
+ if (i > 0) {
+ CHECK_NE(ptrs[i], ptrs[i - 1]); // No dups
+ }
+ a.DeallocateRaw(ptrs[i]);
+ }
+ float* t1 = a.Allocate<float>(1024);
+ double* t2 = a.Allocate<double>(1048576);
+ a.Deallocate(t1);
+ a.Deallocate(t2);
+}
+
+TEST(GPURegionAllocatorTest, CheckMemLeak) {
+ EXPECT_DEATH(
+ {
+ GPURegionAllocator a(0, 1 << 26);
+ float* t1 = a.Allocate<float>(1024);
+ if (t1) {
+ LOG(INFO) << "Not deallocating";
+ }
+ },
+ "");
+}
+
+TEST(GPURegionAllocatorTest, TracksSizes) {
+ GPURegionAllocator a(0, 1 << 26);
+ EXPECT_EQ(true, a.TracksAllocationSizes());
+}
+
+TEST(GPURegionAllocatorTest, AllocatedVsRequested) {
+ GPURegionAllocator a(0, 1 << 26);
+ float* t1 = a.Allocate<float>(1);
+ EXPECT_EQ(sizeof(float), a.RequestedSize(t1));
+
+ // Minimum allocation size if 256
+ EXPECT_EQ(256, a.AllocatedSize(t1));
+
+ a.Deallocate(t1);
+}
+
+} // namespace
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util.cc b/tensorflow/core/common_runtime/gpu/gpu_stream_util.cc
new file mode 100644
index 0000000000..ca86c7fa06
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util.cc
@@ -0,0 +1,97 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h"
+
+#include <set>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace gpu_stream_util {
+
+Status AssignStreams(const Graph* graph, const AssignStreamsOpts& opts,
+ std::unordered_map<int, int>* node_to_stream_id) {
+ VLOG(1) << "AssignStreams";
+ Status status;
+
+ // Sanity check arguments.
+ if (graph == nullptr)
+ status.Update(errors::InvalidArgument("Bad graph argument supplied."));
+ if (node_to_stream_id == nullptr) {
+ status.Update(
+ errors::InvalidArgument("Bad node_to_stream_id argument supplied."));
+ }
+ if ((opts.max_streams < 1) || (opts.send_stream >= opts.max_streams) ||
+ (opts.recv_stream >= opts.max_streams) ||
+ (opts.const_stream >= opts.max_streams) ||
+ (opts.compute_stream >= opts.max_streams)) {
+ status.Update(errors::InvalidArgument("Bad graph argument supplied."));
+ }
+ TF_RETURN_IF_ERROR(status);
+
+ // Topologically sort the nodes.
+ std::vector<Node*> order;
+ GetReversePostOrder(*graph, &order);
+ if (VLOG_IS_ON(2)) {
+ for (Node* n : order) {
+ const int node_id = n->id();
+ VLOG(2) << "Node " << node_id << " " << n->type_string() << " "
+ << n->name() << " " << n->in_edges().size() << " inputs";
+ for (const Edge* e : n->in_edges()) {
+ VLOG(2) << " Edge from " << e->src()->id() << " " << e->src()->name()
+ << " fanout " << e->src()->out_edges().size();
+ }
+ }
+ }
+ // We perform stream assigmnent assuming a large number of
+ // stream IDs and then map these down to the required number of streams
+ // using simple round-robin.
+ // Stream Assignment strategy:
+ // 1. Nodes with zero inputs are always be executed on a
+ // fresh stream.
+ // 2. Try to execute a node on the same stream as one of its
+ // inputs to avoid inter-stream dependencies.
+ // 3. If any input comes from a node with a large fanout then
+ // perhaps an indication that it is shared between parallel
+ // streams of work. We choose a new stream here so that all consumers
+ // of the tensor are likely to run in parallel.
+ int highest_stream_id = -1;
+ for (Node* n : order) {
+ VLOG(3) << "Inspecting node " << n->DebugString();
+ const int node_id = n->id();
+ const string& op = n->type_string();
+
+ // Determine a suitable stream to use.
+ int stream_id = highest_stream_id + 1;
+ for (const Edge* e : n->in_edges()) {
+ const int fanout = e->src()->out_edges().size();
+ if (fanout == 1) {
+ stream_id = (*node_to_stream_id)[e->src()->id()];
+ break;
+ }
+ }
+ // Override stream for specific op types.
+ if (op == "_Send") {
+ if (opts.send_stream >= 0) stream_id = opts.send_stream;
+ } else if (op == "_Recv") {
+ if (opts.recv_stream >= 0) stream_id = opts.recv_stream;
+ } else if (op == "Const") {
+ if (opts.const_stream >= 0) stream_id = opts.const_stream;
+ } else {
+ if (opts.compute_stream >= 0) stream_id = opts.compute_stream;
+ }
+
+ (*node_to_stream_id)[node_id] = stream_id % opts.max_streams;
+ highest_stream_id = std::max(stream_id, highest_stream_id);
+ }
+ VLOG(1) << "Identified " << highest_stream_id << " candidate streams for "
+ << order.size() << " nodes.";
+
+ return Status::OK();
+}
+
+} // namespace gpu_stream_util
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util.h b/tensorflow/core/common_runtime/gpu/gpu_stream_util.h
new file mode 100644
index 0000000000..e1c623382c
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util.h
@@ -0,0 +1,30 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
+
+#include <unordered_map>
+
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace gpu_stream_util {
+
+struct AssignStreamsOpts {
+ int32 max_streams = 1;
+ // The following options specify a stream to use for specific op
+ // types. The value -1 allows ops to be assigned to any stream.
+ int32 send_stream = -1;
+ int32 recv_stream = -1;
+ int32 const_stream = -1;
+ int32 compute_stream = -1;
+};
+
+// Given the input graph, assigns every node in the graph with a
+// stream_id that should be used.
+Status AssignStreams(const Graph* graph, const AssignStreamsOpts& opts,
+ std::unordered_map<int, int>* node_to_stream_id);
+
+} // namespace gpu_stream_util
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc
new file mode 100644
index 0000000000..5c426caaef
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc
@@ -0,0 +1,137 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace {
+
+class GpuStreamUtilTest : public OpsTestBase {
+ protected:
+ void SetUp() override { RequireDefaultOps(); }
+};
+
+TEST_F(GpuStreamUtilTest, BogusOpts) {
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Graph g(OpRegistry::Global());
+ ASSERT_OK(b.ToGraph(&g));
+ std::unordered_map<int, int> node_to_stream_id;
+ gpu_stream_util::AssignStreamsOpts opts;
+ Status status;
+ status = gpu_stream_util::AssignStreams(nullptr, opts, &node_to_stream_id);
+ EXPECT_FALSE(status.ok());
+ status = gpu_stream_util::AssignStreams(&g, opts, nullptr);
+ EXPECT_FALSE(status.ok());
+ opts.max_streams = 0;
+ status = gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id);
+ EXPECT_FALSE(status.ok());
+ opts.max_streams = 1;
+ opts.compute_stream = 5;
+ status = gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id);
+ EXPECT_FALSE(status.ok());
+}
+
+TEST_F(GpuStreamUtilTest, EmptyGraph) {
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Graph g(OpRegistry::Global());
+ ASSERT_OK(b.ToGraph(&g));
+ std::unordered_map<int, int> node_to_stream_id;
+ gpu_stream_util::AssignStreamsOpts opts;
+ ASSERT_OK(gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id));
+ EXPECT_EQ(2, node_to_stream_id.size()); // _SOURCE and _SINK
+}
+
+TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::MatMul(ops::Const(Tensor(DT_FLOAT), b.opts()),
+ ops::Const(Tensor(DT_FLOAT), b.opts()), b.opts());
+ Graph g(OpRegistry::Global());
+ ASSERT_OK(b.ToGraph(&g));
+
+ std::unordered_map<int, int> node_to_stream_id;
+ gpu_stream_util::AssignStreamsOpts opts;
+ ASSERT_OK(gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id));
+
+ // There should be 5 nodes assigned.
+ EXPECT_EQ(5, node_to_stream_id.size());
+
+ // All of them should have stream 0.
+ for (const auto& it : node_to_stream_id) {
+ EXPECT_EQ(0, it.second);
+ }
+}
+
+TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::MatMul(ops::Const(Tensor(DT_FLOAT), b.opts()),
+ ops::Const(Tensor(DT_FLOAT), b.opts()), b.opts());
+ Graph g(OpRegistry::Global());
+ ASSERT_OK(b.ToGraph(&g));
+
+ std::unordered_map<int, int> node_to_stream_id;
+ gpu_stream_util::AssignStreamsOpts opts;
+ opts.max_streams = 3;
+ ASSERT_OK(gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id));
+
+ // There should be 5 nodes assigned.
+ EXPECT_EQ(5, node_to_stream_id.size());
+
+ // All of them should have a stream in the range [0..max_streams).
+ for (const auto& it : node_to_stream_id) {
+ EXPECT_GE(it.second, 0);
+ EXPECT_LT(it.second, opts.max_streams);
+ }
+}
+
+TEST_F(GpuStreamUtilTest, StreamOverrides) {
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::_Recv(DT_FLOAT, "input", "/cpu:0", 0, "/gpu:0",
+ b.opts().WithName("input"));
+ auto n = ops::MatMul(ops::Const(Tensor(DT_FLOAT), b.opts()),
+ ops::Const(Tensor(DT_FLOAT), b.opts()), b.opts());
+ ops::_Send(n, "output", "/gpu:0", 0, "/cpu:0", b.opts().WithName("output"));
+ Graph g(OpRegistry::Global());
+ ASSERT_OK(b.ToGraph(&g));
+
+ // Perform stream assignment using a large number of streams, but with
+ // op types constrained to specific streams.
+ std::unordered_map<int, int> node_to_stream_id;
+ gpu_stream_util::AssignStreamsOpts opts;
+ opts.max_streams = 100;
+ opts.const_stream = 90;
+ opts.send_stream = 91;
+ opts.recv_stream = 92;
+ opts.compute_stream = 93;
+ ASSERT_OK(gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id));
+
+ // There should be 7 nodes assigned.
+ EXPECT_EQ(7, node_to_stream_id.size()); // including _SOURCE and _SINK
+
+ // Nodes should be assigned to streams by op type.
+ for (const auto& it : node_to_stream_id) {
+ Node* n = g.FindNodeId(it.first);
+ const string op = n->type_string();
+ const int stream = it.second;
+ if (op == "Const") {
+ EXPECT_EQ(stream, 90);
+ } else if (op == "_Send") {
+ EXPECT_EQ(stream, 91);
+ } else if (op == "_Recv") {
+ EXPECT_EQ(stream, 92);
+ } else { // Compute.
+ EXPECT_EQ(stream, 93);
+ }
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc
new file mode 100644
index 0000000000..a6a3ce01fc
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc
@@ -0,0 +1,345 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+
+//#include "base/commandlineflags.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/tensor_coding.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/common_runtime/gpu/dma_helper.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#include "tensorflow/core/util/util.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+#include "tensorflow/core/platform/stream_executor_util.h"
+
+#if defined(PLATFORM_GOOGLE)
+DEFINE_int64(brain_gpu_util_debug_string_maxlen, 128,
+ "When dumping gpu memory, prints up to this many bytes.");
+
+DECLARE_bool(record_mem_types);
+#else
+tensorflow::int64 FLAGS_brain_gpu_util_debug_string_maxlen = 128;
+bool FLAGS_EXPERIMENTAL_brain_gpu_multi_stream = false;
+extern bool FLAGS_record_mem_types;
+#endif
+
+using perftools::gputools::DeviceMemoryBase;
+using perftools::gputools::DeviceMemory;
+using perftools::gputools::Stream;
+
+namespace tensorflow {
+
+namespace gpu = ::perftools::gputools;
+
+/*static*/
+void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev,
+ const DeviceContext* device_context,
+ TensorProto* proto, bool is_dead,
+ StatusCallback done) {
+ VLOG(1) << "SetProtoFromGPU device_context " << device_context;
+ // Tensor values need to be copied from GPU to CPU ram so that
+ // we can build the protobuf response for a RecvTensor RPC.
+ // "device context" identifies the stream where the _Send op executed.
+ CHECK(device_context);
+ gpu::Stream* stream =
+ static_cast<const GPUDeviceContext*>(device_context)->stream();
+
+ if (!DMAHelper::CanUseDMA(&tensor)) {
+ done(errors::Internal(strings::StrCat(
+ "GPU copy from non-DMA ", DataTypeString(tensor.dtype()), "tensor")));
+ return;
+ }
+ proto->set_dtype(tensor.dtype());
+ tensor.shape().AsProto(proto->mutable_tensor_shape());
+ // Prepare a Cord with the right data buf size, and DMA the
+ // data over from the GPU buffer. Note that 0-size tensors
+ // do not have a backing buffer.
+ const size_t num_bytes = is_dead ? 0 : tensor.TotalBytes();
+ if (num_bytes > 0) {
+ port::Tracing::ScopedAnnotation annotation("SetProtoFromGPU");
+ Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
+ char* mb = alloc->Allocate<char>(num_bytes);
+ const char* src_ptr =
+ reinterpret_cast<const char*>(DMAHelper::base(&tensor));
+ DeviceMemoryBase gpu_src_ptr(const_cast<char*>(src_ptr), num_bytes);
+ stream->ThenMemcpy(mb, gpu_src_ptr, num_bytes);
+ // Use of tensor may outlive stack scope, so keep a ref.
+ Tensor* tensor_ref = new Tensor(tensor);
+ dev->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
+ stream, [stream, done, proto, mb, num_bytes, alloc, tensor_ref]() {
+ if (!stream->ok()) {
+ done(errors::Internal("SetProtoFromGPU: GPU Memcpy failed"));
+ // TODO(pbar) We currently have no way to recover the
+ // worker from a GPU stream in the error state. Until
+ // there is a way to reset the CUDA driver, it is
+ // preferable to crash the process and restart. Tracked
+ // under b/23717097
+ LOG(FATAL) << "SetProtoFromGPU: GPU Memcpy failed";
+ return;
+ }
+ delete tensor_ref;
+ port::CopyFromArray(proto->mutable_tensor_content(), mb, num_bytes);
+ alloc->Deallocate<char>(mb);
+ done(Status::OK());
+ });
+ } else {
+ done(Status::OK());
+ }
+}
+
+typedef ProcessState::MemDesc PMD;
+
+/*static*/
+void GPUUtil::CopyViaDMA(const string& edge_name,
+ DeviceContext* send_dev_context,
+ DeviceContext* recv_dev_context, Device* src,
+ Device* dst, AllocatorAttributes src_alloc_attr,
+ AllocatorAttributes dst_alloc_attr,
+ const Tensor* input, Tensor* output,
+ StatusCallback done) {
+ port::Tracing::ScopedAnnotation annotation(edge_name);
+ VLOG(1) << "CopyViaDMA " << edge_name;
+ size_t total_bytes = input->TotalBytes();
+ // Note that 0-size tensors have no backing buffer.
+ if (total_bytes > 0) {
+ const void* src_ptr = DMAHelper::base(input);
+ void* dst_ptr = DMAHelper::base(output);
+ VLOG(2) << "src_ptr " << src_ptr << " dst_ptr " << dst_ptr;
+ if (FLAGS_record_mem_types) {
+ ProcessState::MemDesc smd = ProcessState::singleton()->PtrType(src_ptr);
+ ProcessState::MemDesc dmd = ProcessState::singleton()->PtrType(dst_ptr);
+ VLOG(0) << "Src " << smd.DebugString() << " Dst " << dmd.DebugString();
+ if (smd.loc == PMD::CPU && dmd.loc == PMD::GPU && (!smd.gpu_registered)) {
+ LOG(WARNING) << "CPU -> GPU no reg for " << edge_name;
+ }
+ if (dmd.loc == PMD::CPU && smd.loc == PMD::GPU && (!dmd.gpu_registered)) {
+ LOG(WARNING) << "GPU -> CPU no reg for " << edge_name;
+ }
+ }
+
+ auto src_device_type = src->attributes().device_type();
+ auto dst_device_type = dst->attributes().device_type();
+
+ bool non_cpu_src = (!src_alloc_attr.on_host() &&
+ src_device_type != DeviceType(DEVICE_CPU).type());
+ bool non_cpu_dst = (!dst_alloc_attr.on_host() &&
+ dst_device_type != DeviceType(DEVICE_CPU).type());
+ if (non_cpu_src) {
+ gpu::Stream* stream = send_dev_context->stream();
+ if (stream == nullptr) {
+ done(errors::Internal("Failed to find device stream"));
+ return;
+ }
+ auto* src_dev_info = src->tensorflow_gpu_device_info();
+ CHECK(src_dev_info);
+
+ if (non_cpu_dst) {
+ // Device to device copy
+ DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes);
+ stream->ThenMemcpy(
+ &gpu_dst_ptr,
+ DeviceMemoryBase{const_cast<void*>(src_ptr), total_bytes},
+ total_bytes);
+ if (dst_device_type == DeviceType(DEVICE_GPU).type()) {
+ // Use of input may outlive stack scope, so keep a ref.
+ Tensor* input_ref = new Tensor(*input);
+ src_dev_info->event_mgr->ThenExecute(
+ stream, [done, stream, input_ref]() {
+ delete input_ref;
+ if (!stream->ok()) {
+ done(errors::Internal("GPU->GPU Memcpy failed"));
+ } else {
+ done(Status::OK());
+ }
+ });
+ }
+ send_dev_context->MaintainLifetimeOnStream(input, stream);
+ } else {
+ // Device to host copy.
+ return send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src,
+ output, done);
+ }
+ } else if (non_cpu_dst) {
+ // Host to Device copy.
+ // Note that this is already an async copy.
+ recv_dev_context->CopyCPUTensorToDevice(input, dst, output, done);
+ } else {
+ memcpy(dst_ptr, src_ptr, total_bytes);
+ done(Status::OK());
+ }
+ } else {
+ // buffer is empty
+ done(Status::OK());
+ }
+}
+
+void GPUUtil::CopyGPUTensorToCPU(Device* gpu_device,
+ const DeviceContext* device_context,
+ const Tensor* gpu_tensor, Tensor* cpu_tensor,
+ StatusCallback done) {
+ VLOG(1) << "CopyGPUTensorToCPU";
+ size_t total_bytes = gpu_tensor->TotalBytes();
+ // Note that 0-size tensors have no backing buffer.
+ if (total_bytes > 0) {
+ const void* src_ptr = DMAHelper::base(gpu_tensor);
+ void* dst_ptr = DMAHelper::base(cpu_tensor);
+ CHECK(dst_ptr);
+ auto* stream = gpu_device->tensorflow_gpu_device_info()->stream;
+ if (device_context) {
+ stream = static_cast<const GPUDeviceContext*>(device_context)->stream();
+ }
+ stream->ThenMemcpy(
+ dst_ptr, DeviceMemoryBase{const_cast<void*>(src_ptr), total_bytes},
+ total_bytes);
+ stream->BlockHostUntilDone();
+ if (!stream->ok()) {
+ done(errors::Internal("CopyGPUTensorToCPU: GPU->CPU Memcpy failed"));
+ return;
+ }
+ }
+
+ done(Status::OK());
+}
+
+/* static */
+void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor,
+ const DeviceContext* device_context,
+ Device* gpu_device, Tensor* gpu_tensor,
+ StatusCallback done) {
+ VLOG(1) << "CopyCPUTensorToGPU";
+ CHECK(DeviceType(gpu_device->attributes().device_type()) ==
+ DeviceType(DEVICE_GPU));
+
+ auto* dev_info = gpu_device->tensorflow_gpu_device_info();
+ if (!dev_info) {
+ done(errors::Internal("Failed to find dest device GPUDeviceInfo"));
+ return;
+ }
+ if (cpu_tensor->TotalBytes() != gpu_tensor->TotalBytes()) {
+ done(errors::Internal(
+ strings::StrCat("Can't copy ", cpu_tensor->TotalBytes(),
+ " bytes of a tensor into another with ",
+ gpu_tensor->TotalBytes(), " bytes buffer.")));
+ return;
+ }
+ const int64 total_bytes = cpu_tensor->TotalBytes();
+ // Note that 0-size tensors have no backing buffer.
+ if (total_bytes > 0) {
+ const void* src_ptr = DMAHelper::base(cpu_tensor);
+ void* dst_ptr = DMAHelper::base(gpu_tensor);
+ DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes);
+
+ CHECK(device_context);
+ auto* stream =
+ static_cast<const GPUDeviceContext*>(device_context)->stream();
+ stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes);
+ auto* dev_info = gpu_device->tensorflow_gpu_device_info();
+ // Use of cpu_tensor may outlive stack scope, so keep a ref.
+ Tensor* input_ref = new Tensor(*cpu_tensor);
+ dev_info->event_mgr->ThenExecute(stream, [stream, done, input_ref]() {
+ delete input_ref;
+ if (!stream->ok()) {
+ done(errors::Internal("CopyCPUTensorToGPU: GPU Memcpy failed"));
+ } else {
+ done(Status::OK());
+ }
+ });
+ } else {
+ // empty tensor case
+ done(Status::OK());
+ }
+}
+
+Status GPUUtil::Sync(Device* gpu_device) {
+ VLOG(1) << "GPUUtil::Sync";
+ auto* dev_info = gpu_device->tensorflow_gpu_device_info();
+ if (!dev_info) {
+ return errors::Internal("Failed to find dest device GPUDeviceInfo");
+ }
+ dev_info->stream->BlockHostUntilDone();
+ if (!dev_info->stream->ok()) {
+ LOG(FATAL) << "GPU sync failed";
+ }
+ return Status::OK();
+}
+
+Status GPUUtil::SyncAll(Device* gpu_device) {
+ VLOG(1) << "GPUUtil::SyncAll";
+ auto* dev_info = gpu_device->tensorflow_gpu_device_info();
+ if (!dev_info) {
+ return errors::Internal("Failed to find dest device GPUDeviceInfo");
+ }
+ if (!dev_info->stream->parent()->SynchronizeAllActivity() ||
+ !dev_info->stream->ok()) {
+ LOG(FATAL) << "GPU sync failed";
+ }
+ return Status::OK();
+}
+
+string GPUUtil::MemoryDebugString(const Device* device, Tensor* tensor) {
+ string ret;
+ CHECK(tensor);
+ const int64 num_bytes = std::min<int64>(
+ FLAGS_brain_gpu_util_debug_string_maxlen, tensor->TotalBytes());
+ void* ptr = (num_bytes > 0) ? DMAHelper::base(tensor) : nullptr;
+ strings::Appendf(&ret, "%p:", ptr);
+ if (num_bytes > 0) {
+ auto* dev_info = device->tensorflow_gpu_device_info();
+ if (!dev_info) {
+ strings::StrAppend(
+ &ret, PrintMemory(reinterpret_cast<const char*>(ptr), num_bytes));
+ } else {
+ string buf;
+ buf.resize(num_bytes);
+ DeviceMemoryBase gpu_ptr(ptr, num_bytes);
+ Status s = dev_info->stream->parent()->SynchronousMemcpyD2H(
+ gpu_ptr, num_bytes, gtl::string_as_array(&buf));
+ strings::StrAppend(&ret,
+ PrintMemory(gtl::string_as_array(&buf), num_bytes));
+ }
+ }
+ return ret;
+}
+
+// TODO(pbar) Checksum is called from places without a valid device context.
+uint64 GPUUtil::Checksum(Device* gpu_device,
+ const DeviceContext* device_context,
+ const Tensor& tensor) {
+ Tensor copy(tensor.dtype(), tensor.shape());
+ Status s;
+ Notification n;
+ CopyGPUTensorToCPU(gpu_device, device_context, &tensor, &copy,
+ [&s, &n](Status status) {
+ s.Update(status);
+ n.Notify();
+ });
+ n.WaitForNotification();
+ CHECK(s.ok()) << s;
+ return Checksum(copy);
+}
+
+uint64 GPUUtil::Checksum(const Tensor& tensor) {
+ const float* fptr = reinterpret_cast<const float*>(DMAHelper::base(&tensor));
+ size_t num_bytes = tensor.TotalBytes();
+ size_t num_floats = num_bytes / sizeof(float);
+ for (size_t i = 0; i < num_floats; ++i) {
+ CHECK(!std::isnan(fptr[i])) << " i " << i;
+ }
+ // TODO(tucker): consider using crc32c instead.
+ return Hash64(reinterpret_cast<const char*>(DMAHelper::base(&tensor)),
+ tensor.TotalBytes(), 0);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.h b/tensorflow/core/common_runtime/gpu/gpu_util.h
new file mode 100644
index 0000000000..1d8c3a054d
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_util.h
@@ -0,0 +1,89 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_UTIL_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_UTIL_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/common_runtime/gpu/dma_helper.h"
+#include "tensorflow/stream_executor/device_memory.h"
+
+#include "tensorflow/stream_executor/stream.h"
+
+namespace tensorflow {
+
+class RecvTensorResponse;
+class TensorProto;
+
+namespace gpu = ::perftools::gputools;
+
+class GPUUtil {
+ public:
+ // "tensor" is GPU-local. "dev" is the hosting GPU.
+ // "device_context" should be the context of the GPU "_Send" op
+ // which provides the Tensor.
+ // Sets all necessasry fields of "proto" by transferring value
+ // bytes from GPU to CPU RAM. "is_dead" indicates that the
+ // tensor is dead with an uninit value.
+ static void SetProtoFromGPU(const Tensor& tensor, Device* dev,
+ const DeviceContext* device_context,
+ TensorProto* proto, bool is_dead,
+ StatusCallback done);
+
+ // Copies "input" to "output" between devices accessible to the
+ // local process via some DMA-like method. "edge_name" is the name
+ // of the tensor being copied, for debugging purposes. Depending on
+ // the type of devices and memory in use, the copy may be performed
+ // synchronously or asynchronously. 'done' will be invoked only
+ // after the copy is actually complete.
+ static void CopyViaDMA(const string& edge_name,
+ DeviceContext* send_dev_context,
+ DeviceContext* recv_dev_context, Device* src,
+ Device* dst, const AllocatorAttributes src_alloc_attr,
+ const AllocatorAttributes dst_alloc_attr,
+ const Tensor* input, Tensor* output,
+ StatusCallback done);
+
+ // Copies the data in 'gpu_tensor' into 'cpu_tensor'.
+ // 'gpu_tensor''s backing memory must be on 'gpu_device' and
+ // 'cpu_tensor' must be allocated to be of the same size as
+ // 'gpu_tensor'. Synchronous: may block.
+ static void CopyGPUTensorToCPU(Device* gpu_device,
+ const DeviceContext* device_context,
+ const Tensor* gpu_tensor, Tensor* cpu_tensor,
+ StatusCallback done);
+
+ // Blocks until all operations queued on the stream associated with
+ // "gpu_device" at the time of the call have completed. Returns any
+ // error pending on the stream at completion.
+ static Status Sync(Device* gpu_device);
+
+ // Blocks until all operations queued on all streams associated with the
+ // corresponding GPU device at the time of call have completed.
+ // Returns any error pending on the stream at completion.
+ static Status SyncAll(Device* gpu_device);
+
+ // For debugging purpose, given a "device" and a "tensor" allocated
+ // on the device, return a string printing each byte in the tensor
+ // (up to a limit). "device" can be either a CPU or a GPU device.
+ static string MemoryDebugString(const Device* device, Tensor* tensor);
+
+ static perftools::gputools::DeviceMemory<float> AsGPUFloat(const Tensor& t);
+
+ // Computes a checksum over the contents of "tensor", which is allocated
+ // on "gpu_device".
+ static uint64 Checksum(Device* gpu_device,
+ const DeviceContext* device_context,
+ const Tensor& tensor);
+
+ // Computes a checksum over the contents of "tensor", which is allocated
+ // in local CPU RAM.
+ static uint64 Checksum(const Tensor& tensor);
+
+ static void CopyCPUTensorToGPU(const Tensor* cpu_tensor,
+ const DeviceContext* device_context,
+ Device* gpu_device, Tensor* gpu_tensor,
+ StatusCallback done);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_UTIL_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
new file mode 100644
index 0000000000..f1b1174a28
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
@@ -0,0 +1,24 @@
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/stream_executor/stream.h"
+
+namespace tensorflow {
+
+void GPUDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
+ Device* device,
+ Tensor* device_tensor,
+ StatusCallback done) const {
+ GPUUtil::CopyCPUTensorToGPU(cpu_tensor, this, device, device_tensor, done);
+}
+
+void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
+ const string& tensor_name,
+ Device* device, Tensor* cpu_tensor,
+ StatusCallback done) {
+ GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator.cc b/tensorflow/core/common_runtime/gpu/pool_allocator.cc
new file mode 100644
index 0000000000..52deb7fce2
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator.cc
@@ -0,0 +1,269 @@
+#include "tensorflow/core/common_runtime/gpu/pool_allocator.h"
+
+#include <errno.h>
+#include <strings.h>
+#include <sys/mman.h> // for munmap
+
+#include <map>
+
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+//#include "prodkernel/api/base/numa.h"
+
+namespace tensorflow {
+
+PoolAllocator::PoolAllocator(size_t pool_size_limit, bool auto_resize,
+ SubAllocator* allocator,
+ RoundUpInterface* size_rounder, string name)
+ : name_(name),
+ has_size_limit_(pool_size_limit > 0),
+ auto_resize_(auto_resize),
+ pool_size_limit_(pool_size_limit),
+ allocator_(allocator),
+ size_rounder_(size_rounder),
+ allocation_begun_(false) {
+ if (auto_resize) {
+ CHECK_LT(0, pool_size_limit)
+ << "size limit must be > 0 if auto_resize is true.";
+ }
+}
+
+PoolAllocator::~PoolAllocator() { Clear(); }
+
+namespace {
+// Pools contain Chunks allocatated from the underlying Allocator.
+// Chunk alignment is always on kPoolAlignment boundaries. Each Chunk
+// begins with a descriptor (ChunkPrefix) that gives its size and a
+// pointer to itself. The pointer returned to the user is just past
+// the ChunkPrefix. If the user asks for a larger alignment, we will
+// increase the size of the chunk, then adjust the returned user
+// pointer and also re-write the ChunkPrefix.chunk_ptr value
+// immediately before it. This way the Chunk address and size can be
+// recovered from the returned user pointer, regardless of alignment.
+// Note that this deferencing of the pointers means that we cannot
+// handle GPU memory, only CPU memory.
+struct ChunkPrefix {
+ size_t num_bytes;
+ void* chunk_ptr;
+};
+// kPoolAlignment cannot be less than the size of ChunkPrefix.
+static const int kPoolAlignment = sizeof(ChunkPrefix);
+
+void* PrepareChunk(void* chunk, size_t alignment, size_t num_bytes) {
+ ChunkPrefix* cp = reinterpret_cast<ChunkPrefix*>(chunk);
+ cp->num_bytes = num_bytes;
+ cp->chunk_ptr = chunk;
+ void* user_ptr = reinterpret_cast<void*>(cp + 1);
+ if (alignment > kPoolAlignment) {
+ // Move user_ptr forward to the first satisfying offset, and write
+ // chunk_ptr just before it.
+ size_t aligned_ptr = reinterpret_cast<size_t>(user_ptr) + alignment;
+ user_ptr = reinterpret_cast<void*>(aligned_ptr & ~(alignment - 1));
+ (reinterpret_cast<ChunkPrefix*>(user_ptr) - 1)->chunk_ptr = chunk;
+ }
+ // Safety check that user_ptr is always past the ChunkPrefix.
+ CHECK_GE(user_ptr, reinterpret_cast<ChunkPrefix*>(chunk) + 1);
+ return user_ptr;
+}
+
+ChunkPrefix* FindPrefix(void* user_ptr) {
+ ChunkPrefix* cp = reinterpret_cast<ChunkPrefix*>(user_ptr) - 1;
+ return reinterpret_cast<ChunkPrefix*>(cp->chunk_ptr);
+}
+} // namespace
+
+void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
+ if (!allocation_begun_) allocation_begun_ = true;
+ if (num_bytes == 0) return nullptr;
+
+ // If alignment is larger than kPoolAlignment, increase num_bytes so that we
+ // are guaranteed to be able to return an aligned ptr by advancing user_ptr
+ // without overrunning the end of the chunk.
+ if (alignment > kPoolAlignment) {
+ num_bytes += alignment;
+ }
+ num_bytes += sizeof(ChunkPrefix);
+ num_bytes = size_rounder_->RoundUp(num_bytes);
+ PtrRecord* pr = nullptr;
+ if (has_size_limit_) {
+ {
+ mutex_lock lock(mutex_);
+ auto iter = pool_.find(num_bytes);
+ if (iter == pool_.end()) {
+ allocated_count_++;
+ // Deliberately fall out of lock scope before
+ // calling the allocator. No further modification
+ // to the pool will be performed.
+ } else {
+ get_from_pool_count_++;
+ pr = iter->second;
+ RemoveFromList(pr);
+ pool_.erase(iter);
+ // Fall out of lock scope and do the result without the lock held.
+ }
+ }
+ }
+ if (pr != nullptr) {
+ void* r = pr->ptr;
+ delete pr;
+ return PrepareChunk(r, alignment, num_bytes);
+ } else {
+ void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes);
+ for (auto v : alloc_visitors_) {
+ v(ptr, num_bytes);
+ }
+ return PrepareChunk(ptr, alignment, num_bytes);
+ }
+}
+
+void PoolAllocator::DeallocateRaw(void* ptr) {
+ if (ptr == nullptr) return;
+ ChunkPrefix* cp = FindPrefix(ptr);
+ CHECK_LE((void*)cp, (void*)ptr);
+ if (!has_size_limit_ && !auto_resize_) {
+ for (auto v : free_visitors_) {
+ v(cp, cp->num_bytes);
+ }
+ allocator_->Free(cp, cp->num_bytes);
+ } else {
+ mutex_lock lock(mutex_);
+ ++put_count_;
+ while (pool_.size() >= pool_size_limit_) {
+ EvictOne();
+ }
+ PtrRecord* pr = new PtrRecord;
+ pr->num_bytes = cp->num_bytes;
+ pr->ptr = cp;
+ AddToList(pr);
+ pool_.insert(std::make_pair(cp->num_bytes, pr));
+ }
+}
+
+void PoolAllocator::Clear() {
+ if (has_size_limit_) {
+ mutex_lock lock(mutex_);
+ for (auto iter : pool_) {
+ PtrRecord* pr = iter.second;
+ for (auto v : free_visitors_) {
+ v(pr->ptr, pr->num_bytes);
+ }
+ allocator_->Free(pr->ptr, pr->num_bytes);
+ delete pr;
+ }
+ pool_.clear();
+ get_from_pool_count_ = 0;
+ put_count_ = 0;
+ allocated_count_ = 0;
+ evicted_count_ = 0;
+ lru_head_ = nullptr;
+ lru_tail_ = nullptr;
+ }
+}
+
+void PoolAllocator::RemoveFromList(PtrRecord* pr) {
+ if (pr->prev == nullptr) {
+ DCHECK_EQ(lru_head_, pr);
+ lru_head_ = nullptr;
+ } else {
+ pr->prev->next = pr->next;
+ }
+ if (pr->next == nullptr) {
+ DCHECK_EQ(lru_tail_, pr);
+ lru_tail_ = pr->prev;
+ } else {
+ pr->next->prev = pr->prev;
+ if (lru_head_ == nullptr) {
+ lru_head_ = pr->next;
+ }
+ }
+}
+
+void PoolAllocator::AddToList(PtrRecord* pr) {
+ pr->prev = nullptr;
+ if (lru_head_ == nullptr) {
+ CHECK(lru_tail_ == nullptr);
+ lru_tail_ = pr;
+ pr->next = nullptr;
+ } else {
+ pr->next = lru_head_;
+ pr->next->prev = pr;
+ }
+ lru_head_ = pr;
+}
+
+void PoolAllocator::EvictOne() {
+ DCHECK(lru_tail_ != nullptr);
+ PtrRecord* prec = lru_tail_;
+ RemoveFromList(prec);
+ auto iter = pool_.find(prec->num_bytes);
+ while (iter->second != prec) {
+ ++iter;
+ DCHECK(iter != pool_.end());
+ }
+ pool_.erase(iter);
+ for (auto v : free_visitors_) {
+ v(prec->ptr, prec->num_bytes);
+ }
+ allocator_->Free(prec->ptr, prec->num_bytes);
+ delete prec;
+ ++evicted_count_;
+ // Auto-resizing, and warning messages.
+ static const double kTolerable = 2e-3;
+ static const int kCheckInterval = 1000;
+ static const double kIncreaseFactor = 1.1;
+ static const int kMinPoolSize = 100;
+ if (0 == evicted_count_ % kCheckInterval) {
+ const double eviction_rate =
+ evicted_count_ / static_cast<double>(put_count_);
+ const int64 alloc_request_count = allocated_count_ + get_from_pool_count_;
+ const double alloc_rate =
+ allocated_count_ / static_cast<double>(alloc_request_count);
+ static int log_counter = 0;
+ // (counter increment not thread safe but it's just for logging, so we
+ // don't care).
+ bool should_log = ((log_counter++ % 10) == 0);
+ if (should_log) {
+ LOG(WARNING) << "PoolAllocator: After " << alloc_request_count
+ << " get requests, put_count=" << put_count_
+ << " evicted_count=" << evicted_count_
+ << " eviction_rate=" << eviction_rate
+ << " and unsatisfied allocation rate=" << alloc_rate;
+ }
+ if (auto_resize_ && (eviction_rate > kTolerable) &&
+ (alloc_rate > kTolerable)) {
+ size_t new_size_limit = (pool_size_limit_ < kMinPoolSize)
+ ? kMinPoolSize
+ : (kIncreaseFactor * pool_size_limit_);
+ if (should_log) {
+ LOG(INFO) << "Raising pool_size_limit_ from " << pool_size_limit_
+ << " to " << new_size_limit;
+ }
+ pool_size_limit_ = new_size_limit;
+ // Reset all the counters so that ratios are relative to new sizes
+ // at next test interval.
+ put_count_ = 0;
+ allocated_count_ = 0;
+ evicted_count_ = 0;
+ get_from_pool_count_ = 0;
+ }
+ }
+}
+
+void PoolAllocator::AddAllocVisitor(Visitor visitor) {
+ mutex_lock lock(mutex_);
+ CHECK(!allocation_begun_)
+ << "AddAllocVisitor may not be called after pool allocation "
+ << "has begun.";
+ alloc_visitors_.push_back(visitor);
+}
+
+void PoolAllocator::AddFreeVisitor(Visitor visitor) {
+ mutex_lock lock(mutex_);
+ CHECK(!allocation_begun_)
+ << "AddFreeVisitor may not be called after pool allocation "
+ << "has begun.";
+ free_visitors_.push_back(visitor);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator.h b/tensorflow/core/common_runtime/gpu/pool_allocator.h
new file mode 100644
index 0000000000..d10aabe88a
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator.h
@@ -0,0 +1,202 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_POOL_ALLOCATOR_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_POOL_ALLOCATOR_H_
+
+// Simple LRU pool allocators for various flavors of CPU RAM that
+// implement the VisitableAllocator interface. GPU memory is managed
+// by GPURegionAllocator.
+
+#include <atomic>
+#include <map>
+#include <memory>
+#include "tensorflow/core/lib/core/bits.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/common_runtime/gpu/visitable_allocator.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+namespace tensorflow {
+
+// Interface of an object that does the underlying alloc/free of memory.
+class SubAllocator {
+ public:
+ virtual ~SubAllocator() {}
+ virtual void* Alloc(size_t alignment, size_t num_bytes) = 0;
+ virtual void Free(void* ptr, size_t num_bytes) = 0;
+};
+
+// Interface of an object that rounds up integers.
+class RoundUpInterface {
+ public:
+ virtual ~RoundUpInterface() {}
+ virtual size_t RoundUp(size_t num_bytes) = 0;
+};
+
+// Size-limited pool of memory buffers obtained from a SubAllocator
+// instance. Pool eviction policy is LRU.
+class PoolAllocator : public VisitableAllocator {
+ public:
+ // "pool_size_limit" is the maximum number of returned, re-usable
+ // memory buffers to keep in the pool. If pool_size_limit == 0, the
+ // pool is effectively a thin wrapper around the allocator.
+ // If "auto_resize" is true, then the pool_size_limit will gradually
+ // be raised so that deallocations happen very rarely, if at all.
+ // Transitory start-up objects may deallocate, but the long-term
+ // working-set should not. Auto-resizing can raise pool_size_limit
+ // but will never lower it.
+ // "allocator" is the object that performs the underlying memory
+ // malloc/free operations. This object takes ownership of allocator.
+ PoolAllocator(size_t pool_size_limit, bool auto_resize,
+ SubAllocator* allocator, RoundUpInterface* size_rounder,
+ string name);
+ ~PoolAllocator() override;
+
+ string Name() override { return name_; }
+
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+
+ void DeallocateRaw(void* ptr) override;
+
+ // REQUIRES: The following functions may only be called prior
+ // to the first Allocate*() call. Once allocation has begun, it is
+ // illegal to register another visitor.
+
+ void AddAllocVisitor(Visitor visitor) override;
+
+ void AddFreeVisitor(Visitor visitor) override;
+
+ // Allocate an unused memory region of size "num_bytes". Fetch from
+ // the pool if available, otherwise call allocator_.
+ void* Get(size_t num_bytes);
+
+ // Return a no-longer needed memory region to the pool. It is an error
+ // to deference "ptr" after this call. If the pool is full, the least
+ // recently used region will be deallocated.
+ void Put(void* ptr, size_t num_bytes);
+
+ // Reset the pool to empty.
+ void Clear();
+
+ // The following accessors permit monitoring the effectiveness of
+ // the pool at avoiding repeated malloc/frees on the underlying
+ // allocator. Read locks are not taken on the theory that value
+ // consistency with other threads is not important.
+
+ // Number of Get() requests satisfied from pool.
+ int64 get_from_pool_count() const NO_THREAD_SAFETY_ANALYSIS {
+ return get_from_pool_count_;
+ }
+ // Number of Put() requests.
+ int64 put_count() const NO_THREAD_SAFETY_ANALYSIS { return put_count_; }
+ // Number of Get() requests requiring a fresh allocation.
+ int64 allocated_count() const NO_THREAD_SAFETY_ANALYSIS {
+ return allocated_count_;
+ }
+ // Number of pool evictions.
+ int64 evicted_count() const NO_THREAD_SAFETY_ANALYSIS {
+ return evicted_count_;
+ }
+ // Current size limit.
+ size_t size_limit() const NO_THREAD_SAFETY_ANALYSIS {
+ return pool_size_limit_;
+ }
+
+ private:
+ struct PtrRecord {
+ void* ptr;
+ size_t num_bytes;
+ PtrRecord* prev;
+ PtrRecord* next;
+ };
+
+ // Remove "pr" from the double-linked LRU list.
+ void RemoveFromList(PtrRecord* pr) EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Add "pr" to the head of the double-linked LRU list.
+ void AddToList(PtrRecord* pr) EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Delete the least recently used record.
+ void EvictOne() EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ const string name_;
+ const bool has_size_limit_;
+ const bool auto_resize_;
+ size_t pool_size_limit_;
+ std::unique_ptr<SubAllocator> allocator_;
+ std::unique_ptr<RoundUpInterface> size_rounder_;
+ mutex mutex_;
+ std::multimap<const size_t, PtrRecord*> pool_ GUARDED_BY(mutex_);
+ PtrRecord* lru_head_ GUARDED_BY(mutex_) = nullptr;
+ PtrRecord* lru_tail_ GUARDED_BY(mutex_) = nullptr;
+ int64 get_from_pool_count_ GUARDED_BY(mutex_) = 0;
+ int64 put_count_ GUARDED_BY(mutex_) = 0;
+ int64 allocated_count_ GUARDED_BY(mutex_) = 0;
+ int64 evicted_count_ GUARDED_BY(mutex_) = 0;
+ // Write access to these is guarded by mutex_, but not read
+ // access. They may only be modified prior to the first
+ // allocation. Later attempts to modify will fail.
+ std::vector<Visitor> alloc_visitors_;
+ std::vector<Visitor> free_visitors_;
+ std::atomic<bool> allocation_begun_;
+};
+
+// Do-nothing rounder. Passes through sizes unchanged.
+class NoopRounder : public RoundUpInterface {
+ public:
+ size_t RoundUp(size_t num_bytes) override { return num_bytes; }
+};
+
+// Power of 2 rounder: rounds up to nearest power of 2 size.
+class Pow2Rounder : public RoundUpInterface {
+ public:
+ size_t RoundUp(size_t num_bytes) override {
+ return 1uLL << Log2Ceiling64(num_bytes);
+ }
+};
+
+class BasicCPUAllocator : public SubAllocator {
+ public:
+ ~BasicCPUAllocator() override {}
+
+ void* Alloc(size_t alignment, size_t num_bytes) override {
+ return port::aligned_malloc(num_bytes, alignment);
+ }
+ void Free(void* ptr, size_t num_bytes) override { free(ptr); }
+};
+
+// Allocator for pinned CPU RAM that is made known to CUDA for the
+// purpose of efficient DMA with a GPU.
+class CUDAHostAllocator : public SubAllocator {
+ public:
+ // Note: stream_exec cannot be null.
+ explicit CUDAHostAllocator(perftools::gputools::StreamExecutor* stream_exec)
+ : stream_exec_(stream_exec) {
+ CHECK(stream_exec_ != nullptr);
+ }
+ ~CUDAHostAllocator() override {}
+
+ void* Alloc(size_t alignment, size_t num_bytes) override {
+ void* ptr = nullptr;
+ if (num_bytes > 0) {
+ ptr = stream_exec_->HostMemoryAllocate(num_bytes);
+ if (ptr == nullptr) {
+ LOG(FATAL) << "could not allocate pinned host memory of size: "
+ << num_bytes;
+ }
+ }
+ return ptr;
+ }
+
+ void Free(void* ptr, size_t num_bytes) override {
+ if (ptr != nullptr) {
+ stream_exec_->HostMemoryDeallocate(ptr);
+ }
+ }
+
+ private:
+ perftools::gputools::StreamExecutor* stream_exec_; // not owned, non-null
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CUDAHostAllocator);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_POOL_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
new file mode 100644
index 0000000000..ca409b2b4c
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
@@ -0,0 +1,203 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/common_runtime/gpu/pool_allocator.h"
+
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/platform.h"
+#include <gtest/gtest.h>
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+namespace {
+
+TEST(PoolAllocatorTest, ZeroSizeBuffers) {
+ gpu::Platform* platform =
+ gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
+ PoolAllocator pool(
+ 2 /*pool_size_limit*/, false /*auto_resize*/,
+ new CUDAHostAllocator(
+ platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0))
+ .ValueOrDie()),
+ new NoopRounder, "pool");
+
+ EXPECT_EQ(nullptr, pool.AllocateRaw(4 /*alignment*/, 0 /*num_bytes*/));
+ pool.DeallocateRaw(nullptr); // Should not crash.
+ EXPECT_EQ(0, pool.get_from_pool_count());
+ EXPECT_EQ(0, pool.put_count());
+ EXPECT_EQ(0, pool.allocated_count());
+ EXPECT_EQ(0, pool.evicted_count());
+}
+
+TEST(PoolAllocatorTest, ZeroSizePool) {
+ gpu::Platform* platform =
+ gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
+ PoolAllocator pool(
+ 0 /*pool_size_limit*/, false /*auto_resize*/,
+ new CUDAHostAllocator(
+ platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0))
+ .ValueOrDie()),
+ new NoopRounder, "pool");
+
+ EXPECT_EQ(0, pool.get_from_pool_count());
+ EXPECT_EQ(0, pool.put_count());
+ EXPECT_EQ(0, pool.allocated_count());
+ EXPECT_EQ(0, pool.evicted_count());
+
+ // All allocations should bypass the pool and return valid pointers.
+ for (int i = 0; i < 3; ++i) {
+ void* p0 = pool.AllocateRaw(4, 0);
+ void* p4 = pool.AllocateRaw(4, 4);
+ void* p12 = pool.AllocateRaw(4, 12);
+ EXPECT_EQ(nullptr, p0);
+ EXPECT_NE(nullptr, p4);
+ EXPECT_NE(nullptr, p12);
+ pool.DeallocateRaw(p0);
+ pool.DeallocateRaw(p4);
+ pool.DeallocateRaw(p12);
+ }
+ EXPECT_EQ(0, pool.get_from_pool_count());
+ EXPECT_EQ(0, pool.put_count());
+ EXPECT_EQ(0, pool.allocated_count());
+ EXPECT_EQ(0, pool.evicted_count());
+}
+
+TEST(PoolAllocatorTest, Alignment) {
+ gpu::Platform* platform =
+ gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
+ PoolAllocator pool(
+ 0 /*pool_size_limit*/, false /*auto_resize*/,
+ new CUDAHostAllocator(
+ platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0))
+ .ValueOrDie()),
+ new NoopRounder, "pool");
+ for (int i = 0; i < 16; ++i) {
+ size_t alignment = 1 << i;
+ void* p = pool.AllocateRaw(alignment, 111);
+ EXPECT_TRUE(p != nullptr);
+ EXPECT_EQ(0, reinterpret_cast<int64>(p) & (alignment - 1))
+ << "ptr: " << p << " alignment " << alignment;
+ // Intentionally don't deallocate, to test that destruction of
+ // the PoolAllocator frees all pending memory.
+ }
+}
+
+TEST(PoolAllocatorTest, AutoResize) {
+ PoolAllocator pool(2 /*pool_size_limit*/, true /*auto_resize*/,
+ new BasicCPUAllocator, new NoopRounder, "pool");
+
+ // Alloc/dealloc 10 sizes just a few times, confirming pool size
+ // stays at 2.
+ for (int i = 0; i < 10; ++i) {
+ void* p = pool.AllocateRaw(4, 64 << i);
+ pool.DeallocateRaw(p);
+ }
+ EXPECT_EQ(0, pool.get_from_pool_count());
+ EXPECT_EQ(10, pool.allocated_count());
+ EXPECT_EQ(10, pool.put_count());
+ EXPECT_EQ(8, pool.evicted_count());
+ EXPECT_EQ(2, pool.size_limit());
+
+ // Then repeat 1200 times. Pool size limit should jump to 100.
+ for (int j = 0; j < 120; ++j) {
+ for (int i = 0; i < 10; ++i) {
+ void* p = pool.AllocateRaw(4, 64 << i);
+ pool.DeallocateRaw(p);
+ }
+ }
+ EXPECT_EQ(100, pool.size_limit());
+}
+
+TEST(PoolAllocatorTest, CudaHostAllocator) {
+ gpu::Platform* platform =
+ gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
+ PoolAllocator pool(
+ 2 /*pool_size_limit*/, false /*auto_resize*/,
+ new CUDAHostAllocator(
+ platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0))
+ .ValueOrDie()),
+ new NoopRounder, "pool");
+
+ // Repeatedly Get a 16-byte value, confirming that there's only
+ // one real allocation.
+ void* p1_16 = pool.AllocateRaw(4, 16);
+ EXPECT_EQ(0, pool.get_from_pool_count());
+ EXPECT_EQ(1, pool.allocated_count());
+ EXPECT_NE(nullptr, p1_16);
+ pool.DeallocateRaw(p1_16);
+ // Pool contents {16}
+ EXPECT_EQ(1, pool.put_count());
+ void* p2_16 = pool.AllocateRaw(4, 16); // Get it again.
+ EXPECT_EQ(1, pool.get_from_pool_count());
+ EXPECT_EQ(1, pool.allocated_count());
+ EXPECT_EQ(p1_16, p2_16); // Same pointer value
+ pool.DeallocateRaw(p2_16); // Put it back.
+ // Pool contents {16}
+ EXPECT_EQ(2, pool.put_count());
+
+ // Get two more values of different sizes.
+ void* p3_4 = pool.AllocateRaw(4, 4);
+ EXPECT_EQ(2, pool.allocated_count());
+ EXPECT_NE(p1_16, p3_4); // Different pointer value
+ EXPECT_NE(nullptr, p3_4);
+ pool.DeallocateRaw(p3_4); // Put it back. Pool is now full.
+ // Pool contents {4, 16}
+ EXPECT_EQ(3, pool.put_count());
+ void* p4_2 = pool.AllocateRaw(4, 2); // Get a third size buffer.
+ EXPECT_NE(nullptr, p4_2);
+ EXPECT_EQ(0, pool.evicted_count());
+
+ // The pool is full: when we put back p4_2, the 16-byte buffer
+ // should be evicted since it was least recently inserted.
+ pool.DeallocateRaw(p4_2);
+ // Pool contents {2, 4}
+ EXPECT_EQ(4, pool.put_count());
+ EXPECT_EQ(1, pool.evicted_count());
+
+ // Re-getting and putting size 2 or 4 should not alter pool size or
+ // num-evicted.
+ void* p5_4 = pool.AllocateRaw(4, 4);
+ EXPECT_NE(nullptr, p5_4);
+ pool.DeallocateRaw(p5_4);
+ void* p6_2 = pool.AllocateRaw(4, 2);
+ EXPECT_NE(nullptr, p6_2);
+ pool.DeallocateRaw(p6_2);
+ EXPECT_EQ(3, pool.get_from_pool_count());
+ EXPECT_EQ(6, pool.put_count());
+ EXPECT_EQ(3, pool.allocated_count());
+ EXPECT_EQ(1, pool.evicted_count());
+
+ pool.Clear();
+ EXPECT_EQ(0, pool.get_from_pool_count());
+ EXPECT_EQ(0, pool.put_count());
+ EXPECT_EQ(0, pool.allocated_count());
+ EXPECT_EQ(0, pool.evicted_count());
+}
+
+TEST(PoolAllocatorTest, Pow2Rounder) {
+ Pow2Rounder rounder;
+ EXPECT_EQ(1, rounder.RoundUp(1));
+ EXPECT_EQ(2, rounder.RoundUp(2));
+ EXPECT_EQ(16, rounder.RoundUp(9));
+ EXPECT_EQ(16, rounder.RoundUp(16));
+ EXPECT_EQ(65536, rounder.RoundUp(41234));
+ EXPECT_EQ(65536, rounder.RoundUp(65535));
+ EXPECT_EQ(65536, rounder.RoundUp(65536));
+}
+
+TEST(PoolAllocatorTest, Name) {
+ gpu::Platform* platform =
+ gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
+ PoolAllocator pool(
+ 2 /*pool_size_limit*/, false /*auto_resize*/,
+ new CUDAHostAllocator(
+ platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0))
+ .ValueOrDie()),
+ new NoopRounder, "pool");
+ EXPECT_EQ("pool", pool.Name());
+}
+
+} // namespace
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc
new file mode 100644
index 0000000000..70ac6130c2
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/process_state.cc
@@ -0,0 +1,220 @@
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_region_allocator.h"
+#include "tensorflow/core/common_runtime/gpu/pool_allocator.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+
+#if defined(PLATFORM_GOOGLE)
+DEFINE_bool(record_mem_types, false,
+ "If true, record attributes of memory allocations and "
+ "dyanmically check for appropriate use of registered memory."
+ "Should only be true for debugging or diagnosis of "
+ "performance issues.");
+DEFINE_bool(brain_mem_reg_cuda_dma, true,
+ "If true, register CPU RAM used to copy to/from GPU RAM "
+ "with the CUDA driver.");
+DEFINE_bool(brain_gpu_use_bfc_allocator, false,
+ "If true, uses the Best-Fit GPU allocator.");
+DEFINE_bool(brain_gpu_region_allocator_debug, false,
+ "If true, checks for memory overwrites by writing "
+ "distinctive patterns on both ends of allocated memory.");
+DEFINE_bool(brain_gpu_region_allocator_reset_to_nan, false,
+ "If true, initializes all new Malloc buffers to NaN, "
+ "and resets the buffer to NaN upon Free.");
+
+#else
+bool FLAGS_record_mem_types = false;
+bool FLAGS_brain_mem_reg_cuda_dma = true;
+bool FLAGS_brain_gpu_region_allocator_debug = false;
+bool FLAGS_brain_gpu_region_allocator_reset_to_nan = false;
+bool FLAGS_brain_gpu_use_bfc_allocator = false;
+#endif
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+ProcessState* ProcessState::instance_ = nullptr;
+
+/*static*/ ProcessState* ProcessState::singleton() {
+ if (instance_ == nullptr) {
+ instance_ = new ProcessState;
+ }
+
+ return instance_;
+}
+
+ProcessState::ProcessState() : gpu_count_(0) {
+ CHECK(instance_ == nullptr);
+ instance_ = this;
+}
+
+ProcessState::~ProcessState() {
+ for (auto p : gpu_allocators_) {
+ delete p;
+ }
+ instance_ = nullptr;
+}
+
+string ProcessState::MemDesc::DebugString() {
+ return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index, ", dma: ",
+ gpu_registered, ", nic: ", nic_registered);
+}
+
+ProcessState::MemDesc ProcessState::PtrType(const void* ptr) {
+ if (FLAGS_record_mem_types) {
+ auto iter = mem_desc_map_.find(ptr);
+ if (iter != mem_desc_map_.end()) {
+ return iter->second;
+ }
+ }
+ return MemDesc();
+}
+
+void ProcessState::SetGPUCount(int c) {
+ CHECK(gpu_count_ == 0 || gpu_count_ == c)
+ << "Cannot call SetGPUCount with a non-zero value "
+ << "not equal to prior set value.";
+ gpu_count_ = c;
+}
+
+int ProcessState::GPUCount() const { return gpu_count_; }
+
+Allocator* ProcessState::GetGPUAllocator(int gpu_id, size_t total_bytes) {
+#if GOOGLE_CUDA
+ mutex_lock lock(mu_);
+ gpu::Platform* gpu_platform = GPUMachineManager();
+
+ // Verify that gpu_id is legitimate.
+ CHECK_LT(gpu_id, gpu_platform->VisibleDeviceCount())
+ << "gpu_id is outside discovered device range";
+
+ if (gpu_id >= static_cast<int64>(gpu_allocators_.size())) {
+ gpu_allocators_.resize(gpu_id + 1);
+ if (FLAGS_record_mem_types) gpu_al_.resize(gpu_id + 1);
+ }
+
+ if (gpu_allocators_[gpu_id] == nullptr) {
+ VisitableAllocator* gpu_allocator;
+
+ if (FLAGS_brain_gpu_use_bfc_allocator) {
+ gpu_allocator = new GPUBFCAllocator(gpu_id, total_bytes);
+ } else {
+ gpu_allocator = new GPURegionAllocator(gpu_id, total_bytes);
+ }
+
+ if (FLAGS_brain_gpu_region_allocator_debug) {
+ gpu_allocator = new GPUDebugAllocator(gpu_allocator, gpu_id);
+ }
+ if (FLAGS_brain_gpu_region_allocator_reset_to_nan) {
+ gpu_allocator = new GPUNanResetAllocator(gpu_allocator, gpu_id);
+ }
+
+ gpu_allocators_[gpu_id] = gpu_allocator;
+
+ // If there are any pending AllocVisitors for this bus, add
+ // them now.
+ gpu::StreamExecutor* se =
+ gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie();
+ int bus_id = se->GetDeviceDescription().numa_node();
+ if (bus_id < static_cast<int64>(gpu_visitors_.size())) {
+ for (auto v : gpu_visitors_[bus_id]) {
+ gpu_allocators_[gpu_id]->AddAllocVisitor(v);
+ }
+ }
+ if (FLAGS_record_mem_types) {
+ MemDesc md;
+ md.loc = MemDesc::GPU;
+ md.dev_index = gpu_id;
+ md.gpu_registered = false;
+ md.nic_registered = true;
+ if (static_cast<int64>(gpu_al_.size()) <= gpu_id)
+ gpu_al_.resize(gpu_id + 1);
+ gpu_al_[gpu_id] = new internal::RecordingAllocator(
+ &mem_desc_map_, gpu_allocators_[gpu_id], md, &mu_);
+ }
+ }
+ if (FLAGS_record_mem_types) return gpu_al_[gpu_id];
+ return gpu_allocators_[gpu_id];
+#else
+ LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda.";
+ return nullptr;
+#endif // GOOGLE_CUDA
+}
+
+Allocator* ProcessState::GetCPUAllocator(int numa_node) {
+ // Although we're temporarily ignoring numa_node, check for legality.
+ CHECK_GE(numa_node, 0);
+ // TODO(tucker): actually maintain separate CPUAllocators for
+ // different numa_nodes. For now, just one.
+ numa_node = 0;
+ mutex_lock lock(mu_);
+ while (cpu_allocators_.size() <= static_cast<size_t>(numa_node)) {
+ cpu_allocators_.push_back(new PoolAllocator(
+ 100 /*pool_size_limit*/, true /*auto_resize*/, new BasicCPUAllocator(),
+ new NoopRounder, "cpu_pool"));
+ }
+ return cpu_allocators_[0];
+}
+
+Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
+ if (gpu_count_ == 0 || !FLAGS_brain_mem_reg_cuda_dma) {
+ return GetCPUAllocator(numa_node);
+ }
+ // Although we're temporarily ignoring numa_node, check for legality.
+ CHECK_GE(numa_node, 0);
+ // TODO(tucker): actually maintain separate CPUAllocators for
+ // different numa_nodes. For now, just one.
+ numa_node = 0;
+ mutex_lock lock(mu_);
+ while (static_cast<int>(cuda_host_allocators_.size()) <= numa_node) {
+ // CUDAHost alloc the same across all gpus, so just get the
+ // executor for the first device.
+ gpu::Platform* gpu_platform = GPUMachineManager();
+ gpu::StreamExecutor* se = gpu_platform->ExecutorForDevice(0).ValueOrDie();
+ CHECK(se);
+ cuda_host_allocators_.push_back(new PoolAllocator(
+ 100 /*pool_size_limit*/, true /*auto_resize*/,
+ new CUDAHostAllocator(se), new Pow2Rounder, "cuda_host"));
+ if (FLAGS_record_mem_types) {
+ MemDesc md;
+ md.loc = MemDesc::CPU;
+ md.dev_index = 0;
+ md.gpu_registered = true;
+ md.nic_registered = false;
+ cuda_al_.push_back(new internal::RecordingAllocator(
+ &mem_desc_map_, cuda_host_allocators_.back(), md, &mu_));
+ }
+ }
+ if (FLAGS_record_mem_types) return cuda_al_[0];
+ return cuda_host_allocators_[0];
+}
+
+void ProcessState::AddGPUAllocVisitor(int bus_id, AllocVisitor visitor) {
+#if GOOGLE_CUDA
+ mutex_lock lock(mu_);
+ gpu::Platform* gpu_platform = GPUMachineManager();
+ for (int gpu_id = 0; gpu_id < static_cast<int64>(gpu_allocators_.size());
+ ++gpu_id) {
+ gpu::StreamExecutor* se =
+ gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie();
+ if (gpu_allocators_[gpu_id] &&
+ se->GetDeviceDescription().numa_node() == bus_id) {
+ gpu_allocators_[gpu_id]->AddAllocVisitor(visitor);
+ }
+ }
+ while (bus_id >= static_cast<int64>(gpu_visitors_.size())) {
+ gpu_visitors_.push_back(std::vector<AllocVisitor>());
+ }
+ gpu_visitors_[bus_id].push_back(visitor);
+#endif // GOOGLE_CUDA
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/process_state.h b/tensorflow/core/common_runtime/gpu/process_state.h
new file mode 100644
index 0000000000..527d12c10d
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/process_state.h
@@ -0,0 +1,140 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_PROCESS_STATE_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_PROCESS_STATE_H_
+
+#include <functional>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace tensorflow {
+
+class Allocator;
+class VisitableAllocator;
+class PoolAllocator;
+
+// Singleton that manages per-process state, e.g. allocation
+// of shared resources.
+class ProcessState {
+ public:
+ static ProcessState* singleton();
+
+ // Descriptor for memory allocation attributes, used by optional
+ // runtime correctness analysis logic.
+ struct MemDesc {
+ enum MemLoc { CPU, GPU };
+ MemLoc loc;
+ int dev_index;
+ bool gpu_registered;
+ bool nic_registered;
+ MemDesc()
+ : loc(CPU),
+ dev_index(0),
+ gpu_registered(false),
+ nic_registered(false) {}
+ string DebugString();
+ };
+
+ // Records the number of GPUs available in the local process.
+ // It is a fatal error to call this with a value != to the value
+ // in a prior call.
+ void SetGPUCount(int c);
+
+ // Returns number of GPUs available in local process, as set by
+ // SetGPUCount(); Returns 0 if SetGPUCount has not been called.
+ int GPUCount() const;
+
+ // Returns what we know about the memory at ptr.
+ // If we know nothing, it's called CPU 0 with no other attributes.
+ MemDesc PtrType(const void* ptr);
+
+ // Returns the one CPUAllocator used for the given numa_node.
+ // TEMPORY: ignores numa_node.
+ Allocator* GetCPUAllocator(int numa_node);
+
+ // Returns the one GPU allocator used for the indexed GPU.
+ // Note that this is a system GPU index, not (necessarily) a brain
+ // device index.
+ //
+ // 'total_bytes' is the total number of bytes that should be made
+ // available to the allocator. The first call to this function for
+ // a given gpu_id creates the allocator, so only the total_bytes
+ // used on that first call is used.
+ //
+ // REQUIRES: gpu_id must be a valid ordinal for a GPU available in the
+ // current system environment. Otherwise returns nullptr.
+ Allocator* GetGPUAllocator(int gpu_id, size_t total_bytes);
+
+ Allocator* GetCUDAHostAllocator(int numa_node);
+
+ // Registers a function to be called once on every new Region
+ // allocated by every GPURegionAllocator proximate to the specified
+ // bus. The AllocVisitor is provided with a memory pointer and the
+ // size of the area it identifies. The pointer is not guaranteed to
+ // be valid after the call terminates. The intention is for this
+ // interface to be used for network device memory registration.
+ // "bus_id" is platform-specific. On many platforms it
+ // should be 0. On machines with multiple PCIe buses, it should be
+ // the index of one of the PCIe buses. If the the bus_id is invalid,
+ // results are undefined.
+ typedef std::function<void(void*, size_t)> AllocVisitor;
+ void AddGPUAllocVisitor(int bus_id, AllocVisitor visitor);
+
+ typedef std::unordered_map<const void*, MemDesc> MDMap;
+
+ protected:
+ ProcessState();
+
+ static ProcessState* instance_;
+
+ mutex mu_;
+ int gpu_count_;
+
+ std::vector<PoolAllocator*> cpu_allocators_ GUARDED_BY(mu_);
+ std::vector<VisitableAllocator*> gpu_allocators_ GUARDED_BY(mu_);
+ std::vector<std::vector<AllocVisitor>> gpu_visitors_ GUARDED_BY(mu_);
+ std::vector<PoolAllocator*> cuda_host_allocators_ GUARDED_BY(mu_);
+
+ virtual ~ProcessState();
+
+ // Optional RecordingAllocators that wrap the corresponding
+ // Allocators for runtime attribute use analysis.
+ MDMap mem_desc_map_;
+ std::vector<Allocator*> cpu_al_ GUARDED_BY(mu_);
+ std::vector<Allocator*> gpu_al_ GUARDED_BY(mu_);
+ std::vector<Allocator*> cuda_al_ GUARDED_BY(mu_);
+};
+
+namespace internal {
+class RecordingAllocator : public Allocator {
+ public:
+ RecordingAllocator(ProcessState::MDMap* mm, Allocator* a,
+ ProcessState::MemDesc md, mutex* mu)
+ : mm_(mm), a_(a), md_(md), mu_(mu) {}
+
+ string Name() override { return a_->Name(); }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+ void* p = a_->AllocateRaw(alignment, num_bytes);
+ mutex_lock l(*mu_);
+ (*mm_)[p] = md_;
+ return p;
+ }
+ void DeallocateRaw(void* p) override {
+ mutex_lock l(*mu_);
+ auto iter = mm_->find(p);
+ mm_->erase(iter);
+ a_->DeallocateRaw(p);
+ }
+ bool TracksAllocationSizes() override { return a_->TracksAllocationSizes(); }
+ size_t RequestedSize(void* p) override { return a_->RequestedSize(p); }
+ size_t AllocatedSize(void* p) override { return a_->AllocatedSize(p); }
+ ProcessState::MDMap* mm_; // not owned
+ Allocator* a_; // not owned
+ ProcessState::MemDesc md_;
+ mutex* mu_;
+};
+} // namespace internal
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_PROCESS_STATE_H_
diff --git a/tensorflow/core/common_runtime/gpu/visitable_allocator.h b/tensorflow/core/common_runtime/gpu/visitable_allocator.h
new file mode 100644
index 0000000000..23feed9aab
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/visitable_allocator.h
@@ -0,0 +1,30 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_VISITABLE_ALLOCATOR_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_VISITABLE_ALLOCATOR_H_
+
+#include <functional>
+#include "tensorflow/core/framework/allocator.h"
+
+namespace tensorflow {
+
+// Subclass VisitableAllocator instead of Allocator when a memory
+// allocator needs to enable some kind of registration/deregistration
+// of memory areas.
+class VisitableAllocator : public Allocator {
+ public:
+ // Visitor gets called with a pointer to a memory area and its
+ // size in bytes.
+ typedef std::function<void(void*, size_t)> Visitor;
+
+ // Register a visitor guaranteed to be called exactly once on each
+ // chunk of memory newly allocated from the underlying device.
+ // Typically, chunks will be reused and possibly sub-divided by a
+ // pool manager, so the calls will happen only once per process
+ // execution, not once per tensor (re)allocation.
+ virtual void AddAllocVisitor(Visitor visitor) = 0;
+
+ // Register a visitor guaranteed to be called on each chunk of
+ // memory returned to the underlying device.
+ virtual void AddFreeVisitor(Visitor visitor) = 0;
+};
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_VISITABLE_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h
new file mode 100644
index 0000000000..03fd9a97c3
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu_device_context.h
@@ -0,0 +1,45 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/device_base.h"
+
+namespace perftools {
+namespace gputools {
+class Stream;
+} // namespace gputools
+} // namespace perftools
+
+namespace tensorflow {
+
+namespace gpu = ::perftools::gputools;
+
+class GPUDeviceContext : public DeviceContext {
+ public:
+ GPUDeviceContext(int stream_id, gpu::Stream* stream)
+ : stream_id_(stream_id), stream_(stream) {}
+
+ ~GPUDeviceContext() override {}
+
+ gpu::Stream* stream() const override { return stream_; }
+ int stream_id() const { return stream_id_; }
+
+ void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
+ Tensor* device_tensor,
+ StatusCallback done) const override;
+
+ void CopyDeviceTensorToCPU(const Tensor* device_tensor,
+ const string& edge_name, Device* device,
+ Tensor* cpu_tensor, StatusCallback done) override;
+
+ void MaintainLifetimeOnStream(
+ const Tensor* t, perftools::gputools::Stream* stream) const override {}
+
+ private:
+ int stream_id_;
+ gpu::Stream* stream_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
new file mode 100644
index 0000000000..28afc95c1b
--- /dev/null
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
@@ -0,0 +1,160 @@
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/framework/op_segment.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/util/device_name_utils.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session_options.h"
+
+#if defined(PLATFORM_GOOGLE)
+DECLARE_bool(brain_gpu_use_bfc_allocator);
+#else
+extern bool FLAGS_brain_gpu_use_bfc_allocator;
+#endif
+
+namespace tensorflow {
+namespace test {
+
+Benchmark::Benchmark(const string& device, Graph* g,
+ const SessionOptions* options, Graph* init) {
+ RequireDefaultOps();
+
+ FLAGS_brain_gpu_use_bfc_allocator = true;
+
+ SessionOptions default_options;
+ if (!options) {
+ options = &default_options;
+ }
+
+ testing::StopTiming();
+ string t = str_util::Uppercase(device);
+ device_ =
+ DeviceFactory::NewDevice(t, *options, "/job:localhost/replica:0/task:0");
+ CHECK(device_) << "Could not create a " << device << " device";
+
+ pool_ = new thread::ThreadPool(options->env, "blocking",
+ port::NumSchedulableCPUs());
+
+ auto runner = [this](std::function<void()> closure) {
+ pool_->Schedule(closure);
+ };
+
+ rendez_ = NewLocalRendezvous();
+
+ if (init) {
+ Executor* init_exec;
+ TF_CHECK_OK(NewLocalExecutor(
+ {
+ device_, nullptr, false,
+ [this](const NodeDef& ndef, OpKernel** kernel) {
+ return CreateNonCachedKernel(device_, nullptr, ndef, kernel);
+ },
+ [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); },
+ },
+ init, &init_exec));
+ Executor::Args args;
+ args.rendezvous = rendez_;
+ args.runner = runner;
+ TF_CHECK_OK(init_exec->Run(args));
+ delete init_exec;
+ }
+
+ TF_CHECK_OK(NewLocalExecutor(
+ {
+ device_,
+ nullptr,
+ false,
+ [this](const NodeDef& ndef, OpKernel** kernel) {
+ return CreateNonCachedKernel(device_, nullptr, ndef, kernel);
+ },
+ [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); },
+ },
+ g, &exec_));
+}
+
+Benchmark::~Benchmark() {
+ if (device_) {
+ rendez_->Unref();
+ delete exec_;
+ delete device_;
+ delete pool_;
+ }
+}
+
+void Benchmark::Run(int iters) { RunWithArgs({}, {}, iters); }
+
+string GetRendezvousKey(const Node* node) {
+ string send_device;
+ TF_CHECK_OK(GetNodeAttr(node->def(), "send_device", &send_device));
+ string recv_device;
+ TF_CHECK_OK(GetNodeAttr(node->def(), "recv_device", &recv_device));
+ string tensor_name;
+ TF_CHECK_OK(GetNodeAttr(node->def(), "tensor_name", &tensor_name));
+ uint64 send_device_incarnation;
+ TF_CHECK_OK(GetNodeAttr(node->def(), "send_device_incarnation",
+ reinterpret_cast<int64*>(&send_device_incarnation)));
+ return Rendezvous::CreateKey(send_device, send_device_incarnation,
+ recv_device, tensor_name, FrameAndIter(0, 0));
+}
+
+void Benchmark::RunWithArgs(
+ const std::vector<std::pair<const Node*, Tensor>>& inputs,
+ const std::vector<const Node*>& outputs, int iters) {
+ if (device_) {
+ // Gets inputs' and outputs' rendezvous keys.
+ std::vector<std::pair<string, Tensor>> in;
+ for (const auto& p : inputs) {
+ in.push_back({GetRendezvousKey(p.first), p.second});
+ }
+ std::vector<string> out;
+ for (const auto& n : outputs) {
+ out.push_back(GetRendezvousKey(n));
+ }
+ Tensor unused; // In benchmark, we don't care the return value.
+ bool is_dead;
+
+ // Warm up
+ Executor::Args args;
+ args.rendezvous = rendez_;
+ args.runner = [this](std::function<void()> closure) {
+ pool_->Schedule(closure);
+ };
+ for (int i = 0; i < 3; ++i) {
+ for (const auto& p : in) {
+ rendez_->Send(p.first, Rendezvous::Args(), p.second, false);
+ }
+ TF_CHECK_OK(exec_->Run(args));
+ for (const string& key : out) {
+ rendez_->Recv(key, Rendezvous::Args(), &unused, &is_dead);
+ }
+ }
+ TF_CHECK_OK(device_->Sync());
+
+ testing::StartTiming();
+ while (iters-- > 0) {
+ for (const auto& p : in) {
+ rendez_->Send(p.first, Rendezvous::Args(), p.second, false);
+ }
+ TF_CHECK_OK(exec_->Run(args));
+ for (const string& key : out) {
+ rendez_->Recv(key, Rendezvous::Args(), &unused, &is_dead);
+ }
+ }
+
+ TF_CHECK_OK(device_->Sync());
+ testing::StopTiming();
+ }
+}
+
+} // end namespace test
+} // end namespace tensorflow
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
new file mode 100644
index 0000000000..5ebe13e1d4
--- /dev/null
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
@@ -0,0 +1,52 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
+#define TENSORFLOW_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+class Device;
+class SessionOptions;
+
+namespace test {
+
+class Benchmark {
+ public:
+ // "device" must be either "cpu" or "gpu". Takes ownership of "g"
+ // and "init".
+ Benchmark(const string& device, Graph* g,
+ const SessionOptions* options = nullptr, Graph* init = nullptr);
+ ~Benchmark();
+
+ // Executes the graph for "iters" times.
+ void Run(int iters);
+
+ // If "g" contains send/recv nodes, before each execution, we send
+ // inputs to the corresponding recv nodes in the graph, after each
+ // execution, we recv outputs from the corresponding send nodes in
+ // the graph. In the benchmark, we throw away values returned by the
+ // graph.
+ void RunWithArgs(const std::vector<std::pair<const Node*, Tensor>>& inputs,
+ const std::vector<const Node*>& outputs, int iters);
+
+ private:
+ thread::ThreadPool* pool_ = nullptr;
+ thread::ThreadPool* non_blocking_pool_ = nullptr;
+ Device* device_ = nullptr;
+ Rendezvous* rendez_ = nullptr;
+ Executor* exec_ = nullptr;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Benchmark);
+};
+
+} // end namespace test
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc
new file mode 100644
index 0000000000..6a75346805
--- /dev/null
+++ b/tensorflow/core/common_runtime/local_device.cc
@@ -0,0 +1,51 @@
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
+#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/session_options.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+namespace {
+
+DeviceBase::CpuWorkerThreads eigen_worker_threads;
+Eigen::ThreadPoolInterface* eigen_thread_pool = nullptr;
+Eigen::ThreadPoolDevice* eigen_device = nullptr;
+
+static bool InitModule(const SessionOptions& options) {
+ int32 intra_op_parallelism_threads =
+ options.config.intra_op_parallelism_threads();
+ if (intra_op_parallelism_threads == 0) {
+ intra_op_parallelism_threads = port::NumSchedulableCPUs();
+ }
+ LOG(INFO) << "Local device intra op parallelism threads: "
+ << intra_op_parallelism_threads;
+ eigen_worker_threads.num_threads = intra_op_parallelism_threads;
+ eigen_worker_threads.workers = new thread::ThreadPool(
+ options.env, "Eigen", intra_op_parallelism_threads);
+ eigen_thread_pool = new EigenThreadPoolWrapper(eigen_worker_threads.workers);
+ eigen_device = new Eigen::ThreadPoolDevice(eigen_thread_pool,
+ eigen_worker_threads.num_threads);
+ return true;
+}
+} // end namespace
+
+// LocalDevice ----------------------------------------------------------------
+
+LocalDevice::LocalDevice(const SessionOptions& options,
+ const DeviceAttributes& attributes,
+ Allocator* device_allocator)
+ : Device(options.env, attributes, device_allocator) {
+ // All ThreadPoolDevices in the process will use this single fixed
+ // sized threadpool for numerical computations.
+ static bool init = InitModule(options);
+ CHECK(init); // Avoids compiler warning that init is unused.
+ set_tensorflow_cpu_worker_threads(&eigen_worker_threads);
+ set_eigen_cpu_device(eigen_device);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/local_device.h b/tensorflow/core/common_runtime/local_device.h
new file mode 100644
index 0000000000..fc4cfc2dfc
--- /dev/null
+++ b/tensorflow/core/common_runtime/local_device.h
@@ -0,0 +1,27 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_LOCAL_DEVICE_H_
+#define TENSORFLOW_COMMON_RUNTIME_LOCAL_DEVICE_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+
+namespace tensorflow {
+
+class SessionOptions;
+
+// This class is shared by ThreadPoolDevice and GPUDevice and
+// initializes a shared Eigen compute device used by both. This
+// should eventually be removed once we refactor ThreadPoolDevice and
+// GPUDevice into more 'process-wide' abstractions.
+class LocalDevice : public Device {
+ public:
+ LocalDevice(const SessionOptions& options, const DeviceAttributes& attributes,
+ Allocator* device_allocator);
+ ~LocalDevice() override {}
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(LocalDevice);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_LOCAL_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/local_session.cc b/tensorflow/core/common_runtime/local_session.cc
new file mode 100644
index 0000000000..ab6993b8a2
--- /dev/null
+++ b/tensorflow/core/common_runtime/local_session.cc
@@ -0,0 +1,500 @@
+#include "tensorflow/core/common_runtime/local_session.h"
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/common_runtime/session_factory.h"
+#include "tensorflow/core/common_runtime/simple_placer.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_partition.h"
+#include "tensorflow/core/graph/subgraph.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+namespace {
+
+thread::ThreadPool* kernel_thread_pool_ = nullptr;
+static bool InitModule(const SessionOptions& options) {
+ int32 inter_op_parallelism_threads =
+ options.config.inter_op_parallelism_threads();
+ if (inter_op_parallelism_threads == 0) {
+ // Default to using the number of cores available in the process.
+ inter_op_parallelism_threads = port::NumSchedulableCPUs();
+ }
+ LOG(INFO) << "Local session inter op parallelism threads: "
+ << inter_op_parallelism_threads;
+ kernel_thread_pool_ = new thread::ThreadPool(options.env, "Compute",
+ inter_op_parallelism_threads);
+ return true;
+}
+
+// TODO(vrv): Figure out how to unify the many different functions
+// that generate RendezvousKey, since many of them have to be
+// consistent with each other.
+string GetRendezvousKey(const string& tensor_name,
+ const DeviceAttributes& device_info,
+ const FrameAndIter& frame_iter) {
+ return strings::StrCat(device_info.name(), ";",
+ strings::FpToString(device_info.incarnation()), ";",
+ device_info.name(), ";", tensor_name, ";",
+ frame_iter.frame_id, ":", frame_iter.iter_id);
+}
+
+// NOTE: On Android with a single device, there is never
+// a risk of an OpKernel blocking indefinitely:
+//
+// 1) No operations do I/O that depends on other simultaneous kernels,
+//
+// 2) Recv nodes always complete immediately: The inputs are sent into
+// the local rendezvous before we start the executor, so the
+// corresonding recvs will not block.
+//
+// Based on these assumptions, we can use the same thread pool for
+// both "non-blocking" and "blocking" OpKernels on Android.
+//
+// This may change down the road when we add support for multiple
+// devices that run concurrently, in which case we will need to
+// revisit this decision.
+void SchedClosure(std::function<void()> c) {
+// TODO(sanjay): Get rid of __ANDROID__ path
+#ifdef __ANDROID__
+ // On Android, there is no implementation of ThreadPool that takes
+ // std::function, only Closure, which we cannot easily convert.
+ //
+ // Instead, we just run the function in-line, which is currently
+ // safe given the reasoning above.
+ c();
+#else
+ kernel_thread_pool_->Schedule(c);
+#endif // __ANDROID__
+}
+
+} // namespace
+
+LocalSession::LocalSession(const SessionOptions& options,
+ const DeviceMgr* device_mgr)
+ : options_(options),
+ device_mgr_(device_mgr),
+ cancellation_manager_(new CancellationManager()) {
+ static bool init = InitModule(options);
+ CHECK(init); // Avoids compiler warning that init is unused.
+ session_handle_ = strings::FpToString(random::New64());
+ int devices_added = 0;
+ if (options.config.log_device_placement()) {
+ const string mapping_str = device_mgr_->DeviceMappingString();
+ printf("Device mapping:\n%s", mapping_str.c_str());
+ LOG(INFO) << "Device mapping:\n" << mapping_str;
+ }
+ for (auto d : device_mgr_->ListDevices()) {
+ devices_.push_back(d);
+ device_set_.AddDevice(d);
+ d->op_segment()->AddHold(session_handle_);
+
+ // The first device added is special: it is the 'client device' (a
+ // CPU device) from which we feed and fetch Tensors.
+ if (devices_added == 0) {
+ device_set_.set_client_device(d);
+ }
+ ++devices_added;
+ }
+}
+
+LocalSession::~LocalSession() {
+ for (auto d : device_mgr_->ListDevices()) {
+ d->op_segment()->RemoveHold(session_handle_);
+ }
+ for (auto it : executors_) {
+ delete it.second;
+ }
+ delete cancellation_manager_;
+}
+
+Status LocalSession::Create(const GraphDef& graph) {
+ mutex_lock l(graph_def_lock_);
+ if (graph_created_) {
+ return errors::AlreadyExists(
+ "A Graph has already been created for this session.");
+ }
+ return ExtendLocked(graph);
+}
+
+Status LocalSession::Extend(const GraphDef& graph) {
+ mutex_lock l(graph_def_lock_);
+ return ExtendLocked(graph);
+}
+
+Status LocalSession::ExtendLocked(const GraphDef& graph) {
+ graph_created_ = true; // In case this is first call
+ graph_def_.MergeFrom(graph);
+ return Status::OK();
+}
+
+Status LocalSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ std::vector<Tensor>* outputs) {
+ {
+ mutex_lock l(graph_def_lock_);
+ if (!graph_created_) {
+ return errors::InvalidArgument(
+ "Session was not created with a graph before Run()!");
+ }
+ }
+
+ // Extract the inputs names for this run of the session.
+ std::vector<string> input_tensor_names;
+ input_tensor_names.reserve(inputs.size());
+ for (const auto& it : inputs) {
+ input_tensor_names.push_back(it.first);
+ }
+
+ // Check if we already have an executor for these arguments.
+ ExecutorsAndKeys* executors_and_keys;
+ Status s = GetOrCreateExecutors(input_tensor_names, output_names,
+ target_nodes, &executors_and_keys);
+ if (!s.ok()) {
+ return s;
+ }
+
+ IntraProcessRendezvous* rendez =
+ new IntraProcessRendezvous(device_mgr_.get());
+ core::ScopedUnref rendez_unref(rendez);
+
+ // Insert the input tensors into the local rendezvous by their
+ // rendezvous key.
+ for (const auto& input : inputs) {
+ const string& input_key = executors_and_keys->input_keys[input.first];
+ s = rendez->Send(input_key, Rendezvous::Args(), input.second, false);
+ if (!s.ok()) {
+ rendez->StartAbort(s);
+ return s;
+ }
+ }
+
+ // Start parallel Executors.
+ Notification executors_done;
+ const int num_executors = executors_and_keys->device_executors.size();
+ ExecutorBarrier* barrier = new ExecutorBarrier(
+ num_executors, rendez, [&executors_done, &s](const Status& ret) {
+ s = ret;
+ executors_done.Notify();
+ });
+
+ Executor::Args args;
+ args.rendezvous = rendez;
+ args.cancellation_manager = cancellation_manager_;
+ args.runner = SchedClosure;
+
+ for (auto device_executor : executors_and_keys->device_executors) {
+ Executor* exec = device_executor.second;
+ exec->RunAsync(args, barrier->Get());
+ }
+
+ executors_done.WaitForNotification();
+
+ TF_RETURN_IF_ERROR(s);
+
+ if (!output_names.empty()) {
+ outputs->resize(output_names.size());
+ }
+
+ // Get the outputs from the rendezvous
+ for (size_t output_offset = 0; output_offset < output_names.size();
+ ++output_offset) {
+ const string& output_key =
+ executors_and_keys->output_keys[output_names[output_offset]];
+ Tensor output_tensor;
+ bool is_dead;
+
+ // Fetch data from the Rendezvous.
+ s = rendez->Recv(output_key, Rendezvous::Args(), &output_tensor, &is_dead);
+ if (is_dead) {
+ s = errors::InvalidArgument("The tensor returned for ",
+ output_names[output_offset],
+ " was not valid.");
+ }
+ if (!s.ok()) {
+ rendez->StartAbort(s);
+ outputs->clear();
+ return s;
+ }
+
+ (*outputs)[output_offset] = output_tensor;
+ }
+
+ return s;
+}
+
+Status LocalSession::GetOrCreateExecutors(
+ gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
+ gtl::ArraySlice<string> target_nodes,
+ ExecutorsAndKeys** executors_and_keys) {
+ // Sort the inputs and outputs, so we don't create separate
+ // executors when a user passes in the same inputs/outputs in
+ // different orders.
+ //
+ // We could consider some other signature instead of sorting that
+ // preserves the same property to avoid the sort in the future.
+ std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
+ std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
+ std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
+ std::sort(inputs_sorted.begin(), inputs_sorted.end());
+ std::sort(outputs_sorted.begin(), outputs_sorted.end());
+ std::sort(tn_sorted.begin(), tn_sorted.end());
+
+ const string key = strings::StrCat(str_util::Join(inputs_sorted, ","), "->",
+ str_util::Join(outputs_sorted, ","), "/",
+ str_util::Join(tn_sorted, ","));
+
+ // See if we already have the executors for this run.
+ {
+ mutex_lock l(executor_lock_); // could use reader lock
+ auto it = executors_.find(key);
+ if (it != executors_.end()) {
+ *executors_and_keys = it->second;
+ return Status::OK();
+ }
+ }
+
+ // The executor_lock_ is intentionally released while executor is
+ // being created.
+ std::unordered_map<string, Graph*> graphs;
+ Status s = CreateGraphs(inputs, outputs, target_nodes, &graphs);
+ if (!s.ok()) {
+ return s;
+ }
+
+ bool has_control_flow = false;
+ for (const auto& graph : graphs) {
+ for (const Node* n : graph.second->nodes()) {
+ if (IsControlFlow(n)) {
+ has_control_flow = true;
+ break;
+ }
+ }
+ if (has_control_flow) break;
+ }
+
+ std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
+
+ for (const auto& graph : graphs) {
+ const string& partition_name = graph.first;
+ Graph* partition_graph = graph.second;
+
+ Device* d;
+ s = device_mgr_->LookupDevice(partition_name, &d);
+ if (!s.ok()) {
+ return s;
+ }
+
+ LocalExecutorParams params;
+ params.has_control_flow = has_control_flow;
+ params.device = d;
+ params.create_kernel = [this, d](const NodeDef& ndef, OpKernel** kernel) {
+ return CreateCachedKernel(d, session_handle_, nullptr, ndef, kernel);
+ };
+ params.delete_kernel = [this, d](OpKernel* kernel) {
+ DeleteCachedKernel(d, session_handle_, kernel);
+ };
+
+ Executor* tmp_exec;
+ s = NewLocalExecutor(params, partition_graph, &tmp_exec);
+ if (!s.ok()) {
+ return s;
+ }
+ ek->device_executors.insert(std::make_pair(graph.first, tmp_exec));
+ }
+
+ // Compute the rendezvous keys to avoid recomputing them every time.
+ //
+ // We always use the first device as the device name portion of the
+ // key, even if we're feeding another graph.
+ for (const string& input : inputs) {
+ ek->input_keys[input] = GetRendezvousKey(
+ input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
+ }
+ for (const string& output : outputs) {
+ ek->output_keys[output] = GetRendezvousKey(
+ output, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
+ }
+
+ // Reacquire the lock, try to insert into the map.
+ mutex_lock l(executor_lock_);
+ const bool inserted = executors_.insert(std::make_pair(key, ek.get())).second;
+ if (!inserted) {
+ // Another thread created the entry before us, so delete the
+ // one we created and return the already created one.
+ auto it = executors_.find(key);
+ *executors_and_keys = it->second;
+ } else {
+ *executors_and_keys = ek.release();
+ }
+
+ return Status::OK();
+}
+
+void LocalSession::SaveStatefulNodes(Graph* graph) {
+ for (Node* n : graph->nodes()) {
+ if (n->op_def().is_stateful()) {
+ VLOG(2) << "Saving " << n->DebugString();
+ stateful_placements_[n->name()] = n->assigned_device_name();
+ }
+ }
+}
+
+void LocalSession::RestoreStatefulNodes(Graph* graph) {
+ for (Node* n : graph->nodes()) {
+ if (n->op_def().is_stateful()) {
+ auto iter = stateful_placements_.find(n->name());
+ if (iter != stateful_placements_.end()) {
+ n->set_assigned_device_name(iter->second);
+ VLOG(2) << "Restored " << n->DebugString();
+ }
+ }
+ }
+}
+
+Status LocalSession::CreateGraphs(gtl::ArraySlice<string> feeds,
+ gtl::ArraySlice<string> fetches,
+ gtl::ArraySlice<string> target_nodes,
+ std::unordered_map<string, Graph*>* outputs) {
+ Graph graph(OpRegistry::Global());
+ GraphConstructorOptions opts;
+
+ {
+ mutex_lock l(graph_def_lock_);
+ TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def_, &graph));
+ }
+
+ TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
+ &graph, feeds, fetches, target_nodes,
+ device_set_.client_device()->attributes()));
+
+ // Run the simple placer after rewriting the graph.
+ std::unordered_map<string, int32> node_name_to_cost_map;
+ for (Node* n : graph.nodes()) {
+ node_name_to_cost_map[n->name()] = n->cost_id();
+ }
+ SimplePlacer placer(&graph, &device_set_, &node_name_to_cost_map, &options_);
+
+ {
+ mutex_lock l(mu_);
+ // Restore stateful nodes.
+ RestoreStatefulNodes(&graph);
+ TF_RETURN_IF_ERROR(placer.Run());
+ // Save stateful nodes.
+ SaveStatefulNodes(&graph);
+ }
+
+ // Partition the graph across devices.
+ std::unordered_map<string, GraphDef> partitions;
+ PartitionOptions popts;
+ popts.node_to_loc = [](const Node* node) {
+ return node->assigned_device_name();
+ };
+ popts.new_name = [this](const string& prefix) {
+ mutex_lock l(mu_);
+ return strings::StrCat(prefix, "/_", name_counter_++);
+ };
+ popts.get_incarnation = [](const string& name) {
+ // The local session does not have changing incarnation numbers.
+ // Just return '1'.
+ return 1;
+ };
+ popts.control_flow_added = false;
+ TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
+
+ std::vector<string> device_names;
+ for (auto device : devices_) {
+ // Extract the LocalName from the device.
+ device_names.push_back(DeviceNameUtils::LocalName(device->name()));
+ }
+
+ // Check for valid partitions.
+ for (const auto& partition : partitions) {
+ const string& local_partition_name =
+ DeviceNameUtils::LocalName(partition.first);
+ if (std::count(device_names.begin(), device_names.end(),
+ local_partition_name) == 0) {
+ return errors::InvalidArgument(
+ "Creating a partition for ", local_partition_name,
+ " which doesn't exist in the list of available devices. Available "
+ "devices: ",
+ str_util::Join(device_names, ","));
+ }
+ }
+
+ for (const auto& partition : partitions) {
+ const string& partition_name = partition.first;
+
+ const GraphDef& graph_def = partition.second;
+ VLOG(2) << "Created " << graph_def.DebugString() << " for "
+ << partition_name;
+
+ Graph* device_graph = new Graph(OpRegistry::Global());
+ GraphConstructorOptions device_opts;
+ // There are internal operations (e.g., send/recv) that we now
+ // allow.
+ device_opts.allow_internal_ops = true;
+ device_opts.expect_device_spec = true;
+ Status s =
+ ConvertGraphDefToGraph(device_opts, graph_def, device_graph);
+ if (!s.ok()) {
+ delete device_graph;
+ // Also delete other graphs created during the loop.
+ gtl::STLDeleteValues(outputs);
+ return s;
+ }
+ outputs->insert(std::make_pair(partition_name, device_graph));
+ }
+
+ return Status::OK();
+}
+
+::tensorflow::Status LocalSession::Close() {
+ cancellation_manager_->StartCancel();
+ return ::tensorflow::Status::OK();
+}
+
+class LocalSessionFactory : public SessionFactory {
+ public:
+ LocalSessionFactory() {}
+
+ Session* NewSession(const SessionOptions& options) override {
+ std::vector<Device*> devices;
+ DeviceFactory::AddDevices(options, "/job:localhost/replica:0/task:0",
+ &devices);
+ return new LocalSession(options, new DeviceMgr(devices));
+ }
+};
+
+class LocalSessionRegistrar {
+ public:
+ LocalSessionRegistrar() {
+ SessionFactory::Register("LOCAL_SESSION", new LocalSessionFactory());
+ }
+};
+static LocalSessionRegistrar registrar;
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/local_session.h b/tensorflow/core/common_runtime/local_session.h
new file mode 100644
index 0000000000..453cfdde47
--- /dev/null
+++ b/tensorflow/core/common_runtime/local_session.h
@@ -0,0 +1,109 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_LOCAL_SESSION_H_
+#define TENSORFLOW_COMMON_RUNTIME_LOCAL_SESSION_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class Device;
+
+class LocalSession : public Session {
+ public:
+ // Takes ownership of 'device_mgr'.
+ LocalSession(const SessionOptions& options, const DeviceMgr* device_mgr);
+ ~LocalSession() override;
+
+ ::tensorflow::Status Create(const GraphDef& graph) override;
+ ::tensorflow::Status Extend(const GraphDef& graph) override;
+ ::tensorflow::Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ std::vector<Tensor>* outputs) override;
+ ::tensorflow::Status Close() override;
+
+ private:
+ struct ExecutorsAndKeys {
+ std::unordered_map<string, Executor*> device_executors;
+ std::unordered_map<string, string> input_keys;
+ std::unordered_map<string, string> output_keys;
+
+ ~ExecutorsAndKeys() {
+ for (auto it : device_executors) {
+ delete it.second;
+ }
+ }
+ };
+
+ // Retrieves an already existing set of executors to run 'inputs' and
+ // 'outputs', or creates and caches them for future use.
+ ::tensorflow::Status GetOrCreateExecutors(
+ gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
+ gtl::ArraySlice<string> target_nodes,
+ ExecutorsAndKeys** executors_and_keys);
+
+ // Creates several graphs given the existing graph_def_ and the
+ // input feeds and fetches, given 'devices'.
+ ::tensorflow::Status CreateGraphs(
+ gtl::ArraySlice<string> feeds, gtl::ArraySlice<string> fetches,
+ gtl::ArraySlice<string> target_nodes,
+ std::unordered_map<string, Graph*>* outputs);
+
+ ::tensorflow::Status ExtendLocked(const GraphDef& graph)
+ EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
+
+ const SessionOptions options_;
+
+ // Device structures.
+ const std::unique_ptr<const DeviceMgr> device_mgr_;
+ std::vector<Device*> devices_; // not owned
+ DeviceSet device_set_;
+
+ string session_handle_;
+ bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
+
+ mutex graph_def_lock_;
+ GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
+
+ mutex executor_lock_; // protects executors_
+ // Holds mappings from signature to the executors that process
+ // it. The reason for a level of indirection around mapped_type is
+ // to guarantee address stability.
+ std::unordered_map<string, ExecutorsAndKeys*> executors_
+ GUARDED_BY(executor_lock_);
+
+ CancellationManager* cancellation_manager_;
+
+ // Saves and restores device placements for stateful nodes.
+ mutex mu_;
+ void SaveStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ void RestoreStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ // Map of placed stateful nodes, i.e. nodes for which is_stateful()
+ // is true, such as "params" and "queue" nodes. Once placed these
+ // nodes can not be moved to a different device. Maps node names to
+ // device names.
+ std::unordered_map<string, string> stateful_placements_ GUARDED_BY(mu_);
+
+ // For generating unique names.
+ int64 name_counter_ GUARDED_BY(mu_) = 0;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(LocalSession);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_LOCAL_SESSION_H_
diff --git a/tensorflow/core/common_runtime/local_session_test.cc b/tensorflow/core/common_runtime/local_session_test.cc
new file mode 100644
index 0000000000..9325fe44c3
--- /dev/null
+++ b/tensorflow/core/common_runtime/local_session_test.cc
@@ -0,0 +1,314 @@
+#include "tensorflow/core/common_runtime/local_session.h"
+
+#include <map>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/device_name_utils.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+Session* CreateSession() {
+ SessionOptions options;
+ (*options.config.mutable_device_count())["CPU"] = 2;
+ return NewSession(options);
+}
+
+class LocalSessionMinusAXTest : public ::testing::Test {
+ public:
+ void Initialize(std::initializer_list<float> a_values) {
+ RequireDefaultOps();
+ Graph graph(OpRegistry::Global());
+
+ Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
+ test::FillValues<float>(&a_tensor, a_values);
+ Node* a = test::graph::Constant(&graph, a_tensor);
+ a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
+
+ Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
+ test::FillValues<float>(&x_tensor, {1, 1});
+ Node* x = test::graph::Constant(&graph, x_tensor);
+ x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
+ x_ = x->name();
+
+ // y = A * x
+ Node* y = test::graph::Matmul(&graph, a, x, false, false);
+ y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
+ y_ = y->name();
+
+ Node* y_neg = test::graph::Unary(&graph, "Neg", y);
+ y_neg_ = y_neg->name();
+ y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
+
+ test::graph::ToGraphDef(&graph, &def_);
+ }
+
+ string x_;
+ string y_;
+ string y_neg_;
+ GraphDef def_;
+};
+
+TEST_F(LocalSessionMinusAXTest, RunSimpleNetwork) {
+ Initialize({3, 2, -1, 0});
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_OK(session->Create(def_));
+ std::vector<std::pair<string, Tensor>> inputs;
+
+ // Request two targets: one fetch output and one non-fetched output.
+ std::vector<string> output_names = {y_ + ":0"};
+ std::vector<string> target_nodes = {y_neg_};
+ std::vector<Tensor> outputs;
+ Status s = session->Run(inputs, output_names, target_nodes, &outputs);
+ ASSERT_OK(s);
+
+ ASSERT_EQ(1, outputs.size());
+ // The first output should be initiailzed and have the correct
+ // output.
+ auto mat = outputs[0].matrix<float>();
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ EXPECT_FLOAT_EQ(5.0, mat(0, 0));
+}
+
+TEST_F(LocalSessionMinusAXTest, TestFeed) {
+ Initialize({1, 2, 3, 4});
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+
+ ASSERT_OK(session->Create(def_));
+
+ // Fill in the input and ask for the output
+ //
+ // Note that the input being fed is on the second device.
+ Tensor t(DT_FLOAT, TensorShape({2, 1}));
+ t.matrix<float>()(0, 0) = 5;
+ t.matrix<float>()(1, 0) = 6;
+ std::vector<std::pair<string, Tensor>> inputs = {{x_, t}};
+ std::vector<string> output_names = {y_ + ":0"};
+ std::vector<Tensor> outputs;
+
+ // Run the graph
+ Status s = session->Run(inputs, output_names, {}, &outputs);
+ ASSERT_OK(s);
+
+ ASSERT_EQ(1, outputs.size());
+ auto mat = outputs[0].matrix<float>();
+
+ // Expect outputs to be; 1*5 + 2*6, 3*5 + 4*6
+ EXPECT_FLOAT_EQ(17.0, mat(0, 0));
+ EXPECT_FLOAT_EQ(39.0, mat(1, 0));
+}
+
+TEST_F(LocalSessionMinusAXTest, TestConcurrency) {
+ Initialize({1, 2, 3, 4});
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_OK(session->Create(def_));
+
+ // Fill in the input and ask for the output
+ thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
+
+ // Run the graph 1000 times in 4 different threads concurrently.
+ std::vector<string> output_names = {y_ + ":0"};
+ auto fn = [&session, output_names]() {
+ for (int i = 0; i < 1000; ++i) {
+ std::vector<std::pair<string, Tensor>> inputs;
+ std::vector<Tensor> outputs;
+ // Run the graph
+ Status s = session->Run(inputs, output_names, {}, &outputs);
+ ASSERT_TRUE(s.ok());
+ ASSERT_EQ(1, outputs.size());
+ auto mat = outputs[0].matrix<float>();
+ EXPECT_FLOAT_EQ(3.0, mat(0, 0));
+ }
+ };
+
+ for (int i = 0; i < 4; ++i) {
+ tp->Schedule(fn);
+ }
+
+ // Wait for the functions to finish.
+ delete tp;
+}
+
+TEST_F(LocalSessionMinusAXTest, TwoCreateCallsFails) {
+ Initialize({1, 2, 3, 4});
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_OK(session->Create(def_));
+
+ // Second is not.
+ ASSERT_FALSE(session->Create(def_).ok());
+}
+
+TEST_F(LocalSessionMinusAXTest, ForgetToCreate) {
+ Initialize({1, 2, 3, 4});
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ std::vector<std::pair<string, Tensor>> inputs;
+ std::vector<Tensor> outputs;
+ ASSERT_FALSE(session->Run(inputs, {y_ + ":0"}, {y_neg_}, &outputs).ok());
+}
+
+TEST_F(LocalSessionMinusAXTest, InvalidDevice) {
+ GraphDef def;
+ Graph graph(OpRegistry::Global());
+
+ Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
+ a_tensor.flat<float>().setRandom();
+ Node* a = test::graph::Constant(&graph, a_tensor);
+ a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
+ Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
+ x_tensor.flat<float>().setRandom();
+ Node* x = test::graph::Constant(&graph, x_tensor);
+ x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
+ // Skip placing y.
+ Node* y = test::graph::Matmul(&graph, a, x, false, false);
+ y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:2");
+
+ test::graph::ToGraphDef(&graph, &def);
+
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_OK(session->Create(def));
+ std::vector<std::pair<string, Tensor>> inputs;
+ std::vector<string> output_names = {y->name() + ":0"};
+ std::vector<Tensor> outputs;
+
+ // Should return an error.
+ ASSERT_FALSE(session->Run(inputs, output_names, {}, &outputs).ok());
+
+ // Fix placement and run again
+ def.Clear();
+ y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
+ test::graph::ToGraphDef(&graph, &def);
+ session.reset(CreateSession());
+ ASSERT_OK(session->Create(def));
+ ASSERT_OK(session->Run(inputs, output_names, {}, &outputs));
+}
+
+TEST(LocalSessionTest, KeepsStateAcrossRunsOfSession) {
+ GraphDef def;
+ Graph g(OpRegistry::Global());
+ Node* var = test::graph::Var(&g, DT_FLOAT, TensorShape({10}));
+ var->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
+
+ Tensor twenty(DT_FLOAT, TensorShape({10}));
+ for (int i = 0; i < 10; ++i) {
+ twenty.flat<float>()(i) = 20.0;
+ }
+
+ Node* twenty_node = test::graph::Constant(&g, twenty);
+ twenty_node->set_assigned_device_name(
+ "/job:localhost/replica:0/task:0/cpu:0");
+
+ Node* init = test::graph::Assign(&g, var, twenty_node);
+ init->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
+
+ test::graph::ToGraphDef(&g, &def);
+
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_OK(session->Create(def));
+
+ std::vector<std::pair<string, Tensor>> inputs;
+ std::vector<Tensor> outputs;
+
+ // Initialize the variable
+ Status s = session->Run(inputs, {init->name()}, {}, &outputs);
+ ASSERT_OK(s);
+
+ // Get the variable's data
+ s = session->Run(inputs, {var->name() + ":0"}, {}, &outputs);
+ ASSERT_OK(s);
+ ASSERT_EQ(1, outputs.size());
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ EXPECT_EQ(20.0, outputs[0].flat<float>()(0));
+}
+
+TEST(LocalSessionTest, MultipleFeedTest) {
+ GraphDef def;
+ Graph g(OpRegistry::Global());
+ Node* var = test::graph::Var(&g, DT_FLOAT, TensorShape({10}));
+ var->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
+
+ Tensor first_value(DT_FLOAT, TensorShape({}));
+ first_value.scalar<float>()() = 1.0;
+ Node* first_const = test::graph::Constant(&g, first_value);
+ Node* first_identity = test::graph::Identity(&g, first_const);
+
+ Tensor second_value(DT_FLOAT, TensorShape({}));
+ second_value.scalar<float>()() = 2.0;
+ Node* second_const = test::graph::Constant(&g, second_value);
+ Node* second_identity = test::graph::Identity(&g, second_const);
+
+ test::graph::ToGraphDef(&g, &def);
+
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_OK(session->Create(def));
+
+ std::vector<Tensor> outputs;
+
+ // Fetch without feeding.
+ Status s = session->Run(
+ {}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
+ &outputs);
+ ASSERT_TRUE(s.ok());
+ ASSERT_EQ(2, outputs.size());
+ ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
+ ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
+
+ s = session->Run(
+ {}, {second_identity->name() + ":0", first_identity->name() + ":0"}, {},
+ &outputs);
+ ASSERT_TRUE(s.ok());
+ ASSERT_EQ(2, outputs.size());
+ ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
+ ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
+
+ Tensor value_11(DT_FLOAT, TensorShape({}));
+ value_11.scalar<float>()() = 11.0;
+ Tensor value_22(DT_FLOAT, TensorShape({}));
+ value_22.scalar<float>()() = 22.0;
+
+ // Feed [first_const, second_const]
+ s = session->Run(
+ {{first_const->name(), value_11}, {second_const->name(), value_22}},
+ {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
+ &outputs);
+ ASSERT_TRUE(s.ok());
+ ASSERT_EQ(2, outputs.size());
+ ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
+ ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
+
+ // Feed [second_const, first_const]
+ s = session->Run(
+ {{second_const->name(), value_22}, {first_const->name(), value_11}},
+ {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
+ &outputs);
+ ASSERT_TRUE(s.ok());
+ ASSERT_EQ(2, outputs.size());
+ ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
+ ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
+}
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc
new file mode 100644
index 0000000000..111dea6d4c
--- /dev/null
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc
@@ -0,0 +1,170 @@
+#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+
+#include <unordered_set>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#if (!defined(PLATFORM_POSIX_ANDROID) && !defined(PLATFORM_GOOGLE_ANDROID)) && \
+ (defined(PLATFORM_GOOGLE) || GOOGLE_CUDA)
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#endif
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+namespace {
+
+void CopyTensorBetweenDevices(const string& id, DeviceContext* send_dev_context,
+ DeviceContext* recv_dev_context, Device* src,
+ Device* dst,
+ const AllocatorAttributes src_alloc_attr,
+ const AllocatorAttributes dst_alloc_attr,
+ const Tensor* input, Tensor* output,
+ std::function<void(const Status&)> done) {
+ if (src->attributes().device_type() != dst->attributes().device_type()) {
+ done(errors::Unimplemented(
+ "Copy between device types not yet implemented: src=", src->name(),
+ " dst=", dst->name()));
+ } else if (src->attributes().device_type() != "CPU") {
+ done(errors::Unimplemented(
+ "Copy between non-CPU devices not yet implemented"));
+ }
+ *output = *input;
+ done(Status::OK());
+}
+
+#if (!defined(PLATFORM_POSIX_ANDROID) && !defined(PLATFORM_GOOGLE_ANDROID)) && \
+ (defined(PLATFORM_GOOGLE) || GOOGLE_CUDA)
+constexpr auto CopyTensorBetweenDevicesFunc = &GPUUtil::CopyViaDMA;
+#else
+constexpr auto CopyTensorBetweenDevicesFunc = &CopyTensorBetweenDevices;
+#endif
+
+} // end namespace
+
+IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr)
+ : device_mgr_(device_mgr), local_(NewLocalRendezvous()) {}
+
+IntraProcessRendezvous::~IntraProcessRendezvous() { local_->Unref(); }
+
+Status IntraProcessRendezvous::Send(const string& key,
+ const Rendezvous::Args& args,
+ const Tensor& val, const bool is_dead) {
+ VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key;
+ {
+ mutex_lock l(mu_);
+ if (!status_.ok()) return status_;
+ }
+ Rendezvous::ParsedKey parsed;
+ TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
+
+ // Buffers "val" and "device_context" in local_.
+ return local_->Send(key, args, val, is_dead);
+}
+
+Status IntraProcessRendezvous::ParseKey(const string& key, bool is_src,
+ Rendezvous::ParsedKey* parsed) {
+ {
+ mutex_lock l(mu_);
+ if (!status_.ok()) return status_;
+ }
+ TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, parsed));
+ return Status::OK();
+}
+
+void IntraProcessRendezvous::SameWorkerRecvDone(
+ const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
+ StatusCallback done) {
+ // Do a quick copy (sharing the underlying buffer) if both tensors
+ // are on host memory.
+ const bool src_host =
+ (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
+ const bool dst_host =
+ (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
+ if (src_host && dst_host) {
+ *out = in;
+ done(Status::OK());
+ return;
+ }
+
+ // This copy must involve a non-CPU device. Hence, "in" must support DMA
+ // (e.g., string tensors do not work on GPU).
+ if (!DataTypeCanUseMemcpy(in.dtype())) {
+ done(errors::InvalidArgument("Non-DMA-safe ", DataTypeString(in.dtype()),
+ " tensor may not be copied from/to a GPU."));
+ return;
+ }
+
+ Device* src_device;
+ Status s = device_mgr_->LookupDevice(parsed.src_device, &src_device);
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+ Device* dst_device;
+ s = device_mgr_->LookupDevice(parsed.dst_device, &dst_device);
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+
+ AllocatorAttributes attr = recv_args.alloc_attrs;
+ attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
+ recv_args.alloc_attrs.gpu_compatible());
+ Allocator* out_allocator = dst_device->GetAllocator(attr);
+ Tensor copy(out_allocator, in.dtype(), in.shape());
+ *out = copy;
+
+ CopyTensorBetweenDevicesFunc(parsed.edge_name, send_args.device_context,
+ recv_args.device_context, src_device, dst_device,
+ send_args.alloc_attrs, recv_args.alloc_attrs,
+ &in, out, done);
+}
+
+void IntraProcessRendezvous::RecvAsync(const string& key,
+ const Rendezvous::Args& recv_args,
+ DoneCallback done) {
+ VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key;
+
+ Rendezvous::ParsedKey parsed;
+ Status s = ParseKey(key, false /*!is_src*/, &parsed);
+ if (!s.ok()) {
+ done(s, Args(), recv_args, Tensor(), false);
+ return;
+ }
+
+ // Recv the tensor from local_.
+ local_->RecvAsync(key, recv_args, [this, parsed, done](
+ const Status& status,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& in, bool is_dead) {
+ Status s = status;
+ Tensor* out = new Tensor;
+ StatusCallback final_callback = [done, send_args, recv_args, out,
+ is_dead](const Status& s) {
+ done(s, send_args, recv_args, *out, is_dead);
+ delete out;
+ };
+
+ if (s.ok()) {
+ SameWorkerRecvDone(parsed, send_args, recv_args, in, out, final_callback);
+ } else {
+ final_callback(s);
+ }
+ });
+}
+
+void IntraProcessRendezvous::StartAbort(const Status& s) {
+ CHECK(!s.ok());
+ local_->StartAbort(s);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.h b/tensorflow/core/common_runtime/rendezvous_mgr.h
new file mode 100644
index 0000000000..eaae65f956
--- /dev/null
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.h
@@ -0,0 +1,73 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
+#define TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+// IntraProcessRendezvous is a Rendezvous which expects all producers
+// and consumers to be devices immediately accessible within the
+// process. That is, it will never be necessary to perform an RPC to
+// communicate with either.
+//
+// Buffering of Tensor values is delegated to a "local" Rendezvous
+// obtained from NewLocalRendezvous(). This class just adds
+// functionality to coordinate multiple process-local devices.
+class IntraProcessRendezvous : public Rendezvous {
+ public:
+ explicit IntraProcessRendezvous(const DeviceMgr* device_mgr);
+
+ // Forwards to local_, where the Tensor "val" will be buffered and
+ // any waiting callback stored.
+ Status Send(const string& key, const Rendezvous::Args& args,
+ const Tensor& val, const bool is_dead) override;
+
+ // This method is called only by the RecvOp. It tests to see
+ // whether the value will be produced by a local or remote device
+ // and handles accordingly. In the local case it forwards to
+ // local_, in the remote case it initiates an RPC request.
+ void RecvAsync(const string& key, const Rendezvous::Args& args,
+ DoneCallback done) override;
+
+ void StartAbort(const Status& status) override;
+
+ private:
+ const DeviceMgr* device_mgr_;
+ Rendezvous* local_; // Owns a Ref on this object.
+
+ mutable mutex mu_;
+
+ // Status given by StartAbort() if any.
+ Status status_ GUARDED_BY(mu_);
+
+ ~IntraProcessRendezvous() override;
+
+ // Parses "key" into "parsed". If "is_src" is true, checks that the
+ // rendezvous key's source is in this process. If "is_src" is false,
+ // checks that the rendezvous key's destination is in this process.
+ Status ParseKey(const string& key, bool is_src,
+ Rendezvous::ParsedKey* parsed);
+
+ // Callback handling the case when a rendezvous has been
+ // accomplished in local_ and the consumer is local to this process.
+ // Tensor "in" will be copied into "out". The key "parsed" encodes
+ // the src and dst devices.
+ typedef std::function<void(const Status&)> StatusCallback;
+ void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args, const Tensor& in,
+ Tensor* out, StatusCallback done);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(IntraProcessRendezvous);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
diff --git a/tensorflow/core/common_runtime/session.cc b/tensorflow/core/common_runtime/session.cc
new file mode 100644
index 0000000000..6d1ab5cea4
--- /dev/null
+++ b/tensorflow/core/common_runtime/session.cc
@@ -0,0 +1,51 @@
+#include <string>
+
+#include "tensorflow/core/common_runtime/session_factory.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+
+namespace {
+Status GetFactory(const SessionOptions& options, SessionFactory** ret) {
+ string runtime_type = "LOCAL_SESSION";
+ if (!options.target.empty()) {
+ // Use the service based session.
+ runtime_type = "REMOTE_SESSION";
+ }
+ *ret = SessionFactory::GetFactory(runtime_type);
+ if (!*ret) {
+ return errors::NotFound("Could not find session factory for ",
+ runtime_type);
+ }
+ return Status::OK();
+}
+} // end namespace
+
+Session* NewSession(const SessionOptions& options) {
+ SessionFactory* factory;
+ Status s = GetFactory(options, &factory);
+ if (!s.ok()) {
+ LOG(ERROR) << s;
+ return nullptr;
+ }
+ return factory->NewSession(options);
+}
+
+Status NewSession(const SessionOptions& options, Session** out_session) {
+ SessionFactory* factory;
+ Status s = GetFactory(options, &factory);
+ if (!s.ok()) {
+ *out_session = nullptr;
+ LOG(ERROR) << s;
+ return s;
+ }
+ *out_session = factory->NewSession(options);
+ if (!*out_session) {
+ return errors::Internal("Failed to create session.");
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_factory.cc b/tensorflow/core/common_runtime/session_factory.cc
new file mode 100644
index 0000000000..666b99812d
--- /dev/null
+++ b/tensorflow/core/common_runtime/session_factory.cc
@@ -0,0 +1,41 @@
+#include "tensorflow/core/common_runtime/session_factory.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+namespace tensorflow {
+namespace {
+
+static mutex* get_session_factory_lock() {
+ static mutex session_factory_lock;
+ return &session_factory_lock;
+}
+
+typedef std::unordered_map<string, SessionFactory*> SessionFactories;
+SessionFactories* session_factories() {
+ static SessionFactories* factories = new SessionFactories;
+ return factories;
+}
+
+} // namespace
+
+void SessionFactory::Register(const string& runtime_type,
+ SessionFactory* factory) {
+ mutex_lock l(*get_session_factory_lock());
+ if (!session_factories()->insert({runtime_type, factory}).second) {
+ LOG(ERROR) << "Two session factories are being registered "
+ << "under" << runtime_type;
+ }
+}
+
+SessionFactory* SessionFactory::GetFactory(const string& runtime_type) {
+ mutex_lock l(*get_session_factory_lock()); // could use reader lock
+ auto it = session_factories()->find(runtime_type);
+ if (it == session_factories()->end()) {
+ return nullptr;
+ }
+ return it->second;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_factory.h b/tensorflow/core/common_runtime/session_factory.h
new file mode 100644
index 0000000000..f770ba93ff
--- /dev/null
+++ b/tensorflow/core/common_runtime/session_factory.h
@@ -0,0 +1,25 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_SESSION_FACTORY_H_
+#define TENSORFLOW_COMMON_RUNTIME_SESSION_FACTORY_H_
+
+#include <string>
+
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class Session;
+class SessionOptions;
+
+class SessionFactory {
+ public:
+ virtual Session* NewSession(const SessionOptions& options) = 0;
+ virtual ~SessionFactory() {}
+ static void Register(const string& runtime_type, SessionFactory* factory);
+ static SessionFactory* GetFactory(const string& runtime_type);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_SESSION_FACTORY_H_
diff --git a/tensorflow/core/common_runtime/session_options.cc b/tensorflow/core/common_runtime/session_options.cc
new file mode 100644
index 0000000000..ef585efb5c
--- /dev/null
+++ b/tensorflow/core/common_runtime/session_options.cc
@@ -0,0 +1,9 @@
+#include "tensorflow/core/public/session_options.h"
+
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+
+SessionOptions::SessionOptions() : env(Env::Default()) {}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_test.cc b/tensorflow/core/common_runtime/session_test.cc
new file mode 100644
index 0000000000..82b5d7ffb0
--- /dev/null
+++ b/tensorflow/core/common_runtime/session_test.cc
@@ -0,0 +1,17 @@
+#include "tensorflow/core/public/session.h"
+
+#include "tensorflow/core/public/session_options.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+TEST(SessionTest, InvalidTargetReturnsNull) {
+ SessionOptions options;
+ options.target = "invalid target";
+
+ EXPECT_EQ(nullptr, tensorflow::NewSession(options));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc
new file mode 100644
index 0000000000..1cd1db29db
--- /dev/null
+++ b/tensorflow/core/common_runtime/simple_placer.cc
@@ -0,0 +1,559 @@
+#include "tensorflow/core/common_runtime/simple_placer.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Returns a list of devices sorted by name from 'devices' whose type is in
+// 'supported_device_types'. This function searches in order of the device
+// types in 'supported_device_types' and returns the *first* subset of devices
+// that match.
+//
+// For example, if suported_device_types contains {GPU, CPU} and
+// 'devices' contains CPU and GPU devices, the returned vector will
+// include *only* GPU devices, since that is higher in the priority
+// order in 'supported_device_types'.
+std::vector<Device*> FilterSupportedDevices(
+ const std::vector<Device*>& devices,
+ const DeviceTypeVector& supported_device_types) {
+ std::vector<Device*> filtered_devices;
+ auto device_sort = [](const Device* a, const Device* b) {
+ return a->name() < b->name();
+ };
+ for (DeviceType d : supported_device_types) {
+ for (Device* device : devices) {
+ if (DeviceType(device->attributes().device_type()) == d) {
+ filtered_devices.emplace_back(device);
+ }
+ }
+
+ // If there are any devices under this device type, return this
+ // subset.
+ if (!filtered_devices.empty()) {
+ std::sort(filtered_devices.begin(), filtered_devices.end(), device_sort);
+ return filtered_devices;
+ }
+ }
+
+ std::sort(filtered_devices.begin(), filtered_devices.end(), device_sort);
+ return filtered_devices;
+}
+
+bool HasColocatedNodeName(const Node& node) {
+ return StringPiece(node.def().device()).starts_with("@");
+}
+
+Status ParseColocatedNodeName(const Node& node,
+ string* out_colocated_node_name) {
+ StringPiece device(node.def().device());
+ if (!device.Consume("@")) {
+ return errors::InvalidArgument("Malformed colocated node name: '", device,
+ "'");
+ }
+ // TODO(mrry): Validate that the node name is a valid node name.
+ *out_colocated_node_name = device.ToString();
+ return Status::OK();
+}
+
+// This class maintains the connected components of a colocation
+// constraint graph, and uses this information to assign a satisfying
+// device placement to the nodes of the graph.
+//
+// The typical usage pattern is:
+//
+// Graph graph = ...;
+// DeviceSet device_set = ...;
+// ColocationGraph colocation_graph(graph, device_set);
+//
+// // Add all the nodes of graph to colocation_graph.
+// for (Node* node : graph.nodes()) {
+// TF_RETURN_IF_ERROR(colocation_graph.AddNode(*node));
+// }
+//
+// // Add one or more colocation constraint.
+// Node node_1 = *graph.FindNodeId(...);
+// Node node_2 = *graph.FindNodeId(...);
+// TF_RETURN_IF_ERROR(colocation_graph.ColocateNodes(node_1, node_2));
+//
+// // Assign devices based on the accumulated constraints.
+// for (Node* node : graph.nodes()) {
+// TF_RETURN_IF_ERROR(colocation_graph.AssignDevice(node));
+// }
+//
+// The implementation uses the union-find algorithm to maintain the
+// connected components efficiently and incrementally as edges
+// (implied by ColocationGraph::ColocateNodes() invocations) are added.
+class ColocationGraph {
+ public:
+ ColocationGraph(Graph* graph, const DeviceSet* device_set,
+ const SessionOptions* options)
+ : device_set_(device_set),
+ device_types_(device_set->PrioritizedDeviceTypeList()),
+ options_(options) {
+ members_.reserve(graph->num_node_ids());
+ }
+
+ // Adds the given node to this ColocationGraph as a singleton.
+ //
+ // NOTE: The implementation assumes that the ids of nodes passed to
+ // this method are dense and zero-based; the memory used will be linear in
+ // the largest node ID.
+ // NOTE: If this method returns an error, *this is left in an undefined
+ // state.
+ Status AddNode(const Node& node) {
+ Member member;
+ TF_RETURN_IF_ERROR(InitializeMember(node, &member));
+ CHECK_GE(member.parent, 0);
+ members_.resize(member.parent + 1);
+ members_[member.parent] = std::move(member);
+ return Status::OK();
+ }
+
+ // Merge the (possibly disjoint) sets containing nodes "x" and
+ // "y". Returns OK if the all nodes in the union of these sets can
+ // be placed on the same device type.
+ //
+ // NOTE: If this method returns an error, *this is left in an undefined
+ // state.
+ Status ColocateNodes(const Node& x, const Node& y) {
+ int x_root = FindRoot(x.id());
+ int y_root = FindRoot(y.id());
+ if (x_root != y_root) {
+ // Merge the sets by swinging the parent pointer of the smaller
+ // tree to point to the root of the larger tree. Together with
+ // path compression in ColocationGraph::FindRoot, this ensures
+ // that we do not experience pathological performance on graphs
+ // such as chains.
+ int new_root, old_root;
+ if (members_[x_root].rank < members_[y_root].rank) {
+ // The tree rooted at x_root is shallower, so connect it to
+ // y_root. The rank of y_root is unchanged because its new
+ // child has strictly less rank.
+ members_[x_root].parent = y_root;
+ new_root = y_root;
+ old_root = x_root;
+ } else if (members_[x_root].rank > members_[y_root].rank) {
+ // The tree rooted at y_root is shallower, so connect it to
+ // x_root. The rank of x_root is unchanged because its new
+ // child has strictly less rank.
+ members_[y_root].parent = x_root;
+ new_root = x_root;
+ old_root = y_root;
+ } else {
+ // Both trees have the same rank, so break the tie by choosing
+ // x_root as the new root.
+ members_[y_root].parent = x_root;
+ // Increment the rank of the tree rooted at x_root, because it
+ // is now strictly deeper than before.
+ ++members_[x_root].rank;
+ new_root = x_root;
+ old_root = y_root;
+ }
+
+ // Merge the partial device specifications, and ensure that they are
+ // compatible. NULL options_ is treated as allowing soft placement.
+ // TODO(mrry): Consider enriching the error message by pointing
+ // out which nodes have the explicit partial device
+ // specifications that caused this conflict.
+ TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
+ &members_[new_root].device_name, members_[old_root].device_name,
+ options_ == nullptr || options_->config.allow_soft_placement()));
+
+ // Ensure that the common root has at least one supported device
+ // type, by computing the intersection of
+ // members_[new_root].supported_device_types and
+ // members_[old_root].supported_device_types.
+ MergeSupportedDevices(&members_[new_root].supported_device_types,
+ members_[old_root].supported_device_types);
+ if (members_[x_root].supported_device_types.size() == 0) {
+ return errors::InvalidArgument(
+ "Cannot colocate nodes '", x.name(), "' and '", y.name(),
+ "' because no device type supports both of those nodes and the "
+ "other nodes colocated with them");
+ }
+ }
+ return Status::OK();
+ }
+
+ // For the given node, subject to the constraints previously given
+ // to this ColocationGraph, set its assigned_device_name. Returns OK
+ // if a satisfying device can be found, otherwise an error.
+ Status AssignDevice(Node* node) {
+ int node_root = FindRoot(node->id());
+ if (members_[node_root].assigned_device == nullptr) {
+ // We have not yet assigned a device for the colocated node set containing
+ // n, so we do so now using the constraints on the root node.
+
+ // "devices" will contain the set of feasible placements for the
+ // colocated node set containing n.
+ std::vector<Device*> devices;
+ if (DeviceNameUtils::HasSomeDetails(members_[node_root].device_name)) {
+ // The root node has a (possibly partial) device
+ // specification, so enumerate the physical devices that
+ // conform to it.
+ device_set_->FindMatchingDevices(members_[node_root].device_name,
+ &devices);
+
+ if (!devices.empty()) {
+ // Filter devices into those that are compatible with the root
+ // node (and its children).
+ devices = FilterSupportedDevices(
+ devices, members_[node_root].supported_device_types);
+ }
+
+ // Perform soft placement if allow_soft_placement is set. options_
+ // being NULL is treated as allowing soft placement.
+ if (devices.empty() &&
+ (options_ == nullptr || options_->config.allow_soft_placement())) {
+ // The soft_device_name is the same as the node's device name
+ // without specifying the device type or ID.
+ DeviceNameUtils::ParsedName soft_device_name =
+ members_[node_root].device_name;
+ soft_device_name.type.clear();
+ soft_device_name.has_type = false;
+ soft_device_name.has_id = false;
+ device_set_->FindMatchingDevices(soft_device_name, &devices);
+ if (!devices.empty()) {
+ devices = FilterSupportedDevices(
+ devices, members_[node_root].supported_device_types);
+ }
+ }
+
+ if (devices.empty()) {
+ // Return an error when a physical device that matches an explicit
+ // device specification is not found. This ensures that we don't
+ // assign a node to GPU when the user wanted to force it on CPU.
+ DeviceNameUtils::ParsedName specified_device_name;
+ if (DeviceNameUtils::ParseFullName(node->def().device(),
+ &specified_device_name) &&
+ specified_device_name == members_[node_root].device_name) {
+ // The specified device and merged set device match, and
+ // will appear in the GraphDef (for debugging), so just
+ // print the specified device.
+ return errors::InvalidArgument(
+ "Could not satisfy explicit device specification '",
+ node->def().device(), "'");
+ } else {
+ // The specified device may be a valid device but the
+ // merged set device is different, so print both.
+ return errors::InvalidArgument(
+ "Could not satisfy explicit device specification '",
+ node->def().device(),
+ "' because the node was colocated with a group of nodes that "
+ "required incompatible device '",
+ DeviceNameUtils::ParsedNameToString(
+ members_[node_root].device_name),
+ "'");
+ }
+ }
+ } else {
+ // The device is completely unspecified, so enumerate the devices that
+ // support all of the nodes in the set.
+ if (device_set_->devices().empty()) {
+ return errors::Internal("No devices are registered");
+ }
+ devices = FilterSupportedDevices(
+ device_set_->devices(), members_[node_root].supported_device_types);
+
+ if (devices.empty()) {
+ return errors::InvalidArgument(
+ "Node had no OpKernel registered to support this operation: ",
+ "Operation was ", node->type_string(), " and inputs were ",
+ DataTypeVectorString(node->input_types()));
+ }
+ }
+
+ // Returns the first device in sorted devices list so we will always
+ // choose the same device.
+ members_[node_root].assigned_device = devices[0];
+ }
+ node->set_assigned_device_name(members_[node_root].assigned_device->name());
+
+ // Log placement if log_device_placement is set.
+ if (options_ && options_->config.log_device_placement()) {
+ printf("%s: %s\n", node->name().c_str(),
+ node->assigned_device_name().c_str());
+ LOG(INFO) << node->name() << ": " << node->assigned_device_name();
+ }
+
+ return Status::OK();
+ }
+
+ private:
+ // Represents a node in the disjoint node set forest, and the
+ // accumulated constraints on the device used by that node.
+ struct Member {
+ Member() = default;
+ // The id of the node that is the parent of this one, or its own
+ // id if it is a root. parent <= 0 indicates that this member is invalid.
+ int parent = -1;
+ // A proxy for the depth of the tree that is used to prefer
+ // connecting smaller trees to larger trees when merging disjoint
+ // sets.
+ int rank = 0;
+ // The intersection of all device types supported by this node,
+ // and those of all of its children, in priority order
+ // of the preferred device.
+ DeviceTypeVector supported_device_types;
+ // The merged form of the device requested for this node, with
+ // those of all of its children.
+ DeviceNameUtils::ParsedName device_name;
+ // If this node is a root, stores the Device to which this node
+ // and all of its children have been assigned, or nullptr if this
+ // has not yet been computed by GetAssignedDevice().
+ Device* assigned_device = nullptr;
+ };
+
+ Status InitializeMember(const Node& node, Member* member) {
+ const int id = node.id();
+ if (id < 0) {
+ return errors::InvalidArgument("Node id was not positive: ", id);
+ }
+ member->parent = id;
+ TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
+ device_types_, node.def(), &member->supported_device_types));
+
+ if (!node.assigned_device_name().empty()) {
+ // This node has already been assigned to a device, so we
+ // respect this placement, after sanity-checking it. The
+ // device_name and supported_device_types for this node reflect
+ // the assigned device, so any nodes colocated with this node
+ // will be assigned to the same device (assuming this is
+ // possible).
+ // NOTE: Since any assignment must have been performed by
+ // the TensorFlow runtime, we consider errors in this branch to
+ // be INTERNAL.
+ if (!DeviceNameUtils::ParseFullName(node.assigned_device_name(),
+ &member->device_name)) {
+ return errors::Internal("Malformed assigned device '",
+ node.assigned_device_name(), "'");
+ }
+ std::vector<Device*> devices;
+ const Device* assigned_device =
+ device_set_->FindDeviceByName(node.assigned_device_name());
+ if (assigned_device == nullptr) {
+ return errors::Internal("Assigned device '",
+ node.assigned_device_name(),
+ "' does not match any device");
+ }
+
+ for (DeviceType d : member->supported_device_types) {
+ if (DeviceType(assigned_device->attributes().device_type()) == d) {
+ return Status::OK();
+ }
+ }
+
+ return errors::Internal("Assigned device '", node.assigned_device_name(),
+ "' does not have registered OpKernel support "
+ "for ",
+ node.def().op());
+ } else {
+ // This node has not yet been assigned to a device, so we
+ // calculate any constraints due to the set of registered
+ // kernels and any (partial) user-provided device specification
+ // in the NodeDef.
+
+ // If no kernels are registered for this op type, fail with an error.
+ if (member->supported_device_types.empty()) {
+ return errors::InvalidArgument(
+ "No OpKernel was registered to support "
+ "Op '",
+ node.def().op(), "' with these attrs");
+ }
+
+ // If the NodeDef contains a device that is *not* a colocated node name
+ // (i.e. it does not begin with '@') then we interpret it as a (partial)
+ // device specification.
+ string colocated_node_name;
+ if (!node.def().device().empty() && !HasColocatedNodeName(node)) {
+ // The user has specified a device in the NodeDef, try to find a
+ // valid device matching their specification in the set of
+ // devices.
+ // NOTE: The full name may specify a device that is not in
+ // n.supported_device_types(), but we check that in AssignDevice().
+ if (!DeviceNameUtils::ParseFullName(node.def().device(),
+ &member->device_name)) {
+ return errors::InvalidArgument("Malformed device specification '",
+ node.def().device(), "'");
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ // Updates target to contain the intersection of the device types in
+ // "target" and "other".
+ static void MergeSupportedDevices(DeviceTypeVector* target,
+ const DeviceTypeVector& other) {
+ DeviceTypeVector temp = *target;
+ target->clear();
+
+ // Iterate in priority order.
+ for (DeviceType device_type : temp) {
+ bool found = false;
+ for (DeviceType other_device_type : other) {
+ if (device_type == other_device_type) {
+ found = true;
+ break;
+ }
+ }
+ if (found) {
+ target->push_back(device_type);
+ }
+ }
+ }
+
+ // Returns the root node of the disjoint tree to which the node with the
+ // given id is connected.
+ int FindRoot(int node_id) {
+ DCHECK_GE(members_[node_id].parent, 0);
+ if (members_[node_id].parent != node_id) {
+ // NOTE: Compress paths from node_id to its root, so that future
+ // calls to FindRoot and ColocateNodes are more efficient.
+ members_[node_id].parent = FindRoot(members_[node_id].parent);
+ }
+ return members_[node_id].parent;
+ }
+
+ std::vector<Member> members_;
+ const DeviceSet* device_set_; // Not owned.
+ const std::vector<DeviceType> device_types_;
+ const SessionOptions* options_; // Not owned;
+};
+
+} // namespace
+
+SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices,
+ const NodeNameToIdMap* name_to_id_map,
+ const SessionOptions* options)
+ : graph_(graph),
+ devices_(devices),
+ name_to_id_map_(name_to_id_map),
+ options_(options) {}
+
+SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices,
+ const NodeNameToIdMap* name_to_id_map)
+ : graph_(graph), devices_(devices), name_to_id_map_(name_to_id_map) {
+ options_ = nullptr;
+}
+
+SimplePlacer::~SimplePlacer() {}
+
+Status SimplePlacer::Run() {
+ if (devices_->devices().empty()) {
+ return errors::FailedPrecondition("No devices are registered");
+ }
+
+ ColocationGraph colocation_graph(graph_, devices_, options_);
+ Status status;
+
+ // 1. First add all of the nodes. Note that steps (1) and (2)
+ // requires two passes over the nodes because the graph (and hence
+ // the constraints) may not be acyclic.
+ for (Node* node : graph_->nodes()) {
+ // Skip the source and sink nodes.
+ if (!node->IsOp()) {
+ continue;
+ }
+ status = colocation_graph.AddNode(*node);
+ if (!status.ok()) return AttachDef(status, node->def());
+ }
+
+ // 2. Enumerate the constraint edges, and use them to update the disjoint
+ // node set.
+ for (Node* node : graph_->nodes()) {
+ if (!node->IsOp()) {
+ continue;
+ }
+
+ // 2(a). If node n specifies a colocation constraint as its device name,
+ // add an edge from the colocated node to n.
+ if (HasColocatedNodeName(*node)) {
+ string colocated_node_name;
+ status = ParseColocatedNodeName(*node, &colocated_node_name);
+ if (!status.ok()) {
+ return AttachDef(status, node->def());
+ }
+ Node* colocated_node;
+ status = GetNodeByName(colocated_node_name, &colocated_node);
+ if (!status.ok()) {
+ return AttachDef(
+ errors::InvalidArgument("Colocated node named in device '",
+ colocated_node_name, "' does not exist"),
+ node->def());
+ }
+ status = colocation_graph.ColocateNodes(*colocated_node, *node);
+ if (!status.ok()) {
+ return AttachDef(
+ errors::InvalidArgument(
+ "Cannot satisfy colocation constraint named in device '",
+ colocated_node_name, "': ", status.error_message()),
+ node->def());
+ }
+ }
+
+ // 2(b). If `node` has an input edge with reference type, add an
+ // edge from the source of that edge to `node`.
+ for (const auto& edge : node->in_edges()) {
+ if (!edge->IsControlEdge() &&
+ IsRefType(node->input_type(edge->dst_input()))) {
+ status = colocation_graph.ColocateNodes(*edge->src(), *node);
+ if (!status.ok()) {
+ return AttachDef(
+ errors::InvalidArgument("Cannot satisfy colocation constraint "
+ "implied by reference connection: ",
+ status.error_message()),
+ node->def());
+ }
+ }
+ }
+ }
+
+ // 3. For each node, assign a device based on the constraints in the
+ // disjoint node set.
+ for (Node* node : graph_->nodes()) {
+ // Skip the source and sink nodes.
+ if (!node->IsOp()) {
+ continue;
+ }
+ // Skip nodes that already have an assigned name.
+ if (!node->assigned_device_name().empty()) {
+ continue;
+ }
+
+ status = colocation_graph.AssignDevice(node);
+ if (!status.ok()) {
+ return AttachDef(
+ errors::InvalidArgument("Cannot assign a device to node '",
+ node->name(), "': ", status.error_message()),
+ node->def());
+ }
+ }
+ return Status::OK();
+}
+
+Status SimplePlacer::GetNodeByName(const string& name, Node** out_node) const {
+ NodeNameToIdMap::const_iterator iter = name_to_id_map_->find(name);
+ if (iter != name_to_id_map_->end()) {
+ *out_node = graph_->FindNodeId(iter->second);
+ if (*out_node) {
+ return Status::OK();
+ }
+ }
+ return errors::NotFound(name);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/simple_placer.h b/tensorflow/core/common_runtime/simple_placer.h
new file mode 100644
index 0000000000..4b3df50c72
--- /dev/null
+++ b/tensorflow/core/common_runtime/simple_placer.h
@@ -0,0 +1,81 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_
+#define TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/util/device_name_utils.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+// A placement algorithm that assigns the nodes of the given Graph to
+// devices the given DeviceSet, respecting the following constraints:
+//
+// 1. Existing device assignments remain unchanged.
+// 2. Requested (partial or complete) device specifications in the
+// are granted.
+// 3. Nodes connected by edges of a reference type are colocated on
+// the same device.
+// 4. Given nodes "A" and "B", if node "B" has the device specification
+// "@A", nodes "A" and "B" will be colocated on the same device.
+//
+// The implementation builds a constraint graph with the same set of
+// nodes, and edges that represent colocation constraints between
+// nodes. Each connected component in the resulting constraint graph
+// is then assigned to a single device.
+//
+// TODO(mrry): "Soft" constraints, such as "place node 'x' as close as
+// possible to node 'y' while respecting the other constraints"?
+// TODO(mrry): Create a common interface for this and the other
+// placement algorithms so that they may be injected into the graph
+// builder.
+class SimplePlacer {
+ public:
+ // A map from graph node names to numerical IDs (in a Graph object).
+ typedef std::unordered_map<string, int> NodeNameToIdMap;
+
+ // Creates an instance of the SimplePlacer algorithm for the given
+ // Graph "graph" (nodes in which may or may not be assigned) on the
+ // given DeviceSet "devices". The "name_to_id_map" maps the names of
+ // nodes in "g" to their numerical ID.
+ //
+ // REQUIRES: for all mappings (k, v) in "name_to_id_map",
+ // graph.FindNodeId(v)->name() == k.
+ //
+ // The "graph", "devices", and "name_to_id_map" pointer arguments
+ // are borrowed by this SimplePlacer, and must outlive it.
+ SimplePlacer(Graph* graph, const DeviceSet* devices,
+ const NodeNameToIdMap* name_to_id_map,
+ const SessionOptions* options);
+
+ SimplePlacer(Graph* graph, const DeviceSet* devices,
+ const NodeNameToIdMap* name_to_id_map);
+
+ ~SimplePlacer();
+
+ // Assigns each node in this SimplePlacer's graph to a device in its
+ // set of devices.
+ //
+ // This method is not thread-safe.
+ // Run() may be invoked at most once.
+ Status Run();
+
+ private:
+ Status GetNodeByName(const string& name, Node** out_node) const;
+
+ Graph* const graph_; // Not owned.
+ const DeviceSet* const devices_; // Not owned.
+ const NodeNameToIdMap* const name_to_id_map_; // Not owned.
+ const SessionOptions* options_; // Not owned.
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SimplePlacer);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_
diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc
new file mode 100644
index 0000000000..3139962d7e
--- /dev/null
+++ b/tensorflow/core/common_runtime/simple_placer_test.cc
@@ -0,0 +1,863 @@
+#include "tensorflow/core/common_runtime/simple_placer.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+namespace {
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// Op, kernel, and device registrations to set up the environment.
+//
+// The SimplePlacer uses information about the op (input types),
+// kernel (device constraints), and available devices to make
+// placement decisions. To avoid depending on the full runtime, we
+// define dummy implementations of these, and register them with the
+// runtime.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+// A dummy OpKernel that is used to register ops on different devices.
+class DummyOp : public OpKernel {
+ public:
+ explicit DummyOp(OpKernelConstruction* context) : OpKernel(context) {}
+ void Compute(OpKernelContext* context) override {}
+};
+
+// A fake device that has specific device attributes, used to simulate
+// the presence of a CPU or a GPU (without depending on that part of
+// the runtime.
+class FakeDevice : public Device {
+ private:
+ explicit FakeDevice(const DeviceAttributes& device_attributes)
+ : Device(nullptr, device_attributes, nullptr) {}
+
+ public:
+ Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
+
+ Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
+
+ static std::unique_ptr<Device> MakeCPU(const string& name) {
+ DeviceAttributes device_attributes;
+ device_attributes.set_name(name);
+ device_attributes.set_device_type(DeviceType(DEVICE_CPU).type());
+ return std::unique_ptr<Device>(new FakeDevice(device_attributes));
+ }
+
+ static std::unique_ptr<Device> MakeGPU(const string& name) {
+ DeviceAttributes device_attributes;
+ device_attributes.set_name(name);
+ device_attributes.set_device_type(DeviceType(DEVICE_GPU).type());
+ return std::unique_ptr<Device>(new FakeDevice(device_attributes));
+ }
+};
+
+// Register the following ops so they can be added to a Graph, and
+// kernels so that they can be placed on particular device types.
+REGISTER_OP("TestVariable").Output("o: Ref(float)");
+REGISTER_KERNEL_BUILDER(Name("TestVariable").Device(DEVICE_CPU), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("TestVariable").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("VariableCPU").Output("o: Ref(float)");
+REGISTER_KERNEL_BUILDER(Name("VariableCPU").Device(DEVICE_CPU), DummyOp);
+
+REGISTER_OP("VariableGPU").Output("o: Ref(float)");
+REGISTER_KERNEL_BUILDER(Name("VariableGPU").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("VariableNoKernels").Output("o: Ref(float)");
+
+REGISTER_OP("TestAdd").Input("a: float").Input("b: float").Output("o: float");
+REGISTER_KERNEL_BUILDER(Name("TestAdd").Device(DEVICE_CPU), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("TestAdd").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("TestRelu").Input("i: float").Output("o: float");
+REGISTER_KERNEL_BUILDER(Name("TestRelu").Device(DEVICE_CPU), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("TestRelu").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("ReluGPU").Input("i: float").Output("o: float");
+REGISTER_KERNEL_BUILDER(Name("ReluGPU").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("TestAssign").Input("i: Ref(float)").Input("v: float");
+REGISTER_KERNEL_BUILDER(Name("TestAssign").Device(DEVICE_CPU), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("TestAssign").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("AssignCPU").Input("i: Ref(float)").Input("v: float");
+REGISTER_KERNEL_BUILDER(Name("AssignCPU").Device(DEVICE_CPU), DummyOp);
+
+REGISTER_OP("AssignGPU").Input("i: Ref(float)").Input("v: float");
+REGISTER_KERNEL_BUILDER(Name("AssignGPU").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("TestInput").Output("a: float").Output("b: float");
+REGISTER_KERNEL_BUILDER(Name("TestInput").Device(DEVICE_CPU), DummyOp);
+
+REGISTER_OP("TestDevice").Output("a: float").Output("b: float");
+REGISTER_KERNEL_BUILDER(Name("TestDevice").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("TestDeviceEnforce").Input("a: Ref(float)").Output("b: float");
+REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device(DEVICE_CPU), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device(DEVICE_GPU), DummyOp);
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// A SimplePlacerTest method has three phases:
+//
+// 1. Build a TensorFlow graph, with no (or partial) device assignments.
+// 2. Attempt to compute a placement using the SimplePlacer.
+// 3. EITHER: test that the constraints implied by the graph are respected;
+// or that an appropriate error was reported.
+//
+////////////////////////////////////////////////////////////////////////////////
+class SimplePlacerTest : public ::testing::Test {
+ protected:
+ SimplePlacerTest() {
+ RequireDefaultOps();
+ // Build a set of 10 GPU and 10 CPU devices.
+ // NOTE: this->local_devices_ owns the device objects;
+ // this->devices_ contains borrowed pointers to the device
+ // objects.
+ for (int i = 0; i < 10; ++i) {
+ local_devices_.emplace_back(FakeDevice::MakeCPU(
+ strings::StrCat("/job:a/replica:0/task:0/cpu:", i)));
+ devices_.AddDevice(local_devices_.back().get());
+ // Insert the GPUs in reverse order.
+ local_devices_.emplace_back(FakeDevice::MakeGPU(
+ strings::StrCat("/job:a/replica:0/task:0/gpu:", 9 - i)));
+ devices_.AddDevice(local_devices_.back().get());
+ }
+ }
+
+ // Builds the given graph, and (if successful) indexes the node
+ // names for use in placement, and later lookup.
+ Status BuildGraph(const GraphDefBuilder& builder, Graph* out_graph) {
+ TF_RETURN_IF_ERROR(builder.ToGraph(out_graph));
+ nodes_by_name_.clear();
+ for (Node* node : out_graph->nodes()) {
+ nodes_by_name_[node->name()] = node->id();
+ }
+ return Status::OK();
+ }
+
+ // Invokes the SimplePlacer on "graph". If no DeviceSet is specified, the
+ // placement will use the default DeviceSet (of 10 CPU and 10 GPU devices).
+ //
+ // REQUIRES: "*graph" was produced by the most recent call to BuildGraph.
+ Status Place(Graph* graph, DeviceSet* devices, SessionOptions* options) {
+ SimplePlacer placer(graph, devices, &nodes_by_name_, options);
+ return placer.Run();
+ }
+
+ Status Place(Graph* graph, DeviceSet* devices) {
+ return Place(graph, devices, nullptr);
+ }
+
+ Status Place(Graph* graph, SessionOptions* options) {
+ return Place(graph, &devices_, options);
+ }
+
+ Status Place(Graph* graph) { return Place(graph, &devices_, nullptr); }
+
+ // Returns the node in "graph" with the given name.
+ //
+ // REQUIRES: "graph" was produced by the most recent call to BuildGraph.
+ Node* GetNodeByName(const Graph& graph, const string& name) {
+ const auto search = nodes_by_name_.find(name);
+ CHECK(search != nodes_by_name_.end()) << "Unknown node name: " << name;
+ return graph.FindNodeId(search->second);
+ }
+
+ protected:
+ std::vector<std::unique_ptr<Device>> local_devices_;
+ DeviceSet devices_;
+ SimplePlacer::NodeNameToIdMap nodes_by_name_;
+
+ Status ReferenceTestHelper(const string& variable_op_type,
+ const string& assign_op_type,
+ DeviceType expected_device_type);
+};
+
+#define EXPECT_COLOCATED(g, name_a, name_b) \
+ do { \
+ Graph& g_ = (g); \
+ EXPECT_EQ(GetNodeByName(g_, (name_a))->assigned_device_name(), \
+ GetNodeByName(g_, (name_b))->assigned_device_name()); \
+ } while (0)
+
+#define EXPECT_DEVICE_TYPE(g, name, expected_device_type) \
+ EXPECT_EQ(DeviceType(expected_device_type).type(), \
+ devices_.FindDeviceByName( \
+ GetNodeByName((g), (name))->assigned_device_name()) \
+ ->attributes() \
+ .device_type())
+
+#define EXPECT_DEVICE_CONTAINS(g, name, device_substr) \
+ EXPECT_TRUE(StringPiece(GetNodeByName((g), (name))->assigned_device_name()) \
+ .contains(device_substr))
+
+// Test that a graph with no constraints will successfully assign nodes to the
+// "best available" device (i.e. prefer GPU over CPU).
+TEST_F(SimplePlacerTest, TestNoConstraints) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ ops::UnaryOp("TestRelu", ops::NodeOut(input, 0), b.opts().WithName("n1"));
+ ops::UnaryOp("TestRelu", ops::NodeOut(input, 1), b.opts().WithName("n2"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ EXPECT_OK(Place(&g));
+ EXPECT_DEVICE_TYPE(g, "in", DEVICE_CPU);
+ EXPECT_DEVICE_TYPE(g, "n1", DEVICE_GPU);
+ EXPECT_DEVICE_TYPE(g, "n2", DEVICE_GPU);
+}
+
+// Test that a graph with device type and reference constraints on
+// some of the ops will successfully assign nodes to the constrained
+// device, and colocate nodes with reference connections.
+TEST_F(SimplePlacerTest, TestDeviceTypeConstraints) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
+ ops::BinaryOp("AssignCPU", var_cpu, input, b.opts().WithName("assign_cpu"));
+ Node* var_gpu = ops::SourceOp("VariableGPU", b.opts().WithName("var_gpu"));
+ ops::BinaryOp("AssignGPU", var_gpu, input, b.opts().WithName("assign_gpu"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ EXPECT_OK(Place(&g));
+ EXPECT_DEVICE_TYPE(g, "in", DEVICE_CPU);
+ EXPECT_DEVICE_TYPE(g, "var_cpu", DEVICE_CPU);
+ EXPECT_DEVICE_TYPE(g, "assign_cpu", DEVICE_CPU);
+ EXPECT_COLOCATED(g, "var_cpu", "assign_cpu");
+ EXPECT_DEVICE_TYPE(g, "var_gpu", DEVICE_GPU);
+ EXPECT_DEVICE_TYPE(g, "assign_gpu", DEVICE_GPU);
+ EXPECT_COLOCATED(g, "var_gpu", "assign_gpu");
+}
+
+// Test that a graph with partial device specifications on the ops
+// will successfully
+TEST_F(SimplePlacerTest, TestPartialSpec) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:a"));
+ ops::SourceOp("TestVariable",
+ b.opts().WithName("var").WithDevice("/job:a"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ EXPECT_OK(Place(&g));
+ EXPECT_DEVICE_TYPE(g, "in", DEVICE_CPU);
+ EXPECT_DEVICE_CONTAINS(g, "in", "/job:a");
+ EXPECT_DEVICE_TYPE(g, "var", DEVICE_GPU);
+ EXPECT_DEVICE_CONTAINS(g, "var", "/job:a");
+}
+
+// Test that a node with an assigned device is not relocated.
+TEST_F(SimplePlacerTest, TestAssignedDevicePreserved) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ GetNodeByName(g, "in")
+ ->set_assigned_device_name("/job:a/replica:0/task:0/cpu:7");
+
+ EXPECT_OK(Place(&g));
+ EXPECT_EQ("/job:a/replica:0/task:0/cpu:7",
+ GetNodeByName(g, "in")->assigned_device_name());
+}
+
+// Test that a graph with partial device specifications for CPU-only ops
+// will be relocated to CPU.
+TEST_F(SimplePlacerTest, TestPartialSpecGpuToCpu) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/gpu:0"));
+ ops::SourceOp("TestVariable",
+ b.opts().WithName("var").WithDevice("/gpu:0"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ SessionOptions options;
+ options.config.set_allow_soft_placement(true);
+ EXPECT_OK(Place(&g, &options));
+ EXPECT_DEVICE_TYPE(g, "in", DEVICE_CPU);
+ EXPECT_DEVICE_CONTAINS(g, "in", "/cpu");
+ EXPECT_DEVICE_TYPE(g, "var", DEVICE_GPU);
+ EXPECT_DEVICE_CONTAINS(g, "var", "/gpu:0");
+}
+
+// Test that a node with an assigned GPU device but has not registered
+// OpKernel will fail.
+TEST_F(SimplePlacerTest, TestAssignedGpuDeviceToCpuDevice) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ GetNodeByName(g, "in")
+ ->set_assigned_device_name("/job:a/replica:0/task:0/gpu:0");
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INTERNAL, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains("Assigned device '/job:a/replica:0/task:0/gpu:0' "
+ "does not have registered OpKernel support for TestInput"));
+}
+
+// Test that graphs with reference connections are correctly placed.
+
+// Build a graph containing a Variable op of "variable_op_type" and an
+// Assign op of "assign_op_type", and expect all of the ops to be
+// placed on a device of type "expected_device_type".
+Status SimplePlacerTest::ReferenceTestHelper(const string& variable_op_type,
+ const string& assign_op_type,
+ DeviceType expected_device_type) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ // Build ten variable-and-assignment pairs.
+ for (int i = 0; i < 10; ++i) {
+ Node* var = ops::SourceOp(variable_op_type,
+ b.opts().WithName(strings::StrCat("var_", i)));
+ ops::BinaryOp(assign_op_type, var, input,
+ b.opts().WithName(strings::StrCat("assign_", i)));
+ }
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ TF_RETURN_IF_ERROR(Place(&g));
+
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_COLOCATED(g, strings::StrCat("var_", i),
+ strings::StrCat("assign_", i));
+ EXPECT_DEVICE_TYPE(g, strings::StrCat("var_", i), expected_device_type);
+ EXPECT_DEVICE_TYPE(g, strings::StrCat("assign_", i), expected_device_type);
+ }
+
+ return Status::OK();
+}
+
+// Test all 2^3 combinations of Variable and Assignment op types
+// (unconstrained, CPU-only, and GPU-only).
+TEST_F(SimplePlacerTest, TestReferenceConnection) {
+ Status s;
+ EXPECT_OK(ReferenceTestHelper("TestVariable", "TestAssign", DEVICE_GPU));
+ EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignCPU", DEVICE_CPU));
+ EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignGPU", DEVICE_GPU));
+ EXPECT_OK(ReferenceTestHelper("VariableCPU", "TestAssign", DEVICE_CPU));
+ EXPECT_OK(ReferenceTestHelper("VariableCPU", "AssignCPU", DEVICE_CPU));
+ {
+ Status s = ReferenceTestHelper("VariableCPU", "AssignGPU", DEVICE_CPU);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("no device type supports both of those nodes"));
+ }
+ EXPECT_OK(ReferenceTestHelper("VariableGPU", "TestAssign", DEVICE_GPU));
+ {
+ Status s = ReferenceTestHelper("VariableGPU", "AssignCPU", DEVICE_CPU);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("no device type supports both of those nodes"));
+ }
+ EXPECT_OK(ReferenceTestHelper("VariableGPU", "AssignGPU", DEVICE_GPU));
+}
+
+// Test the handling of '@node_name' colocation constraints, when
+// these are arranged in multiple chains.
+TEST_F(SimplePlacerTest, TestColocatedChain) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ Node* last_node = input;
+ for (int i = 0; i < 100; ++i) {
+ if (i % 10 == 0) {
+ // Every ten nodes, start a new chain.
+ last_node = ops::UnaryOp("TestRelu", last_node,
+ b.opts().WithName(strings::StrCat("n_", i)));
+ } else {
+ // Chain each successive node to the previous one.
+ last_node =
+ ops::UnaryOp("TestRelu", last_node,
+ b.opts()
+ .WithName(strings::StrCat("n_", i))
+ .WithDevice(strings::StrCat("@n_", i - 1)));
+ }
+ }
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ EXPECT_OK(Place(&g));
+ for (int i = 0; i < 100; ++i) {
+ if (i % 10 != 0) {
+ EXPECT_COLOCATED(g, strings::StrCat("n_", i - (i % 1)),
+ strings::StrCat("n_", i));
+ }
+ }
+}
+
+// Test the handling of '@node_name' colocation constraints, when the
+// chains are shuffled.
+TEST_F(SimplePlacerTest, TestColocatedChainWithLongRangeColocations) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ Node* last_node = input;
+ for (int i = 0; i < 10; ++i) {
+ // Start ten chains.
+ last_node = ops::UnaryOp("TestRelu", last_node,
+ b.opts().WithName(strings::StrCat("n_", i)));
+ }
+ for (int i = 10; i < 100; ++i) {
+ // Add each node to the (i % 10)^th chain.
+ last_node = ops::UnaryOp("TestRelu", last_node,
+ b.opts()
+ .WithName(strings::StrCat("n_", i))
+ .WithDevice(strings::StrCat("@n_", i % 10)));
+ }
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ EXPECT_OK(Place(&g));
+ for (int i = 10; i < 100; ++i) {
+ EXPECT_COLOCATED(g, strings::StrCat("n_", i % 10),
+ strings::StrCat("n_", i));
+ }
+}
+
+TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ for (int i = 0; i < 10; ++i) {
+ // Declare ten variable and assignment pairs.
+ Node* var = ops::SourceOp("TestVariable",
+ b.opts().WithName(strings::StrCat("var_", i)));
+ ops::BinaryOp("TestAssign", var, input,
+ b.opts().WithName(strings::StrCat("assign_", i)));
+ }
+ for (int i = 10; i < 100; ++i) {
+ // Create a variable colocated with some existing variable, and
+ // an assignment colocated with a possibly-different variable.
+ Node* var = ops::SourceOp(
+ "TestVariable", b.opts()
+ .WithName(strings::StrCat("var_", i))
+ .WithDevice(strings::StrCat("@var_", i % 6)));
+ ops::BinaryOp("TestAssign", var, input,
+ b.opts()
+ .WithName(strings::StrCat("assign_", i))
+ .WithDevice(strings::StrCat("@assign_", i % 3)));
+ }
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ EXPECT_OK(Place(&g));
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_COLOCATED(g, strings::StrCat("var_", i),
+ strings::StrCat("assign_", i));
+ }
+ for (int i = 10; i < 100; ++i) {
+ EXPECT_COLOCATED(g, strings::StrCat("var_", i),
+ strings::StrCat("assign_", i));
+ EXPECT_COLOCATED(g, strings::StrCat("var_", i),
+ strings::StrCat("var_", i % 6));
+ EXPECT_COLOCATED(g, strings::StrCat("assign_", i),
+ strings::StrCat("assign_", i % 3));
+ }
+}
+
+// Test that placement fails when no devices are registered.
+TEST_F(SimplePlacerTest, TestEmptyDeviceSet) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ DeviceSet empty;
+
+ Status s = Place(&g, &empty);
+ EXPECT_TRUE(
+ StringPiece(s.error_message()).contains("No devices are registered"));
+}
+
+// Test that placement fails when the requested device forces an
+// indirect constraint to be violated.
+TEST_F(SimplePlacerTest, TestHeterogeneousDeviceSetFailure) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* in = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ Node* var = ops::SourceOp("VariableGPU", b.opts().WithName("var"));
+ ops::BinaryOp("TestAssign", var, in,
+ b.opts().WithName("assign").WithDevice("/job:b/task:1"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ DeviceSet heterogeneous;
+ std::unique_ptr<Device> gpu(
+ FakeDevice::MakeGPU("/job:b/replica:0/task:0/gpu:0"));
+ heterogeneous.AddDevice(gpu.get());
+ std::unique_ptr<Device> cpu(
+ FakeDevice::MakeCPU("/job:b/replica:0/task:1/cpu:0"));
+ heterogeneous.AddDevice(cpu.get());
+ Status s = Place(&g, &heterogeneous);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("colocated with a group of nodes that required "
+ "incompatible device"));
+}
+
+// Test that placement fails when an unknown device is requested.
+TEST_F(SimplePlacerTest, TestUnknownDevice) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:foo"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains(
+ "Could not satisfy explicit device specification '/job:foo'"));
+}
+
+// Test that placement fails when the combination of partial
+// constraints leads to an unknown device.
+TEST_F(SimplePlacerTest, TestUnknownMergedDevice) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:foo"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains(
+ "Could not satisfy explicit device specification '/job:foo'"));
+}
+
+// Test that placement fails when the previously-assigned device for a
+// node is unknown.
+TEST_F(SimplePlacerTest, TestUnknownAssignedDevice) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ GetNodeByName(g, "in")->set_assigned_device_name("/job:foo");
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INTERNAL, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains("Assigned device '/job:foo' does not match any device"));
+}
+
+// Test that placement fails when an op with no registered kernels is
+// requested.
+TEST_F(SimplePlacerTest, TestNoKernelsRegistered) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("VariableNoKernels", b.opts().WithName("var"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains(
+ "No OpKernel was registered to support Op 'VariableNoKernels'"));
+}
+
+// Test that placement fails when a kernel is registered but no known
+// device supports it.
+TEST_F(SimplePlacerTest, TestNoDevicesRegistered) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("VariableGPU", b.opts().WithName("var"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ DeviceSet cpu_only;
+ std::unique_ptr<Device> cpu(
+ FakeDevice::MakeCPU("/job:a/replica:0/task:0/cpu:0"));
+ cpu_only.AddDevice(cpu.get());
+
+ Status s = Place(&g, &cpu_only);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("No OpKernel was registered to support "
+ "Op 'VariableGPU'"));
+}
+
+// Test that placement fails when a requested device is malformed.
+TEST_F(SimplePlacerTest, TestMalformedDeviceSpecification) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/foo:bar"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("Malformed device specification '/foo:bar'"));
+}
+
+// Test that placement fails when a previously-assigned device is malformed.
+TEST_F(SimplePlacerTest, TestMalformedAssignedDevice) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ GetNodeByName(g, "in")->set_assigned_device_name("/foo:bar");
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INTERNAL, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("Malformed assigned device '/foo:bar'"));
+}
+
+// Test that placement fails when a device was previously assigned to
+// a node, but it does not uniquely identify a particular device.
+TEST_F(SimplePlacerTest, TestNonUniqueAssignedDevice) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ GetNodeByName(g, "in")->set_assigned_device_name("/job:a");
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INTERNAL, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains("Assigned device '/job:a' does not match any device"));
+}
+
+// Test that placement fails when a node requests colocation with another
+// node that does not exist.
+TEST_F(SimplePlacerTest, TestUnknownColocatedNode) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("@foo"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message()).contains("'foo' does not exist"));
+}
+
+// Test that placement fails when a node requests colocation with a
+// malformed node name.
+TEST_F(SimplePlacerTest, TestMalformedColocatedNode) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("@"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("node named in device '' does not exist"));
+}
+
+// Test that ops request to be placed on non-existent devices will be relocated
+// to existing device of the same type if allow_soft_placement is set.
+TEST_F(SimplePlacerTest, TestNonexistentGpuAllowSoftPlacement) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestDevice", b.opts().WithName("in").WithDevice("/gpu:11"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ SessionOptions options;
+ options.config.set_allow_soft_placement(true);
+ EXPECT_OK(Place(&g, &options));
+ EXPECT_DEVICE_CONTAINS(g, "in", "/gpu:0");
+}
+
+// Test that ops request to be placed on non-existent devices will fail if
+// allow_soft_placement is not set.
+TEST_F(SimplePlacerTest, TestNonexistentGpuNoAllowSoftPlacement) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestDevice", b.opts().WithName("in").WithDevice("/gpu:11"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ SessionOptions options;
+ Status s = Place(&g, &options);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains(
+ "Could not satisfy explicit device specification '/gpu:11'"));
+}
+
+// Test that placement fails when a node requests an explicit device that is not
+// supported by the registered kernels if allow_soft_placement is no set.
+TEST_F(SimplePlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("VariableGPU", b.opts().WithName("var").WithDevice("/cpu:0"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ SessionOptions options;
+ Status s = Place(&g, &options);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains(
+ "Could not satisfy explicit device specification '/cpu:0'"));
+}
+
+TEST_F(SimplePlacerTest, TestUnsupportedDeviceAllowSoftPlacement) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("VariableGPU", b.opts().WithName("var").WithDevice("/cpu:0"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ SessionOptions options;
+ options.config.set_allow_soft_placement(true);
+ EXPECT_OK(Place(&g, &options));
+}
+
+// Test that a graph with device type and reference constraints on
+// some of the ops will successfully assign nodes to the constrained
+// device, and colocate nodes with reference connections.
+TEST_F(SimplePlacerTest, TestDeviceTypeConstraintsAllowSoftPlacement) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ // var_gpu has ref output and runs on GPU.
+ // force_gpu takes var_gpu and requested CPU.
+ // Verify that both are placed on GPU.
+ Node* var_gpu = ops::SourceOp("VariableGPU", b.opts().WithName("var_gpu"));
+ ops::UnaryOp("TestDeviceEnforce", var_gpu,
+ b.opts().WithName("force_gpu").WithDevice("/cpu:0"));
+ // var_cpu has ref output and runs on CPU.
+ // force_cpu takes var_cpu and requested GPU.
+ // Verify that both are placed on CPU.
+ Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
+ ops::UnaryOp("TestDeviceEnforce", var_cpu,
+ b.opts().WithName("force_cpu").WithDevice("/gpu:0"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ SessionOptions options;
+ options.config.set_allow_soft_placement(true);
+ EXPECT_OK(Place(&g, &options));
+ EXPECT_DEVICE_TYPE(g, "var_gpu", DEVICE_GPU);
+ EXPECT_DEVICE_TYPE(g, "force_gpu", DEVICE_GPU);
+ EXPECT_COLOCATED(g, "var_gpu", "force_gpu");
+ EXPECT_DEVICE_TYPE(g, "var_cpu", DEVICE_CPU);
+ EXPECT_DEVICE_TYPE(g, "force_cpu", DEVICE_CPU);
+ EXPECT_COLOCATED(g, "var_cpu", "force_cpu");
+}
+
+// Test that placement fails when two nodes have a reference connection
+// constraint, and each node requires a mutually incompatible device.
+TEST_F(SimplePlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* var = ops::SourceOp("VariableGPU", b.opts().WithName("var"));
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ ops::BinaryOp("AssignCPU", var, input, b.opts().WithName("assign"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("Cannot colocate nodes 'var' and 'assign'"));
+}
+
+// Test that placement fails when two nodes have an explicit
+// colocation constraint, and each node requires a mutually
+// incompatible device.
+TEST_F(SimplePlacerTest, TestUnsatisfiableConstraintWithColocatedNodes) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput",
+ b.opts().WithName("in").WithDevice("/gpu:0"));
+ Node* relu_1 = ops::UnaryOp("TestRelu", input,
+ b.opts().WithName("relu_1").WithDevice("@in"));
+ ops::UnaryOp("ReluGPU", relu_1,
+ b.opts().WithName("relu_2").WithDevice("@relu_1"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("Cannot colocate nodes 'relu_1' and 'relu_2'"));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc
new file mode 100644
index 0000000000..4806e69c67
--- /dev/null
+++ b/tensorflow/core/common_runtime/threadpool_device.cc
@@ -0,0 +1,55 @@
+#include "tensorflow/core/common_runtime/threadpool_device.h"
+
+#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/types.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
+ const string& name, Bytes memory_limit,
+ BusAdjacency bus_adjacency,
+ Allocator* allocator)
+ : LocalDevice(options, Device::BuildDeviceAttributes(
+ name, DEVICE_CPU, memory_limit, bus_adjacency),
+ allocator),
+ allocator_(allocator) {}
+
+ThreadPoolDevice::~ThreadPoolDevice() {}
+
+void ThreadPoolDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
+ if (port::Tracing::IsActive()) {
+ // TODO(pbar) We really need a useful identifier of the graph node.
+ const uint64 id = Hash64(op_kernel->name());
+ port::Tracing::ScopedActivity region(port::Tracing::EventCategory::kCompute,
+ id);
+ op_kernel->Compute(context);
+ } else {
+ op_kernel->Compute(context);
+ }
+}
+
+Allocator* ThreadPoolDevice::GetAllocator(AllocatorAttributes attr) {
+ return allocator_;
+}
+
+Status ThreadPoolDevice::MakeTensorFromProto(
+ const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) {
+ Tensor parsed(tensor_proto.dtype());
+ if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ tensor_proto.DebugString());
+ }
+ *tensor = parsed;
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/threadpool_device.h b/tensorflow/core/common_runtime/threadpool_device.h
new file mode 100644
index 0000000000..5b0347231f
--- /dev/null
+++ b/tensorflow/core/common_runtime/threadpool_device.h
@@ -0,0 +1,31 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_
+#define TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/local_device.h"
+
+namespace tensorflow {
+
+// CPU device implementation.
+class ThreadPoolDevice : public LocalDevice {
+ public:
+ ThreadPoolDevice(const SessionOptions& options, const string& name,
+ Bytes memory_limit, BusAdjacency bus_adjacency,
+ Allocator* allocator);
+ ~ThreadPoolDevice() override;
+
+ void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
+ Allocator* GetAllocator(AllocatorAttributes attr) override;
+ Status MakeTensorFromProto(const TensorProto& tensor_proto,
+ const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) override;
+
+ Status Sync() override { return Status::OK(); }
+
+ private:
+ Allocator* allocator_; // Not owned
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/threadpool_device_factory.cc b/tensorflow/core/common_runtime/threadpool_device_factory.cc
new file mode 100644
index 0000000000..ee6319abad
--- /dev/null
+++ b/tensorflow/core/common_runtime/threadpool_device_factory.cc
@@ -0,0 +1,31 @@
+// Register a factory that provides CPU devices.
+#include "tensorflow/core/common_runtime/threadpool_device.h"
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+// TODO(zhifengc/tucker): Figure out the bytes of available RAM.
+class ThreadPoolDeviceFactory : public DeviceFactory {
+ public:
+ void CreateDevices(const SessionOptions& options, const string& name_prefix,
+ std::vector<Device*>* devices) override {
+ // TODO(zhifengc/tucker): Figure out the number of available CPUs
+ // and/or NUMA configuration.
+ int n = 1;
+ auto iter = options.config.device_count().find("CPU");
+ if (iter != options.config.device_count().end()) {
+ n = iter->second;
+ }
+ for (int i = 0; i < n; i++) {
+ string name = strings::StrCat(name_prefix, "/cpu:", i);
+ devices->push_back(new ThreadPoolDevice(options, name, Bytes(256 << 20),
+ BUS_ANY, cpu_allocator()));
+ }
+ }
+};
+REGISTER_LOCAL_DEVICE_FACTORY("CPU", ThreadPoolDeviceFactory);
+
+} // namespace tensorflow