aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tfprof
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-30 12:20:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-30 13:34:53 -0700
commit113093b017c8b2654c052a054e63738174bae649 (patch)
tree63235478ca56348efaf9be32cee298f13a4eeafb /tensorflow/contrib/tfprof
parent7cbf9454c04d9d0d7f05a5ca688d174e0e82a2ed (diff)
1. Port latest tfprof changes
2. Allow users to call tfprof_logger as tf.contrib.tfprof.tfprof_logger 3. Add a test Change: 134819165
Diffstat (limited to 'tensorflow/contrib/tfprof')
-rw-r--r--tensorflow/contrib/tfprof/BUILD17
-rw-r--r--tensorflow/contrib/tfprof/README.md12
-rw-r--r--tensorflow/contrib/tfprof/__init__.py21
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/BUILD19
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py56
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger_test.py72
6 files changed, 177 insertions, 20 deletions
diff --git a/tensorflow/contrib/tfprof/BUILD b/tensorflow/contrib/tfprof/BUILD
new file mode 100644
index 0000000000..d55bda1bd0
--- /dev/null
+++ b/tensorflow/contrib/tfprof/BUILD
@@ -0,0 +1,17 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+py_library(
+ name = "tfprof",
+ srcs = [
+ "__init__.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow/contrib/tfprof/python/tools/tfprof:tfprof_logger",
+ ],
+)
diff --git a/tensorflow/contrib/tfprof/README.md b/tensorflow/contrib/tfprof/README.md
index 0e6420134a..27a08c514c 100644
--- a/tensorflow/contrib/tfprof/README.md
+++ b/tensorflow/contrib/tfprof/README.md
@@ -1,12 +1,11 @@
# tfprof: A Profiling Tool for TensorFlow Models
-go/tfprof
+Internal User Please Use: go/tfprof
Author: Xin Pan (xpan@google.com, github: panyx0718)
-Consultants: Jon Shlens (shlens@google.com), Pete Warden (petewarden@google.com)
+Consultants: Jon Shlens, Pete Warden
-[TOC]
## Introduction
@@ -259,6 +258,7 @@ First, in Python code, create an `OpLog` proto and add op type
information to it:
```python
+
op_log = tfprof_log_pb2.OpLog()
entry = op_log.log_entries.add()
entry.name = 'pool_logit/DW'
@@ -274,7 +274,7 @@ entry.types.append('pool_logit')
Second, call write_op_log to write the OpLog proto.
```python
-tfprof_logger.write_op_log(sess.graph, /tmp/my_op_log_dir, op_log)
+tf.tfprof.tfprof_logger.write_op_log(sess.graph, /tmp/my_op_log_dir, op_log)
```
Third, when starting the tfprof tool, specify
@@ -288,8 +288,8 @@ _TFProfRoot (--/650 params)
```
Note that when you call
-`tfprof_logger.write_op_log(...)`, the tool adds all `Variables` inside
-`tf.trainable_variables()` to `_trainable_variables`.
+`tf.tfprof.tfprof_logger.write_op_log(...)`, the tool adds all `Variables`
+inside `tf.trainable_variables()` to `_trainable_variables`.
12) Run tfprof in one-shot mode and dump result to file.
diff --git a/tensorflow/contrib/tfprof/__init__.py b/tensorflow/contrib/tfprof/__init__.py
new file mode 100644
index 0000000000..ce777979b9
--- /dev/null
+++ b/tensorflow/contrib/tfprof/__init__.py
@@ -0,0 +1,21 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""tfprof is a tool that profile various aspect of TensorFlow model."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_logger
+from tensorflow.python.util.all_util import make_all
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
index d78020bbd8..87a8311486 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
@@ -1,20 +1,29 @@
-package(
- default_visibility = ["//visibility:public"],
-)
-
licenses(["notice"]) # Apache 2.0
+package(default_visibility = ["//visibility:public"])
+
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
py_library(
name = "tfprof_logger",
srcs = ["tfprof_logger.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow:tensorflow_py",
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_py",
"//tensorflow/python:framework_for_generated_wrappers",
],
)
+tf_py_test(
+ name = "tfprof_logger_test",
+ srcs = ["tfprof_logger_test.py"],
+ additional_deps = [
+ ":tfprof_logger",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/tfprof/tools/tfprof:protos_all_py",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
index 4a487461a3..53dd2632b6 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
@@ -21,25 +21,57 @@ from __future__ import division
from __future__ import print_function
import os
+import sys
import tensorflow as tf
-from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_log_pb2
+from tensorflow.contrib.tfprof.tools.tfprof import tfprof_log_pb2
from tensorflow.python.framework import ops
TRAINABLE_VARIABLES = '_trainable_variables'
REGISTERED_FLOP_STATS = 'flops'
-def _get_logged_ops(graph):
+def _fill_missing_graph_shape(graph, run_meta):
+ """Fill Tensor shapes in 'graph' with run time shape from 'run_meta'."""
+ for dev_stat in run_meta.step_stats.dev_stats:
+ for node_stat in dev_stat.node_stats:
+ if not node_stat.output:
+ continue
+ try:
+ op = graph.get_operation_by_name(node_stat.node_name)
+ except KeyError as e:
+ # Graph doesn't contains the node_stat, usually RecvTensor.
+ continue
+ if len(node_stat.output) != len(op.outputs):
+ # For example, conditional op has only 1 output at run time.
+ continue
+ for (i, node_stat_out) in enumerate(node_stat.output):
+ if op.outputs[i].get_shape().is_fully_defined():
+ continue
+ node_stat_dims = node_stat_out.tensor_description.shape.dim
+ node_stat_shape = tf.TensorShape([d.size for d in node_stat_dims])
+ try:
+ op.outputs[i].set_shape(op.outputs[i].get_shape().merge_with(
+ node_stat_shape))
+ except ValueError as e:
+ sys.stderr.write('Node %s incompatible shapes: %s.\n' %
+ (node_stat.node_name, e))
+ return graph
+
+
+def _get_logged_ops(graph, run_meta=None):
"""Extract trainable model parameters and FLOPs for ops from a Graph.
Args:
graph: tf.Graph.
+ run_meta: RunMetadata proto used to complete shape information.
Returns:
logged_ops: dict mapping from op_name to OpLogEntry.
"""
- logged_ops = {}
+ if run_meta:
+ graph = _fill_missing_graph_shape(graph, run_meta)
+ logged_ops = {}
graph_def = graph.as_graph_def()
for node in graph_def.node:
try:
@@ -67,17 +99,18 @@ def _get_logged_ops(graph):
return logged_ops
-def _merge_default_with_oplog(graph, op_log=None):
+def _merge_default_with_oplog(graph, op_log=None, run_meta=None):
"""Merge the tfprof default extra info with caller's op_log.
Args:
graph: tf.Graph.
op_log: OpLog proto.
+ run_meta: RunMetadata proto used to complete shape information.
Returns:
tmp_op_log: Merged OpLog proto.
"""
tmp_op_log = tfprof_log_pb2.OpLog()
- logged_ops = _get_logged_ops(graph)
+ logged_ops = _get_logged_ops(graph, run_meta)
if not op_log:
tmp_op_log.log_entries.extend(logged_ops.values())
else:
@@ -95,20 +128,25 @@ def _merge_default_with_oplog(graph, op_log=None):
return tmp_op_log
-def write_op_log(graph, log_dir, op_log=None):
+def write_op_log(graph, log_dir, op_log=None, run_meta=None):
"""Log provided 'op_log', and add additional model information below.
The API also assigns ops in tf.trainable_variables() an op type called
'_trainable_variables'.
The API also logs 'flops' statistics for ops with op.RegisterStatistics()
- defined.
+ defined. flops calculation depends on Tensor shapes defined in 'graph',
+ which might not be complete, 'run_meta', if provided, completes the shape
+ information with best effort.
Args:
graph: tf.Graph.
log_dir: directory to write the log file.
- op_log: OpLog proto.
+ op_log: (Optional) OpLog proto to be written. If not provided, an new
+ one is created.
+ run_meta: (Optional) RunMetadata proto that helps flops computation using
+ run time shape information.
"""
- op_log = _merge_default_with_oplog(graph, op_log)
+ op_log = _merge_default_with_oplog(graph, op_log, run_meta)
with tf.gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log:
log.write(op_log.SerializeToString())
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger_test.py
new file mode 100644
index 0000000000..56a470e5f5
--- /dev/null
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger_test.py
@@ -0,0 +1,72 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+class TFProfLoggerTest(tf.test.TestCase):
+
+ def _BuildSmallPlaceholderlModel(self):
+ a = tf.placeholder(tf.int32, [2, 2])
+ b = tf.placeholder(tf.int32, [2, 2])
+ y = tf.matmul(a, b)
+ return a, b, y
+
+ def _BuildSmallModel(self):
+ a = tf.constant([[1, 2], [3, 4]])
+ b = tf.constant([[1, 2], [3, 4]])
+ return tf.matmul(a, b)
+
+ def testFillMissingShape(self):
+ a, b, y = self._BuildSmallPlaceholderlModel()
+ run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
+ run_metadata = tf.RunMetadata()
+ sess = tf.Session()
+ sess.run(y,
+ options=run_options,
+ run_metadata=run_metadata,
+ feed_dict={a: [[1, 2], [2, 3]],
+ b: [[1, 2], [2, 3]]})
+
+ graph2 = tf.Graph()
+ # Use copy_op_to_graph to remove shape information.
+ y2 = tf.contrib.copy_graph.copy_op_to_graph(y, graph2, [])
+ self.assertEquals('<unknown>', str(y2.get_shape()))
+
+ tf.contrib.tfprof.tfprof_logger._fill_missing_graph_shape(graph2,
+ run_metadata)
+ self.assertEquals('(2, 2)', str(y2.get_shape()))
+
+ def testFailedFillMissingShape(self):
+ y = self._BuildSmallModel()
+ run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
+ run_metadata = tf.RunMetadata()
+ sess = tf.Session()
+ sess.run(y, options=run_options, run_metadata=run_metadata)
+
+ graph2 = tf.Graph()
+ y2 = tf.contrib.copy_graph.copy_op_to_graph(y, graph2, [])
+ self.assertEquals('<unknown>', str(y2.get_shape()))
+ # run_metadata has special name for MatMul, hence failed to fill shape.
+ tf.contrib.tfprof.tfprof_logger._fill_missing_graph_shape(graph2,
+ run_metadata)
+ self.assertEquals('<unknown>', str(y2.get_shape()))
+
+
+if __name__ == '__main__':
+ tf.test.main()