aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tfprof
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2017-05-05 14:43:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-05 16:09:40 -0700
commit63f7fae6a44686e1153f23d70a912cc0a83c7354 (patch)
treef8f4650fcc2551bea9f857754084d522697262c7 /tensorflow/contrib/tfprof
parent37e3b71b49495af3873e7916a2ff28e598931b89 (diff)
Add a profiler that dumps data in pprof format.
Change: 155249446
Diffstat (limited to 'tensorflow/contrib/tfprof')
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/BUILD23
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py445
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py164
3 files changed, 632 insertions, 0 deletions
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
index 9c3b10b22c..818c2d2cbf 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
@@ -97,6 +97,29 @@ py_test(
],
)
+py_library(
+ name = "pprof_profiler",
+ srcs = ["pprof_profiler.py"],
+ srcs_version = "PY2AND3",
+ deps = ["@pprof_profile_proto//:pprof_proto_py"],
+)
+
+py_test(
+ name = "pprof_profiler_test",
+ srcs = ["pprof_profiler_test.py"],
+ main = "pprof_profiler_test.py",
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"], # TODO(annarev): get it working with pip.
+ deps = [
+ ":pprof_profiler",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ "@pprof_profile_proto//:pprof_proto_py",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py b/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py
new file mode 100644
index 0000000000..c3fea915a3
--- /dev/null
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py
@@ -0,0 +1,445 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Profiler for TensorFlow models that outputs data in pprof format.
+
+See https://github.com/google/pprof/blob/master/proto/profile.proto for pprof
+profile format.
+The following needs to be set for profiler to work:
+ * trace_level needs to be set to FULL_TRACE
+ * run_metadata object should be passed in to session.run call
+
+Sample usage:
+ options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
+ run_metadata = tf.RunMetadata()
+
+ with tf.Session as sess:
+ ...
+ sess.run(computation, run_metadata=run_metadata, options=options)
+ pprof_profiler.profile(sess.graph, run_metadata, output_dir)
+
+
+ The code above would output a pprof profile to separate output_dir/.*.pb.gz
+ file for each device. These files can be passed to pprof for formatting.
+ For e.g.:
+ pprof -png --nodecount=100 --sample_index=1 output_dir/profile_output.pb.gz
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import defaultdict
+from collections import namedtuple
+import gzip
+import os
+import string
+import sys
+import time
+
+from pprof import profile_pb2
+
+
+if sys.version_info < (3,):
+ maketrans = string.maketrans
+else:
+ maketrans = str.maketrans
+
+
+ProfileDatum = namedtuple('ProfileDatum', [
+ 'node_exec_stats', 'op_type', 'traceback'])
+
+
+class StringTable(object):
+ """Keeps track of strings to add to string_table in pprof proto."""
+
+ def __init__(self):
+ # Pprof requires first entry in string_table to be ''.
+ self._string_table = ['']
+ self._string_to_index = {'': 0}
+
+ def index_of(self, value_str):
+ """Get index of value_str in the string table.
+
+ If value_str is not in the string table, we will add it at the end
+ and then return the new index.
+ Args:
+ value_str: (string) Value to lookup/add in/to the string table.
+
+ Returns:
+ Index of value_str in the string table.
+ """
+ if value_str is None:
+ value_str = ''
+ if value_str in self._string_to_index:
+ return self._string_to_index[value_str]
+ index = len(self._string_table)
+ self._string_table.append(value_str)
+ self._string_to_index[value_str] = index
+ return index
+
+ def next_index(self):
+ """Gets index that would be assigned to the next added string.
+
+ Returns:
+ Index of the next string if it was added.
+ """
+ return len(self._string_table)
+
+ def string_table(self):
+ """Returns a list of strings to store in pprof's string_table."""
+ return self._string_table
+
+
+class Functions(object):
+ """Keeps track of `Function` protos for pprof profile."""
+
+ def __init__(self, string_table):
+ """Constructor.
+
+ Args:
+ string_table: A `StringTable` object.
+ """
+ self._string_table = string_table
+ # Maps tuples in the form (file_path, function_name, start_line_number)
+ # to `Function` protos.
+ self._function_key_to_function = {}
+
+ def index_of(self, file_path, function_name, function_start_line):
+ """Returns index of the function, adding the function if needed.
+
+ Args:
+ file_path: (string) Path to file where the function is defined.
+ function_name: (string) Function name.
+ function_start_line: (integer) Start line number of function definition.
+
+ Returns:
+ Function index.
+ """
+ function_key = (file_path, function_name, function_start_line)
+ if function_key in self._function_key_to_function:
+ return self._function_key_to_function[function_key].id
+ else:
+ # Function indexes should start from 1
+ function_index = len(self._function_key_to_function) + 1
+ function = profile_pb2.Function()
+ function.id = function_index
+ function.name = self._string_table.index_of(function_name)
+ function.filename = self._string_table.index_of(file_path)
+ function.start_line = function_start_line
+ self._function_key_to_function[function_key] = function
+ return function_index
+
+ def function_protos(self):
+ """Returns list of `profile_pb2.Function` protos."""
+ return self._function_key_to_function.values()
+
+
+class Locations(object):
+ """Keeps track of `Location` protos for pprof profile.
+
+ `Locations` store information about function call locations.
+ """
+
+ def __init__(self, functions):
+ """Constructor.
+
+ Args:
+ functions: A `Functions` object.
+ """
+ self._functions = functions
+ # Maps tuples in the form (file_path, called_function_name, line_number)
+ # to `Location` protos.
+ self._location_key_to_location = {}
+
+ def index_of(
+ self, file_path, line_number, called_function_name, called_file_path,
+ called_function_start_line):
+ """Returns index of the location, adding the location if needed.
+
+ Args:
+ file_path: (string) Path to file that makes the call.
+ line_number: (integer) Call line number.
+ called_function_name: (string) Function name of the function called at
+ `file_path` and `line_number`.
+ called_file_path: (string) Path to file where the called function is
+ defined.
+ called_function_start_line: (integer) Start line number of called
+ function definition in `called_file_path` file.
+
+ Returns:
+ Index of location.
+ """
+ location_key = (file_path, called_function_name, line_number)
+ if location_key in self._location_key_to_location:
+ location = self._location_key_to_location[location_key]
+ return location.id
+ else:
+ # Location indexes should start from 1
+ location_index = len(self._location_key_to_location) + 1
+ location = profile_pb2.Location()
+ location.id = location_index
+ self._location_key_to_location[location_key] = location
+
+ line = location.line.add()
+ line.function_id = self._functions.index_of(
+ called_file_path, called_function_name, called_function_start_line)
+ line.line = line_number
+ return location_index
+
+ def location_protos(self):
+ """Returns list of `profile_pb2.Location` protos."""
+ return self._location_key_to_location.values()
+
+
+class Samples(object):
+ """Keeps track of `Sample` protos for pprof profile.
+
+ Samples store the following statistics in order:
+ count, all_time, op_time
+ """
+
+ def __init__(self, string_table):
+ """Constructor.
+
+ Args:
+ string_table: A `StringTable` object.
+ """
+ self._string_table = string_table
+ # TODO(annarev): figure out if location is unique for each node name.
+ # If not, also key this dictionary based on location ids.
+ self._node_name_to_sample = {}
+
+ def add(self, datum, location_ids):
+ """Adds a sample data point.
+
+ Args:
+ datum: `ProfileDatum` to add a sample for.
+ location_ids: List of numberic location ids for this
+ sample.
+ """
+ node_name = datum.node_exec_stats.node_name
+ if node_name in self._node_name_to_sample:
+ sample = self._node_name_to_sample[node_name]
+ sample.location_id.extend(location_ids)
+ else:
+ sample = profile_pb2.Sample()
+ # Sample stores 3 values: count, all_time, op_time
+ sample.value.extend([0, 0, 0])
+
+ label = sample.label.add()
+ label.key = self._string_table.index_of('node_name')
+ label.str = self._string_table.index_of(node_name)
+ label = sample.label.add()
+ label.key = self._string_table.index_of('op_type')
+ label.str = self._string_table.index_of(datum.op_type)
+ self._node_name_to_sample[node_name] = sample
+ sample.value[0] += 1
+ sample.value[1] += datum.node_exec_stats.all_end_rel_micros
+ sample.value[2] += (
+ datum.node_exec_stats.op_end_rel_micros -
+ datum.node_exec_stats.op_start_rel_micros)
+
+ def get_sample_protos(self):
+ """Returns list of `Sample` protos for pprof profile."""
+ return self._node_name_to_sample.values()
+
+
+class PprofProfiler(object):
+ """Creates profiles in pprof format."""
+
+ def __init__(self, graph, run_metadata):
+ """Constructor.
+
+ Args:
+ graph: A `Graph` instance.
+ run_metadata: A list of `RunMetadata` objects.
+ """
+ self._graph = graph
+ self._run_metadata = run_metadata
+ self._string_table = StringTable()
+ self._functions = Functions(self._string_table)
+ self._locations = Locations(self._functions)
+
+ def profile(self):
+ """Generates pprof profiles.
+
+ Returns:
+ Dictionary mapping from device name to proto in `profile_pb2.Profile`
+ format.
+ """
+ profiles = {}
+ data_generator_func = self._get_profile_data_generator()
+ for device_index, device_stats in enumerate(
+ self._run_metadata.step_stats.dev_stats):
+ # Create profile
+ pprof_proto = self._get_pprof_proto(data_generator_func(device_stats))
+ if not pprof_proto.sample:
+ print(
+ 'Not enough data to create profile for device %s. Did you pass '
+ 'RunMetadata to session.run call?' % device_stats.device)
+ continue
+ # Add device name comment
+ device_count = len(self._run_metadata.step_stats.dev_stats)
+ device_description = (
+ 'Device %d of %d: %s' %
+ (device_index + 1, device_count, device_stats.device))
+ device_description_str_index = self._string_table.next_index()
+ pprof_proto.string_table.append(device_description)
+ pprof_proto.comment.append(device_description_str_index)
+ profiles[device_stats.device] = pprof_proto
+ return profiles
+
+ def _get_pprof_proto(self, profile_datum_generator):
+ """Returns profile data in pprof proto format.
+
+ Args:
+ profile_datum_generator: Generator outputting `ProfileDatum` objects.
+
+ Returns:
+ A proto in pprof format.
+ """
+ pprof_profile = profile_pb2.Profile()
+ samples = Samples(self._string_table)
+
+ for datum in profile_datum_generator:
+ if not datum.traceback:
+ continue
+
+ stack_frame = datum.traceback[-1]
+ after_apply_op = False
+ location_ids = []
+
+ # We add locations from stack trace in bottom-up order.
+ for stack_frame_index in reversed(range(len(datum.traceback) - 1)):
+ prev_stack_frame = stack_frame
+ stack_frame = datum.traceback[stack_frame_index]
+
+ # Call at current frame calls function at previous frame.
+ prev_file_path = prev_stack_frame[0]
+ prev_function = prev_stack_frame[2]
+ prev_function_start_line = prev_stack_frame[4]
+ curr_file_path = stack_frame[0]
+ curr_line_number = stack_frame[1]
+
+ # Skip all calls up to apply_op since they are the same for all ops.
+ if not after_apply_op:
+ if prev_function == 'apply_op':
+ after_apply_op = True
+ continue
+ location_index = self._locations.index_of(
+ curr_file_path, curr_line_number,
+ prev_function, prev_file_path, prev_function_start_line)
+ location_ids.append(location_index)
+ samples.add(datum, location_ids)
+
+ sample_type_description = 'count'
+ sample_type = pprof_profile.sample_type.add()
+ sample_type.type = self._string_table.index_of(sample_type_description)
+ sample_type.unit = self._string_table.index_of('count')
+ sample_type_description = 'all_time'
+ sample_type = pprof_profile.sample_type.add()
+ sample_type.type = self._string_table.index_of(sample_type_description)
+ sample_type.unit = self._string_table.index_of('nanoseconds')
+ sample_type_description = 'op_time'
+ sample_type = pprof_profile.sample_type.add()
+ sample_type.type = self._string_table.index_of(sample_type_description)
+ sample_type.unit = self._string_table.index_of('nanoseconds')
+
+ pprof_profile.string_table.extend(self._string_table.string_table())
+ pprof_profile.sample.extend(samples.get_sample_protos())
+ pprof_profile.function.extend(self._functions.function_protos())
+ pprof_profile.location.extend(self._locations.location_protos())
+ return pprof_profile
+
+ def _get_profile_data_generator(self):
+ """Get function that generates `ProfileDatum` objects.
+
+ Returns:
+ A function that generates `ProfileDatum` objects.
+ """
+ node_to_traceback = defaultdict(list)
+ node_to_op_type = defaultdict(str)
+ for op in self._graph.get_operations():
+ node_to_traceback[op.name] = op.traceback_with_start_lines
+ node_to_op_type[op.name] = op.type
+
+ def profile_data_generator(device_step_stats):
+ for node_stats in device_step_stats.node_stats:
+ if node_stats.node_name == '_SOURCE' or node_stats.node_name == '_SINK':
+ continue
+ yield ProfileDatum(
+ node_stats,
+ node_to_op_type[node_stats.node_name],
+ node_to_traceback[node_stats.node_name])
+
+ return profile_data_generator
+
+
+def get_profiles(graph, run_metadata):
+ """Generate profiles in pprof format.
+
+ See https://github.com/google/pprof/blob/master/proto/profile.proto
+ for pprof proto format.
+
+ Args:
+ graph: A `Graph` object.
+ run_metadata: A `RunMetadata` proto.
+
+ Returns:
+ A dictionary mapping from device name to pprof proto for that device.
+ """
+ return PprofProfiler(graph, run_metadata).profile()
+
+
+def profile(graph, run_metadata, output_dir=None):
+ """Generate profiles in pprof format.
+
+ See https://github.com/google/pprof/blob/master/proto/profile.proto
+ for pprof proto format.
+
+ Args:
+ graph: A `Graph` object.
+ run_metadata: A `RunMetadata` proto.
+ output_dir: (string) Directory to output pprof profile to.
+ Profile files for each device will be stored in compressed
+ serialized proto format. If output_dir is None, profile protos
+ will be printed to stdout instead.
+
+ Returns:
+ List of output files created by this profile call.
+ (Note: this list will be empty if output_dir is None)
+ """
+ profiles = get_profiles(graph, run_metadata)
+ output_file_template = None
+ if output_dir:
+ if not os.path.isdir(output_dir):
+ os.makedirs(output_dir)
+ time_suffix = time.strftime('%Y%m%d%H%M%S')
+ output_file_template = os.path.join(
+ output_dir, '%s_' + time_suffix + '.pb.gz')
+
+ profile_files = []
+ for device, pprof_proto in profiles.items():
+ if output_file_template is None:
+ print('No output directory specified, printing to stdout instead.')
+ print(pprof_proto)
+ else:
+ device_name = str(device).strip('/').translate(
+ maketrans('/:', '__'))
+ profile_file = output_file_template % device_name
+ profile_files.append(profile_file)
+ with gzip.open(profile_file, 'w') as output_file:
+ print('Writing profile to %s...' % profile_file)
+ output_file.write(pprof_proto.SerializeToString())
+ return profile_files
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py
new file mode 100644
index 0000000000..13d3fb41ac
--- /dev/null
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py
@@ -0,0 +1,164 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for pprof_profiler."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+
+from pprof import profile_pb2
+from tensorflow.contrib.tfprof.python.tools.tfprof import pprof_profiler
+from tensorflow.core.framework import step_stats_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class PprofProfilerTest(test.TestCase):
+
+ def testDataEmpty(self):
+ output_dir = test.get_temp_dir()
+ run_metadata = config_pb2.RunMetadata()
+ graph = test.mock.MagicMock()
+ graph.get_operations.return_value = []
+
+ profiles = pprof_profiler.get_profiles(graph, run_metadata)
+ self.assertEquals(0, len(profiles))
+ profile_files = pprof_profiler.profile(
+ graph, run_metadata, output_dir)
+ self.assertEquals(0, len(profile_files))
+
+ def testRunMetadataEmpty(self):
+ output_dir = test.get_temp_dir()
+ run_metadata = config_pb2.RunMetadata()
+ graph = test.mock.MagicMock()
+ op1 = test.mock.MagicMock()
+ op1.name = 'Add/123'
+ op1.traceback = [('a/b/file1', 10, 'some_var')]
+ op1.type = 'add'
+ graph.get_operations.return_value = [op1]
+
+ profiles = pprof_profiler.get_profiles(graph, run_metadata)
+ self.assertEquals(0, len(profiles))
+ profile_files = pprof_profiler.profile(
+ graph, run_metadata, output_dir)
+ self.assertEquals(0, len(profile_files))
+
+ def testValidProfile(self):
+ output_dir = test.get_temp_dir()
+ run_metadata = config_pb2.RunMetadata()
+
+ node1 = step_stats_pb2.NodeExecStats(
+ node_name='Add/123',
+ op_start_rel_micros=3,
+ op_end_rel_micros=5,
+ all_end_rel_micros=4)
+
+ run_metadata = config_pb2.RunMetadata()
+ device1 = run_metadata.step_stats.dev_stats.add()
+ device1.device = 'deviceA'
+ device1.node_stats.extend([node1])
+
+ graph = test.mock.MagicMock()
+ op1 = test.mock.MagicMock()
+ op1.name = 'Add/123'
+ op1.traceback = [
+ ('a/b/file1', 10, 'apply_op', 'abc'), ('a/c/file2', 12, 'my_op', 'def')]
+ op1.type = 'add'
+ graph.get_operations.return_value = [op1]
+
+ expected_proto = """sample_type {
+ type: 5
+ unit: 5
+}
+sample_type {
+ type: 6
+ unit: 7
+}
+sample_type {
+ type: 8
+ unit: 7
+}
+sample {
+ value: 1
+ value: 4
+ value: 2
+ label {
+ key: 1
+ str: 2
+ }
+ label {
+ key: 3
+ str: 4
+ }
+}
+string_table: ""
+string_table: "node_name"
+string_table: "Add/123"
+string_table: "op_type"
+string_table: "add"
+string_table: "count"
+string_table: "all_time"
+string_table: "nanoseconds"
+string_table: "op_time"
+string_table: "Device 1 of 1: deviceA"
+comment: 9
+"""
+ # Test with protos
+ profiles = pprof_profiler.get_profiles(graph, run_metadata)
+ self.assertEquals(1, len(profiles))
+ self.assertTrue('deviceA' in profiles)
+ self.assertEquals(expected_proto, str(profiles['deviceA']))
+ # Test with files
+ profile_files = pprof_profiler.profile(
+ graph, run_metadata, output_dir)
+ self.assertEquals(1, len(profile_files))
+ with gzip.open(profile_files[0]) as profile_file:
+ profile_contents = profile_file.read()
+ profile = profile_pb2.Profile()
+ profile.ParseFromString(profile_contents)
+ self.assertEquals(expected_proto, str(profile))
+
+ def testProfileWithWhileLoop(self):
+ options = config_pb2.RunOptions()
+ options.trace_level = config_pb2.RunOptions.FULL_TRACE
+ run_metadata = config_pb2.RunMetadata()
+
+ num_iters = 5
+ with self.test_session() as sess:
+ i = constant_op.constant(0)
+ c = lambda i: math_ops.less(i, num_iters)
+ b = lambda i: math_ops.add(i, 1)
+ r = control_flow_ops.while_loop(c, b, [i])
+ sess.run(r, options=options, run_metadata=run_metadata)
+ profiles = pprof_profiler.get_profiles(sess.graph, run_metadata)
+ self.assertEquals(1, len(profiles))
+ profile = next(iter(profiles.values()))
+ add_samples = [] # Samples for the while/Add node
+ for sample in profile.sample:
+ if profile.string_table[sample.label[0].str] == 'while/Add':
+ add_samples.append(sample)
+ # Values for same nodes are aggregated.
+ self.assertEquals(1, len(add_samples))
+ # Value of "count" should be equal to number of iterations.
+ self.assertEquals(num_iters, add_samples[0].value[0])
+
+
+if __name__ == '__main__':
+ test.main()