diff options
Diffstat (limited to 'tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py')
-rw-r--r-- | tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py | 121 |
1 files changed, 97 insertions, 24 deletions
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py index 419beac0b9..c781d2af4e 100644 --- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py @@ -20,6 +20,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six + from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_logger from tensorflow.contrib.tfprof.python.tools.tfprof.internal import pywrap_tensorflow_print_model_analysis_lib as print_mdl from tensorflow.python.framework import errors @@ -108,49 +110,77 @@ PRINT_ALL_TIMING_MEMORY = { 'dump_to_file': '' } +# The following options are for 'advise' tfprof_cmd. +# Show all advice. +ALL_ADVICE = { + 'ExpensiveOperationChecker': {}, + 'AcceleratorUtilizationChecker': {}, + 'JobChecker': {}, # Only available internally. + 'OperationChecker': {}, +} + # pylint: enable=bad-whitespace # pylint: enable=bad-continuation -def _build_options(tfprof_options): +def _build_options(options): """Build tfprof.OptionsProto. Args: - tfprof_options: A dictionary of options. + options: A dictionary of options. Returns: tfprof.OptionsProto. """ opts = tfprof_options_pb2.OptionsProto() - opts.max_depth = tfprof_options.get('max_depth', 10) - opts.min_bytes = tfprof_options.get('min_bytes', 0) - opts.min_micros = tfprof_options.get('min_micros', 0) - opts.min_params = tfprof_options.get('min_params', 0) - opts.min_float_ops = tfprof_options.get('min_float_ops', 0) - opts.min_occurrence = tfprof_options.get('min_occurrence', 0) + opts.max_depth = options.get('max_depth', 10) + opts.min_bytes = options.get('min_bytes', 0) + opts.min_micros = options.get('min_micros', 0) + opts.min_params = options.get('min_params', 0) + opts.min_float_ops = options.get('min_float_ops', 0) + opts.min_occurrence = options.get('min_occurrence', 0) - opts.step = tfprof_options.get('step', -1) + opts.step = options.get('step', -1) - opts.order_by = tfprof_options.get('order_by', 'name') + opts.order_by = options.get('order_by', 'name') - for p in tfprof_options.get('account_type_regexes', []): + for p in options.get('account_type_regexes', []): opts.account_type_regexes.append(p) - for p in tfprof_options.get('start_name_regexes', []): + for p in options.get('start_name_regexes', []): opts.start_name_regexes.append(p) - for p in tfprof_options.get('trim_name_regexes', []): + for p in options.get('trim_name_regexes', []): opts.trim_name_regexes.append(p) - for p in tfprof_options.get('show_name_regexes', []): + for p in options.get('show_name_regexes', []): opts.show_name_regexes.append(p) - for p in tfprof_options.get('hide_name_regexes', []): + for p in options.get('hide_name_regexes', []): opts.hide_name_regexes.append(p) - opts.account_displayed_op_only = tfprof_options.get( - 'account_displayed_op_only', False) + opts.account_displayed_op_only = options.get('account_displayed_op_only', + False) - for p in tfprof_options.get('select', []): + for p in options.get('select', []): opts.select.append(p) - opts.output = tfprof_options.get('output', 'stdout') - opts.dump_to_file = tfprof_options.get('dump_to_file', '') + opts.output = options.get('output', 'stdout') + opts.dump_to_file = options.get('dump_to_file', '') + + return opts + + +def _build_advisor_options(options): + """Build tfprof.AdvisorOptionsProto. + Args: + options: A dictionary of options. See ALL_ADVICE example. + Returns: + tfprof.AdvisorOptionsProto. + """ + opts = tfprof_options_pb2.AdvisorOptionsProto() + if options is None: + return opts + for checker, checker_opts in six.iteritems(options): + checker_ops_pb = tfprof_options_pb2.AdvisorOptionsProto.CheckerOption() + for k, v in six.iteritems(checker_opts): + checker_ops_pb[k] = v + opts.checkers[checker].MergeFrom(checker_ops_pb) return opts @@ -190,7 +220,7 @@ class Profiler(object): else: _ = sess.run(...) # Auto detect problems and generate advice. - profiler.advise() + profiler.advise(model_analyzer.ALL_ADVICE) """ def __init__(self, graph, op_log=None): @@ -288,9 +318,19 @@ class Profiler(object): print_mdl.Profile('graph'.encode('utf-8'), opts.SerializeToString())) return tfprof_node - def advise(self): - """Automatically detect problems and generate reports.""" - print_mdl.Advise() + def advise(self, options=ALL_ADVICE): # pylint: disable=dangerous-default-value + """Automatically detect problems and generate reports. + + Args: + options: A dict of options. + Returns: + A Advise proto that conains the reports from all checkers. + """ + advise_pb = tfprof_output_pb2.AdviceProto() + opts = _build_advisor_options(options) + advise_pb.ParseFromString( + print_mdl.Profile('advise'.encode('utf-8'), opts.SerializeToString())) + return advise_pb def print_model_analysis(graph, @@ -354,3 +394,36 @@ def print_model_analysis(graph, None, None, 'unknown tfprof_cmd: %s\n' % tfprof_cmd) return tfprof_node + + +def advise(graph, run_meta=None, tfprof_options=ALL_ADVICE): # pylint: disable=dangerous-default-value + """Auto profile and advise. + + Builds profiles and automatically check anormalies of various + aspects. See go/tfprof or README for examples and tutorials. + + Args: + graph: tf.Graph. + run_meta: tensorflow::RunMetadata proto. Allows auto-profile + time and memroy. + tfprof_options: see ALL_ADVICE example above. + Returns: + Returns AdviceProto proto + """ + # pylint: disable=protected-access + op_log = tfprof_logger._merge_default_with_oplog( + graph, None, run_meta, add_trace=True) + # pylint: enable=protected-access + + run_meta_str = run_meta.SerializeToString() if run_meta else b'' + + opts = _build_advisor_options(tfprof_options) + ret = tfprof_output_pb2.AdviceProto() + ret.ParseFromString( + print_mdl.PrintModelAnalysis( + graph.as_graph_def(add_shapes=True).SerializeToString(), + run_meta_str, + op_log.SerializeToString(), + 'advise'.encode('utf-8'), + opts.SerializeToString())) + return ret |