aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Pete Warden <pete@petewarden.com>2016-02-22 16:33:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-23 09:57:26 -0800
commit5bd9073799436daa9275a2990e5df58a75a1aa78 (patch)
treebadff985a50fe94f3687862d6d95541b4ffe001b
parent52fce2481d21800a7a6483286e83315be4645bba (diff)
Add more control over shapes when calculating graph metrics.
Change: 115284554
-rw-r--r--tensorflow/python/tools/BUILD14
-rw-r--r--tensorflow/python/tools/graph_metrics.py77
-rw-r--r--tensorflow/python/tools/graph_metrics_test.py50
3 files changed, 121 insertions, 20 deletions
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index 2ebe65a2b9..7cf0a31f14 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -62,6 +62,20 @@ py_binary(
],
)
+py_test(
+ name = "graph_metrics_test",
+ size = "small",
+ srcs = [
+ "graph_metrics_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":graph_metrics_lib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/python/tools/graph_metrics.py b/tensorflow/python/tools/graph_metrics.py
index 0b18c75ef8..cf11eef1b4 100644
--- a/tensorflow/python/tools/graph_metrics.py
+++ b/tensorflow/python/tools/graph_metrics.py
@@ -54,46 +54,83 @@ tf.flags.DEFINE_integer("batch_size", 1,
"""The batch size to use for the calculations.""")
tf.flags.DEFINE_string("statistics", "weight_parameters,flops",
"""Which statistic types to examine.""")
+tf.flags.DEFINE_string("input_shape_override", "",
+ """If this is set, the comma-separated values will be"""
+ """ used to set the shape of the input layer.""")
+tf.flags.DEFINE_boolean("print_nodes", False,
+ """Whether to show statistics for each op.""")
+
+
+def print_stat(prefix, statistic_type, value):
+ if value is None:
+ friendly_value = "None"
+ else:
+ friendly_value = locale.format("%d", value, grouping=True)
+ print("%s%s=%s" % (prefix, statistic_type, friendly_value))
def main(unused_args):
if not tf.gfile.Exists(FLAGS.graph):
print("Input graph file '" + FLAGS.graph + "' does not exist!")
return -1
-
graph_def = graph_pb2.GraphDef()
with open(FLAGS.graph, "rb") as f:
if FLAGS.input_binary:
graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), graph_def)
- _ = tf.import_graph_def(graph_def, name="")
-
statistic_types = FLAGS.statistics.split(",")
+ if FLAGS.input_shape_override:
+ input_shape_override = map(int, FLAGS.input_shape_override.split(","))
+ else:
+ input_shape_override = None
+ total_stats, node_stats = calculate_graph_metrics(
+ graph_def, statistic_types, FLAGS.input_layer, input_shape_override,
+ FLAGS.batch_size)
+ if FLAGS.print_nodes:
+ for node in graph_def.node:
+ for statistic_type in statistic_types:
+ current_stats = node_stats[statistic_type][node.name]
+ print_stat(node.name + "(" + node.op + "): ", statistic_type,
+ current_stats.value)
+ for statistic_type in statistic_types:
+ value = total_stats[statistic_type].value
+ print_stat("Total: ", statistic_type, value)
+
+
+def calculate_graph_metrics(graph_def, statistic_types, input_layer,
+ input_shape_override, batch_size):
+ """Looks at the performance statistics of all nodes in the graph."""
+ _ = tf.import_graph_def(graph_def, name="")
total_stats = {}
+ node_stats = {}
for statistic_type in statistic_types:
total_stats[statistic_type] = ops.OpStats(statistic_type)
+ node_stats[statistic_type] = {}
+ # Make sure we get pretty-printed numbers with separators.
+ locale.setlocale(locale.LC_ALL, "")
with tf.Session() as sess:
- input_tensor = sess.graph.get_tensor_by_name(FLAGS.input_layer)
- input_shape = input_tensor.get_shape()
- input_shape = [FLAGS.batch_size, input_shape[1], input_shape[2],
- input_shape[3]]
+ input_tensor = sess.graph.get_tensor_by_name(input_layer)
+ input_shape_tensor = input_tensor.get_shape()
+ if input_shape_tensor:
+ input_shape = input_shape_tensor.as_list()
+ else:
+ input_shape = None
+ if input_shape_override:
+ input_shape = input_shape_override
+ input_shape[0] = batch_size
input_tensor.set_shape(input_shape)
for node in graph_def.node:
+ # Ensure that the updated input shape has been fully-propagated before we
+ # ask for the statistics, since they may depend on the output size.
+ op = sess.graph.get_operation_by_name(node.name)
+ ops.set_shapes_for_outputs(op)
for statistic_type in statistic_types:
- node_stats = ops.get_stats_for_node_def(sess.graph, node,
- statistic_type)
- total_stats[statistic_type] += node_stats
- # Make sure we get pretty-printed numbers with separators.
- locale.setlocale(locale.LC_ALL, "")
- for statistic_type in statistic_types:
- value = total_stats[statistic_type].value
- if value is None:
- friendly_value = "None"
- else:
- friendly_value = locale.format("%d", value, grouping=True)
- print("%s=%s" % (statistic_type, friendly_value))
-
+ current_stats = ops.get_stats_for_node_def(sess.graph, node,
+ statistic_type)
+ node_stats[statistic_type][node.name] = current_stats
+ total_stats[statistic_type] += current_stats
+ return total_stats, node_stats
if __name__ == "__main__":
tf.app.run()
diff --git a/tensorflow/python/tools/graph_metrics_test.py b/tensorflow/python/tools/graph_metrics_test.py
new file mode 100644
index 0000000000..8fc20b3fb8
--- /dev/null
+++ b/tensorflow/python/tools/graph_metrics_test.py
@@ -0,0 +1,50 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests the graph metrics tool."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.tools import graph_metrics
+
+
+class GraphMetricsTest(tf.test.TestCase):
+
+ def testGraphMetrics(self):
+ with tf.Graph().as_default():
+ input_node = tf.placeholder(tf.float32, shape=[10, 20], name="input_node")
+ weights_node = tf.constant(0.0,
+ dtype=tf.float32,
+ shape=[20, 5],
+ name="weights_node")
+ tf.matmul(input_node, weights_node, name="matmul_node")
+ sess = tf.Session()
+ graph_def = sess.graph.as_graph_def()
+ statistic_types = ["weight_parameters", "flops"]
+ total_stats, node_stats = graph_metrics.calculate_graph_metrics(
+ graph_def, statistic_types, "input_node:0", None, 10)
+ expected = {"weight_parameters": 100, "flops": 2000}
+ for statistic_type in statistic_types:
+ current_stats = node_stats[statistic_type]["matmul_node"]
+ self.assertEqual(expected[statistic_type], current_stats.value)
+ for statistic_type in statistic_types:
+ current_stats = total_stats[statistic_type]
+ self.assertEqual(expected[statistic_type], current_stats.value)
+
+
+if __name__ == "__main__":
+ tf.test.main()