aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-12-21 16:27:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-21 16:31:44 -0800
commitaf0e847c1b3dbc6b31dcf90c63c509f5f2709a48 (patch)
tree8a54409cdfcdd92dcef825580bf80b8382d0ff7b
parent4c76bb4dadc1defb56ede40066df07916cfb64c2 (diff)
Extract placement restrictions for a graph
PiperOrigin-RevId: 179872807
-rw-r--r--tensorflow/core/grappler/clusters/cluster.h3
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.h2
-rw-r--r--tensorflow/core/grappler/clusters/virtual_cluster.h2
-rw-r--r--tensorflow/python/BUILD14
-rw-r--r--tensorflow/python/grappler/cluster.i85
-rw-r--r--tensorflow/python/grappler/cluster.py3
-rw-r--r--tensorflow/python/grappler/cluster_test.py46
7 files changed, 148 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/clusters/cluster.h b/tensorflow/core/grappler/clusters/cluster.h
index d7af50f7dc..db13595eb3 100644
--- a/tensorflow/core/grappler/clusters/cluster.h
+++ b/tensorflow/core/grappler/clusters/cluster.h
@@ -38,6 +38,9 @@ class Cluster {
explicit Cluster(int timeout_s);
virtual ~Cluster();
+ // Returns a string that represent the type of cluster that was instantiated.
+ virtual string type() const = 0;
+
// Provision the hardware resources needed to run TensorFlow and start a
// TensorFlow session that can take advantage of these resources.
// The actual resources that are leveraged depend on the type of cluster
diff --git a/tensorflow/core/grappler/clusters/single_machine.h b/tensorflow/core/grappler/clusters/single_machine.h
index be005a9509..4d8e75d844 100644
--- a/tensorflow/core/grappler/clusters/single_machine.h
+++ b/tensorflow/core/grappler/clusters/single_machine.h
@@ -32,6 +32,8 @@ class SingleMachine : public Cluster {
SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus);
~SingleMachine() override;
+ string type() const override { return "single_machine"; }
+
Status Provision() override;
Status Shutdown() override;
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.h b/tensorflow/core/grappler/clusters/virtual_cluster.h
index a74911cb23..1c73dbb240 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster.h
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.h
@@ -35,6 +35,8 @@ class VirtualCluster : public Cluster {
~VirtualCluster() override;
+ string type() const override { return "virtual"; }
+
Status Provision() override;
Status Initialize(const GrapplerItem& item) override;
Status Run(const GraphDef& item,
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7a58d045f9..9d6bee8441 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -4424,24 +4424,24 @@ py_library(
],
)
-py_test(
+cuda_py_test(
name = "cluster_test",
size = "small",
srcs = [
"grappler/cluster_test.py",
],
- srcs_version = "PY2AND3",
- tags = [
- "grappler",
- "no_pip", # tf_optimizer is not available in pip.
- ],
- deps = [
+ additional_deps = [
":client_testlib",
":framework_for_generated_wrappers",
":tf_cluster",
":tf_item",
"//tensorflow/core:protos_all_py",
],
+ shard_count = 10,
+ tags = [
+ "grappler",
+ "no_pip", # tf_optimizer is not available in pip.
+ ],
)
py_library(
diff --git a/tensorflow/python/grappler/cluster.i b/tensorflow/python/grappler/cluster.i
index 9981c1d22d..0c8d04ff29 100644
--- a/tensorflow/python/grappler/cluster.i
+++ b/tensorflow/python/grappler/cluster.i
@@ -100,6 +100,7 @@ bool _PyObjAs(PyObject *input, tensorflow::NamedDevice *out) {
#include <memory>
#include <vector>
#include "tensorflow/core/grappler/devices.h"
+#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/costs/graph_memory.h"
@@ -107,6 +108,8 @@ bool _PyObjAs(PyObject *input, tensorflow::NamedDevice *out) {
#include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/protobuf/device_properties.pb.h"
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/framework/memory_types.h"
// Provide the implementation of the GCluster struct here.
struct GCluster {
@@ -214,6 +217,87 @@ static std::vector<string> TF_ListAvailableOps() {
return op_names;
}
+static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item) {
+ if (cluster.is_none() || item.is_none()) {
+ Py_RETURN_NONE;
+ }
+ const std::unordered_map<string, tensorflow::DeviceProperties>& devices = cluster->GetDevices();
+ std::unordered_map<string, std::vector<string>> device_types;
+ for (const auto& dev : devices) {
+ device_types[dev.second.type()].push_back(dev.first);
+ }
+
+ std::unordered_map<string, std::set<string>> supported_device_types;
+ std::unordered_map<string, std::set<string>> device_restrictions;
+
+ for (const auto& node : item->graph.node()) {
+ for (const auto& dev : device_types) {
+ const string& type = dev.first;
+ if (cluster->type() != "single_machine") {
+ // The actual kernel may not be linked in this binary.
+ supported_device_types[node.name()].insert(type);
+ } else {
+ // Check the kernel capabilities
+ const tensorflow::DeviceType dev_type(type);
+ tensorflow::Status s = tensorflow::FindKernelDef(dev_type, node, nullptr, nullptr);
+ if (s.ok()) {
+ supported_device_types[node.name()].insert(type);
+
+ // Check which inputs are restricted to reside on the host.
+ // TODO: extends this to support outputs as well
+ tensorflow::MemoryTypeVector inp_mtypes;
+ tensorflow::MemoryTypeVector out_mtypes;
+ s = tensorflow::MemoryTypesForNode(tensorflow::OpRegistry::Global(), dev_type, node,
+ &inp_mtypes, &out_mtypes);
+ if (s.ok()) {
+ for (int i = 0; i < inp_mtypes.size(); ++i) {
+ if (inp_mtypes[i] == tensorflow::HOST_MEMORY) {
+ device_restrictions[tensorflow::grappler::NodeName(node.input(i))].insert("CPU");
+ break;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ PyGILState_STATE gstate = PyGILState_Ensure();
+ PyObject* result = PyDict_New();
+
+ for (const auto& supported_dev : supported_device_types) {
+ const string& node = supported_dev.first;
+ std::set<string> feasible;
+ const auto it = device_restrictions.find(node);
+ if (it != device_restrictions.end()) {
+ const std::set<string>& candidates = supported_dev.second;
+ const std::set<string>& valid = it->second;
+ std::set_intersection(candidates.begin(), candidates.end(), valid.begin(), valid.end(),
+ std::inserter(feasible, feasible.begin()));
+ } else {
+ feasible = supported_dev.second;
+ }
+
+ std::vector<string> device_names;
+ for (const string& type : feasible) {
+ auto it = device_types.find(type);
+ CHECK(it != device_types.end());
+ for (const string& name : it->second) {
+ device_names.push_back(name);
+ }
+ }
+
+ PyObject* dev = PyList_New(device_names.size());
+ for (int i = 0; i < device_names.size(); ++i) {
+ PyList_SetItem(dev, i, PyString_FromString(device_names[i].c_str()));
+ }
+ CHECK_EQ(0, PyDict_SetItem(result, PyString_FromString(node.c_str()), dev));
+ }
+ PyGILState_Release(gstate);
+ return result;
+}
+
+
static double TF_EstimatePerformance(const tensorflow::NamedDevice& device) {
tensorflow::grappler::OpLevelCostEstimator estimator;
tensorflow::grappler::OpLevelCostEstimator::DeviceInfo info =
@@ -348,6 +432,7 @@ static GCluster TF_NewVirtualCluster(
static void TF_ShutdownCluster(GCluster cluster);
static PyObject* TF_ListDevices(GCluster cluster);
static std::vector<string> TF_ListAvailableOps();
+static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item);
static float TF_EstimatePerformance(const tensorflow::NamedDevice& device);
static PyObject* TF_MeasureCosts(
GItem item, GCluster cluster,
diff --git a/tensorflow/python/grappler/cluster.py b/tensorflow/python/grappler/cluster.py
index ba1a734ee0..079d07115b 100644
--- a/tensorflow/python/grappler/cluster.py
+++ b/tensorflow/python/grappler/cluster.py
@@ -84,6 +84,9 @@ class Cluster(object):
"""Returns a list of all the available operations (sorted alphatically)."""
return tf_cluster.TF_ListAvailableOps()
+ def GetSupportedDevices(self, item):
+ return tf_cluster.TF_GetSupportedDevices(self._tf_cluster, item.tf_item)
+
def EstimatePerformance(self, device):
"""Estimate the performance of the specified device."""
serialized = device.SerializeToString()
diff --git a/tensorflow/python/grappler/cluster_test.py b/tensorflow/python/grappler/cluster_test.py
index 26feac0a23..f987d84e4e 100644
--- a/tensorflow/python/grappler/cluster_test.py
+++ b/tensorflow/python/grappler/cluster_test.py
@@ -23,6 +23,8 @@ from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.grappler import cluster
from tensorflow.python.grappler import item
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -133,6 +135,50 @@ class ClusterTest(test.TestCase):
self.assertTrue(b'MatMul' in op_names)
self.assertEqual(op_names, sorted(op_names))
+ def testSupportDevices(self):
+ with ops.Graph().as_default() as g:
+ a = random_ops.random_uniform(shape=(2, 3))
+ b = random_ops.random_uniform(shape=(2, 3))
+ c = a + b
+ dims = math_ops.range(0, array_ops.rank(c), 1)
+ d = math_ops.reduce_sum(a, axis=dims)
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ train_op.append(d)
+ mg = meta_graph.create_meta_graph_def(graph=g)
+ grappler_item = item.Item(mg)
+
+ device_properties = device_properties_pb2.DeviceProperties(
+ type='GPU', frequency=1000, num_cores=60)
+ named_gpu = device_properties_pb2.NamedDevice(
+ properties=device_properties, name='/GPU:0')
+ device_properties = device_properties_pb2.DeviceProperties(
+ type='CPU', frequency=3000, num_cores=6)
+ named_cpu = device_properties_pb2.NamedDevice(
+ properties=device_properties, name='/CPU:0')
+ virtual_cluster = cluster.Cluster(devices=[named_cpu, named_gpu])
+ supported_dev = virtual_cluster.GetSupportedDevices(grappler_item)
+ self.assertEqual(supported_dev['add'], ['/CPU:0', '/GPU:0'])
+ self.assertEqual(supported_dev['Sum'], ['/CPU:0', '/GPU:0'])
+ self.assertEqual(supported_dev['range'], ['/CPU:0', '/GPU:0'])
+
+ real_cluster = cluster.Cluster()
+ supported_dev = real_cluster.GetSupportedDevices(grappler_item)
+ if test.is_gpu_available():
+ self.assertEqual(supported_dev['add'], [
+ '/job:localhost/replica:0/task:0/cpu:0',
+ '/job:localhost/replica:0/task:0/device:GPU:0'
+ ])
+ self.assertEqual(supported_dev['Sum'], [
+ '/job:localhost/replica:0/task:0/cpu:0',
+ '/job:localhost/replica:0/task:0/device:GPU:0'
+ ])
+ # The axis tensor must reside on the host
+ self.assertEqual(supported_dev['range'],
+ ['/job:localhost/replica:0/task:0/cpu:0'])
+ else:
+ self.assertEqual(supported_dev['add'],
+ ['/job:localhost/replica:0/task:0/cpu:0'])
+
if __name__ == '__main__':
test.main()