aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/converters/return_statements.py
blob: 496c99e3b5247c174f8a74e9b3f23517ddc649f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Canonicalizes functions with multiple returns to use just one."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gast

from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno


# TODO(mdan): Move this logic into transformer_base.
class BodyVisitor(converter.Base):
  """Walks breadth- or depth-first the list-of-nodes bodies of AST nodes."""

  def __init__(self, ctx, depth_first=False):
    super(BodyVisitor, self).__init__(ctx)
    self.depth_first = depth_first
    self.changes_made = False

  def visit_nodelist(self, nodelist):
    for node in nodelist:
      if isinstance(node, list):
        node = self.visit_nodelist(node)
      else:
        node = self.generic_visit(node)
    return nodelist

  def visit_If(self, node):
    if self.depth_first:
      node = self.generic_visit(node)
    node.body = self.visit_nodelist(node.body)
    node.orelse = self.visit_nodelist(node.orelse)
    if not self.depth_first:
      node = self.generic_visit(node)
    return node

  def visit_For(self, node):
    if self.depth_first:
      node = self.generic_visit(node)
    node.body = self.visit_nodelist(node.body)
    node.orelse = self.visit_nodelist(node.orelse)
    if not self.depth_first:
      node = self.generic_visit(node)
    return node

  def visit_While(self, node):
    if self.depth_first:
      node = self.generic_visit(node)
    node.body = self.visit_nodelist(node.body)
    node.orelse = self.visit_nodelist(node.orelse)
    if not self.depth_first:
      node = self.generic_visit(node)
    return node

  def visit_Try(self, node):
    if self.depth_first:
      node = self.generic_visit(node)
    node.body = self.visit_nodelist(node.body)
    node.orelse = self.visit_nodelist(node.orelse)
    node.finalbody = self.visit_nodelist(node.finalbody)
    for i in range(len(node.handlers)):
      node.handlers[i].body = self.visit_nodelist(node.handlers[i].body)
    if not self.depth_first:
      node = self.generic_visit(node)
    return node

  def visit_With(self, node):
    if self.depth_first:
      node = self.generic_visit(node)
    node.body = self.visit_nodelist(node.body)
    if not self.depth_first:
      node = self.generic_visit(node)
    return node

  def visit_FunctionDef(self, node):
    if self.depth_first:
      node = self.generic_visit(node)
    node.body = self.visit_nodelist(node.body)
    self.generic_visit(node)
    if not self.depth_first:
      node = self.generic_visit(node)
    return node


class FoldElse(BodyVisitor):

  def visit_nodelist(self, nodelist):
    for i in range(len(nodelist)):
      node = nodelist[i]
      if isinstance(node, gast.If):
        true_branch_returns = isinstance(node.body[-1], gast.Return)
        false_branch_returns = len(node.orelse) and isinstance(
            node.orelse[-1], gast.Return)
        # If the last node in the if body is a return,
        # then every line after this if statement effectively
        # belongs in the else.
        if true_branch_returns and not false_branch_returns:
          for j in range(i + 1, len(nodelist)):
            nodelist[i].orelse.append(ast_util.copy_clean(nodelist[j]))
          if nodelist[i + 1:]:
            self.changes_made = True
          return nodelist[:i + 1]
        elif not true_branch_returns and false_branch_returns:
          for j in range(i + 1, len(nodelist)):
            nodelist[i].body.append(ast_util.copy_clean(nodelist[j]))
          if nodelist[i + 1:]:
            self.changes_made = True
          return nodelist[:i + 1]
        elif true_branch_returns and false_branch_returns:
          if nodelist[i + 1:]:
            raise ValueError(
                'Unreachable code after conditional where both branches return.'
            )
          return nodelist
      elif isinstance(node, gast.Return) and nodelist[i + 1:]:
        raise ValueError(
            'Cannot have statements after a return in the same basic block')
    return nodelist


def contains_return(node):
  for n in gast.walk(node):
    if isinstance(n, gast.Return):
      return True
  return False


