diff options
author | 2017-12-21 16:27:19 -0800 | |
---|---|---|
committer | 2017-12-21 16:31:44 -0800 | |
commit | af0e847c1b3dbc6b31dcf90c63c509f5f2709a48 (patch) | |
tree | 8a54409cdfcdd92dcef825580bf80b8382d0ff7b | |
parent | 4c76bb4dadc1defb56ede40066df07916cfb64c2 (diff) |
Extract placement restrictions for a graph
PiperOrigin-RevId: 179872807
-rw-r--r-- | tensorflow/core/grappler/clusters/cluster.h | 3 | ||||
-rw-r--r-- | tensorflow/core/grappler/clusters/single_machine.h | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/clusters/virtual_cluster.h | 2 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/python/grappler/cluster.i | 85 | ||||
-rw-r--r-- | tensorflow/python/grappler/cluster.py | 3 | ||||
-rw-r--r-- | tensorflow/python/grappler/cluster_test.py | 46 |
7 files changed, 148 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/clusters/cluster.h b/tensorflow/core/grappler/clusters/cluster.h index d7af50f7dc..db13595eb3 100644 --- a/tensorflow/core/grappler/clusters/cluster.h +++ b/tensorflow/core/grappler/clusters/cluster.h @@ -38,6 +38,9 @@ class Cluster { explicit Cluster(int timeout_s); virtual ~Cluster(); + // Returns a string that represent the type of cluster that was instantiated. + virtual string type() const = 0; + // Provision the hardware resources needed to run TensorFlow and start a // TensorFlow session that can take advantage of these resources. // The actual resources that are leveraged depend on the type of cluster diff --git a/tensorflow/core/grappler/clusters/single_machine.h b/tensorflow/core/grappler/clusters/single_machine.h index be005a9509..4d8e75d844 100644 --- a/tensorflow/core/grappler/clusters/single_machine.h +++ b/tensorflow/core/grappler/clusters/single_machine.h @@ -32,6 +32,8 @@ class SingleMachine : public Cluster { SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus); ~SingleMachine() override; + string type() const override { return "single_machine"; } + Status Provision() override; Status Shutdown() override; diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.h b/tensorflow/core/grappler/clusters/virtual_cluster.h index a74911cb23..1c73dbb240 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.h +++ b/tensorflow/core/grappler/clusters/virtual_cluster.h @@ -35,6 +35,8 @@ class VirtualCluster : public Cluster { ~VirtualCluster() override; + string type() const override { return "virtual"; } + Status Provision() override; Status Initialize(const GrapplerItem& item) override; Status Run(const GraphDef& item, diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 7a58d045f9..9d6bee8441 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4424,24 +4424,24 @@ py_library( ], ) -py_test( +cuda_py_test( name = "cluster_test", size = "small", srcs = [ "grappler/cluster_test.py", ], - srcs_version = "PY2AND3", - tags = [ - "grappler", - "no_pip", # tf_optimizer is not available in pip. - ], - deps = [ + additional_deps = [ ":client_testlib", ":framework_for_generated_wrappers", ":tf_cluster", ":tf_item", "//tensorflow/core:protos_all_py", ], + shard_count = 10, + tags = [ + "grappler", + "no_pip", # tf_optimizer is not available in pip. + ], ) py_library( diff --git a/tensorflow/python/grappler/cluster.i b/tensorflow/python/grappler/cluster.i index 9981c1d22d..0c8d04ff29 100644 --- a/tensorflow/python/grappler/cluster.i +++ b/tensorflow/python/grappler/cluster.i @@ -100,6 +100,7 @@ bool _PyObjAs(PyObject *input, tensorflow::NamedDevice *out) { #include <memory> #include <vector> #include "tensorflow/core/grappler/devices.h" +#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/clusters/single_machine.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/costs/graph_memory.h" @@ -107,6 +108,8 @@ bool _PyObjAs(PyObject *input, tensorflow::NamedDevice *out) { #include "tensorflow/core/grappler/costs/measuring_cost_estimator.h" #include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/protobuf/device_properties.pb.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/memory_types.h" // Provide the implementation of the GCluster struct here. struct GCluster { @@ -214,6 +217,87 @@ static std::vector<string> TF_ListAvailableOps() { return op_names; } +static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item) { + if (cluster.is_none() || item.is_none()) { + Py_RETURN_NONE; + } + const std::unordered_map<string, tensorflow::DeviceProperties>& devices = cluster->GetDevices(); + std::unordered_map<string, std::vector<string>> device_types; + for (const auto& dev : devices) { + device_types[dev.second.type()].push_back(dev.first); + } + + std::unordered_map<string, std::set<string>> supported_device_types; + std::unordered_map<string, std::set<string>> device_restrictions; + + for (const auto& node : item->graph.node()) { + for (const auto& dev : device_types) { + const string& type = dev.first; + if (cluster->type() != "single_machine") { + // The actual kernel may not be linked in this binary. + supported_device_types[node.name()].insert(type); + } else { + // Check the kernel capabilities + const tensorflow::DeviceType dev_type(type); + tensorflow::Status s = tensorflow::FindKernelDef(dev_type, node, nullptr, nullptr); + if (s.ok()) { + supported_device_types[node.name()].insert(type); + + // Check which inputs are restricted to reside on the host. + // TODO: extends this to support outputs as well + tensorflow::MemoryTypeVector inp_mtypes; + tensorflow::MemoryTypeVector out_mtypes; + s = tensorflow::MemoryTypesForNode(tensorflow::OpRegistry::Global(), dev_type, node, + &inp_mtypes, &out_mtypes); + if (s.ok()) { + for (int i = 0; i < inp_mtypes.size(); ++i) { + if (inp_mtypes[i] == tensorflow::HOST_MEMORY) { + device_restrictions[tensorflow::grappler::NodeName(node.input(i))].insert("CPU"); + break; + } + } + } + } + } + } + } + + PyGILState_STATE gstate = PyGILState_Ensure(); + PyObject* result = PyDict_New(); + + for (const auto& supported_dev : supported_device_types) { + const string& node = supported_dev.first; + std::set<string> feasible; + const auto it = device_restrictions.find(node); + if (it != device_restrictions.end()) { + const std::set<string>& candidates = supported_dev.second; + const std::set<string>& valid = it->second; + std::set_intersection(candidates.begin(), candidates.end(), valid.begin(), valid.end(), + std::inserter(feasible, feasible.begin())); + } else { + feasible = supported_dev.second; + } + + std::vector<string> device_names; + for (const string& type : feasible) { + auto it = device_types.find(type); + CHECK(it != device_types.end()); + for (const string& name : it->second) { + device_names.push_back(name); + } + } + + PyObject* dev = PyList_New(device_names.size()); + for (int i = 0; i < device_names.size(); ++i) { + PyList_SetItem(dev, i, PyString_FromString(device_names[i].c_str())); + } + CHECK_EQ(0, PyDict_SetItem(result, PyString_FromString(node.c_str()), dev)); + } + PyGILState_Release(gstate); + return result; +} + + static double TF_EstimatePerformance(const tensorflow::NamedDevice& device) { tensorflow::grappler::OpLevelCostEstimator estimator; tensorflow::grappler::OpLevelCostEstimator::DeviceInfo info = @@ -348,6 +432,7 @@ static GCluster TF_NewVirtualCluster( static void TF_ShutdownCluster(GCluster cluster); static PyObject* TF_ListDevices(GCluster cluster); static std::vector<string> TF_ListAvailableOps(); +static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item); static float TF_EstimatePerformance(const tensorflow::NamedDevice& device); static PyObject* TF_MeasureCosts( GItem item, GCluster cluster, diff --git a/tensorflow/python/grappler/cluster.py b/tensorflow/python/grappler/cluster.py index ba1a734ee0..079d07115b 100644 --- a/tensorflow/python/grappler/cluster.py +++ b/tensorflow/python/grappler/cluster.py @@ -84,6 +84,9 @@ class Cluster(object): """Returns a list of all the available operations (sorted alphatically).""" return tf_cluster.TF_ListAvailableOps() + def GetSupportedDevices(self, item): + return tf_cluster.TF_GetSupportedDevices(self._tf_cluster, item.tf_item) + def EstimatePerformance(self, device): """Estimate the performance of the specified device.""" serialized = device.SerializeToString() diff --git a/tensorflow/python/grappler/cluster_test.py b/tensorflow/python/grappler/cluster_test.py index 26feac0a23..f987d84e4e 100644 --- a/tensorflow/python/grappler/cluster_test.py +++ b/tensorflow/python/grappler/cluster_test.py @@ -23,6 +23,8 @@ from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.grappler import cluster from tensorflow.python.grappler import item +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -133,6 +135,50 @@ class ClusterTest(test.TestCase): self.assertTrue(b'MatMul' in op_names) self.assertEqual(op_names, sorted(op_names)) + def testSupportDevices(self): + with ops.Graph().as_default() as g: + a = random_ops.random_uniform(shape=(2, 3)) + b = random_ops.random_uniform(shape=(2, 3)) + c = a + b + dims = math_ops.range(0, array_ops.rank(c), 1) + d = math_ops.reduce_sum(a, axis=dims) + train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) + train_op.append(d) + mg = meta_graph.create_meta_graph_def(graph=g) + grappler_item = item.Item(mg) + + device_properties = device_properties_pb2.DeviceProperties( + type='GPU', frequency=1000, num_cores=60) + named_gpu = device_properties_pb2.NamedDevice( + properties=device_properties, name='/GPU:0') + device_properties = device_properties_pb2.DeviceProperties( + type='CPU', frequency=3000, num_cores=6) + named_cpu = device_properties_pb2.NamedDevice( + properties=device_properties, name='/CPU:0') + virtual_cluster = cluster.Cluster(devices=[named_cpu, named_gpu]) + supported_dev = virtual_cluster.GetSupportedDevices(grappler_item) + self.assertEqual(supported_dev['add'], ['/CPU:0', '/GPU:0']) + self.assertEqual(supported_dev['Sum'], ['/CPU:0', '/GPU:0']) + self.assertEqual(supported_dev['range'], ['/CPU:0', '/GPU:0']) + + real_cluster = cluster.Cluster() + supported_dev = real_cluster.GetSupportedDevices(grappler_item) + if test.is_gpu_available(): + self.assertEqual(supported_dev['add'], [ + '/job:localhost/replica:0/task:0/cpu:0', + '/job:localhost/replica:0/task:0/device:GPU:0' + ]) + self.assertEqual(supported_dev['Sum'], [ + '/job:localhost/replica:0/task:0/cpu:0', + '/job:localhost/replica:0/task:0/device:GPU:0' + ]) + # The axis tensor must reside on the host + self.assertEqual(supported_dev['range'], + ['/job:localhost/replica:0/task:0/cpu:0']) + else: + self.assertEqual(supported_dev['add'], + ['/job:localhost/replica:0/task:0/cpu:0']) + if __name__ == '__main__': test.main() |