diff options
author | 2018-01-11 11:09:48 -0800 | |
---|---|---|
committer | 2018-01-11 11:13:27 -0800 | |
commit | 453f94412f908aadd21561c14feae80dfac1e933 (patch) | |
tree | 163730dfebfaceaa286501a1f067244bc6685a33 | |
parent | 46d6620e4b7f71b420c58df90e7ceb89609ac85a (diff) |
Add support for for loops.
Generalize the static analysis across while and for loops.
Convert len builtin to tf.shape()[0].
Add for loop canonicalization and companion tests.
Modify the template behavior for Name nodes to let the template control the target, which allows simplifying the caller.
PiperOrigin-RevId: 181633983
-rw-r--r-- | tensorflow/contrib/py2tf/conversion.py | 16 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/convert/BUILD | 24 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/convert/builtin_functions.py | 54 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/convert/builtin_functions_test.py | 58 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/convert/control_flow.py | 50 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/convert/for_canonicalization.py | 65 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/convert/for_canonicalization_test.py | 61 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/convert/side_effect_guards.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/pyct/anno.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/pyct/static_analysis/access.py | 30 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py | 26 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/pyct/templates.py | 12 |
12 files changed, 352 insertions, 55 deletions
diff --git a/tensorflow/contrib/py2tf/conversion.py b/tensorflow/contrib/py2tf/conversion.py index 12dd70e497..40b9c3369d 100644 --- a/tensorflow/contrib/py2tf/conversion.py +++ b/tensorflow/contrib/py2tf/conversion.py @@ -22,8 +22,10 @@ import six from tensorflow.contrib.py2tf import config from tensorflow.contrib.py2tf import naming +from tensorflow.contrib.py2tf.convert import builtin_functions from tensorflow.contrib.py2tf.convert import call_trees from tensorflow.contrib.py2tf.convert import control_flow +from tensorflow.contrib.py2tf.convert import for_canonicalization from tensorflow.contrib.py2tf.convert import logical_expressions from tensorflow.contrib.py2tf.convert import print_functions from tensorflow.contrib.py2tf.convert import side_effect_guards @@ -151,6 +153,20 @@ def node_to_graph(node, namer, namespace, value_hints): # * keeping track of symbols that have been created # * marking nodes (e.g. py_func wrappers) to suppress further processing + node = for_canonicalization.transform(node, namer) + node = builtin_functions.transform(node) + + # The transformation steps above insert new variables. Although less + # efficient, it is most robust to re-run the analysis. + # We also need to ensure the namespace contains any new references that may + # have been created. + namespace['len'] = len + namespace['print'] = print + + node = access.resolve(node) + node = live_values.resolve(node, namespace, config.PYTHON_LITERALS) + node = type_info.resolve(node, value_hints) + node = print_functions.transform(node) node = call_trees.transform(node, namer, config.DEFAULT_UNCOMPILED_MODULES) node = control_flow.transform(node, namer) diff --git a/tensorflow/contrib/py2tf/convert/BUILD b/tensorflow/contrib/py2tf/convert/BUILD index ddbf336947..ebe720501b 100644 --- a/tensorflow/contrib/py2tf/convert/BUILD +++ b/tensorflow/contrib/py2tf/convert/BUILD @@ -17,8 +17,10 @@ filegroup( py_library( name = "convert", srcs = [ + "builtin_functions.py", "call_trees.py", "control_flow.py", + "for_canonicalization.py", "logical_expressions.py", "print_functions.py", "side_effect_guards.py", @@ -53,6 +55,28 @@ py_test( ) py_test( + name = "builtin_functions_test", + srcs = ["builtin_functions_test.py"], + deps = [ + ":convert", + "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/py2tf/pyct/static_analysis", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "for_canonicalization_test", + srcs = ["for_canonicalization_test.py"], + deps = [ + ":convert", + "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/py2tf/pyct/static_analysis", + "//tensorflow/python:client_testlib", + ], +) + +py_test( name = "logical_expressions_test", srcs = ["logical_expressions_test.py"], deps = [ diff --git a/tensorflow/contrib/py2tf/convert/builtin_functions.py b/tensorflow/contrib/py2tf/convert/builtin_functions.py new file mode 100644 index 0000000000..b80c96c97a --- /dev/null +++ b/tensorflow/contrib/py2tf/convert/builtin_functions.py @@ -0,0 +1,54 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Handles builtins and other special functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.py2tf.pyct import templates + + +class BuiltinFunctionTransformer(gast.NodeTransformer): + """Transforms Print nodes to Call so they can be handled as functions.""" + + # TODO(mdan): Bring print_functions in here. + + def _convert_len(self, node): + + def template(args): + tf.shape(args)[0] # pylint:disable=undefined-variable,expression-not-assigned + + new_call = templates.replace(template, args=node.args)[0].value + return new_call + + # pylint:disable=invalid-name + + def visit_Call(self, node): + self.generic_visit(node) + # TODO(mdan): This won't work if the function was hidden. + if isinstance(node.func, gast.Name) and node.func.id == 'len': + return self._convert_len(node) + return node + + # pylint:enable=invalid-name + + +def transform(node): + transformer = BuiltinFunctionTransformer() + node = transformer.visit(node) + return node diff --git a/tensorflow/contrib/py2tf/convert/builtin_functions_test.py b/tensorflow/contrib/py2tf/convert/builtin_functions_test.py new file mode 100644 index 0000000000..9a6517321c --- /dev/null +++ b/tensorflow/contrib/py2tf/convert/builtin_functions_test.py @@ -0,0 +1,58 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for builtin_functions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.convert import builtin_functions +from tensorflow.contrib.py2tf.pyct import compiler +from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.py2tf.pyct.static_analysis import access +from tensorflow.contrib.py2tf.pyct.static_analysis import live_values +from tensorflow.contrib.py2tf.pyct.static_analysis import type_info +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class BuiltinFunctionsTest(test.TestCase): + + def _parse_and_analyze(self, test_fn, namespace): + node = parser.parse_object(test_fn) + node = access.resolve(node) + node = live_values.resolve(node, namespace, {}) + node = type_info.resolve(node, None) + return node + + def test_len(self): + + def test_fn(a): + return len(a) + + node = self._parse_and_analyze(test_fn, {'len': len}) + node = builtin_functions.transform(node) + result = compiler.ast_to_object(node) + setattr(result, 'tf', array_ops) + + with self.test_session() as sess: + self.assertEqual(3, + sess.run( + result.test_fn(constant_op.constant([0, 0, 0])))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/convert/control_flow.py b/tensorflow/contrib/py2tf/convert/control_flow.py index e15deb938c..40b6ba6cbb 100644 --- a/tensorflow/contrib/py2tf/convert/control_flow.py +++ b/tensorflow/contrib/py2tf/convert/control_flow.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Identity converter. Useful for testing and diagnostic.""" +"""Handles control flow statements: while, if.""" from __future__ import absolute_import from __future__ import division @@ -48,17 +48,8 @@ class ControlFlowTransformer(gast.NodeTransformer): # pylint:disable=invalid-name - def _tuple_or_item(self, elts): - elts = tuple(elts) - if len(elts) == 1: - return elts[0] - return elts - - def _ast_tuple_or_item(self, elts, ctx): - elts = list(elts) - if len(elts) == 1: - return elts[0] - return gast.Tuple(elts, ctx) + def visit_For(self, node): + assert False, 'for statement should have been canonicalized at this point' def visit_If(self, node): raise NotImplementedError() @@ -70,41 +61,38 @@ class ControlFlowTransformer(gast.NodeTransformer): body_closure = tuple(body_scope.modified - body_scope.created) def template( - state_args, # pylint:disable=unused-argument - state_locals, - state_results, # pylint:disable=unused-argument + state, # pylint:disable=unused-argument + state_ast_tuple, # pylint:disable=unused-argument test_name, test, # pylint:disable=unused-argument body_name, - body, - state_init): + body): - def test_name(state_args): # pylint:disable=function-redefined,unused-argument + def test_name(state): # pylint:disable=function-redefined,unused-argument return test - def body_name(state_args): # pylint:disable=function-redefined,unused-argument + def body_name(state): # pylint:disable=function-redefined,unused-argument body # pylint:disable=pointless-statement - return state_locals + return state, - state_results = tf.while_loop(test_name, body_name, [state_init]) # pylint:disable=undefined-variable + state_ast_tuple = tf.while_loop(test_name, body_name, [state]) # pylint:disable=undefined-variable test_name = self.namer.new_symbol('loop_test', body_scope.used) body_name = self.namer.new_symbol('loop_body', body_scope.used) + if len(body_closure) == 1: + state = gast.Name(body_closure[0], None, None) + state_ast_tuple = state + else: + state = tuple(gast.Name(n, None, None) for n in body_closure) + state_ast_tuple = gast.Tuple(state, None) node = templates.replace( template, - state_args=self._tuple_or_item( - gast.Name(n, gast.Param(), None) for n in body_closure), - state_locals=self._ast_tuple_or_item( - (gast.Name(n, gast.Load(), None) for n in body_closure), - gast.Load()), - state_results=self._ast_tuple_or_item( - (gast.Name(n, gast.Store(), None) for n in body_closure), - gast.Store()), + state=state, + state_ast_tuple=state_ast_tuple, test_name=gast.Name(test_name, gast.Load(), None), test=node.test, body_name=gast.Name(body_name, gast.Load(), None), - body=node.body, - state_init=[gast.Name(n, gast.Load(), None) for n in body_closure]) + body=node.body) return node diff --git a/tensorflow/contrib/py2tf/convert/for_canonicalization.py b/tensorflow/contrib/py2tf/convert/for_canonicalization.py new file mode 100644 index 0000000000..eb31ac386f --- /dev/null +++ b/tensorflow/contrib/py2tf/convert/for_canonicalization.py @@ -0,0 +1,65 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Canonicalizes for loops into while loops. + +This canonicalizer uses the len function on its argument. That should be +converted to a tf.shape separately. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import templates + + +class ForLoopCanonicalizationTransformer(gast.NodeTransformer): + """Canonicalizes for loops (e.g. into while loops).""" + + def __init__(self, namer): + self.namer = namer + + def visit_For(self, node): + self.generic_visit(node) + body_scope = anno.getanno(node, 'body_scope') + + # TODO(mdan): Distinguish between `for i in n` and `for i in range(n)` + # Or maybe we should replace range with tf.range? + + def template(loop_iter, target, body, i, n): # pylint:disable=unused-argument + i = 0 + n = len(loop_iter) # pylint:disable=undefined-variable + while i < n: + # TODO(mdan): Use TensorListFromTensor(loop_iter) here. + target = loop_iter[i] + body # pylint:disable=pointless-statement + i += 1 + + return templates.replace( + template, + loop_iter=node.iter, + target=node.target, + body=node.body, + i=gast.Name(self.namer.new_symbol('i', body_scope.used), None, None), + n=gast.Name(self.namer.new_symbol('n', body_scope.used), None, None)) + + +def transform(node, namer): + transformer = ForLoopCanonicalizationTransformer(namer) + node = transformer.visit(node) + return node diff --git a/tensorflow/contrib/py2tf/convert/for_canonicalization_test.py b/tensorflow/contrib/py2tf/convert/for_canonicalization_test.py new file mode 100644 index 0000000000..8de2d1a0f8 --- /dev/null +++ b/tensorflow/contrib/py2tf/convert/for_canonicalization_test.py @@ -0,0 +1,61 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for for_canonicalization module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.convert import control_flow +from tensorflow.contrib.py2tf.convert import for_canonicalization +from tensorflow.contrib.py2tf.pyct import compiler +from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.py2tf.pyct.static_analysis import access +from tensorflow.python.platform import test + + +class TestNamer(control_flow.SymbolNamer): + + def new_symbol(self, name_root, _): + return name_root + + +class ControlFlowTest(test.TestCase): + + def _parse_and_analyze(self, test_fn, namespace): + node = parser.parse_object(test_fn) + node = access.resolve(node) + return node + + def test_basic_for(self): + + def test_fn(l): + s = 0 + for e in l: + s += e + return s + + node = self._parse_and_analyze(test_fn, {}) + node = for_canonicalization.transform(node, TestNamer()) + result = compiler.ast_to_object(node) + + l = [1, 2, 3] + self.assertEqual(test_fn(l), result.test_fn(l)) + l = [] + self.assertEqual(test_fn(l), result.test_fn(l)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/convert/side_effect_guards.py b/tensorflow/contrib/py2tf/convert/side_effect_guards.py index 25cf422517..61b428fd8f 100644 --- a/tensorflow/contrib/py2tf/convert/side_effect_guards.py +++ b/tensorflow/contrib/py2tf/convert/side_effect_guards.py @@ -95,13 +95,11 @@ class SideEffectGuardTransformer(gast.NodeTransformer): def _gate_symbols(self, guard_statement, guarded_args): - def template(dst_args, src_args): # pylint:disable=unused-argument - (dst_args,) = (tf.identity(a) for a in (src_args,)) # pylint:disable=undefined-variable + def template(args): # pylint:disable=unused-argument + (args,) = (tf.identity(a) for a in (args,)) # pylint:disable=undefined-variable guards = templates.replace( - template, - dst_args=tuple(gast.Name(a, gast.Store(), None) for a in guarded_args), - src_args=tuple(gast.Name(a, gast.Load(), None) for a in guarded_args)) + template, args=tuple(gast.Name(a, None, None) for a in guarded_args)) guard_statement.body.extend(guards) return guard_statement @@ -134,7 +132,7 @@ class SideEffectGuardTransformer(gast.NodeTransformer): statements = templates.replace( template, call=node.value, - temp_result=gast.Name(temp_name, gast.Store(), None)) + temp_result=gast.Name(temp_name, None, None)) control_deps_guard = statements[-1] control_deps_guard.body = [] diff --git a/tensorflow/contrib/py2tf/pyct/anno.py b/tensorflow/contrib/py2tf/pyct/anno.py index 567195fb7e..889e4ba4ff 100644 --- a/tensorflow/contrib/py2tf/pyct/anno.py +++ b/tensorflow/contrib/py2tf/pyct/anno.py @@ -33,7 +33,6 @@ def hasanno(node, key, field_name='___pyct_anno'): def setanno(node, key, value, field_name='___pyct_anno'): annotations = getattr(node, field_name, {}) setattr(node, field_name, annotations) - assert not hasanno(node, key, field_name), (node, key) annotations[key] = value # So that the annotations survive gast_to_ast() and ast_to_gast() diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/access.py b/tensorflow/contrib/py2tf/pyct/static_analysis/access.py index 7a27473e44..25409c63ba 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/access.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/access.py @@ -129,22 +129,28 @@ class AccessResolver(gast.NodeTransformer): self.visit(node.func) return node + def _process_block_node(self, node, block, scope_name): + current_scope = self.scope + anno.setanno(node, '%s_parent_scope' % scope_name, current_scope) + block_scope = Scope(current_scope, isolated=False) + self.scope = block_scope + for n in block: + self.visit(n) + anno.setanno(node, '%s_scope' % scope_name, block_scope) + self.scope = current_scope + return node + def visit_For(self, node): - raise NotImplementedError() + self.visit(node.target) + self.visit(node.iter) + node = self._process_block_node(node, node.body, 'body') + node = self._process_block_node(node, node.orelse, 'orelse') + return node def visit_While(self, node): self.visit(node.test) - current_scope = self.scope - anno.setanno(node, 'parent_scope', current_scope) - body_scope = Scope(current_scope, isolated=False) - self.scope = body_scope - for n in node.body: - self.visit(n) - anno.setanno(node, 'body_scope', body_scope) - if node.orelse: - raise NotImplementedError() - # TODO(mdan): Add support for orelse. - self.scope = current_scope + node = self._process_block_node(node, node.body, 'body') + node = self._process_block_node(node, node.orelse, 'orelse') return node diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py index 8fcccc84a7..5f1c45c6c3 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py @@ -137,7 +137,7 @@ class AccessResolverTest(test.TestCase): while_node = node.body[0].body[1] while_body_scope = anno.getanno(while_node, 'body_scope') - while_parent_scope = anno.getanno(while_node, 'parent_scope') + while_parent_scope = anno.getanno(while_node, 'body_parent_scope') self.assertItemsEqual(['b'], while_body_scope.used) self.assertItemsEqual(['b', 'c'], while_body_scope.modified) @@ -147,6 +147,30 @@ class AccessResolverTest(test.TestCase): self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.modified) self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.created) + def test_for(self): + + def test_fn(a): + b = a + for _ in a: + c = b + b -= 1 + return b, c + + node = parser.parse_object(test_fn) + node = access.resolve(node) + + for_node = node.body[0].body[1] + for_body_scope = anno.getanno(for_node, 'body_scope') + for_parent_scope = anno.getanno(for_node, 'body_parent_scope') + + self.assertItemsEqual(['b'], for_body_scope.used) + self.assertItemsEqual(['b', 'c'], for_body_scope.modified) + self.assertItemsEqual(['c'], for_body_scope.created) + + self.assertItemsEqual(['a', 'b', 'c'], for_parent_scope.used) + self.assertItemsEqual(['a', 'b', 'c', '_'], for_parent_scope.modified) + self.assertItemsEqual(['a', 'b', 'c', '_'], for_parent_scope.created) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/py2tf/pyct/templates.py index 6acc03bfce..4fadc793e6 100644 --- a/tensorflow/contrib/py2tf/pyct/templates.py +++ b/tensorflow/contrib/py2tf/pyct/templates.py @@ -22,6 +22,7 @@ from __future__ import division from __future__ import print_function import ast +import copy import gast @@ -61,14 +62,17 @@ class ReplaceTransformer(gast.NodeTransformer): return node def visit_Name(self, node): - # Note: The caller is reposnsible with making sure the replacement - # Name nodes have the proper ctx set up. - # TODO(mdan): Is it possible to always infer the proper context here? if node.id in self.replacements: # TODO(mdan): Sanitize the nodes by erasing scope-dependent annotations. - new_nodes = self.replacements[node.id] + new_nodes = copy.copy(self.replacements[node.id]) if isinstance(new_nodes, gast.AST): new_nodes = [new_nodes] + # Preserve the target context. + for n in new_nodes: + if isinstance(n, gast.Tuple): + for e in n.elts: + e.ctx = node.ctx + n.ctx = node.ctx if len(new_nodes) == 1: new_nodes, = new_nodes return new_nodes |