aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-30 11:45:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-30 11:48:41 -0700
commit3328243cdca5d08f56fc64c582ce2f3b80630259 (patch)
treea0342cb65909f8aaf91c7eee4d92898c6af71b13 /tensorflow/contrib/autograph
parentf48928939d1e882e8a03f15be67e59405a1ddbc9 (diff)
Add an A-normal form transformer for Python code to pyct.
The purpose of A-normal form is to assign every intermediate value to an explicit variable, so that downstream transformations have those variables to associate information with. https://en.wikipedia.org/wiki/A-normal_form This transformer is mostly complete, but there are a few corner cases with room for improvement (notably constructs that only appear in Python 3). PiperOrigin-RevId: 206619935
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/BUILD1
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/anf.py381
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py356
3 files changed, 725 insertions, 13 deletions
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
index ca1441cf6f..a0938b3e5f 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
@@ -24,6 +24,7 @@ py_library(
deps = [
"//tensorflow/contrib/autograph/pyct",
"@gast_archive//:gast",
+ "@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
index cc039986c2..e42f679cfe 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
@@ -12,12 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Conversion to A-normal form."""
+"""Conversion to A-normal form.
+
+The general idea of A-normal form is that every intermediate value is
+explicitly named with a variable. For more, see
+https://en.wikipedia.org/wiki/A-normal_form.
+
+The specific converters used here are based on Python AST semantics as
+documented at https://greentreesnakes.readthedocs.io/en/latest/.
+"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import gast
+import six
+
+from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct import transformer
@@ -32,26 +44,375 @@ class DummyGensym(object):
# * the symbols generated so far
self._idx = 0
- def new_name(self, stem):
+ def new_name(self, stem='tmp'):
self._idx += 1
return stem + '_' + str(1000 + self._idx)
class AnfTransformer(transformer.Base):
- """Performs the actual conversion."""
+ """Performs the conversion to A-normal form (ANF)."""
- # TODO(mdan): Link to a reference.
- # TODO(mdan): Implement.
+ # The algorithm is a postorder recursive tree walk. Any given node A may, in
+ # general, require creation of a series B of Assign statements, which compute
+ # and explicitly name the intermediate values needed to compute the value of
+ # A. If A was already a statement, it can be replaced with the sequence B +
+ # [A]. If A was an expression, B needs to be propagated up the tree until a
+ # statement is encountered. Since the `ast.NodeTransformer` framework makes
+ # no provision for subtraversals returning side information, this class
+ # accumulates the sequence B in an instance variable.
- def __init__(self, entity_info):
- """Creates a transformer.
+ # The only other subtlety is that some Python statements (like `if`) have both
+ # expression fields (`test`) and statement list fields (`body` and `orelse`).
+ # Any additional assignments needed to name all the intermediate values in the
+ # `test` can be prepended to the `if` node, but assignments produced by
+ # processing the `body` and the `orelse` need to be kept together with them,
+ # and not accidentally lifted out of the `if`.
+
+ def __init__(self, entity_info, gensym_source=None):
+ """Creates an ANF transformer.
Args:
entity_info: transformer.EntityInfo
+ gensym_source: An optional object with the same interface as `DummyGensym`
+ for generating unique names
"""
super(AnfTransformer, self).__init__(entity_info)
- self._gensym = DummyGensym(entity_info)
+ if gensym_source is None:
+ self._gensym = DummyGensym(entity_info)
+ else:
+ self._gensym = gensym_source(entity_info)
+ self._pending_statements = []
+
+ def _consume_pending_statements(self):
+ ans = self._pending_statements
+ self._pending_statements = []
+ return ans
+
+ def _add_pending_statement(self, stmt):
+ self._pending_statements.append(stmt)
+
+ _trivial_nodes = (
+ # Non-nodes that show up as AST fields
+ bool, six.string_types,
+ # Leaf nodes that are already in A-normal form
+ gast.expr_context, gast.Name, gast.Num, gast.Str, gast.Bytes,
+ gast.NameConstant, gast.Ellipsis,
+ # Binary operators
+ gast.Add, gast.Sub, gast.Mult, gast.Div, gast.Mod, gast.Pow, gast.LShift,
+ gast.RShift, gast.BitOr, gast.BitXor, gast.BitAnd, gast.FloorDiv,
+ # Unary operators
+ gast.Invert, gast.Not, gast.UAdd, gast.USub,
+ # Comparison operators
+ gast.Eq, gast.NotEq, gast.Lt, gast.LtE, gast.Gt, gast.GtE,
+ gast.Is, gast.IsNot, gast.In, gast.NotIn,
+ )
+
+ def _is_node_trivial(self, node):
+ if node is None:
+ return True
+ elif isinstance(node, self._trivial_nodes):
+ return True
+ elif isinstance(node, gast.keyword):
+ return self._is_node_trivial(node.value)
+ elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)):
+ return self._are_children_trivial(node)
+ return False
+
+ def _are_children_trivial(self, node):
+ for field in node._fields:
+ if not field.startswith('__'):
+ if not self._is_node_trivial(getattr(node, field)):
+ return False
+ return True
+
+ def _ensure_node_is_trivial(self, node):
+ if node is None:
+ return node
+ elif isinstance(node, self._trivial_nodes):
+ return node
+ elif isinstance(node, list):
+ # If something's field was actually a list, e.g., variadic arguments.
+ return [self._ensure_node_is_trivial(n) for n in node]
+ elif isinstance(node, gast.keyword):
+ node.value = self._ensure_node_is_trivial(node.value)
+ return node
+ elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)):
+ return self._ensure_fields_trivial(node)
+ elif isinstance(node, gast.expr):
+ temp_name = self._gensym.new_name()
+ temp_assign = templates.replace(
+ 'temp_name = expr', temp_name=temp_name, expr=node)[0]
+ self._add_pending_statement(temp_assign)
+ answer = templates.replace('temp_name', temp_name=temp_name)[0]
+ return answer
+ else:
+ raise ValueError('Do not know how to treat {}'.format(node))
+
+ def _ensure_fields_trivial(self, node):
+ for field in node._fields:
+ if field.startswith('__'):
+ continue
+ setattr(node, field, self._ensure_node_is_trivial(getattr(node, field)))
+ return node
+
+ def _visit_strict_statement(self, node, trivialize_children=True):
+ assert not self._pending_statements
+ node = self.generic_visit(node)
+ if trivialize_children:
+ self._ensure_fields_trivial(node)
+ results = self._consume_pending_statements()
+ results.append(node)
+ return results
+
+ def _visit_strict_expression(self, node):
+ node = self.generic_visit(node)
+ self._ensure_fields_trivial(node)
+ return node
+
+ # Note on code order: These are listed in the same order as the grammar
+ # elements on https://github.com/serge-sans-paille/gast
+
+ # FunctionDef, AsyncFunctionDef, and ClassDef should be correct by default.
+
+ def visit_Return(self, node):
+ return self._visit_strict_statement(node)
+
+ def visit_Delete(self, node):
+ return self._visit_strict_statement(node, trivialize_children=False)
+
+ def visit_Assign(self, node):
+ return self._visit_strict_statement(node, trivialize_children=False)
+
+ def visit_AugAssign(self, node):
+ return self._visit_strict_statement(node, trivialize_children=False)
+
+ def visit_Print(self, node):
+ return self._visit_strict_statement(node)
+
+ def visit_For(self, node):
+ assert not self._pending_statements
+ # It's important to visit node.iter first, because any statements created
+ # thereby need to live outside the body.
+ self.visit(node.iter)
+ node.iter = self._ensure_node_is_trivial(node.iter)
+ iter_stmts = self._consume_pending_statements()
+ # This generic_visit will revisit node.iter, but that is both correct and
+ # cheap because by this point node.iter is trivial.
+ node = self.generic_visit(node)
+ assert not self._pending_statements
+ iter_stmts.append(node)
+ return iter_stmts
+
+ def visit_AsyncFor(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial AsyncFor nodes not supported yet '
+ '(need to think through the semantics).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_While(self, node):
+ if not self._is_node_trivial(node.test):
+ msg = ('While with nontrivial test not supported yet '
+ '(need to avoid precomputing the test).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_If(self, node):
+ assert not self._pending_statements
+ # It's important to visit node.test first, because any statements created
+ # thereby need to live outside the body.
+ self.visit(node.test)
+ node.test = self._ensure_node_is_trivial(node.test)
+ condition_stmts = self._consume_pending_statements()
+ # This generic_visit will revisit node.test, but that is both correct and
+ # cheap because by this point node.test is trivial.
+ node = self.generic_visit(node)
+ assert not self._pending_statements
+ condition_stmts.append(node)
+ return condition_stmts
+
+ def visit_With(self, node):
+ assert not self._pending_statements
+ # It's important to visit node.items first, because any statements created
+ # thereby need to live outside the body.
+ for item in node.items:
+ self.visit(item)
+ node.items = [self._ensure_node_is_trivial(n) for n in node.items]
+ contexts_stmts = self._consume_pending_statements()
+ # This generic_visit will revisit node.items, but that is both correct and
+ # cheap because by this point node.items is trivial.
+ node = self.generic_visit(node)
+ assert not self._pending_statements
+ contexts_stmts.append(node)
+ return contexts_stmts
+
+ def visit_AsyncWith(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial AsyncWith nodes not supported yet '
+ '(need to think through the semantics).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_Raise(self, node):
+ return self._visit_strict_statement(node)
+
+ # Try should be correct by default.
+
+ def visit_Assert(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial Assert nodes not supported yet '
+ '(need to avoid computing the test when assertions are off, and '
+ 'avoid computing the irritant when the assertion does not fire).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ # Import and ImportFrom should be correct by default.
+
+ def visit_Exec(self, node):
+ return self._visit_strict_statement(node)
+
+ # Global and Nonlocal should be correct by default.
+
+ def visit_Expr(self, node):
+ return self._visit_strict_statement(node, trivialize_children=False)
+
+ # Pass, Break, and Continue should be correct by default.
+
+ def visit_BoolOp(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial BoolOp nodes not supported yet '
+ '(need to preserve short-circuiting semantics).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_BinOp(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_UnaryOp(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_Lambda(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial Lambda nodes not supported '
+ '(cannot insert statements into lambda bodies).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_IfExp(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial IfExp nodes not supported yet '
+ '(need to convert to If statement, to evaluate branches lazily '
+ 'and insert statements into them).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_Dict(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_Set(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_ListComp(self, node):
+ msg = ('ListComp nodes not supported '
+ '(need to convert to a form that tolerates '
+ 'assignment statements in clause bodies).')
+ raise ValueError(msg)
+
+ def visit_SetComp(self, node):
+ msg = ('SetComp nodes not supported '
+ '(need to convert to a form that tolerates '
+ 'assignment statements in clause bodies).')
+ raise ValueError(msg)
+
+ def visit_DictComp(self, node):
+ msg = ('DictComp nodes not supported '
+ '(need to convert to a form that tolerates '
+ 'assignment statements in clause bodies).')
+ raise ValueError(msg)
+
+ def visit_GeneratorExp(self, node):
+ msg = ('GeneratorExp nodes not supported '
+ '(need to convert to a form that tolerates '
+ 'assignment statements in clause bodies).')
+ raise ValueError(msg)
+
+ def visit_Await(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial Await nodes not supported yet '
+ '(need to think through the semantics).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_Yield(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_YieldFrom(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial YieldFrom nodes not supported yet '
+ '(need to unit-test them in Python 2).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_Compare(self, node):
+ if len(node.ops) > 1:
+ msg = ('Multi-ary compare nodes not supported yet '
+ '(need to preserve short-circuiting semantics).')
+ raise ValueError(msg)
+ return self._visit_strict_expression(node)
+
+ def visit_Call(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_Repr(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial Repr nodes not supported yet '
+ '(need to research their syntax and semantics).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_FormattedValue(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial FormattedValue nodes not supported yet '
+ '(need to unit-test them in Python 2).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_JoinedStr(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial JoinedStr nodes not supported yet '
+ '(need to unit-test them in Python 2).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_Attribute(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_Subscript(self, node):
+ return self._visit_strict_expression(node)
+
+ # Starred and Name are correct by default, because the right thing to do is to
+ # just recur.
+
+ def visit_List(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_Tuple(self, node):
+ return self._visit_strict_expression(node)
+
+
+def transform(node, entity_info, gensym_source=None):
+ """Converts the given node to A-normal form (ANF).
+
+ The general idea of A-normal form: https://en.wikipedia.org/wiki/A-normal_form
+ The specific converters used here are based on Python AST semantics as
+ documented at https://greentreesnakes.readthedocs.io/en/latest/.
-def transform(node, entity_info):
- return AnfTransformer(entity_info).visit(node)
+ Args:
+ node: The node to transform.
+ entity_info: transformer.EntityInfo. TODO(mdan): What information does this
+ argument provide?
+ gensym_source: An optional object with the same interface as `DummyGensym`
+ for generating unique names.
+ """
+ return AnfTransformer(entity_info, gensym_source=gensym_source).visit(node)
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
index aefbc69d8c..951974820c 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import textwrap
+
from tensorflow.contrib.autograph.pyct import compiler
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import transformer
@@ -25,6 +27,22 @@ from tensorflow.contrib.autograph.pyct.common_transformers import anf
from tensorflow.python.platform import test
+class DummyGensym(object):
+ """A dumb gensym that suffixes a stem by sequential numbers from 1000."""
+
+ def __init__(self, entity_info):
+ del entity_info
+ # A proper implementation needs to account for:
+ # * entity_info.namespace
+ # * all the symbols defined in the AST
+ # * the symbols generated so far
+ self._idx = 0
+
+ def new_name(self, stem='tmp'):
+ self._idx += 1
+ return stem + '_' + str(1000 + self._idx)
+
+
class AnfTransformerTest(test.TestCase):
def _simple_source_info(self):
@@ -37,17 +55,349 @@ class AnfTransformerTest(test.TestCase):
owner_type=None)
def test_basic(self):
-
def test_function():
a = 0
return a
-
node, _ = parser.parse_entity(test_function)
node = anf.transform(node.body[0], self._simple_source_info())
result, _ = compiler.ast_to_object(node)
-
self.assertEqual(test_function(), result.test_function())
+ def assert_same_ast(self, expected_node, node, msg=None):
+ expected_source = compiler.ast_to_source(expected_node, indentation=' ')
+ expected_str = textwrap.dedent(expected_source).strip()
+ got_source = compiler.ast_to_source(node, indentation=' ')
+ got_str = textwrap.dedent(got_source).strip()
+ self.assertEqual(expected_str, got_str, msg=msg)
+
+ def assert_body_anfs_as_expected(self, expected_fn, test_fn):
+ # Testing the code bodies only. Wrapping them in functions so the
+ # syntax highlights nicely, but Python doesn't try to execute the
+ # statements.
+ exp_node, _ = parser.parse_entity(expected_fn)
+ node, _ = parser.parse_entity(test_fn)
+ node = anf.transform(
+ node, self._simple_source_info(), gensym_source=DummyGensym)
+ exp_name = exp_node.body[0].name
+ # Ignoring the function names in the result because they can't be
+ # the same (because both functions have to exist in the same scope
+ # at the same time).
+ node.body[0].name = exp_name
+ self.assert_same_ast(exp_node, node)
+ # Check that ANF is idempotent
+ node_repeated = anf.transform(
+ node, self._simple_source_info(), gensym_source=DummyGensym)
+ self.assert_same_ast(node_repeated, node)
+
+ def test_binop_basic(self):
+
+ def test_function(x, y, z):
+ a = x + y + z
+ return a
+
+ def expected_result(x, y, z):
+ tmp_1001 = x + y
+ a = tmp_1001 + z
+ return a
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_if_basic(self):
+
+ def test_function(a, b, c, e, f, g):
+ if a + b + c:
+ d = e + f + g
+ return d
+
+ def expected_result(a, b, c, e, f, g):
+ tmp_1001 = a + b
+ tmp_1002 = tmp_1001 + c
+ if tmp_1002:
+ tmp_1003 = e + f
+ d = tmp_1003 + g
+ return d
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_nested_binop_and_return(self):
+
+ def test_function(b, c, d, e):
+ return (2 * b + c) + (d + e)
+
+ def expected_result(b, c, d, e):
+ tmp_1001 = 2 * b
+ tmp_1002 = tmp_1001 + c
+ tmp_1003 = d + e
+ tmp_1004 = tmp_1002 + tmp_1003
+ return tmp_1004
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_function_call_and_expr(self):
+
+ def test_function(call_something, a, b, y, z, c, d, e, f, g, h, i):
+ call_something(a + b, y * z, kwarg=c + d, *(e + f), **(g + h + i))
+
+ def expected_result(call_something, a, b, y, z, c, d, e, f, g, h, i):
+ tmp_1001 = g + h
+ tmp_1002 = a + b
+ tmp_1003 = y * z
+ tmp_1004 = e + f
+ tmp_1005 = c + d
+ tmp_1006 = tmp_1001 + i
+ call_something(tmp_1002, tmp_1003, kwarg=tmp_1005, *tmp_1004, **tmp_1006)
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_with_and_print(self):
+
+ def test_function(a, b, c):
+ with a + b + c as d:
+ print(2 * d + 1)
+
+ def expected_result(a, b, c):
+ tmp_1001 = a + b
+ tmp_1002 = tmp_1001 + c
+ with tmp_1002 as d:
+ tmp_1003 = 2 * d
+ tmp_1004 = tmp_1003 + 1
+ print(tmp_1004)
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_local_definition_and_binary_compare(self):
+
+ def test_function():
+ def foo(a, b):
+ return 2 * a < b
+ return foo
+
+ def expected_result():
+ def foo(a, b):
+ tmp_1001 = 2 * a
+ tmp_1002 = tmp_1001 < b
+ return tmp_1002
+ return foo
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_list_literal(self):
+
+ def test_function(a, b, c, d, e, f):
+ return [a + b, c + d, e + f]
+
+ def expected_result(a, b, c, d, e, f):
+ tmp_1001 = a + b
+ tmp_1002 = c + d
+ tmp_1003 = e + f
+ tmp_1004 = [tmp_1001, tmp_1002, tmp_1003]
+ return tmp_1004
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_tuple_literal_and_unary(self):
+
+ def test_function(a, b, c, d, e, f):
+ return (a + b, -(c + d), e + f)
+
+ def expected_result(a, b, c, d, e, f):
+ tmp_1001 = c + d
+ tmp_1002 = a + b
+ tmp_1003 = -tmp_1001
+ tmp_1004 = e + f
+ tmp_1005 = (tmp_1002, tmp_1003, tmp_1004)
+ return tmp_1005
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_set_literal(self):
+
+ def test_function(a, b, c, d, e, f):
+ return set(a + b, c + d, e + f)
+
+ def expected_result(a, b, c, d, e, f):
+ tmp_1001 = a + b
+ tmp_1002 = c + d
+ tmp_1003 = e + f
+ tmp_1004 = set(tmp_1001, tmp_1002, tmp_1003)
+ return tmp_1004
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_dict_literal_and_repr(self):
+
+ def test_function(foo, bar, baz):
+ return repr({foo + bar + baz: 7 | 8})
+
+ def expected_result(foo, bar, baz):
+ tmp_1001 = foo + bar
+ tmp_1002 = tmp_1001 + baz
+ tmp_1003 = 7 | 8
+ tmp_1004 = {tmp_1002: tmp_1003}
+ tmp_1005 = repr(tmp_1004)
+ return tmp_1005
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_field_read_and_write(self):
+
+ def test_function(a, d):
+ a.b.c = d.e.f + 3
+
+ def expected_result(a, d):
+ tmp_1001 = a.b
+ tmp_1002 = d.e
+ tmp_1003 = tmp_1002.f
+ tmp_1001.c = tmp_1003 + 3
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_subscript_read_and_write(self):
+
+ def test_function(a, b, c, d, e, f):
+ a[b][c] = d[e][f] + 3
+
+ def expected_result(a, b, c, d, e, f):
+ tmp_1001 = a[b]
+ tmp_1002 = d[e]
+ tmp_1003 = tmp_1002[f]
+ tmp_1001[c] = tmp_1003 + 3
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_augassign_and_delete(self):
+
+ def test_function(a, x, y, z):
+ a += x + y + z
+ del a
+ del z[y][x]
+
+ def expected_result(a, x, y, z):
+ tmp_1001 = x + y
+ a += tmp_1001 + z
+ del a
+ tmp_1002 = z[y]
+ del tmp_1002[x]
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_raise_yield_and_raise(self):
+
+ def test_function(a, c, some_computed, exception):
+ yield a ** c
+ raise some_computed('complicated' + exception)
+
+ def expected_result(a, c, some_computed, exception):
+ tmp_1001 = a ** c
+ yield tmp_1001
+ tmp_1002 = 'complicated' + exception
+ tmp_1003 = some_computed(tmp_1002)
+ raise tmp_1003
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_with_and_if_with_expressions(self):
+
+ def test_function(foo, bar, function, quux, quozzle, w, x, y, z):
+ with foo + bar:
+ function(x + y)
+ if quux + quozzle:
+ function(z / w)
+
+ def expected_result(foo, bar, function, quux, quozzle, w, x, y, z):
+ tmp_1001 = foo + bar
+ with tmp_1001:
+ tmp_1002 = x + y
+ function(tmp_1002)
+ tmp_1003 = quux + quozzle
+ if tmp_1003:
+ tmp_1004 = z / w
+ function(tmp_1004)
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_exec(self):
+
+ def test_function():
+ # The point is to test A-normal form conversion of exec
+ # pylint: disable=exec-used
+ exec('computed' + 5 + 'stuff', globals(), locals())
+
+ def expected_result():
+ # pylint: disable=exec-used
+ tmp_1001 = 'computed' + 5
+ tmp_1002 = tmp_1001 + 'stuff'
+ tmp_1003 = globals()
+ tmp_1004 = locals()
+ exec(tmp_1002, tmp_1003, tmp_1004)
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_simple_while_and_assert(self):
+
+ def test_function(foo, quux):
+ while foo:
+ assert quux
+ foo = foo + 1 * 3
+
+ def expected_result(foo, quux):
+ while foo:
+ assert quux
+ tmp_1001 = 1 * 3
+ foo = foo + tmp_1001
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_for(self):
+
+ def test_function(compute, something, complicated, foo):
+ for foo in compute(something + complicated):
+ bar = foo + 1 * 3
+ return bar
+
+ def expected_result(compute, something, complicated, foo):
+ tmp_1001 = something + complicated
+ tmp_1002 = compute(tmp_1001)
+ for foo in tmp_1002:
+ tmp_1003 = 1 * 3
+ bar = foo + tmp_1003
+ return bar
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ # This test collects several examples where the definition of A-normal form
+ # implemented by this transformer is questionable. Mostly it's here to spell
+ # out what the definition is in these cases.
+ def test_controversial(self):
+
+ def test_function(b, c, d, f):
+ a = c + d
+ a.b = c + d
+ a[b] = c + d
+ a += c + d
+ a, b = c
+ a, b = c, d
+ a = f(c)
+ a = f(c + d)
+ a[b + d] = f.e(c + d)
+
+ def expected_result(b, c, d, f):
+ a = c + d
+ a.b = c + d # Should be a.b = tmp? (Definitely not tmp = c + d)
+ a[b] = c + d # Should be a[b] = tmp? (Definitely not tmp = c + d)
+ a += c + d # Should be a += tmp? (Definitely not tmp = c + d)
+ a, b = c # Should be a = c[0], b = c[1]? Or not?
+ a, b = c, d # Should be a = c, b = d? Or not?
+ a = f(c)
+ tmp_1001 = c + d
+ a = f(tmp_1001)
+ tmp_1002 = b + d
+ tmp_1003 = f.e
+ tmp_1004 = c + d
+ a[tmp_1002] = tmp_1003(tmp_1004) # Or should be a[tmp1] = tmp2?
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
if __name__ == '__main__':
test.main()