aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-23 13:25:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 13:38:47 -0700
commit6fe361f80d4277ea879b3182e1d7148a65a8ca21 (patch)
tree506b5bec3d843f3b90568a2e0990b030fe1d48fe /tensorflow/contrib/autograph
parent0a427ca13e52bc121f2d42f21c65e6f03a520a1a (diff)
Allow Autograph tuple unpacking in for loops.
PiperOrigin-RevId: 209988449
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow.py4
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow_test.py8
2 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py
index f7dd3183b0..8d314250a0 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/contrib/autograph/converters/control_flow.py
@@ -310,7 +310,9 @@ class ControlFlowTransformer(converter.Base):
template = """
def extra_test_name(state_ssf):
return extra_test_expr
- def body_name(iterate, state_ssf):
+ def body_name(loop_vars, state_ssf):
+ # Workaround for PEP-3113
+ iterate = loop_vars
body
return state_ssf,
state_ast_tuple = ag__.for_stmt(
diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py
index 02bc00dbc8..2a6f3cb395 100644
--- a/tensorflow/contrib/autograph/converters/control_flow_test.py
+++ b/tensorflow/contrib/autograph/converters/control_flow_test.py
@@ -217,5 +217,13 @@ class ControlFlowTest(converter_testing.TestCase):
with self.assertRaises(transformer.AutographParseError):
control_flow.transform(node, ctx)
+ def test_for_tuple_unpacking(self):
+ def test_fn(x_list):
+ z = tf.constant(0) # pylint:disable=undefined-variable
+ for i, x in enumerate(x_list):
+ z = z + x + i
+ return z
+
+ self.assertTransformedResult(test_fn, [3, 3], 7)
if __name__ == '__main__':
test.main()