aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/profiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-06 21:56:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-06 22:01:54 -0700
commit9e9ffa33d154b6c332bd475d6da2030746460fed (patch)
tree64de1c00853926d394f6f4ba82f01f448e990fda /tensorflow/python/profiler
parent27ce2a4ac956941ba8a0b9aaaa77acc0aa861fef (diff)
Unify all profile files (graph,run_meta,op_log) into one.
Also allow profiler to serialize/deserialize to/from file. PiperOrigin-RevId: 167815923
Diffstat (limited to 'tensorflow/python/profiler')
-rw-r--r--tensorflow/python/profiler/BUILD15
-rw-r--r--tensorflow/python/profiler/internal/model_analyzer_testlib.py14
-rw-r--r--tensorflow/python/profiler/model_analyzer.py33
-rw-r--r--tensorflow/python/profiler/model_analyzer_test.py208
-rw-r--r--tensorflow/python/profiler/profile_context.py293
-rw-r--r--tensorflow/python/profiler/profile_context_test.py71
6 files changed, 394 insertions, 240 deletions
diff --git a/tensorflow/python/profiler/BUILD b/tensorflow/python/profiler/BUILD
index 32ecde0243..8dd2413661 100644
--- a/tensorflow/python/profiler/BUILD
+++ b/tensorflow/python/profiler/BUILD
@@ -112,6 +112,21 @@ py_library(
],
)
+cuda_py_test(
+ name = "profile_context_test",
+ srcs = ["profile_context_test.py"],
+ additional_deps = [
+ ":profile_context",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform",
+ "//tensorflow/python/profiler/internal:model_analyzer_testlib",
+ "//tensorflow/python:variables",
+ ],
+ tags = ["no_pip"],
+)
+
py_library(
name = "pprof_profiler",
srcs = ["pprof_profiler.py"],
diff --git a/tensorflow/python/profiler/internal/model_analyzer_testlib.py b/tensorflow/python/profiler/internal/model_analyzer_testlib.py
index 42b83fde7c..350a62c0ea 100644
--- a/tensorflow/python/profiler/internal/model_analyzer_testlib.py
+++ b/tensorflow/python/profiler/internal/model_analyzer_testlib.py
@@ -17,6 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import contextlib
+
+from tensorflow.python import pywrap_tensorflow as print_mdl
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
@@ -27,7 +30,9 @@ from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.ops import variable_scope
+from tensorflow.python.profiler import model_analyzer
from tensorflow.python.training import gradient_descent
+from tensorflow.python.util import compat
def BuildSmallModel():
@@ -95,3 +100,12 @@ def SearchTFProfNode(node, name):
r = SearchTFProfNode(c, name)
if r: return r
return None
+
+
+@contextlib.contextmanager
+def ProfilerFromFile(profile_file):
+ """Initialize a profiler from profile file."""
+ print_mdl.ProfilerFromFile(compat.as_bytes(profile_file))
+ profiler = model_analyzer.Profiler.__new__(model_analyzer.Profiler)
+ yield profiler
+ print_mdl.DeleteProfiler()
diff --git a/tensorflow/python/profiler/model_analyzer.py b/tensorflow/python/profiler/model_analyzer.py
index a1fe47982f..98d3e58f2a 100644
--- a/tensorflow/python/profiler/model_analyzer.py
+++ b/tensorflow/python/profiler/model_analyzer.py
@@ -183,8 +183,11 @@ class Profiler(object):
self._graph, run_meta=run_meta, add_trace=False,
add_trainable_var=False)
# pylint: enable=protected-access
+ # TODO(xpan): P1: Better to find the current graph.
print_mdl.AddStep(
- step, run_meta.SerializeToString(), op_log.SerializeToString())
+ step,
+ self._graph.as_graph_def(add_shapes=True).SerializeToString(),
+ run_meta.SerializeToString(), op_log.SerializeToString())
def profile_python(self, options):
"""Profile the statistics of the Python codes.
@@ -200,8 +203,11 @@ class Profiler(object):
"""
opts = _build_options(options)
tfprof_node = tfprof_output_pb2.MultiGraphNodeProto()
- tfprof_node.ParseFromString(
- print_mdl.Profile('code'.encode('utf-8'), opts.SerializeToString()))
+ try:
+ tfprof_node.ParseFromString(
+ print_mdl.Profile('code'.encode('utf-8'), opts.SerializeToString()))
+ except message.DecodeError as _:
+ pass
return tfprof_node
def profile_operations(self, options):
@@ -214,8 +220,11 @@ class Profiler(object):
"""
opts = _build_options(options)
tfprof_node = tfprof_output_pb2.MultiGraphNodeProto()
- tfprof_node.ParseFromString(
- print_mdl.Profile('op'.encode('utf-8'), opts.SerializeToString()))
+ try:
+ tfprof_node.ParseFromString(
+ print_mdl.Profile('op'.encode('utf-8'), opts.SerializeToString()))
+ except message.DecodeError as _:
+ pass
return tfprof_node
def profile_name_scope(self, options):
@@ -228,8 +237,11 @@ class Profiler(object):
"""
opts = _build_options(options)
tfprof_node = tfprof_output_pb2.GraphNodeProto()
- tfprof_node.ParseFromString(
- print_mdl.Profile('scope'.encode('utf-8'), opts.SerializeToString()))
+ try:
+ tfprof_node.ParseFromString(
+ print_mdl.Profile('scope'.encode('utf-8'), opts.SerializeToString()))
+ except message.DecodeError as _:
+ pass
return tfprof_node
def profile_graph(self, options):
@@ -242,8 +254,11 @@ class Profiler(object):
"""
opts = _build_options(options)
tfprof_node = tfprof_output_pb2.GraphNodeProto()
- tfprof_node.ParseFromString(
- print_mdl.Profile('graph'.encode('utf-8'), opts.SerializeToString()))
+ try:
+ tfprof_node.ParseFromString(
+ print_mdl.Profile('graph'.encode('utf-8'), opts.SerializeToString()))
+ except message.DecodeError as _:
+ pass
return tfprof_node
def advise(self, options):
diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py
index 3432765b60..dcdda1ffa2 100644
--- a/tensorflow/python/profiler/model_analyzer_test.py
+++ b/tensorflow/python/profiler/model_analyzer_test.py
@@ -69,57 +69,65 @@ class PrintModelAnalysisTest(test.TestCase):
.select(['micros', 'bytes', 'params', 'float_ops', 'occurrence',
'device', 'op_types', 'input_shapes']).build())
- config = config_pb2.ConfigProto()
- with session.Session(config=config) as sess, ops.device(dev):
- x = lib.BuildSmallModel()
-
- sess.run(variables.global_variables_initializer())
- run_meta = config_pb2.RunMetadata()
- _ = sess.run(x,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_meta)
-
- model_analyzer.profile(
- sess.graph, run_meta, options=opts)
-
- with gfile.Open(outfile, 'r') as f:
- # pylint: disable=line-too-long
- outputs = f.read().split('\n')
-
- self.assertEqual(outputs[0],
- 'node name | # parameters | # float_ops | requested bytes | total execution time | accelerator execution time | cpu execution time | assigned devices | op types | op count (run|defined) | input shapes')
- for o in outputs[1:]:
- if o.find('Conv2D ') > 0:
- metrics = o[o.find('(') +1: o.find(')')].split(',')
- # Make sure time is profiled.
- gap = 1 if test.is_gpu_available() else 2
- for i in range(3, 6, gap):
- mat = re.search('(.*)[um]s/(.*)[um]s', metrics[i])
+ with profile_context.ProfileContext(test.get_temp_dir(),
+ trace_steps=[],
+ dump_steps=[]) as pctx:
+ with session.Session() as sess, ops.device(dev):
+ x = lib.BuildSmallModel()
+
+ sess.run(variables.global_variables_initializer())
+ pctx.trace_next_step()
+ pctx.dump_next_step()
+ _ = sess.run(x)
+
+ pctx.profiler.profile_name_scope(options=opts)
+
+ with gfile.Open(outfile, 'r') as f:
+ # pylint: disable=line-too-long
+ dump_str = f.read()
+ outputs = dump_str.split('\n')
+
+ self.assertEqual(outputs[0],
+ 'node name | # parameters | # float_ops | requested bytes | total execution time | accelerator execution time | cpu execution time | assigned devices | op types | op count (run|defined) | input shapes')
+ for o in outputs[1:]:
+ if o.find('Conv2D ') > 0:
+ metrics = o[o.find('(') +1: o.find(')')].split(',')
+ # Make sure time is profiled.
+ gap = 1 if test.is_gpu_available() else 2
+ for i in range(3, 6, gap):
+ mat = re.search('(.*)[um]s/(.*)[um]s', metrics[i])
+ self.assertGreater(float(mat.group(1)), 0.0)
+ self.assertGreater(float(mat.group(2)), 0.0)
+ # Make sure device is profiled.
+ if test.is_gpu_available():
+ self.assertTrue(metrics[6].find('gpu') > 0)
+ self.assertFalse(metrics[6].find('cpu') > 0)
+ else:
+ self.assertFalse(metrics[6].find('gpu') > 0)
+ self.assertTrue(metrics[6].find('cpu') > 0)
+ # Make sure float_ops is profiled.
+ mat = re.search('(.*)k/(.*)k flops', metrics[1].strip())
self.assertGreater(float(mat.group(1)), 0.0)
self.assertGreater(float(mat.group(2)), 0.0)
- # Make sure device is profiled.
- if test.is_gpu_available():
- self.assertTrue(metrics[6].find('gpu') > 0)
- self.assertFalse(metrics[6].find('cpu') > 0)
- else:
- self.assertFalse(metrics[6].find('gpu') > 0)
- self.assertTrue(metrics[6].find('cpu') > 0)
- # Make sure float_ops is profiled.
- mat = re.search('(.*)k/(.*)k flops', metrics[1].strip())
- self.assertGreater(float(mat.group(1)), 0.0)
- self.assertGreater(float(mat.group(2)), 0.0)
- # Make sure op_count is profiled.
- self.assertEqual(metrics[8].strip(), '1/1|1/1')
- # Make sure input_shapes is profiled.
- self.assertEqual(metrics[9].strip(), '0:2x6x6x3|1:3x3x3x6')
-
- if o.find('DW (3x3x3x6') > 0:
- metrics = o[o.find('(') +1: o.find(')')].split(',')
- mat = re.search('(.*)/(.*) params', metrics[1].strip())
- self.assertGreater(float(mat.group(1)), 0.0)
- self.assertGreater(float(mat.group(2)), 0.0)
- # pylint: enable=line-too-long
+ # Make sure op_count is profiled.
+ self.assertEqual(metrics[8].strip(), '1/1|1/1')
+ # Make sure input_shapes is profiled.
+ self.assertEqual(metrics[9].strip(), '0:2x6x6x3|1:3x3x3x6')
+
+ if o.find('DW (3x3x3x6') > 0:
+ metrics = o[o.find('(') +1: o.find(')')].split(',')
+ mat = re.search('(.*)/(.*) params', metrics[1].strip())
+ self.assertGreater(float(mat.group(1)), 0.0)
+ self.assertGreater(float(mat.group(2)), 0.0)
+ # pylint: enable=line-too-long
+
+ # Test that profiler restored from profile file gives the same result.
+ gfile.Remove(outfile)
+ profile_file = os.path.join(test.get_temp_dir(), 'profile_1')
+ with lib.ProfilerFromFile(profile_file) as profiler:
+ profiler.profile_name_scope(options=opts)
+ with gfile.Open(outfile, 'r') as f:
+ self.assertEqual(dump_str, f.read())
def testSelectEverything(self):
ops.reset_default_graph()
@@ -198,56 +206,54 @@ class PrintModelAnalysisTest(test.TestCase):
.account_displayed_op_only(False)
.select(['params', 'float_ops']).build())
- with session.Session() as sess:
- x = lib.BuildFullModel()
+ with profile_context.ProfileContext(test.get_temp_dir(),
+ trace_steps=[],
+ dump_steps=[]) as pctx:
+ with session.Session() as sess:
+ x = lib.BuildFullModel()
- sess.run(variables.global_variables_initializer())
- run_meta = config_pb2.RunMetadata()
- _ = sess.run(x,
- options=config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE),
- run_metadata=run_meta)
+ sess.run(variables.global_variables_initializer())
+ pctx.trace_next_step()
+ _ = sess.run(x)
+ tfprof_node = pctx.profiler.profile_python(options=opts)
- tfprof_node = model_analyzer.profile(
- sess.graph, run_meta, cmd='code', options=opts)
-
- # pylint: disable=line-too-long
- with gfile.Open(outfile, 'r') as f:
- lines = f.read().split('\n')
- result = '\n'.join([l[:min(len(l), 80)] for l in lines])
- self.assertEqual('node name | # parameters | # float_ops\n_TFProfRoot (--/2.84k params, --/91.04k flops)\n model_analyzer_testlib.py:58:BuildFullModel (0/1.80k params, 0/41.76k flops)\n model_analyzer_testlib.py:35:BuildSmallModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:39:BuildSmallModel (0/4 params, 0/0 flops)\n model_analyzer_testlib.py:43:BuildSmallModel (0/648 params, 0/0 flops)\n model_analyzer_testlib.py:44:BuildSmallModel (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:48:BuildSmallModel (0/1.15k params, 0/0 flops)\n model_analyzer_testlib.py:49:BuildSmallModel (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:58:BuildFullModel (gradient) (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:44:BuildSmallModel (gradient) (0/0 params, 0/0 flo\n model_analyzer_testlib.py:49:BuildSmallModel (gradient) (0/0 params, 0/0 flo\n model_analyzer_testlib.py:62:BuildFullModel (0/1.04k params, 0/16.51k flops)\n model_analyzer_testlib.py:62:BuildFullModel (gradient) (0/0 params, 0/32.77k f\n model_analyzer_testlib.py:64:BuildFullModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:65:BuildFullModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:65:BuildFullModel (gradient) (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:67:BuildFullModel (0/0 params, 0/0 flops)\n',
- result)
-
- self.assertLess(0, tfprof_node.total_exec_micros)
- self.assertEqual(2844, tfprof_node.total_parameters)
- self.assertEqual(91040, tfprof_node.total_float_ops)
- self.assertEqual(8, len(tfprof_node.children))
- self.assertEqual('_TFProfRoot', tfprof_node.name)
- self.assertEqual(
- 'model_analyzer_testlib.py:58:BuildFullModel',
- tfprof_node.children[0].name)
- self.assertEqual(
- 'model_analyzer_testlib.py:58:BuildFullModel (gradient)',
- tfprof_node.children[1].name)
- self.assertEqual(
- 'model_analyzer_testlib.py:62:BuildFullModel',
- tfprof_node.children[2].name)
- self.assertEqual(
- 'model_analyzer_testlib.py:62:BuildFullModel (gradient)',
- tfprof_node.children[3].name)
- self.assertEqual(
- 'model_analyzer_testlib.py:64:BuildFullModel',
- tfprof_node.children[4].name)
- self.assertEqual(
- 'model_analyzer_testlib.py:65:BuildFullModel',
- tfprof_node.children[5].name)
- self.assertEqual(
- 'model_analyzer_testlib.py:65:BuildFullModel (gradient)',
- tfprof_node.children[6].name)
- self.assertEqual(
- 'model_analyzer_testlib.py:67:BuildFullModel',
- tfprof_node.children[7].name)
- # pylint: enable=line-too-long
+ # pylint: disable=line-too-long
+ with gfile.Open(outfile, 'r') as f:
+ lines = f.read().split('\n')
+ result = '\n'.join([l[:min(len(l), 80)] for l in lines])
+ self.assertEqual('node name | # parameters | # float_ops\n_TFProfRoot (--/2.84k params, --/91.04k flops)\n model_analyzer_testlib.py:63:BuildFullModel (0/1.80k params, 0/41.76k flops)\n model_analyzer_testlib.py:40:BuildSmallModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:44:BuildSmallModel (0/4 params, 0/0 flops)\n model_analyzer_testlib.py:48:BuildSmallModel (0/648 params, 0/0 flops)\n model_analyzer_testlib.py:49:BuildSmallModel (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:53:BuildSmallModel (0/1.15k params, 0/0 flops)\n model_analyzer_testlib.py:54:BuildSmallModel (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:63:BuildFullModel (gradient) (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:49:BuildSmallModel (gradient) (0/0 params, 0/0 flo\n model_analyzer_testlib.py:54:BuildSmallModel (gradient) (0/0 params, 0/0 flo\n model_analyzer_testlib.py:67:BuildFullModel (0/1.04k params, 0/16.51k flops)\n model_analyzer_testlib.py:67:BuildFullModel (gradient) (0/0 params, 0/32.77k f\n model_analyzer_testlib.py:69:BuildFullModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:70:BuildFullModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:70:BuildFullModel (gradient) (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:72:BuildFullModel (0/0 params, 0/0 flops)\n',
+ result)
+
+ self.assertLess(0, tfprof_node.total_exec_micros)
+ self.assertEqual(2844, tfprof_node.total_parameters)
+ self.assertEqual(91040, tfprof_node.total_float_ops)
+ self.assertEqual(8, len(tfprof_node.children))
+ self.assertEqual('_TFProfRoot', tfprof_node.name)
+ self.assertEqual(
+ 'model_analyzer_testlib.py:63:BuildFullModel',
+ tfprof_node.children[0].name)
+ self.assertEqual(
+ 'model_analyzer_testlib.py:63:BuildFullModel (gradient)',
+ tfprof_node.children[1].name)
+ self.assertEqual(
+ 'model_analyzer_testlib.py:67:BuildFullModel',
+ tfprof_node.children[2].name)
+ self.assertEqual(
+ 'model_analyzer_testlib.py:67:BuildFullModel (gradient)',
+ tfprof_node.children[3].name)
+ self.assertEqual(
+ 'model_analyzer_testlib.py:69:BuildFullModel',
+ tfprof_node.children[4].name)
+ self.assertEqual(
+ 'model_analyzer_testlib.py:70:BuildFullModel',
+ tfprof_node.children[5].name)
+ self.assertEqual(
+ 'model_analyzer_testlib.py:70:BuildFullModel (gradient)',
+ tfprof_node.children[6].name)
+ self.assertEqual(
+ 'model_analyzer_testlib.py:72:BuildFullModel',
+ tfprof_node.children[7].name)
+ # pylint: enable=line-too-long
def testCodeViewLeafGraphNode(self):
ops.reset_default_graph()
@@ -590,8 +596,7 @@ class PrintModelAnalysisTest(test.TestCase):
self.assertEqual(len(gfile.ListDirectory(memory_dir)), 0)
if i in dump_step:
ret = gfile.ListDirectory(profile_dir)
- self.assertAllEqual(sorted(ret),
- ['graph.pbtxt', 'run_metadata', 'tfprof_log'])
+ self.assertAllEqual(ret, ['profile_%d' % i])
_ = [gfile.Remove(os.path.join(profile_dir, x)) for x in ret]
else:
if i < dump_step[0]:
@@ -620,10 +625,11 @@ class PrintModelAnalysisTest(test.TestCase):
dump_steps = [3, 4]
x = lib.BuildSmallModel()
- with profile_context.ProfileContext() as pctx:
+ with profile_context.ProfileContext(profile_dir,
+ trace_steps=[1, 2, 3],
+ dump_steps=[3, 4]) as pctx:
pctx.add_auto_profiling('scope', time_opts, time_steps)
pctx.add_auto_profiling('scope', memory_opts, memory_steps)
- pctx.add_auto_profile_dump(profile_dir, dump_steps)
self._trainLoop(x, 10, time_dir, time_steps,
memory_dir, memory_steps, profile_dir, dump_steps)
diff --git a/tensorflow/python/profiler/profile_context.py b/tensorflow/python/profiler/profile_context.py
index 6438fede2f..07adcb9c3f 100644
--- a/tensorflow/python/profiler/profile_context.py
+++ b/tensorflow/python/profiler/profile_context.py
@@ -18,15 +18,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import contextlib
import os
-import threading
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python import pywrap_tensorflow as print_mdl
from tensorflow.python.client import session
from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
from tensorflow.python.platform import gfile
from tensorflow.python.profiler import model_analyzer
-from tensorflow.python.profiler import tfprof_logger
+from tensorflow.python.util import compat
+
+MAX_TRACED_STEPS = 100
def _profiled_init(self, target='', graph=None, config=None):
@@ -42,51 +46,56 @@ def _profiled_run(self,
"""Overwrites the session.run()."""
# pylint: disable=protected-access
# Count the session steps.
- self.profile_context._new_step()
- # Fast path if no need for profiling.
- to_profiles = self.profile_context._profile_candidates()
- to_dumps = self.profile_context._dump_candidates()
- if (not to_profiles and not to_dumps and
- not self.profile_context._is_capture_enforced()):
- return self._profiler_run_internal(
- fetches, feed_dict, options, run_metadata)
-
- # Enable tracing, perform auto profiling or auto dump.
- if not run_metadata:
- run_metadata = config_pb2.RunMetadata()
-
- if not options:
- options = config_pb2.RunOptions(
- trace_level=config_pb2.RunOptions.FULL_TRACE)
- old_trace_level = options.trace_level
- else:
- old_trace_level = options.trace_level
- options.trace_level = config_pb2.RunOptions.FULL_TRACE
-
- ret = self._profiler_run_internal(fetches, feed_dict, options, run_metadata)
-
- if self.profile_context._capture_next_step:
- self.profile_context._add_run_meta(run_metadata)
-
- for to_dump in to_dumps:
- outdir, _ = to_dump
- if not gfile.Exists(outdir):
- gfile.MakeDirs(outdir)
- with gfile.Open(os.path.join(outdir, 'graph.pbtxt'), 'w') as f:
- f.write('%s' % self.graph.as_graph_def(add_shapes=True))
- with gfile.Open(os.path.join(outdir, 'run_metadata'), 'w') as f:
- f.write(run_metadata.SerializeToString())
- tfprof_logger.write_op_log(
- self.graph, outdir, run_meta=run_metadata, add_trace=True)
-
- for to_prof in to_profiles:
- cmd, opts, _ = to_prof
- model_analyzer.profile(
- self.graph, run_meta=run_metadata, cmd=cmd, options=opts)
-
- # Restore to default.
- options.trace_level = old_trace_level
- return ret
+ with self.profile_context._new_step():
+ # Fast path if no need for profiling.
+ if self.profile_context._is_fast_path():
+ return self._profiler_run_internal(
+ fetches, feed_dict, options, run_metadata)
+
+ step = self.profile_context._step
+
+ # Maybe trace this step.
+ if self.profile_context._should_trace():
+ # Enable tracing, perform auto profiling or auto dump.
+ if not run_metadata:
+ run_metadata = config_pb2.RunMetadata()
+
+ if not options:
+ options = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
+ old_trace_level = options.trace_level
+ else:
+ old_trace_level = options.trace_level
+ options.trace_level = config_pb2.RunOptions.FULL_TRACE
+
+ ret = self._profiler_run_internal(
+ fetches, feed_dict, options, run_metadata)
+
+ self.profile_context.profiler._graph = self.graph
+ self.profile_context.profiler.add_step(step, run_metadata)
+ options.trace_level = old_trace_level
+ else:
+ ret = self._profiler_run_internal(fetches, feed_dict, options)
+
+ # Maybe dump profile.
+ self.profile_context._maybe_dump()
+
+ # Maybe profile:
+ to_profiles = self.profile_context._profile_candidates()
+ for to_prof in to_profiles:
+ cmd, opts, _ = to_prof
+ if cmd == 'graph':
+ self.profile_context.profiler.profile_graph(opts)
+ elif cmd == 'scope':
+ self.profile_context.profiler.profile_name_scope(opts)
+ elif cmd == 'op':
+ self.profile_context.profiler.profile_operations(opts)
+ elif cmd == 'code':
+ self.profile_context.profiler.profile_python(opts)
+ else:
+ raise ValueError('Unknown cmd: %s\n' % cmd)
+
+ return ret
# pylint: enable=protected-access
@@ -94,112 +103,135 @@ class ProfileContext(object):
"""A Context that captures RunMetadata and performs profiling.
```python
- # Auto profiling at step 1, 100 and 1000.:
- with tf.contrib.tfprof.ProfileContext() as pctx:
- # Create the profiling options.
+ # Trace steps 10~20, profile at [15, 18, 20] and dump profile at 20.
+ with tf.contrib.tfprof.ProfileContext('/tmp/train_dir',
+ trace_steps=range(10, 20),
+ dump_steps=[20]) as pctx:
opts = tf.profiler.ProfileOptionBuilder.time_and_memory()
- # Run profiling at certain steps. Multiple ones can be added.
- pctx.add_auto_profiling('op', opts, [1, 100, 1000])
- # Or dump the profile files at certain steps.
- pctx.add_auto_profile_dump('/tmp/profiles', [1000])
- # Run train/eval loop.
+ pctx.add_auto_profiling('op', opts, [15, 18, 20])
+ train_loop().
+
+ # Tracing only.
+ with tf.contrib.tfprof.ProfileContext('/tmp/train_dir') as pctx:
+ # Run train/eval loop for at least few hundred steps. Profiles will be
+ # dumped to train_dir. Use web UI or command line to do profiling.
train_loop().
- # Alternatively, enable and capture RunMetadata of next step.
- with tf.contrib.tfprof.ProfileContext() as pctx:
- pctx.capture_next_run_meta()
+ # When session object is available, do explicit trace, profile and dump.
+ with tf.contrib.tfprof.ProfileContext('/tmp/train_dir',
+ trace_steps=[],
+ dump_steps=[]) as pctx:
opts = tf.profiler.ProfileOptionBuilder.time_and_memory()
+ pctx.trace_next_step()
_ = session.run(train_op)
- tf.profiler.profile(session.graph,
- run_meta=pctx.run_meta(),
- cmd='op',
- options=opts)
+ pctx.profiler.profile_operations(options=opts)
```
+
+ Args:
+ profile_dir: Directory to store profiles.
+ trace_steps: A list of session run steps to trace. If None, use
+ pre-defined steps.
+ dump_steps: A list of steps to dump the profile to `profile_dir`. If None,
+ use pre-defined steps.
"""
- def __init__(self):
- self._lock = threading.Lock()
- self._capture_next_step = False
+ def __init__(self,
+ profile_dir,
+ trace_steps=None,
+ dump_steps=None):
+ if not profile_dir:
+ raise ValueError('Must have a directory for profile.\n')
+ self._profiler_dir = profile_dir
+
+ if trace_steps is None:
+ self._trace_steps = set(list(range(10, 50, 3)) +
+ list(range(100, 10000, 1000)))
+ else:
+ if len(trace_steps) > MAX_TRACED_STEPS:
+ raise ValueError('Only support tracing up to 100 steps.\n')
+ self._trace_steps = set(trace_steps[:])
+
+ if dump_steps is None:
+ self._dump_steps = set([50] + list(range(100, 10000, 2000)))
+ else:
+ self._dump_steps = set(dump_steps[:])
+
+ self._slow_path_steps = self._dump_steps | self._trace_steps
+ self._trace_next_step = False
+ self._dump_next_step = False
self._step = 0
+ self._traced_steps = 0
self._auto_profiles = []
- self._auto_dumps = []
- self._run_meta = None
+ self._profiler = None
- def add_auto_profiling(self, cmd, profile_options, profile_steps):
- """Runs profiling at some steps with provided command and options.
+ def add_auto_profiling(self, cmd, options, profile_steps):
+ """Traces and profiles at some session run steps.
Args:
- cmd: The profiling commands.
- profile_options: The profiling options.
+ cmd: The profiling commands. (i.e. scope, op, python, graph)
+ options: The profiling options.
profile_steps: A list/set of integers. The profiling command and options
will be run automatically at these integer steps. Each step is
a session.run.
"""
- with self._lock:
- self._auto_profiles.append((cmd, profile_options, profile_steps))
-
- def add_auto_profile_dump(self, outdir, dump_steps):
- """Dumps profiles at some steps to the directory.
-
- Args:
- outdir: The directory to dump the profile files.
- dump_steps: A list/set of integers. The profile files will be dump at
- these integer steps. Each step is a session.run.
- """
- with self._lock:
- self._auto_dumps.append((outdir, dump_steps))
-
- def capture_next_run_meta(self):
- """Enables tracing and captures RunMetadata at next session.run.
-
- The captured RunMetadata can be retrieved via run_meta(). It
- will be cleared one step later.
- """
- with self._lock:
- self._capture_next_step = True
-
- def run_meta(self):
- """Returns the RunMetadata captured at previous session.run.
-
- Needs to call capture_next_run_meta() before session.run to enable
- capturing.
- """
- with self._lock:
- assert self._run_meta, 'Need to call capture_next_run_meta()'
- return self._run_meta
-
- def _is_capture_enforced(self):
- with self._lock:
- return self._capture_next_step
-
- def _add_run_meta(self, run_meta):
- with self._lock:
- self._run_meta = run_meta
- self._capture_next_step = False
-
+ self._auto_profiles.append((cmd, options, profile_steps[:]))
+ self._slow_path_steps |= set(profile_steps)
+ self._trace_steps |= set(profile_steps)
+
+ @property
+ def profiler(self):
+ """Returns the current profiler object."""
+ if not self._profiler:
+ self._profiler = model_analyzer.Profiler(ops.get_default_graph())
+ return self._profiler
+
+ def trace_next_step(self):
+ """Enables tracing and add traces to profiler at next step."""
+ self._trace_next_step = True
+
+ def dump_next_step(self):
+ """Enable tracing and dump profiles at next step."""
+ self._dump_next_step = True
+
+ def _is_fast_path(self):
+ if (self._step in self._slow_path_steps or
+ self._trace_next_step or
+ self._dump_next_step):
+ return False
+ return True
+
+ def _should_trace(self):
+ if self._traced_steps > MAX_TRACED_STEPS:
+ return False
+ trace = self._step in self._trace_steps or self._trace_next_step
+ if trace:
+ self._traced_steps += 1
+ return trace
+
+ def _maybe_dump(self):
+ if not (self._step in self._dump_steps or self._dump_next_step):
+ return
+ if not gfile.Exists(self._profiler_dir):
+ gfile.MakeDirs(self._profiler_dir)
+ print_mdl.WriteProfile(
+ os.path.join(compat.as_bytes(self._profiler_dir),
+ compat.as_bytes('profile_%d' % self._step)))
+
+ @contextlib.contextmanager
def _new_step(self):
- with self._lock:
- self._run_meta = None
- self._step += 1
+ yield
+ self._step += 1
+ self._trace_next_step = False
+ self._dump_next_step = False
def _profile_candidates(self):
to_profile = []
- with self._lock:
- for auto_prof in self._auto_profiles:
- _, _, prof_steps = auto_prof
- if self._step - 1 in prof_steps:
- to_profile.append(auto_prof)
+ for auto_prof in self._auto_profiles:
+ _, _, prof_steps = auto_prof
+ if self._step in prof_steps:
+ to_profile.append(auto_prof)
return to_profile
- def _dump_candidates(self):
- to_dump = []
- with self._lock:
- for auto_dump in self._auto_dumps:
- _, dump_steps = auto_dump
- if self._step - 1 in dump_steps:
- to_dump.append(auto_dump)
- return to_dump
-
def __enter__(self):
self.old_run = getattr(session.BaseSession, 'run', None)
self.old_init = getattr(session.BaseSession, '__init__', None)
@@ -223,6 +255,7 @@ class ProfileContext(object):
return self
def __exit__(self, exec_type, exec_value, exec_tb):
+ print_mdl.DeleteProfiler()
setattr(session.BaseSession, 'run', self.old_run)
setattr(session.BaseSession, '__init__', self.old_init)
setattr(session.BaseSession, '_profiler_run_internal', None)
diff --git a/tensorflow/python/profiler/profile_context_test.py b/tensorflow/python/profiler/profile_context_test.py
new file mode 100644
index 0000000000..e3ad00da66
--- /dev/null
+++ b/tensorflow/python/profiler/profile_context_test.py
@@ -0,0 +1,71 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+from tensorflow.python.client import session
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.profiler import option_builder
+
+# pylint: disable=g-bad-import-order
+from tensorflow.python.profiler import profile_context
+from tensorflow.python.profiler.internal import model_analyzer_testlib as lib
+
+builder = option_builder.ProfileOptionBuilder
+
+
+class ProfilerContextTest(test.TestCase):
+
+ def testBasics(self):
+ ops.reset_default_graph()
+ outfile = os.path.join(test.get_temp_dir(), "dump")
+ opts = builder(builder.time_and_memory()
+ ).with_file_output(outfile).build()
+
+ x = lib.BuildFullModel()
+
+ profile_str = None
+ profile_step50 = os.path.join(test.get_temp_dir(), "profile_50")
+ with profile_context.ProfileContext(test.get_temp_dir()) as pctx:
+ pctx.add_auto_profiling("op", options=opts, profile_steps=[15, 50, 100])
+ with session.Session() as sess:
+ sess.run(variables.global_variables_initializer())
+ total_steps = 101 if test.is_gpu_available() else 50
+ for i in range(total_steps):
+ sess.run(x)
+ if i == 14 or i == 99:
+ self.assertTrue(gfile.Exists(outfile))
+ gfile.Remove(outfile)
+ if i == 49:
+ self.assertTrue(gfile.Exists(profile_step50))
+ with gfile.Open(outfile, "r") as f:
+ profile_str = f.read()
+ gfile.Remove(outfile)
+
+ with lib.ProfilerFromFile(
+ os.path.join(test.get_temp_dir(), "profile_50")) as profiler:
+ profiler.profile_operations(options=opts)
+ with gfile.Open(outfile, "r") as f:
+ self.assertEqual(profile_str, f.read())
+
+
+if __name__ == "__main__":
+ test.main()