aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
blob: d77c15915bb1a69aaf7be854c9b433951a359151 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Conversion to A-normal form.

The general idea of A-normal form is that every intermediate value is
explicitly named with a variable.  For more, see
https://en.wikipedia.org/wiki/A-normal_form.

The specific converters used here are based on Python AST semantics as
documented at https://greentreesnakes.readthedocs.io/en/latest/.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gast
import six

from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct import transformer


class DummyGensym(object):
  """A dumb gensym that suffixes a stem by sequential numbers from 1000."""

  def __init__(self, entity_info):
    del entity_info
    # A proper implementation needs to account for:
    #   * entity_info.namespace
    #   * all the symbols defined in the AST
    #   * the symbols generated so far
    self._idx = 0

  def new_name(self, stem='tmp'):
    self._idx += 1
    return stem + '_' + str(1000 + self._idx)


class AnfTransformer(transformer.Base):
  """Performs the conversion to A-normal form (ANF)."""

  # The algorithm is a postorder recursive tree walk.  Any given node A may, in
  # general, require creation of a series B of Assign statements, which compute
  # and explicitly name the intermediate values needed to compute the value of
  # A.  If A was already a statement, it can be replaced with the sequence B +
  # [A].  If A was an expression, B needs to be propagated up the tree until a
  # statement is encountered.  Since the `ast.NodeTransformer` framework makes
  # no provision for subtraversals returning side information, this class
  # accumulates the sequence B in an instance variable.

  # The only other subtlety is that some Python statements (like `if`) have both
  # expression fields (`test`) and statement list fields (`body` and `orelse`).
  # Any additional assignments needed to name all the intermediate values in the
  # `test` can be prepended to the `if` node, but assignments produced by
  # processing the `body` and the `orelse` need to be kept together with them,
  # and not accidentally lifted out of the `if`.

  def __init__(self, entity_info, gensym_source=None):
    """Creates an ANF transformer.

    Args:
      entity_info: transformer.EntityInfo
      gensym_source: An optional object with the same interface as `DummyGensym`
        for generating unique names
    """
    super(AnfTransformer, self).__init__(entity_info)
    if gensym_source is None:
      self._gensym = DummyGensym(entity_info)
    else:
      self._gensym = gensym_source(entity_info)
    self._pending_statements = []

  def _consume_pending_statements(self):
    ans = self._pending_statements
    self._pending_statements = []
    return ans

  def _add_pending_statement(self, stmt):
    self._pending_statements.append(stmt)

  _trivial_nodes = (
      # Non-nodes that show up as AST fields
      bool, six.string_types,
      # Leaf nodes that are already in A-normal form
      gast.expr_context, gast.Name, gast.Num, gast.Str, gast.Bytes,
      gast.NameConstant, gast.Ellipsis,
      # Binary operators
      gast.Add, gast.Sub, gast.Mult, gast.Div, gast.Mod, gast.Pow, gast.LShift,
      gast.RShift, gast.BitOr, gast.BitXor, gast.BitAnd, gast.FloorDiv,
      # Unary operators
      gast.Invert, gast.Not, gast.UAdd, gast.USub,
      # Comparison operators
      gast.Eq, gast.NotEq, gast.Lt, gast.LtE, gast.Gt, gast.GtE,
      gast.Is, gast.IsNot, gast.In, gast.NotIn,
  )

  def _is_node_trivial(self, node):
    if node is None:
      return True
    elif isinstance(node, self._trivial_nodes):
      return True
    elif isinstance(node, gast.keyword):
      return self._is_node_trivial(node.value)
    elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)):
      return self._are_children_trivial(node)
    return False

  def _are_children_trivial(self, node):
    for field in node._fields:
      if not field.startswith('__'):
        if not self._is_node_trivial(getattr(node, field)):
          return False
    return True

  def _ensure_node_is_trivial(self, node):
    if node is None:
      return node
    elif isinstance(node, self._trivial_nodes):
      return node
    elif isinstance(node, list):
      # If something's field was actually a list, e.g., variadic arguments.
      return [self._ensure_node_is_trivial(n) for n in node]
    elif isinstance(node, gast.keyword):
      node.value = self._ensure_node_is_trivial(node.value)
      return node
    elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)):
      return self._ensure_fields_trivial(node)
    elif isinstance(node, gast.expr):
      temp_name = self._gensym.new_name()
      temp_assign = templates.replace(
          'temp_name = expr', temp_name=temp_name, expr=node)[0]
      self._add_pending_statement(temp_assign)
      answer = templates.replace('temp_name', temp_name=temp_name)[0]
      return answer
    else:
      raise ValueError('Do not know how to treat {}'.format(node))

  def _ensure_fields_trivial(self, node):
    for field in node._fields:
      if field.startswith('__'):
        continue
      setattr(node, field, self._ensure_node_is_trivial(getattr(node, field)))
    return node

  def _visit_strict_statement(self, node, trivialize_children=True):
    assert not self._pending_statements
    node = self.generic_visit(node)
    if trivialize_children:
      self._ensure_fields_trivial(node)
    results = self._consume_pending_statements()
    results.append(node)
    return results

  def _visit_strict_expression(self, node):
    node = self.generic_visit(node)
    self._ensure_fields_trivial(node)
    return node

  # Note on code order: These are listed in the same order as the grammar
  # elements on https://github.com/serge-sans-paille/gast

  # FunctionDef, AsyncFunctionDef, and ClassDef should be correct by default.

  def visit_Return(self, node):
    return self._visit_strict_statement(node)

  def visit_Delete(self, node):
    return self._visit_strict_statement(node, trivialize_children=False)

  def visit_Assign(self, node):
    return self._visit_strict_statement(node, trivialize_children=False)

  def visit_AugAssign(self, node):
    return self._visit_strict_statement(node, trivialize_children=False)

  def visit_Print(self, node):
    return self._visit_strict_statement(node)

  def visit_For(self, node):
    assert not self._pending_statements
    # It's important to visit node.iter first, because any statements created
    # thereby need to live outside the body.
    self.visit(node.iter)
    node.iter = self._ensure_node_is_trivial(node.iter)
    iter_stmts = self._consume_pending_statements()
    # This generic_visit will revisit node.iter, but that is both correct and
    # cheap because by this point node.iter is trivial.
    node = self.generic_visit(node)
    assert not self._pending_statements
    iter_stmts.append(node)
    return iter_stmts

  def visit_AsyncFor(self, node):
    if not self._are_children_trivial(node):
      msg = ('Nontrivial AsyncFor nodes not supported yet '
             '(need to think through the semantics).')
      raise ValueError(msg)
    return self.generic_visit(node)

  def visit_While(self, node):
    if not self._is_node_trivial(node.test):
      msg = ('While with nontrivial test not supported yet '
             '(need to avoid precomputing the test).')
      raise ValueError(msg)
    return self.generic_visit(node)

  def visit_If(self, node):
    assert not self._pending_statements
    # It's important to visit node.test first, because any statements created
    # thereby need to live outside the body.
    self.visit(node.test)
    node.test = self._ensure_node_is_trivial(node.test)
    condition_stmts = self._consume_pending_statements()
    # This generic_visit will revisit node.test, but that is both correct and
    # cheap because by this point node.test is trivial.
    node = self.generic_visit(node)
    assert not self._pending_statements
    condition_stmts.append(node)
    return condition_stmts

  def visit_With(self, node):
    assert not self._pending_statements
    # It's important to visit node.items first, because any statements created
    # thereby need to live outside the body.
    for item in node.items:
      self.visit(item)
    node.items = [self._ensure_node_is_trivial(n) for n in node.items]
    contexts_stmts = self._consume_pending_statements()
    # This generic_visit will revisit node.items, but that is both correct and
    # cheap because by this point node.items is trivial.
    node = self.generic_visit(node)
    assert not self._pending_statements
    contexts_stmts.append(node)
    return contexts_stmts

  def visit_AsyncWith(self, node):
    if not self._are_children_trivial(node):
      msg = ('Nontrivial AsyncWith nodes not supported yet '
             '(need to think through the semantics).')
      raise ValueError(msg)
    return self.generic_visit(node)

  def visit_Raise(self, node):
    return self._visit_strict_statement(node)

  # Try should be correct by default.

  def visit_Assert(self, node):
    if not self._are_children_trivial(node):
      msg = ('Nontrivial Assert nodes not supported yet '
             '(need to avoid computing the test when assertions are off, and '
             'avoid computing the irritant when the assertion does not fire).')
      raise ValueError(msg)
    return self.generic_visit(node)

  # Import and ImportFrom should be correct by default.

  def visit_Exec(self, node):
    return self._visit_strict_statement(node)

  # Global and Nonlocal should be correct by default.

  def visit_Expr(self, node):
    return self._visit_strict_statement(node, trivialize_children=False)

  # Pass, Break, and Continue should be correct by default.

  def visit_BoolOp(self, node):
    if not self._are_children_trivial(node):
      msg = ('Nontrivial BoolOp nodes not supported yet '
             '(need to preserve short-circuiting semantics).')
      raise ValueError(msg)
    return self.generic_visit(node)

  def visit_BinOp(self, node):
    return self._visit_strict_expression(node)

  def visit_UnaryOp(self, node):
    return self._visit_strict_expression(node)

  def visit_Lambda(self, node):
    if not self._are_children_trivial(node):
      msg = ('Nontrivial Lambda nodes not supported '
             '(cannot insert statements into lambda bodies).')
      raise ValueError(msg)
    return self.generic_visit(node)

  def visit_IfExp(self, node):
    if not self._are_children_trivial(node):
      msg = ('Nontrivial IfExp nodes not supported yet '
             '(need to convert to If statement, to evaluate branches lazily '
             'and insert statements into them).')
      raise ValueError(msg)
    return self.generic_visit(node)

  def visit_Dict(self, node):
    return self._visit_strict_expression(node)

  def visit_Set(self, node):
    return self._visit_strict_expression(node)

  def visit_ListComp(self, node):
    msg = ('ListComp nodes not supported '
           '(need to convert to a form that tolerates '
           'assignment statements in clause bodies).')
    raise ValueError(msg)

  def visit_SetComp(self, node):
    msg = ('SetComp nodes not supported '
           '(need to convert to a form that tolerates '
           'assignment statements in clause bodies).')
    raise ValueError(msg)

  def visit_DictComp(self, node):
    msg = ('DictComp nodes not supported '
           '(need to convert to a form that tolerates '
           'assignment statements in clause bodies).')
    raise ValueError(msg)

  def visit_GeneratorExp(self, node):
    msg = ('GeneratorExp nodes not supported '
           '(need to convert to a form that tolerates '
           'assignment statements in clause bodies).')
    raise ValueError(msg)

  def visit_Await(self, node):
    if not self._are_children_trivial(node):
      msg = ('Nontrivial Await nodes not supported yet '
             '(need to think through the semantics).')
      raise ValueError(msg)
    return self.generic_visit(node)

  def visit_Yield(self, node):
    return self._visit_strict_expression(node)

  def visit_YieldFrom(self, node):
    if not self._are_children_trivial(node):
      msg = ('Nontrivial YieldFrom nodes not supported yet '
             '(need to unit-test them in Python 2).')
      raise ValueError(msg)
    return self.generic_visit(node)

  def visit_Compare(self, node):
    if len(node.ops) > 1:
      msg = ('Multi-ary compare nodes not supported yet '
             '(need to preserve short-circuiting semantics).')
      raise ValueError(msg)
    return self._visit_strict_expression(node)

  def visit_Call(self, node):
    return self._visit_strict_expression(node)

  def visit_Repr(self, node):
    if not self._are_children_trivial(node):
      msg = ('Nontrivial Repr nodes not supported yet '
             '(need to research their syntax and semantics).')
      raise ValueError(msg)
    return self.generic_visit(node)

  def visit_FormattedValue(self, node):
    if not self._are_children_trivial(node):
      msg = ('Nontrivial FormattedValue nodes not supported yet '
             '(need to unit-test them in Python 2).')
      raise ValueError(msg)
    return self.generic_visit(node)

  def visit_JoinedStr(self, node):
    if not self._are_children_trivial(node):
      msg = ('Nontrivial JoinedStr nodes not supported yet '
             '(need to unit-test them in Python 2).')
      raise ValueError(msg)
    return self.generic_visit(node)

  def visit_Attribute(self, node):
    return self._visit_strict_expression(node)

  def visit_Subscript(self, node):
    return self._visit_strict_expression(node)

  # Starred and Name are correct by default, because the right thing to do is to
  # just recur.

  def visit_List(self, node):
    node = self.generic_visit(node)
    if not isinstance(node.ctx, gast.Store):
      self._ensure_fields_trivial(node)
    return node

  def visit_Tuple(self, node):
    node = self.generic_visit(node)
    if not isinstance(node.ctx, gast.Store):
      self._ensure_fields_trivial(node)
    return node


def transform(node, entity_info, gensym_source=None):
  """Converts the given node to A-normal form (ANF).

  The general idea of A-normal form: https://en.wikipedia.org/wiki/A-normal_form

  The specific converters used here are based on Python AST semantics as
  documented at https://greentreesnakes.readthedocs.io/en/latest/.

  Args:
    node: The node to transform.
    entity_info: transformer.EntityInfo.  TODO(mdan): What information does this
      argument provide?
    gensym_source: An optional object with the same interface as `DummyGensym`
      for generating unique names.
  """
  return AnfTransformer(entity_info, gensym_source=gensym_source).visit(node)