aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/profiler/tfprof_logger.py
blob: e651de32ea3bce32a965bfbeefc76ff08a79ac38 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logging tensorflow::tfprof::OpLogProto.

OpLogProto is used to add extra model information for offline analysis.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys

import six
from tensorflow.core.profiler import tfprof_log_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import gfile
from tensorflow.python.profiler.internal import flops_registry  # pylint: disable=unused-import
from tensorflow.python.util.tf_export import tf_export

TRAINABLE_VARIABLES = '_trainable_variables'
REGISTERED_FLOP_STATS = 'flops'


def _fill_missing_graph_shape(graph, run_meta):
  """Fill Tensor shapes in 'graph' with run time shape from 'run_meta'."""
  for dev_stat in run_meta.step_stats.dev_stats:
    for node_stat in dev_stat.node_stats:
      if not node_stat.output:
        continue
      try:
        op = graph.get_operation_by_name(node_stat.node_name)
      except KeyError as e:
        # Graph doesn't contains the node_stat, usually RecvTensor.
        continue
      if len(node_stat.output) != len(op.outputs):
        # For example, conditional op has only 1 output at run time.
        continue
      for (i, node_stat_out) in enumerate(node_stat.output):
        if op.outputs[i].get_shape().is_fully_defined():
          continue
        node_stat_dims = node_stat_out.tensor_description.shape.dim
        node_stat_shape = tensor_shape.TensorShape(
            [d.size for d in node_stat_dims])
        try:
          op.outputs[i].set_shape(op.outputs[i].get_shape().merge_with(
              node_stat_shape))
        except ValueError as e:
          sys.stderr.write('Node %s incompatible shapes: %s.\n' %
                           (node_stat.node_name, e))
  return graph


def _str_id(s, str_to_id):
  """Maps string to id."""
  num = str_to_id.get(s, None)
  if num is None:
    num = len(str_to_id)
    str_to_id[s] = num
  return num


def _get_logged_ops(graph, run_meta=None, add_trace=True,
                    add_trainable_var=True):
  """Extract trainable model parameters and FLOPs for ops from a Graph.

  Args:
    graph: tf.Graph.
    run_meta: RunMetadata proto used to complete shape information.
    add_trace: Whether to add op trace information.
    add_trainable_var: Whether to assign tf.trainable_variables() op type
      '_trainable_variables'.
  Returns:
    logged_ops: dict mapping from op_name to OpLogEntry.
    string_to_id: dict mapping from string to id.
  """
  if run_meta:
    graph = _fill_missing_graph_shape(graph, run_meta)

  op_missing_shape = 0
  logged_ops = {}
  string_to_id = dict()
  string_to_id['none'] = len(string_to_id)
  # TODO(xpan): Work with Profiler more efficiently.
  for op in graph.get_operations():
    try:
      stats = ops.get_stats_for_node_def(
          graph, op.node_def, REGISTERED_FLOP_STATS)
    except ValueError:
      # Catch Exception When shape is incomplete. Skip it.
      op_missing_shape += 1
      stats = None

    entry = tfprof_log_pb2.OpLogEntry()
    entry.name = op.name
    add_entry = False
    if stats and stats.value:
      entry.float_ops = int(stats.value)
      add_entry = True

    if add_trace:
      for tb in op.traceback_with_start_lines:
        trace = entry.code_def.traces.add()
        trace.file_id = _str_id(tb[0], string_to_id) if tb[0] else 0
        trace.lineno = tb[1] if tb[1] else -1
        trace.function_id = _str_id(tb[2], string_to_id) if tb[2] else 0
        trace.line_id = _str_id(tb[3], string_to_id) if tb[3] else 0
        trace.func_start_line = tb[4] if tb[4] else -1
      add_entry = True

    if add_entry:
      logged_ops[entry.name] = entry

  if add_trainable_var:
    for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
      if v.op.name not in logged_ops:
        entry = tfprof_log_pb2.OpLogEntry()
        entry.name = v.op.name
        entry.types.append(TRAINABLE_VARIABLES)
        logged_ops[entry.name] = entry
      else:
        logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES)

  if op_missing_shape > 0 and not run_meta:
    sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' %
                     op_missing_shape)
  return logged_ops, string_to_id


def merge_default_with_oplog(graph, op_log=None, run_meta=None,
                             add_trace=True, add_trainable_var=True):
  """Merge the tfprof default extra info with caller's op_log.

  Args:
    graph: tf.Graph. If None and eager execution is not enabled, use
        default graph.
    op_log: OpLogProto proto.
    run_meta: RunMetadata proto used to complete shape information.
    add_trace: Whether to add op trace information.
    add_trainable_var: Whether to assign tf.trainable_variables() op type
      '_trainable_variables'.
  Returns:
    tmp_op_log: Merged OpLogProto proto.
  """
  if not graph and not context.executing_eagerly():
    graph = ops.get_default_graph()

  tmp_op_log = tfprof_log_pb2.OpLogProto()
  if not graph:
    return tmp_op_log

  logged_ops, string_to_id = _get_logged_ops(
      graph, run_meta, add_trace=add_trace, add_trainable_var=add_trainable_var)

  if not op_log:
    tmp_op_log.log_entries.extend(logged_ops.values())
  else:
    all_ops = dict()
    for entry in op_log.log_entries:
      all_ops[entry.name] = entry
    for op_name, entry in six.iteritems(logged_ops):
      if op_name in all_ops:
        all_ops[op_name].types.extend(entry.types)
        if entry.float_ops > 0 and all_ops[op_name].float_ops == 0:
          all_ops[op_name].float_ops = entry.float_ops
        if entry.code_def.traces and not all_ops[op_name].code_def.traces:
          all_ops[op_name].code_def.MergeFrom(entry.code_def)
      else:
        all_ops[op_name] = entry
    tmp_op_log.log_entries.extend(all_ops.values())

  for s, i in six.iteritems(string_to_id):
    tmp_op_log.id_to_string[i] = s
  return tmp_op_log


@tf_export('profiler.write_op_log')
def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True):
  """Log provided 'op_log', and add additional model information below.

    The API also assigns ops in tf.trainable_variables() an op type called
    '_trainable_variables'.
    The API also logs 'flops' statistics for ops with op.RegisterStatistics()
    defined. flops calculation depends on Tensor shapes defined in 'graph',
    which might not be complete. 'run_meta', if provided, completes the shape
    information with best effort.

  Args:
    graph: tf.Graph. If None and eager execution is not enabled, use
        default graph.
    log_dir: directory to write the log file.
    op_log: (Optional) OpLogProto proto to be written. If not provided, an new
        one is created.
    run_meta: (Optional) RunMetadata proto that helps flops computation using
        run time shape information.
    add_trace: Whether to add python code trace information.
        Used to support "code" view.
  """
  if not graph and not context.executing_eagerly():
    graph = ops.get_default_graph()
  op_log = merge_default_with_oplog(graph, op_log, run_meta, add_trace)

  with gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log:
    log.write(op_log.SerializeToString())