aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py
blob: 9a84f1231cb71745f778285f30ada151a7c1accd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reaching definition analysis.

This analysis attaches a set of a Definition objects to each symbol, one
for each distinct definition that may reach it. The Definition objects are
mutable and may be used by subsequent analyses to further annotate data like
static type and value information.
The analysis also attaches the set of the symbols defined at the entry of
control flow statements.

Requires activity analysis.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gast

from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import cfg
from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis import annos


class Definition(object):
  """Definition objects describe a unique definition of a variable.

  Subclasses of this may be used by passing an appropriate factory fuction to
  resolve.

  Attributes:
    param_of: Optional[ast.AST]
  """

  def __init__(self):
    self.param_of = None

  def __repr__(self):
    return '%s[%d]' % (self.__class__.__name__, id(self))


class _NodeState(object):
  """Abstraction for the state of the CFG walk for reaching definition analysis.

  This is a value type. Only implements the strictly necessary operators.

  Attributes:
    value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
        their possible definitions
  """

  def __init__(self, init_from=None):
    if init_from:
      if isinstance(init_from, _NodeState):
        self.value = {
            s: set(other_infos) for s, other_infos in init_from.value.items()
        }
      elif isinstance(init_from, dict):
        self.value = {s: set((init_from[s],)) for s in init_from}
      else:
        assert False, init_from
    else:
      self.value = {}

  def __eq__(self, other):
    if frozenset(self.value.keys()) != frozenset(other.value.keys()):
      return False
    ret = all(self.value[s] == other.value[s] for s in self.value)
    return ret

  def __ne__(self, other):
    return not self.__eq__(other)

  def __or__(self, other):
    assert isinstance(other, _NodeState)
    result = _NodeState(self)
    for s, other_infos in other.value.items():
      if s in result.value:
        result.value[s].update(other_infos)
      else:
        result.value[s] = set(other_infos)
    return result

  def __sub__(self, other):
    assert isinstance(other, set)
    result = _NodeState(self)
    for s in other:
      result.value.pop(s, None)
    return result

  def __repr__(self):
    return 'NodeState[%s]=%s' % (id(self), repr(self.value))


class Analyzer(cfg.GraphVisitor):
  """CFG visitor that determines reaching definitions at statement level."""

  def __init__(self, graph, definition_factory):
    self._definition_factory = definition_factory
    super(Analyzer, self).__init__(graph)
    # This allows communicating that nodes have extra reaching definitions,
    # e.g. those that a function closes over.
    self.extra_in = {}

    self.gen_map = {}

  def init_state(self, _):
    return _NodeState()

  def visit_node(self, node):
    prev_defs_out = self.out[node]

    defs_in = _NodeState(self.extra_in.get(node.ast_node, None))
    for n in node.prev:
      defs_in |= self.out[n]

    if anno.hasanno(node.ast_node, anno.Static.SCOPE):
      node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
      # The definition objects created by each node must be singletons because
      # their ids are used in equality checks.
      if node not in self.gen_map:
        node_symbols = {}
        for s in node_scope.modified:
          def_ = self._definition_factory()
          if s in node_scope.params:
            def_.param_of = node_scope.params[s]
          node_symbols[s] = def_
        self.gen_map[node] = _NodeState(node_symbols)

      gen = self.gen_map[node]
      kill = node_scope.modified
      defs_out = gen | (defs_in - kill)

    else:
      # Nodes that don't have a scope annotation are assumed not to touch any
      # symbols.
      # This Name node below is a literal name, e.g. False
      # This can also happen if activity.py forgot to annotate the node with a
      # scope object.
      assert isinstance(
          node.ast_node,
          (gast.Name, gast.Break, gast.Continue, gast.Raise)), (node.ast_node,
                                                                node)
      defs_out = defs_in

    self.in_[node] = defs_in
    self.out[node] = defs_out

    # TODO(mdan): Move this to the superclass?
    return prev_defs_out != defs_out


