aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tfprof
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-05 16:14:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-05 17:30:49 -0700
commit1e59f00c4803ef242500454b6e704a142db33222 (patch)
tree05b39aab2a4a2fee8a288eb156125d5c65a935a8 /tensorflow/contrib/tfprof
parentfdb4eba5b1cd0f2a2b10f83042a7e0eec1a41548 (diff)
Extend tfprof to associate op stats with Python codes.
It's backward compatible. Stats of a source code line are aggregated from all ops created by that line. A example. _TFProfRoot (0us/22.44ms) model_analyzer_test.py:149:run_filename_as_m...:none (0us/22.44ms) model_analyzer_test.py:33:_run_code_in_main:none (0us/22.44ms) model_analyzer_test.py:208:<module>:test.main() (0us/22.44ms) model_analyzer_test.py:132:testComplexCodeView:x = lib.BuildFull... (0us/22.44ms) model_analyzer_testlib.py:63:BuildFullModel:return sgd_op.min... (0us/21.83ms) model_analyzer_testlib.py:54:BuildFullModel:seq.append(array_... (0us/254us) model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0us/134us) ... model_analyzer_testlib.py:61:BuildFullModel:loss = nn_ops.l2_... (0us/28us) model_analyzer_test.py:134:testComplexCodeView:sess.run(variable... (0us/0us) Change: 155258346
Diffstat (limited to 'tensorflow/contrib/tfprof')
-rw-r--r--tensorflow/contrib/tfprof/README.md10
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/BUILD21
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py41
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py150
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_testlib.py67
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py9
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py45
7 files changed, 283 insertions, 60 deletions
diff --git a/tensorflow/contrib/tfprof/README.md b/tensorflow/contrib/tfprof/README.md
index c7ff4a2921..d891ecdc9a 100644
--- a/tensorflow/contrib/tfprof/README.md
+++ b/tensorflow/contrib/tfprof/README.md
@@ -11,7 +11,12 @@ Consultants: Jon Shlens, Pete Warden
1. Measure model parameters, float operations, tensor shapes.
2. Measure op execution times, requested memory size and device placement.
3. Inspect checkpoint tensors' shapes and their values.
-4. Explore model based on name scope or graph structure.
+4. 3 ways to view and explore TensorFlow model profiles
+
+ * Organize by Python code call stack.
+ * Organize by TensorFlow operation name scope hierarchies.
+ * Organize by TensorFlow operation inputs/outputs graph.
+
5. Selectively grouping/filtering/accounting/ordering ops.
tfprof can be used as Python API, Interactive CLI and One-shot Script.
@@ -28,7 +33,8 @@ param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
tfprof_options=tf.contrib.tfprof.model_analyzer.
TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
-# param_stats is tensorflow.tfprof.TFProfNode proto. It organize the statistics
+# param_stats is tensorflow.tfprof.TFGraphNodeProto proto.
+# It organize the statistics
# of each graph node in tree scructure. Let's print the root below.
sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
```
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
index 818c2d2cbf..22bca93c87 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
@@ -23,14 +23,31 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":model_analyzer",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
+ ":model_analyzer_testlib",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:variables",
+ ],
+)
+
+py_library(
+ name = "model_analyzer_testlib",
+ srcs = ["model_analyzer_testlib.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":model_analyzer",
+ "//tensorflow/contrib/rnn:rnn_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:init_ops",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform",
+ "//tensorflow/python:rnn",
+ "//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
index cc94fd65b5..13b407d815 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
@@ -123,7 +123,7 @@ def print_model_analysis(graph,
"""Print model statistics.
Prints the model statistics to stdout. Also returns the results
- in a TFProfNode proto. See go/tfprof or run tfprof tool:
+ in a TFGraphNodeProto proto. See go/tfprof or run tfprof tool:
'bazel run third_party/tensorflow/tools/tfprof help'
Examples:
@@ -142,15 +142,19 @@ def print_model_analysis(graph,
'micros' and 'bytes'.
op_log: tensorflow::tfprof::OpLog proto. users can use this proto to
group together ops and use a op_type to select the group.
- tfprof_cmd: string. Either 'scope' or 'graph'. 'scope' view organize
- ops using their name scopes. 'graph' view organize ops using
- their graph inputs.
+ tfprof_cmd: string. Either 'scope', 'graph', 'code'.
+ 'scope' view organize outputs using ops' name scope.
+ 'graph' view organize outputs using op's inputs/outputs.
+ 'code' view organize outputs using Python call stack.
tfprof_options: See 'tfprof help' for details.
Returns:
- TFProfNode proto. Side effect: a formatted output to stdout.
+ If tfprof_cmd is 'scope' or 'graph', returns TFGraphNodeProto proto.
+ If tfprof_cmd is 'code', returns TFCodeNodeProto proto.
+ Side effect: a formatted output to stdout.
"""
# pylint: disable=protected-access
- op_log = tfprof_logger._merge_default_with_oplog(graph, op_log, run_meta)
+ op_log = tfprof_logger._merge_default_with_oplog(
+ graph, op_log, run_meta, add_trace=tfprof_cmd == 'code')
# pylint: enable=protected-access
opts = tfprof_options_pb2.OptionsProto()
opts.max_depth = tfprof_options['max_depth']
@@ -178,11 +182,24 @@ def print_model_analysis(graph,
opts.dump_to_file = tfprof_options['dump_to_file']
run_meta_str = run_meta.SerializeToString() if run_meta else b''
- op_log_str = op_log.SerializeToString() if op_log else b''
- tfprof_node = tfprof_output_pb2.TFProfNode()
- tfprof_node.ParseFromString(
- print_mdl.PrintModelAnalysis(
- graph.as_graph_def().SerializeToString(), run_meta_str, op_log_str,
- tfprof_cmd.encode('utf-8'), opts.SerializeToString()))
+ if tfprof_cmd == 'code':
+ tfprof_node = tfprof_output_pb2.TFCodeNodeProto()
+ tfprof_node.ParseFromString(
+ print_mdl.PrintModelAnalysis(
+ graph.as_graph_def().SerializeToString(),
+ run_meta_str,
+ op_log.SerializeToString(),
+ tfprof_cmd.encode('utf-8'),
+ opts.SerializeToString()))
+ else:
+ tfprof_node = tfprof_output_pb2.TFGraphNodeProto()
+ tfprof_node.ParseFromString(
+ print_mdl.PrintModelAnalysis(
+ graph.as_graph_def().SerializeToString(),
+ run_meta_str,
+ op_log.SerializeToString(),
+ tfprof_cmd.encode('utf-8'),
+ opts.SerializeToString()))
+
return tfprof_node
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
index 66b9267cbe..ac0d46d4ae 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
@@ -18,49 +18,27 @@ from __future__ import division
from __future__ import print_function
import os
-
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
# XXX: this depends on pywrap_tensorflow and must come later
from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer
+from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer_testlib as lib
class PrintModelAnalysisTest(test.TestCase):
- def _BuildSmallModel(self):
- image = array_ops.zeros([2, 6, 6, 3])
- _ = variable_scope.get_variable(
- 'ScalarW', [],
- dtypes.float32,
- initializer=init_ops.random_normal_initializer(stddev=0.001))
- kernel = variable_scope.get_variable(
- 'DW', [3, 3, 3, 6],
- dtypes.float32,
- initializer=init_ops.random_normal_initializer(stddev=0.001))
- x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
- kernel = variable_scope.get_variable(
- 'DW2', [2, 2, 6, 12],
- dtypes.float32,
- initializer=init_ops.random_normal_initializer(stddev=0.001))
- x = nn_ops.conv2d(x, kernel, [1, 2, 2, 1], padding='SAME')
- return x
-
def testDumpToFile(self):
+ ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
with session.Session() as sess, ops.device('/cpu:0'):
- _ = self._BuildSmallModel()
+ _ = lib.BuildSmallModel()
model_analyzer.print_model_analysis(sess.graph, tfprof_options=opts)
with gfile.Open(opts['dump_to_file'], 'r') as f:
@@ -71,6 +49,7 @@ class PrintModelAnalysisTest(test.TestCase):
f.read())
def testSelectEverything(self):
+ ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
opts['account_type_regexes'] = ['.*']
@@ -78,8 +57,10 @@ class PrintModelAnalysisTest(test.TestCase):
'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device', 'op_types'
]
- with session.Session() as sess, ops.device('/cpu:0'):
- x = self._BuildSmallModel()
+ config = config_pb2.ConfigProto(
+ graph_options=config_pb2.GraphOptions(build_cost_model=1))
+ with session.Session(config=config) as sess, ops.device('/cpu:0'):
+ x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
run_meta = config_pb2.RunMetadata()
@@ -98,6 +79,121 @@ class PrintModelAnalysisTest(test.TestCase):
f.read())
# pylint: enable=line-too-long
+ def testSimpleCodeView(self):
+ ops.reset_default_graph()
+ opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
+ opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
+ opts['account_type_regexes'] = ['.*']
+ opts['show_name_regexes'] = ['.*model_analyzer_testlib.*']
+ opts['account_displayed_op_only'] = False
+ # TODO(xpan): Test 'micros'. Since the execution time changes each run,
+ # it's a bit difficult to test it now.
+ opts['select'] = [
+ 'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device',
+ ]
+
+ config = config_pb2.ConfigProto(
+ graph_options=config_pb2.GraphOptions(build_cost_model=1))
+ with session.Session(config=config) as sess, ops.device('/cpu:0'):
+ x = lib.BuildSmallModel()
+
+ sess.run(variables.global_variables_initializer())
+ run_meta = config_pb2.RunMetadata()
+ _ = sess.run(x,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE),
+ run_metadata=run_meta)
+
+ model_analyzer.print_model_analysis(
+ sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
+
+ with gfile.Open(opts['dump_to_file'], 'r') as f:
+ # pylint: disable=line-too-long
+ self.assertEqual(
+ '_TFProfRoot (0/451 params, 0/10.44k flops, 0B/5.28KB)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops, 0B/864B)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/1 params, 0/0 flops, 0B/0B)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/162 params, 0/0 flops, 0B/1.30KB)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/5.83k flops, 0B/432B)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/288 params, 0/0 flops, 0B/2.30KB)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/4.61k flops, 0B/384B)\n',
+ f.read())
+ # pylint: enable=line-too-long
+
+ def testComplexCodeView(self):
+ ops.reset_default_graph()
+ opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
+ opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
+ opts['account_type_regexes'] = ['.*']
+ opts['show_name_regexes'] = ['.*model_analyzer_testlib.py.*']
+ opts['account_displayed_op_only'] = False
+ opts['select'] = [
+ 'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device',
+ ]
+
+ config = config_pb2.ConfigProto(
+ graph_options=config_pb2.GraphOptions(build_cost_model=1))
+ with session.Session(config=config) as sess, ops.device('/cpu:0'):
+ x = lib.BuildFullModel()
+
+ sess.run(variables.global_variables_initializer())
+ run_meta = config_pb2.RunMetadata()
+ _ = sess.run(x,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE),
+ run_metadata=run_meta)
+
+ tfprof_node = model_analyzer.print_model_analysis(
+ sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
+
+ # pylint: disable=line-too-long
+ with gfile.Open(opts['dump_to_file'], 'r') as f:
+ self.assertEqual(
+ '_TFProfRoot (0/2.84k params, 0/54.08k flops, 0B/241.58KB)\n model_analyzer_testlib.py:56:BuildFullModel:seq.append(array_... (0/1.80k params, 0/41.76k flops, 0B/20.08KB)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops, 0B/864B)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/4 params, 0/0 flops, 0B/0B)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/648 params, 0/0 flops, 0B/5.18KB)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/23.33k flops, 0B/1.73KB)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/1.15k params, 0/0 flops, 0B/9.22KB)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/18.43k flops, 0B/1.54KB)\n model_analyzer_testlib.py:60:BuildFullModel:cell, array_ops.c... (0/1.04k params, 0/4.13k flops, 0B/24.86KB)\n model_analyzer_testlib.py:62:BuildFullModel:target = array_op... (0/0 params, 0/0 flops, 0B/0B)\n model_analyzer_testlib.py:63:BuildFullModel:loss = nn_ops.l2_... (0/0 params, 0/0 flops, 0B/528B)\n model_analyzer_testlib.py:65:BuildFullModel:return sgd_op.min... (0/0 params, 0/8.19k flops, 0B/196.12KB)\n',
+ f.read())
+
+ self.assertEqual(241584, tfprof_node.total_requested_bytes)
+ self.assertLess(0, tfprof_node.total_exec_micros)
+ self.assertEqual(2844, tfprof_node.total_parameters)
+ self.assertEqual(54080, tfprof_node.total_float_ops)
+ self.assertEqual(5, len(tfprof_node.children))
+ self.assertEqual('_TFProfRoot', tfprof_node.name)
+ self.assertEqual('model_analyzer_testlib.py:56:BuildFullModel:seq.append(array_...',
+ tfprof_node.children[0].name)
+ self.assertEqual('model_analyzer_testlib.py:60:BuildFullModel:cell, array_ops.c...',
+ tfprof_node.children[1].name)
+ self.assertEqual('model_analyzer_testlib.py:62:BuildFullModel:target = array_op...',
+ tfprof_node.children[2].name)
+ self.assertEqual('model_analyzer_testlib.py:63:BuildFullModel:loss = nn_ops.l2_...',
+ tfprof_node.children[3].name)
+ self.assertEqual('model_analyzer_testlib.py:65:BuildFullModel:return sgd_op.min...',
+ tfprof_node.children[4].name)
+ # pylint: enable=line-too-long
+
+ def testCodeViewLeafGraphNode(self):
+ ops.reset_default_graph()
+ opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
+ opts['account_type_regexes'] = ['.*']
+ opts['account_displayed_op_only'] = False
+ opts['select'] = [
+ 'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device',
+ ]
+
+ config = config_pb2.ConfigProto(
+ graph_options=config_pb2.GraphOptions(build_cost_model=1))
+ with session.Session(config=config) as sess, ops.device('/cpu:0'):
+ x = lib.BuildSmallModel()
+
+ sess.run(variables.global_variables_initializer())
+ run_meta = config_pb2.RunMetadata()
+ _ = sess.run(x,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE),
+ run_metadata=run_meta)
+
+ tfprof_node = model_analyzer.print_model_analysis(
+ sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
+
+ leaf = tfprof_node
+ while leaf.children:
+ self.assertEqual(0, len(leaf.graph_nodes))
+ leaf = leaf.children[0]
+ self.assertEqual(1, len(leaf.graph_nodes))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_testlib.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_testlib.py
new file mode 100644
index 0000000000..81bac84b8c
--- /dev/null
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_testlib.py
@@ -0,0 +1,67 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A test lib that defines some models."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.rnn.python.ops.core_rnn_cell import BasicRNNCell
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import rnn
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import gradient_descent
+
+
+def BuildSmallModel():
+ """Build a small forward conv model."""
+ image = array_ops.zeros([2, 6, 6, 3])
+ _ = variable_scope.get_variable(
+ 'ScalarW', [],
+ dtypes.float32,
+ initializer=init_ops.random_normal_initializer(stddev=0.001))
+ kernel = variable_scope.get_variable(
+ 'DW', [3, 3, 3, 6],
+ dtypes.float32,
+ initializer=init_ops.random_normal_initializer(stddev=0.001))
+ x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
+ kernel = variable_scope.get_variable(
+ 'DW2', [2, 2, 6, 12],
+ dtypes.float32,
+ initializer=init_ops.random_normal_initializer(stddev=0.001))
+ x = nn_ops.conv2d(x, kernel, [1, 2, 2, 1], padding='SAME')
+ return x
+
+
+def BuildFullModel():
+ """Build the full model with conv,rnn,opt."""
+ seq = []
+ for i in xrange(4):
+ with variable_scope.variable_scope('inp_%d' % i):
+ seq.append(array_ops.reshape(BuildSmallModel(), [2, 1, -1]))
+
+ cell = BasicRNNCell(16, 48)
+ out = rnn.dynamic_rnn(
+ cell, array_ops.concat(seq, axis=1), dtype=dtypes.float32)[0]
+
+ target = array_ops.ones_like(out)
+ loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out))
+ sgd_op = gradient_descent.GradientDescentOptimizer(1e-2)
+ return sgd_op.minimize(loss)
+
+
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py
index f0ac36c66a..aa133d3142 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py
@@ -96,12 +96,13 @@ class PrintModelAnalysisTest(test.TestCase):
with session.Session() as sess, ops.device('/cpu:0'):
_ = self._BuildSmallModel()
- tfprof_pb = tfprof_output_pb2.TFProfNode()
+ tfprof_pb = tfprof_output_pb2.TFGraphNodeProto()
tfprof_pb.ParseFromString(
- print_mdl.PrintModelAnalysis(sess.graph.as_graph_def(
- ).SerializeToString(), b'', b'', b'scope', opts.SerializeToString()))
+ print_mdl.PrintModelAnalysis(
+ sess.graph.as_graph_def().SerializeToString(),
+ b'', b'', b'scope', opts.SerializeToString()))
- expected_pb = tfprof_output_pb2.TFProfNode()
+ expected_pb = tfprof_output_pb2.TFGraphNodeProto()
text_format.Merge(r"""name: "_TFProfRoot"
exec_micros: 0
requested_bytes: 0
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
index e8cf84b6c7..cd3912bbfb 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
@@ -62,12 +62,13 @@ def _fill_missing_graph_shape(graph, run_meta):
return graph
-def _get_logged_ops(graph, run_meta=None):
+def _get_logged_ops(graph, run_meta=None, add_trace=False):
"""Extract trainable model parameters and FLOPs for ops from a Graph.
Args:
graph: tf.Graph.
run_meta: RunMetadata proto used to complete shape information.
+ add_trace: Whether to add op trace information.
Returns:
logged_ops: dict mapping from op_name to OpLogEntry.
"""
@@ -76,21 +77,32 @@ def _get_logged_ops(graph, run_meta=None):
op_missing_shape = 0
logged_ops = {}
- graph_def = graph.as_graph_def()
- for node in graph_def.node:
+ for op in graph.get_operations():
try:
- stats = ops.get_stats_for_node_def(graph, node, REGISTERED_FLOP_STATS)
+ stats = ops.get_stats_for_node_def(
+ graph, op.node_def, REGISTERED_FLOP_STATS)
except ValueError:
# Catch Exception When shape is incomplete. Skip it.
op_missing_shape += 1
stats = None
- if not stats or not stats.value:
- continue
- if node.name not in logged_ops:
- entry = tfprof_log_pb2.OpLogEntry()
- entry.name = node.name
+ entry = tfprof_log_pb2.OpLogEntry()
+ entry.name = op.name
+ add_entry = False
+ if stats and stats.value:
entry.float_ops = int(stats.value)
+ add_entry = True
+
+ if add_trace:
+ for tb in op.traceback:
+ trace = entry.code_def.traces.add()
+ trace.file = tb[0] if tb[0] else 'none'
+ trace.lineno = tb[1] if tb[1] else -1
+ trace.function = tb[2] if tb[2] else 'none'
+ trace.line = tb[3] if tb[3] else 'none'
+ add_entry = True
+
+ if add_entry:
logged_ops[entry.name] = entry
for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
@@ -108,18 +120,21 @@ def _get_logged_ops(graph, run_meta=None):
return logged_ops
-def _merge_default_with_oplog(graph, op_log=None, run_meta=None):
+def _merge_default_with_oplog(graph, op_log=None,
+ run_meta=None,
+ add_trace=False):
"""Merge the tfprof default extra info with caller's op_log.
Args:
graph: tf.Graph.
op_log: OpLog proto.
run_meta: RunMetadata proto used to complete shape information.
+ add_trace: Whether to add op trace information.
Returns:
tmp_op_log: Merged OpLog proto.
"""
tmp_op_log = tfprof_log_pb2.OpLog()
- logged_ops = _get_logged_ops(graph, run_meta)
+ logged_ops = _get_logged_ops(graph, run_meta, add_trace=add_trace)
if not op_log:
tmp_op_log.log_entries.extend(logged_ops.values())
else:
@@ -131,13 +146,16 @@ def _merge_default_with_oplog(graph, op_log=None, run_meta=None):
all_ops[op_name].types.extend(entry.types)
if entry.float_ops > 0 and all_ops[op_name].float_ops == 0:
all_ops[op_name].float_ops = entry.float_ops
+ if entry.code_def.traces and not all_ops[op_name].code_def.traces:
+ all_ops[op_name].code_def.MergeFrom(entry.code_def)
else:
all_ops[op_name] = entry
tmp_op_log.log_entries.extend(all_ops.values())
return tmp_op_log
-def write_op_log(graph, log_dir, op_log=None, run_meta=None):
+def write_op_log(graph, log_dir, op_log=None, run_meta=None,
+ add_trace=False):
"""Log provided 'op_log', and add additional model information below.
The API also assigns ops in tf.trainable_variables() an op type called
@@ -154,8 +172,9 @@ def write_op_log(graph, log_dir, op_log=None, run_meta=None):
one is created.
run_meta: (Optional) RunMetadata proto that helps flops computation using
run time shape information.
+ add_trace: Whether to add op trace information. Used to support "code" view.
"""
- op_log = _merge_default_with_oplog(graph, op_log, run_meta)
+ op_log = _merge_default_with_oplog(graph, op_log, run_meta, add_trace)
with gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log:
log.write(op_log.SerializeToString())