diff options
author | 2017-05-05 14:43:21 -0800 | |
---|---|---|
committer | 2017-05-05 16:09:40 -0700 | |
commit | 63f7fae6a44686e1153f23d70a912cc0a83c7354 (patch) | |
tree | f8f4650fcc2551bea9f857754084d522697262c7 /tensorflow/contrib/tfprof | |
parent | 37e3b71b49495af3873e7916a2ff28e598931b89 (diff) |
Add a profiler that dumps data in pprof format.
Change: 155249446
Diffstat (limited to 'tensorflow/contrib/tfprof')
3 files changed, 632 insertions, 0 deletions
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD index 9c3b10b22c..818c2d2cbf 100644 --- a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD @@ -97,6 +97,29 @@ py_test( ], ) +py_library( + name = "pprof_profiler", + srcs = ["pprof_profiler.py"], + srcs_version = "PY2AND3", + deps = ["@pprof_profile_proto//:pprof_proto_py"], +) + +py_test( + name = "pprof_profiler_test", + srcs = ["pprof_profiler_test.py"], + main = "pprof_profiler_test.py", + srcs_version = "PY2AND3", + tags = ["no_pip"], # TODO(annarev): get it working with pip. + deps = [ + ":pprof_profiler", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "@pprof_profile_proto//:pprof_proto_py", + ], +) + # ----------------------------------------------------------------------------- # Google-internal targets. These must be at the end for syncrepo. diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py b/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py new file mode 100644 index 0000000000..c3fea915a3 --- /dev/null +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py @@ -0,0 +1,445 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Profiler for TensorFlow models that outputs data in pprof format. + +See https://github.com/google/pprof/blob/master/proto/profile.proto for pprof +profile format. +The following needs to be set for profiler to work: + * trace_level needs to be set to FULL_TRACE + * run_metadata object should be passed in to session.run call + +Sample usage: + options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + run_metadata = tf.RunMetadata() + + with tf.Session as sess: + ... + sess.run(computation, run_metadata=run_metadata, options=options) + pprof_profiler.profile(sess.graph, run_metadata, output_dir) + + + The code above would output a pprof profile to separate output_dir/.*.pb.gz + file for each device. These files can be passed to pprof for formatting. + For e.g.: + pprof -png --nodecount=100 --sample_index=1 output_dir/profile_output.pb.gz +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict +from collections import namedtuple +import gzip +import os +import string +import sys +import time + +from pprof import profile_pb2 + + +if sys.version_info < (3,): + maketrans = string.maketrans +else: + maketrans = str.maketrans + + +ProfileDatum = namedtuple('ProfileDatum', [ + 'node_exec_stats', 'op_type', 'traceback']) + + +class StringTable(object): + """Keeps track of strings to add to string_table in pprof proto.""" + + def __init__(self): + # Pprof requires first entry in string_table to be ''. + self._string_table = [''] + self._string_to_index = {'': 0} + + def index_of(self, value_str): + """Get index of value_str in the string table. + + If value_str is not in the string table, we will add it at the end + and then return the new index. + Args: + value_str: (string) Value to lookup/add in/to the string table. + + Returns: + Index of value_str in the string table. + """ + if value_str is None: + value_str = '' + if value_str in self._string_to_index: + return self._string_to_index[value_str] + index = len(self._string_table) + self._string_table.append(value_str) + self._string_to_index[value_str] = index + return index + + def next_index(self): + """Gets index that would be assigned to the next added string. + + Returns: + Index of the next string if it was added. + """ + return len(self._string_table) + + def string_table(self): + """Returns a list of strings to store in pprof's string_table.""" + return self._string_table + + +class Functions(object): + """Keeps track of `Function` protos for pprof profile.""" + + def __init__(self, string_table): + """Constructor. + + Args: + string_table: A `StringTable` object. + """ + self._string_table = string_table + # Maps tuples in the form (file_path, function_name, start_line_number) + # to `Function` protos. + self._function_key_to_function = {} + + def index_of(self, file_path, function_name, function_start_line): + """Returns index of the function, adding the function if needed. + + Args: + file_path: (string) Path to file where the function is defined. + function_name: (string) Function name. + function_start_line: (integer) Start line number of function definition. + + Returns: + Function index. + """ + function_key = (file_path, function_name, function_start_line) + if function_key in self._function_key_to_function: + return self._function_key_to_function[function_key].id + else: + # Function indexes should start from 1 + function_index = len(self._function_key_to_function) + 1 + function = profile_pb2.Function() + function.id = function_index + function.name = self._string_table.index_of(function_name) + function.filename = self._string_table.index_of(file_path) + function.start_line = function_start_line + self._function_key_to_function[function_key] = function + return function_index + + def function_protos(self): + """Returns list of `profile_pb2.Function` protos.""" + return self._function_key_to_function.values() + + +class Locations(object): + """Keeps track of `Location` protos for pprof profile. + + `Locations` store information about function call locations. + """ + + def __init__(self, functions): + """Constructor. + + Args: + functions: A `Functions` object. + """ + self._functions = functions + # Maps tuples in the form (file_path, called_function_name, line_number) + # to `Location` protos. + self._location_key_to_location = {} + + def index_of( + self, file_path, line_number, called_function_name, called_file_path, + called_function_start_line): + """Returns index of the location, adding the location if needed. + + Args: + file_path: (string) Path to file that makes the call. + line_number: (integer) Call line number. + called_function_name: (string) Function name of the function called at + `file_path` and `line_number`. + called_file_path: (string) Path to file where the called function is + defined. + called_function_start_line: (integer) Start line number of called + function definition in `called_file_path` file. + + Returns: + Index of location. + """ + location_key = (file_path, called_function_name, line_number) + if location_key in self._location_key_to_location: + location = self._location_key_to_location[location_key] + return location.id + else: + # Location indexes should start from 1 + location_index = len(self._location_key_to_location) + 1 + location = profile_pb2.Location() + location.id = location_index + self._location_key_to_location[location_key] = location + + line = location.line.add() + line.function_id = self._functions.index_of( + called_file_path, called_function_name, called_function_start_line) + line.line = line_number + return location_index + + def location_protos(self): + """Returns list of `profile_pb2.Location` protos.""" + return self._location_key_to_location.values() + + +class Samples(object): + """Keeps track of `Sample` protos for pprof profile. + + Samples store the following statistics in order: + count, all_time, op_time + """ + + def __init__(self, string_table): + """Constructor. + + Args: + string_table: A `StringTable` object. + """ + self._string_table = string_table + # TODO(annarev): figure out if location is unique for each node name. + # If not, also key this dictionary based on location ids. + self._node_name_to_sample = {} + + def add(self, datum, location_ids): + """Adds a sample data point. + + Args: + datum: `ProfileDatum` to add a sample for. + location_ids: List of numberic location ids for this + sample. + """ + node_name = datum.node_exec_stats.node_name + if node_name in self._node_name_to_sample: + sample = self._node_name_to_sample[node_name] + sample.location_id.extend(location_ids) + else: + sample = profile_pb2.Sample() + # Sample stores 3 values: count, all_time, op_time + sample.value.extend([0, 0, 0]) + + label = sample.label.add() + label.key = self._string_table.index_of('node_name') + label.str = self._string_table.index_of(node_name) + label = sample.label.add() + label.key = self._string_table.index_of('op_type') + label.str = self._string_table.index_of(datum.op_type) + self._node_name_to_sample[node_name] = sample + sample.value[0] += 1 + sample.value[1] += datum.node_exec_stats.all_end_rel_micros + sample.value[2] += ( + datum.node_exec_stats.op_end_rel_micros - + datum.node_exec_stats.op_start_rel_micros) + + def get_sample_protos(self): + """Returns list of `Sample` protos for pprof profile.""" + return self._node_name_to_sample.values() + + +class PprofProfiler(object): + """Creates profiles in pprof format.""" + + def __init__(self, graph, run_metadata): + """Constructor. + + Args: + graph: A `Graph` instance. + run_metadata: A list of `RunMetadata` objects. + """ + self._graph = graph + self._run_metadata = run_metadata + self._string_table = StringTable() + self._functions = Functions(self._string_table) + self._locations = Locations(self._functions) + + def profile(self): + """Generates pprof profiles. + + Returns: + Dictionary mapping from device name to proto in `profile_pb2.Profile` + format. + """ + profiles = {} + data_generator_func = self._get_profile_data_generator() + for device_index, device_stats in enumerate( + self._run_metadata.step_stats.dev_stats): + # Create profile + pprof_proto = self._get_pprof_proto(data_generator_func(device_stats)) + if not pprof_proto.sample: + print( + 'Not enough data to create profile for device %s. Did you pass ' + 'RunMetadata to session.run call?' % device_stats.device) + continue + # Add device name comment + device_count = len(self._run_metadata.step_stats.dev_stats) + device_description = ( + 'Device %d of %d: %s' % + (device_index + 1, device_count, device_stats.device)) + device_description_str_index = self._string_table.next_index() + pprof_proto.string_table.append(device_description) + pprof_proto.comment.append(device_description_str_index) + profiles[device_stats.device] = pprof_proto + return profiles + + def _get_pprof_proto(self, profile_datum_generator): + """Returns profile data in pprof proto format. + + Args: + profile_datum_generator: Generator outputting `ProfileDatum` objects. + + Returns: + A proto in pprof format. + """ + pprof_profile = profile_pb2.Profile() + samples = Samples(self._string_table) + + for datum in profile_datum_generator: + if not datum.traceback: + continue + + stack_frame = datum.traceback[-1] + after_apply_op = False + location_ids = [] + + # We add locations from stack trace in bottom-up order. + for stack_frame_index in reversed(range(len(datum.traceback) - 1)): + prev_stack_frame = stack_frame + stack_frame = datum.traceback[stack_frame_index] + + # Call at current frame calls function at previous frame. + prev_file_path = prev_stack_frame[0] + prev_function = prev_stack_frame[2] + prev_function_start_line = prev_stack_frame[4] + curr_file_path = stack_frame[0] + curr_line_number = stack_frame[1] + + # Skip all calls up to apply_op since they are the same for all ops. + if not after_apply_op: + if prev_function == 'apply_op': + after_apply_op = True + continue + location_index = self._locations.index_of( + curr_file_path, curr_line_number, + prev_function, prev_file_path, prev_function_start_line) + location_ids.append(location_index) + samples.add(datum, location_ids) + + sample_type_description = 'count' + sample_type = pprof_profile.sample_type.add() + sample_type.type = self._string_table.index_of(sample_type_description) + sample_type.unit = self._string_table.index_of('count') + sample_type_description = 'all_time' + sample_type = pprof_profile.sample_type.add() + sample_type.type = self._string_table.index_of(sample_type_description) + sample_type.unit = self._string_table.index_of('nanoseconds') + sample_type_description = 'op_time' + sample_type = pprof_profile.sample_type.add() + sample_type.type = self._string_table.index_of(sample_type_description) + sample_type.unit = self._string_table.index_of('nanoseconds') + + pprof_profile.string_table.extend(self._string_table.string_table()) + pprof_profile.sample.extend(samples.get_sample_protos()) + pprof_profile.function.extend(self._functions.function_protos()) + pprof_profile.location.extend(self._locations.location_protos()) + return pprof_profile + + def _get_profile_data_generator(self): + """Get function that generates `ProfileDatum` objects. + + Returns: + A function that generates `ProfileDatum` objects. + """ + node_to_traceback = defaultdict(list) + node_to_op_type = defaultdict(str) + for op in self._graph.get_operations(): + node_to_traceback[op.name] = op.traceback_with_start_lines + node_to_op_type[op.name] = op.type + + def profile_data_generator(device_step_stats): + for node_stats in device_step_stats.node_stats: + if node_stats.node_name == '_SOURCE' or node_stats.node_name == '_SINK': + continue + yield ProfileDatum( + node_stats, + node_to_op_type[node_stats.node_name], + node_to_traceback[node_stats.node_name]) + + return profile_data_generator + + +def get_profiles(graph, run_metadata): + """Generate profiles in pprof format. + + See https://github.com/google/pprof/blob/master/proto/profile.proto + for pprof proto format. + + Args: + graph: A `Graph` object. + run_metadata: A `RunMetadata` proto. + + Returns: + A dictionary mapping from device name to pprof proto for that device. + """ + return PprofProfiler(graph, run_metadata).profile() + + +def profile(graph, run_metadata, output_dir=None): + """Generate profiles in pprof format. + + See https://github.com/google/pprof/blob/master/proto/profile.proto + for pprof proto format. + + Args: + graph: A `Graph` object. + run_metadata: A `RunMetadata` proto. + output_dir: (string) Directory to output pprof profile to. + Profile files for each device will be stored in compressed + serialized proto format. If output_dir is None, profile protos + will be printed to stdout instead. + + Returns: + List of output files created by this profile call. + (Note: this list will be empty if output_dir is None) + """ + profiles = get_profiles(graph, run_metadata) + output_file_template = None + if output_dir: + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + time_suffix = time.strftime('%Y%m%d%H%M%S') + output_file_template = os.path.join( + output_dir, '%s_' + time_suffix + '.pb.gz') + + profile_files = [] + for device, pprof_proto in profiles.items(): + if output_file_template is None: + print('No output directory specified, printing to stdout instead.') + print(pprof_proto) + else: + device_name = str(device).strip('/').translate( + maketrans('/:', '__')) + profile_file = output_file_template % device_name + profile_files.append(profile_file) + with gzip.open(profile_file, 'w') as output_file: + print('Writing profile to %s...' % profile_file) + output_file.write(pprof_proto.SerializeToString()) + return profile_files diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py new file mode 100644 index 0000000000..13d3fb41ac --- /dev/null +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py @@ -0,0 +1,164 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for pprof_profiler.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip + +from pprof import profile_pb2 +from tensorflow.contrib.tfprof.python.tools.tfprof import pprof_profiler +from tensorflow.core.framework import step_stats_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class PprofProfilerTest(test.TestCase): + + def testDataEmpty(self): + output_dir = test.get_temp_dir() + run_metadata = config_pb2.RunMetadata() + graph = test.mock.MagicMock() + graph.get_operations.return_value = [] + + profiles = pprof_profiler.get_profiles(graph, run_metadata) + self.assertEquals(0, len(profiles)) + profile_files = pprof_profiler.profile( + graph, run_metadata, output_dir) + self.assertEquals(0, len(profile_files)) + + def testRunMetadataEmpty(self): + output_dir = test.get_temp_dir() + run_metadata = config_pb2.RunMetadata() + graph = test.mock.MagicMock() + op1 = test.mock.MagicMock() + op1.name = 'Add/123' + op1.traceback = [('a/b/file1', 10, 'some_var')] + op1.type = 'add' + graph.get_operations.return_value = [op1] + + profiles = pprof_profiler.get_profiles(graph, run_metadata) + self.assertEquals(0, len(profiles)) + profile_files = pprof_profiler.profile( + graph, run_metadata, output_dir) + self.assertEquals(0, len(profile_files)) + + def testValidProfile(self): + output_dir = test.get_temp_dir() + run_metadata = config_pb2.RunMetadata() + + node1 = step_stats_pb2.NodeExecStats( + node_name='Add/123', + op_start_rel_micros=3, + op_end_rel_micros=5, + all_end_rel_micros=4) + + run_metadata = config_pb2.RunMetadata() + device1 = run_metadata.step_stats.dev_stats.add() + device1.device = 'deviceA' + device1.node_stats.extend([node1]) + + graph = test.mock.MagicMock() + op1 = test.mock.MagicMock() + op1.name = 'Add/123' + op1.traceback = [ + ('a/b/file1', 10, 'apply_op', 'abc'), ('a/c/file2', 12, 'my_op', 'def')] + op1.type = 'add' + graph.get_operations.return_value = [op1] + + expected_proto = """sample_type { + type: 5 + unit: 5 +} +sample_type { + type: 6 + unit: 7 +} +sample_type { + type: 8 + unit: 7 +} +sample { + value: 1 + value: 4 + value: 2 + label { + key: 1 + str: 2 + } + label { + key: 3 + str: 4 + } +} +string_table: "" +string_table: "node_name" +string_table: "Add/123" +string_table: "op_type" +string_table: "add" +string_table: "count" +string_table: "all_time" +string_table: "nanoseconds" +string_table: "op_time" +string_table: "Device 1 of 1: deviceA" +comment: 9 +""" + # Test with protos + profiles = pprof_profiler.get_profiles(graph, run_metadata) + self.assertEquals(1, len(profiles)) + self.assertTrue('deviceA' in profiles) + self.assertEquals(expected_proto, str(profiles['deviceA'])) + # Test with files + profile_files = pprof_profiler.profile( + graph, run_metadata, output_dir) + self.assertEquals(1, len(profile_files)) + with gzip.open(profile_files[0]) as profile_file: + profile_contents = profile_file.read() + profile = profile_pb2.Profile() + profile.ParseFromString(profile_contents) + self.assertEquals(expected_proto, str(profile)) + + def testProfileWithWhileLoop(self): + options = config_pb2.RunOptions() + options.trace_level = config_pb2.RunOptions.FULL_TRACE + run_metadata = config_pb2.RunMetadata() + + num_iters = 5 + with self.test_session() as sess: + i = constant_op.constant(0) + c = lambda i: math_ops.less(i, num_iters) + b = lambda i: math_ops.add(i, 1) + r = control_flow_ops.while_loop(c, b, [i]) + sess.run(r, options=options, run_metadata=run_metadata) + profiles = pprof_profiler.get_profiles(sess.graph, run_metadata) + self.assertEquals(1, len(profiles)) + profile = next(iter(profiles.values())) + add_samples = [] # Samples for the while/Add node + for sample in profile.sample: + if profile.string_table[sample.label[0].str] == 'while/Add': + add_samples.append(sample) + # Values for same nodes are aggregated. + self.assertEquals(1, len(add_samples)) + # Value of "count" should be equal to number of iterations. + self.assertEquals(num_iters, add_samples[0].value[0]) + + +if __name__ == '__main__': + test.main() |