class LiftReturn(converter.Base):
  """Move return statements out of If and With blocks."""

  def __init__(self, ctx):
    super(LiftReturn, self).__init__(ctx)
    self.changes_made = False
    self.common_return_name = None

  def visit_If(self, node):
    # Depth-first traversal of if statements
    node = self.generic_visit(node)

    # We check if both branches return, and if so, lift the return out of the
    # conditional. We don't enforce that the true and false branches either
    # both return or both do not, because FoldElse might move a return
    # into a branch after this transform completes. FoldElse and LiftReturn
    # are alternately run until the code reaches a fixed point.
    true_branch_returns = isinstance(node.body[-1], gast.Return)
    false_branch_returns = len(node.orelse) and isinstance(
        node.orelse[-1], gast.Return)
    if true_branch_returns and false_branch_returns:
      node.body[-1] = templates.replace(
          'a = b', a=self.common_return_name, b=node.body[-1].value)[0]
      node.orelse[-1] = templates.replace(
          'a = b', a=self.common_return_name, b=node.orelse[-1].value)[0]
      return_node = templates.replace('return a', a=self.common_return_name)[0]
      self.changes_made = True
      return [node, return_node]
    else:
      return node

  def visit_With(self, node):
    # Depth-first traversal of syntax
    node = self.generic_visit(node)

    # If the with statement returns, lift the return
    if isinstance(node.body[-1], gast.Return):
      node.body[-1] = templates.replace(
          'a = b', a=self.common_return_name, b=node.body[-1].value)[0]
      return_node = templates.replace('return a', a=self.common_return_name)[0]
      node = self.generic_visit(node)
      self.changes_made = True
      return [node, return_node]
    else:
      return node

  def visit_FunctionDef(self, node):
    # Ensure we're doing depth-first traversal
    last_return_name = self.common_return_name
    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    referenced_names = body_scope.referenced
    self.common_return_name = self.ctx.namer.new_symbol('return_',
                                                        referenced_names)
    node = self.generic_visit(node)
    self.common_return_name = last_return_name
    return node


class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor):
  """Throws an error if code returns inside loops or try/except."""

  # First, throw an error if we detect a return statement in a loop.
  # TODO(alexbw): we need to learn to handle returns inside a loop,
  # but don't currently have the TF constructs to do so (need something
  # that looks vaguely like a goto).

  def __init__(self):
    self.cant_return = False
    self.function_level = 0
    super(DetectReturnInUnsupportedControlFlow, self).__init__()

  def visit_While(self, node):
    self.cant_return = True
    self.generic_visit(node)
    self.cant_return = False

  def visit_For(self, node):
    self.cant_return = True
    self.generic_visit(node)
    self.cant_return = False

  def visit_Try(self, node):
    self.cant_return = True
    self.generic_visit(node)
    self.cant_return = False

  def visit_FunctionDef(self, node):
    if not self.function_level:
      self.function_level += 1
      self.generic_visit(node)
      self.function_level -= 1

  def visit_Return(self, node):
    if self.cant_return:
      raise ValueError(
          '`return` statements are not supported in loops. '
          'Try assigning to a variable in the while loop, and returning '
          'outside of the loop')


class DetectReturnInConditional(gast.NodeVisitor):
  """Assert that no return statements are present in conditionals."""

  def __init__(self):
    self.cant_return = False
    self.function_level = 0
    super(DetectReturnInConditional, self).__init__()

  def visit_If(self, node):
    self.cant_return = True
    self.generic_visit(node)
    self.cant_return = False

  def visit_FunctionDef(self, node):
    if not self.function_level:
      self.function_level += 1
      self.generic_visit(node)
      self.function_level -= 1

  def visit_Return(self, node):
    if self.cant_return:
      raise ValueError(
          'After transforms, a conditional contained a `return `statement, '
          'which is not allowed. This is a bug, and should not happen.')


class DetectReturnInFunctionDef(gast.NodeVisitor):

  def visit_FunctionDef(self, node):
    self.generic_visit(node)
    if not contains_return(node):
      raise ValueError(
          'Each function definition should contain at least one return.')


def transform(node, ctx):
  """Ensure a function has only a single return.

  This transforms an AST node with multiple returns successively into containing
  only a single return node.
  There are a few restrictions on what we can handle:
   - An AST being transformed must contain at least one return.
   - No returns allowed in loops. We have to know the type of the return value,
   and we currently don't have either a type inference system to discover it,
   nor do we have a mechanism for late type binding in TensorFlow.
   - After all transformations are finished, a Return node is not allowed inside
   control flow. If we were unable to move a return outside of control flow,
   this is an error.

  Args:
     node: ast.AST
     ctx: converter.EntityContext

  Returns:
     new_node: an AST with a single return value

  Raises:
    ValueError: if the AST is structured so that we can't perform the
   transform.
  """
  # Make sure that the function has at least one return statement
  # TODO(alexbw): turning off this assertion for now --
  # we need to not require this in e.g. class constructors.
  # DetectReturnInFunctionDef().visit(node)

  # Make sure there's no returns in unsupported locations (loops, try/except)
  DetectReturnInUnsupportedControlFlow().visit(node)

  while True:

    # Try to lift all returns out of if statements and with blocks
    lr = LiftReturn(ctx)
    node = lr.visit(node)
    changes_made = lr.changes_made
    fe = FoldElse(ctx)
    node = fe.visit(node)
    changes_made = changes_made or fe.changes_made

    if not changes_made:
      break

  # Make sure we've scrubbed all returns from conditionals
  DetectReturnInConditional().visit(node)

  return node