class TreeAnnotator(transformer.Base):
  """AST visitor that annotates each symbol name with its reaching definitions.

  Simultaneously, the visitor runs the dataflow analysis on each function node,
  accounting for the effect of closures. For example:

    def foo():
      bar = 1
      def baz():
        # bar = 1 reaches here
  """

  def __init__(self, source_info, graphs, definition_factory):
    super(TreeAnnotator, self).__init__(source_info)
    self.definition_factory = definition_factory
    self.graphs = graphs
    self.current_analyzer = None
    self.current_cfg_node = None

  def visit_FunctionDef(self, node):
    parent_analyzer = self.current_analyzer
    subgraph = self.graphs[node]

    # Preorder tree processing:
    #  1. if this is a child function, the parent was already analyzed and it
    #     has the proper state value for the subgraph's entry
    #  2. analyze the current function body
    #  2. recursively walk the subtree; child functions will be processed
    analyzer = Analyzer(subgraph, self.definition_factory)
    if parent_analyzer is not None:
      # Wire the state between the two subgraphs' analyzers.
      parent_out_state = parent_analyzer.out[parent_analyzer.graph.index[node]]
      # Exception: symbols modified in the child function are local to it
      body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
      parent_out_state -= body_scope.modified
      analyzer.extra_in[node.args] = parent_out_state

    # Complete the analysis for the local function and annotate its body.
    analyzer.visit_forward()

    # Recursively process any remaining subfunctions.
    self.current_analyzer = analyzer
    # Note: not visiting name, decorator_list and returns because they don't
    # apply to this anlysis.
    # TODO(mdan): Should we still process the function name?
    node.args = self.visit(node.args)
    node.body = self.visit_block(node.body)
    self.current_analyzer = parent_analyzer

    return node

  def visit_nonlocal(self, node):
    raise NotImplementedError()

  def visit_global(self, node):
    raise NotImplementedError()

  def visit_Name(self, node):
    if self.current_analyzer is None:
      # Names may appear outside function defs - for example in class
      # definitions.
      return node

    analyzer = self.current_analyzer
    cfg_node = self.current_cfg_node

    assert cfg_node is not None, 'name node outside of any statement?'

    qn = anno.getanno(node, anno.Basic.QN)
    if isinstance(node.ctx, gast.Load):
      anno.setanno(node, anno.Static.DEFINITIONS,
                   tuple(analyzer.in_[cfg_node].value.get(qn, ())))
    else:
      anno.setanno(node, anno.Static.DEFINITIONS,
                   tuple(analyzer.out[cfg_node].value.get(qn, ())))

    return node

  def _aggregate_predecessors_defined_in(self, node):
    preds = self.current_analyzer.graph.stmt_prev[node]
    node_defined_in = set()
    for p in preds:
      node_defined_in |= set(self.current_analyzer.out[p].value.keys())
    anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in))

  def visit_If(self, node):
    self._aggregate_predecessors_defined_in(node)
    return self.generic_visit(node)

  def visit_For(self, node):
    self._aggregate_predecessors_defined_in(node)

    # Manually accounting for the shortcoming described in
    # cfg.AstToCfg.visit_For.
    parent = self.current_cfg_node
    self.current_cfg_node = self.current_analyzer.graph.index[node.iter]
    node.target = self.visit(node.target)
    self.current_cfg_node = parent

    node.iter = self.visit(node.iter)
    node.body = self.visit_block(node.body)
    node.orelse = self.visit_block(node.orelse)

    return node

  def visit_While(self, node):
    self._aggregate_predecessors_defined_in(node)
    return self.generic_visit(node)

  def visit(self, node):
    parent = self.current_cfg_node

    if (self.current_analyzer is not None and
        node in self.current_analyzer.graph.index):
      self.current_cfg_node = self.current_analyzer.graph.index[node]
    node = super(TreeAnnotator, self).visit(node)

    self.current_cfg_node = parent
    return node


def resolve(node, source_info, graphs, definition_factory):
  """Resolves reaching definitions for each symbol.

  Args:
    node: ast.AST
    source_info: transformer.SourceInfo
    graphs: Dict[ast.FunctionDef, cfg.Graph]
    definition_factory: Callable[[], Definition]
  Returns:
    ast.AST
  """
  visitor = TreeAnnotator(source_info, graphs, definition_factory)
  node = visitor.visit(node)
  return node