diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-09-30 12:20:29 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-30 13:34:53 -0700 |
commit | 113093b017c8b2654c052a054e63738174bae649 (patch) | |
tree | 63235478ca56348efaf9be32cee298f13a4eeafb /tensorflow/contrib/tfprof | |
parent | 7cbf9454c04d9d0d7f05a5ca688d174e0e82a2ed (diff) |
1. Port latest tfprof changes
2. Allow users to call tfprof_logger as tf.contrib.tfprof.tfprof_logger
3. Add a test
Change: 134819165
Diffstat (limited to 'tensorflow/contrib/tfprof')
-rw-r--r-- | tensorflow/contrib/tfprof/BUILD | 17 | ||||
-rw-r--r-- | tensorflow/contrib/tfprof/README.md | 12 | ||||
-rw-r--r-- | tensorflow/contrib/tfprof/__init__.py | 21 | ||||
-rw-r--r-- | tensorflow/contrib/tfprof/python/tools/tfprof/BUILD | 19 | ||||
-rw-r--r-- | tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py | 56 | ||||
-rw-r--r-- | tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger_test.py | 72 |
6 files changed, 177 insertions, 20 deletions
diff --git a/tensorflow/contrib/tfprof/BUILD b/tensorflow/contrib/tfprof/BUILD new file mode 100644 index 0000000000..d55bda1bd0 --- /dev/null +++ b/tensorflow/contrib/tfprof/BUILD @@ -0,0 +1,17 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +py_library( + name = "tfprof", + srcs = [ + "__init__.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/contrib/tfprof/python/tools/tfprof:tfprof_logger", + ], +) diff --git a/tensorflow/contrib/tfprof/README.md b/tensorflow/contrib/tfprof/README.md index 0e6420134a..27a08c514c 100644 --- a/tensorflow/contrib/tfprof/README.md +++ b/tensorflow/contrib/tfprof/README.md @@ -1,12 +1,11 @@ # tfprof: A Profiling Tool for TensorFlow Models -go/tfprof +Internal User Please Use: go/tfprof Author: Xin Pan (xpan@google.com, github: panyx0718) -Consultants: Jon Shlens (shlens@google.com), Pete Warden (petewarden@google.com) +Consultants: Jon Shlens, Pete Warden -[TOC] ## Introduction @@ -259,6 +258,7 @@ First, in Python code, create an `OpLog` proto and add op type information to it: ```python + op_log = tfprof_log_pb2.OpLog() entry = op_log.log_entries.add() entry.name = 'pool_logit/DW' @@ -274,7 +274,7 @@ entry.types.append('pool_logit') Second, call write_op_log to write the OpLog proto. ```python -tfprof_logger.write_op_log(sess.graph, /tmp/my_op_log_dir, op_log) +tf.tfprof.tfprof_logger.write_op_log(sess.graph, /tmp/my_op_log_dir, op_log) ``` Third, when starting the tfprof tool, specify @@ -288,8 +288,8 @@ _TFProfRoot (--/650 params) ``` Note that when you call -`tfprof_logger.write_op_log(...)`, the tool adds all `Variables` inside -`tf.trainable_variables()` to `_trainable_variables`. +`tf.tfprof.tfprof_logger.write_op_log(...)`, the tool adds all `Variables` +inside `tf.trainable_variables()` to `_trainable_variables`. 12) Run tfprof in one-shot mode and dump result to file. diff --git a/tensorflow/contrib/tfprof/__init__.py b/tensorflow/contrib/tfprof/__init__.py new file mode 100644 index 0000000000..ce777979b9 --- /dev/null +++ b/tensorflow/contrib/tfprof/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""tfprof is a tool that profile various aspect of TensorFlow model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_logger +from tensorflow.python.util.all_util import make_all diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD index d78020bbd8..87a8311486 100644 --- a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD @@ -1,20 +1,29 @@ -package( - default_visibility = ["//visibility:public"], -) - licenses(["notice"]) # Apache 2.0 +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + py_library( name = "tfprof_logger", srcs = ["tfprof_logger.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow:tensorflow_py", "//tensorflow/contrib/tfprof/tools/tfprof:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", ], ) +tf_py_test( + name = "tfprof_logger_test", + srcs = ["tfprof_logger_test.py"], + additional_deps = [ + ":tfprof_logger", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/tfprof/tools/tfprof:protos_all_py", + ], +) + # ----------------------------------------------------------------------------- # Google-internal targets. These must be at the end for syncrepo. diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py index 4a487461a3..53dd2632b6 100644 --- a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py @@ -21,25 +21,57 @@ from __future__ import division from __future__ import print_function import os +import sys import tensorflow as tf -from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_log_pb2 +from tensorflow.contrib.tfprof.tools.tfprof import tfprof_log_pb2 from tensorflow.python.framework import ops TRAINABLE_VARIABLES = '_trainable_variables' REGISTERED_FLOP_STATS = 'flops' -def _get_logged_ops(graph): +def _fill_missing_graph_shape(graph, run_meta): + """Fill Tensor shapes in 'graph' with run time shape from 'run_meta'.""" + for dev_stat in run_meta.step_stats.dev_stats: + for node_stat in dev_stat.node_stats: + if not node_stat.output: + continue + try: + op = graph.get_operation_by_name(node_stat.node_name) + except KeyError as e: + # Graph doesn't contains the node_stat, usually RecvTensor. + continue + if len(node_stat.output) != len(op.outputs): + # For example, conditional op has only 1 output at run time. + continue + for (i, node_stat_out) in enumerate(node_stat.output): + if op.outputs[i].get_shape().is_fully_defined(): + continue + node_stat_dims = node_stat_out.tensor_description.shape.dim + node_stat_shape = tf.TensorShape([d.size for d in node_stat_dims]) + try: + op.outputs[i].set_shape(op.outputs[i].get_shape().merge_with( + node_stat_shape)) + except ValueError as e: + sys.stderr.write('Node %s incompatible shapes: %s.\n' % + (node_stat.node_name, e)) + return graph + + +def _get_logged_ops(graph, run_meta=None): """Extract trainable model parameters and FLOPs for ops from a Graph. Args: graph: tf.Graph. + run_meta: RunMetadata proto used to complete shape information. Returns: logged_ops: dict mapping from op_name to OpLogEntry. """ - logged_ops = {} + if run_meta: + graph = _fill_missing_graph_shape(graph, run_meta) + logged_ops = {} graph_def = graph.as_graph_def() for node in graph_def.node: try: @@ -67,17 +99,18 @@ def _get_logged_ops(graph): return logged_ops -def _merge_default_with_oplog(graph, op_log=None): +def _merge_default_with_oplog(graph, op_log=None, run_meta=None): """Merge the tfprof default extra info with caller's op_log. Args: graph: tf.Graph. op_log: OpLog proto. + run_meta: RunMetadata proto used to complete shape information. Returns: tmp_op_log: Merged OpLog proto. """ tmp_op_log = tfprof_log_pb2.OpLog() - logged_ops = _get_logged_ops(graph) + logged_ops = _get_logged_ops(graph, run_meta) if not op_log: tmp_op_log.log_entries.extend(logged_ops.values()) else: @@ -95,20 +128,25 @@ def _merge_default_with_oplog(graph, op_log=None): return tmp_op_log -def write_op_log(graph, log_dir, op_log=None): +def write_op_log(graph, log_dir, op_log=None, run_meta=None): """Log provided 'op_log', and add additional model information below. The API also assigns ops in tf.trainable_variables() an op type called '_trainable_variables'. The API also logs 'flops' statistics for ops with op.RegisterStatistics() - defined. + defined. flops calculation depends on Tensor shapes defined in 'graph', + which might not be complete, 'run_meta', if provided, completes the shape + information with best effort. Args: graph: tf.Graph. log_dir: directory to write the log file. - op_log: OpLog proto. + op_log: (Optional) OpLog proto to be written. If not provided, an new + one is created. + run_meta: (Optional) RunMetadata proto that helps flops computation using + run time shape information. """ - op_log = _merge_default_with_oplog(graph, op_log) + op_log = _merge_default_with_oplog(graph, op_log, run_meta) with tf.gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log: log.write(op_log.SerializeToString()) diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger_test.py new file mode 100644 index 0000000000..56a470e5f5 --- /dev/null +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger_test.py @@ -0,0 +1,72 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +class TFProfLoggerTest(tf.test.TestCase): + + def _BuildSmallPlaceholderlModel(self): + a = tf.placeholder(tf.int32, [2, 2]) + b = tf.placeholder(tf.int32, [2, 2]) + y = tf.matmul(a, b) + return a, b, y + + def _BuildSmallModel(self): + a = tf.constant([[1, 2], [3, 4]]) + b = tf.constant([[1, 2], [3, 4]]) + return tf.matmul(a, b) + + def testFillMissingShape(self): + a, b, y = self._BuildSmallPlaceholderlModel() + run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + run_metadata = tf.RunMetadata() + sess = tf.Session() + sess.run(y, + options=run_options, + run_metadata=run_metadata, + feed_dict={a: [[1, 2], [2, 3]], + b: [[1, 2], [2, 3]]}) + + graph2 = tf.Graph() + # Use copy_op_to_graph to remove shape information. + y2 = tf.contrib.copy_graph.copy_op_to_graph(y, graph2, []) + self.assertEquals('<unknown>', str(y2.get_shape())) + + tf.contrib.tfprof.tfprof_logger._fill_missing_graph_shape(graph2, + run_metadata) + self.assertEquals('(2, 2)', str(y2.get_shape())) + + def testFailedFillMissingShape(self): + y = self._BuildSmallModel() + run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + run_metadata = tf.RunMetadata() + sess = tf.Session() + sess.run(y, options=run_options, run_metadata=run_metadata) + + graph2 = tf.Graph() + y2 = tf.contrib.copy_graph.copy_op_to_graph(y, graph2, []) + self.assertEquals('<unknown>', str(y2.get_shape())) + # run_metadata has special name for MatMul, hence failed to fill shape. + tf.contrib.tfprof.tfprof_logger._fill_missing_graph_shape(graph2, + run_metadata) + self.assertEquals('<unknown>', str(y2.get_shape())) + + +if __name__ == '__main__': + tf.test.main() |