aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py')
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py121
1 files changed, 97 insertions, 24 deletions
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
index 419beac0b9..c781d2af4e 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
@@ -20,6 +20,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import six
+
from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_logger
from tensorflow.contrib.tfprof.python.tools.tfprof.internal import pywrap_tensorflow_print_model_analysis_lib as print_mdl
from tensorflow.python.framework import errors
@@ -108,49 +110,77 @@ PRINT_ALL_TIMING_MEMORY = {
'dump_to_file': ''
}
+# The following options are for 'advise' tfprof_cmd.
+# Show all advice.
+ALL_ADVICE = {
+ 'ExpensiveOperationChecker': {},
+ 'AcceleratorUtilizationChecker': {},
+ 'JobChecker': {}, # Only available internally.
+ 'OperationChecker': {},
+}
+
# pylint: enable=bad-whitespace
# pylint: enable=bad-continuation
-def _build_options(tfprof_options):
+def _build_options(options):
"""Build tfprof.OptionsProto.
Args:
- tfprof_options: A dictionary of options.
+ options: A dictionary of options.
Returns:
tfprof.OptionsProto.
"""
opts = tfprof_options_pb2.OptionsProto()
- opts.max_depth = tfprof_options.get('max_depth', 10)
- opts.min_bytes = tfprof_options.get('min_bytes', 0)
- opts.min_micros = tfprof_options.get('min_micros', 0)
- opts.min_params = tfprof_options.get('min_params', 0)
- opts.min_float_ops = tfprof_options.get('min_float_ops', 0)
- opts.min_occurrence = tfprof_options.get('min_occurrence', 0)
+ opts.max_depth = options.get('max_depth', 10)
+ opts.min_bytes = options.get('min_bytes', 0)
+ opts.min_micros = options.get('min_micros', 0)
+ opts.min_params = options.get('min_params', 0)
+ opts.min_float_ops = options.get('min_float_ops', 0)
+ opts.min_occurrence = options.get('min_occurrence', 0)
- opts.step = tfprof_options.get('step', -1)
+ opts.step = options.get('step', -1)
- opts.order_by = tfprof_options.get('order_by', 'name')
+ opts.order_by = options.get('order_by', 'name')
- for p in tfprof_options.get('account_type_regexes', []):
+ for p in options.get('account_type_regexes', []):
opts.account_type_regexes.append(p)
- for p in tfprof_options.get('start_name_regexes', []):
+ for p in options.get('start_name_regexes', []):
opts.start_name_regexes.append(p)
- for p in tfprof_options.get('trim_name_regexes', []):
+ for p in options.get('trim_name_regexes', []):
opts.trim_name_regexes.append(p)
- for p in tfprof_options.get('show_name_regexes', []):
+ for p in options.get('show_name_regexes', []):
opts.show_name_regexes.append(p)
- for p in tfprof_options.get('hide_name_regexes', []):
+ for p in options.get('hide_name_regexes', []):
opts.hide_name_regexes.append(p)
- opts.account_displayed_op_only = tfprof_options.get(
- 'account_displayed_op_only', False)
+ opts.account_displayed_op_only = options.get('account_displayed_op_only',
+ False)
- for p in tfprof_options.get('select', []):
+ for p in options.get('select', []):
opts.select.append(p)
- opts.output = tfprof_options.get('output', 'stdout')
- opts.dump_to_file = tfprof_options.get('dump_to_file', '')
+ opts.output = options.get('output', 'stdout')
+ opts.dump_to_file = options.get('dump_to_file', '')
+
+ return opts
+
+
+def _build_advisor_options(options):
+ """Build tfprof.AdvisorOptionsProto.
+ Args:
+ options: A dictionary of options. See ALL_ADVICE example.
+ Returns:
+ tfprof.AdvisorOptionsProto.
+ """
+ opts = tfprof_options_pb2.AdvisorOptionsProto()
+ if options is None:
+ return opts
+ for checker, checker_opts in six.iteritems(options):
+ checker_ops_pb = tfprof_options_pb2.AdvisorOptionsProto.CheckerOption()
+ for k, v in six.iteritems(checker_opts):
+ checker_ops_pb[k] = v
+ opts.checkers[checker].MergeFrom(checker_ops_pb)
return opts
@@ -190,7 +220,7 @@ class Profiler(object):
else:
_ = sess.run(...)
# Auto detect problems and generate advice.
- profiler.advise()
+ profiler.advise(model_analyzer.ALL_ADVICE)
"""
def __init__(self, graph, op_log=None):
@@ -288,9 +318,19 @@ class Profiler(object):
print_mdl.Profile('graph'.encode('utf-8'), opts.SerializeToString()))
return tfprof_node
- def advise(self):
- """Automatically detect problems and generate reports."""
- print_mdl.Advise()
+ def advise(self, options=ALL_ADVICE): # pylint: disable=dangerous-default-value
+ """Automatically detect problems and generate reports.
+
+ Args:
+ options: A dict of options.
+ Returns:
+ A Advise proto that conains the reports from all checkers.
+ """
+ advise_pb = tfprof_output_pb2.AdviceProto()
+ opts = _build_advisor_options(options)
+ advise_pb.ParseFromString(
+ print_mdl.Profile('advise'.encode('utf-8'), opts.SerializeToString()))
+ return advise_pb
def print_model_analysis(graph,
@@ -354,3 +394,36 @@ def print_model_analysis(graph,
None, None, 'unknown tfprof_cmd: %s\n' % tfprof_cmd)
return tfprof_node
+
+
+def advise(graph, run_meta=None, tfprof_options=ALL_ADVICE): # pylint: disable=dangerous-default-value
+ """Auto profile and advise.
+
+ Builds profiles and automatically check anormalies of various
+ aspects. See go/tfprof or README for examples and tutorials.
+
+ Args:
+ graph: tf.Graph.
+ run_meta: tensorflow::RunMetadata proto. Allows auto-profile
+ time and memroy.
+ tfprof_options: see ALL_ADVICE example above.
+ Returns:
+ Returns AdviceProto proto
+ """
+ # pylint: disable=protected-access
+ op_log = tfprof_logger._merge_default_with_oplog(
+ graph, None, run_meta, add_trace=True)
+ # pylint: enable=protected-access
+
+ run_meta_str = run_meta.SerializeToString() if run_meta else b''
+
+ opts = _build_advisor_options(tfprof_options)
+ ret = tfprof_output_pb2.AdviceProto()
+ ret.ParseFromString(
+ print_mdl.PrintModelAnalysis(
+ graph.as_graph_def(add_shapes=True).SerializeToString(),
+ run_meta_str,
+ op_log.SerializeToString(),
+ 'advise'.encode('utf-8'),
+ opts.SerializeToString()))
+ return ret