aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-11 18:17:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-11 18:21:40 -0700
commita5b9ef0ec0e9a4f5c42157755ffd180a90063e2b (patch)
tree2675d15016665f83f9517d54faebb550ac04bf2a /tensorflow
parent14221c1fd4155d80341feea61896a4d8abc7ae78 (diff)
Add timeline support in tfprof
This CL mainly adds timeline support in three views of tfprof. It includes a few other small changes: 1. Handle the case that one Op fires multiple kernels. 2. Remove the requirements for CostGraph for easier user adoption, for now. 3. Some speed improvements in graph view. 4. Consolidate the all kinds of tfprof output into one -output option. PiperOrigin-RevId: 155822542
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/tfprof/README.md78
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py10
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py38
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py9
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py10
-rw-r--r--tensorflow/tools/tfprof/README.md64
-rw-r--r--tensorflow/tools/tfprof/g3doc/code_timeline.pngbin0 -> 45674 bytes
-rw-r--r--tensorflow/tools/tfprof/g3doc/graph_timeline.pngbin0 -> 168051 bytes
-rw-r--r--tensorflow/tools/tfprof/g3doc/scope_timeline.pngbin0 -> 24944 bytes
-rw-r--r--tensorflow/tools/tfprof/internal/BUILD58
-rw-r--r--tensorflow/tools/tfprof/internal/print_model_analysis.cc11
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_code.cc17
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_code.h20
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_graph.cc23
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_graph.h29
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_node.cc26
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_node.h68
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_node_show.cc296
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_node_show.h173
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_options.cc134
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_options.h43
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_scope.cc12
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_scope.h19
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_show.cc168
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_show.h32
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_show_code.cc161
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_show_code.h33
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_show_test.cc2
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_stats.cc27
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_stats_test.cc40
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_tensor_test.cc2
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_timeline.cc245
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_timeline.h147
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_timeline_test.cc92
-rw-r--r--tensorflow/tools/tfprof/internal/tfprof_utils.cc14
-rw-r--r--tensorflow/tools/tfprof/tfprof_main.cc32
-rw-r--r--tensorflow/tools/tfprof/tfprof_options.proto2
-rw-r--r--tensorflow/tools/tfprof/tfprof_output.proto3
38 files changed, 1470 insertions, 668 deletions
diff --git a/tensorflow/contrib/tfprof/README.md b/tensorflow/contrib/tfprof/README.md
index d891ecdc9a..5bfa0247a5 100644
--- a/tensorflow/contrib/tfprof/README.md
+++ b/tensorflow/contrib/tfprof/README.md
@@ -2,81 +2,25 @@
# Full Docment in tensorflow/tools/tfprof/README.md
-Author: Xin Pan (xpan@google.com, github: panyx0718)
+Author: Xin Pan (xpan@google.com, github: panyx0718), Jon Shlens, Yao Zhang
Consultants: Jon Shlens, Pete Warden
###Major Features
1. Measure model parameters, float operations, tensor shapes.
-2. Measure op execution times, requested memory size and device placement.
+2. Profile op execution times, requested memory size and device placement.
3. Inspect checkpoint tensors' shapes and their values.
-4. 3 ways to view and explore TensorFlow model profiles
+4. Selectively group, filter, account and order ops.
- * Organize by Python code call stack.
- * Organize by TensorFlow operation name scope hierarchies.
- * Organize by TensorFlow operation inputs/outputs graph.
+####tfprof supports 3 views to organize TensorFlow model profiles
-5. Selectively grouping/filtering/accounting/ordering ops.
+ * code view: Stats are associated your Python codes and organized as call stacks.
+ * scope view: Stats are organized as name scope hierarchies.
+ * graph view: Stats are organized as Tensorflow Op graph.
-tfprof can be used as Python API, Interactive CLI and One-shot Script.
+####For each view, there are 3 ways to display outputs:
-## Python API Tutorials
-
-tfprof is part of TensorFlow core. Simply ```import tensorflow as tf```.
-
-### Examine the shapes and sizes of all trainiable Variables.
-```python
-# Print trainable variable parameter statistics to stdout.
-param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
- tf.get_default_graph(),
- tfprof_options=tf.contrib.tfprof.model_analyzer.
- TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
-
-# param_stats is tensorflow.tfprof.TFGraphNodeProto proto.
-# It organize the statistics
-# of each graph node in tree scructure. Let's print the root below.
-sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
-```
-
-### Examine the number of floating point operations
-``` python
-# Print to stdout an analysis of the number of floating point operations in the
-# model broken down by individual operations.
-#
-# Note: Only Ops with RegisterStatistics('flops') defined have flop stats. It
-# also requires complete shape information. It is common that shape is unknown
-# statically. To complete the shape, provide run-time shape information with
-# tf.RunMetadata to the API (See next example on how to provide RunMetadata).
-tf.contrib.tfprof.model_analyzer.print_model_analysis(
- tf.get_default_graph(),
- tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
-```
-
-### Examine the timing and memory usage
-You will first need to run the following set up in your model in order to
-compute the memory and timing statistics.
-
-```python
-# Generate the meta information for the model that contains the memory usage
-# and timing information.
-run_metadata = tf.RunMetadata()
-with tf.Session() as sess:
- _ = sess.run(train_op,
- options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
- run_metadata=run_metadata)
-```
-
-Finally, you may run `print_model_analysis` to explore the timing and memory
-demands of the model.
-
-``` python
-# Print to stdout an analysis of the memory usage and the timing information
-# from running the graph broken down by operations.
-tf.contrib.tfprof.model_analyzer.print_model_analysis(
- tf.get_default_graph(),
- run_meta=run_metadata,
- tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY)
-```
-
-Users can change ```tfprof_options``` to fully leverage tfprof's power.
+ * stdout: Results are written to stdout.
+ * timeline: Visualized in chrome browser as time series.
+ * file: Results are dumped to file.
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
index 13b407d815..17dff69edd 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
@@ -45,7 +45,7 @@ TRAINABLE_VARS_PARAMS_STAT_OPTIONS = {
'hide_name_regexes': [],
'account_displayed_op_only': True,
'select': ['params'],
- 'viz': False,
+ 'output': 'stdout',
'dump_to_file': ''
}
@@ -65,7 +65,7 @@ FLOAT_OPS_OPTIONS = {
'hide_name_regexes': [],
'account_displayed_op_only': True,
'select': ['float_ops'],
- 'viz': False,
+ 'output': 'stdout',
'dump_to_file': ''
}
@@ -87,7 +87,7 @@ PRINT_PARAMS_ON_DEVICE = {
'hide_name_regexes': [],
'account_displayed_op_only': False,
'select': ['device', 'params'],
- 'viz': False,
+ 'output': 'stdout',
'dump_to_file': ''
}
@@ -107,7 +107,7 @@ PRINT_ALL_TIMING_MEMORY = {
'hide_name_regexes': [],
'account_displayed_op_only': True,
'select': ['micros', 'bytes'],
- 'viz': False,
+ 'output': 'stdout',
'dump_to_file': ''
}
@@ -178,7 +178,7 @@ def print_model_analysis(graph,
opts.account_displayed_op_only = tfprof_options['account_displayed_op_only']
for p in tfprof_options['select']:
opts.select.append(p)
- opts.viz = tfprof_options['viz']
+ opts.output = tfprof_options['output']
opts.dump_to_file = tfprof_options['dump_to_file']
run_meta_str = run_meta.SerializeToString() if run_meta else b''
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
index 55167576e2..afd8563e78 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
@@ -35,13 +35,14 @@ class PrintModelAnalysisTest(test.TestCase):
def testDumpToFile(self):
ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
- opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
+ outfile = os.path.join(test.get_temp_dir(), 'dump')
+ opts['output'] = 'file:outfile=' + outfile
with session.Session() as sess, ops.device('/cpu:0'):
_ = lib.BuildSmallModel()
model_analyzer.print_model_analysis(sess.graph, tfprof_options=opts)
- with gfile.Open(opts['dump_to_file'], 'r') as f:
+ with gfile.Open(outfile, 'r') as f:
self.assertEqual(u'_TFProfRoot (--/451 params)\n'
' DW (3x3x3x6, 162/162 params)\n'
' DW2 (2x2x6x12, 288/288 params)\n'
@@ -51,15 +52,14 @@ class PrintModelAnalysisTest(test.TestCase):
def testSelectEverything(self):
ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
- opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
+ outfile = os.path.join(test.get_temp_dir(), 'dump')
+ opts['output'] = 'file:outfile=' + outfile
opts['account_type_regexes'] = ['.*']
opts['select'] = [
'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device', 'op_types'
]
- config = config_pb2.ConfigProto(
- graph_options=config_pb2.GraphOptions(build_cost_model=1))
- with session.Session(config=config) as sess, ops.device('/cpu:0'):
+ with session.Session() as sess, ops.device('/cpu:0'):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -72,17 +72,18 @@ class PrintModelAnalysisTest(test.TestCase):
model_analyzer.print_model_analysis(
sess.graph, run_meta, tfprof_options=opts)
- with gfile.Open(opts['dump_to_file'], 'r') as f:
+ with gfile.Open(outfile, 'r') as f:
# pylint: disable=line-too-long
self.assertEqual(
- '_TFProfRoot (0/451 params, 0/10.44k flops, 0B/5.28KB, _kTFScopeParent)\n Conv2D (0/0 params, 5.83k/5.83k flops, 432B/432B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, 384B/384B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n DW (3x3x3x6, 162/162 params, 0/0 flops, 648B/1.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n DW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/read (0/0 params, 0/0 flops, 648B/648B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, 1.15KB/2.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW2/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n DW2/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW2/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/read (0/0 params, 0/0 flops, 1.15KB/1.15KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n ScalarW (1, 1/1 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|VariableV2|_trainable_variables)\n ScalarW/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n ScalarW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n ScalarW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n ScalarW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n ScalarW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n ScalarW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/read (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Identity)\n init (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|NoOp)\n zeros (0/0 params, 0/0 flops, 864B/864B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const)\n',
+ '_TFProfRoot (0/451 params, 0/10.44k flops, 0B/5.28KB, _kTFScopeParent)\n Conv2D (0/0 params, 5.83k/5.83k flops, 432B/432B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, 384B/384B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n DW (3x3x3x6, 162/162 params, 0/0 flops, 648B/1.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW/Assign (0/0 params, 0/0 flops, 0B/0B, Assign)\n DW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/read (0/0 params, 0/0 flops, 648B/648B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, 1.15KB/2.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW2/Assign (0/0 params, 0/0 flops, 0B/0B, Assign)\n DW2/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW2/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/read (0/0 params, 0/0 flops, 1.15KB/1.15KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n ScalarW (1, 1/1 params, 0/0 flops, 0B/0B, VariableV2|_trainable_variables)\n ScalarW/Assign (0/0 params, 0/0 flops, 0B/0B, Assign)\n ScalarW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n ScalarW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n ScalarW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n ScalarW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n ScalarW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/read (0/0 params, 0/0 flops, 0B/0B, Identity)\n init (0/0 params, 0/0 flops, 0B/0B, NoOp)\n zeros (0/0 params, 0/0 flops, 864B/864B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const)\n',
f.read())
# pylint: enable=line-too-long
def testSimpleCodeView(self):
ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
- opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
+ outfile = os.path.join(test.get_temp_dir(), 'dump')
+ opts['output'] = 'file:outfile=' + outfile
opts['account_type_regexes'] = ['.*']
opts['show_name_regexes'] = ['.*model_analyzer_testlib.*']
opts['account_displayed_op_only'] = False
@@ -92,9 +93,7 @@ class PrintModelAnalysisTest(test.TestCase):
'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device',
]
- config = config_pb2.ConfigProto(
- graph_options=config_pb2.GraphOptions(build_cost_model=1))
- with session.Session(config=config) as sess, ops.device('/cpu:0'):
+ with session.Session() as sess, ops.device('/cpu:0'):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -107,7 +106,7 @@ class PrintModelAnalysisTest(test.TestCase):
model_analyzer.print_model_analysis(
sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
- with gfile.Open(opts['dump_to_file'], 'r') as f:
+ with gfile.Open(outfile, 'r') as f:
# pylint: disable=line-too-long
self.assertEqual(
'_TFProfRoot (0/451 params, 0/10.44k flops, 0B/5.28KB)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops, 0B/864B)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/1 params, 0/0 flops, 0B/0B)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/162 params, 0/0 flops, 0B/1.30KB)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/5.83k flops, 0B/432B)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/288 params, 0/0 flops, 0B/2.30KB)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/4.61k flops, 0B/384B)\n',
@@ -117,15 +116,14 @@ class PrintModelAnalysisTest(test.TestCase):
def testComplexCodeView(self):
ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
- opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
+ outfile = os.path.join(test.get_temp_dir(), 'dump')
+ opts['output'] = 'file:outfile=' + outfile
opts['account_type_regexes'] = ['.*']
opts['show_name_regexes'] = ['.*model_analyzer_testlib.py.*']
opts['account_displayed_op_only'] = False
opts['select'] = ['params', 'float_ops']
- config = config_pb2.ConfigProto(
- graph_options=config_pb2.GraphOptions(build_cost_model=1))
- with session.Session(config=config) as sess, ops.device('/cpu:0'):
+ with session.Session() as sess, ops.device('/cpu:0'):
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -139,7 +137,7 @@ class PrintModelAnalysisTest(test.TestCase):
sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
# pylint: disable=line-too-long
- with gfile.Open(opts['dump_to_file'], 'r') as f:
+ with gfile.Open(outfile, 'r') as f:
self.assertEqual(
'_TFProfRoot (0/2.84k params, 0/54.08k flops)\n model_analyzer_testlib.py:56:BuildFullModel:seq.append(array_... (0/1.80k params, 0/41.76k flops)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/4 params, 0/0 flops)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/648 params, 0/0 flops)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/1.15k params, 0/0 flops)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:60:BuildFullModel:cell, array_ops.c... (0/1.04k params, 0/4.13k flops)\n model_analyzer_testlib.py:62:BuildFullModel:target = array_op... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:63:BuildFullModel:loss = nn_ops.l2_... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:65:BuildFullModel:return sgd_op.min... (0/0 params, 0/8.19k flops)\n',
f.read())
@@ -170,9 +168,7 @@ class PrintModelAnalysisTest(test.TestCase):
'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device'
]
- config = config_pb2.ConfigProto(
- graph_options=config_pb2.GraphOptions(build_cost_model=1))
- with session.Session(config=config) as sess, ops.device('/cpu:0'):
+ with session.Session() as sess, ops.device('/cpu:0'):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py
index aa133d3142..c3e9fc9cc0 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py
@@ -51,7 +51,7 @@ TEST_OPTIONS = {
'hide_name_regexes': [],
'account_displayed_op_only': True,
'select': ['params'],
- 'viz': False
+ 'output': 'stdout',
}
# pylint: enable=bad-whitespace
@@ -92,7 +92,7 @@ class PrintModelAnalysisTest(test.TestCase):
opts.account_displayed_op_only = TEST_OPTIONS['account_displayed_op_only']
for p in TEST_OPTIONS['select']:
opts.select.append(p)
- opts.viz = TEST_OPTIONS['viz']
+ opts.output = TEST_OPTIONS['output']
with session.Session() as sess, ops.device('/cpu:0'):
_ = self._BuildSmallModel()
@@ -116,7 +116,6 @@ class PrintModelAnalysisTest(test.TestCase):
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
- device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
@@ -128,7 +127,6 @@ class PrintModelAnalysisTest(test.TestCase):
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 648
- device: "/device:CPU:0"
children {
name: "DW/Assign"
exec_micros: 0
@@ -136,7 +134,6 @@ class PrintModelAnalysisTest(test.TestCase):
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
- device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
@@ -217,7 +214,6 @@ class PrintModelAnalysisTest(test.TestCase):
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
- device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
@@ -231,7 +227,6 @@ class PrintModelAnalysisTest(test.TestCase):
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
- device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
index cd3912bbfb..e6d504d516 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
@@ -62,7 +62,7 @@ def _fill_missing_graph_shape(graph, run_meta):
return graph
-def _get_logged_ops(graph, run_meta=None, add_trace=False):
+def _get_logged_ops(graph, run_meta=None, add_trace=True):
"""Extract trainable model parameters and FLOPs for ops from a Graph.
Args:
@@ -120,9 +120,8 @@ def _get_logged_ops(graph, run_meta=None, add_trace=False):
return logged_ops
-def _merge_default_with_oplog(graph, op_log=None,
- run_meta=None,
- add_trace=False):
+def _merge_default_with_oplog(graph, op_log=None, run_meta=None,
+ add_trace=True):
"""Merge the tfprof default extra info with caller's op_log.
Args:
@@ -154,8 +153,7 @@ def _merge_default_with_oplog(graph, op_log=None,
return tmp_op_log
-def write_op_log(graph, log_dir, op_log=None, run_meta=None,
- add_trace=False):
+def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True):
"""Log provided 'op_log', and add additional model information below.
The API also assigns ops in tf.trainable_variables() an op type called
diff --git a/tensorflow/tools/tfprof/README.md b/tensorflow/tools/tfprof/README.md
index 52d376d5f5..69f09411a9 100644
--- a/tensorflow/tools/tfprof/README.md
+++ b/tensorflow/tools/tfprof/README.md
@@ -1,6 +1,6 @@
# tfprof: A Profiling Tool for TensorFlow Models
-Author: Xin Pan (xpan@google.com, github: panyx0718)
+Author: Xin Pan (xpan@google.com, github: panyx0718), Jon Shlens, Yao Zhang
Consultants: Jon Shlens, Pete Warden
@@ -8,15 +8,22 @@ Consultants: Jon Shlens, Pete Warden
###Major Features
1. Measure model parameters, float operations, tensor shapes.
-2. Measure op execution times, requested memory size and device placement.
+2. Profile op execution times, requested memory size and device placement.
3. Inspect checkpoint tensors' shapes and their values.
-4. 3 ways to view and explore TensorFlow model profiles
+4. Selectively group, filter, account and order ops.
- * Organize by Python code call stack.
- * Organize by TensorFlow operation name scope hierarchies.
- * Organize by TensorFlow operation inputs/outputs graph.
+####tfprof supports 3 views to organize TensorFlow model profiles
+
+ * code view: Stats are associated your Python codes and organized as call stacks.
+ * scope view: Stats are organized as name scope hierarchies.
+ * graph view: Stats are organized as Tensorflow Op graph.
+
+####For each view, there are 3 ways to display outputs:
+
+ * stdout: Results are written to stdout.
+ * timeline: Visualized in chrome browser as time series.
+ * file: Results are dumped to file.
-5. Selectively grouping/filtering/accounting/ordering ops.
[Python API Tutorials](#python-api-tutorials): It can be called directly from
Python codes. Results are either printed
@@ -83,13 +90,11 @@ compute the memory and timing statistics.
#
# Note: When run on GPU, a kernel is first scheduled (enqueued) and then
# executed asynchronously. tfprof only tracks the execution time.
-# Which is from proto CostGraphDef::Node::compute_cost.
# In addition, a substantial of time might be spent between Python and
# TensorFlow runtime, which is also not tracked by tfprof.
#
-config = tf.ConfigProto(graph_options=tf.GraphOptions(build_cost_model=1))
run_metadata = tf.RunMetadata()
-with tf.Session(config=config) as sess:
+with tf.Session() as sess:
_ = sess.run(train_op,
options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
run_metadata=run_metadata)
@@ -121,6 +126,18 @@ tf.contrib.tfprof.model_analyzer.print_model_analysis(
Users can change ```tfprof_options``` to fully leverage tfprof's power.
+```
+For example set opts['output'] = 'timeline:outfile=<filename>' to
+generate a timeline json file. Open a Chrome Browser, open URL
+chrome://tracing, and load the json file. Below are 2 examples of graph
+view and scope view. See code view example in later examples.
+```
+
+<left>
+![CodeTimeline](g3doc/graph_timeline.png)
+![CodeTimeline](g3doc/scope_timeline.png)
+</left>
+
## CLI Tutorials
@@ -197,8 +214,14 @@ tfprof>
# supported select fileds. Availability depends on --[run_meta|checkpoint|op_log]_path.
# [bytes|micros|params|float_ops|num_hidden_ops|tensor_value|device|op_types]
-select params
--viz false
--dump_to_file
+# format: output_type:key=value,key=value...
+# output_types: stdout (default), timeline, file.
+# key=value pairs:
+# 1. timeline: outfile=<filename>
+# 2. file: outfile=<filename>
+# 3. stdout: None.
+# E.g. timeline:outfile=/tmp/timeline.json
+-output
```
3) I want to see which line of my python codes costs most time!
@@ -222,6 +245,12 @@ _TFProfRoot (0us/22.44ms)
model_analyzer_test.py:134:testComplexCodeView:sess.run(variable... (0us/0us)
```
+Set ```-output timeline:outfile=<filename>``` to generate timeline instead of stdout.
+<left>
+![CodeTimeline](g3doc/code_timeline.png)
+</left>
+
+
4) I want to see the `BatchNorm`'s gamma value in checkpoint.
```shell
@@ -423,10 +452,10 @@ the tool adds all `Variables` inside `tf.trainable_variables()` to
12) Run tfprof in one-shot mode and dump result to file.
```shell
-# Printed to stdout if --dump_to_file is not set.
+# By default output to stdout. Use -output option to change output types.
tfprof scope --graph_path=graph.pbtxt \
--max_depth=3 \
- --dump_to_file="/tmp/dump"
+ --output="file:outfile=/tmp/dump"
Reading Files...
Parsing GraphDef...
Preparing Views...
@@ -538,4 +567,9 @@ as long as they match the `-account_xxx` options.
`-select`: Comma-separated list of metrics to show: [bytes|micros|params|float_ops|num_hidden_ops|tensor_value|device|op_types].
-`-dump_to_file`: Dump the output to a file, instead of terminal.
+`-output`: Output results as stdout, file or timeline.
+The format is ```output_type:key=value,key=value```.
+For example: ```timeline:outfile=<filename>```.
+timeline: key=outfile, value=<filename>.
+stdout: none.
+file: key=outfile, value=<filename>.
diff --git a/tensorflow/tools/tfprof/g3doc/code_timeline.png b/tensorflow/tools/tfprof/g3doc/code_timeline.png
new file mode 100644
index 0000000000..c5ab246f7d
--- /dev/null
+++ b/tensorflow/tools/tfprof/g3doc/code_timeline.png
Binary files differ
diff --git a/tensorflow/tools/tfprof/g3doc/graph_timeline.png b/tensorflow/tools/tfprof/g3doc/graph_timeline.png
new file mode 100644
index 0000000000..255a91fd5f
--- /dev/null
+++ b/tensorflow/tools/tfprof/g3doc/graph_timeline.png
Binary files differ
diff --git a/tensorflow/tools/tfprof/g3doc/scope_timeline.png b/tensorflow/tools/tfprof/g3doc/scope_timeline.png
new file mode 100644
index 0000000000..c6d95af84a
--- /dev/null
+++ b/tensorflow/tools/tfprof/g3doc/scope_timeline.png
Binary files differ
diff --git a/tensorflow/tools/tfprof/internal/BUILD b/tensorflow/tools/tfprof/internal/BUILD
index adace89985..e90f0ec40a 100644
--- a/tensorflow/tools/tfprof/internal/BUILD
+++ b/tensorflow/tools/tfprof/internal/BUILD
@@ -21,6 +21,7 @@ cc_library(
":tfprof_options",
":tfprof_scope",
":tfprof_show",
+ ":tfprof_timeline",
":tfprof_utils",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/core:lib",
@@ -31,6 +32,20 @@ cc_library(
)
cc_library(
+ name = "tfprof_timeline",
+ srcs = ["tfprof_timeline.cc"],
+ hdrs = ["tfprof_timeline.h"],
+ deps = [
+ ":tfprof_node_show",
+ ":tfprof_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/tools/tfprof:protos_all_cc",
+ "@jsoncpp_git//:jsoncpp",
+ ],
+)
+
+cc_library(
name = "tfprof_node",
srcs = ["tfprof_node.cc"],
hdrs = ["tfprof_node.h"],
@@ -71,7 +86,7 @@ cc_library(
":tfprof_node",
":tfprof_options",
":tfprof_show_code",
- ":tfprof_tensor",
+ ":tfprof_timeline",
":tfprof_utils",
"//tensorflow/c:c_api",
"//tensorflow/c:checkpoint_reader",
@@ -93,6 +108,7 @@ cc_library(
":tfprof_options",
":tfprof_show",
":tfprof_tensor",
+ ":tfprof_timeline",
":tfprof_utils",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/core:lib",
@@ -103,14 +119,31 @@ cc_library(
)
cc_library(
+ name = "tfprof_node_show",
+ srcs = ["tfprof_node_show.cc"],
+ hdrs = ["tfprof_node_show.h"],
+ deps = [
+ ":tfprof_constants",
+ ":tfprof_node",
+ ":tfprof_options",
+ ":tfprof_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/tools/tfprof:protos_all_cc",
+ ],
+)
+
+cc_library(
name = "tfprof_show",
srcs = ["tfprof_show.cc"],
hdrs = ["tfprof_show.h"],
deps = [
":tfprof_constants",
":tfprof_node",
+ ":tfprof_node_show",
":tfprof_options",
":tfprof_tensor",
+ ":tfprof_timeline",
":tfprof_utils",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/core:lib",
@@ -127,10 +160,12 @@ cc_library(
deps = [
":tfprof_constants",
":tfprof_node",
+ ":tfprof_node_show",
":tfprof_options",
":tfprof_scope",
":tfprof_show",
":tfprof_tensor",
+ ":tfprof_timeline",
":tfprof_utils",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/core:lib",
@@ -166,6 +201,27 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "tfprof_timeline_test",
+ srcs = ["tfprof_timeline_test.cc"],
+ data = [
+ "testdata/graph.pbtxt",
+ "testdata/run_meta",
+ ],
+ deps = [
+ ":tfprof_constants",
+ ":tfprof_options",
+ ":tfprof_stats",
+ ":tfprof_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/tools/tfprof:protos_all_cc",
+ ],
+)
+
cc_library(
name = "tfprof_utils",
srcs = ["tfprof_utils.cc"],
diff --git a/tensorflow/tools/tfprof/internal/print_model_analysis.cc b/tensorflow/tools/tfprof/internal/print_model_analysis.cc
index c816e3209e..f73675e8a7 100644
--- a/tensorflow/tools/tfprof/internal/print_model_analysis.cc
+++ b/tensorflow/tools/tfprof/internal/print_model_analysis.cc
@@ -56,11 +56,14 @@ string PrintModelAnalysis(const string* graph, const string* run_meta,
TFStats tf_stats(std::move(graph_ptr), std::move(run_meta_ptr),
std::move(op_log_ptr), std::move(ckpt_reader));
- Options opts = Options::FromProtoStr(*options);
+ Options opts;
+ tensorflow::Status s = Options::FromProtoStr(*options, &opts);
+ if (!s.ok()) {
+ fprintf(stderr, "%s\n", s.ToString().c_str());
+ return "";
+ }
- // TODO(xpan): We should have dump_to_file/print_stdout/etc to control
- // side-effects independently instead of one controlling the other.
- if (opts.dump_to_file.empty()) {
+ if (opts.output_type == kOutput[1]) {
printf("\n=========================Options=============================\n");
printf("%s", opts.ToString().c_str());
printf("\n==================Model Analysis Report======================\n");
diff --git a/tensorflow/tools/tfprof/internal/tfprof_code.cc b/tensorflow/tools/tfprof/internal/tfprof_code.cc
index 5618a52286..9739db1e0b 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_code.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_code.cc
@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/tools/tfprof/internal/tfprof_constants.h"
-#include "tensorflow/tools/tfprof/internal/tfprof_tensor.h"
namespace tensorflow {
namespace tfprof {
@@ -99,7 +98,8 @@ CodeNode* TFCode::BuildCodeNodes(TFCodeNode* root) {
return code_root_ptr;
}
-const ShowCodeNode* TFCode::ShowInternal(const Options& opts) {
+const ShowCodeNode* TFCode::ShowInternal(const Options& opts,
+ Timeline* timeline) {
// Search from roots recursively to find start node, if start_name_regexes
// is specified.
tfprof_trace_root_.reset(new TFCodeNode(kTFProfRoot));
@@ -117,7 +117,11 @@ const ShowCodeNode* TFCode::ShowInternal(const Options& opts) {
tfprof_code_root_->children.assign(roots.begin(), roots.end());
Account({tfprof_code_root_.get()}, opts);
- return PrintScope({tfprof_code_root_.get()}, opts, 1, 0)[0];
+ CodeNode* root = PrintScope({tfprof_code_root_.get()}, opts, 1, 0)[0];
+ if (timeline) {
+ timeline->GenerateCodeTimeline(root);
+ }
+ return root;
}
std::vector<CodeNode*> TFCode::SearchRoot(std::vector<CodeNode*> roots,
@@ -170,8 +174,13 @@ std::vector<CodeNode*> TFCode::PrintScope(const std::vector<CodeNode*> roots,
show_cnodes = SortNodes(show_cnodes, opts);
string children_str;
for (CodeNode* sc : show_cnodes) {
- children_str += sc->formatted_str;
+ if (opts.output_type == kOutput[1] || opts.output_type == kOutput[2]) {
+ children_str += sc->formatted_str;
+ sc->formatted_str.clear();
+ }
node->mutable_proto()->add_children()->MergeFrom(sc->proto());
+ sc->mutable_proto()->mutable_children()->Clear();
+ node->show_children.push_back(sc);
if (opts.account_displayed_op_only) {
node->AggregateTotalStats(sc);
}
diff --git a/tensorflow/tools/tfprof/internal/tfprof_code.h b/tensorflow/tools/tfprof/internal/tfprof_code.h
index 7c517e6b0e..d7a28624f1 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_code.h
+++ b/tensorflow/tools/tfprof/internal/tfprof_code.h
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/tools/tfprof/internal/tfprof_node.h"
#include "tensorflow/tools/tfprof/internal/tfprof_options.h"
#include "tensorflow/tools/tfprof/internal/tfprof_show_code.h"
+#include "tensorflow/tools/tfprof/internal/tfprof_timeline.h"
#include "tensorflow/tools/tfprof/internal/tfprof_utils.h"
#include "tensorflow/tools/tfprof/tfprof_log.pb.h"
#include "tensorflow/tools/tfprof/tfprof_output.pb.h"
@@ -37,22 +38,6 @@ limitations under the License.
namespace tensorflow {
namespace tfprof {
-class CodeNode : public ShowCodeNode {
- public:
- explicit CodeNode(const TFCodeNode* node) : ShowCodeNode(node) {}
- ~CodeNode() override {}
-
- void AggregateTotalStats(CodeNode* node) {
- ShowCodeNode::AggregateTotalStats(node);
- }
-
- void AddSelfToTotalStats() { ShowCodeNode::AddSelfToTotalStats(); }
-
- void ResetTotalStats() { ShowCodeNode::ResetTotalStats(); }
-
- std::vector<CodeNode*> children;
-};
-
class TFCode : public TFShowCode {
public:
explicit TFCode() : code_root_(nullptr), trace_root_(nullptr) {}
@@ -65,7 +50,8 @@ class TFCode : public TFShowCode {
private:
CodeNode* BuildCodeNodes(TFCodeNode* root);
- const ShowCodeNode* ShowInternal(const Options& opts) override;
+ const ShowCodeNode* ShowInternal(const Options& opts,
+ Timeline* timeline) override;
std::vector<CodeNode*> SearchRoot(std::vector<CodeNode*> roots,
const std::vector<string>& regexes);
diff --git a/tensorflow/tools/tfprof/internal/tfprof_graph.cc b/tensorflow/tools/tfprof/internal/tfprof_graph.cc
index 1623d9f8c4..23084146c2 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_graph.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_graph.cc
@@ -66,7 +66,7 @@ void TFGraph::Build() {
}
}
-const ShowNode* TFGraph::ShowInternal(const Options& opts) {
+const ShowNode* TFGraph::ShowInternal(const Options& opts, Timeline* timeline) {
// Search the nodes to start from.
std::vector<GraphNode*> roots = roots_;
if (opts.start_name_regexes.size() != 1 ||
@@ -81,11 +81,13 @@ const ShowNode* TFGraph::ShowInternal(const Options& opts) {
std::map<string, int64> account_visits;
Account({root}, opts, &account_visits);
- if (opts.viz) {
- printf("Visualizing feature disabled...\n");
- }
std::set<string> visits;
- return PrintGraph({root}, opts, 1, 0, 0, &visits)[0];
+ root = PrintGraph({root}, opts, 1, 0, 0, &visits)[0];
+
+ if (timeline) {
+ timeline->GenerateGraphTimeline(root);
+ }
+ return root;
}
std::vector<GraphNode*> TFGraph::SearchRoot(
@@ -155,8 +157,14 @@ std::vector<GraphNode*> TFGraph::PrintGraph(const std::vector<GraphNode*> roots,
show_cnodes = SortNodes(show_cnodes, opts);
string children_str;
for (GraphNode* sc : show_cnodes) {
- children_str += sc->formatted_str;
- node->mutable_proto()->add_children()->MergeFrom(sc->proto());
+ if (opts.output_type == kOutput[1] || opts.output_type == kOutput[2]) {
+ children_str += sc->formatted_str;
+ sc->formatted_str.clear();
+ }
+ // This swap and reinit pattern is critical for performance.
+ node->mutable_proto()->add_children()->Swap(sc->mutable_proto());
+ sc->ReInit();
+ node->show_children.push_back(sc);
if (opts.account_displayed_op_only) {
node->AggregateTotalStats(sc);
}
@@ -181,7 +189,6 @@ std::vector<GraphNode*> TFGraph::PrintGraph(const std::vector<GraphNode*> roots,
node->formatted_str += value_str;
}
}
-
node->formatted_str += children_str;
show_nodes.push_back(node);
} else {
diff --git a/tensorflow/tools/tfprof/internal/tfprof_graph.h b/tensorflow/tools/tfprof/internal/tfprof_graph.h
index 75979d020c..4d4aa8b2b1 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_graph.h
+++ b/tensorflow/tools/tfprof/internal/tfprof_graph.h
@@ -37,32 +37,6 @@ limitations under the License.
namespace tensorflow {
namespace tfprof {
-class GraphNode : public ShowNode {
- public:
- explicit GraphNode(TFGraphNode* node) : ShowNode(node) {
- mutable_proto()->set_inputs(node->inputs().size());
- mutable_proto()->set_total_inputs(0);
- }
-
- void AggregateTotalStats(GraphNode* node) {
- ShowNode::AggregateTotalStats(node);
- mutable_proto()->set_total_inputs(proto().total_inputs() +
- node->proto().total_inputs() + 1);
- }
-
- void AddSelfToTotalStats() {
- ShowNode::AddSelfToTotalStats();
- mutable_proto()->set_total_inputs(proto().total_inputs() +
- proto().inputs());
- }
-
- void ResetTotalStats() {
- ShowNode::ResetTotalStats();
- mutable_proto()->set_total_inputs(0);
- }
-
- std::vector<GraphNode*> children;
-};
// Organize tensorflow ops in a graph structure, pointing from output ops
// to input ops.
@@ -77,7 +51,8 @@ class TFGraph : public TFShow {
void Build() override;
private:
- const ShowNode* ShowInternal(const Options& opts) override;
+ const ShowNode* ShowInternal(const Options& opts,
+ Timeline* timeline) override;
bool ShouldShowIfExtra(ShowNode* node, const Options& opts,
int depth) override {
diff --git a/tensorflow/tools/tfprof/internal/tfprof_node.cc b/tensorflow/tools/tfprof/internal/tfprof_node.cc
index 5f018addb4..74c8fcbe48 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_node.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_node.cc
@@ -20,20 +20,22 @@ limitations under the License.
namespace tensorflow {
namespace tfprof {
+// Notes about start and end time from the NodeExecStats proto.
+// For GPU, there is no difference between op_end_rel_micros and
+// all_end_rel_micros. All are kernel times.
+// For CPU, op_end_rel is the kernel time, while all_end_rel_micros includes
+// some post-processing.
+// Here, we only consider kernel time for simplicity.
void TFGraphNode::AddStepStat(const string& device,
const NodeExecStats* step_stat) {
- if (!device.empty()) {
- // This might override device from GraphDef.
- device_ = device;
- }
step_stat_ = step_stat;
+ CHECK(step_stat_);
- op_start_micros_ = step_stat_->all_start_micros();
- if (step_stat_->op_end_rel_micros() && step_stat_->op_start_rel_micros()) {
- op_schedule_micros_ =
- step_stat_->op_end_rel_micros() - step_stat_->op_start_rel_micros();
- }
- all_spent_micros_ = step_stat_->all_end_rel_micros();
+ string dev = str_util::Lowercase(device);
+
+ devices_.insert(dev);
+ op_kernel_execs_[dev].push_back(std::make_pair(
+ step_stat_->all_start_micros(), step_stat_->op_end_rel_micros()));
for (const auto& output : step_stat_->output()) {
if (output.has_tensor_description() &&
@@ -44,9 +46,5 @@ void TFGraphNode::AddStepStat(const string& device,
}
}
}
-
-void TFGraphNode::AddNodeStat(const CostGraphDef::Node* cost_node) {
- kernel_compute_micros_ = cost_node->compute_cost();
-}
} // namespace tfprof
} // namespace tensorflow
diff --git a/tensorflow/tools/tfprof/internal/tfprof_node.h b/tensorflow/tools/tfprof/internal/tfprof_node.h
index 235904ea6c..8e57db7ba2 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_node.h
+++ b/tensorflow/tools/tfprof/internal/tfprof_node.h
@@ -23,12 +23,12 @@ limitations under the License.
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
-#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/tools/tfprof/internal/tfprof_options.h"
#include "tensorflow/tools/tfprof/tfprof_log.pb.h"
@@ -41,10 +41,6 @@ class TFGraphNode {
: node_(node),
code_(nullptr),
step_stat_(nullptr),
- op_start_micros_(0),
- op_schedule_micros_(0),
- kernel_compute_micros_(0),
- all_spent_micros_(0),
requested_bytes_(0),
float_ops_(0) {
if (!node) return;
@@ -69,38 +65,51 @@ class TFGraphNode {
update_shape(shape_vec);
}
op_types_.insert(node->op());
- device_ = node->device();
}
TFGraphNode() : TFGraphNode(nullptr) {}
- void AddInput(TFGraphNode* input) { inputs_[input->name()] = input; }
+ void AddInput(TFGraphNode* input, int64 output_idx) {
+ inputs_[input->name()] = input;
+ output_idx_[input->name()] = output_idx;
+ }
void AddOpType(const string& op_type) { op_types_.insert(op_type); }
void AddStepStat(const string& device, const NodeExecStats* step_stat);
- // Add CostGraphDef::Node.
- void AddNodeStat(const CostGraphDef::Node* cost_node);
-
void AddFloatOps(int64 float_ops) { float_ops_ = float_ops; }
void AddCode(const CodeDef* code) { code_ = code; }
const string& name() const { return node_->name(); }
const NodeDef* node_def() { return node_; }
+
+ const NodeExecStats* step_stats() const { return step_stat_; }
+
const std::map<string, TFGraphNode*>& inputs() const { return inputs_; }
- int64 op_start_micros() { return op_start_micros_; }
- // This is time spent in Op::Compute(), which is GPU kernel schedule time.
- // Currently not used.
- int64 op_schedule_micros() { return op_schedule_micros_; }
+ const std::map<string, int64>& output_idx() { return output_idx_; }
+
// This is time spent in kernel execution.
- int64 kernel_compute_micros() const { return kernel_compute_micros_; }
- int64 all_spent_micros() { return all_spent_micros_; }
+ int64 kernel_exec_micros() const {
+ if (!step_stat_) return 0;
+ int64 total = 0;
+ for (const auto& execs : op_kernel_execs_) {
+ for (const auto& exec : execs.second) {
+ total += exec.second;
+ }
+ }
+ return total;
+ }
+ const std::map<string, std::vector<std::pair<int64, int64>>>&
+ op_kernel_execs() const {
+ return op_kernel_execs_;
+ }
+
int64 requested_bytes() const { return requested_bytes_; }
int64 float_ops() const { return float_ops_; }
const CodeDef* code() { return code_; }
- string device() const { return device_; }
+ std::set<string> devices() const { return devices_; }
const std::set<string>& op_types() const { return op_types_; }
const std::vector<int64>& shape() const { return shape_; }
@@ -109,17 +118,19 @@ class TFGraphNode {
void update_shape(const std::vector<int64>& shape) { shape_ = shape; }
std::map<string, TFGraphNode*> inputs_;
+ std::map<string, int64> output_idx_;
+
const NodeDef* node_;
const CodeDef* code_;
const NodeExecStats* step_stat_;
std::vector<int64> shape_;
std::set<string> op_types_;
- string device_;
- int64 op_start_micros_;
- int64 op_schedule_micros_;
- int64 kernel_compute_micros_;
- int64 all_spent_micros_;
+
+ // device -> vector of {op_start_micros, op_kernel_exec_micros} pairs.
+ std::map<string, std::vector<std::pair<int64, int64>>> op_kernel_execs_;
+
+ std::set<string> devices_;
int64 requested_bytes_;
int64 float_ops_;
};
@@ -128,7 +139,7 @@ class TFCodeNode {
public:
TFCodeNode(const string& trace)
: trace_(trace),
- kernel_compute_micros_(0),
+ kernel_exec_micros_(0),
requested_bytes_(0),
float_ops_(0) {}
@@ -138,16 +149,15 @@ class TFCodeNode {
}
nodes_[node->name()] = node;
- kernel_compute_micros_ += node->kernel_compute_micros();
+ kernel_exec_micros_ += node->kernel_exec_micros();
requested_bytes_ += node->requested_bytes();
float_ops_ += node->float_ops();
op_types_.insert(node->op_types().begin(), node->op_types().end());
if (node->shape().size() > 0) {
shapes_.push_back(node->shape());
}
- if (!node->device().empty()) {
- devices_.insert(node->device());
- }
+ std::set<string> devices = node->devices();
+ devices_.insert(devices.begin(), devices.end());
}
const std::map<string, const TFGraphNode*>& graph_nodes() const {
return nodes_;
@@ -165,7 +175,7 @@ class TFCodeNode {
const string& name() const { return trace_; }
- int64 kernel_compute_micros() const { return kernel_compute_micros_; }
+ int64 kernel_exec_micros() const { return kernel_exec_micros_; }
int64 requested_bytes() const { return requested_bytes_; }
@@ -180,7 +190,7 @@ class TFCodeNode {
private:
const string trace_;
std::set<string> op_types_;
- int64 kernel_compute_micros_;
+ int64 kernel_exec_micros_;
int64 requested_bytes_;
int64 float_ops_;
diff --git a/tensorflow/tools/tfprof/internal/tfprof_node_show.cc b/tensorflow/tools/tfprof/internal/tfprof_node_show.cc
new file mode 100644
index 0000000000..2b5390676d
--- /dev/null
+++ b/tensorflow/tools/tfprof/internal/tfprof_node_show.cc
@@ -0,0 +1,296 @@
+/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/tools/tfprof/internal/tfprof_node_show.h"
+
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+namespace tensorflow {
+namespace tfprof {
+ShowNode::ShowNode(const TFGraphNode* node) : node(node), account(true) {
+ ReInit();
+}
+
+void ShowNode::ReInit() {
+ mutable_proto()->set_name(name());
+ for (const string& device : node->devices()) {
+ *mutable_proto()->mutable_devices()->Add() = device;
+ }
+ mutable_proto()->set_exec_micros(node->kernel_exec_micros());
+ mutable_proto()->set_requested_bytes(node->requested_bytes());
+ mutable_proto()->set_float_ops(node->float_ops());
+
+ if (!node->shape().empty()) {
+ int64 params = 1;
+ bool complete_shape = true;
+ for (int64 d : node->shape()) {
+ // Sometimes parameters could be <0 when a dim is unknown.
+ if (d < 0) {
+ complete_shape = false;
+ break;
+ }
+ params *= d;
+ }
+ if (complete_shape) {
+ mutable_proto()->set_parameters(proto_.parameters() + params);
+ } else {
+ fprintf(stderr, "Incomplete shape.");
+ }
+ }
+}
+
+string ShowNode::Format(const Options& opts) {
+ if (opts.select.empty()) {
+ return name();
+ }
+ return strings::Printf("%s (%s)", name().c_str(), FormatMeta(opts).c_str());
+}
+
+string ShowNode::FormatMeta(const Options& opts) {
+ std::vector<string> info;
+ if (opts.select.find(kShown[2]) != opts.select.end()) {
+ const string shape = FormatShapes(node->shape());
+ if (!shape.empty()) {
+ info.push_back(shape);
+ }
+ string params = FormatNumber(proto().total_parameters()) + " params";
+ if (account) {
+ params = FormatNumber(proto().parameters()) + "/" + params;
+ } else {
+ params = "--/" + params;
+ }
+ info.push_back(params);
+ }
+ if (opts.select.find(kShown[3]) != opts.select.end()) {
+ string fops = FormatNumber(proto().total_float_ops()) + " flops";
+ if (account) {
+ fops = FormatNumber(proto().float_ops()) + "/" + fops;
+ } else {
+ fops = "--/" + fops;
+ }
+ info.push_back(fops);
+ }
+ if (opts.select.find(kShown[0]) != opts.select.end()) {
+ string memory = FormatMemory(proto().total_requested_bytes());
+ if (account) {
+ memory = FormatMemory(proto().requested_bytes()) + "/" + memory;
+
+ } else {
+ memory = "--/" + memory;
+ }
+ info.push_back(memory);
+ }
+ if (opts.select.find(kShown[1]) != opts.select.end()) {
+ string time = FormatTime(proto().total_exec_micros());
+ if (account) {
+ time = FormatTime(proto().exec_micros()) + "/" + time;
+ } else {
+ time = "--/" + time;
+ }
+ info.push_back(time);
+ }
+ if (opts.select.find(kShown[6]) != opts.select.end()) {
+ if (proto().devices_size() > 0) {
+ info.push_back(str_util::Join(proto().devices(), "|"));
+ }
+ }
+ if (opts.select.find(kShown[7]) != opts.select.end()) {
+ std::set<string> op_types = node->op_types();
+ // Device is considered a type.
+ if (proto().devices_size() > 0) {
+ op_types.insert(str_util::Join(proto().devices(), "|"));
+ }
+ info.push_back(str_util::Join(op_types, "|"));
+ }
+ return str_util::Join(info, ", ");
+}
+
+TFGraphNodeProto* ShowNode::mutable_proto() { return &proto_; }
+
+const TFGraphNodeProto& ShowNode::proto() const { return proto_; }
+
+void ShowNode::AggregateTotalStats(ShowNode* node) {
+ TFGraphNodeProto* node_pb = node->mutable_proto();
+ mutable_proto()->set_total_exec_micros(proto().total_exec_micros() +
+ node_pb->total_exec_micros());
+ mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() +
+ node_pb->total_requested_bytes());
+ mutable_proto()->set_total_parameters(proto().total_parameters() +
+ node_pb->total_parameters());
+ mutable_proto()->set_total_float_ops(proto().total_float_ops() +
+ node_pb->total_float_ops());
+}
+
+void ShowNode::AddSelfToTotalStats() {
+ mutable_proto()->set_total_exec_micros(proto().total_exec_micros() +
+ proto().exec_micros());
+ mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() +
+ proto().requested_bytes());
+ mutable_proto()->set_total_parameters(proto().total_parameters() +
+ proto().parameters());
+ mutable_proto()->set_total_float_ops(proto().total_float_ops() +
+ proto().float_ops());
+}
+
+void ShowNode::ResetTotalStats() {
+ mutable_proto()->set_total_exec_micros(0);
+ mutable_proto()->set_total_requested_bytes(0);
+ mutable_proto()->set_total_parameters(0);
+ mutable_proto()->set_total_float_ops(0);
+ mutable_proto()->mutable_children()->Clear();
+}
+
+ShowCodeNode::ShowCodeNode(const TFCodeNode* node) : node(node), account(true) {
+ std::vector<ScopeNode> snodes;
+ for (auto it : node->graph_nodes()) {
+ ScopeNode snode(it.second);
+ snodes.push_back(snode);
+ snodes[snodes.size() - 1].AddSelfToTotalStats();
+ *mutable_proto()->mutable_graph_nodes()->Add() =
+ snodes[snodes.size() - 1].proto();
+ }
+
+ mutable_proto()->set_name(name());
+ mutable_proto()->set_exec_micros(node->kernel_exec_micros());
+ mutable_proto()->set_requested_bytes(node->requested_bytes());
+ mutable_proto()->set_float_ops(node->float_ops());
+
+ if (!node->shapes().empty()) {
+ for (const std::vector<int64>& shape : node->shapes()) {
+ int64 params = 1;
+ bool complete_shape = true;
+ for (int64 d : shape) {
+ // Sometimes parameters could be <0 when a dim is unknown.
+ if (d < 0) {
+ complete_shape = false;
+ break;
+ }
+ params *= d;
+ }
+ if (complete_shape) {
+ mutable_proto()->set_parameters(proto().parameters() + params);
+ } else {
+ fprintf(stderr, "Incomplete shape.");
+ }
+ }
+ }
+}
+
+string ShowCodeNode::Format(const Options& opts) {
+ if (opts.select.empty()) {
+ return name();
+ }
+ return strings::Printf("%s (%s)", name().c_str(), FormatMeta(opts).c_str());
+}
+
+string ShowCodeNode::FormatMeta(const Options& opts) {
+ std::vector<string> info;
+ std::vector<string> shapes;
+ if (opts.select.find(kShown[2]) != opts.select.end()) {
+ for (const std::vector<int64>& shape : node->shapes()) {
+ if (!shape.empty()) {
+ shapes.push_back(FormatShapes(shape));
+ }
+ }
+ if (!shapes.empty()) {
+ info.push_back(str_util::Join(shapes, "|"));
+ }
+ string params = FormatNumber(proto().total_parameters()) + " params";
+ if (account) {
+ params = FormatNumber(proto().parameters()) + "/" + params;
+ } else {
+ params = "--/" + params;
+ }
+ info.push_back(params);
+ }
+ if (opts.select.find(kShown[3]) != opts.select.end()) {
+ string fops = FormatNumber(proto().total_float_ops()) + " flops";
+ if (account) {
+ fops = FormatNumber(proto().float_ops()) + "/" + fops;
+ } else {
+ fops = "--/" + fops;
+ }
+ info.push_back(fops);
+ }
+ if (opts.select.find(kShown[0]) != opts.select.end()) {
+ string memory = FormatMemory(proto().total_requested_bytes());
+ if (account) {
+ memory = FormatMemory(proto().requested_bytes()) + "/" + memory;
+
+ } else {
+ memory = "--/" + memory;
+ }
+ info.push_back(memory);
+ }
+ if (opts.select.find(kShown[1]) != opts.select.end()) {
+ string time = FormatTime(proto().total_exec_micros());
+ if (account) {
+ time = FormatTime(proto().exec_micros()) + "/" + time;
+ } else {
+ time = "--/" + time;
+ }
+ info.push_back(time);
+ }
+ if (opts.select.find(kShown[6]) != opts.select.end()) {
+ if (!node->devices().empty()) {
+ info.push_back(str_util::Join(node->devices(), "|"));
+ }
+ }
+ if (opts.select.find(kShown[7]) != opts.select.end()) {
+ std::set<string> op_types = node->op_types();
+ // Device is considered a type.
+ op_types.insert(node->devices().cbegin(), node->devices().cend());
+ info.push_back(str_util::Join(op_types, "|"));
+ }
+ return str_util::Join(info, ", ");
+}
+
+TFCodeNodeProto* ShowCodeNode::mutable_proto() { return &proto_; }
+
+const TFCodeNodeProto& ShowCodeNode::proto() const { return proto_; }
+
+void ShowCodeNode::AggregateTotalStats(ShowCodeNode* node) {
+ TFCodeNodeProto* node_pb = node->mutable_proto();
+ mutable_proto()->set_total_exec_micros(proto().total_exec_micros() +
+ node_pb->total_exec_micros());
+ mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() +
+ node_pb->total_requested_bytes());
+ mutable_proto()->set_total_parameters(proto().total_parameters() +
+ node_pb->total_parameters());
+ mutable_proto()->set_total_float_ops(proto().total_float_ops() +
+ node_pb->total_float_ops());
+}
+
+void ShowCodeNode::AddSelfToTotalStats() {
+ mutable_proto()->set_total_exec_micros(proto().total_exec_micros() +
+ proto().exec_micros());
+ mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() +
+ proto().requested_bytes());
+ mutable_proto()->set_total_parameters(proto().total_parameters() +
+ proto().parameters());
+ mutable_proto()->set_total_float_ops(proto().total_float_ops() +
+ proto().float_ops());
+}
+
+void ShowCodeNode::ResetTotalStats() {
+ mutable_proto()->set_total_exec_micros(0);
+ mutable_proto()->set_total_requested_bytes(0);
+ mutable_proto()->set_total_parameters(0);
+ mutable_proto()->set_total_float_ops(0);
+ mutable_proto()->mutable_children()->Clear();
+}
+
+} // namespace tfprof
+} // namespace tensorflow
diff --git a/tensorflow/tools/tfprof/internal/tfprof_node_show.h b/tensorflow/tools/tfprof/internal/tfprof_node_show.h
new file mode 100644
index 0000000000..4ce0f63f9b
--- /dev/null
+++ b/tensorflow/tools/tfprof/internal/tfprof_node_show.h
@@ -0,0 +1,173 @@
+/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Nodes used for different views.
+// ScopeNode is for scope view. GraphNode is for graph view and CodeNode
+// is for code view.
+
+#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_NODE_SHOW_H_
+#define THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_NODE_SHOW_H_
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/tools/tfprof/internal/tfprof_constants.h"
+#include "tensorflow/tools/tfprof/internal/tfprof_node.h"
+#include "tensorflow/tools/tfprof/internal/tfprof_options.h"
+#include "tensorflow/tools/tfprof/internal/tfprof_utils.h"
+#include "tensorflow/tools/tfprof/tfprof_output.pb.h"
+
+namespace tensorflow {
+namespace tfprof {
+
+class ShowNode {
+ public:
+ explicit ShowNode(const TFGraphNode* node);
+ virtual ~ShowNode() {}
+
+ const string& name() const { return node->name(); }
+ TFGraphNodeProto* mutable_proto();
+ const TFGraphNodeProto& proto() const;
+
+ void ReInit();
+
+ string Format(const Options& opts);
+
+ string FormatMeta(const Options& opts);
+
+ const TFGraphNode* node;
+ bool account;
+ string formatted_str;
+
+ protected:
+ void AggregateTotalStats(ShowNode* node);
+
+ void AddSelfToTotalStats();
+
+ void ResetTotalStats();
+
+ TFGraphNodeProto proto_;
+};
+
+class GraphNode : public ShowNode {
+ public:
+ explicit GraphNode(TFGraphNode* node) : ShowNode(node) {
+ mutable_proto()->set_inputs(node->inputs().size());
+ mutable_proto()->set_total_inputs(0);
+ }
+
+ void ReInit() {
+ ShowNode::ReInit();
+ mutable_proto()->set_inputs(node->inputs().size());
+ mutable_proto()->set_total_inputs(0);
+ }
+
+ void AggregateTotalStats(GraphNode* node) {
+ ShowNode::AggregateTotalStats(node);
+ mutable_proto()->set_total_inputs(proto().total_inputs() +
+ node->proto().total_inputs() + 1);
+ }
+
+ void AddSelfToTotalStats() {
+ ShowNode::AddSelfToTotalStats();
+ mutable_proto()->set_total_inputs(proto().total_inputs() +
+ proto().inputs());
+ }
+
+ void ResetTotalStats() {
+ ShowNode::ResetTotalStats();
+ mutable_proto()->set_total_inputs(0);
+ show_children.clear();
+ }
+
+ std::vector<GraphNode*> children;
+ std::vector<GraphNode*> show_children;
+};
+
+class ScopeNode : public ShowNode {
+ public:
+ explicit ScopeNode(const TFGraphNode* node) : ShowNode(node) {}
+ ~ScopeNode() override {}
+
+ void ReInit() { ShowNode::ReInit(); }
+
+ void AggregateTotalStats(ScopeNode* node) {
+ ShowNode::AggregateTotalStats(node);
+ }
+
+ void AddSelfToTotalStats() { ShowNode::AddSelfToTotalStats(); }
+
+ void ResetTotalStats() {
+ ShowNode::ResetTotalStats();
+ show_children.clear();
+ }
+
+ std::vector<ScopeNode*> children;
+ std::vector<ScopeNode*> show_children;
+};
+
+class ShowCodeNode {
+ public:
+ explicit ShowCodeNode(const TFCodeNode* node);
+ virtual ~ShowCodeNode() {}
+
+ const string& name() const { return node->name(); }
+ TFCodeNodeProto* mutable_proto();
+ const TFCodeNodeProto& proto() const;
+
+ string Format(const Options& opts);
+
+ string FormatMeta(const Options& opts);
+
+ const TFCodeNode* node;
+ bool account;
+ string formatted_str;
+
+ protected:
+ void AggregateTotalStats(ShowCodeNode* node);
+
+ void AddSelfToTotalStats();
+
+ void ResetTotalStats();
+
+ TFCodeNodeProto proto_;
+};
+
+class CodeNode : public ShowCodeNode {
+ public:
+ explicit CodeNode(const TFCodeNode* node) : ShowCodeNode(node) {}
+ ~CodeNode() override {}
+
+ void AggregateTotalStats(CodeNode* node) {
+ ShowCodeNode::AggregateTotalStats(node);
+ }
+
+ void AddSelfToTotalStats() { ShowCodeNode::AddSelfToTotalStats(); }
+
+ void ResetTotalStats() {
+ ShowCodeNode::ResetTotalStats();
+ show_children.clear();
+ }
+
+ std::vector<CodeNode*> children;
+ std::vector<CodeNode*> show_children;
+};
+} // namespace tfprof
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_NODE_SHOW_H_
diff --git a/tensorflow/tools/tfprof/internal/tfprof_options.cc b/tensorflow/tools/tfprof/internal/tfprof_options.cc
index 03282533ff..f592a4cf8c 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_options.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_options.cc
@@ -17,16 +17,133 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/tools/tfprof/tfprof_options.pb.h"
namespace tensorflow {
namespace tfprof {
+namespace {
+string KeyValueToStr(const std::map<string, string>& kv_map) {
+ std::vector<string> kv_vec;
+ kv_vec.reserve(kv_map.size());
+ for (const auto& pair : kv_map) {
+ kv_vec.push_back(strings::StrCat(pair.first, "=", pair.second));
+ }
+ return str_util::Join(kv_vec, ",");
+}
+} // namespace
+
+tensorflow::Status ParseOutput(const string& output_opt, string* output_type,
+ std::map<string, string>* output_options) {
+ // The default is to use stdout.
+ if (output_opt.empty()) {
+ *output_type = kOutput[1];
+ return tensorflow::Status::OK();
+ }
+
+ std::set<string> output_types(kOutput,
+ kOutput + sizeof(kOutput) / sizeof(*kOutput));
+ auto opt_split = output_opt.find(":");
+ std::vector<string> kv_split;
+ if (opt_split == output_opt.npos) {
+ if (output_types.find(output_opt) == output_types.end()) {
+ return tensorflow::Status(
+ tensorflow::error::INVALID_ARGUMENT,
+ strings::Printf("E.g. Unknown output type: %s, Valid types: %s\n",
+ output_opt.c_str(),
+ str_util::Join(output_types, ",").c_str()));
+ }
+ *output_type = output_opt;
+ } else {
+ *output_type = output_opt.substr(0, opt_split);
+ if (output_types.find(*output_type) == output_types.end()) {
+ return tensorflow::Status(
+ tensorflow::error::INVALID_ARGUMENT,
+ strings::Printf("E.g. Unknown output type: %s, Valid types: %s\n",
+ output_type->c_str(),
+ str_util::Join(output_types, ",").c_str()));
+ }
+ kv_split = str_util::Split(output_opt.substr(opt_split + 1), ",",
+ str_util::SkipEmpty());
+ }
-Options Options::FromProtoStr(const string& opts_proto_str) {
+ std::set<string> valid_options;
+ std::set<string> required_options;
+ if (*output_type == kOutput[0]) {
+ valid_options.insert(
+ kTimelineOpts,
+ kTimelineOpts + sizeof(kTimelineOpts) / sizeof(*kTimelineOpts));
+ required_options.insert(
+ kTimelineRequiredOpts,
+ kTimelineRequiredOpts +
+ sizeof(kTimelineRequiredOpts) / sizeof(*kTimelineRequiredOpts));
+ } else if (*output_type == kOutput[2]) {
+ valid_options.insert(kFileOpts,
+ kFileOpts + sizeof(kFileOpts) / sizeof(*kFileOpts));
+ required_options.insert(kFileRequiredOpts,
+ kFileRequiredOpts + sizeof(kFileRequiredOpts) /
+ sizeof(*kFileRequiredOpts));
+ }
+
+ for (const string& kv_str : kv_split) {
+ const std::vector<string> kv =
+ str_util::Split(kv_str, "=", str_util::SkipEmpty());
+ if (kv.size() != 2) {
+ return tensorflow::Status(
+ tensorflow::error::INVALID_ARGUMENT,
+ "Visualize format: -output timeline:key=value,key=value,...");
+ }
+ if (valid_options.find(kv[0]) == valid_options.end()) {
+ return tensorflow::Status(
+ tensorflow::error::INVALID_ARGUMENT,
+ strings::Printf("Unrecognized options %s for output_type: %s\n",
+ kv[0].c_str(), output_type->c_str()));
+ }
+ (*output_options)[kv[0]] = kv[1];
+ }
+
+ for (const string& opt : required_options) {
+ if (output_options->find(opt) == output_options->end()) {
+ return tensorflow::Status(
+ tensorflow::error::INVALID_ARGUMENT,
+ strings::Printf("Missing required output_options for %s\n"
+ "E.g. -output %s:%s=...\n",
+ output_type->c_str(), output_type->c_str(),
+ opt.c_str()));
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Options::FromProtoStr(const string& opts_proto_str,
+ Options* opts) {
OptionsProto opts_pb;
- CHECK(opts_pb.ParseFromString(opts_proto_str));
- Options opts(
+ if (!opts_pb.ParseFromString(opts_proto_str)) {
+ return tensorflow::Status(
+ tensorflow::error::INTERNAL,
+ strings::StrCat("Failed to parse option string from Python API: ",
+ opts_proto_str));
+ }
+
+ string output_type;
+ std::map<string, string> output_options;
+ tensorflow::Status s =
+ ParseOutput(opts_pb.output(), &output_type, &output_options);
+ if (!s.ok()) return s;
+
+ if (!opts_pb.dump_to_file().empty()) {
+ fprintf(stderr,
+ "-dump_to_file option is deprecated. "
+ "Please use -output file:outfile=<filename>\n");
+ fprintf(stderr, "-output %s is overwritten with -output file:outfile=%s\n",
+ opts_pb.output().c_str(), opts_pb.dump_to_file().c_str());
+ output_type = kOutput[2];
+ output_options.clear();
+ output_options[kFileOpts[0]] = opts_pb.dump_to_file();
+ }
+
+ *opts = Options(
opts_pb.max_depth(), opts_pb.min_bytes(), opts_pb.min_micros(),
opts_pb.min_params(), opts_pb.min_float_ops(),
std::vector<string>(opts_pb.device_regexes().begin(),
@@ -44,8 +161,8 @@ Options Options::FromProtoStr(const string& opts_proto_str) {
opts_pb.hide_name_regexes().end()),
opts_pb.account_displayed_op_only(),
std::vector<string>(opts_pb.select().begin(), opts_pb.select().end()),
- opts_pb.viz(), opts_pb.dump_to_file());
- return opts;
+ output_type, output_options);
+ return tensorflow::Status::OK();
}
string Options::ToString() const {
@@ -64,8 +181,7 @@ string Options::ToString() const {
"%-28s%s\n"
"%-28s%s\n"
"%-28s%s\n"
- "%-28s%s\n"
- "%-28s%s\n",
+ "%-28s%s:%s\n",
kOptions[0], max_depth, kOptions[1], min_bytes, kOptions[2], min_micros,
kOptions[3], min_params, kOptions[4], min_float_ops, kOptions[5],
str_util::Join(device_regexes, ",").c_str(), kOptions[6],
@@ -76,8 +192,8 @@ string Options::ToString() const {
str_util::Join(show_name_regexes, ",").c_str(), kOptions[11],
str_util::Join(hide_name_regexes, ",").c_str(), kOptions[12],
(account_displayed_op_only ? "true" : "false"), kOptions[13],
- str_util::Join(select, ",").c_str(), kOptions[14],
- (viz ? "true" : "false"), kOptions[15], dump_to_file.c_str());
+ str_util::Join(select, ",").c_str(), kOptions[14], output_type.c_str(),
+ KeyValueToStr(output_options).c_str());
return s;
}
diff --git a/tensorflow/tools/tfprof/internal/tfprof_options.h b/tensorflow/tools/tfprof/internal/tfprof_options.h
index 0a9f2768e0..cf48b4de81 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_options.h
+++ b/tensorflow/tools/tfprof/internal/tfprof_options.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace tfprof {
@@ -40,8 +41,7 @@ static const char* const kOptions[] = {
"-hide_name_regexes",
"-account_displayed_op_only",
"-select",
- "-viz",
- "-dump_to_file",
+ "-output",
};
static const char* const kOrderBy[] = {
@@ -58,11 +58,30 @@ static const char* const kCmds[] = {
"scope", "graph", "code", "set", "help",
};
+static const char* const kOutput[] = {"timeline", "stdout", "file"};
+
+static const char* const kTimelineOpts[] = {
+ "outfile",
+};
+
+static const char* const kTimelineRequiredOpts[] = {"outfile"};
+
+static const char* const kFileOpts[] = {
+ "outfile",
+};
+
+static const char* const kFileRequiredOpts[] = {
+ "outfile",
+};
+
struct Options {
public:
- static Options FromProtoStr(const string& opts_proto_str);
+ static tensorflow::Status FromProtoStr(const string& opts_proto_str,
+ Options* opts);
virtual ~Options() {}
+ Options()
+ : Options(0, 0, 0, 0, 0, {}, "", {}, {}, {}, {}, {}, false, {}, "", {}) {}
Options(int max_depth, tensorflow::int64 min_bytes,
tensorflow::int64 min_micros, tensorflow::int64 min_params,
tensorflow::int64 min_float_ops,
@@ -73,7 +92,8 @@ struct Options {
const std::vector<string>& show_name_regexes,
const std::vector<string>& hide_name_regexes,
bool account_displayed_op_only, const std::vector<string>& select,
- bool viz, const string& dump_to_file = "")
+ const string& output_type,
+ const std::map<string, string>& output_options)
: max_depth(max_depth),
min_bytes(min_bytes),
min_micros(min_micros),
@@ -88,8 +108,8 @@ struct Options {
hide_name_regexes(hide_name_regexes),
account_displayed_op_only(account_displayed_op_only),
select(select.begin(), select.end()),
- viz(viz),
- dump_to_file(dump_to_file) {}
+ output_type(output_type),
+ output_options(output_options) {}
string ToString() const;
@@ -109,10 +129,17 @@ struct Options {
bool account_displayed_op_only;
std::set<string> select;
- bool viz;
- string dump_to_file;
+
+ string output_type;
+ std::map<string, string> output_options;
};
+// Parse the -output option.
+// 'output_opt': User input string with format: output_type:key=value,key=value.
+// 'output_type' and 'output_options' are extracted from 'output_opt'.
+tensorflow::Status ParseOutput(const string& output_opt, string* output_type,
+ std::map<string, string>* output_options);
+
} // namespace tfprof
} // namespace tensorflow
diff --git a/tensorflow/tools/tfprof/internal/tfprof_scope.cc b/tensorflow/tools/tfprof/internal/tfprof_scope.cc
index b4aef717c8..fe525c4bd8 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_scope.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_scope.cc
@@ -72,7 +72,7 @@ void TFScope::Build() {
}
}
-const ShowNode* TFScope::ShowInternal(const Options& opts) {
+const ShowNode* TFScope::ShowInternal(const Options& opts, Timeline* timeline) {
// Search from roots recursively to find start node, if start_name_regexes
// is specified.
std::vector<ScopeNode*> roots = roots_;
@@ -86,6 +86,9 @@ const ShowNode* TFScope::ShowInternal(const Options& opts) {
Account({root}, opts);
root = PrintScope({root}, opts, 1, 0)[0];
+ if (timeline) {
+ timeline->GenerateScopeTimeline(root);
+ }
return root;
}
@@ -139,8 +142,13 @@ std::vector<ScopeNode*> TFScope::PrintScope(const std::vector<ScopeNode*> roots,
show_cnodes = SortNodes(show_cnodes, opts);
string children_str;
for (ScopeNode* sc : show_cnodes) {
- children_str += sc->formatted_str;
+ if (opts.output_type == kOutput[1] || opts.output_type == kOutput[2]) {
+ children_str += sc->formatted_str;
+ sc->formatted_str.clear();
+ }
node->mutable_proto()->add_children()->MergeFrom(sc->proto());
+ sc->mutable_proto()->mutable_children()->Clear();
+ node->show_children.push_back(sc);
if (opts.account_displayed_op_only) {
node->AggregateTotalStats(sc);
}
diff --git a/tensorflow/tools/tfprof/internal/tfprof_scope.h b/tensorflow/tools/tfprof/internal/tfprof_scope.h
index 2e2e4f5266..7bdcc794cd 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_scope.h
+++ b/tensorflow/tools/tfprof/internal/tfprof_scope.h
@@ -37,22 +37,6 @@ limitations under the License.
namespace tensorflow {
namespace tfprof {
-class ScopeNode : public ShowNode {
- public:
- explicit ScopeNode(const TFGraphNode* node) : ShowNode(node) {}
- ~ScopeNode() override {}
-
- void AggregateTotalStats(ScopeNode* node) {
- ShowNode::AggregateTotalStats(node);
- }
-
- void AddSelfToTotalStats() { ShowNode::AddSelfToTotalStats(); }
-
- void ResetTotalStats() { ShowNode::ResetTotalStats(); }
-
- std::vector<ScopeNode*> children;
-};
-
class TFScope : public TFShow {
public:
explicit TFScope(checkpoint::CheckpointReader* ckpt_reader)
@@ -64,7 +48,8 @@ class TFScope : public TFShow {
void Build() override;
private:
- const ShowNode* ShowInternal(const Options& opts) override;
+ const ShowNode* ShowInternal(const Options& opts,
+ Timeline* timeline) override;
ScopeNode* CreateParentNode(const string& name);
diff --git a/tensorflow/tools/tfprof/internal/tfprof_show.cc b/tensorflow/tools/tfprof/internal/tfprof_show.cc
index 932dfb3893..b96db5468e 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_show.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_show.cc
@@ -18,155 +18,32 @@ limitations under the License.
#include <memory>
#include <set>
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/regexp.h"
namespace tensorflow {
namespace tfprof {
-ShowNode::ShowNode(const TFGraphNode* node) : node(node), account(true) {
- mutable_proto()->set_name(name());
- if (!node->device().empty()) {
- mutable_proto()->set_device(node->device());
- }
- mutable_proto()->set_exec_micros(node->kernel_compute_micros());
- mutable_proto()->set_requested_bytes(node->requested_bytes());
- mutable_proto()->set_float_ops(node->float_ops());
-
- if (!node->shape().empty()) {
- int64 params = 1;
- bool complete_shape = true;
- for (int64 d : node->shape()) {
- // Sometimes parameters could be <0 when a dim is unknown.
- if (d < 0) {
- complete_shape = false;
- break;
- }
- params *= d;
- }
- if (complete_shape) {
- mutable_proto()->set_parameters(proto_.parameters() + params);
- } else {
- fprintf(stderr, "Incomplete shape.");
- }
- }
-}
-
-string ShowNode::Format(const Options& opts) {
- if (opts.select.empty()) {
- return name();
- }
- return strings::Printf("%s (%s)", name().c_str(), FormatMeta(opts).c_str());
-}
-
-string ShowNode::FormatMeta(const Options& opts) {
- std::vector<string> info;
- if (opts.select.find(kShown[2]) != opts.select.end()) {
- const string shape = FormatShapes(node->shape());
- if (!shape.empty()) {
- info.push_back(shape);
- }
- string params = FormatNumber(proto().total_parameters()) + " params";
- if (account) {
- params = FormatNumber(proto().parameters()) + "/" + params;
- } else {
- params = "--/" + params;
- }
- info.push_back(params);
- }
- if (opts.select.find(kShown[3]) != opts.select.end()) {
- string fops = FormatNumber(proto().total_float_ops()) + " flops";
- if (account) {
- fops = FormatNumber(proto().float_ops()) + "/" + fops;
- } else {
- fops = "--/" + fops;
- }
- info.push_back(fops);
- }
- if (opts.select.find(kShown[0]) != opts.select.end()) {
- string memory = FormatMemory(proto().total_requested_bytes());
- if (account) {
- memory = FormatMemory(proto().requested_bytes()) + "/" + memory;
-
- } else {
- memory = "--/" + memory;
- }
- info.push_back(memory);
- }
- if (opts.select.find(kShown[1]) != opts.select.end()) {
- string time = FormatTime(proto().total_exec_micros());
- if (account) {
- time = FormatTime(proto().exec_micros()) + "/" + time;
- } else {
- time = "--/" + time;
- }
- info.push_back(time);
- }
- if (opts.select.find(kShown[6]) != opts.select.end()) {
- if (!proto().device().empty()) {
- info.push_back(proto().device());
- }
- }
- if (opts.select.find(kShown[7]) != opts.select.end()) {
- std::set<string> op_types = node->op_types();
- // Device is considered a type.
- if (!proto().device().empty()) {
- op_types.insert(proto().device());
- }
- info.push_back(str_util::Join(op_types, "|"));
- }
- return str_util::Join(info, ", ");
-}
-
-TFGraphNodeProto* ShowNode::mutable_proto() { return &proto_; }
-
-const TFGraphNodeProto& ShowNode::proto() const { return proto_; }
-
-void ShowNode::AggregateTotalStats(ShowNode* node) {
- TFGraphNodeProto* node_pb = node->mutable_proto();
- mutable_proto()->set_total_exec_micros(proto().total_exec_micros() +
- node_pb->total_exec_micros());
- mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() +
- node_pb->total_requested_bytes());
- mutable_proto()->set_total_parameters(proto().total_parameters() +
- node_pb->total_parameters());
- mutable_proto()->set_total_float_ops(proto().total_float_ops() +
- node_pb->total_float_ops());
-}
-
-void ShowNode::AddSelfToTotalStats() {
- mutable_proto()->set_total_exec_micros(proto().total_exec_micros() +
- proto().exec_micros());
- mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() +
- proto().requested_bytes());
- mutable_proto()->set_total_parameters(proto().total_parameters() +
- proto().parameters());
- mutable_proto()->set_total_float_ops(proto().total_float_ops() +
- proto().float_ops());
-}
-
-void ShowNode::ResetTotalStats() {
- mutable_proto()->set_total_exec_micros(0);
- mutable_proto()->set_total_requested_bytes(0);
- mutable_proto()->set_total_parameters(0);
- mutable_proto()->set_total_float_ops(0);
- mutable_proto()->mutable_children()->Clear();
-}
const TFGraphNodeProto& TFShow::Show(const Options& opts) {
- const ShowNode* root = ShowInternal(opts);
- if (opts.dump_to_file.empty()) {
- printf("%s", root->formatted_str.c_str());
- fflush(stdout);
- } else {
- Status s = WriteStringToFile(Env::Default(), opts.dump_to_file,
- root->formatted_str);
+ if (opts.output_type == kOutput[0]) {
+ Timeline timeline(opts.output_options.at(kTimelineOpts[0]));
+ return ShowInternal(opts, &timeline)->proto();
+ } else if (opts.output_type == kOutput[2]) {
+ const ShowNode* root = ShowInternal(opts, nullptr);
+ Status s =
+ WriteStringToFile(Env::Default(), opts.output_options.at(kFileOpts[0]),
+ root->formatted_str);
if (!s.ok()) {
fprintf(stderr, "%s\n", s.ToString().c_str());
}
+ return root->proto();
+ } else {
+ const ShowNode* root = ShowInternal(opts, nullptr);
+ printf("%s", root->formatted_str.c_str());
+ fflush(stdout);
+ return root->proto();
}
- return root->proto();
}
bool TFShow::LookUpCheckPoint(const string& name,
@@ -206,10 +83,13 @@ bool TFShow::ShouldShow(ShowNode* node, const Options& opts, int depth) {
show = true;
} else {
for (const string& regex : opts.device_regexes) {
- if (RE2::FullMatch(node->proto().device(), regex)) {
- show = true;
- break;
+ for (const string& device : node->proto().devices()) {
+ if (RE2::FullMatch(device, regex)) {
+ show = true;
+ break;
+ }
}
+ if (show) break;
}
}
// Don't show if device_regexes don't cover it.
@@ -255,11 +135,11 @@ bool TFShow::ShouldAccount(ShowNode* node, const Options& opts) {
return true;
}
}
- if (RE2::FullMatch(node->proto().device(), regex)) {
- return true;
- }
+ for (const string& device : node->proto().devices())
+ if (RE2::FullMatch(device, regex)) {
+ return true;
+ }
}
-
return false;
}
diff --git a/tensorflow/tools/tfprof/internal/tfprof_show.h b/tensorflow/tools/tfprof/internal/tfprof_show.h
index 5e85b81a72..803b301044 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_show.h
+++ b/tensorflow/tools/tfprof/internal/tfprof_show.h
@@ -28,40 +28,15 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/tools/tfprof/internal/tfprof_constants.h"
#include "tensorflow/tools/tfprof/internal/tfprof_node.h"
+#include "tensorflow/tools/tfprof/internal/tfprof_node_show.h"
#include "tensorflow/tools/tfprof/internal/tfprof_options.h"
#include "tensorflow/tools/tfprof/internal/tfprof_tensor.h"
+#include "tensorflow/tools/tfprof/internal/tfprof_timeline.h"
#include "tensorflow/tools/tfprof/internal/tfprof_utils.h"
#include "tensorflow/tools/tfprof/tfprof_output.pb.h"
namespace tensorflow {
namespace tfprof {
-class ShowNode {
- public:
- explicit ShowNode(const TFGraphNode* node);
- virtual ~ShowNode() {}
-
- const string& name() const { return node->name(); }
- TFGraphNodeProto* mutable_proto();
- const TFGraphNodeProto& proto() const;
-
- string Format(const Options& opts);
-
- string FormatMeta(const Options& opts);
-
- const TFGraphNode* node;
- bool account;
- string formatted_str;
-
- protected:
- void AggregateTotalStats(ShowNode* node);
-
- void AddSelfToTotalStats();
-
- void ResetTotalStats();
-
- TFGraphNodeProto proto_;
-};
-
class TFShow {
public:
explicit TFShow(checkpoint::CheckpointReader* ckpt_reader)
@@ -72,7 +47,8 @@ class TFShow {
const TFGraphNodeProto& Show(const Options& opts);
protected:
- virtual const ShowNode* ShowInternal(const Options& opts) = 0;
+ virtual const ShowNode* ShowInternal(const Options& opts,
+ Timeline* timeline) = 0;
bool LookUpCheckPoint(const string& name,
std::unique_ptr<TFProfTensor>* tensor);
diff --git a/tensorflow/tools/tfprof/internal/tfprof_show_code.cc b/tensorflow/tools/tfprof/internal/tfprof_show_code.cc
index 17c0073c7e..cfec09ad19 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_show_code.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_show_code.cc
@@ -26,159 +26,26 @@ limitations under the License.
namespace tensorflow {
namespace tfprof {
-ShowCodeNode::ShowCodeNode(const TFCodeNode* node) : node(node), account(true) {
- std::vector<ScopeNode> snodes;
- for (auto it : node->graph_nodes()) {
- ScopeNode snode(it.second);
- snodes.push_back(snode);
- snodes[snodes.size() - 1].AddSelfToTotalStats();
- *mutable_proto()->mutable_graph_nodes()->Add() =
- snodes[snodes.size() - 1].proto();
- }
-
- mutable_proto()->set_name(name());
- mutable_proto()->set_exec_micros(node->kernel_compute_micros());
- mutable_proto()->set_requested_bytes(node->requested_bytes());
- mutable_proto()->set_float_ops(node->float_ops());
-
- if (!node->shapes().empty()) {
- for (const std::vector<int64>& shape : node->shapes()) {
- int64 params = 1;
- bool complete_shape = true;
- for (int64 d : shape) {
- // Sometimes parameters could be <0 when a dim is unknown.
- if (d < 0) {
- complete_shape = false;
- break;
- }
- params *= d;
- }
- if (complete_shape) {
- mutable_proto()->set_parameters(proto().parameters() + params);
- } else {
- fprintf(stderr, "Incomplete shape.");
- }
- }
- }
-}
-
-string ShowCodeNode::Format(const Options& opts) {
- if (opts.select.empty()) {
- return name();
- }
- return strings::Printf("%s (%s)", name().c_str(), FormatMeta(opts).c_str());
-}
-
-string ShowCodeNode::FormatMeta(const Options& opts) {
- std::vector<string> info;
- std::vector<string> shapes;
- if (opts.select.find(kShown[2]) != opts.select.end()) {
- for (const std::vector<int64>& shape : node->shapes()) {
- if (!shape.empty()) {
- shapes.push_back(FormatShapes(shape));
- }
- }
- if (!shapes.empty()) {
- info.push_back(str_util::Join(shapes, "|"));
- }
- string params = FormatNumber(proto().total_parameters()) + " params";
- if (account) {
- params = FormatNumber(proto().parameters()) + "/" + params;
- } else {
- params = "--/" + params;
- }
- info.push_back(params);
- }
- if (opts.select.find(kShown[3]) != opts.select.end()) {
- string fops = FormatNumber(proto().total_float_ops()) + " flops";
- if (account) {
- fops = FormatNumber(proto().float_ops()) + "/" + fops;
- } else {
- fops = "--/" + fops;
- }
- info.push_back(fops);
- }
- if (opts.select.find(kShown[0]) != opts.select.end()) {
- string memory = FormatMemory(proto().total_requested_bytes());
- if (account) {
- memory = FormatMemory(proto().requested_bytes()) + "/" + memory;
-
- } else {
- memory = "--/" + memory;
- }
- info.push_back(memory);
- }
- if (opts.select.find(kShown[1]) != opts.select.end()) {
- string time = FormatTime(proto().total_exec_micros());
- if (account) {
- time = FormatTime(proto().exec_micros()) + "/" + time;
- } else {
- time = "--/" + time;
- }
- info.push_back(time);
- }
- if (opts.select.find(kShown[6]) != opts.select.end()) {
- if (!node->devices().empty()) {
- info.push_back(str_util::Join(node->devices(), "|"));
- }
- }
- if (opts.select.find(kShown[7]) != opts.select.end()) {
- std::set<string> op_types = node->op_types();
- // Device is considered a type.
- op_types.insert(node->devices().cbegin(), node->devices().cend());
- info.push_back(str_util::Join(op_types, "|"));
- }
- return str_util::Join(info, ", ");
-}
-
-TFCodeNodeProto* ShowCodeNode::mutable_proto() { return &proto_; }
-
-const TFCodeNodeProto& ShowCodeNode::proto() const { return proto_; }
-
-void ShowCodeNode::AggregateTotalStats(ShowCodeNode* node) {
- TFCodeNodeProto* node_pb = node->mutable_proto();
- mutable_proto()->set_total_exec_micros(proto().total_exec_micros() +
- node_pb->total_exec_micros());
- mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() +
- node_pb->total_requested_bytes());
- mutable_proto()->set_total_parameters(proto().total_parameters() +
- node_pb->total_parameters());
- mutable_proto()->set_total_float_ops(proto().total_float_ops() +
- node_pb->total_float_ops());
-}
-
-void ShowCodeNode::AddSelfToTotalStats() {
- mutable_proto()->set_total_exec_micros(proto().total_exec_micros() +
- proto().exec_micros());
- mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() +
- proto().requested_bytes());
- mutable_proto()->set_total_parameters(proto().total_parameters() +
- proto().parameters());
- mutable_proto()->set_total_float_ops(proto().total_float_ops() +
- proto().float_ops());
-}
-
-void ShowCodeNode::ResetTotalStats() {
- mutable_proto()->set_total_exec_micros(0);
- mutable_proto()->set_total_requested_bytes(0);
- mutable_proto()->set_total_parameters(0);
- mutable_proto()->set_total_float_ops(0);
- mutable_proto()->mutable_children()->Clear();
-}
const TFCodeNodeProto& TFShowCode::Show(const Options& opts) {
- const ShowCodeNode* root = ShowInternal(opts);
- if (opts.dump_to_file.empty()) {
- printf("%s", root->formatted_str.c_str());
- fflush(stdout);
- } else {
- Status s = WriteStringToFile(Env::Default(), opts.dump_to_file,
- root->formatted_str);
+ if (opts.output_type == kOutput[0]) {
+ Timeline timeline(opts.output_options.at(kTimelineOpts[0]));
+ return ShowInternal(opts, &timeline)->proto();
+ } else if (opts.output_type == kOutput[2]) {
+ const ShowCodeNode* root = ShowInternal(opts, nullptr);
+ Status s =
+ WriteStringToFile(Env::Default(), opts.output_options.at(kFileOpts[0]),
+ root->formatted_str);
if (!s.ok()) {
fprintf(stderr, "%s\n", s.ToString().c_str());
}
+ return root->proto();
+ } else {
+ const ShowCodeNode* root = ShowInternal(opts, nullptr);
+ printf("%s", root->formatted_str.c_str());
+ fflush(stdout);
+ return root->proto();
}
- return root->proto();
}
bool TFShowCode::ShouldShow(ShowCodeNode* node, const Options& opts,
diff --git a/tensorflow/tools/tfprof/internal/tfprof_show_code.h b/tensorflow/tools/tfprof/internal/tfprof_show_code.h
index afe0fa4473..cbfd38945f 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_show_code.h
+++ b/tensorflow/tools/tfprof/internal/tfprof_show_code.h
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// Parent class and utilities for tfprof_graph and tfprof_scope.
+// Parent class and utilities for tfprof_code.
#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_SHOW_CODE_H_
#define THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_SHOW_CODE_H_
@@ -28,39 +28,15 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/tools/tfprof/internal/tfprof_constants.h"
#include "tensorflow/tools/tfprof/internal/tfprof_node.h"
+#include "tensorflow/tools/tfprof/internal/tfprof_node_show.h"
#include "tensorflow/tools/tfprof/internal/tfprof_options.h"
#include "tensorflow/tools/tfprof/internal/tfprof_tensor.h"
+#include "tensorflow/tools/tfprof/internal/tfprof_timeline.h"
#include "tensorflow/tools/tfprof/internal/tfprof_utils.h"
#include "tensorflow/tools/tfprof/tfprof_output.pb.h"
namespace tensorflow {
namespace tfprof {
-class ShowCodeNode {
- public:
- explicit ShowCodeNode(const TFCodeNode* node);
- virtual ~ShowCodeNode() {}
-
- const string& name() const { return node->name(); }
- TFCodeNodeProto* mutable_proto();
- const TFCodeNodeProto& proto() const;
-
- string Format(const Options& opts);
-
- string FormatMeta(const Options& opts);
-
- const TFCodeNode* node;
- bool account;
- string formatted_str;
-
- protected:
- void AggregateTotalStats(ShowCodeNode* node);
-
- void AddSelfToTotalStats();
-
- void ResetTotalStats();
-
- TFCodeNodeProto proto_;
-};
class TFShowCode {
public:
@@ -71,7 +47,8 @@ class TFShowCode {
const TFCodeNodeProto& Show(const Options& opts);
protected:
- virtual const ShowCodeNode* ShowInternal(const Options& opts) = 0;
+ virtual const ShowCodeNode* ShowInternal(const Options& opts,
+ Timeline* timeline) = 0;
bool LookUpCheckPoint(const string& name,
std::unique_ptr<TFProfTensor>* tensor);
diff --git a/tensorflow/tools/tfprof/internal/tfprof_show_test.cc b/tensorflow/tools/tfprof/internal/tfprof_show_test.cc
index ffaa576639..f0621c9af0 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_show_test.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_show_test.cc
@@ -75,7 +75,7 @@ TEST_F(TFProfShowTest, DumpScopeMode) {
{"VariableV2"}, // accout_type_regexes
{".*"}, {""}, {".*"}, {""}, false,
{"params", "bytes", "micros", "float_ops", "num_hidden_ops"},
- false, dump_file);
+ "file", {{"outfile", dump_file}});
tf_stats_->PrintGraph("scope", opts);
string dump_str;
diff --git a/tensorflow/tools/tfprof/internal/tfprof_stats.cc b/tensorflow/tools/tfprof/internal/tfprof_stats.cc
index 13ff6e7246..566b4cee44 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_stats.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_stats.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/tools/tfprof/internal/tfprof_timeline.h"
namespace tensorflow {
namespace tfprof {
@@ -92,10 +94,16 @@ void TFStats::ParseGraph() {
for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
const NodeDef* node_def = it->second.node_def();
for (string node_input : node_def->input()) {
+ int output_idx = 0;
// input name format can be: "^node:src_output"
auto prefix_pos = node_input.find(":");
if (prefix_pos != node_input.npos) {
- node_input.substr(0, prefix_pos);
+ std::vector<string> input_parts = str_util::Split(node_input, ":");
+ CHECK(input_parts.size() == 2)
+ << "Unknown NodeDef.input format: " << node_input;
+ node_input = input_parts[0];
+ CHECK(strings::safe_strto32(input_parts[1], &output_idx))
+ << "Failed to parse integer: " << output_idx;
}
if (node_input.substr(0, 1) == "^") {
node_input = node_input.substr(1);
@@ -104,7 +112,7 @@ void TFStats::ParseGraph() {
if (input_node == nodes_map_.end()) {
continue;
}
- it->second.AddInput(&input_node->second);
+ it->second.AddInput(&input_node->second, output_idx);
}
}
}
@@ -137,21 +145,6 @@ void TFStats::ParseRunMeta() {
node->second.AddStepStat(dev_stat.device(), &node_stat);
}
}
-
- if (!run_meta_->has_cost_graph()) {
- fprintf(stderr,
- "Missing CostGraphDef in RunMetadata.\nMaybe you forget to"
- "set tf.ConfigProto(graph_options=tf.GraphOptions("
- "build_cost_model=1)) to Session()\n");
- } else {
- for (const auto& node_pb : run_meta_->cost_graph().node()) {
- auto node = nodes_map_.find(node_pb.name());
- if (node == nodes_map_.end()) {
- continue;
- }
- node->second.AddNodeStat(&node_pb);
- }
- }
}
} // namespace tfprof
} // namespace tensorflow
diff --git a/tensorflow/tools/tfprof/internal/tfprof_stats_test.cc b/tensorflow/tools/tfprof/internal/tfprof_stats_test.cc
index b913161f6a..eb01425e04 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_stats_test.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_stats_test.cc
@@ -74,8 +74,8 @@ TEST_F(TFProfStatsTest, CustomOpType) {
Options opts(3, 0, 0, 0, 0, {".*"}, "name",
{kTrainableVarType}, // accout_type_regexes
{".*"}, {""}, {".*"}, {""}, false,
- {"params", "bytes", "micros", "float_ops", "num_hidden_ops"},
- false);
+ {"params", "bytes", "micros", "float_ops", "num_hidden_ops"}, "",
+ {});
const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
TFGraphNodeProto expected;
@@ -84,20 +84,20 @@ TEST_F(TFProfStatsTest, CustomOpType) {
"0\ntotal_exec_micros: 5\ntotal_requested_bytes: 1480\ntotal_parameters: "
"370\nchildren {\n name: \"conv2d/bias\"\n exec_micros: 1\n "
"requested_bytes: 20\n parameters: 5\n total_exec_micros: 1\n "
- "total_requested_bytes: 20\n total_parameters: 5\n device: "
+ "total_requested_bytes: 20\n total_parameters: 5\n devices: "
"\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n "
"total_float_ops: 0\n}\nchildren {\n name: \"conv2d/kernel\"\n "
"exec_micros: 1\n requested_bytes: 540\n parameters: 135\n "
"total_exec_micros: 1\n total_requested_bytes: 540\n total_parameters: "
- "135\n device: \"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: "
+ "135\n devices: \"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: "
"0\n total_float_ops: 0\n}\nchildren {\n name: \"conv2d_1/bias\"\n "
"exec_micros: 1\n requested_bytes: 20\n parameters: 5\n "
"total_exec_micros: 1\n total_requested_bytes: 20\n total_parameters: "
- "5\n device: \"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: "
+ "5\n devices: \"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: "
"0\n total_float_ops: 0\n}\nchildren {\n name: \"conv2d_1/kernel\"\n "
"exec_micros: 2\n requested_bytes: 900\n parameters: 225\n "
"total_exec_micros: 2\n total_requested_bytes: 900\n total_parameters: "
- "225\n device: \"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: "
+ "225\n devices: \"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: "
"0\n total_float_ops: 0\n}\nfloat_ops: 0\ntotal_float_ops: 0\n",
&expected));
EXPECT_EQ(expected.DebugString(), root.DebugString());
@@ -107,7 +107,7 @@ TEST_F(TFProfStatsTest, CheckPointOpType) {
Options opts(
3, 0, 0, 0, 0, {".*"}, "name", {kCkptVarType}, // accout_type_regexes
{".*"}, {""}, {".*"}, {""}, false,
- {"params", "bytes", "micros", "float_ops", "num_hidden_ops"}, false);
+ {"params", "bytes", "micros", "float_ops", "num_hidden_ops"}, "", {});
const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
TFGraphNodeProto expected;
@@ -116,20 +116,20 @@ TEST_F(TFProfStatsTest, CheckPointOpType) {
"0\ntotal_exec_micros: 5\ntotal_requested_bytes: 1480\ntotal_parameters: "
"370\nchildren {\n name: \"conv2d/bias\"\n exec_micros: 1\n "
"requested_bytes: 20\n parameters: 5\n total_exec_micros: 1\n "
- "total_requested_bytes: 20\n total_parameters: 5\n device: "
+ "total_requested_bytes: 20\n total_parameters: 5\n devices: "
"\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n "
"total_float_ops: 0\n}\nchildren {\n name: \"conv2d/kernel\"\n "
"exec_micros: 1\n requested_bytes: 540\n parameters: 135\n "
"total_exec_micros: 1\n total_requested_bytes: 540\n total_parameters: "
- "135\n device: \"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: "
+ "135\n devices: \"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: "
"0\n total_float_ops: 0\n}\nchildren {\n name: \"conv2d_1/bias\"\n "
"exec_micros: 1\n requested_bytes: 20\n parameters: 5\n "
"total_exec_micros: 1\n total_requested_bytes: 20\n total_parameters: "
- "5\n device: \"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: "
+ "5\n devices: \"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: "
"0\n total_float_ops: 0\n}\nchildren {\n name: \"conv2d_1/kernel\"\n "
"exec_micros: 2\n requested_bytes: 900\n parameters: 225\n "
"total_exec_micros: 2\n total_requested_bytes: 900\n total_parameters: "
- "225\n device: \"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: "
+ "225\n devices: \"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: "
"0\n total_float_ops: 0\n}\nfloat_ops: 0\ntotal_float_ops: 0\n",
&expected));
EXPECT_EQ(expected.DebugString(), root.DebugString());
@@ -139,8 +139,8 @@ TEST_F(TFProfStatsTest, TestGraph) {
Options opts(100, 0, 10000, 0, 0, {".*"}, "name", {".*"},
{"cost.*"}, // start_name_regexes
{""}, {".*"}, {""}, false,
- {"params", "bytes", "micros", "float_ops", "num_hidden_ops"},
- false);
+ {"params", "bytes", "micros", "float_ops", "num_hidden_ops"}, "",
+ {});
const TFGraphNodeProto& root = tf_stats_->PrintGraph("graph", opts);
TFGraphNodeProto expected;
@@ -154,7 +154,7 @@ TEST_F(TFProfStatsTest, TestGraph) {
TEST_F(TFProfStatsTest, TestFloatOps) {
Options opts(10, 0, 0, 0, 1, {".*"}, "name", {".*"}, {".*"}, {""}, {".*"},
- {""}, false, {"float_ops"}, false);
+ {""}, false, {"float_ops"}, "", {});
const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
TFGraphNodeProto expected;
@@ -163,19 +163,19 @@ TEST_F(TFProfStatsTest, TestFloatOps) {
"0\ntotal_exec_micros: 96\ntotal_requested_bytes: "
"8656\ntotal_parameters: 370\nchildren {\n name: \"conv2d/BiasAdd\"\n "
"exec_micros: 12\n requested_bytes: 1440\n total_exec_micros: 12\n "
- "total_requested_bytes: 1440\n total_parameters: 0\n device: "
+ "total_requested_bytes: 1440\n total_parameters: 0\n devices: "
"\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 360\n "
"total_float_ops: 360\n}\nchildren {\n name: \"conv2d/convolution\"\n "
"exec_micros: 60\n requested_bytes: 1440\n total_exec_micros: 60\n "
- "total_requested_bytes: 1440\n total_parameters: 0\n device: "
+ "total_requested_bytes: 1440\n total_parameters: 0\n devices: "
"\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 19440\n "
"total_float_ops: 19440\n}\nchildren {\n name: \"conv2d_2/BiasAdd\"\n "
"exec_micros: 2\n requested_bytes: 640\n total_exec_micros: 2\n "
- "total_requested_bytes: 640\n total_parameters: 0\n device: "
+ "total_requested_bytes: 640\n total_parameters: 0\n devices: "
"\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 160\n "
"total_float_ops: 160\n}\nchildren {\n name: \"conv2d_2/convolution\"\n "
" exec_micros: 13\n requested_bytes: 640\n total_exec_micros: 13\n "
- "total_requested_bytes: 640\n total_parameters: 0\n device: "
+ "total_requested_bytes: 640\n total_parameters: 0\n devices: "
"\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 14400\n "
"total_float_ops: 14400\n}\nfloat_ops: 0\ntotal_float_ops: 34360\n",
&expected));
@@ -186,7 +186,7 @@ TEST_F(TFProfStatsTest, TestAccountShownNameOnly) {
Options opts(100, 0, 0, 0, 0, {".*"}, "name", {".*"}, {".*"}, {""},
{"unit_2_1.*DW"}, // show_name_regexes.
{""}, true, // account_displayed_op_only.
- {"params"}, false);
+ {"params"}, "", {});
const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
TFGraphNodeProto expected;
@@ -202,7 +202,7 @@ TEST_F(TFProfStatsTest, TestShowTensorValue) {
Options opts(10, 0, 0, 0, 0, {".*"}, "name", {".*"}, {".*"}, {""},
{"unit_1_0.*gamma"}, {""}, false,
{"tensor_value"}, // Show tensor value from checkpoint.
- false);
+ "", {});
const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
TFGraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString(
diff --git a/tensorflow/tools/tfprof/internal/tfprof_tensor_test.cc b/tensorflow/tools/tfprof/internal/tfprof_tensor_test.cc
index 95759f9d47..79a781210d 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_tensor_test.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_tensor_test.cc
@@ -57,7 +57,7 @@ class TFProfTensorTest : public ::testing::Test {
TEST_F(TFProfTensorTest, Basics) {
Options opts(3, 0, 0, 0, 0, {".*"}, "name", {"VariableV2"}, {".*"}, {""},
{".*"}, {""}, false, {"tensor_value"}, // show the tensor value.
- false);
+ "", {});
const TFGraphNodeProto& root = tf_stats_->PrintGraph("scope", opts);
TFGraphNodeProto expected;
diff --git a/tensorflow/tools/tfprof/internal/tfprof_timeline.cc b/tensorflow/tools/tfprof/internal/tfprof_timeline.cc
new file mode 100644
index 0000000000..a5640c0e56
--- /dev/null
+++ b/tensorflow/tools/tfprof/internal/tfprof_timeline.cc
@@ -0,0 +1,245 @@
+/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/tools/tfprof/internal/tfprof_timeline.h"
+
+#include <utility>
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/tools/tfprof/internal/tfprof_utils.h"
+
+namespace tensorflow {
+namespace tfprof {
+
+Json::Value ChromeTraceFormatter::CreateEvent(const string& ph,
+ const string& category,
+ const string& name, int64 pid,
+ int64 tid, int64 ts) {
+ Json::Value event(Json::objectValue);
+ event["ph"] = Json::Value(ph);
+ event["cat"] = Json::Value(category);
+ event["name"] = Json::Value(name);
+ event["pid"] = Json::Value(pid);
+ event["tid"] = Json::Value(tid);
+ event["ts"] = Json::Value(ts);
+ return event;
+}
+
+void ChromeTraceFormatter::EmitPID(const string& name, int64 pid) {
+ Json::Value event(Json::objectValue);
+ event["name"] = Json::Value("process_name");
+ event["ph"] = Json::Value("M");
+ event["pid"] = Json::Value(pid);
+ Json::Value args(Json::objectValue);
+ args["name"] = Json::Value(name);
+ event["args"] = args;
+ metadata_.push_back(event);
+}
+
+void ChromeTraceFormatter::EmitRegion(int64 ts, int64 duration, int64 pid,
+ int64 tid, const string& category,
+ const string& name, Json::Value args) {
+ Json::Value event = CreateEvent("X", category, name, pid, tid, ts);
+ event["dur"] = Json::Value(duration);
+ event["args"] = std::move(args);
+ metadata_.push_back(event);
+}
+
+void ChromeTraceFormatter::EmitFlowStart(const string& name, int64 ts,
+ int64 pid, int64 tid, int64 flow_id) {
+ Json::Value event = CreateEvent("s", "DataFlow", name, pid, tid, ts);
+ event["id"] = flow_id;
+ events_.push_back(event);
+}
+
+void ChromeTraceFormatter::EmitFlowEnd(const string& name, int64 ts, int64 pid,
+ int64 tid, int64 flow_id) {
+ Json::Value event = CreateEvent("t", "DataFlow", name, pid, tid, ts);
+ event["id"] = flow_id;
+ events_.push_back(event);
+}
+
+string ChromeTraceFormatter::Format() {
+ Json::Value trace;
+ trace["traceEvents"] = Json::Value(Json::arrayValue);
+ for (const Json::Value& v : metadata_) {
+ trace["traceEvents"].append(v);
+ }
+ for (const Json::Value& v : events_) {
+ trace["traceEvents"].append(v);
+ }
+ return trace.toStyledString();
+}
+
+void Timeline::GenerateGraphTimeline(const GraphNode* gnode) {
+ fprintf(stdout, "adding graph nodes.\n");
+ AddGraphNode(gnode);
+ AllocateLanes();
+ fprintf(stdout, "generating trace file.\n");
+ int64 flow_id = 1;
+ for (const auto& process : alloc_nodes_) {
+ for (const auto& lane : process.second) {
+ for (const auto& node : lane.second) {
+ TimeNode* tnode = node.second;
+
+ Json::Value args(Json::objectValue);
+ args["name"] = Json::Value(tnode->name);
+ args["op"] = Json::Value(tnode->name);
+ chrome_formatter_.EmitRegion(node.first, tnode->exec_micros,
+ process.first, lane.first, "Op",
+ tnode->name, args);
+
+ for (TimeNode* next_tnode : node.second->next_tnodes) {
+ chrome_formatter_.EmitFlowStart(
+ tnode->name + "_flow", tnode->start_micros + tnode->exec_micros,
+ process.first, lane.first, flow_id);
+ chrome_formatter_.EmitFlowEnd(
+ tnode->name + "_flow", next_tnode->start_micros,
+ next_tnode->process->pid, next_tnode->tid, flow_id);
+ flow_id += 1;
+ }
+ }
+ }
+ }
+ OutputTimeline();
+}
+
+void Timeline::GenerateScopeTimeline(const ScopeNode* node) {
+ std::set<int64> visited_depth;
+ EmitTreeNode(node, 0, node->proto().total_exec_micros(), 0, &visited_depth);
+ OutputTimeline();
+}
+
+void Timeline::GenerateCodeTimeline(const CodeNode* node) {
+ std::set<int64> visited_depth;
+ EmitTreeNode(node, 0, node->proto().total_exec_micros(), 0, &visited_depth);
+ OutputTimeline();
+}
+
+void Timeline::OutputTimeline() {
+ Status s =
+ WriteStringToFile(Env::Default(), outfile_, chrome_formatter_.Format());
+ if (!s.ok()) {
+ fprintf(stderr, "Failed to write timeline file: %s\nError: %s\n",
+ outfile_.c_str(), s.ToString().c_str());
+ return;
+ }
+ fprintf(stdout, "\n******************************************************\n");
+ fprintf(stdout,
+ "Timeline file is written to %s.\n"
+ "Open a Chrome browser, enter URL chrome://tracing and "
+ "load the timeline file.",
+ outfile_.c_str());
+ fprintf(stdout, "\n******************************************************\n");
+ fflush(stdout);
+}
+
+std::vector<TimeNode*> Timeline::AddGraphNode(const GraphNode* gnode) {
+ std::vector<TimeNode*> tnodes;
+ if (!gnode) return tnodes;
+
+ std::vector<TimeNode*> shown_cinputs;
+ for (GraphNode* schild : gnode->show_children) {
+ std::vector<TimeNode*> inputs = AddGraphNode(schild);
+ shown_cinputs.insert(shown_cinputs.end(), inputs.begin(), inputs.end());
+ }
+ if (!gnode->node->step_stats()) {
+ return shown_cinputs;
+ }
+
+ const TFGraphNode* node = gnode->node;
+ for (const auto& kernel_execs : node->op_kernel_execs()) {
+ const string& device = kernel_execs.first;
+ const std::vector<std::pair<int64, int64>>& execs = kernel_execs.second;
+
+ if (process_.find(device) == process_.end()) {
+ int64 pid = AllocatePID();
+ process_[device].reset(new Process(pid));
+ chrome_formatter_.EmitPID(device, pid);
+ }
+ Process* p = process_[device].get();
+
+ for (const auto& exec : execs) {
+ int64 start_micros = exec.first;
+ int64 exec_micros = exec.second;
+ // TODO(xpan): There might be start time duplication here.
+ if (tnodes_[device].find(start_micros) == tnodes_[device].end()) {
+ // TODO(xpan): Give each kernel call a unique_name.
+ tnodes_[device][start_micros].reset(
+ new TimeNode(p, node->name(), start_micros, exec_micros));
+ }
+ TimeNode* tnode_ptr = tnodes_[device][start_micros].get();
+
+ for (int i = 0; i < shown_cinputs.size(); i++) {
+ shown_cinputs[i]->next_tnodes.push_back(tnode_ptr);
+ }
+ tnodes.push_back(tnode_ptr);
+ }
+ }
+ return tnodes;
+}
+
+void Timeline::AllocateLanes() {
+ for (auto& process : tnodes_) {
+ Process* p = process_[process.first].get();
+ for (auto& tnode : process.second) {
+ int64 start_time = tnode.second->start_micros;
+ int64 end_time = tnode.second->exec_micros - 1;
+
+ int64 l = -1;
+ for (int i = 0; i < p->lanes.size(); ++i) {
+ const auto& lane = p->lanes[i];
+ auto cur_it = lane.lower_bound(start_time);
+ if (cur_it == lane.end()) {
+ --cur_it;
+ }
+ l = i;
+ for (; cur_it != lane.begin(); --cur_it) {
+ if (cur_it->second < start_time) {
+ break;
+ }
+ if (cur_it->first <= end_time) {
+ l = -1;
+ break;
+ }
+ }
+ if (l >= 0) {
+ break;
+ }
+ }
+ if (l < 0) {
+ l = p->lanes.size();
+ std::map<int64, int64> nlane;
+ nlane[start_time] = end_time;
+ p->lanes.push_back(nlane);
+ } else {
+ p->lanes[l][start_time] = end_time;
+ }
+ tnode.second->tid = l;
+ alloc_nodes_[p->pid][l][start_time] = tnode.second.get();
+ }
+ }
+}
+
+int64 Timeline::AllocatePID() {
+ int64 cur_pid = next_pid_;
+ next_pid_ += 1;
+ return cur_pid;
+}
+
+} // namespace tfprof
+} // namespace tensorflow
diff --git a/tensorflow/tools/tfprof/internal/tfprof_timeline.h b/tensorflow/tools/tfprof/internal/tfprof_timeline.h
new file mode 100644
index 0000000000..3d26874abd
--- /dev/null
+++ b/tensorflow/tools/tfprof/internal/tfprof_timeline.h
@@ -0,0 +1,147 @@
+/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_TIMELINE_H_
+#define THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_TIMELINE_H_
+
+#include "include/json/json.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/tools/tfprof/internal/tfprof_node_show.h"
+
+namespace tensorflow {
+namespace tfprof {
+
+typedef std::map<string, string> Event;
+
+class ChromeTraceFormatter {
+ public:
+ ChromeTraceFormatter() {}
+
+ Json::Value CreateEvent(const string& ph, const string& category,
+ const string& name, int64 pid, int64 tid, int64 ts);
+
+ void EmitPID(const string& name, int64 pid);
+
+ void EmitRegion(int64 ts, int64 duration, int64 pid, int64 tid,
+ const string& category, const string& name, Json::Value args);
+
+ void EmitFlowStart(const string& name, int64 ts, int64 pid, int64 tid,
+ int64 flow_id);
+
+ void EmitFlowEnd(const string& name, int64 ts, int64 pid, int64 tid,
+ int64 flow_id);
+
+ string Format();
+
+ private:
+ std::vector<Json::Value> events_;
+ std::vector<Json::Value> metadata_;
+};
+
+class Process {
+ public:
+ Process(int64 pid) : pid(pid) {}
+
+ // Each lane is a map from start_time to end_time.
+ std::vector<std::map<int64, int64>> lanes;
+ int64 pid;
+};
+
+class TimeNode {
+ public:
+ TimeNode(Process* process, const string& name, int64 start_micros,
+ int64 exec_micros)
+ : process(process),
+ name(name),
+ start_micros(start_micros),
+ exec_micros(exec_micros),
+ tid(-1) {}
+ virtual ~TimeNode() {}
+
+ Process* process;
+ string name;
+ int64 start_micros;
+ int64 exec_micros;
+ int64 tid;
+ std::vector<TimeNode*> next_tnodes;
+};
+
+class Timeline {
+ public:
+ Timeline(const string& outfile) : outfile_(outfile) {}
+ ~Timeline() {}
+
+ void GenerateGraphTimeline(const GraphNode* gnode);
+
+ void GenerateScopeTimeline(const ScopeNode* node);
+
+ void GenerateCodeTimeline(const CodeNode* node);
+
+ private:
+ void OutputTimeline();
+
+ template <typename Node>
+ void EmitTreeNode(const Node* node, int64 start_time, int64 duration,
+ int64 depth, std::set<int64>* visited_depth) {
+ if (visited_depth->find(depth) == visited_depth->end()) {
+ chrome_formatter_.EmitPID(strings::StrCat("Scope:", depth), depth);
+ visited_depth->insert(depth);
+ }
+
+ Json::Value args(Json::objectValue);
+ args["name"] = Json::Value(node->name());
+ args["op"] = Json::Value(node->name());
+ chrome_formatter_.EmitRegion(start_time, duration, depth, 0, "Op",
+ node->name(), args);
+
+ int64 total_micros = 0;
+ int64 c_start_time = start_time;
+ for (const Node* child : node->show_children) {
+ int64 total_exec_micros = child->proto().total_exec_micros();
+ if (total_exec_micros <= 0) {
+ continue;
+ }
+ EmitTreeNode(child, c_start_time, total_exec_micros, depth + 1,
+ visited_depth);
+ c_start_time += total_exec_micros;
+ total_micros += total_exec_micros;
+ }
+ CHECK(total_micros <= duration) << node->name() << " parent:" << duration
+ << " children:" << total_micros;
+ }
+
+ std::vector<TimeNode*> AddGraphNode(const GraphNode* gnode);
+
+ void AllocateLanes();
+
+ int64 AllocatePID();
+
+ const string outfile_;
+ int64 next_pid_ = 0;
+ int64 allocator_pid_ = -1;
+ ChromeTraceFormatter chrome_formatter_;
+ std::map<string, int64> device_pids_;
+
+ std::map<string, std::unique_ptr<Process>> process_;
+ std::map<int64, std::map<int64, std::map<int64, TimeNode*>>> alloc_nodes_;
+ std::map<string, std::map<int64, std::unique_ptr<TimeNode>>> tnodes_;
+};
+
+} // namespace tfprof
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_TFPROF_TIMELINE_H_
diff --git a/tensorflow/tools/tfprof/internal/tfprof_timeline_test.cc b/tensorflow/tools/tfprof/internal/tfprof_timeline_test.cc
new file mode 100644
index 0000000000..2dfe6ab335
--- /dev/null
+++ b/tensorflow/tools/tfprof/internal/tfprof_timeline_test.cc
@@ -0,0 +1,92 @@
+/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/tools/tfprof/internal/tfprof_stats.h"
+
+#include <utility>
+
+#include "tensorflow/c/checkpoint_reader.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/tools/tfprof/internal/tfprof_constants.h"
+#include "tensorflow/tools/tfprof/internal/tfprof_options.h"
+#include "tensorflow/tools/tfprof/internal/tfprof_utils.h"
+#include "tensorflow/tools/tfprof/tfprof_log.pb.h"
+#include "tensorflow/tools/tfprof/tfprof_output.pb.h"
+
+namespace tensorflow {
+namespace tfprof {
+class TFProfTimelineTest : public ::testing::Test {
+ protected:
+ TFProfTimelineTest() {
+ string graph_path =
+ io::JoinPath(testing::TensorFlowSrcRoot(),
+ "tools/tfprof/internal/testdata/graph.pbtxt");
+ std::unique_ptr<tensorflow::GraphDef> graph_pb(new tensorflow::GraphDef());
+ TF_CHECK_OK(ReadGraphDef(Env::Default(), graph_path, graph_pb.get()));
+
+ std::unique_ptr<tensorflow::RunMetadata> run_meta_pb(
+ new tensorflow::RunMetadata());
+ string run_meta_path =
+ io::JoinPath(testing::TensorFlowSrcRoot(),
+ "tools/tfprof/internal/testdata/run_meta");
+ TF_CHECK_OK(
+ ReadBinaryProto(Env::Default(), run_meta_path, run_meta_pb.get()));
+
+ tf_stats_.reset(new TFStats(std::move(graph_pb), std::move(run_meta_pb),
+ nullptr, nullptr));
+ }
+
+ std::unique_ptr<TFStats> tf_stats_;
+};
+
+// Before adding test, first dump the json file and
+// manually check it's correct
+TEST_F(TFProfTimelineTest, GraphView) {
+ string dump_file = io::JoinPath(testing::TmpDir(), "dump");
+ Options opts(10000, 0, 0, 0, 0, {".*"}, "name",
+ {".*"}, // accout_type_regexes
+ {".*"}, {""}, {".*"}, {""}, false,
+ {"params", "bytes", "micros", "float_ops", "num_hidden_ops"},
+ "timeline", {{"outfile", dump_file}});
+ tf_stats_->PrintGraph("graph", opts);
+
+ string dump_str;
+ TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file, &dump_str));
+ EXPECT_EQ(14171250174278825648ull, Hash64(dump_str));
+}
+
+TEST_F(TFProfTimelineTest, ScopeView) {
+ string dump_file = io::JoinPath(testing::TmpDir(), "dump");
+ Options opts(5, 0, 0, 0, 0, {".*"}, "name", {".*"}, // accout_type_regexes
+ {".*"}, {""}, {".*"}, {""}, false,
+ {"params", "bytes", "micros", "float_ops", "num_hidden_ops"},
+ "timeline", {{"outfile", dump_file}});
+ tf_stats_->PrintGraph("scope", opts);
+
+ string dump_str;
+ TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file, &dump_str));
+ EXPECT_EQ(2355241164346147404ull, Hash64(dump_str));
+}
+
+// TODO(xpan): tfprof_log is too large to include in testdata when adding
+// code traces.
+
+} // namespace tfprof
+} // namespace tensorflow
diff --git a/tensorflow/tools/tfprof/internal/tfprof_utils.cc b/tensorflow/tools/tfprof/internal/tfprof_utils.cc
index eb2be35c6e..8e55e009d3 100644
--- a/tensorflow/tools/tfprof/internal/tfprof_utils.cc
+++ b/tensorflow/tools/tfprof/internal/tfprof_utils.cc
@@ -251,19 +251,13 @@ tensorflow::Status ParseCmdLine(const string& line, string* cmd,
opts->select = requested_set;
++i;
} else if (pieces[i] == tensorflow::tfprof::kOptions[14]) {
- if ((pieces.size() > i + 1 && pieces[i + 1].find("-") == 0) ||
- pieces.size() == i + 1) {
- opts->viz = true;
- } else if (!StringToBool(pieces[i + 1], &opts->viz)) {
- return ReturnError(pieces, i);
- } else {
- ++i;
- }
- } else if (pieces[i] == tensorflow::tfprof::kOptions[15]) {
if (pieces.size() <= i + 1) {
return ReturnError(pieces, i);
}
- opts->dump_to_file = StripQuote(pieces[i + 1]);
+
+ tensorflow::Status s =
+ ParseOutput(pieces[i + 1], &opts->output_type, &opts->output_options);
+ if (!s.ok()) return s;
++i;
} else {
return ReturnError(pieces, i);
diff --git a/tensorflow/tools/tfprof/tfprof_main.cc b/tensorflow/tools/tfprof/tfprof_main.cc
index dd18a4ad3c..cfe239da22 100644
--- a/tensorflow/tools/tfprof/tfprof_main.cc
+++ b/tensorflow/tools/tfprof/tfprof_main.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/tfprof/internal/tfprof_options.h"
@@ -82,8 +83,7 @@ int main(int argc, char** argv) {
tensorflow::string FLAGS_hide_name_regexes;
bool FLAGS_account_displayed_op_only = false;
tensorflow::string FLAGS_select = "params";
- bool FLAGS_viz = false;
- tensorflow::string FLAGS_dump_to_file = "";
+ tensorflow::string FLAGS_output = "";
for (int i = 0; i < argc; i++) {
fprintf(stderr, "%s\n", argv[i]);
}
@@ -117,7 +117,7 @@ int main(int argc, char** argv) {
&FLAGS_account_displayed_op_only,
"account displayed op only"),
tensorflow::Flag("select", &FLAGS_select, "select"),
- tensorflow::Flag("dump_to_file", &FLAGS_dump_to_file, "dump to file"),
+ tensorflow::Flag("output", &FLAGS_output, "output"),
};
tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
@@ -144,6 +144,12 @@ int main(int argc, char** argv) {
std::vector<tensorflow::string> select =
Split(FLAGS_select, ',', tensorflow::str_util::SkipEmpty());
+ tensorflow::string output_type;
+ std::map<tensorflow::string, tensorflow::string> output_options;
+ tensorflow::Status s = tensorflow::tfprof::ParseOutput(
+ FLAGS_output, &output_type, &output_options);
+ CHECK(s.ok()) << s.ToString();
+
tensorflow::string cmd = "";
if (argc == 1 && FLAGS_graph_path.empty()) {
printf("1) go/tfprof: Tutorial.\n");
@@ -186,10 +192,18 @@ int main(int argc, char** argv) {
std::unique_ptr<tensorflow::tfprof::OpLog> op_log(
new tensorflow::tfprof::OpLog());
- if (!ReadBinaryProto(tensorflow::Env::Default(), FLAGS_op_log_path,
- op_log.get())
- .ok()) {
- op_log.release();
+ if (!FLAGS_op_log_path.empty()) {
+ tensorflow::string op_log_str;
+ s = tensorflow::ReadFileToString(tensorflow::Env::Default(),
+ FLAGS_op_log_path, &op_log_str);
+ if (!s.ok()) {
+ fprintf(stderr, "Failed to read op_log_path: %s\n", s.ToString().c_str());
+ return 1;
+ }
+ if (!tensorflow::ParseProtoUnlimited(op_log.get(), op_log_str)) {
+ fprintf(stderr, "Failed to parse op_log_path\n");
+ return 1;
+ }
}
std::unique_ptr<tensorflow::checkpoint::CheckpointReader> ckpt_reader;
@@ -212,8 +226,8 @@ int main(int argc, char** argv) {
FLAGS_max_depth, FLAGS_min_bytes, FLAGS_min_micros, FLAGS_min_params,
FLAGS_min_float_ops, device_regexes, FLAGS_order_by, account_type_regexes,
start_name_regexes, trim_name_regexes, show_name_regexes,
- hide_name_regexes, FLAGS_account_displayed_op_only, select, FLAGS_viz,
- FLAGS_dump_to_file);
+ hide_name_regexes, FLAGS_account_displayed_op_only, select, output_type,
+ output_options);
if (cmd == tensorflow::tfprof::kCmds[2]) {
tf_stat.PrintCode(opts);
diff --git a/tensorflow/tools/tfprof/tfprof_options.proto b/tensorflow/tools/tfprof/tfprof_options.proto
index 9d269a0995..84a2e14005 100644
--- a/tensorflow/tools/tfprof/tfprof_options.proto
+++ b/tensorflow/tools/tfprof/tfprof_options.proto
@@ -19,6 +19,6 @@ message OptionsProto {
repeated string hide_name_regexes = 12;
optional bool account_displayed_op_only = 13;
repeated string select = 14;
- optional bool viz = 15;
+ optional string output = 15;
optional string dump_to_file = 16;
}
diff --git a/tensorflow/tools/tfprof/tfprof_output.proto b/tensorflow/tools/tfprof/tfprof_output.proto
index 78dd056662..93e6c1233c 100644
--- a/tensorflow/tools/tfprof/tfprof_output.proto
+++ b/tensorflow/tools/tfprof/tfprof_output.proto
@@ -31,7 +31,8 @@ message TFGraphNodeProto {
// Number of inputs to the op.
optional int64 inputs = 5;
// Device the op is assigned to.
- optional string device = 10;
+ // Since an op can fire multiple kernel calls, there can be multiple devices.
+ repeated string devices = 10;
// The following are the aggregated stats from all accounted descendants and
// the op itself. The actual descendants depend on the data structure used