aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/core/converter.py
blob: 7b3905fdeed4ba6e0a26029f81a32f4374ef29c0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converter construction support.

This module contains a base class for all converters, as well as supporting
structures. These structures are referred to as contexts.

The class hierarchy is as follows:

    <your converter>
      [extends] converter.Base
        [extends] transformer.Base
            [extends] gast.nodeTransformer
          [uses] transfomer.SourceInfo
        [uses] converter.EntityContext
          [uses] converter.ProgramContext
          [uses] transfomer.SourceInfo

converter.Base is a specialization of transformer.Base for AutoGraph. It's a
very lightweight subclass that adds a `ctx` attribute holding the corresponding
EntityContext object (see below). Note that converters are not reusable, and
`visit` will raise an error if called more than once.

converter.EntityContext contains mutable state associated with an entity that
the converter processes.

converter.ProgramContext contains mutable state across related entities. For
example, when converting several functions that call one another, the
ProgramContext should be shared across these entities.

Below is the overal flow at conversion:

    program_ctx = ProgramContext(<entities to convert>, <global settings>, ...)
    while <program_ctx has more entities to convert>:
      entity, source_info = <get next entity from program_ctx>
      entity_ctx = EntityContext(program_ctx, source_info)
      for <each ConverterClass>:
        converter = ConverterClass(entity_ctx)

        # May update entity_ctx and program_ctx
        entity = converter.visit(entity)

      <add entity's dependencies to program_ctx>

Note that pyct contains a small number of transformers used for static analysis.
These implement transformer.Base, rather than converter.Base, to avoid a
dependency on AutoGraph.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
from enum import Enum


from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import naming
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import live_values
from tensorflow.python.autograph.pyct.static_analysis import liveness
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
from tensorflow.python.autograph.pyct.static_analysis import type_info

# TODO(mdan): These contexts can be refactored into first class objects.
# For example, we could define Program and Entity abstractions that hold on
# to the actual entity and have conversion methods.

# TODO(mdan): Add a test specific to this converter.


class ProgramContext(object):
  """ProgramContext keeps track of converting function hierarchies.

  This object is mutable, and is updated during conversion. Not thread safe.

  Attributes:
    recursive: bool, whether to recursively convert any functions that the
        decorator function may call.
    autograph_decorators: Tuple[Callable, ...], decorator functions that belong
        to AutoGraph. These require special treatment.
    dependency_cache: Dict[Any, ast.AST], the original entities mapped to their
        converted AST
    additional_imports: Set[Any], additional entities which for any reason
        cannot be attached after loading and need to be explicitly imported
        in the generated code
    name_map: Dict[str, str], map of original entity name to the name of
        their converted counterparts
    autograph_module: Module, a reference to the autograph module. This
        needs to be specified by the caller to avoid circular dependencies.
    uncompiled_modules: Set[Tuple[str, ...]], with each tuple representing the
        fully qualified name of a package containing functions that will not be
        compiled.
    required_imports: str, containing an import statement on each line. These
        are all the imports necessary for the compiled code to run, in addition
        to the closures of each entity, which are attached dynamically.
  """

  def __init__(
      self,
      recursive,
      autograph_decorators,
      partial_types,
      autograph_module,
      uncompiled_modules,
  ):
    self.recursive = recursive
    self.autograph_decorators = autograph_decorators
    self.partial_types = partial_types if partial_types else ()
    self.autograph_module = autograph_module
    self.uncompiled_modules = uncompiled_modules

    # Required to output dependencies in discovery order, which should match
    # the reverse dependency order.
    self.dependency_cache = collections.OrderedDict()
    self.additional_imports = set()
    self.name_map = {}

  @property
  def required_imports(self):
    """Returns a block containing all imports required by the converted code."""
    # TODO(mdan): Check that these don't clobber one another.
    return '\n'.join(config.COMPILED_IMPORT_STATEMENTS +
                     tuple(self.additional_imports))

  def new_namer(self, namespace):
    return naming.Namer(namespace, self.recursive, self.name_map,
                        self.partial_types)

  def update_name_map(self, namer):
    """Updates renamed_calls based on the recent activity from the namer.

    Whenever we convert a new entity, any references to other entities are being
    renamed to match their soon-to-be-converted counterparts. The namer keeps
    track of these renames. When conversion is complete, we copy those renames
    so that when those referenced entities are being converted, their new name
    matches.

    Args:
      namer: naming.Namer

    Raises:
      ValueError: when an entity was renamed twice and to different names.
    """
    # TODO(mdan): Have call_trees do this directly.
    # This is done so indirectly, via the namer, for historic reasons. But
    # now we can have the converter that does the rename record the new name
    # as well and skip this step altogether.
    for o, name in namer.renamed_calls.items():
      if o in self.name_map:
        if self.name_map[o] != name:
          raise ValueError(
              'Calls to %s were converted using multiple names (%s). This is '
              'possible when an entity with one of these names already '
              'existed. To fix, avoid using any of these names.' %
              (o, (name, self.name_map[o])))
      else:
        self.name_map[o] = name

  def add_to_cache(self, original_entity, converted_ast):
    self.dependency_cache[original_entity] = converted_ast


