diff options
author | 2017-12-21 16:27:19 -0800 | |
---|---|---|
committer | 2017-12-21 16:31:44 -0800 | |
commit | af0e847c1b3dbc6b31dcf90c63c509f5f2709a48 (patch) | |
tree | 8a54409cdfcdd92dcef825580bf80b8382d0ff7b /tensorflow/python/grappler/cluster_test.py | |
parent | 4c76bb4dadc1defb56ede40066df07916cfb64c2 (diff) |
Extract placement restrictions for a graph
PiperOrigin-RevId: 179872807
Diffstat (limited to 'tensorflow/python/grappler/cluster_test.py')
-rw-r--r-- | tensorflow/python/grappler/cluster_test.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/tensorflow/python/grappler/cluster_test.py b/tensorflow/python/grappler/cluster_test.py index 26feac0a23..f987d84e4e 100644 --- a/tensorflow/python/grappler/cluster_test.py +++ b/tensorflow/python/grappler/cluster_test.py @@ -23,6 +23,8 @@ from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.grappler import cluster from tensorflow.python.grappler import item +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -133,6 +135,50 @@ class ClusterTest(test.TestCase): self.assertTrue(b'MatMul' in op_names) self.assertEqual(op_names, sorted(op_names)) + def testSupportDevices(self): + with ops.Graph().as_default() as g: + a = random_ops.random_uniform(shape=(2, 3)) + b = random_ops.random_uniform(shape=(2, 3)) + c = a + b + dims = math_ops.range(0, array_ops.rank(c), 1) + d = math_ops.reduce_sum(a, axis=dims) + train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) + train_op.append(d) + mg = meta_graph.create_meta_graph_def(graph=g) + grappler_item = item.Item(mg) + + device_properties = device_properties_pb2.DeviceProperties( + type='GPU', frequency=1000, num_cores=60) + named_gpu = device_properties_pb2.NamedDevice( + properties=device_properties, name='/GPU:0') + device_properties = device_properties_pb2.DeviceProperties( + type='CPU', frequency=3000, num_cores=6) + named_cpu = device_properties_pb2.NamedDevice( + properties=device_properties, name='/CPU:0') + virtual_cluster = cluster.Cluster(devices=[named_cpu, named_gpu]) + supported_dev = virtual_cluster.GetSupportedDevices(grappler_item) + self.assertEqual(supported_dev['add'], ['/CPU:0', '/GPU:0']) + self.assertEqual(supported_dev['Sum'], ['/CPU:0', '/GPU:0']) + self.assertEqual(supported_dev['range'], ['/CPU:0', '/GPU:0']) + + real_cluster = cluster.Cluster() + supported_dev = real_cluster.GetSupportedDevices(grappler_item) + if test.is_gpu_available(): + self.assertEqual(supported_dev['add'], [ + '/job:localhost/replica:0/task:0/cpu:0', + '/job:localhost/replica:0/task:0/device:GPU:0' + ]) + self.assertEqual(supported_dev['Sum'], [ + '/job:localhost/replica:0/task:0/cpu:0', + '/job:localhost/replica:0/task:0/device:GPU:0' + ]) + # The axis tensor must reside on the host + self.assertEqual(supported_dev['range'], + ['/job:localhost/replica:0/task:0/cpu:0']) + else: + self.assertEqual(supported_dev['add'], + ['/job:localhost/replica:0/task:0/cpu:0']) + if __name__ == '__main__': test.main() |