aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/profiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-30 15:20:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-30 15:23:54 -0800
commitce4200eae990d7f5efdfb727939d38bf48001ba2 (patch)
tree7c9c3e9cd932273198fa2ef8e4fb24637de1f42f /tensorflow/python/profiler
parent6bfc73a0b3c6810725a5eb0020470457cc5cc23e (diff)
Fix profiler to track some missed persistent bytes.
PiperOrigin-RevId: 177516249
Diffstat (limited to 'tensorflow/python/profiler')
-rw-r--r--tensorflow/python/profiler/model_analyzer_test.py40
1 files changed, 38 insertions, 2 deletions
diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py
index 698f8906d4..c39d0fa5b1 100644
--- a/tensorflow/python/profiler/model_analyzer_test.py
+++ b/tensorflow/python/profiler/model_analyzer_test.py
@@ -23,12 +23,15 @@ import os
import random
import re
+import numpy as np
+
from tensorflow.core.profiler import profile_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
@@ -346,8 +349,8 @@ class PrintModelAnalysisTest(test.TestCase):
with gfile.Open(outfile, 'r') as f:
# pylint: disable=line-too-long
self.assertEqual(
- 'nodename|requestedbytes|peakbytes|residualbytes|outputbytes|totalexecutiontime|acceleratorexecutiontime|cpuexecutiontime|#parameters|opoccurrence(run|defined)|inputshapes\nConst0B(0',
- f.read().replace('\t', '').replace(' ', '')[0:180])
+ 'nodename|requestedbytes|peakbytes|residualbytes|outputbytes|totalexecutiontime|acceleratorexecutiontime|cpuexecutiontime|#parameters|opoccurrence(run|defined)|inputshapes',
+ f.read().replace('\t', '').replace(' ', '')[0:170])
# pylint: enable=line-too-long
total_children = 0
@@ -694,6 +697,39 @@ class PrintModelAnalysisTest(test.TestCase):
exception_str)
self.assertTrue(mat is None)
+ def testTrackPersistentBytes(self):
+ ops.reset_default_graph()
+ a = array_ops.constant(np.ones((100, 100)))
+ b = array_ops.constant(np.ones((100, 100)))
+ c = a * b
+
+ with session.Session() as sess:
+ run_options = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
+ run_metadata = config_pb2.RunMetadata()
+ sess.run(c, options=run_options, run_metadata=run_metadata)
+
+ options = option_builder.ProfileOptionBuilder.time_and_memory()
+ options['min_bytes'] = 0
+ options['select'] = ('bytes', 'peak_bytes', 'output_bytes',
+ 'residual_bytes')
+ ret = model_analyzer.profile(
+ sess.graph, run_meta=run_metadata, cmd='scope', options=options)
+
+ run_metadata = config_pb2.RunMetadata()
+ sess.run(c, options=run_options, run_metadata=run_metadata)
+ ret2 = model_analyzer.profile(
+ sess.graph, run_meta=run_metadata, cmd='scope', options=options)
+
+ n = lib.SearchTFProfNode(ret, 'mul')
+ n2 = lib.SearchTFProfNode(ret2, 'mul')
+ self.assertGreater(n.peak_bytes, 0)
+ self.assertGreater(n.output_bytes, 0)
+ self.assertGreater(n.residual_bytes, 0)
+ self.assertEqual(n.peak_bytes, n2.peak_bytes)
+ self.assertEqual(n.output_bytes, n2.output_bytes)
+ self.assertEqual(n.residual_bytes, n2.residual_bytes)
+
if __name__ == '__main__':
test.main()