aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-11 11:09:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-11 11:13:27 -0800
commit453f94412f908aadd21561c14feae80dfac1e933 (patch)
tree163730dfebfaceaa286501a1f067244bc6685a33
parent46d6620e4b7f71b420c58df90e7ceb89609ac85a (diff)
Add support for for loops.
Generalize the static analysis across while and for loops. Convert len builtin to tf.shape()[0]. Add for loop canonicalization and companion tests. Modify the template behavior for Name nodes to let the template control the target, which allows simplifying the caller. PiperOrigin-RevId: 181633983
-rw-r--r--tensorflow/contrib/py2tf/conversion.py16
-rw-r--r--tensorflow/contrib/py2tf/convert/BUILD24
-rw-r--r--tensorflow/contrib/py2tf/convert/builtin_functions.py54
-rw-r--r--tensorflow/contrib/py2tf/convert/builtin_functions_test.py58
-rw-r--r--tensorflow/contrib/py2tf/convert/control_flow.py50
-rw-r--r--tensorflow/contrib/py2tf/convert/for_canonicalization.py65
-rw-r--r--tensorflow/contrib/py2tf/convert/for_canonicalization_test.py61
-rw-r--r--tensorflow/contrib/py2tf/convert/side_effect_guards.py10
-rw-r--r--tensorflow/contrib/py2tf/pyct/anno.py1
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/access.py30
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py26
-rw-r--r--tensorflow/contrib/py2tf/pyct/templates.py12
12 files changed, 352 insertions, 55 deletions
diff --git a/tensorflow/contrib/py2tf/conversion.py b/tensorflow/contrib/py2tf/conversion.py
index 12dd70e497..40b9c3369d 100644
--- a/tensorflow/contrib/py2tf/conversion.py
+++ b/tensorflow/contrib/py2tf/conversion.py
@@ -22,8 +22,10 @@ import six
from tensorflow.contrib.py2tf import config
from tensorflow.contrib.py2tf import naming
+from tensorflow.contrib.py2tf.convert import builtin_functions
from tensorflow.contrib.py2tf.convert import call_trees
from tensorflow.contrib.py2tf.convert import control_flow
+from tensorflow.contrib.py2tf.convert import for_canonicalization
from tensorflow.contrib.py2tf.convert import logical_expressions
from tensorflow.contrib.py2tf.convert import print_functions
from tensorflow.contrib.py2tf.convert import side_effect_guards
@@ -151,6 +153,20 @@ def node_to_graph(node, namer, namespace, value_hints):
# * keeping track of symbols that have been created
# * marking nodes (e.g. py_func wrappers) to suppress further processing
+ node = for_canonicalization.transform(node, namer)
+ node = builtin_functions.transform(node)
+
+ # The transformation steps above insert new variables. Although less
+ # efficient, it is most robust to re-run the analysis.
+ # We also need to ensure the namespace contains any new references that may
+ # have been created.
+ namespace['len'] = len
+ namespace['print'] = print
+
+ node = access.resolve(node)
+ node = live_values.resolve(node, namespace, config.PYTHON_LITERALS)
+ node = type_info.resolve(node, value_hints)
+
node = print_functions.transform(node)
node = call_trees.transform(node, namer, config.DEFAULT_UNCOMPILED_MODULES)
node = control_flow.transform(node, namer)
diff --git a/tensorflow/contrib/py2tf/convert/BUILD b/tensorflow/contrib/py2tf/convert/BUILD
index ddbf336947..ebe720501b 100644
--- a/tensorflow/contrib/py2tf/convert/BUILD
+++ b/tensorflow/contrib/py2tf/convert/BUILD
@@ -17,8 +17,10 @@ filegroup(
py_library(
name = "convert",
srcs = [
+ "builtin_functions.py",
"call_trees.py",
"control_flow.py",
+ "for_canonicalization.py",
"logical_expressions.py",
"print_functions.py",
"side_effect_guards.py",
@@ -53,6 +55,28 @@ py_test(
)
py_test(
+ name = "builtin_functions_test",
+ srcs = ["builtin_functions_test.py"],
+ deps = [
+ ":convert",
+ "//tensorflow/contrib/py2tf/pyct",
+ "//tensorflow/contrib/py2tf/pyct/static_analysis",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "for_canonicalization_test",
+ srcs = ["for_canonicalization_test.py"],
+ deps = [
+ ":convert",
+ "//tensorflow/contrib/py2tf/pyct",
+ "//tensorflow/contrib/py2tf/pyct/static_analysis",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
name = "logical_expressions_test",
srcs = ["logical_expressions_test.py"],
deps = [
diff --git a/tensorflow/contrib/py2tf/convert/builtin_functions.py b/tensorflow/contrib/py2tf/convert/builtin_functions.py
new file mode 100644
index 0000000000..b80c96c97a
--- /dev/null
+++ b/tensorflow/contrib/py2tf/convert/builtin_functions.py
@@ -0,0 +1,54 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Handles builtins and other special functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import templates
+
+
+class BuiltinFunctionTransformer(gast.NodeTransformer):
+ """Transforms Print nodes to Call so they can be handled as functions."""
+
+ # TODO(mdan): Bring print_functions in here.
+
+ def _convert_len(self, node):
+
+ def template(args):
+ tf.shape(args)[0] # pylint:disable=undefined-variable,expression-not-assigned
+
+ new_call = templates.replace(template, args=node.args)[0].value
+ return new_call
+
+ # pylint:disable=invalid-name
+
+ def visit_Call(self, node):
+ self.generic_visit(node)
+ # TODO(mdan): This won't work if the function was hidden.
+ if isinstance(node.func, gast.Name) and node.func.id == 'len':
+ return self._convert_len(node)
+ return node
+
+ # pylint:enable=invalid-name
+
+
+def transform(node):
+ transformer = BuiltinFunctionTransformer()
+ node = transformer.visit(node)
+ return node
diff --git a/tensorflow/contrib/py2tf/convert/builtin_functions_test.py b/tensorflow/contrib/py2tf/convert/builtin_functions_test.py
new file mode 100644
index 0000000000..9a6517321c
--- /dev/null
+++ b/tensorflow/contrib/py2tf/convert/builtin_functions_test.py
@@ -0,0 +1,58 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for builtin_functions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.convert import builtin_functions
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
+from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class BuiltinFunctionsTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn, namespace):
+ node = parser.parse_object(test_fn)
+ node = access.resolve(node)
+ node = live_values.resolve(node, namespace, {})
+ node = type_info.resolve(node, None)
+ return node
+
+ def test_len(self):
+
+ def test_fn(a):
+ return len(a)
+
+ node = self._parse_and_analyze(test_fn, {'len': len})
+ node = builtin_functions.transform(node)
+ result = compiler.ast_to_object(node)
+ setattr(result, 'tf', array_ops)
+
+ with self.test_session() as sess:
+ self.assertEqual(3,
+ sess.run(
+ result.test_fn(constant_op.constant([0, 0, 0]))))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/py2tf/convert/control_flow.py b/tensorflow/contrib/py2tf/convert/control_flow.py
index e15deb938c..40b6ba6cbb 100644
--- a/tensorflow/contrib/py2tf/convert/control_flow.py
+++ b/tensorflow/contrib/py2tf/convert/control_flow.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Identity converter. Useful for testing and diagnostic."""
+"""Handles control flow statements: while, if."""
from __future__ import absolute_import
from __future__ import division
@@ -48,17 +48,8 @@ class ControlFlowTransformer(gast.NodeTransformer):
# pylint:disable=invalid-name
- def _tuple_or_item(self, elts):
- elts = tuple(elts)
- if len(elts) == 1:
- return elts[0]
- return elts
-
- def _ast_tuple_or_item(self, elts, ctx):
- elts = list(elts)
- if len(elts) == 1:
- return elts[0]
- return gast.Tuple(elts, ctx)
+ def visit_For(self, node):
+ assert False, 'for statement should have been canonicalized at this point'
def visit_If(self, node):
raise NotImplementedError()
@@ -70,41 +61,38 @@ class ControlFlowTransformer(gast.NodeTransformer):
body_closure = tuple(body_scope.modified - body_scope.created)
def template(
- state_args, # pylint:disable=unused-argument
- state_locals,
- state_results, # pylint:disable=unused-argument
+ state, # pylint:disable=unused-argument
+ state_ast_tuple, # pylint:disable=unused-argument
test_name,
test, # pylint:disable=unused-argument
body_name,
- body,
- state_init):
+ body):
- def test_name(state_args): # pylint:disable=function-redefined,unused-argument
+ def test_name(state): # pylint:disable=function-redefined,unused-argument
return test
- def body_name(state_args): # pylint:disable=function-redefined,unused-argument
+ def body_name(state): # pylint:disable=function-redefined,unused-argument
body # pylint:disable=pointless-statement
- return state_locals
+ return state,
- state_results = tf.while_loop(test_name, body_name, [state_init]) # pylint:disable=undefined-variable
+ state_ast_tuple = tf.while_loop(test_name, body_name, [state]) # pylint:disable=undefined-variable
test_name = self.namer.new_symbol('loop_test', body_scope.used)
body_name = self.namer.new_symbol('loop_body', body_scope.used)
+ if len(body_closure) == 1:
+ state = gast.Name(body_closure[0], None, None)
+ state_ast_tuple = state
+ else:
+ state = tuple(gast.Name(n, None, None) for n in body_closure)
+ state_ast_tuple = gast.Tuple(state, None)
node = templates.replace(
template,
- state_args=self._tuple_or_item(
- gast.Name(n, gast.Param(), None) for n in body_closure),
- state_locals=self._ast_tuple_or_item(
- (gast.Name(n, gast.Load(), None) for n in body_closure),
- gast.Load()),
- state_results=self._ast_tuple_or_item(
- (gast.Name(n, gast.Store(), None) for n in body_closure),
- gast.Store()),
+ state=state,
+ state_ast_tuple=state_ast_tuple,
test_name=gast.Name(test_name, gast.Load(), None),
test=node.test,
body_name=gast.Name(body_name, gast.Load(), None),
- body=node.body,
- state_init=[gast.Name(n, gast.Load(), None) for n in body_closure])
+ body=node.body)
return node
diff --git a/tensorflow/contrib/py2tf/convert/for_canonicalization.py b/tensorflow/contrib/py2tf/convert/for_canonicalization.py
new file mode 100644
index 0000000000..eb31ac386f
--- /dev/null
+++ b/tensorflow/contrib/py2tf/convert/for_canonicalization.py
@@ -0,0 +1,65 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Canonicalizes for loops into while loops.
+
+This canonicalizer uses the len function on its argument. That should be
+converted to a tf.shape separately.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import templates
+
+
+class ForLoopCanonicalizationTransformer(gast.NodeTransformer):
+ """Canonicalizes for loops (e.g. into while loops)."""
+
+ def __init__(self, namer):
+ self.namer = namer
+
+ def visit_For(self, node):
+ self.generic_visit(node)
+ body_scope = anno.getanno(node, 'body_scope')
+
+ # TODO(mdan): Distinguish between `for i in n` and `for i in range(n)`
+ # Or maybe we should replace range with tf.range?
+
+ def template(loop_iter, target, body, i, n): # pylint:disable=unused-argument
+ i = 0
+ n = len(loop_iter) # pylint:disable=undefined-variable
+ while i < n:
+ # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
+ target = loop_iter[i]
+ body # pylint:disable=pointless-statement
+ i += 1
+
+ return templates.replace(
+ template,
+ loop_iter=node.iter,
+ target=node.target,
+ body=node.body,
+ i=gast.Name(self.namer.new_symbol('i', body_scope.used), None, None),
+ n=gast.Name(self.namer.new_symbol('n', body_scope.used), None, None))
+
+
+def transform(node, namer):
+ transformer = ForLoopCanonicalizationTransformer(namer)
+ node = transformer.visit(node)
+ return node
diff --git a/tensorflow/contrib/py2tf/convert/for_canonicalization_test.py b/tensorflow/contrib/py2tf/convert/for_canonicalization_test.py
new file mode 100644
index 0000000000..8de2d1a0f8
--- /dev/null
+++ b/tensorflow/contrib/py2tf/convert/for_canonicalization_test.py
@@ -0,0 +1,61 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for for_canonicalization module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.convert import control_flow
+from tensorflow.contrib.py2tf.convert import for_canonicalization
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.python.platform import test
+
+
+class TestNamer(control_flow.SymbolNamer):
+
+ def new_symbol(self, name_root, _):
+ return name_root
+
+
+class ControlFlowTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn, namespace):
+ node = parser.parse_object(test_fn)
+ node = access.resolve(node)
+ return node
+
+ def test_basic_for(self):
+
+ def test_fn(l):
+ s = 0
+ for e in l:
+ s += e
+ return s
+
+ node = self._parse_and_analyze(test_fn, {})
+ node = for_canonicalization.transform(node, TestNamer())
+ result = compiler.ast_to_object(node)
+
+ l = [1, 2, 3]
+ self.assertEqual(test_fn(l), result.test_fn(l))
+ l = []
+ self.assertEqual(test_fn(l), result.test_fn(l))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/py2tf/convert/side_effect_guards.py b/tensorflow/contrib/py2tf/convert/side_effect_guards.py
index 25cf422517..61b428fd8f 100644
--- a/tensorflow/contrib/py2tf/convert/side_effect_guards.py
+++ b/tensorflow/contrib/py2tf/convert/side_effect_guards.py
@@ -95,13 +95,11 @@ class SideEffectGuardTransformer(gast.NodeTransformer):
def _gate_symbols(self, guard_statement, guarded_args):
- def template(dst_args, src_args): # pylint:disable=unused-argument
- (dst_args,) = (tf.identity(a) for a in (src_args,)) # pylint:disable=undefined-variable
+ def template(args): # pylint:disable=unused-argument
+ (args,) = (tf.identity(a) for a in (args,)) # pylint:disable=undefined-variable
guards = templates.replace(
- template,
- dst_args=tuple(gast.Name(a, gast.Store(), None) for a in guarded_args),
- src_args=tuple(gast.Name(a, gast.Load(), None) for a in guarded_args))
+ template, args=tuple(gast.Name(a, None, None) for a in guarded_args))
guard_statement.body.extend(guards)
return guard_statement
@@ -134,7 +132,7 @@ class SideEffectGuardTransformer(gast.NodeTransformer):
statements = templates.replace(
template,
call=node.value,
- temp_result=gast.Name(temp_name, gast.Store(), None))
+ temp_result=gast.Name(temp_name, None, None))
control_deps_guard = statements[-1]
control_deps_guard.body = []
diff --git a/tensorflow/contrib/py2tf/pyct/anno.py b/tensorflow/contrib/py2tf/pyct/anno.py
index 567195fb7e..889e4ba4ff 100644
--- a/tensorflow/contrib/py2tf/pyct/anno.py
+++ b/tensorflow/contrib/py2tf/pyct/anno.py
@@ -33,7 +33,6 @@ def hasanno(node, key, field_name='___pyct_anno'):
def setanno(node, key, value, field_name='___pyct_anno'):
annotations = getattr(node, field_name, {})
setattr(node, field_name, annotations)
- assert not hasanno(node, key, field_name), (node, key)
annotations[key] = value
# So that the annotations survive gast_to_ast() and ast_to_gast()
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/access.py b/tensorflow/contrib/py2tf/pyct/static_analysis/access.py
index 7a27473e44..25409c63ba 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/access.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/access.py
@@ -129,22 +129,28 @@ class AccessResolver(gast.NodeTransformer):
self.visit(node.func)
return node
+ def _process_block_node(self, node, block, scope_name):
+ current_scope = self.scope
+ anno.setanno(node, '%s_parent_scope' % scope_name, current_scope)
+ block_scope = Scope(current_scope, isolated=False)
+ self.scope = block_scope
+ for n in block:
+ self.visit(n)
+ anno.setanno(node, '%s_scope' % scope_name, block_scope)
+ self.scope = current_scope
+ return node
+
def visit_For(self, node):
- raise NotImplementedError()
+ self.visit(node.target)
+ self.visit(node.iter)
+ node = self._process_block_node(node, node.body, 'body')
+ node = self._process_block_node(node, node.orelse, 'orelse')
+ return node
def visit_While(self, node):
self.visit(node.test)
- current_scope = self.scope
- anno.setanno(node, 'parent_scope', current_scope)
- body_scope = Scope(current_scope, isolated=False)
- self.scope = body_scope
- for n in node.body:
- self.visit(n)
- anno.setanno(node, 'body_scope', body_scope)
- if node.orelse:
- raise NotImplementedError()
- # TODO(mdan): Add support for orelse.
- self.scope = current_scope
+ node = self._process_block_node(node, node.body, 'body')
+ node = self._process_block_node(node, node.orelse, 'orelse')
return node
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py
index 8fcccc84a7..5f1c45c6c3 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py
@@ -137,7 +137,7 @@ class AccessResolverTest(test.TestCase):
while_node = node.body[0].body[1]
while_body_scope = anno.getanno(while_node, 'body_scope')
- while_parent_scope = anno.getanno(while_node, 'parent_scope')
+ while_parent_scope = anno.getanno(while_node, 'body_parent_scope')
self.assertItemsEqual(['b'], while_body_scope.used)
self.assertItemsEqual(['b', 'c'], while_body_scope.modified)
@@ -147,6 +147,30 @@ class AccessResolverTest(test.TestCase):
self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.modified)
self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.created)
+ def test_for(self):
+
+ def test_fn(a):
+ b = a
+ for _ in a:
+ c = b
+ b -= 1
+ return b, c
+
+ node = parser.parse_object(test_fn)
+ node = access.resolve(node)
+
+ for_node = node.body[0].body[1]
+ for_body_scope = anno.getanno(for_node, 'body_scope')
+ for_parent_scope = anno.getanno(for_node, 'body_parent_scope')
+
+ self.assertItemsEqual(['b'], for_body_scope.used)
+ self.assertItemsEqual(['b', 'c'], for_body_scope.modified)
+ self.assertItemsEqual(['c'], for_body_scope.created)
+
+ self.assertItemsEqual(['a', 'b', 'c'], for_parent_scope.used)
+ self.assertItemsEqual(['a', 'b', 'c', '_'], for_parent_scope.modified)
+ self.assertItemsEqual(['a', 'b', 'c', '_'], for_parent_scope.created)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/py2tf/pyct/templates.py
index 6acc03bfce..4fadc793e6 100644
--- a/tensorflow/contrib/py2tf/pyct/templates.py
+++ b/tensorflow/contrib/py2tf/pyct/templates.py
@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
import ast
+import copy
import gast
@@ -61,14 +62,17 @@ class ReplaceTransformer(gast.NodeTransformer):
return node
def visit_Name(self, node):
- # Note: The caller is reposnsible with making sure the replacement
- # Name nodes have the proper ctx set up.
- # TODO(mdan): Is it possible to always infer the proper context here?
if node.id in self.replacements:
# TODO(mdan): Sanitize the nodes by erasing scope-dependent annotations.
- new_nodes = self.replacements[node.id]
+ new_nodes = copy.copy(self.replacements[node.id])
if isinstance(new_nodes, gast.AST):
new_nodes = [new_nodes]
+ # Preserve the target context.
+ for n in new_nodes:
+ if isinstance(n, gast.Tuple):
+ for e in n.elts:
+ e.ctx = node.ctx
+ n.ctx = node.ctx
if len(new_nodes) == 1:
new_nodes, = new_nodes
return new_nodes