diff options
author | 2016-02-22 16:33:56 -0800 | |
---|---|---|
committer | 2016-02-23 09:57:26 -0800 | |
commit | 5bd9073799436daa9275a2990e5df58a75a1aa78 (patch) | |
tree | badff985a50fe94f3687862d6d95541b4ffe001b | |
parent | 52fce2481d21800a7a6483286e83315be4645bba (diff) |
Add more control over shapes when calculating graph metrics.
Change: 115284554
-rw-r--r-- | tensorflow/python/tools/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/python/tools/graph_metrics.py | 77 | ||||
-rw-r--r-- | tensorflow/python/tools/graph_metrics_test.py | 50 |
3 files changed, 121 insertions, 20 deletions
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index 2ebe65a2b9..7cf0a31f14 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -62,6 +62,20 @@ py_binary( ], ) +py_test( + name = "graph_metrics_test", + size = "small", + srcs = [ + "graph_metrics_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":graph_metrics_lib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/python/tools/graph_metrics.py b/tensorflow/python/tools/graph_metrics.py index 0b18c75ef8..cf11eef1b4 100644 --- a/tensorflow/python/tools/graph_metrics.py +++ b/tensorflow/python/tools/graph_metrics.py @@ -54,46 +54,83 @@ tf.flags.DEFINE_integer("batch_size", 1, """The batch size to use for the calculations.""") tf.flags.DEFINE_string("statistics", "weight_parameters,flops", """Which statistic types to examine.""") +tf.flags.DEFINE_string("input_shape_override", "", + """If this is set, the comma-separated values will be""" + """ used to set the shape of the input layer.""") +tf.flags.DEFINE_boolean("print_nodes", False, + """Whether to show statistics for each op.""") + + +def print_stat(prefix, statistic_type, value): + if value is None: + friendly_value = "None" + else: + friendly_value = locale.format("%d", value, grouping=True) + print("%s%s=%s" % (prefix, statistic_type, friendly_value)) def main(unused_args): if not tf.gfile.Exists(FLAGS.graph): print("Input graph file '" + FLAGS.graph + "' does not exist!") return -1 - graph_def = graph_pb2.GraphDef() with open(FLAGS.graph, "rb") as f: if FLAGS.input_binary: graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), graph_def) - _ = tf.import_graph_def(graph_def, name="") - statistic_types = FLAGS.statistics.split(",") + if FLAGS.input_shape_override: + input_shape_override = map(int, FLAGS.input_shape_override.split(",")) + else: + input_shape_override = None + total_stats, node_stats = calculate_graph_metrics( + graph_def, statistic_types, FLAGS.input_layer, input_shape_override, + FLAGS.batch_size) + if FLAGS.print_nodes: + for node in graph_def.node: + for statistic_type in statistic_types: + current_stats = node_stats[statistic_type][node.name] + print_stat(node.name + "(" + node.op + "): ", statistic_type, + current_stats.value) + for statistic_type in statistic_types: + value = total_stats[statistic_type].value + print_stat("Total: ", statistic_type, value) + + +def calculate_graph_metrics(graph_def, statistic_types, input_layer, + input_shape_override, batch_size): + """Looks at the performance statistics of all nodes in the graph.""" + _ = tf.import_graph_def(graph_def, name="") total_stats = {} + node_stats = {} for statistic_type in statistic_types: total_stats[statistic_type] = ops.OpStats(statistic_type) + node_stats[statistic_type] = {} + # Make sure we get pretty-printed numbers with separators. + locale.setlocale(locale.LC_ALL, "") with tf.Session() as sess: - input_tensor = sess.graph.get_tensor_by_name(FLAGS.input_layer) - input_shape = input_tensor.get_shape() - input_shape = [FLAGS.batch_size, input_shape[1], input_shape[2], - input_shape[3]] + input_tensor = sess.graph.get_tensor_by_name(input_layer) + input_shape_tensor = input_tensor.get_shape() + if input_shape_tensor: + input_shape = input_shape_tensor.as_list() + else: + input_shape = None + if input_shape_override: + input_shape = input_shape_override + input_shape[0] = batch_size input_tensor.set_shape(input_shape) for node in graph_def.node: + # Ensure that the updated input shape has been fully-propagated before we + # ask for the statistics, since they may depend on the output size. + op = sess.graph.get_operation_by_name(node.name) + ops.set_shapes_for_outputs(op) for statistic_type in statistic_types: - node_stats = ops.get_stats_for_node_def(sess.graph, node, - statistic_type) - total_stats[statistic_type] += node_stats - # Make sure we get pretty-printed numbers with separators. - locale.setlocale(locale.LC_ALL, "") - for statistic_type in statistic_types: - value = total_stats[statistic_type].value - if value is None: - friendly_value = "None" - else: - friendly_value = locale.format("%d", value, grouping=True) - print("%s=%s" % (statistic_type, friendly_value)) - + current_stats = ops.get_stats_for_node_def(sess.graph, node, + statistic_type) + node_stats[statistic_type][node.name] = current_stats + total_stats[statistic_type] += current_stats + return total_stats, node_stats if __name__ == "__main__": tf.app.run() diff --git a/tensorflow/python/tools/graph_metrics_test.py b/tensorflow/python/tools/graph_metrics_test.py new file mode 100644 index 0000000000..8fc20b3fb8 --- /dev/null +++ b/tensorflow/python/tools/graph_metrics_test.py @@ -0,0 +1,50 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests the graph metrics tool.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.python.tools import graph_metrics + + +class GraphMetricsTest(tf.test.TestCase): + + def testGraphMetrics(self): + with tf.Graph().as_default(): + input_node = tf.placeholder(tf.float32, shape=[10, 20], name="input_node") + weights_node = tf.constant(0.0, + dtype=tf.float32, + shape=[20, 5], + name="weights_node") + tf.matmul(input_node, weights_node, name="matmul_node") + sess = tf.Session() + graph_def = sess.graph.as_graph_def() + statistic_types = ["weight_parameters", "flops"] + total_stats, node_stats = graph_metrics.calculate_graph_metrics( + graph_def, statistic_types, "input_node:0", None, 10) + expected = {"weight_parameters": 100, "flops": 2000} + for statistic_type in statistic_types: + current_stats = node_stats[statistic_type]["matmul_node"] + self.assertEqual(expected[statistic_type], current_stats.value) + for statistic_type in statistic_types: + current_stats = total_stats[statistic_type] + self.assertEqual(expected[statistic_type], current_stats.value) + + +if __name__ == "__main__": + tf.test.main() |