aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tfprof
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tfprof')
-rw-r--r--tensorflow/contrib/tfprof/README.md78
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py10
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py38
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py9
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py10
5 files changed, 39 insertions, 106 deletions
diff --git a/tensorflow/contrib/tfprof/README.md b/tensorflow/contrib/tfprof/README.md
index d891ecdc9a..5bfa0247a5 100644
--- a/tensorflow/contrib/tfprof/README.md
+++ b/tensorflow/contrib/tfprof/README.md
@@ -2,81 +2,25 @@
# Full Docment in tensorflow/tools/tfprof/README.md
-Author: Xin Pan (xpan@google.com, github: panyx0718)
+Author: Xin Pan (xpan@google.com, github: panyx0718), Jon Shlens, Yao Zhang
Consultants: Jon Shlens, Pete Warden
###Major Features
1. Measure model parameters, float operations, tensor shapes.
-2. Measure op execution times, requested memory size and device placement.
+2. Profile op execution times, requested memory size and device placement.
3. Inspect checkpoint tensors' shapes and their values.
-4. 3 ways to view and explore TensorFlow model profiles
+4. Selectively group, filter, account and order ops.
- * Organize by Python code call stack.
- * Organize by TensorFlow operation name scope hierarchies.
- * Organize by TensorFlow operation inputs/outputs graph.
+####tfprof supports 3 views to organize TensorFlow model profiles
-5. Selectively grouping/filtering/accounting/ordering ops.
+ * code view: Stats are associated your Python codes and organized as call stacks.
+ * scope view: Stats are organized as name scope hierarchies.
+ * graph view: Stats are organized as Tensorflow Op graph.
-tfprof can be used as Python API, Interactive CLI and One-shot Script.
+####For each view, there are 3 ways to display outputs:
-## Python API Tutorials
-
-tfprof is part of TensorFlow core. Simply ```import tensorflow as tf```.
-
-### Examine the shapes and sizes of all trainiable Variables.
-```python
-# Print trainable variable parameter statistics to stdout.
-param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
- tf.get_default_graph(),
- tfprof_options=tf.contrib.tfprof.model_analyzer.
- TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
-
-# param_stats is tensorflow.tfprof.TFGraphNodeProto proto.
-# It organize the statistics
-# of each graph node in tree scructure. Let's print the root below.
-sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
-```
-
-### Examine the number of floating point operations
-``` python
-# Print to stdout an analysis of the number of floating point operations in the
-# model broken down by individual operations.
-#
-# Note: Only Ops with RegisterStatistics('flops') defined have flop stats. It
-# also requires complete shape information. It is common that shape is unknown
-# statically. To complete the shape, provide run-time shape information with
-# tf.RunMetadata to the API (See next example on how to provide RunMetadata).
-tf.contrib.tfprof.model_analyzer.print_model_analysis(
- tf.get_default_graph(),
- tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
-```
-
-### Examine the timing and memory usage
-You will first need to run the following set up in your model in order to
-compute the memory and timing statistics.
-
-```python
-# Generate the meta information for the model that contains the memory usage
-# and timing information.
-run_metadata = tf.RunMetadata()
-with tf.Session() as sess:
- _ = sess.run(train_op,
- options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
- run_metadata=run_metadata)
-```
-
-Finally, you may run `print_model_analysis` to explore the timing and memory
-demands of the model.
-
-``` python
-# Print to stdout an analysis of the memory usage and the timing information
-# from running the graph broken down by operations.
-tf.contrib.tfprof.model_analyzer.print_model_analysis(
- tf.get_default_graph(),
- run_meta=run_metadata,
- tfprof_options=tf.contrib.tfprof.model_analyzer.PRINT_ALL_TIMING_MEMORY)
-```
-
-Users can change ```tfprof_options``` to fully leverage tfprof's power.
+ * stdout: Results are written to stdout.
+ * timeline: Visualized in chrome browser as time series.
+ * file: Results are dumped to file.
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
index 13b407d815..17dff69edd 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer.py
@@ -45,7 +45,7 @@ TRAINABLE_VARS_PARAMS_STAT_OPTIONS = {
'hide_name_regexes': [],
'account_displayed_op_only': True,
'select': ['params'],
- 'viz': False,
+ 'output': 'stdout',
'dump_to_file': ''
}
@@ -65,7 +65,7 @@ FLOAT_OPS_OPTIONS = {
'hide_name_regexes': [],
'account_displayed_op_only': True,
'select': ['float_ops'],
- 'viz': False,
+ 'output': 'stdout',
'dump_to_file': ''
}
@@ -87,7 +87,7 @@ PRINT_PARAMS_ON_DEVICE = {
'hide_name_regexes': [],
'account_displayed_op_only': False,
'select': ['device', 'params'],
- 'viz': False,
+ 'output': 'stdout',
'dump_to_file': ''
}
@@ -107,7 +107,7 @@ PRINT_ALL_TIMING_MEMORY = {
'hide_name_regexes': [],
'account_displayed_op_only': True,
'select': ['micros', 'bytes'],
- 'viz': False,
+ 'output': 'stdout',
'dump_to_file': ''
}
@@ -178,7 +178,7 @@ def print_model_analysis(graph,
opts.account_displayed_op_only = tfprof_options['account_displayed_op_only']
for p in tfprof_options['select']:
opts.select.append(p)
- opts.viz = tfprof_options['viz']
+ opts.output = tfprof_options['output']
opts.dump_to_file = tfprof_options['dump_to_file']
run_meta_str = run_meta.SerializeToString() if run_meta else b''
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
index 55167576e2..afd8563e78 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py
@@ -35,13 +35,14 @@ class PrintModelAnalysisTest(test.TestCase):
def testDumpToFile(self):
ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
- opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
+ outfile = os.path.join(test.get_temp_dir(), 'dump')
+ opts['output'] = 'file:outfile=' + outfile
with session.Session() as sess, ops.device('/cpu:0'):
_ = lib.BuildSmallModel()
model_analyzer.print_model_analysis(sess.graph, tfprof_options=opts)
- with gfile.Open(opts['dump_to_file'], 'r') as f:
+ with gfile.Open(outfile, 'r') as f:
self.assertEqual(u'_TFProfRoot (--/451 params)\n'
' DW (3x3x3x6, 162/162 params)\n'
' DW2 (2x2x6x12, 288/288 params)\n'
@@ -51,15 +52,14 @@ class PrintModelAnalysisTest(test.TestCase):
def testSelectEverything(self):
ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
- opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
+ outfile = os.path.join(test.get_temp_dir(), 'dump')
+ opts['output'] = 'file:outfile=' + outfile
opts['account_type_regexes'] = ['.*']
opts['select'] = [
'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device', 'op_types'
]
- config = config_pb2.ConfigProto(
- graph_options=config_pb2.GraphOptions(build_cost_model=1))
- with session.Session(config=config) as sess, ops.device('/cpu:0'):
+ with session.Session() as sess, ops.device('/cpu:0'):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -72,17 +72,18 @@ class PrintModelAnalysisTest(test.TestCase):
model_analyzer.print_model_analysis(
sess.graph, run_meta, tfprof_options=opts)
- with gfile.Open(opts['dump_to_file'], 'r') as f:
+ with gfile.Open(outfile, 'r') as f:
# pylint: disable=line-too-long
self.assertEqual(
- '_TFProfRoot (0/451 params, 0/10.44k flops, 0B/5.28KB, _kTFScopeParent)\n Conv2D (0/0 params, 5.83k/5.83k flops, 432B/432B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, 384B/384B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n DW (3x3x3x6, 162/162 params, 0/0 flops, 648B/1.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n DW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/read (0/0 params, 0/0 flops, 648B/648B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, 1.15KB/2.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW2/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n DW2/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW2/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/read (0/0 params, 0/0 flops, 1.15KB/1.15KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n ScalarW (1, 1/1 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|VariableV2|_trainable_variables)\n ScalarW/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n ScalarW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n ScalarW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n ScalarW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n ScalarW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n ScalarW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/read (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Identity)\n init (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|NoOp)\n zeros (0/0 params, 0/0 flops, 864B/864B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const)\n',
+ '_TFProfRoot (0/451 params, 0/10.44k flops, 0B/5.28KB, _kTFScopeParent)\n Conv2D (0/0 params, 5.83k/5.83k flops, 432B/432B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, 384B/384B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n DW (3x3x3x6, 162/162 params, 0/0 flops, 648B/1.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW/Assign (0/0 params, 0/0 flops, 0B/0B, Assign)\n DW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/read (0/0 params, 0/0 flops, 648B/648B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, 1.15KB/2.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW2/Assign (0/0 params, 0/0 flops, 0B/0B, Assign)\n DW2/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW2/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/read (0/0 params, 0/0 flops, 1.15KB/1.15KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n ScalarW (1, 1/1 params, 0/0 flops, 0B/0B, VariableV2|_trainable_variables)\n ScalarW/Assign (0/0 params, 0/0 flops, 0B/0B, Assign)\n ScalarW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n ScalarW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n ScalarW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n ScalarW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n ScalarW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n ScalarW/read (0/0 params, 0/0 flops, 0B/0B, Identity)\n init (0/0 params, 0/0 flops, 0B/0B, NoOp)\n zeros (0/0 params, 0/0 flops, 864B/864B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const)\n',
f.read())
# pylint: enable=line-too-long
def testSimpleCodeView(self):
ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
- opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
+ outfile = os.path.join(test.get_temp_dir(), 'dump')
+ opts['output'] = 'file:outfile=' + outfile
opts['account_type_regexes'] = ['.*']
opts['show_name_regexes'] = ['.*model_analyzer_testlib.*']
opts['account_displayed_op_only'] = False
@@ -92,9 +93,7 @@ class PrintModelAnalysisTest(test.TestCase):
'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device',
]
- config = config_pb2.ConfigProto(
- graph_options=config_pb2.GraphOptions(build_cost_model=1))
- with session.Session(config=config) as sess, ops.device('/cpu:0'):
+ with session.Session() as sess, ops.device('/cpu:0'):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -107,7 +106,7 @@ class PrintModelAnalysisTest(test.TestCase):
model_analyzer.print_model_analysis(
sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
- with gfile.Open(opts['dump_to_file'], 'r') as f:
+ with gfile.Open(outfile, 'r') as f:
# pylint: disable=line-too-long
self.assertEqual(
'_TFProfRoot (0/451 params, 0/10.44k flops, 0B/5.28KB)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops, 0B/864B)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/1 params, 0/0 flops, 0B/0B)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/162 params, 0/0 flops, 0B/1.30KB)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/5.83k flops, 0B/432B)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/288 params, 0/0 flops, 0B/2.30KB)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/4.61k flops, 0B/384B)\n',
@@ -117,15 +116,14 @@ class PrintModelAnalysisTest(test.TestCase):
def testComplexCodeView(self):
ops.reset_default_graph()
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS.copy()
- opts['dump_to_file'] = os.path.join(test.get_temp_dir(), 'dump')
+ outfile = os.path.join(test.get_temp_dir(), 'dump')
+ opts['output'] = 'file:outfile=' + outfile
opts['account_type_regexes'] = ['.*']
opts['show_name_regexes'] = ['.*model_analyzer_testlib.py.*']
opts['account_displayed_op_only'] = False
opts['select'] = ['params', 'float_ops']
- config = config_pb2.ConfigProto(
- graph_options=config_pb2.GraphOptions(build_cost_model=1))
- with session.Session(config=config) as sess, ops.device('/cpu:0'):
+ with session.Session() as sess, ops.device('/cpu:0'):
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -139,7 +137,7 @@ class PrintModelAnalysisTest(test.TestCase):
sess.graph, run_meta, tfprof_cmd='code', tfprof_options=opts)
# pylint: disable=line-too-long
- with gfile.Open(opts['dump_to_file'], 'r') as f:
+ with gfile.Open(outfile, 'r') as f:
self.assertEqual(
'_TFProfRoot (0/2.84k params, 0/54.08k flops)\n model_analyzer_testlib.py:56:BuildFullModel:seq.append(array_... (0/1.80k params, 0/41.76k flops)\n model_analyzer_testlib.py:33:BuildSmallModel:image = array_ops... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:37:BuildSmallModel:initializer=init_... (0/4 params, 0/0 flops)\n model_analyzer_testlib.py:41:BuildSmallModel:initializer=init_... (0/648 params, 0/0 flops)\n model_analyzer_testlib.py:42:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:46:BuildSmallModel:initializer=init_... (0/1.15k params, 0/0 flops)\n model_analyzer_testlib.py:47:BuildSmallModel:x = nn_ops.conv2d... (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:60:BuildFullModel:cell, array_ops.c... (0/1.04k params, 0/4.13k flops)\n model_analyzer_testlib.py:62:BuildFullModel:target = array_op... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:63:BuildFullModel:loss = nn_ops.l2_... (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:65:BuildFullModel:return sgd_op.min... (0/0 params, 0/8.19k flops)\n',
f.read())
@@ -170,9 +168,7 @@ class PrintModelAnalysisTest(test.TestCase):
'bytes', 'params', 'float_ops', 'num_hidden_ops', 'device'
]
- config = config_pb2.ConfigProto(
- graph_options=config_pb2.GraphOptions(build_cost_model=1))
- with session.Session(config=config) as sess, ops.device('/cpu:0'):
+ with session.Session() as sess, ops.device('/cpu:0'):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py
index aa133d3142..c3e9fc9cc0 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/print_model_analysis_test.py
@@ -51,7 +51,7 @@ TEST_OPTIONS = {
'hide_name_regexes': [],
'account_displayed_op_only': True,
'select': ['params'],
- 'viz': False
+ 'output': 'stdout',
}
# pylint: enable=bad-whitespace
@@ -92,7 +92,7 @@ class PrintModelAnalysisTest(test.TestCase):
opts.account_displayed_op_only = TEST_OPTIONS['account_displayed_op_only']
for p in TEST_OPTIONS['select']:
opts.select.append(p)
- opts.viz = TEST_OPTIONS['viz']
+ opts.output = TEST_OPTIONS['output']
with session.Session() as sess, ops.device('/cpu:0'):
_ = self._BuildSmallModel()
@@ -116,7 +116,6 @@ class PrintModelAnalysisTest(test.TestCase):
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
- device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
@@ -128,7 +127,6 @@ class PrintModelAnalysisTest(test.TestCase):
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 648
- device: "/device:CPU:0"
children {
name: "DW/Assign"
exec_micros: 0
@@ -136,7 +134,6 @@ class PrintModelAnalysisTest(test.TestCase):
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
- device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
@@ -217,7 +214,6 @@ class PrintModelAnalysisTest(test.TestCase):
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
- device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
@@ -231,7 +227,6 @@ class PrintModelAnalysisTest(test.TestCase):
total_exec_micros: 0
total_requested_bytes: 0
total_parameters: 0
- device: "/device:CPU:0"
float_ops: 0
total_float_ops: 0
}
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
index cd3912bbfb..e6d504d516 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
@@ -62,7 +62,7 @@ def _fill_missing_graph_shape(graph, run_meta):
return graph
-def _get_logged_ops(graph, run_meta=None, add_trace=False):
+def _get_logged_ops(graph, run_meta=None, add_trace=True):
"""Extract trainable model parameters and FLOPs for ops from a Graph.
Args:
@@ -120,9 +120,8 @@ def _get_logged_ops(graph, run_meta=None, add_trace=False):
return logged_ops
-def _merge_default_with_oplog(graph, op_log=None,
- run_meta=None,
- add_trace=False):
+def _merge_default_with_oplog(graph, op_log=None, run_meta=None,
+ add_trace=True):
"""Merge the tfprof default extra info with caller's op_log.
Args:
@@ -154,8 +153,7 @@ def _merge_default_with_oplog(graph, op_log=None,
return tmp_op_log
-def write_op_log(graph, log_dir, op_log=None, run_meta=None,
- add_trace=False):
+def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True):
"""Log provided 'op_log', and add additional model information below.
The API also assigns ops in tf.trainable_variables() an op type called