class EntityContext(object):
  """Tracks the conversion of a single entity.

  This object is mutable, and is updated during conversion. Not thread safe.

  Attributes:
    namer: Namer
    info: transformer.EntityInfo
    program: ProgramContext
  """

  def __init__(self, namer, entity_info, program_ctx):
    self.namer = namer
    self.info = entity_info
    self.program = program_ctx


class Base(transformer.Base):
  """All converters should inherit from this class.

  Attributes:
    ctx: EntityContext
  """

  def __init__(self, ctx):
    super(Base, self).__init__(ctx.info)
    self.ctx = ctx  # Keeping this short because it's used frequently.

    self._used = False
    self._ast_depth = 0

  def get_definition_directive(self, node, directive, arg, default):
    """Returns the unique directive for a symbol, or a default if none exist.

    See lang/directives.py for details on directives.

    Args:
      node: ast.AST
      directive: Callable[..., Any]
      arg: str
      default: Any

    Raises:
      ValueError: if conflicting annotations have been found
    """
    defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
    if not defs:
      return default

    # TODO(mdan): Simplify this.
    arg_values = []
    for def_ in defs:
      if (directive not in def_.directives or
          arg not in def_.directives[directive]):
        continue
      arg_value = def_.directives[directive][arg]
      for prev_value in arg_values:
        if not ast_util.matches(arg_value, prev_value):
          qn = anno.getanno(node, anno.Basic.QN)
          raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
                           (qn, directive.__name__, arg,
                            compiler.ast_to_source(arg_value).strip(),
                            compiler.ast_to_source(prev_value).strip()))
      arg_values.append(arg_value)

    if not arg_values:
      return default

    arg_value, = arg_values
    return arg_value

  def visit(self, node):
    if not self._ast_depth:
      if self._used:
        raise ValueError('converter objects cannot be reused')
      self._used = True

    self._ast_depth += 1
    try:
      return super(Base, self).visit(node)
    finally:
      self._ast_depth -= 1


class AnnotatedDef(reaching_definitions.Definition):

  def __init__(self):
    super(AnnotatedDef, self).__init__()
    self.directives = {}


class AgAnno(Enum):
  """Annotation labels specific to AutoGraph. See anno.py."""

  DIRECTIVES = 'User directives associated with the annotated statement.'

  def __repr__(self):
    return self.name


def standard_analysis(node, context, is_initial=False):
  """Performs a complete static analysis of the given code.

  Args:
    node: ast.AST
    context: converter.EntityContext
    is_initial: bool, whether this is the initial analysis done on the input
        source code

  Returns:
    ast.AST, same as node, with the static analysis annotations added
  """
  # TODO(mdan): Clear static analysis here.
  # TODO(mdan): Consider not running all analyses every time.
  # TODO(mdan): Don't return a node because it's modified by reference.
  graphs = cfg.build(node)
  node = qual_names.resolve(node)
  node = activity.resolve(node, context.info, None)
  node = reaching_definitions.resolve(node, context.info, graphs, AnnotatedDef)
  node = liveness.resolve(node, context.info, graphs)
  node = live_values.resolve(node, context.info, config.PYTHON_LITERALS)
  node = type_info.resolve(node, context.info)
  # This second call allows resolving first-order class attributes.
  node = live_values.resolve(node, context.info, config.PYTHON_LITERALS)
  if is_initial:
    anno.dup(
        node,
        {
            anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
        },
    )
  return node


def apply_(node, context, converter_module):
  """Applies a converter to an AST.

  Args:
    node: ast.AST
    context: converter.EntityContext
    converter_module: converter.Base

  Returns:
    ast.AST, the result of applying converter to node
  """
  node = standard_analysis(node, context)
  node = converter_module.transform(node, context)
  return node