aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-03-27 19:24:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-27 19:27:05 -0700
commitf656b7f3e07fc3a6a51cb6083d27abebcc6212bb (patch)
tree2ca23cb69bb312a7b3fa8be5260654e65eae80f9 /tensorflow/python/grappler
parent9b5411a13c2d983e6709e6dff4f82b1779389ece (diff)
Fixed the interaction between virtual cluster and measuring cost estimator.
PiperOrigin-RevId: 190712404
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r--tensorflow/python/grappler/cluster_test.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/python/grappler/cluster_test.py b/tensorflow/python/grappler/cluster_test.py
index a3c4c2bbeb..26c6f22d34 100644
--- a/tensorflow/python/grappler/cluster_test.py
+++ b/tensorflow/python/grappler/cluster_test.py
@@ -87,9 +87,10 @@ class ClusterTest(test.TestCase):
def testVirtualCluster(self):
with ops.Graph().as_default() as g:
- a = random_ops.random_uniform(shape=())
- b = random_ops.random_uniform(shape=())
- c = a + b
+ with ops.device('/device:GPU:0'):
+ a = random_ops.random_uniform(shape=[1024, 1024])
+ b = random_ops.random_uniform(shape=[1024, 1024])
+ c = a + b
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(c)
mg = meta_graph.create_meta_graph_def(graph=g)
@@ -102,10 +103,13 @@ class ClusterTest(test.TestCase):
'architecture': '7'
})
named_device = device_properties_pb2.NamedDevice(
- properties=device_properties, name='/GPU:0')
- grappler_cluster = cluster.Cluster(devices=[named_device])
+ properties=device_properties, name='/device:GPU:0')
+ grappler_cluster = cluster.Cluster(
+ disable_detailed_stats=False,
+ disable_timeline=False,
+ devices=[named_device])
op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item)
- self.assertGreater(run_time, 0)
+ self.assertEqual(run_time, 0.000545)
self.assertEqual(len(op_perfs), 15)
estimated_perf = grappler_cluster.EstimatePerformance(named_device)