From 98ccdcd2a0a0835a2804a00d66fcd6014df099bc Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 11 May 2017 16:23:15 -0700 Subject: Automated g4 rollback of changelist 155779520 PiperOrigin-RevId: 155811666 --- tensorflow/contrib/tfprof/README.md | 78 +++++++++++++++++++--- .../tfprof/python/tools/tfprof/model_analyzer.py | 10 +-- .../python/tools/tfprof/model_analyzer_test.py | 38 ++++++----- .../tools/tfprof/print_model_analysis_test.py | 9 ++- .../tfprof/python/tools/tfprof/tfprof_logger.py | 10 +-- 5 files changed, 106 insertions(+), 39 deletions(-) (limited to 'tensorflow/contrib/tfprof') diff --git a/tensorflow/contrib/tfprof/README.md b/tensorflow/contrib/tfprof/README.md index 5bfa0247a5..d891ecdc9a 100644 --- a/tensorflow/contrib/tfprof/README.md +++ b/tensorflow/contrib/tfprof/README.md @@ -2,25 +2,81 @@ # Full Docment in tensorflow/tools/tfprof/README.md -Author: Xin Pan (xpan@google.com, github: panyx0718), Jon Shlens, Yao Zhang +Author: Xin Pan (xpan@google.com, github: panyx0718) Consultants: Jon Shlens, Pete Warden ###Major Features 1. Measure model parameters, float operations, tensor shapes. -2. Profile op execution times, requested memory size and device placement. +2. Measure op execution times, requested memory size and device placement. 3. Inspect checkpoint tensors' shapes and their values. -4. Selectively group, filter, account and order ops. +4. 3 ways to view and explore TensorFlow model profiles -####tfprof supports 3 views to organize TensorFlow model profiles + * Organize by Python code call stack. + * Organize by TensorFlow operation name scope hierarchies. + * Organize by TensorFlow operation inputs/outputs graph. - * code view: Stats are associated your Python codes and organized as call stacks. - * scope view: Stats are organized as name scope hierarchies. - * graph view: Stats are organized as Tensorflow Op graph. +5. Selectively grouping/filtering/accounting/ordering ops. -####For each view, there are 3 ways to display outputs: +tfprof can be used as Python API, Interactive CLI and One-shot Script. - * stdout: Results are written to stdout. - * timeline: Visualized in chrome browser as time series. - * file: Results are dumped to file. +## Python API Tutorials + +tfprof is part of TensorFlow core. Simply ```import tensorflow as tf```. + +### Examine the shapes and sizes of all trainiable Variables. +```python +# Print trainable variable parameter statistics to stdout. +param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis( + tf.get_default_graph(), + tfprof_options=tf.contrib.tfprof.model_analyzer. + TRAINABLE_VARS_PARAMS_STAT_OPTIONS) + +# param_stats is tensorflow.tfprof.TFGraphNodeProto proto. +# It organize the statistics +# of each graph node in tree scructure. Let's print the root below. +sys.stdout.write('total_params: %d\n' % param_stats.total_parameters) +``` + +### Examine the number of floating point operations +``` python +# Print to stdout an analysis of the number of floating point operations in the +# model broken down by individual operations. +# +# Note: Only Ops with RegisterStatistics('flops') defined have flop stats. It +# also requires complete shape information. It is common that shape is unknown +# statically. To complete the shape, provide run-time shape information with +# tf.RunMetadata to the API (See next example on how to provide RunMetadata). +tf.contrib.tfprof.model_analyzer.print_model_analysis( + tf.get_default_graph(), + tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS) +``` + +### Examine the timing and memory usage +You will first need to run the following set up in your model in order to +compute the memory and timing statistics. + +```python +# Generate the meta information for the model that contains the memory usage +# and timing information. +run_metadata = tf.RunMetadata() +with tf.Session() as sess: + _ = sess.run(train_op, + options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), + run_metadata=run_metadata) +``` + +Finally, you may run `print_model_analysis` to explore the timing and memory +demands of the model. + +``` python +# Print to stdout an analysis of the memory usage and the timing information +# from running the graph broken down by operations. +tf.contrib.tfprof.model_analyzer.print_model_analysis( + tf.get_default_graph(), + run_meta=run_metadata, + tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY) +``` + +Users can change ```tfprof_options``` to fully leverage tfprof's power. diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py index 17dff69edd..13b407d815 100644 --- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py @@ -45,7 +45,7 @@ TRAINABLE_VARS_PARAMS_STAT_OPTIONS = { 'hide_name_regexes': [], 'account_displayed_op_only': True, 'select': ['params'], - 'output': 'stdout', + 'viz': False, 'dump_to_file': '' } @@ -65,7 +65,7 @@ FLOAT_OPS_OPTIONS = { 'hide_name_regexes': [], 'account_displayed_op_only': True, 'select': ['float_ops'], - 'output': 'stdout', + 'viz': False, 'dump_to_file': '' } @@ -87,7 +87,7 @@ PRINT_PARAMS_ON_DEVICE = { 'hide_name_regexes': [], 'account_displayed_op_only': False, 'select': ['device', 'params'], - 'output': 'stdout', + 'viz': False, 'dump_to_file': '' } @@ -107,7 +107,7 @@ PRINT_ALL_TIMING_MEMORY = { 'hide_name_regexes': [], 'account_displayed_op_only': True, 'select': ['micros', 'bytes'], - 'output': 'stdout', + 'viz': False, 'dump_to_file': '' } @@ -178,7 +178,7 @@ def print_model_analysis(graph, opts.account_displayed_op_only = tfprof_options['account_displayed_op_only'] for p in tfprof_options['select']: opts.select.append(p) - opts.output = tfprof_options['output'] + opts.viz = tfprof_options['viz'] opts.dump_to_file = tfprof_options['dump_to_file'] run_meta_str = run_meta.SerializeToString() if run_meta else b'' diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py index afd8563e78..55167576e2 100644 --- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py @@ -35,14 +35,13 @@ class PrintModelAnalysisTest(test.TestCase): def testDumpToFile(self): ops.reset_default_graph() opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS - outfile = os.path.join(test.get_temp_dir(), 'dump') - opts['output'] = 'file:outfile=' + outfile + opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump') with session.Session() as sess, ops.device('/cpu:0'): _ = lib.BuildSmallModel() model_analyzer.print_model_analysis(sess.graph, tfprof_options=opts) - with gfile.Open(outfile, 'r') as f: + with gfile.Open(opts['dump_to_file'], 'r') as f: self.assertEqual(u'_TFProfRoot (--/451 params)\n' ' DW (3x3x3x6, 162/162 params)\n' ' DW2 (2x2x6x12, 288/288 params)\n' @@ -52,14 +51,15 @@ class PrintModelAnalysisTest(test.TestCase): def testSelectEverything(self): ops.reset_default_graph() opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS - outfile = os.path.join(test.get_temp_dir(), 'dump') - opts['output'] = 'file:outfile=' + outfile + opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump') opts['account_type_regexes'] = ['.*'] opts['select'] = [ 'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device', 'op_types' ] - with session.Session() as sess, ops.device('/cpu:0'): + config = config_pb2.ConfigProto( + graph_options=config_pb2.GraphOptions(build_cost_model=1)) + with session.Session(config=config) as sess, ops.device('/cpu:0'): x = lib.BuildSmallModel() sess.run(variables.global_variables_initializer()) @@ -72,18 +72,17 @@ class PrintModelAnalysisTest(test.TestCase): model_analyzer.print_model_analysis( sess.graph, run_meta, tfprof_options=opts) - with gfile.Open(outfile, 'r') as f: + with gfile.Open(opts['dump_to_file'], 'r') as f: # pylint: disable=line-too-long self.assertEqual( - '_TFProfRoot (0/451 params, 0/10.44k flops, 0B/5.28KB, _kTFScopeParent)\n Conv2D (0/0 params, 5.83k/5.83k flops, 432B/432B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, 384B/384B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n DW (3x3x3x6, 162/162 params, 0/0 flops, 648B/1.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW/Assign (0/0 params, 0/0 flops, 0B/0B, Assign)\n DW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/read (0/0 params, 0/0 flops, 648B/648B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, 1.15KB/2.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW2/Assign (0/0 params, 0/0 flops, 0B/0B, Assign)\n DW2/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW2/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/read (0/0 params, 0/0 flops, 1.15KB/1.15KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n ScalarW (1, 1/1 params, 0/0 flops, 0B/0B, VariableV2|_trainable_variables)\n ScalarW/Assign (0/0 params, 0/0 flops, 0B/0B, Assign)\n ScalarW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n ScalarW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n ScalarW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n ScalarW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n ScalarW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/read (0/0 params, 0/0 flops, 0B/0B, Identity)\n init (0/0 params, 0/0 flops, 0B/0B, NoOp)\n zeros (0/0 params, 0/0 flops, 864B/864B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const)\n', + '_TFProfRoot (0/451 params, 0/10.44k flops, 0B/5.28KB, _kTFScopeParent)\n Conv2D (0/0 params, 5.83k/5.83k flops, 432B/432B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, 384B/384B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n DW (3x3x3x6, 162/162 params, 0/0 flops, 648B/1.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n DW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/read (0/0 params, 0/0 flops, 648B/648B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, 1.15KB/2.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW2/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n DW2/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW2/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/read (0/0 params, 0/0 flops, 1.15KB/1.15KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n ScalarW (1, 1/1 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|VariableV2|_trainable_variables)\n ScalarW/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n ScalarW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n ScalarW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n ScalarW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n ScalarW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n ScalarW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/read (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Identity)\n init (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|NoOp)\n zeros (0/0 params, 0/0 flops, 864B/864B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const)\n', f.read()) # pylint: enable=line-too-long def testSimpleCodeView(self): ops.reset_default_graph() opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy() - outfile = os.path.join(test.get_temp_dir(), 'dump') - opts['output'] = 'file:outfile=' + outfile + opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump') opts['account_type_regexes'] = ['.*'] opts['show_name_regexes'] = ['.*model_analyzer_testlib.*'] opts['account_displayed_op_only'] = False @@ -93,7 +92,9 @@ class PrintModelAnalysisTest(test.TestCase): 'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device', ] - with session.Session() as sess, ops.device('/cpu:0'): + config = config_pb2.ConfigProto( + graph_options=config_pb2.GraphOptions(build_cost_model=1)) + with session.Session(config=config) as sess, ops.device('/cpu:0'): x = lib.BuildSmallModel() sess.run(variables.global_variables_initializer()) @@ -106,7 +107,7 @@ class PrintModelAnalysisTest(test.TestCase): model_analyzer.print_model_analysis( sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts) - with gfile.Open(outfile, 'r') as f: + with gfile.Open(opts['dump_to_file'], 'r') as f: # pylint: disable=line-too-long self.assertEqual( '_TFProfRoot (0/451 params, 0/10.44k flops, 0B/5.28KB)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops, 0B/864B)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/1 params, 0/0 flops, 0B/0B)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/162 params, 0/0 flops, 0B/1.30KB)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/5.83k flops, 0B/432B)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/288 params, 0/0 flops, 0B/2.30KB)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/4.61k flops, 0B/384B)\n', @@ -116,14 +117,15 @@ class PrintModelAnalysisTest(test.TestCase): def testComplexCodeView(self): ops.reset_default_graph() opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy() - outfile = os.path.join(test.get_temp_dir(), 'dump') - opts['output'] = 'file:outfile=' + outfile + opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump') opts['account_type_regexes'] = ['.*'] opts['show_name_regexes'] = ['.*model_analyzer_testlib.py.*'] opts['account_displayed_op_only'] = False opts['select'] = ['params', 'float_ops'] - with session.Session() as sess, ops.device('/cpu:0'): + config = config_pb2.ConfigProto( + graph_options=config_pb2.GraphOptions(build_cost_model=1)) + with session.Session(config=config) as sess, ops.device('/cpu:0'): x = lib.BuildFullModel() sess.run(variables.global_variables_initializer()) @@ -137,7 +139,7 @@ class PrintModelAnalysisTest(test.TestCase): sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts) # pylint: disable=line-too-long - with gfile.Open(outfile, 'r') as f: + with gfile.Open(opts['dump_to_file'], 'r') as f: self.assertEqual( '_TFProfRoot (0/2.84k params, 0/54.08k flops)\n model_analyzer_testlib.py:56:BuildFullModel:seq.append(array_... (0/1.80k params, 0/41.76k flops)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/4 params, 0/0 flops)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/648 params, 0/0 flops)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/1.15k params, 0/0 flops)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:60:BuildFullModel:cell, array_ops.c... (0/1.04k params, 0/4.13k flops)\n model_analyzer_testlib.py:62:BuildFullModel:target = array_op... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:63:BuildFullModel:loss = nn_ops.l2_... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:65:BuildFullModel:return sgd_op.min... (0/0 params, 0/8.19k flops)\n', f.read()) @@ -168,7 +170,9 @@ class PrintModelAnalysisTest(test.TestCase): 'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device' ] - with session.Session() as sess, ops.device('/cpu:0'): + config = config_pb2.ConfigProto( + graph_options=config_pb2.GraphOptions(build_cost_model=1)) + with session.Session(config=config) as sess, ops.device('/cpu:0'): x = lib.BuildSmallModel() sess.run(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py index c3e9fc9cc0..aa133d3142 100644 --- a/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py @@ -51,7 +51,7 @@ TEST_OPTIONS = { 'hide_name_regexes': [], 'account_displayed_op_only': True, 'select': ['params'], - 'output': 'stdout', + 'viz': False } # pylint: enable=bad-whitespace @@ -92,7 +92,7 @@ class PrintModelAnalysisTest(test.TestCase): opts.account_displayed_op_only = TEST_OPTIONS['account_displayed_op_only'] for p in TEST_OPTIONS['select']: opts.select.append(p) - opts.output = TEST_OPTIONS['output'] + opts.viz = TEST_OPTIONS['viz'] with session.Session() as sess, ops.device('/cpu:0'): _ = self._BuildSmallModel() @@ -116,6 +116,7 @@ class PrintModelAnalysisTest(test.TestCase): total_exec_micros: 0 total_requested_bytes: 0 total_parameters: 0 + device: "/device:CPU:0" float_ops: 0 total_float_ops: 0 } @@ -127,6 +128,7 @@ class PrintModelAnalysisTest(test.TestCase): total_exec_micros: 0 total_requested_bytes: 0 total_parameters: 648 + device: "/device:CPU:0" children { name: "DW/Assign" exec_micros: 0 @@ -134,6 +136,7 @@ class PrintModelAnalysisTest(test.TestCase): total_exec_micros: 0 total_requested_bytes: 0 total_parameters: 0 + device: "/device:CPU:0" float_ops: 0 total_float_ops: 0 } @@ -214,6 +217,7 @@ class PrintModelAnalysisTest(test.TestCase): total_exec_micros: 0 total_requested_bytes: 0 total_parameters: 0 + device: "/device:CPU:0" float_ops: 0 total_float_ops: 0 } @@ -227,6 +231,7 @@ class PrintModelAnalysisTest(test.TestCase): total_exec_micros: 0 total_requested_bytes: 0 total_parameters: 0 + device: "/device:CPU:0" float_ops: 0 total_float_ops: 0 } diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py index e6d504d516..cd3912bbfb 100644 --- a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py @@ -62,7 +62,7 @@ def _fill_missing_graph_shape(graph, run_meta): return graph -def _get_logged_ops(graph, run_meta=None, add_trace=True): +def _get_logged_ops(graph, run_meta=None, add_trace=False): """Extract trainable model parameters and FLOPs for ops from a Graph. Args: @@ -120,8 +120,9 @@ def _get_logged_ops(graph, run_meta=None, add_trace=True): return logged_ops -def _merge_default_with_oplog(graph, op_log=None, run_meta=None, - add_trace=True): +def _merge_default_with_oplog(graph, op_log=None, + run_meta=None, + add_trace=False): """Merge the tfprof default extra info with caller's op_log. Args: @@ -153,7 +154,8 @@ def _merge_default_with_oplog(graph, op_log=None, run_meta=None, return tmp_op_log -def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True): +def write_op_log(graph, log_dir, op_log=None, run_meta=None, + add_trace=False): """Log provided 'op_log', and add additional model information below. The API also assigns ops in tf.trainable_variables() an op type called -- cgit v1.2.3