aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler/cluster_test.py
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-12-21 16:27:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-21 16:31:44 -0800
commitaf0e847c1b3dbc6b31dcf90c63c509f5f2709a48 (patch)
tree8a54409cdfcdd92dcef825580bf80b8382d0ff7b /tensorflow/python/grappler/cluster_test.py
parent4c76bb4dadc1defb56ede40066df07916cfb64c2 (diff)
Extract placement restrictions for a graph
PiperOrigin-RevId: 179872807
Diffstat (limited to 'tensorflow/python/grappler/cluster_test.py')
-rw-r--r--tensorflow/python/grappler/cluster_test.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/tensorflow/python/grappler/cluster_test.py b/tensorflow/python/grappler/cluster_test.py
index 26feac0a23..f987d84e4e 100644
--- a/tensorflow/python/grappler/cluster_test.py
+++ b/tensorflow/python/grappler/cluster_test.py
@@ -23,6 +23,8 @@ from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.grappler import cluster
from tensorflow.python.grappler import item
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -133,6 +135,50 @@ class ClusterTest(test.TestCase):
self.assertTrue(b'MatMul' in op_names)
self.assertEqual(op_names, sorted(op_names))
+ def testSupportDevices(self):
+ with ops.Graph().as_default() as g:
+ a = random_ops.random_uniform(shape=(2, 3))
+ b = random_ops.random_uniform(shape=(2, 3))
+ c = a + b
+ dims = math_ops.range(0, array_ops.rank(c), 1)
+ d = math_ops.reduce_sum(a, axis=dims)
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ train_op.append(d)
+ mg = meta_graph.create_meta_graph_def(graph=g)
+ grappler_item = item.Item(mg)
+
+ device_properties = device_properties_pb2.DeviceProperties(
+ type='GPU', frequency=1000, num_cores=60)
+ named_gpu = device_properties_pb2.NamedDevice(
+ properties=device_properties, name='/GPU:0')
+ device_properties = device_properties_pb2.DeviceProperties(
+ type='CPU', frequency=3000, num_cores=6)
+ named_cpu = device_properties_pb2.NamedDevice(
+ properties=device_properties, name='/CPU:0')
+ virtual_cluster = cluster.Cluster(devices=[named_cpu, named_gpu])
+ supported_dev = virtual_cluster.GetSupportedDevices(grappler_item)
+ self.assertEqual(supported_dev['add'], ['/CPU:0', '/GPU:0'])
+ self.assertEqual(supported_dev['Sum'], ['/CPU:0', '/GPU:0'])
+ self.assertEqual(supported_dev['range'], ['/CPU:0', '/GPU:0'])
+
+ real_cluster = cluster.Cluster()
+ supported_dev = real_cluster.GetSupportedDevices(grappler_item)
+ if test.is_gpu_available():
+ self.assertEqual(supported_dev['add'], [
+ '/job:localhost/replica:0/task:0/cpu:0',
+ '/job:localhost/replica:0/task:0/device:GPU:0'
+ ])
+ self.assertEqual(supported_dev['Sum'], [
+ '/job:localhost/replica:0/task:0/cpu:0',
+ '/job:localhost/replica:0/task:0/device:GPU:0'
+ ])
+ # The axis tensor must reside on the host
+ self.assertEqual(supported_dev['range'],
+ ['/job:localhost/replica:0/task:0/cpu:0'])
+ else:
+ self.assertEqual(supported_dev['add'],
+ ['/job:localhost/replica:0/task:0/cpu:0'])
+
if __name__ == '__main__':
test.main()