aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tfprof
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-30 22:21:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-30 22:25:03 -0700
commita7fff05e05693f806ae367eb59efb055844380e5 (patch)
treec530203c71d5ad8e14ad439ef5961b91e2ce8cee /tensorflow/contrib/tfprof
parent7aac2395ce92ad9f280a30c0ce110b9fcf494668 (diff)
tfprof multi-step profiling.
This allows users to fill in RunMetadata across different steps. 1. It is useful for RL model which runs a subset of graph each step. 2. It also gets averages of multi-step stats. PiperOrigin-RevId: 157552388
Diffstat (limited to 'tensorflow/contrib/tfprof')
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/BUILD16
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py202
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py1
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_testlib.py30
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py184
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/pywrap_tensorflow_print_model_analysis.i6
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py35
7 files changed, 435 insertions, 39 deletions
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
index c96f6719e7..93d0cb4099 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
@@ -33,6 +33,22 @@ py_test(
],
)
+py_test(
+ name = "profiler_test",
+ srcs = ["profiler_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":model_analyzer",
+ ":model_analyzer_testlib",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:variables",
+ ],
+)
+
py_library(
name = "model_analyzer_testlib",
srcs = ["model_analyzer_testlib.py"],
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
index fcce3cd45b..b640fa7593 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
@@ -112,6 +112,181 @@ PRINT_ALL_TIMING_MEMORY = {
# pylint: enable=bad-continuation
+def _build_options(tfprof_options):
+ """Build tfprof.OptionsProto.
+
+ Args:
+ tfprof_options: A dictionary of options.
+ Returns:
+ tfprof.OptionsProto.
+ """
+ opts = tfprof_options_pb2.OptionsProto()
+ opts.max_depth = tfprof_options.get('max_depth', 10)
+ opts.min_bytes = tfprof_options.get('min_bytes', 0)
+ opts.min_micros = tfprof_options.get('min_micros', 0)
+ opts.min_params = tfprof_options.get('min_params', 0)
+ opts.min_float_ops = tfprof_options.get('min_float_ops', 0)
+ opts.min_occurrence = tfprof_options.get('min_occurrence', 0)
+
+ opts.step = tfprof_options.get('step', -1)
+
+ opts.order_by = tfprof_options.get('order_by', 'name')
+
+ for p in tfprof_options.get('account_type_regexes', []):
+ opts.account_type_regexes.append(p)
+ for p in tfprof_options.get('start_name_regexes', []):
+ opts.start_name_regexes.append(p)
+ for p in tfprof_options.get('trim_name_regexes', []):
+ opts.trim_name_regexes.append(p)
+ for p in tfprof_options.get('show_name_regexes', []):
+ opts.show_name_regexes.append(p)
+ for p in tfprof_options.get('hide_name_regexes', []):
+ opts.hide_name_regexes.append(p)
+ opts.account_displayed_op_only = tfprof_options.get(
+ 'account_displayed_op_only', False)
+
+ for p in tfprof_options.get('select', []):
+ opts.select.append(p)
+
+ opts.output = tfprof_options.get('output', 'stdout')
+ opts.dump_to_file = tfprof_options.get('dump_to_file', '')
+
+ return opts
+
+
+class Profiler(object):
+ """TensorFlow multi-step profiler.
+
+ See go/tfprof or README for details.
+
+ Typical use case:
+ # Currently we are only allowed to create 1 profiler per process.
+ profiler = Profile(sess.graph)
+
+ for i in xrange(total_steps):
+ if i % 10000 == 0:
+ run_meta = tf.RunMetadata()
+ _ = sess.run(...,
+ options=tf.RunOptions(
+ trace_level=tf.RunOptions.FULL_TRACE),
+ run_metadata=run_meta)
+ profiler.add_step(i, run_meta)
+
+ # Profile the parameters of your model.
+ profiler.profile_name_scope(options=TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
+
+ # Or profile the timing of your model operations.
+ opts = PRINT_ALL_TIMING_MEMORY.copy()
+ opts['order_by'] = 'micros'
+ opts['select'] = ['micros', 'occurrence']
+ opts['max_depth'] = 20
+ profiler.profile_operations(options=opts)
+
+ # Or you can generate a timeline:
+ opts = PRINT_ALL_TIMING_MEMORY.copy()
+ opts['output'] = 'timeline:outfile=' + filename
+ opts['step'] = i
+ profiler.profile_graph(options=opts)
+ else:
+ _ = sess.run(...)
+ """
+
+ def __init__(self, graph, op_log=None):
+ """Constructor.
+
+ Args:
+ graph: tf.Graph.
+ op_log: optional. tensorflow::tfprof::OpLog proto. Used to define
+ extra op types.
+ """
+ self._graph = graph
+ # pylint: disable=protected-access
+ op_log = tfprof_logger._merge_default_with_oplog(
+ self._graph, op_log=op_log)
+ # pylint: enable=protected-access
+
+ print_mdl.NewProfiler(
+ self._graph.as_graph_def().SerializeToString(),
+ op_log.SerializeToString())
+
+ def __del__(self):
+ print_mdl.DeleteProfiler()
+
+ def add_step(self, step, run_meta):
+ """Add statistics of a step.
+
+ Args:
+ step: A step uint64 used to identify the RunMetadata. Must be different
+ across different AddStep() calls.
+ run_meta: RunMetadata proto that contains statistics of a session run.
+ """
+ # pylint: disable=protected-access
+ op_log = tfprof_logger._merge_default_with_oplog(
+ self._graph, run_meta=run_meta, add_trace=False,
+ add_trainable_var=False)
+ # pylint: enable=protected-access
+ print_mdl.AddStep(
+ step, run_meta.SerializeToString(), op_log.SerializeToString())
+
+ def profile_python_codes(self, options):
+ """Profile the statistics of the Python codes.
+
+ Hint: set options['show_name_regexes'] = ['.*my_code.py.*']
+
+ Args:
+ options: A dict of profiler options.
+ Returns:
+ a TFMultiGraphNodeProto that records the results.
+ """
+ opts = _build_options(options)
+ tfprof_node = tfprof_output_pb2.TFMultiGraphNodeProto()
+ tfprof_node.ParseFromString(
+ print_mdl.Profile('code'.encode('utf-8'), opts.SerializeToString()))
+ return tfprof_node
+
+ def profile_operations(self, options):
+ """Profile the statistics of the Operation types (e.g. MatMul, Conv2D).
+
+ Args:
+ options: A dict of profiler options.
+ Returns:
+ a TFMultiGraphNodeProto that records the results.
+ """
+ opts = _build_options(options)
+ tfprof_node = tfprof_output_pb2.TFMultiGraphNodeProto()
+ tfprof_node.ParseFromString(
+ print_mdl.Profile('op'.encode('utf-8'), opts.SerializeToString()))
+ return tfprof_node
+
+ def profile_name_scope(self, options):
+ """Profile the statistics of graph nodes, organized by name scope.
+
+ Args:
+ options: A dict of profiler options.
+ Returns:
+ a TFGraphNodeProto that records the results.
+ """
+ opts = _build_options(options)
+ tfprof_node = tfprof_output_pb2.TFGraphNodeProto()
+ tfprof_node.ParseFromString(
+ print_mdl.Profile('scope'.encode('utf-8'), opts.SerializeToString()))
+ return tfprof_node
+
+ def profile_graph(self, options):
+ """Profile the statistics of graph nodes, organized by dataflow graph.
+
+ Args:
+ options: A dict of profiler options.
+ Returns:
+ a TFGraphNodeProto that records the results.
+ """
+ opts = _build_options(options)
+ tfprof_node = tfprof_output_pb2.TFGraphNodeProto()
+ tfprof_node.ParseFromString(
+ print_mdl.Profile('graph'.encode('utf-8'), opts.SerializeToString()))
+ return tfprof_node
+
+
def print_model_analysis(graph,
run_meta=None,
op_log=None,
@@ -145,33 +320,8 @@ def print_model_analysis(graph,
op_log = tfprof_logger._merge_default_with_oplog(
graph, op_log, run_meta, add_trace=tfprof_cmd == 'code')
# pylint: enable=protected-access
- opts = tfprof_options_pb2.OptionsProto()
- opts.max_depth = tfprof_options['max_depth']
- opts.min_bytes = tfprof_options['min_bytes']
- opts.min_micros = tfprof_options['min_micros']
- opts.min_params = tfprof_options['min_params']
- opts.min_float_ops = tfprof_options['min_float_ops']
- if 'min_occurrence' in tfprof_options:
- opts.min_occurrence = tfprof_options['min_occurrence']
- else:
- opts.min_occurrence = 0
- opts.order_by = tfprof_options['order_by']
- for p in tfprof_options['account_type_regexes']:
- opts.account_type_regexes.append(p)
- for p in tfprof_options['start_name_regexes']:
- opts.start_name_regexes.append(p)
- for p in tfprof_options['trim_name_regexes']:
- opts.trim_name_regexes.append(p)
- for p in tfprof_options['show_name_regexes']:
- opts.show_name_regexes.append(p)
- for p in tfprof_options['hide_name_regexes']:
- opts.hide_name_regexes.append(p)
- opts.account_displayed_op_only = tfprof_options['account_displayed_op_only']
- for p in tfprof_options['select']:
- opts.select.append(p)
- opts.output = tfprof_options['output']
- opts.dump_to_file = tfprof_options['dump_to_file']
+ opts = _build_options(tfprof_options)
run_meta_str = run_meta.SerializeToString() if run_meta else b''
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
index 14f8fcff9d..2c55924e83 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
@@ -199,6 +199,7 @@ class PrintModelAnalysisTest(test.TestCase):
opts['output'] = 'timeline:outfile=' + outfile
opts['account_type_regexes'] = ['.*']
opts['max_depth'] = 100000
+ opts['step'] = 0
with session.Session() as sess, ops.device('/cpu:0'):
x = lib.BuildFullModel()
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_testlib.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_testlib.py
index fea08b2db8..42b83fde7c 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_testlib.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_testlib.py
@@ -65,3 +65,33 @@ def BuildFullModel():
loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out))
sgd_op = gradient_descent.GradientDescentOptimizer(1e-2)
return sgd_op.minimize(loss)
+
+
+def BuildSplitableModel():
+ """Build a small model that can be run partially in each step."""
+ image = array_ops.zeros([2, 6, 6, 3])
+
+ kernel1 = variable_scope.get_variable(
+ 'DW', [3, 3, 3, 6],
+ dtypes.float32,
+ initializer=init_ops.random_normal_initializer(stddev=0.001))
+ r1 = nn_ops.conv2d(image, kernel1, [1, 2, 2, 1], padding='SAME')
+
+ kernel2 = variable_scope.get_variable(
+ 'DW2', [2, 3, 3, 6],
+ dtypes.float32,
+ initializer=init_ops.random_normal_initializer(stddev=0.001))
+ r2 = nn_ops.conv2d(image, kernel2, [1, 2, 2, 1], padding='SAME')
+
+ r3 = r1 + r2
+ return r1, r2, r3
+
+
+def SearchTFProfNode(node, name):
+ """Search a node in the tree."""
+ if node.name == name:
+ return node
+ for c in node.children:
+ r = SearchTFProfNode(c, name)
+ if r: return r
+ return None
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py
new file mode 100644
index 0000000000..d705a9b725
--- /dev/null
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/profiler_test.py
@@ -0,0 +1,184 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+
+# pylint: disable=g-bad-import-order
+from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer
+from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer_testlib as lib
+
+
+class ProfilerTest(test.TestCase):
+
+ def testProfileBasic(self):
+ ops.reset_default_graph()
+ opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
+ opts['account_type_regexes'] = ['.*']
+ opts['select'] = ['params', 'float_ops', 'micros', 'bytes',
+ 'device', 'op_types', 'occurrence']
+ outfile = os.path.join(test.get_temp_dir(), 'dump')
+ opts['output'] = 'file:outfile=' + outfile
+
+ # Test the output without run_meta.
+ sess = session.Session()
+ r = lib.BuildFullModel()
+ sess.run(variables.global_variables_initializer())
+
+ profiler = model_analyzer.Profiler(sess.graph)
+ profiler.profile_name_scope(opts)
+ with gfile.Open(outfile, 'r') as f:
+ profiler_str = f.read()
+
+ model_analyzer.print_model_analysis(
+ sess.graph, tfprof_cmd='scope', tfprof_options=opts)
+ with gfile.Open(outfile, 'r') as f:
+ pma_str = f.read()
+ self.assertEqual(pma_str, profiler_str)
+
+ # Test the output with run_meta.
+ run_meta = config_pb2.RunMetadata()
+ _ = sess.run(r,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE),
+ run_metadata=run_meta)
+
+ profiler.add_step(1, run_meta)
+ profiler.profile_graph(opts)
+ with gfile.Open(outfile, 'r') as f:
+ profiler_str = f.read()
+
+ model_analyzer.print_model_analysis(
+ sess.graph, tfprof_cmd='graph', run_meta=run_meta, tfprof_options=opts)
+ with gfile.Open(outfile, 'r') as f:
+ pma_str = f.read()
+ self.assertEqual(pma_str, profiler_str)
+
+ profiler.profile_python_codes(opts)
+ with gfile.Open(outfile, 'r') as f:
+ profiler_str = f.read()
+
+ model_analyzer.print_model_analysis(
+ sess.graph, tfprof_cmd='code', run_meta=run_meta, tfprof_options=opts)
+ with gfile.Open(outfile, 'r') as f:
+ pma_str = f.read()
+ self.assertEqual(pma_str, profiler_str)
+
+ profiler.profile_operations(opts)
+ with gfile.Open(outfile, 'r') as f:
+ profiler_str = f.read()
+
+ model_analyzer.print_model_analysis(
+ sess.graph, tfprof_cmd='op', run_meta=run_meta, tfprof_options=opts)
+ with gfile.Open(outfile, 'r') as f:
+ pma_str = f.read()
+ self.assertEqual(pma_str, profiler_str)
+
+ # Test the output difference between multi-step profile and 1-step profile.
+ _ = sess.run(r,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE),
+ run_metadata=run_meta)
+
+ profiler.add_step(2, run_meta)
+ profiler.profile_name_scope(opts)
+ with gfile.Open(outfile, 'r') as f:
+ profiler_str = f.read()
+
+ model_analyzer.print_model_analysis(
+ sess.graph, tfprof_cmd='scope', run_meta=run_meta, tfprof_options=opts)
+ with gfile.Open(outfile, 'r') as f:
+ pma_str = f.read()
+ self.assertNotEqual(pma_str, profiler_str)
+
+ opts2 = opts.copy()
+ opts2['select'] = ['params', 'float_ops']
+ profiler.profile_name_scope(opts2)
+ with gfile.Open(outfile, 'r') as f:
+ profiler_str = f.read()
+
+ model_analyzer.print_model_analysis(
+ sess.graph, tfprof_cmd='scope', run_meta=run_meta, tfprof_options=opts2)
+ with gfile.Open(outfile, 'r') as f:
+ pma_str = f.read()
+ self.assertEqual(pma_str, profiler_str)
+
+ def testMultiStepProfile(self):
+ ops.reset_default_graph()
+ opts = model_analyzer.PRINT_ALL_TIMING_MEMORY.copy()
+ opts['account_type_regexes'] = ['.*']
+
+ with session.Session() as sess, ops.device('/cpu:0'):
+ r1, r2, r3 = lib.BuildSplitableModel()
+ sess.run(variables.global_variables_initializer())
+
+ profiler = model_analyzer.Profiler(sess.graph)
+ pb0 = profiler.profile_name_scope(opts)
+
+ run_meta = config_pb2.RunMetadata()
+ _ = sess.run(r1,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE),
+ run_metadata=run_meta)
+ profiler.add_step(1, run_meta)
+ pb1 = profiler.profile_name_scope(opts)
+
+ self.assertNotEqual(lib.SearchTFProfNode(pb1, 'DW'), None)
+ self.assertEqual(lib.SearchTFProfNode(pb1, 'DW2'), None)
+ self.assertEqual(lib.SearchTFProfNode(pb1, 'add'), None)
+
+ run_meta2 = config_pb2.RunMetadata()
+ _ = sess.run(r2,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE),
+ run_metadata=run_meta2)
+ profiler.add_step(2, run_meta2)
+ pb2 = profiler.profile_name_scope(opts)
+
+ self.assertNotEqual(lib.SearchTFProfNode(pb2, 'DW'), None)
+ self.assertNotEqual(lib.SearchTFProfNode(pb2, 'DW2'), None)
+ self.assertEqual(lib.SearchTFProfNode(pb2, 'add'), None)
+
+ run_meta3 = config_pb2.RunMetadata()
+ _ = sess.run(r3,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE),
+ run_metadata=run_meta3)
+ profiler.add_step(3, run_meta3)
+ pb3 = profiler.profile_name_scope(opts)
+
+ self.assertNotEqual(lib.SearchTFProfNode(pb3, 'DW'), None)
+ self.assertNotEqual(lib.SearchTFProfNode(pb3, 'DW2'), None)
+ self.assertNotEqual(lib.SearchTFProfNode(pb3, 'add'), None)
+
+ self.assertEqual(lib.SearchTFProfNode(pb0, 'Conv2D'), None)
+ self.assertGreater(lib.SearchTFProfNode(pb1, 'Conv2D').exec_micros, 0)
+ self.assertEqual(lib.SearchTFProfNode(pb1, 'Conv2D_1'), None)
+ self.assertGreater(lib.SearchTFProfNode(pb2, 'Conv2D_1').exec_micros, 0)
+ self.assertEqual(lib.SearchTFProfNode(pb2, 'add'), None)
+ self.assertGreater(lib.SearchTFProfNode(pb3, 'add').exec_micros, 0)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/pywrap_tensorflow_print_model_analysis.i b/tensorflow/contrib/tfprof/python/tools/tfprof/pywrap_tensorflow_print_model_analysis.i
index 05b734a699..582c36e339 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/pywrap_tensorflow_print_model_analysis.i
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/pywrap_tensorflow_print_model_analysis.i
@@ -19,6 +19,8 @@ limitations under the License.
%{
#include "tensorflow/tools/tfprof/internal/print_model_analysis.h"
#include "tensorflow/core/framework/types.h"
+
+using tensorflow::int64;
%}
%typemap(typecheck) const string & = char *;
@@ -37,6 +39,10 @@ limitations under the License.
%unignore tensorflow;
%unignore tensorflow::tfprof;
%unignore tensorflow::tfprof::PrintModelAnalysis;
+%unignore tensorflow::tfprof::NewProfiler;
+%unignore tensorflow::tfprof::DeleteProfiler;
+%unignore tensorflow::tfprof::AddStep;
+%unignore tensorflow::tfprof::Profile;
%include "tensorflow/tools/tfprof/internal/print_model_analysis.h"
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
index e6d504d516..52febef26c 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
@@ -62,13 +62,16 @@ def _fill_missing_graph_shape(graph, run_meta):
return graph
-def _get_logged_ops(graph, run_meta=None, add_trace=True):
+def _get_logged_ops(graph, run_meta=None, add_trace=True,
+ add_trainable_var=True):
"""Extract trainable model parameters and FLOPs for ops from a Graph.
Args:
graph: tf.Graph.
run_meta: RunMetadata proto used to complete shape information.
add_trace: Whether to add op trace information.
+ add_trainable_var: Whether to assign tf.trainable_variables() op type
+ '_trainable_variables'.
Returns:
logged_ops: dict mapping from op_name to OpLogEntry.
"""
@@ -77,6 +80,7 @@ def _get_logged_ops(graph, run_meta=None, add_trace=True):
op_missing_shape = 0
logged_ops = {}
+ # TODO(xpan): Work with Profiler more efficiently.
for op in graph.get_operations():
try:
stats = ops.get_stats_for_node_def(
@@ -105,23 +109,24 @@ def _get_logged_ops(graph, run_meta=None, add_trace=True):
if add_entry:
logged_ops[entry.name] = entry
- for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
- if v.op.name not in logged_ops:
- entry = tfprof_log_pb2.OpLogEntry()
- entry.name = v.op.name
- entry.types.append(TRAINABLE_VARIABLES)
- logged_ops[entry.name] = entry
- else:
- logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES)
+ if add_trainable_var:
+ for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
+ if v.op.name not in logged_ops:
+ entry = tfprof_log_pb2.OpLogEntry()
+ entry.name = v.op.name
+ entry.types.append(TRAINABLE_VARIABLES)
+ logged_ops[entry.name] = entry
+ else:
+ logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES)
+
if op_missing_shape > 0 and not run_meta:
- sys.stderr.write('%d ops no flops stats due to incomplete shapes. '
- 'Consider passing run_meta to use run_time shapes.\n' %
+ sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' %
op_missing_shape)
return logged_ops
def _merge_default_with_oplog(graph, op_log=None, run_meta=None,
- add_trace=True):
+ add_trace=True, add_trainable_var=True):
"""Merge the tfprof default extra info with caller's op_log.
Args:
@@ -129,11 +134,15 @@ def _merge_default_with_oplog(graph, op_log=None, run_meta=None,
op_log: OpLog proto.
run_meta: RunMetadata proto used to complete shape information.
add_trace: Whether to add op trace information.
+ add_trainable_var: Whether to assign tf.trainable_variables() op type
+ '_trainable_variables'.
Returns:
tmp_op_log: Merged OpLog proto.
"""
tmp_op_log = tfprof_log_pb2.OpLog()
- logged_ops = _get_logged_ops(graph, run_meta, add_trace=add_trace)
+ logged_ops = _get_logged_ops(
+ graph, run_meta, add_trace=add_trace, add_trainable_var=add_trainable_var)
+
if not op_log:
tmp_op_log.log_entries.extend(logged_ops.values())
else: