aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/converters/return_statements.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/converters/return_statements.py')
-rw-r--r--tensorflow/contrib/autograph/converters/return_statements.py317
1 files changed, 317 insertions, 0 deletions
diff --git a/tensorflow/contrib/autograph/converters/return_statements.py b/tensorflow/contrib/autograph/converters/return_statements.py
new file mode 100644
index 0000000000..a351cd81b8
--- /dev/null
+++ b/tensorflow/contrib/autograph/converters/return_statements.py
@@ -0,0 +1,317 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Canonicalizes functions with multiple returns to use just one."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.autograph.core import converter
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import ast_util
+from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+
+
+# TODO(mdan): Move this logic into transformer_base.
+class BodyVisitor(converter.Base):
+ """Walks breadth- or depth-first the list-of-nodes bodies of AST nodes."""
+
+ def __init__(self, ctx, depth_first=False):
+ super(BodyVisitor, self).__init__(ctx)
+ self.depth_first = depth_first
+ self.changes_made = False
+
+ def visit_nodelist(self, nodelist):
+ for node in nodelist:
+ if isinstance(node, list):
+ node = self.visit_nodelist(node)
+ else:
+ node = self.generic_visit(node)
+ return nodelist
+
+ def visit_If(self, node):
+ if self.depth_first:
+ node = self.generic_visit(node)
+ node.body = self.visit_nodelist(node.body)
+ node.orelse = self.visit_nodelist(node.orelse)
+ if not self.depth_first:
+ node = self.generic_visit(node)
+ return node
+
+ def visit_For(self, node):
+ if self.depth_first:
+ node = self.generic_visit(node)
+ node.body = self.visit_nodelist(node.body)
+ node.orelse = self.visit_nodelist(node.orelse)
+ if not self.depth_first:
+ node = self.generic_visit(node)
+ return node
+
+ def visit_While(self, node):
+ if self.depth_first:
+ node = self.generic_visit(node)
+ node.body = self.visit_nodelist(node.body)
+ node.orelse = self.visit_nodelist(node.orelse)
+ if not self.depth_first:
+ node = self.generic_visit(node)
+ return node
+
+ def visit_Try(self, node):
+ if self.depth_first:
+ node = self.generic_visit(node)
+ node.body = self.visit_nodelist(node.body)
+ node.orelse = self.visit_nodelist(node.orelse)
+ node.finalbody = self.visit_nodelist(node.finalbody)
+ for i in range(len(node.handlers)):
+ node.handlers[i].body = self.visit_nodelist(node.handlers[i].body)
+ if not self.depth_first:
+ node = self.generic_visit(node)
+ return node
+
+ def visit_With(self, node):
+ if self.depth_first:
+ node = self.generic_visit(node)
+ node.body = self.visit_nodelist(node.body)
+ if not self.depth_first:
+ node = self.generic_visit(node)
+ return node
+
+ def visit_FunctionDef(self, node):
+ if self.depth_first:
+ node = self.generic_visit(node)
+ node.body = self.visit_nodelist(node.body)
+ self.generic_visit(node)
+ if not self.depth_first:
+ node = self.generic_visit(node)
+ return node
+
+
+class FoldElse(BodyVisitor):
+
+ def visit_nodelist(self, nodelist):
+ for i in range(len(nodelist)):
+ node = nodelist[i]
+ if isinstance(node, gast.If):
+ true_branch_returns = isinstance(node.body[-1], gast.Return)
+ false_branch_returns = len(node.orelse) and isinstance(
+ node.orelse[-1], gast.Return)
+ # If the last node in the if body is a return,
+ # then every line after this if statement effectively
+ # belongs in the else.
+ if true_branch_returns and not false_branch_returns:
+ for j in range(i + 1, len(nodelist)):
+ nodelist[i].orelse.append(ast_util.copy_clean(nodelist[j]))
+ if nodelist[i + 1:]:
+ self.changes_made = True
+ return nodelist[:i + 1]
+ elif not true_branch_returns and false_branch_returns:
+ for j in range(i + 1, len(nodelist)):
+ nodelist[i].body.append(ast_util.copy_clean(nodelist[j]))
+ if nodelist[i + 1:]:
+ self.changes_made = True
+ return nodelist[:i + 1]
+ elif true_branch_returns and false_branch_returns:
+ if nodelist[i + 1:]:
+ raise ValueError(
+ 'Unreachable code after conditional where both branches return.'
+ )
+ return nodelist
+ elif isinstance(node, gast.Return) and nodelist[i + 1:]:
+ raise ValueError(
+ 'Cannot have statements after a return in the same basic block')
+ return nodelist
+
+
+def contains_return(node):
+ for n in gast.walk(node):
+ if isinstance(n, gast.Return):
+ return True
+ return False
+
+
+class LiftReturn(converter.Base):
+ """Move return statements out of If and With blocks."""
+
+ def __init__(self, ctx):
+ super(LiftReturn, self).__init__(ctx)
+ self.changes_made = False
+ self.common_return_name = None
+
+ def visit_If(self, node):
+ # Depth-first traversal of if statements
+ node = self.generic_visit(node)
+
+ # We check if both branches return, and if so, lift the return out of the
+ # conditional. We don't enforce that the true and false branches either
+ # both return or both do not, because FoldElse might move a return
+ # into a branch after this transform completes. FoldElse and LiftReturn
+ # are alternately run until the code reaches a fixed point.
+ true_branch_returns = isinstance(node.body[-1], gast.Return)
+ false_branch_returns = len(node.orelse) and isinstance(
+ node.orelse[-1], gast.Return)
+ if true_branch_returns and false_branch_returns:
+ node.body[-1] = templates.replace(
+ 'a = b', a=self.common_return_name, b=node.body[-1].value)[0]
+ node.orelse[-1] = templates.replace(
+ 'a = b', a=self.common_return_name, b=node.orelse[-1].value)[0]
+ return_node = templates.replace('return a', a=self.common_return_name)[0]
+ self.changes_made = True
+ return [node, return_node]
+ else:
+ return node
+
+ def visit_With(self, node):
+ # Depth-first traversal of syntax
+ node = self.generic_visit(node)
+
+ # If the with statement returns, lift the return
+ if isinstance(node.body[-1], gast.Return):
+ node.body[-1] = templates.replace(
+ 'a = b', a=self.common_return_name, b=node.body[-1].value)[0]
+ return_node = templates.replace('return a', a=self.common_return_name)[0]
+ node = self.generic_visit(node)
+ self.changes_made = True
+ return [node, return_node]
+ else:
+ return node
+
+ def visit_FunctionDef(self, node):
+ # Ensure we're doing depth-first traversal
+ last_return_name = self.common_return_name
+ body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+ referenced_names = body_scope.referenced
+ self.common_return_name = self.ctx.namer.new_symbol('return_',
+ referenced_names)
+ node = self.generic_visit(node)
+ self.common_return_name = last_return_name
+ return node
+
+
+class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor):
+ """Throws an error if code returns inside loops or try/except."""
+
+ # First, throw an error if we detect a return statement in a loop.
+ # TODO(alexbw): we need to learn to handle returns inside a loop,
+ # but don't currently have the TF constructs to do so (need something
+ # that looks vaguely like a goto).
+
+ def __init__(self):
+ self.cant_return = False
+ super(DetectReturnInUnsupportedControlFlow, self).__init__()
+
+ def visit_While(self, node):
+ self.cant_return = True
+ self.generic_visit(node)
+ self.cant_return = False
+
+ def visit_For(self, node):
+ self.cant_return = True
+ self.generic_visit(node)
+ self.cant_return = False
+
+ def visit_Try(self, node):
+ self.cant_return = True
+ self.generic_visit(node)
+ self.cant_return = False
+
+ def visit_Return(self, node):
+ if self.cant_return:
+ raise ValueError(
+ '`return` statements are not supported in loops. '
+ 'Try assigning to a variable in the while loop, and returning '
+ 'outside of the loop')
+
+
+class DetectReturnInConditional(gast.NodeVisitor):
+ """Assert that no return statements are present in conditionals."""
+
+ def __init__(self):
+ self.cant_return = False
+ super(DetectReturnInConditional, self).__init__()
+
+ def visit_If(self, node):
+ self.cant_return = True
+ self.generic_visit(node)
+ self.cant_return = False
+
+ def visit_Return(self, node):
+ if self.cant_return:
+ raise ValueError(
+ 'After transforms, a conditional contained a `return `statement, '
+ 'which is not allowed. This is a bug, and should not happen.')
+
+
+class DetectReturnInFunctionDef(gast.NodeVisitor):
+
+ def visit_FunctionDef(self, node):
+ self.generic_visit(node)
+ if not contains_return(node):
+ raise ValueError(
+ 'Each function definition should contain at least one return.')
+
+
+def transform(node, ctx):
+ """Ensure a function has only a single return.
+
+ This transforms an AST node with multiple returns successively into containing
+ only a single return node.
+ There are a few restrictions on what we can handle:
+ - An AST being transformed must contain at least one return.
+ - No returns allowed in loops. We have to know the type of the return value,
+ and we currently don't have either a type inference system to discover it,
+ nor do we have a mechanism for late type binding in TensorFlow.
+ - After all transformations are finished, a Return node is not allowed inside
+ control flow. If we were unable to move a return outside of control flow,
+ this is an error.
+
+ Args:
+ node: ast.AST
+ ctx: converter.EntityContext
+
+ Returns:
+ new_node: an AST with a single return value
+
+ Raises:
+ ValueError: if the AST is structured so that we can't perform the
+ transform.
+ """
+ # Make sure that the function has at least one return statement
+ # TODO(alexbw): turning off this assertion for now --
+ # we need to not require this in e.g. class constructors.
+ # DetectReturnInFunctionDef().visit(node)
+
+ # Make sure there's no returns in unsupported locations (loops, try/except)
+ DetectReturnInUnsupportedControlFlow().visit(node)
+
+ while True:
+
+ # Try to lift all returns out of if statements and with blocks
+ lr = LiftReturn(ctx)
+ node = lr.visit(node)
+ changes_made = lr.changes_made
+ fe = FoldElse(ctx)
+ node = fe.visit(node)
+ changes_made = changes_made or fe.changes_made
+
+ if not changes_made:
+ break
+
+ # Make sure we've scrubbed all returns from conditionals
+ DetectReturnInConditional().visit(node)
+
+ return node