aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/internal/pywrap_tensorflow_print_model_analysis.i1
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py121
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py30
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py16
-rw-r--r--tensorflow/tools/tfprof/BUILD1
-rw-r--r--tensorflow/tools/tfprof/README.md6
-rw-r--r--tensorflow/tools/tfprof/g3doc/advise.md80
-rw-r--r--tensorflow/tools/tfprof/internal/advisor/BUILD9
-rw-r--r--tensorflow/tools/tfprof/internal/advisor/accelerator_utilization_checker.h24
-rw-r--r--tensorflow/tools/tfprof/internal/advisor/checker.h24
-rw-r--r--tensorflow/tools/tfprof/internal/advisor/expensive_operation_checker.h138
-rw-r--r--tensorflow/tools/tfprof/internal/advisor/internal_checker_runner.h5
-rw-r--r--tensorflow/tools/tfprof/internal/advisor/internal_checker_runner_dummy.cc6
-rw-r--r--tensorflow/tools/tfprof/internal/advisor/operation_checker.h19
-rw-r--r--tensorflow/tools/tfprof/internal/advisor/tfprof_advisor.h47
-rw-r--r--tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc46
-rw-r--r--tensorflow/tools/tfprof/internal/print_model_analysis.cc19
-rw-r--r--tensorflow/tools/tfprof/internal/print_model_analysis.h2
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_graph.h4
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_op.cc11
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_op.h6
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_options.h4
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_show.cc14
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_show.h13
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_show_multi.cc19
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_show_multi.h16
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_show_test.cc1
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_stats.cc77
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_stats.h22
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_stats_test.cc1
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_tensor_test.cc1
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_timeline_test.cc1
-rw-r--r--tensorflow/tools/tfprof/tfprof_main.cc25
-rw-r--r--tensorflow/tools/tfprof/tfprof_options.proto10
-rw-r--r--tensorflow/tools/tfprof/tfprof_output.proto10
35 files changed, 646 insertions, 183 deletions
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/pywrap_tensorflow_print_model_analysis.i b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/pywrap_tensorflow_print_model_analysis.i
index 40f29ae8a2..582c36e339 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/pywrap_tensorflow_print_model_analysis.i
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/pywrap_tensorflow_print_model_analysis.i
@@ -43,7 +43,6 @@ using tensorflow::int64;
%unignore tensorflow::tfprof::DeleteProfiler;
%unignore tensorflow::tfprof::AddStep;
%unignore tensorflow::tfprof::Profile;
-%unignore tensorflow::tfprof::Advise;
%include "tensorflow/tools/tfprof/internal/print_model_analysis.h"
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
index 419beac0b9..c781d2af4e 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
@@ -20,6 +20,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import six
+
from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_logger
from tensorflow.contrib.tfprof.python.tools.tfprof.internal import pywrap_tensorflow_print_model_analysis_lib as print_mdl
from tensorflow.python.framework import errors
@@ -108,49 +110,77 @@ PRINT_ALL_TIMING_MEMORY = {
'dump_to_file': ''
}
+# The following options are for 'advise' tfprof_cmd.
+# Show all advice.
+ALL_ADVICE = {
+ 'ExpensiveOperationChecker': {},
+ 'AcceleratorUtilizationChecker': {},
+ 'JobChecker': {}, # Only available internally.
+ 'OperationChecker': {},
+}
+
# pylint: enable=bad-whitespace
# pylint: enable=bad-continuation
-def _build_options(tfprof_options):
+def _build_options(options):
"""Build tfprof.OptionsProto.
Args:
- tfprof_options: A dictionary of options.
+ options: A dictionary of options.
Returns:
tfprof.OptionsProto.
"""
opts = tfprof_options_pb2.OptionsProto()
- opts.max_depth = tfprof_options.get('max_depth', 10)
- opts.min_bytes = tfprof_options.get('min_bytes', 0)
- opts.min_micros = tfprof_options.get('min_micros', 0)
- opts.min_params = tfprof_options.get('min_params', 0)
- opts.min_float_ops = tfprof_options.get('min_float_ops', 0)
- opts.min_occurrence = tfprof_options.get('min_occurrence', 0)
+ opts.max_depth = options.get('max_depth', 10)
+ opts.min_bytes = options.get('min_bytes', 0)
+ opts.min_micros = options.get('min_micros', 0)
+ opts.min_params = options.get('min_params', 0)
+ opts.min_float_ops = options.get('min_float_ops', 0)
+ opts.min_occurrence = options.get('min_occurrence', 0)
- opts.step = tfprof_options.get('step', -1)
+ opts.step = options.get('step', -1)
- opts.order_by = tfprof_options.get('order_by', 'name')
+ opts.order_by = options.get('order_by', 'name')
- for p in tfprof_options.get('account_type_regexes', []):
+ for p in options.get('account_type_regexes', []):
opts.account_type_regexes.append(p)
- for p in tfprof_options.get('start_name_regexes', []):
+ for p in options.get('start_name_regexes', []):
opts.start_name_regexes.append(p)
- for p in tfprof_options.get('trim_name_regexes', []):
+ for p in options.get('trim_name_regexes', []):
opts.trim_name_regexes.append(p)
- for p in tfprof_options.get('show_name_regexes', []):
+ for p in options.get('show_name_regexes', []):
opts.show_name_regexes.append(p)
- for p in tfprof_options.get('hide_name_regexes', []):
+ for p in options.get('hide_name_regexes', []):
opts.hide_name_regexes.append(p)
- opts.account_displayed_op_only = tfprof_options.get(
- 'account_displayed_op_only', False)
+ opts.account_displayed_op_only = options.get('account_displayed_op_only',
+ False)
- for p in tfprof_options.get('select', []):
+ for p in options.get('select', []):
opts.select.append(p)
- opts.output = tfprof_options.get('output', 'stdout')
- opts.dump_to_file = tfprof_options.get('dump_to_file', '')
+ opts.output = options.get('output', 'stdout')
+ opts.dump_to_file = options.get('dump_to_file', '')
+
+ return opts
+
+
+def _build_advisor_options(options):
+ """Build tfprof.AdvisorOptionsProto.
+ Args:
+ options: A dictionary of options. See ALL_ADVICE example.
+ Returns:
+ tfprof.AdvisorOptionsProto.
+ """
+ opts = tfprof_options_pb2.AdvisorOptionsProto()
+ if options is None:
+ return opts
+ for checker, checker_opts in six.iteritems(options):
+ checker_ops_pb = tfprof_options_pb2.AdvisorOptionsProto.CheckerOption()
+ for k, v in six.iteritems(checker_opts):
+ checker_ops_pb[k] = v
+ opts.checkers[checker].MergeFrom(checker_ops_pb)
return opts
@@ -190,7 +220,7 @@ class Profiler(object):
else:
_ = sess.run(...)
# Auto detect problems and generate advice.
- profiler.advise()
+ profiler.advise(model_analyzer.ALL_ADVICE)
"""
def __init__(self, graph, op_log=None):
@@ -288,9 +318,19 @@ class Profiler(object):
print_mdl.Profile('graph'.encode('utf-8'), opts.SerializeToString()))
return tfprof_node
- def advise(self):
- """Automatically detect problems and generate reports."""
- print_mdl.Advise()
+ def advise(self, options=ALL_ADVICE): # pylint: disable=dangerous-default-value
+ """Automatically detect problems and generate reports.
+
+ Args:
+ options: A dict of options.
+ Returns:
+ A Advise proto that conains the reports from all checkers.
+ """
+ advise_pb = tfprof_output_pb2.AdviceProto()
+ opts = _build_advisor_options(options)
+ advise_pb.ParseFromString(
+ print_mdl.Profile('advise'.encode('utf-8'), opts.SerializeToString()))
+ return advise_pb
def print_model_analysis(graph,
@@ -354,3 +394,36 @@ def print_model_analysis(graph,
None, None, 'unknown tfprof_cmd: %s\n' % tfprof_cmd)
return tfprof_node
+
+
+def advise(graph, run_meta=None, tfprof_options=ALL_ADVICE): # pylint: disable=dangerous-default-value
+ """Auto profile and advise.
+
+ Builds profiles and automatically check anormalies of various
+ aspects. See go/tfprof or README for examples and tutorials.
+
+ Args:
+ graph: tf.Graph.
+ run_meta: tensorflow::RunMetadata proto. Allows auto-profile
+ time and memroy.
+ tfprof_options: see ALL_ADVICE example above.
+ Returns:
+ Returns AdviceProto proto
+ """
+ # pylint: disable=protected-access
+ op_log = tfprof_logger._merge_default_with_oplog(
+ graph, None, run_meta, add_trace=True)
+ # pylint: enable=protected-access
+
+ run_meta_str = run_meta.SerializeToString() if run_meta else b''
+
+ opts = _build_advisor_options(tfprof_options)
+ ret = tfprof_output_pb2.AdviceProto()
+ ret.ParseFromString(
+ print_mdl.PrintModelAnalysis(
+ graph.as_graph_def(add_shapes=True).SerializeToString(),
+ run_meta_str,
+ op_log.SerializeToString(),
+ 'advise'.encode('utf-8'),
+ opts.SerializeToString()))
+ return ret
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
index 1b5041441f..ec7540a657 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
@@ -126,7 +126,7 @@ class PrintModelAnalysisTest(test.TestCase):
opts['account_displayed_op_only'] = False
opts['select'] = ['params', 'float_ops']
- with session.Session() as sess, ops.device('/cpu:0'):
+ with session.Session() as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -176,6 +176,7 @@ class PrintModelAnalysisTest(test.TestCase):
opts['select'] = [
'bytes', 'params', 'float_ops', 'device'
]
+ opts['output'] = 'none'
with session.Session() as sess:
x = lib.BuildSmallModel()
@@ -276,6 +277,33 @@ class PrintModelAnalysisTest(test.TestCase):
self.assertEqual(total_children, 15)
self.assertGreater(input_shapes, 0)
+ def testAdvisor(self):
+ ops.reset_default_graph()
+
+ with session.Session() as sess:
+ x = lib.BuildFullModel()
+
+ sess.run(variables.global_variables_initializer())
+ run_meta = config_pb2.RunMetadata()
+ _ = sess.run(
+ x,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE),
+ run_metadata=run_meta)
+
+ advice_pb = model_analyzer.advise(sess.graph, run_meta)
+ self.assertTrue('AcceleratorUtilizationChecker' in advice_pb.checkers)
+ self.assertTrue('ExpensiveOperationChecker' in advice_pb.checkers)
+ self.assertTrue('OperationChecker' in advice_pb.checkers)
+
+ checker = advice_pb.checkers['AcceleratorUtilizationChecker']
+ if test.is_gpu_available():
+ self.assertGreater(len(checker.reports), 0)
+ else:
+ self.assertEqual(len(checker.reports), 0)
+ checker = advice_pb.checkers['ExpensiveOperationChecker']
+ self.assertGreater(len(checker.reports), 0)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py
index 5daaafd7c8..c7113b6a57 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py
@@ -129,7 +129,7 @@ class ProfilerTest(test.TestCase):
opts = model_analyzer.PRINT_ALL_TIMING_MEMORY.copy()
opts['account_type_regexes'] = ['.*']
- with session.Session() as sess, ops.device('/cpu:0'):
+ with session.Session() as sess:
r1, r2, r3 = lib.BuildSplitableModel()
sess.run(variables.global_variables_initializer())
@@ -179,8 +179,18 @@ class ProfilerTest(test.TestCase):
self.assertEqual(lib.SearchTFProfNode(pb2, 'add'), None)
self.assertGreater(lib.SearchTFProfNode(pb3, 'add').exec_micros, 0)
- # TODO(xpan): Better test of advisor.
- profiler.advise()
+ advice_pb = profiler.advise(model_analyzer.ALL_ADVICE)
+ self.assertTrue('AcceleratorUtilizationChecker' in advice_pb.checkers)
+ self.assertTrue('ExpensiveOperationChecker' in advice_pb.checkers)
+ self.assertTrue('OperationChecker' in advice_pb.checkers)
+
+ checker = advice_pb.checkers['AcceleratorUtilizationChecker']
+ if test.is_gpu_available():
+ self.assertGreater(len(checker.reports), 0)
+ else:
+ self.assertEqual(len(checker.reports), 0)
+ checker = advice_pb.checkers['ExpensiveOperationChecker']
+ self.assertGreater(len(checker.reports), 0)
if __name__ == '__main__':
diff --git a/tensorflow/tools/tfprof/BUILD b/tensorflow/tools/tfprof/BUILD
index 57cccd8921..541df78e47 100644
--- a/tensorflow/tools/tfprof/BUILD
+++ b/tensorflow/tools/tfprof/BUILD
@@ -33,6 +33,7 @@ cc_binary(
"//tensorflow/tools/tfprof/internal:tfprof_options",
"//tensorflow/tools/tfprof/internal:tfprof_stats",
"//tensorflow/tools/tfprof/internal:tfprof_utils",
+ "//tensorflow/tools/tfprof/internal/advisor:tfprof_advisor",
"@linenoise//:linenoise",
],
)
diff --git a/tensorflow/tools/tfprof/README.md b/tensorflow/tools/tfprof/README.md
index 5927990524..816ad8c07e 100644
--- a/tensorflow/tools/tfprof/README.md
+++ b/tensorflow/tools/tfprof/README.md
@@ -7,7 +7,11 @@
* Profile model performance
* execution time, memory consumption
* Profile multiple steps.
-* Auto detect and advise. (Experimental)
+* Auto profile and advise.
+ * accelerator utilization check
+ * expensive operation check
+ * operation configuration check
+ * distributed runtime check (Not OSS)
###Interfaces
diff --git a/tensorflow/tools/tfprof/g3doc/advise.md b/tensorflow/tools/tfprof/g3doc/advise.md
index 3bce6270ff..e30add6fbf 100644
--- a/tensorflow/tools/tfprof/g3doc/advise.md
+++ b/tensorflow/tools/tfprof/g3doc/advise.md
@@ -3,6 +3,7 @@
tfprof analyzes profiles and generates advises for common issues.
### Run Advise.
+
```python
# First create a profiler. See profiler tutorials for more details.
profiler = model_analyzer.Profiler(sess.graph)
@@ -13,8 +14,63 @@ _ = sess.run(r1,
run_metadata=run_meta)
profiler.add_step(1, run_meta)
-# Start advise.
-profiler.advise()
+# Then Start advise.
+profiler.advise(model_analyzer.ALL_ADVICE)
+
+# For one-shot API
+tf.contrib.tfprof.model_analyzer.advise(
+ sess.graph, run_meta=run_metadata)
+```
+
+```shell
+# Run advisor on CLI
+# See CLI tutorial on generating the files.
+tfprof --graph_path=graph.pbtxt \
+ --run_meta_path=run_metadata \
+ --op_log_path=tfprof_log
+
+tfprof> advise
+AcceleratorUtilizationChecker:
+device: /job:worker/replica:0/task:0/gpu:0 low utilization: 0.03
+device: /job:worker/replica:0/task:0/gpu:1 low utilization: 0.08
+device: /job:worker/replica:0/task:0/gpu:2 low utilization: 0.04
+device: /job:worker/replica:0/task:0/gpu:3 low utilization: 0.21
+
+OperationChecker:
+Found operation using NHWC data_format on GPU. Maybe NCHW is faster.
+
+ExpensiveOperationChecker:
+top 1 operation type: SoftmaxCrossEntropyWithLogits, cpu: 1.37sec, accelerator: 0us, total: 1.37sec (26.68%)
+top 2 operation type: MatMul, cpu: 427.39ms, accelerator: 280.76ms, total: 708.14ms (13.83%)
+top 3 operation type: ConcatV2, cpu: 357.83ms, accelerator: 31.80ms, total: 389.63ms (7.61%)
+seq2seq_attention_model.py:360:build_graph:self._add_seq2seq(), cpu: 3.16sec, accelerator: 214.84ms, total: 3.37sec
+ seq2seq_attention_model.py:293:_add_seq2seq:decoder_outputs, ..., cpu: 2.46sec, accelerator: 3.25ms, total: 2.47sec
+ seq2seq_lib.py:181:sampled_sequence_...:average_across_ti..., cpu: 2.46sec, accelerator: 3.24ms, total: 2.47sec
+ seq2seq_lib.py:147:sequence_loss_by_...:crossent = loss_f..., cpu: 2.46sec, accelerator: 3.06ms, total: 2.46sec
+ seq2seq_attention_model.py:289:sampled_loss_func:num_classes=vsize), cpu: 2.46sec, accelerator: 3.06ms, total: 2.46sec
+ seq2seq_attention_model.py:282:sampled_loss_func:labels = tf.resha..., cpu: 164us, accelerator: 0us, total: 164us
+ seq2seq_lib.py:148:sequence_loss_by_...:log_perp_list.app..., cpu: 1.33ms, accelerator: 120us, total: 1.45ms
+ seq2seq_lib.py:151:sequence_loss_by_...:total_size = tf.a..., cpu: 154us, accelerator: 23us, total: 177us
+ seq2seq_lib.py:184:sampled_sequence_...:return cost / tf...., cpu: 97us, accelerator: 8us, total: 105us
+ math_ops.py:690:cast:return gen_math_o..., cpu: 62us, accelerator: 3us, total: 65us
+ math_ops.py:839:binary_op_wrapper:return func(x, y,..., cpu: 35us, accelerator: 5us, total: 40us
+ seq2seq_attention_model.py:192:_add_seq2seq:sequence_length=a..., cpu: 651.56ms, accelerator: 158.92ms, total: 810.48ms
+ seq2seq_lib.py:104:bidirectional_rnn:sequence_length, ..., cpu: 306.58ms, accelerator: 73.54ms, total: 380.12ms
+ core_rnn.py:195:static_rnn:state_size=cell.s..., cpu: 306.52ms, accelerator: 73.54ms, total: 380.05ms
+ rnn.py:218:_rnn_step:_maybe_copy_some_..., cpu: 303.76ms, accelerator: 73.54ms, total: 377.30ms
+ rnn.py:216:_rnn_step:time >= max_seque..., cpu: 2.75ms, accelerator: 0us, total: 2.75ms
+ core_rnn.py:179:static_rnn:max_sequence_leng..., cpu: 67us, accelerator: 0us, total: 67us
+ seq2seq_lib.py:110:bidirectional_rnn:initial_state_bw,..., cpu: 296.21ms, accelerator: 73.54ms, total: 369.75ms
+ core_rnn.py:195:static_rnn:state_size=cell.s..., cpu: 296.11ms, accelerator: 73.54ms, total: 369.65ms
+ rnn.py:218:_rnn_step:_maybe_copy_some_..., cpu: 292.04ms, accelerator: 73.54ms, total: 365.58ms
+ rnn.py:216:_rnn_step:time >= max_seque..., cpu: 4.07ms, accelerator: 0us, total: 4.07ms
+ core_rnn.py:178:static_rnn:min_sequence_leng..., cpu: 85us, accelerator: 0us, total: 85us
+ core_rnn.py:179:static_rnn:max_sequence_leng..., cpu: 16us, accelerator: 0us, total: 16us
+ seq2seq_lib.py:113:bidirectional_rnn:outputs = [tf.con..., cpu: 46.88ms, accelerator: 3.87ms, total: 50.75ms
+ ...(omitted)
+top 1 graph node: seq2seq/loss/sampled_sequence_loss/sequence_loss_by_example/SoftmaxCrossEntropyWithLogits_11, cpu: 89.92ms, accelerator: 0us, total: 89.92ms
+top 2 graph node: train_step/update_seq2seq/output_projection/w/ApplyAdam, cpu: 84.52ms, accelerator: 0us, total: 84.52ms
+top 3 graph node: seq2seq/loss/sampled_sequence_loss/sequence_loss_by_example/SoftmaxCrossEntropyWithLogits_19, cpu: 73.02ms, accelerator: 0us, total: 73.02ms
```
### Checker
@@ -25,16 +81,24 @@ area with the profile and report issues. A `Checker` is like a plugin.
For example:
-####JobChecker (Not Available OSS)
-* Checking RecvTensor RPC latency and bandwidth.
-* Checking CPU/Memory utilization of the job.
+#### JobChecker (Not Available OSS)
+
+* Checks RecvTensor RPC latency and bandwidth.
+* Checks CPU/Memory utilization of the job.
####AcceleratorUtilization Checker
* Checks what percentage of time the accelerator spends on computation.
-####Operation Checker
-* Check whether the operation runs with optimal options.
-* Checks if there is a better implementation to replace the current operation.
+#### OperationChecker
+
+* Checks whether the operation runs with optimal options.
+* Checks if there is a better implementation to replace the current operation.
+
+#### ExpensiveOperationChecker
+
+* Checks the most expensive operation type.
+* Checks the most expensive graph nodes.
+* Checks the most expensive graph-building Python codes.
####Contribute Your Checker
diff --git a/tensorflow/tools/tfprof/internal/advisor/BUILD b/tensorflow/tools/tfprof/internal/advisor/BUILD
index 30012fa7b1..d28443bba7 100644
--- a/tensorflow/tools/tfprof/internal/advisor/BUILD
+++ b/tensorflow/tools/tfprof/internal/advisor/BUILD
@@ -44,11 +44,20 @@ cc_library(
)
cc_library(
+ name = "expensive_operation_checker",
+ hdrs = ["expensive_operation_checker.h"],
+ deps = [
+ ":checker",
+ ],
+)
+
+cc_library(
name = "tfprof_advisor",
hdrs = ["tfprof_advisor.h"],
deps = [
":accelerator_utilization_checker",
":checker",
+ ":expensive_operation_checker",
":internal_checker_runner_dummy",
":operation_checker",
],
diff --git a/tensorflow/tools/tfprof/internal/advisor/accelerator_utilization_checker.h b/tensorflow/tools/tfprof/internal/advisor/accelerator_utilization_checker.h
index fb7f65d7dc..074b8e57b0 100644
--- a/tensorflow/tools/tfprof/internal/advisor/accelerator_utilization_checker.h
+++ b/tensorflow/tools/tfprof/internal/advisor/accelerator_utilization_checker.h
@@ -33,10 +33,11 @@ struct ExecStats {
class AcceleratorUtilizationChecker : public Checker {
public:
- string name() override { return "AcceleratorUtilizationChecker"; }
+ string name() const override { return kCheckers[0]; }
private:
- std::vector<string> Check(const TFStats* stats) override {
+ AdviceProto::Checker Check(const AdvisorOptionsProto::CheckerOption& options,
+ const TFStats* stats) override {
if (!stats) {
fprintf(stderr, "Missing profiles (e.g. graph, run_meta). Skip %s\n",
name().c_str());
@@ -48,24 +49,21 @@ class AcceleratorUtilizationChecker : public Checker {
return CheckInternal();
}
- std::vector<string> CheckInternal() {
+ AdviceProto::Checker CheckInternal() {
for (const auto& s : accelerator_exec_stats_) {
const ExecStats& stat = s.second;
int64 total_micros = stat.end_micros - stat.start_micros;
if (total_micros <= 0) continue;
double utilization = 1.0 * stat.exec_micros / total_micros;
if (utilization >= 0.5) {
- reports_.push_back(strings::Printf("%s: device: %s utilization: %.2f",
- kLevel[0], s.first.c_str(),
- utilization));
+ reports_.add_reports(strings::Printf("device: %s utilization: %.2f",
+ s.first.c_str(), utilization));
} else if (utilization < 0.5 && utilization > 0.2) {
- reports_.push_back(
- strings::Printf("%s: device: %s low utilization: %.2f", kLevel[1],
- s.first.c_str(), utilization));
+ reports_.add_reports(strings::Printf("device: %s low utilization: %.2f",
+ s.first.c_str(), utilization));
} else if (utilization <= 0.2) {
- reports_.push_back(
- strings::Printf("%s: device: %s low utilization: %.2f", kLevel[2],
- s.first.c_str(), utilization));
+ reports_.add_reports(strings::Printf("device: %s low utilization: %.2f",
+ s.first.c_str(), utilization));
}
}
return reports_;
@@ -102,7 +100,7 @@ class AcceleratorUtilizationChecker : public Checker {
std::map<string, ExecStats> accelerator_exec_stats_;
std::map<string, int64> ps_placement_;
- std::vector<string> reports_;
+ AdviceProto::Checker reports_;
};
} // namespace tfprof
diff --git a/tensorflow/tools/tfprof/internal/advisor/checker.h b/tensorflow/tools/tfprof/internal/advisor/checker.h
index b8b057be5b..3ce80cd8c4 100644
--- a/tensorflow/tools/tfprof/internal/advisor/checker.h
+++ b/tensorflow/tools/tfprof/internal/advisor/checker.h
@@ -18,27 +18,33 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/tools/tfprof/internal/tfprof_stats.h"
+#include "tensorflow/tools/tfprof/tfprof_options.pb.h"
namespace tensorflow {
namespace tfprof {
-static const char* const kLevel[] = {
- "NOTE", // Good to know.
- "SUGGEST", // Might get better.
- "WARN", // Please do it for better.
+// Append only.
+static const char* const kCheckers[] = {
+ "AcceleratorUtilizationChecker", "OperationChecker",
+ "ExpensiveOperationChecker",
+ "JobChecker", // Internal checker.
};
class Checker {
public:
- virtual ~Checker(){};
+ virtual ~Checker() {}
- virtual string name() = 0;
+ virtual string name() const = 0;
- std::vector<string> Run(const TFStats* stats) { return Check(stats); }
+ AdviceProto::Checker Run(const AdvisorOptionsProto::CheckerOption& options,
+ const TFStats* stats) {
+ return Check(options, stats);
+ }
protected:
- // Returns a vector of string, each one being an advice.
- virtual std::vector<string> Check(const TFStats* stats) = 0;
+ virtual AdviceProto::Checker Check(
+ const AdvisorOptionsProto::CheckerOption& options,
+ const TFStats* stats) = 0;
};
} // namespace tfprof
} // namespace tensorflow
diff --git a/tensorflow/tools/tfprof/internal/advisor/expensive_operation_checker.h b/tensorflow/tools/tfprof/internal/advisor/expensive_operation_checker.h
new file mode 100644
index 0000000000..14e30edd87
--- /dev/null
+++ b/tensorflow/tools/tfprof/internal/advisor/expensive_operation_checker.h
@@ -0,0 +1,138 @@
+/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This checker checks the most expensive operations.
+#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_EXPENSIVE_OPERATION_CHECKER_H_
+#define THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_EXPENSIVE_OPERATION_CHECKER_H_
+
+#include "tensorflow/tools/tfprof/internal/advisor/checker.h"
+
+namespace tensorflow {
+namespace tfprof {
+
+class ExpensiveOperationChecker : public Checker {
+ public:
+ string name() const override { return kCheckers[2]; }
+
+ private:
+ AdviceProto::Checker Check(const AdvisorOptionsProto::CheckerOption& options,
+ const TFStats* stats) override {
+ if (!stats) {
+ fprintf(stderr, "Missing profiles (e.g. graph, run_meta). Skip %s\n",
+ name().c_str());
+ return reports_;
+ }
+ if (stats->steps().empty()) {
+ fprintf(stderr, "Missing RunMetadata info. Skip %s\n", name().c_str());
+ }
+ CheckOpView(stats);
+ CheckCodeView(stats);
+ CheckScopeView(stats);
+ return reports_;
+ }
+
+ void CheckOpView(const TFStats* stats) {
+ if (stats->steps().empty()) {
+ fprintf(stderr, "Missing run_meta for %s\n", name().c_str());
+ return;
+ }
+ Options opts(3, 0, 1, 0, 0, 0, -1, "micros", {".*"}, {".*"}, {}, {".*"}, {},
+ false, {"micros", "occurrence"}, "none", {});
+ const TFMultiGraphNodeProto root = stats->ShowMultiGraphNode("op", opts);
+ if (root.children_size() == 0) {
+ return;
+ }
+ const TFMultiGraphNodeProto* node = &root;
+ std::vector<string> outputs;
+ for (int i = 0; i < 3 && node->children_size() > 0; ++i) {
+ node = &node->children(0);
+ outputs.push_back(strings::Printf(
+ "top %d operation type: %s, "
+ "cpu: %s, accelerator: %s, total: %s (%.2f%%)",
+ i + 1, node->name().c_str(),
+ FormatTime(node->cpu_exec_micros()).c_str(),
+ FormatTime(node->accelerator_exec_micros()).c_str(),
+ FormatTime(node->exec_micros()).c_str(),
+ 100.0 * node->exec_micros() / (root.total_exec_micros() + 1e-10)));
+ }
+ reports_.add_reports(str_util::Join(outputs, "\n"));
+ }
+
+ void CheckCodeView(const TFStats* stats) {
+ if (!stats->has_code_traces()) {
+ fprintf(stderr, "Missing op_log (code traces) for %s\n", name().c_str());
+ return;
+ }
+ Options opts(100, 0, 1, 0, 0, 0, -1, "micros", {".*"}, {".*"}, {}, {".*"},
+ {}, false, {"micros"}, "none", {});
+ const TFMultiGraphNodeProto root = stats->ShowMultiGraphNode("code", opts);
+ const TFMultiGraphNodeProto* node = &root;
+ // A trick here is: Usually, codes in library file are usually referenced
+ // only once, while user's own code are referenced multiple times.
+ while (node->children_size() == 1) {
+ node = &node->children(0);
+ }
+ if (node->children_size() == 0) {
+ return;
+ }
+
+ std::vector<string> outputs;
+ CodeViewHelper(node, 0, &outputs);
+ reports_.add_reports(str_util::Join(outputs, "\n"));
+ }
+
+ void CheckScopeView(const TFStats* stats) {
+ Options opts(100, 0, 100, 0, 0, 0, -1, "micros", {".*"}, {".*"}, {}, {".*"},
+ {}, false, {"micros"}, "none", {});
+ const TFGraphNodeProto root = stats->ShowGraphNode("scope", opts);
+ if (root.children_size() == 0) {
+ return;
+ }
+ std::vector<string> outputs;
+ const TFGraphNodeProto* node = &root;
+ for (int i = 0; i < 3 && i < root.children_size(); ++i) {
+ const TFGraphNodeProto& node = root.children(i);
+ outputs.push_back(strings::Printf(
+ "top %d graph node: %s, cpu: %s, accelerator: %s, total: %s", i + 1,
+ node.name().c_str(), FormatTime(node.cpu_exec_micros()).c_str(),
+ FormatTime(node.accelerator_exec_micros()).c_str(),
+ FormatTime(node.exec_micros()).c_str()));
+ }
+ reports_.add_reports(str_util::Join(outputs, "\n"));
+ }
+
+ void CodeViewHelper(const TFMultiGraphNodeProto* node, int depth,
+ std::vector<string>* outputs) {
+ if (node->children_size() <= 1 || depth > 4) {
+ return;
+ }
+ for (int j = 0; j < 3 && j < node->children_size(); ++j) {
+ const TFMultiGraphNodeProto* c = &node->children(j);
+ outputs->push_back(strings::Printf(
+ "%s%s, cpu: %s, accelerator: %s, total: %s",
+ string(depth * 2, ' ').c_str(), c->name().c_str(),
+ FormatTime(c->total_cpu_exec_micros()).c_str(),
+ FormatTime(c->total_accelerator_exec_micros()).c_str(),
+ FormatTime(c->total_exec_micros()).c_str()));
+ CodeViewHelper(c, depth + 1, outputs);
+ }
+ }
+
+ AdviceProto::Checker reports_;
+};
+
+} // namespace tfprof
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_EXPENSIVE_OP_CHECKER_H_
diff --git a/tensorflow/tools/tfprof/internal/advisor/internal_checker_runner.h b/tensorflow/tools/tfprof/internal/advisor/internal_checker_runner.h
index 1238b57f20..ed8ae571b6 100644
--- a/tensorflow/tools/tfprof/internal/advisor/internal_checker_runner.h
+++ b/tensorflow/tools/tfprof/internal/advisor/internal_checker_runner.h
@@ -17,13 +17,16 @@ limitations under the License.
#define THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_
#include "tensorflow/tools/tfprof/internal/tfprof_utils.h"
+#include "tensorflow/tools/tfprof/tfprof_options.pb.h"
+#include "tensorflow/tools/tfprof/tfprof_output.pb.h"
namespace tensorflow {
namespace tfprof {
class TFStats;
-std::map<string, std::vector<string>> RunInternalCheckers(const TFStats* stats);
+AdviceProto RunInternalCheckers(const AdvisorOptionsProto& options,
+ const TFStats* stats);
} // namespace tfprof
} // namespace tensorflow
diff --git a/tensorflow/tools/tfprof/internal/advisor/internal_checker_runner_dummy.cc b/tensorflow/tools/tfprof/internal/advisor/internal_checker_runner_dummy.cc
index 8204d2b04e..67962c8e8b 100644
--- a/tensorflow/tools/tfprof/internal/advisor/internal_checker_runner_dummy.cc
+++ b/tensorflow/tools/tfprof/internal/advisor/internal_checker_runner_dummy.cc
@@ -17,9 +17,9 @@ limitations under the License.
namespace tensorflow {
namespace tfprof {
-std::map<string, std::vector<string>> RunInternalCheckers(
- const TFStats* stats) {
- return std::map<string, std::vector<string>>();
+AdviceProto RunInternalCheckers(const AdvisorOptionsProto& options,
+ const TFStats* stats) {
+ return AdviceProto();
}
} // namespace tfprof
diff --git a/tensorflow/tools/tfprof/internal/advisor/operation_checker.h b/tensorflow/tools/tfprof/internal/advisor/operation_checker.h
index 2a05f9bfd0..4d0d68e3bf 100644
--- a/tensorflow/tools/tfprof/internal/advisor/operation_checker.h
+++ b/tensorflow/tools/tfprof/internal/advisor/operation_checker.h
@@ -24,10 +24,11 @@ namespace tfprof {
class OperationChecker : public Checker {
public:
- string name() override { return "OperationChecker"; }
+ string name() const override { return kCheckers[1]; }
private:
- std::vector<string> Check(const TFStats* stats) override {
+ AdviceProto::Checker Check(const AdvisorOptionsProto::CheckerOption& options,
+ const TFStats* stats) override {
if (!stats) {
fprintf(stderr, "Missing profiles (e.g. graph, run_meta). Skip %s\n",
name().c_str());
@@ -53,22 +54,20 @@ class OperationChecker : public Checker {
}
}
if (use_batch_norm && !use_fused_batch_norm) {
- reports_.push_back(strings::Printf(
- "%s: Maybe use faster FusedBatchNorm instead of BatchNorm",
- kLevel[1]));
+ reports_.add_reports(
+ "Maybe use faster FusedBatchNorm instead of BatchNorm");
}
if (recommend_nchw) {
// TODO(xpan): Maybe print which Op supports NCHW.
- reports_.push_back(strings::Printf(
- "%s: Found operation using NHWC data_format on GPU. Maybe "
- "NCHW is faster.",
- kLevel[1]));
+ reports_.add_reports(
+ "Found operation using NHWC data_format on GPU. Maybe "
+ "NCHW is faster.");
}
return reports_;
}
private:
- std::vector<string> reports_;
+ AdviceProto::Checker reports_;
};
} // namespace tfprof
diff --git a/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor.h b/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor.h
index 856f515459..d2257fb9b5 100644
--- a/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor.h
+++ b/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor.h
@@ -18,8 +18,10 @@ limitations under the License.
#include "tensorflow/tools/tfprof/internal/advisor/accelerator_utilization_checker.h"
#include "tensorflow/tools/tfprof/internal/advisor/checker.h"
+#include "tensorflow/tools/tfprof/internal/advisor/expensive_operation_checker.h"
#include "tensorflow/tools/tfprof/internal/advisor/internal_checker_runner.h"
#include "tensorflow/tools/tfprof/internal/advisor/operation_checker.h"
+#include "tensorflow/tools/tfprof/tfprof_options.pb.h"
namespace tensorflow {
namespace tfprof {
@@ -29,23 +31,44 @@ class Advisor {
public:
Advisor(const TFStats* stats) : stats_(stats) {}
- std::map<string, std::vector<string>> Advise() {
+ static AdvisorOptionsProto DefaultOptions() {
+ AdvisorOptionsProto options;
+ std::vector<string> checkers(
+ kCheckers, kCheckers + sizeof(kCheckers) / sizeof(*kCheckers));
+ for (const string& checker : checkers) {
+ (*options.mutable_checkers())[checker];
+ }
+ return options;
+ }
+
+ AdviceProto Advise(const AdvisorOptionsProto& options) {
// Note: Release a checker's memory ASAP.
- std::map<string, std::vector<string>> reports = RunInternalCheckers(stats_);
- // TODO(xpan): Think of a way to turn off/on specific checkers.
- AcceleratorUtilizationChecker au_checker;
- reports[au_checker.name()] = au_checker.Run(stats_);
- OperationChecker op_checker;
- reports[op_checker.name()] = op_checker.Run(stats_);
-
- for (const auto& checker_r : reports) {
- fprintf(stdout, "%s reports:\n", checker_r.first.c_str());
- for (const auto& r : checker_r.second) {
+ AdviceProto ret = RunInternalCheckers(options, stats_);
+
+ if (options.checkers().find(kCheckers[0]) != options.checkers().end()) {
+ AcceleratorUtilizationChecker au_checker;
+ (*ret.mutable_checkers())[kCheckers[0]].MergeFrom(
+ au_checker.Run(options.checkers().at(kCheckers[0]), stats_));
+ }
+ if (options.checkers().find(kCheckers[1]) != options.checkers().end()) {
+ OperationChecker op_checker;
+ (*ret.mutable_checkers())[kCheckers[1]].MergeFrom(
+ op_checker.Run(options.checkers().at(kCheckers[1]), stats_));
+ }
+ if (options.checkers().find(kCheckers[2]) != options.checkers().end()) {
+ ExpensiveOperationChecker expensive_op_checker;
+ (*ret.mutable_checkers())[kCheckers[2]].MergeFrom(
+ expensive_op_checker.Run(options.checkers().at(kCheckers[2]),
+ stats_));
+ }
+ for (const auto& checker : ret.checkers()) {
+ fprintf(stdout, "\n%s:\n", checker.first.c_str());
+ for (const string& r : checker.second.reports()) {
fprintf(stdout, "%s\n", r.c_str());
}
}
fflush(stdout);
- return reports;
+ return ret;
}
private:
diff --git a/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc b/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc
index b41d0770dc..3b40253954 100644
--- a/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc
+++ b/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor_test.cc
@@ -29,15 +29,16 @@ class TFProfAdvisorTest : public ::testing::Test {
nullptr, nullptr));
stats_->AddNodeForTest(
- "n1", CreateNode("n1", "Conv2D", {{"data_format", "NHWC"}}, 10, 2));
- stats_->AddNodeForTest("n2", CreateNode("n2", "Conv2D", {}, 20, 2));
+ 0, CreateNode("n1", "Conv2D", {{"data_format", "NHWC"}}, 0, 10, 2));
+ stats_->AddNodeForTest(0, CreateNode("n2", "Conv2D", {}, 0, 20, 2));
+ stats_->BuildAllViews();
advisor_.reset(new Advisor(stats_.get()));
}
std::unique_ptr<TFGraphNode> CreateNode(const string& name,
const string& type,
std::map<string, string> attrs,
- int64 start_miros,
+ int64 step, int64 start_miros,
int64 end_rel_micros) {
node_defs_.push_back(std::unique_ptr<NodeDef>(new NodeDef()));
NodeDef* def = node_defs_.back().get();
@@ -52,10 +53,10 @@ class TFProfAdvisorTest : public ::testing::Test {
NodeExecStats node_stat;
node_stat.set_all_start_micros(start_miros);
node_stat.set_op_end_rel_micros(end_rel_micros);
- node->AddStepStat(0, "/job:localhost/replica:0/task:0/gpu:0", node_stat);
- node->AddStepStat(0, "/job:localhost/replica:0/task:0/gpu:0:stream:all",
+ node->AddStepStat(step, "/job:localhost/replica:0/task:0/gpu:0", node_stat);
+ node->AddStepStat(step, "/job:localhost/replica:0/task:0/gpu:0:stream:all",
node_stat);
- node->AddStepStat(0, "/job:localhost/replica:0/task:0/gpu:0:stream:0",
+ node->AddStepStat(step, "/job:localhost/replica:0/task:0/gpu:0:stream:0",
node_stat);
return node;
}
@@ -66,23 +67,38 @@ class TFProfAdvisorTest : public ::testing::Test {
};
TEST_F(TFProfAdvisorTest, Basics) {
- std::map<string, std::vector<string>> reports = advisor_->Advise();
- EXPECT_TRUE(reports.find("AcceleratorUtilizationChecker") != reports.end());
- EXPECT_TRUE(reports.find("OperationChecker") != reports.end());
+ AdvisorOptionsProto options = Advisor::DefaultOptions();
+ AdviceProto advice = advisor_->Advise(options);
+ EXPECT_TRUE(advice.checkers().find(kCheckers[0]) != advice.checkers().end());
+ EXPECT_TRUE(advice.checkers().find(kCheckers[1]) != advice.checkers().end());
+ EXPECT_TRUE(advice.checkers().find(kCheckers[2]) != advice.checkers().end());
}
TEST_F(TFProfAdvisorTest, OperationChecker) {
- std::map<string, std::vector<string>> reports = advisor_->Advise();
- EXPECT_EQ(reports["OperationChecker"].size(), 1);
- EXPECT_TRUE(StringPiece(reports["OperationChecker"][0]).contains("NCHW"));
+ AdvisorOptionsProto options;
+ (*options.mutable_checkers())[kCheckers[1]];
+ AdviceProto advice = advisor_->Advise(options);
+ EXPECT_EQ(advice.checkers().at(kCheckers[1]).reports_size(), 1);
+ EXPECT_TRUE(StringPiece(advice.checkers().at(kCheckers[1]).reports(0))
+ .contains("NCHW"));
}
TEST_F(TFProfAdvisorTest, UtilizationChecker) {
- std::map<string, std::vector<string>> reports = advisor_->Advise();
- EXPECT_EQ(reports["AcceleratorUtilizationChecker"].size(), 1);
- EXPECT_TRUE(StringPiece(reports["AcceleratorUtilizationChecker"][0])
+ AdvisorOptionsProto options;
+ (*options.mutable_checkers())[kCheckers[0]];
+ AdviceProto advice = advisor_->Advise(options);
+ EXPECT_EQ(advice.checkers().at(kCheckers[0]).reports_size(), 1);
+ EXPECT_TRUE(StringPiece(advice.checkers().at(kCheckers[0]).reports(0))
.contains("low utilization"));
}
+TEST_F(TFProfAdvisorTest, ExpensiveOperationChecker) {
+ AdvisorOptionsProto options;
+ (*options.mutable_checkers())[kCheckers[2]];
+ AdviceProto advice = advisor_->Advise(options);
+ EXPECT_TRUE(StringPiece(advice.checkers().at(kCheckers[2]).reports(0))
+ .contains("top 1 operation type: Conv2D"));
+}
+
} // namespace tfprof
} // namespace tensorflow
diff --git a/tensorflow/tools/tfprof/internal/print_model_analysis.cc b/tensorflow/tools/tfprof/internal/print_model_analysis.cc
index 37d01db3a1..5a9c44d8e6 100644
--- a/tensorflow/tools/tfprof/internal/print_model_analysis.cc
+++ b/tensorflow/tools/tfprof/internal/print_model_analysis.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/tools/tfprof/internal/tfprof_options.h"
#include "tensorflow/tools/tfprof/internal/tfprof_stats.h"
#include "tensorflow/tools/tfprof/tfprof_log.pb.h"
+#include "tensorflow/tools/tfprof/tfprof_options.pb.h"
#include "tensorflow/tools/tfprof/tfprof_output.pb.h"
namespace tensorflow {
@@ -36,6 +37,18 @@ TFStats* tf_stat = nullptr;
string RunProfile(const string& command, const string& options,
TFStats* tf_stats) {
+ if (command == kCmds[4]) {
+ AdvisorOptionsProto option_pb;
+ if (!option_pb.ParseFromString(options)) {
+ fprintf(stderr, "Cannot parse AdvisorOptionsProto\n");
+ return "";
+ }
+ tf_stats->BuildAllViews();
+ return Advisor(tf_stats).Advise(option_pb).SerializeAsString();
+ } else {
+ tf_stats->BuildView(command);
+ }
+
Options opts;
tensorflow::Status s = Options::FromProtoStr(options, &opts);
if (!s.ok()) {
@@ -97,14 +110,14 @@ void AddStep(int64 step, const string* run_meta, const string* op_log) {
// TODO(xpan): Better error handling.
std::unique_ptr<RunMetadata> run_meta_ptr(new RunMetadata());
run_meta_ptr->ParseFromString(*run_meta);
- tf_stat->ParseRunMeta(step, std::move(run_meta_ptr));
+ tf_stat->AddRunMeta(step, std::move(run_meta_ptr));
std::unique_ptr<OpLog> op_log_ptr;
if (op_log && !op_log->empty()) {
op_log_ptr.reset(new OpLog());
op_log_ptr->ParseFromString(*op_log);
}
- tf_stat->ParseOpLog(std::move(op_log_ptr));
+ tf_stat->AddOpLog(std::move(op_log_ptr));
}
string Profile(const string* command, const string* options) {
@@ -144,7 +157,5 @@ string PrintModelAnalysis(const string* graph, const string* run_meta,
return RunProfile(*command, *options, &tf_stats);
}
-void Advise() { Advisor(tf_stat).Advise(); }
-
} // namespace tfprof
} // namespace tensorflow
diff --git a/tensorflow/tools/tfprof/internal/print_model_analysis.h b/tensorflow/tools/tfprof/internal/print_model_analysis.h
index 84165e542d..46db63646d 100644
--- a/tensorflow/tools/tfprof/internal/print_model_analysis.h
+++ b/tensorflow/tools/tfprof/internal/print_model_analysis.h
@@ -39,8 +39,6 @@ void AddStep(int64 step, const string* run_meta, const string* op_log);
string Profile(const string* command, const string* options);
-void Advise();
-
// Single-step Profiler.
//
// Interface defined for Python API swig. Calls the tfprof core API.
diff --git a/tensorflow/tools/tfprof/internal/tfprof_graph.h b/tensorflow/tools/tfprof/internal/tfprof_graph.h
index fbeae8673d..194a21f0cc 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_graph.h
+++ b/tensorflow/tools/tfprof/internal/tfprof_graph.h
@@ -54,8 +54,8 @@ class TFGraph : public TFShow {
const ShowNode* ShowInternal(const Options& opts,
Timeline* timeline) override;
- bool ShouldShowIfExtra(ShowNode* node, const Options& opts,
- int depth) override {
+ bool ShouldShowIfExtra(const ShowNode* node, const Options& opts,
+ int depth) const override {
return true;
}
diff --git a/tensorflow/tools/tfprof/internal/tfprof_op.cc b/tensorflow/tools/tfprof/internal/tfprof_op.cc
index ac702320b3..30284f0307 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_op.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_op.cc
@@ -126,6 +126,7 @@ const ShowMultiNode* TFOp::ShowInternal(const Options& opts,
}
nodes = SortNodes(nodes, opts);
+ // pre keeps track of previous visited node.
OpNode* pre = nullptr;
std::vector<OpNode*> account_nodes;
for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) {
@@ -170,16 +171,20 @@ const ShowMultiNode* TFOp::ShowInternal(const Options& opts,
root_->ResetTotalStats();
if (pre) {
root_->AggregateTotalStats(pre);
- root_->mutable_proto()->add_children()->MergeFrom(pre->proto());
- pre->mutable_proto()->clear_children();
}
}
+ if (pre) {
+ root_->mutable_proto()->add_children()->MergeFrom(pre->proto());
+ pre->mutable_proto()->clear_children();
+ }
if (opts.output_type == kOutput[1] || opts.output_type == kOutput[2]) {
string display_str = FormatLegend(opts);
for (OpNode* node : show_nodes) {
display_str += FormatNode(node, root_.get(), opts);
}
+ // In op view, we don't show root (total). But it will still in proto.
+ // TODO(xpan): Is it the right choice?
root_->formatted_str = display_str;
}
return root_.get();
@@ -201,7 +206,7 @@ int64 TFOp::SearchRoot(const std::vector<OpNode*> nodes,
return i;
}
-string TFOp::FormatNode(OpNode* node, OpNode* root, const Options& opts) {
+string TFOp::FormatNode(OpNode* node, OpNode* root, const Options& opts) const {
std::vector<string> attrs;
if (opts.select.find(kShown[0]) != opts.select.end()) {
diff --git a/tensorflow/tools/tfprof/internal/tfprof_op.h b/tensorflow/tools/tfprof/internal/tfprof_op.h
index 5b16490363..34812f54be 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_op.h
+++ b/tensorflow/tools/tfprof/internal/tfprof_op.h
@@ -56,15 +56,15 @@ class TFOp : public TFMultiShow {
int64 SearchRoot(const std::vector<OpNode*> nodes,
const std::vector<string>& regexes);
- bool ShouldShowIfExtra(ShowMultiNode* node, const Options& opts,
- int depth) override {
+ bool ShouldShowIfExtra(const ShowMultiNode* node, const Options& opts,
+ int depth) const override {
if (opts.min_occurrence > node->node->graph_nodes().size()) {
return false;
}
return true;
}
- string FormatNode(OpNode* node, OpNode* root, const Options& opts);
+ string FormatNode(OpNode* node, OpNode* root, const Options& opts) const;
std::unique_ptr<OpNode> root_;
std::map<string, std::unique_ptr<OpNode>> cnodes_map_;
diff --git a/tensorflow/tools/tfprof/internal/tfprof_options.h b/tensorflow/tools/tfprof/internal/tfprof_options.h
index 6c9db24342..d39333e3fc 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_options.h
+++ b/tensorflow/tools/tfprof/internal/tfprof_options.h
@@ -59,10 +59,10 @@ static const char* const kShown[] = {
"cpu_micros"};
static const char* const kCmds[] = {
- "scope", "graph", "code", "op", "set", "help",
+ "scope", "graph", "code", "op", "advise", "set", "help",
};
-static const char* const kOutput[] = {"timeline", "stdout", "file"};
+static const char* const kOutput[] = {"timeline", "stdout", "file", "none"};
static const char* const kTimelineOpts[] = {
"outfile",
diff --git a/tensorflow/tools/tfprof/internal/tfprof_show.cc b/tensorflow/tools/tfprof/internal/tfprof_show.cc
index eaab5b9608..c456099c86 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_show.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_show.cc
@@ -26,7 +26,9 @@ namespace tensorflow {
namespace tfprof {
const TFGraphNodeProto& TFShow::Show(const Options& opts) {
- if (opts.output_type == kOutput[0]) {
+ if (opts.output_type == kOutput[3]) {
+ return ShowInternal(opts, nullptr)->proto();
+ } else if (opts.output_type == kOutput[0]) {
Timeline timeline(opts.step, opts.output_options.at(kTimelineOpts[0]));
return ShowInternal(opts, &timeline)->proto();
} else if (opts.output_type == kOutput[2]) {
@@ -64,7 +66,8 @@ bool TFShow::LookUpCheckPoint(const string& name,
return true;
}
-bool TFShow::ShouldShow(ShowNode* node, const Options& opts, int depth) {
+bool TFShow::ShouldShow(const ShowNode* node, const Options& opts,
+ int depth) const {
// Always show kTFProfRoot.
if (node->name() == kTFProfRoot) return true;
@@ -96,7 +99,8 @@ bool TFShow::ShouldShow(ShowNode* node, const Options& opts, int depth) {
return true;
}
-bool TFShow::ShouldTrim(ShowNode* node, const std::vector<string>& regexes) {
+bool TFShow::ShouldTrim(const ShowNode* node,
+ const std::vector<string>& regexes) const {
for (const string& regex : regexes) {
if (RE2::FullMatch(node->name(), regex)) {
return true;
@@ -121,7 +125,7 @@ bool TFShow::ReAccount(ShowNode* node, const Options& opts) {
return false;
}
-string TFShow::FormatNode(ShowNode* node, const Options& opts) {
+string TFShow::FormatNode(ShowNode* node, const Options& opts) const {
std::vector<string> info;
if (opts.select.find(kShown[2]) != opts.select.end()) {
const string shape = FormatShapes(node->node->shape());
@@ -194,7 +198,7 @@ string TFShow::FormatNode(ShowNode* node, const Options& opts) {
str_util::Join(info, ", ").c_str());
}
-string TFShow::FormatLegend(const Options& opts) {
+string TFShow::FormatLegend(const Options& opts) const {
std::vector<string> legends;
if (opts.select.find(kShown[2]) != opts.select.end()) {
legends.push_back("# parameters");
diff --git a/tensorflow/tools/tfprof/internal/tfprof_show.h b/tensorflow/tools/tfprof/internal/tfprof_show.h
index 2c61b4fd73..95513e086f 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_show.h
+++ b/tensorflow/tools/tfprof/internal/tfprof_show.h
@@ -54,20 +54,21 @@ class TFShow {
std::unique_ptr<TFProfTensor>* tensor);
// Overridden by subclass if extra requirements need to be met.
- virtual bool ShouldShowIfExtra(ShowNode* node, const Options& opts,
- int depth) {
+ virtual bool ShouldShowIfExtra(const ShowNode* node, const Options& opts,
+ int depth) const {
return true;
}
- bool ShouldShow(ShowNode* node, const Options& opts, int depth);
+ bool ShouldShow(const ShowNode* node, const Options& opts, int depth) const;
- bool ShouldTrim(ShowNode* node, const std::vector<string>& regexes);
+ bool ShouldTrim(const ShowNode* node,
+ const std::vector<string>& regexes) const;
bool ReAccount(ShowNode* node, const Options& opts);
- string FormatNode(ShowNode* node, const Options& opts);
+ string FormatNode(ShowNode* node, const Options& opts) const;
- string FormatLegend(const Options& opts);
+ string FormatLegend(const Options& opts) const;
template <typename T>
std::vector<T*> SortNodes(const std::vector<T*>& nodes, const Options& opts) {
diff --git a/tensorflow/tools/tfprof/internal/tfprof_show_multi.cc b/tensorflow/tools/tfprof/internal/tfprof_show_multi.cc
index 97f204d25b..389968b750 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_show_multi.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_show_multi.cc
@@ -28,7 +28,9 @@ namespace tensorflow {
namespace tfprof {
const TFMultiGraphNodeProto& TFMultiShow::Show(const Options& opts) {
- if (opts.output_type == kOutput[0]) {
+ if (opts.output_type == kOutput[3]) {
+ return ShowInternal(opts, nullptr)->proto();
+ } else if (opts.output_type == kOutput[0]) {
Timeline timeline(opts.step, opts.output_options.at(kTimelineOpts[0]));
return ShowInternal(opts, &timeline)->proto();
} else if (opts.output_type == kOutput[2]) {
@@ -48,8 +50,8 @@ const TFMultiGraphNodeProto& TFMultiShow::Show(const Options& opts) {
}
}
-bool TFMultiShow::ShouldShow(ShowMultiNode* node, const Options& opts,
- int depth) {
+bool TFMultiShow::ShouldShow(const ShowMultiNode* node, const Options& opts,
+ int depth) const {
// Always show kTFProfRoot.
if (node->name() == kTFProfRoot) return true;
@@ -88,8 +90,8 @@ bool TFMultiShow::ShouldShow(ShowMultiNode* node, const Options& opts,
return true;
}
-bool TFMultiShow::ShouldTrim(ShowMultiNode* node,
- const std::vector<string>& regexes) {
+bool TFMultiShow::ShouldTrim(const ShowMultiNode* node,
+ const std::vector<string>& regexes) const {
for (const string& regex : regexes) {
if (RE2::FullMatch(node->name(), regex)) {
return true;
@@ -102,7 +104,7 @@ bool TFMultiShow::ReAccount(ShowMultiNode* node, const Options& opts) {
return node->ReInit(opts.step, opts.account_type_regexes);
}
-string TFMultiShow::FormatLegend(const Options& opts) {
+string TFMultiShow::FormatLegend(const Options& opts) const {
std::vector<string> legends;
if (opts.select.find(kShown[0]) != opts.select.end()) {
legends.push_back("output bytes");
@@ -142,7 +144,8 @@ string TFMultiShow::FormatLegend(const Options& opts) {
str_util::Join(legends, " | ").c_str());
}
-string TFMultiShow::FormatInputShapes(const TFMultiGraphNodeProto& proto) {
+string TFMultiShow::FormatInputShapes(
+ const TFMultiGraphNodeProto& proto) const {
std::map<string, int> input_shapes_str;
std::map<string, int> input_time_str;
for (int i = 0; i < proto.graph_nodes_size(); ++i) {
@@ -190,7 +193,7 @@ string TFMultiShow::FormatInputShapes(const TFMultiGraphNodeProto& proto) {
}
std::vector<string> TFMultiShow::FormatTimes(const ShowMultiNode* node,
- const Options& opts) {
+ const Options& opts) const {
std::vector<string> attrs;
if (opts.select.find(kShown[1]) != opts.select.end()) {
attrs.push_back(FormatTotalExecTime(node, opts));
diff --git a/tensorflow/tools/tfprof/internal/tfprof_show_multi.h b/tensorflow/tools/tfprof/internal/tfprof_show_multi.h
index e6faf1231d..ce309816a9 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_show_multi.h
+++ b/tensorflow/tools/tfprof/internal/tfprof_show_multi.h
@@ -55,21 +55,23 @@ class TFMultiShow {
std::unique_ptr<TFProfTensor>* tensor);
// Overridden by subclass if extra requirements need to be met.
- virtual bool ShouldShowIfExtra(ShowMultiNode* node, const Options& opts,
- int depth) {
+ virtual bool ShouldShowIfExtra(const ShowMultiNode* node, const Options& opts,
+ int depth) const {
return true;
}
- bool ShouldShow(ShowMultiNode* node, const Options& opts, int depth);
+ bool ShouldShow(const ShowMultiNode* node, const Options& opts,
+ int depth) const;
- bool ShouldTrim(ShowMultiNode* node, const std::vector<string>& regexes);
+ bool ShouldTrim(const ShowMultiNode* node,
+ const std::vector<string>& regexes) const;
bool ReAccount(ShowMultiNode* node, const Options& opts);
- string FormatLegend(const Options& opts);
- string FormatInputShapes(const TFMultiGraphNodeProto& proto);
+ string FormatLegend(const Options& opts) const;
+ string FormatInputShapes(const TFMultiGraphNodeProto& proto) const;
std::vector<string> FormatTimes(const ShowMultiNode* node,
- const Options& opts);
+ const Options& opts) const;
template <typename T>
std::vector<T*> SortNodes(const std::vector<T*>& nodes, const Options& opts) {
diff --git a/tensorflow/tools/tfprof/internal/tfprof_show_test.cc b/tensorflow/tools/tfprof/internal/tfprof_show_test.cc
index 478e269f87..a58fbdafca 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_show_test.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_show_test.cc
@@ -65,6 +65,7 @@ class TFProfShowTest : public ::testing::Test {
tf_stats_.reset(new TFStats(std::move(graph_pb), std::move(run_meta_pb),
std::move(op_log_pb), std::move(ckpt_reader)));
+ tf_stats_->BuildAllViews();
}
std::unique_ptr<TFStats> tf_stats_;
diff --git a/tensorflow/tools/tfprof/internal/tfprof_stats.cc b/tensorflow/tools/tfprof/internal/tfprof_stats.cc
index f5b8dad4e2..64da7ae7cf 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_stats.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_stats.cc
@@ -29,16 +29,17 @@ TFStats::TFStats(std::unique_ptr<GraphDef> graph,
std::unique_ptr<RunMetadata> run_meta,
std::unique_ptr<OpLog> op_log,
std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader)
- : graph_(std::move(graph)),
+ : has_code_traces_(false),
+ graph_(std::move(graph)),
ckpt_reader_(std::move(ckpt_reader)) {
CHECK(graph_) << "Must at least have GraphDef";
printf("Parsing Inputs...\n");
ParseGraph();
if (run_meta && run_meta->has_step_stats()) {
- ParseRunMeta(0, std::move(run_meta));
+ AddRunMeta(0, std::move(run_meta));
}
- ParseOpLog(std::move(op_log));
+ AddOpLog(std::move(op_log));
if (ckpt_reader_) {
for (const auto& v : ckpt_reader_->GetVariableToShapeMap()) {
@@ -48,27 +49,48 @@ TFStats::TFStats(std::unique_ptr<GraphDef> graph,
}
}
}
+}
- printf("Preparing Views...\n");
- scope_view_ = std::unique_ptr<TFScope>(new TFScope(ckpt_reader_.get()));
- graph_view_ = std::unique_ptr<TFGraph>(new TFGraph(ckpt_reader_.get()));
- code_view_ = std::unique_ptr<TFCode>(new TFCode());
- op_view_ = std::unique_ptr<TFOp>(new TFOp());
+void TFStats::BuildView(const string& cmd) {
+ if (cmd == kCmds[0] && !scope_view_) {
+ scope_view_.reset(new TFScope(ckpt_reader_.get()));
+ for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
+ scope_view_->AddNode(it->second.get());
+ }
+ scope_view_->Build();
+ }
+ if (cmd == kCmds[1] && !graph_view_) {
+ graph_view_.reset(new TFGraph(ckpt_reader_.get()));
+ for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
+ graph_view_->AddNode(it->second.get());
+ }
+ graph_view_->Build();
+ }
+ if (cmd == kCmds[2] && !code_view_) {
+ code_view_.reset(new TFCode());
+ for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
+ code_view_->AddNode(it->second.get());
+ }
+ code_view_->Build();
+ }
+ if (cmd == kCmds[3] && !op_view_) {
+ op_view_.reset(new TFOp());
+ for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
+ op_view_->AddNode(it->second.get());
+ }
+ op_view_->Build();
+ }
+}
- for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
- scope_view_->AddNode(it->second.get());
- graph_view_->AddNode(it->second.get());
- code_view_->AddNode(it->second.get());
- op_view_->AddNode(it->second.get());
- }
- scope_view_->Build();
- graph_view_->Build();
- code_view_->Build();
- op_view_->Build();
+void TFStats::BuildAllViews() {
+ std::vector<string> cmds_str(kCmds, kCmds + sizeof(kCmds) / sizeof(*kCmds));
+ for (const string& cmd : cmds_str) {
+ BuildView(cmd);
+ }
}
const TFGraphNodeProto& TFStats::ShowGraphNode(const string& cmd,
- const Options& opts) {
+ const Options& opts) const {
if (!Validate(opts)) {
return empty_graph_node_;
}
@@ -82,8 +104,8 @@ const TFGraphNodeProto& TFStats::ShowGraphNode(const string& cmd,
}
}
-const TFMultiGraphNodeProto& TFStats::ShowMultiGraphNode(const string& cmd,
- const Options& opts) {
+const TFMultiGraphNodeProto& TFStats::ShowMultiGraphNode(
+ const string& cmd, const Options& opts) const {
if (!Validate(opts)) {
return empty_multi_graph_node_;
}
@@ -130,7 +152,7 @@ void TFStats::ParseGraph() {
}
}
-void TFStats::ParseOpLog(std::unique_ptr<OpLog> op_log) {
+void TFStats::AddOpLog(std::unique_ptr<OpLog> op_log) {
if (!op_log) {
return;
}
@@ -144,12 +166,13 @@ void TFStats::ParseOpLog(std::unique_ptr<OpLog> op_log) {
node->second->AddFloatOps(entry.float_ops());
}
if (entry.has_code_def()) {
+ has_code_traces_ = true;
node->second->AddCode(entry.code_def());
}
}
}
-void TFStats::ParseRunMeta(int64 step, std::unique_ptr<RunMetadata> run_meta) {
+void TFStats::AddRunMeta(int64 step, std::unique_ptr<RunMetadata> run_meta) {
if (!run_meta || !run_meta->has_step_stats()) {
fprintf(stderr, "Invalid RunMetadata for step %lld\n", step);
return;
@@ -176,7 +199,7 @@ void TFStats::ParseRunMeta(int64 step, std::unique_ptr<RunMetadata> run_meta) {
}
}
-bool TFStats::Validate(const Options& opts) {
+bool TFStats::Validate(const Options& opts) const {
if (opts.step >= 0 && steps_.find(opts.step) == steps_.end()) {
fprintf(stderr, "Options -step=%lld not found\n", opts.step);
return false;
@@ -184,9 +207,9 @@ bool TFStats::Validate(const Options& opts) {
return true;
}
-void TFStats::AddNodeForTest(const string& name,
- std::unique_ptr<TFGraphNode> node) {
- nodes_map_[name] = std::move(node);
+void TFStats::AddNodeForTest(int64 step, std::unique_ptr<TFGraphNode> node) {
+ steps_.insert(step);
+ nodes_map_[node->name()] = std::move(node);
}
} // namespace tfprof
} // namespace tensorflow
diff --git a/tensorflow/tools/tfprof/internal/tfprof_stats.h b/tensorflow/tools/tfprof/internal/tfprof_stats.h
index dfb190e703..b26d274f80 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_stats.h
+++ b/tensorflow/tools/tfprof/internal/tfprof_stats.h
@@ -59,28 +59,38 @@ class TFStats {
const std::map<string, std::unique_ptr<TFGraphNode>>& nodes() const {
return nodes_map_;
}
+ const std::set<int64>& steps() const { return steps_; }
+ bool has_code_traces() const { return has_code_traces_; }
+ void BuildView(const string& cmd);
+ void BuildAllViews();
+
+ // Note: Must first BuildView(view_foo) before ShowXXX(view_foo) methods.
+ //
// Organize the TensorFlow model as different types of views, and generate
// outputs for profiling.
- const TFGraphNodeProto& ShowGraphNode(const string& cmd, const Options& opts);
+ // TODO(xpan): Should it return reference here?
+ const TFGraphNodeProto& ShowGraphNode(const string& cmd,
+ const Options& opts) const;
const TFMultiGraphNodeProto& ShowMultiGraphNode(const string& cmd,
- const Options& opts);
+ const Options& opts) const;
// Add a step of run time meta data.
- void ParseRunMeta(int64 step, std::unique_ptr<RunMetadata> run_meta);
+ void AddRunMeta(int64 step, std::unique_ptr<RunMetadata> run_meta);
// Add tfprof operation meta data, such as customized op type, float_ops,
// and code traces.
- void ParseOpLog(std::unique_ptr<OpLog> op_log);
+ void AddOpLog(std::unique_ptr<OpLog> op_log);
// For test purpose only.
- void AddNodeForTest(const string& name, std::unique_ptr<TFGraphNode> node);
+ void AddNodeForTest(int64 step, std::unique_ptr<TFGraphNode> node);
private:
- bool Validate(const Options& opts);
+ bool Validate(const Options& opts) const;
void ParseGraph();
std::set<int64> steps_;
+ bool has_code_traces_;
std::unique_ptr<GraphDef> graph_;
std::unique_ptr<TFScope> scope_view_;
std::unique_ptr<TFGraph> graph_view_;
diff --git a/tensorflow/tools/tfprof/internal/tfprof_stats_test.cc b/tensorflow/tools/tfprof/internal/tfprof_stats_test.cc
index 948ce49df4..45f1fdd06f 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_stats_test.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_stats_test.cc
@@ -66,6 +66,7 @@ class TFProfStatsTest : public ::testing::Test {
tf_stats_.reset(new TFStats(std::move(graph_pb), std::move(run_meta_pb),
std::move(op_log_pb), std::move(ckpt_reader)));
+ tf_stats_->BuildAllViews();
}
std::unique_ptr<TFStats> tf_stats_;
diff --git a/tensorflow/tools/tfprof/internal/tfprof_tensor_test.cc b/tensorflow/tools/tfprof/internal/tfprof_tensor_test.cc
index 698738c23c..59955d214c 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_tensor_test.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_tensor_test.cc
@@ -50,6 +50,7 @@ class TFProfTensorTest : public ::testing::Test {
tf_stats_.reset(new TFStats(std::move(graph_pb), std::move(run_meta_pb),
std::move(op_log_pb), std::move(ckpt_reader)));
+ tf_stats_->BuildAllViews();
}
std::unique_ptr<TFStats> tf_stats_;
diff --git a/tensorflow/tools/tfprof/internal/tfprof_timeline_test.cc b/tensorflow/tools/tfprof/internal/tfprof_timeline_test.cc
index 0e9bb9658c..cad31050a9 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_timeline_test.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_timeline_test.cc
@@ -52,6 +52,7 @@ class TFProfTimelineTest : public ::testing::Test {
tf_stats_.reset(new TFStats(std::move(graph_pb), std::move(run_meta_pb),
nullptr, nullptr));
+ tf_stats_->BuildAllViews();
}
std::unique_ptr<TFStats> tf_stats_;
diff --git a/tensorflow/tools/tfprof/tfprof_main.cc b/tensorflow/tools/tfprof/tfprof_main.cc
index 7a4e7e85ff..5e70c093cc 100644
--- a/tensorflow/tools/tfprof/tfprof_main.cc
+++ b/tensorflow/tools/tfprof/tfprof_main.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/command_line_flags.h"
+#include "tensorflow/tools/tfprof/internal/advisor/tfprof_advisor.h"
#include "tensorflow/tools/tfprof/internal/tfprof_options.h"
#include "tensorflow/tools/tfprof/internal/tfprof_stats.h"
#include "tensorflow/tools/tfprof/internal/tfprof_utils.h"
@@ -161,12 +162,13 @@ int Run(int argc, char** argv) {
"Profiling everything!\n");
return 0;
} else if (argc > 1) {
- if (string(argv[1]) == kCmds[5]) {
+ if (string(argv[1]) == kCmds[6]) {
PrintHelp();
return 0;
}
if (string(argv[1]) == kCmds[0] || string(argv[1]) == kCmds[1] ||
- string(argv[1]) == kCmds[2] || string(argv[1]) == kCmds[3]) {
+ string(argv[1]) == kCmds[2] || string(argv[1]) == kCmds[3] ||
+ string(argv[1]) == kCmds[4]) {
cmd = argv[1];
}
}
@@ -216,7 +218,13 @@ int Run(int argc, char** argv) {
run_meta_files[i].c_str(), s.ToString().c_str());
return 1;
}
- tf_stat.ParseRunMeta(i, std::move(run_meta));
+ tf_stat.AddRunMeta(i, std::move(run_meta));
+ }
+
+ if (cmd == kCmds[4]) {
+ tf_stat.BuildAllViews();
+ Advisor(&tf_stat).Advise(Advisor::DefaultOptions());
+ return 0;
}
Options opts(FLAGS_max_depth, FLAGS_min_bytes, FLAGS_min_micros,
@@ -227,9 +235,11 @@ int Run(int argc, char** argv) {
output_type, output_options);
if (cmd == kCmds[2] || cmd == kCmds[3]) {
+ tf_stat.BuildView(cmd);
tf_stat.ShowMultiGraphNode(cmd, opts);
return 0;
} else if (cmd == kCmds[0] || cmd == kCmds[1]) {
+ tf_stat.BuildView(cmd);
tf_stat.ShowGraphNode(cmd, opts);
return 0;
}
@@ -254,14 +264,19 @@ int Run(int argc, char** argv) {
fprintf(stderr, "E: %s\n", s.ToString().c_str());
continue;
}
- if (cmd == kCmds[4]) {
+ if (cmd == kCmds[5]) {
opts = new_opts;
- } else if (cmd == kCmds[5]) {
+ } else if (cmd == kCmds[6]) {
PrintHelp();
} else if (cmd == kCmds[2] || cmd == kCmds[3]) {
+ tf_stat.BuildView(cmd);
tf_stat.ShowMultiGraphNode(cmd, new_opts);
} else if (cmd == kCmds[0] || cmd == kCmds[1]) {
+ tf_stat.BuildView(cmd);
tf_stat.ShowGraphNode(cmd, new_opts);
+ } else if (cmd == kCmds[4]) {
+ tf_stat.BuildAllViews();
+ Advisor(&tf_stat).Advise(Advisor::DefaultOptions());
}
}
return 0;
diff --git a/tensorflow/tools/tfprof/tfprof_options.proto b/tensorflow/tools/tfprof/tfprof_options.proto
index 27eafb1ca9..47e1ff33ee 100644
--- a/tensorflow/tools/tfprof/tfprof_options.proto
+++ b/tensorflow/tools/tfprof/tfprof_options.proto
@@ -22,5 +22,13 @@ message OptionsProto {
optional bool account_displayed_op_only = 13;
repeated string select = 14;
optional string output = 15;
- optional string dump_to_file = 16;
+ optional string dump_to_file = 16 [deprecated = true];
+}
+
+message AdvisorOptionsProto {
+ // checker name -> a dict of key-value options.
+ map<string, CheckerOption> checkers = 1;
+ message CheckerOption {
+ map<string, string> options = 1;
+ }
}
diff --git a/tensorflow/tools/tfprof/tfprof_output.proto b/tensorflow/tools/tfprof/tfprof_output.proto
index 1ea956152c..e41fde45f2 100644
--- a/tensorflow/tools/tfprof/tfprof_output.proto
+++ b/tensorflow/tools/tfprof/tfprof_output.proto
@@ -95,4 +95,12 @@ message TFMultiGraphNodeProto {
// Descendants of the node. The actual descendants depend on the data
// structure used.
repeated TFMultiGraphNodeProto children = 11;
-} \ No newline at end of file
+}
+
+message AdviceProto {
+ // checker name -> a list of reports from the checker.
+ map<string, Checker> checkers = 1;
+ message Checker {
+ repeated string reports = 2;
+ }
+}