diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-11 17:35:20 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-11 17:38:48 -0800 |
commit | 6ee404d17929c613b217400406e7e665010ebf18 (patch) | |
tree | 5f94c1a130368be770eaec4e1fd52a7d9a34115f | |
parent | a3f47ad628be2e590bcbfadf93a08446d3af6ef2 (diff) |
Add support for if conditionals. Fix a bug in the activity analysis.
PiperOrigin-RevId: 181686453
7 files changed, 256 insertions, 42 deletions
diff --git a/tensorflow/contrib/py2tf/convert/control_flow.py b/tensorflow/contrib/py2tf/convert/control_flow.py index 40b6ba6cbb..8ebd9ad93d 100644 --- a/tensorflow/contrib/py2tf/convert/control_flow.py +++ b/tensorflow/contrib/py2tf/convert/control_flow.py @@ -40,6 +40,17 @@ class SymbolNamer(object): raise NotImplementedError() +class SymbolRenamer(gast.NodeTransformer): + + def __init__(self, name_map): + self.name_map = name_map + + def visit_Name(self, node): + if node.id in self.name_map: + node.id = self.name_map[node.id] + return node + + class ControlFlowTransformer(gast.NodeTransformer): """Transforms control flow structures like loops an conditionals.""" @@ -52,11 +63,84 @@ class ControlFlowTransformer(gast.NodeTransformer): assert False, 'for statement should have been canonicalized at this point' def visit_If(self, node): - raise NotImplementedError() + self.generic_visit(node) + + body_scope = anno.getanno(node, 'body_scope') + orelse_scope = anno.getanno(node, 'orelse_scope') + + if body_scope.created - orelse_scope.created: + raise ValueError( + 'The if branch creates new symbols that the else branch does not.') + if orelse_scope.created - body_scope.created: + raise ValueError( + 'The else branch creates new symbols that the if branch does not.') + + def template( # pylint:disable=missing-docstring + test, + body_name, + body, + orelse_name, + orelse, + aliased, + aliases, # pylint:disable=unused-argument + aliased_results, + results): # pylint:disable=unused-argument + + def body_name(): # pylint:disable=function-redefined + aliases, = aliased, # pylint:disable=unused-variable + body # pylint:disable=pointless-statement + return (aliased_results,) + + def orelse_name(): # pylint:disable=function-redefined + aliases, = aliased, # pylint:disable=unused-variable + orelse # pylint:disable=pointless-statement + return (aliased_results,) + + results = tf.cond(test, body_name, orelse_name) # pylint:disable=undefined-variable + + all_modified = tuple(body_scope.modified | orelse_scope.modified) + all_referenced = body_scope.referenced | orelse_scope.referenced + + # Alias the closure variables inside the conditional functions + # to avoid errors caused by the local variables created in the branch + # functions. + need_alias = ( + (body_scope.modified | orelse_scope.modified) - + (body_scope.created | orelse_scope.created)) + aliased = tuple(need_alias) + aliases = tuple( + self.namer.new_symbol(s, all_referenced) for s in aliased) + alias_map = dict(zip(aliased, aliases)) + node_body = node.body + node_body = [SymbolRenamer(alias_map).visit(n) for n in node_body] + node_orelse = node.orelse + node_orelse = [SymbolRenamer(alias_map).visit(n) for n in node_orelse] + + if len(all_modified) == 1: + results = gast.Name(all_modified[0], None, None) + else: + results = gast.Tuple( + tuple(gast.Name(s, None, None) for s in all_modified), None) + + return templates.replace( + template, + test=node.test, + body_name=gast.Name( + self.namer.new_symbol('if_true', all_referenced), None, None), + body=node_body, + orelse_name=gast.Name( + self.namer.new_symbol('if_false', all_referenced), None, None), + orelse=node_orelse, + aliased=tuple(gast.Name(s, None, None) for s in aliased), + aliases=tuple(gast.Name(s, None, None) for s in aliases), + aliased_results=tuple( + gast.Name(alias_map[s] if s in aliased else s, None, None) + for s in all_modified), + results=results) def visit_While(self, node): self.generic_visit(node) - # Scrape out the data flow analysis + body_scope = anno.getanno(node, 'body_scope') body_closure = tuple(body_scope.modified - body_scope.created) @@ -77,8 +161,8 @@ class ControlFlowTransformer(gast.NodeTransformer): state_ast_tuple = tf.while_loop(test_name, body_name, [state]) # pylint:disable=undefined-variable - test_name = self.namer.new_symbol('loop_test', body_scope.used) - body_name = self.namer.new_symbol('loop_body', body_scope.used) + test_name = self.namer.new_symbol('loop_test', body_scope.referenced) + body_name = self.namer.new_symbol('loop_body', body_scope.referenced) if len(body_closure) == 1: state = gast.Name(body_closure[0], None, None) state_ast_tuple = state diff --git a/tensorflow/contrib/py2tf/convert/control_flow_test.py b/tensorflow/contrib/py2tf/convert/control_flow_test.py index c27a079546..51237a291d 100644 --- a/tensorflow/contrib/py2tf/convert/control_flow_test.py +++ b/tensorflow/contrib/py2tf/convert/control_flow_test.py @@ -31,8 +31,13 @@ from tensorflow.python.platform import test class TestNamer(control_flow.SymbolNamer): - def new_symbol(self, name_root, _): - return name_root + def new_symbol(self, name_root, used): + i = 0 + while True: + name = '%s%d' % (name_root, i) + if name not in used: + return name + i += 1 class ControlFlowTest(test.TestCase): @@ -78,6 +83,43 @@ class ControlFlowTest(test.TestCase): with self.test_session() as sess: self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5)))) + def test_simple_if(self): + + def test_fn(n): + a = 0 + b = 0 + if n > 0: + a = -n + else: + b = 2 * n + return a, b + + node = self._parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, TestNamer()) + result = compiler.ast_to_object(node) + setattr(result, 'tf', control_flow_ops) + + with self.test_session() as sess: + self.assertEqual((-1, 0), sess.run( + result.test_fn(constant_op.constant(1)))) + self.assertEqual((0, -2), + sess.run(result.test_fn(constant_op.constant(-1)))) + + def test_if_single_var(self): + + def test_fn(n): + if n > 0: + n = -n + return n + + node = self._parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, TestNamer()) + result = compiler.ast_to_object(node) + setattr(result, 'tf', control_flow_ops) + + with self.test_session() as sess: + self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1)))) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/convert/for_canonicalization.py b/tensorflow/contrib/py2tf/convert/for_canonicalization.py index eb31ac386f..c51a2326a6 100644 --- a/tensorflow/contrib/py2tf/convert/for_canonicalization.py +++ b/tensorflow/contrib/py2tf/convert/for_canonicalization.py @@ -55,8 +55,10 @@ class ForLoopCanonicalizationTransformer(gast.NodeTransformer): loop_iter=node.iter, target=node.target, body=node.body, - i=gast.Name(self.namer.new_symbol('i', body_scope.used), None, None), - n=gast.Name(self.namer.new_symbol('n', body_scope.used), None, None)) + i=gast.Name( + self.namer.new_symbol('i', body_scope.referenced), None, None), + n=gast.Name( + self.namer.new_symbol('n', body_scope.referenced), None, None)) def transform(node, namer): diff --git a/tensorflow/contrib/py2tf/convert/logical_expressions.py b/tensorflow/contrib/py2tf/convert/logical_expressions.py index c2f27a5f10..cfa23627fa 100644 --- a/tensorflow/contrib/py2tf/convert/logical_expressions.py +++ b/tensorflow/contrib/py2tf/convert/logical_expressions.py @@ -34,10 +34,15 @@ class LogicalExpressionTransformer(gast.NodeTransformer): self.op_mapping = { gast.And: 'tf.logical_and', gast.Or: 'tf.logical_or', + gast.Not: 'tf.logical_not', } def visit_UnaryOp(self, node): - raise NotImplementedError() + if isinstance(node.op, gast.Not): + tf_function = parser.parse_str(self.op_mapping[type( + node.op)]).body[0].value + node = gast.Call(func=tf_function, args=[node.operand], keywords=[]) + return node def visit_BoolOp(self, node): # TODO(mdan): A normalizer may be useful here. Use ANF? diff --git a/tensorflow/contrib/py2tf/convert/side_effect_guards.py b/tensorflow/contrib/py2tf/convert/side_effect_guards.py index 61b428fd8f..1f25303fba 100644 --- a/tensorflow/contrib/py2tf/convert/side_effect_guards.py +++ b/tensorflow/contrib/py2tf/convert/side_effect_guards.py @@ -112,7 +112,7 @@ class SideEffectGuardTransformer(gast.NodeTransformer): # tf.py_func(...) args_scope = anno.getanno(node.value, 'args_scope') - temp_name = self.namer.new_symbol('temp', args_scope.parent.used) + temp_name = self.namer.new_symbol('temp', args_scope.parent.referenced) # TODO(mdan): Unsafe reference modification! args_scope.mark_write(temp_name) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/access.py b/tensorflow/contrib/py2tf/pyct/static_analysis/access.py index 25409c63ba..05dc09eb04 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/access.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/access.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy + import gast from tensorflow.contrib.py2tf.pyct import anno @@ -52,10 +54,24 @@ class Scope(object): self.created = set() self.used = set() + @property + def referenced(self): + return self.used | self.modified + def __repr__(self): return 'Scope{r=%s, c=%s, w=%s}' % (tuple(self.used), tuple(self.created), tuple(self.modified)) + def copy_from(self, other): + self.modified = copy.copy(other.modified) + self.created = copy.copy(other.created) + self.used = copy.copy(other.used) + + def merge_from(self, other): + self.modified |= other.modified + self.created |= other.created + self.used |= other.used + def has(self, name): if name in self.modified: return True @@ -131,7 +147,6 @@ class AccessResolver(gast.NodeTransformer): def _process_block_node(self, node, block, scope_name): current_scope = self.scope - anno.setanno(node, '%s_parent_scope' % scope_name, current_scope) block_scope = Scope(current_scope, isolated=False) self.scope = block_scope for n in block: @@ -140,17 +155,43 @@ class AccessResolver(gast.NodeTransformer): self.scope = current_scope return node + def _process_parallel_blocks(self, parent, children): + # Because the scopes are not isolated, processing any child block + # modifies the parent state causing the other child blocks to be + # processed incorrectly. So we need to checkpoint the parent scope so that + # each child sees the same context. + before_parent = Scope(None) + before_parent.copy_from(self.scope) + after_children = [] + for child, name in children: + self.scope.copy_from(before_parent) + parent = self._process_block_node(parent, child, name) + after_child = Scope(None) + after_child.copy_from(self.scope) + after_children.append(after_child) + for after_child in after_children: + self.scope.merge_from(after_child) + for child, name in children: + anno.setanno(parent, '%s_parent_scope' % name, self.scope) + return parent + + def visit_If(self, node): + self.visit(node.test) + node = self._process_parallel_blocks( + node, ((node.body, 'body'), (node.orelse, 'orelse'))) + return node + def visit_For(self, node): self.visit(node.target) self.visit(node.iter) - node = self._process_block_node(node, node.body, 'body') - node = self._process_block_node(node, node.orelse, 'orelse') + node = self._process_parallel_blocks( + node, ((node.body, 'body'), (node.orelse, 'orelse'))) return node def visit_While(self, node): self.visit(node.test) - node = self._process_block_node(node, node.body, 'body') - node = self._process_block_node(node, node.orelse, 'orelse') + node = self._process_parallel_blocks( + node, ((node.body, 'body'), (node.orelse, 'orelse'))) return node diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py index 5f1c45c6c3..b16ce7b467 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py @@ -41,6 +41,26 @@ class ScopeTest(test.TestCase): scope.mark_read('bar') self.assertFalse(scope.has('bar')) + def test_copy(self): + scope = access.Scope(None) + scope.mark_write('foo') + + other = access.Scope(None) + other.copy_from(scope) + + self.assertTrue('foo' in other.created) + + scope.mark_write('bar') + scope.copy_from(other) + + self.assertFalse('bar' in scope.created) + + scope.mark_write('bar') + scope.merge_from(other) + + self.assertTrue('bar' in scope.created) + self.assertFalse('bar' in other.created) + def test_nesting(self): scope = access.Scope(None) scope.mark_write('foo') @@ -75,6 +95,11 @@ class AccessResolverTest(test.TestCase): self.assertTrue(anno.getanno(node.body[0].body[2].value, 'is_local')) # b in return b + def assertScopeIs(self, scope, used, modified, created): + self.assertItemsEqual(used, scope.used) + self.assertItemsEqual(modified, scope.modified) + self.assertItemsEqual(created, scope.created) + def test_print_statement(self): def test_fn(a): @@ -96,12 +121,9 @@ class AccessResolverTest(test.TestCase): # The call node should be the one being annotated. print_node = print_node.value print_args_scope = anno.getanno(print_node, 'args_scope') - # We basically need to detect which variables are captured by the call # arguments. - self.assertItemsEqual(['a', 'b'], print_args_scope.used) - self.assertItemsEqual([], print_args_scope.modified) - self.assertItemsEqual([], print_args_scope.created) + self.assertScopeIs(print_args_scope, ('a', 'b'), (), ()) def test_call(self): @@ -115,13 +137,10 @@ class AccessResolverTest(test.TestCase): node = access.resolve(node) call_node = node.body[0].body[2].value - call_args_scope = anno.getanno(call_node, 'args_scope') - # We basically need to detect which variables are captured by the call # arguments. - self.assertItemsEqual(['a', 'b'], call_args_scope.used) - self.assertItemsEqual([], call_args_scope.modified) - self.assertItemsEqual([], call_args_scope.created) + self.assertScopeIs( + anno.getanno(call_node, 'args_scope'), ('a', 'b'), (), ()) def test_while(self): @@ -136,16 +155,11 @@ class AccessResolverTest(test.TestCase): node = access.resolve(node) while_node = node.body[0].body[1] - while_body_scope = anno.getanno(while_node, 'body_scope') - while_parent_scope = anno.getanno(while_node, 'body_parent_scope') - - self.assertItemsEqual(['b'], while_body_scope.used) - self.assertItemsEqual(['b', 'c'], while_body_scope.modified) - self.assertItemsEqual(['c'], while_body_scope.created) - - self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.used) - self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.modified) - self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.created) + self.assertScopeIs( + anno.getanno(while_node, 'body_scope'), ('b',), ('b', 'c'), ('c',)) + self.assertScopeIs( + anno.getanno(while_node, 'body_parent_scope'), ('a', 'b', 'c'), + ('a', 'b', 'c'), ('a', 'b', 'c')) def test_for(self): @@ -160,16 +174,42 @@ class AccessResolverTest(test.TestCase): node = access.resolve(node) for_node = node.body[0].body[1] - for_body_scope = anno.getanno(for_node, 'body_scope') - for_parent_scope = anno.getanno(for_node, 'body_parent_scope') + self.assertScopeIs( + anno.getanno(for_node, 'body_scope'), ('b',), ('b', 'c'), ('c',)) + self.assertScopeIs( + anno.getanno(for_node, 'body_parent_scope'), ('a', 'b', 'c'), + ('a', 'b', 'c', '_'), ('a', 'b', 'c', '_')) + + def test_if(self): + + def test_fn(x): + if x > 0: + x = -x + y = 2 * x + z = -y + else: + x = 2 * x + y = -x + u = -y + return z, u - self.assertItemsEqual(['b'], for_body_scope.used) - self.assertItemsEqual(['b', 'c'], for_body_scope.modified) - self.assertItemsEqual(['c'], for_body_scope.created) + node = parser.parse_object(test_fn) + node = access.resolve(node) - self.assertItemsEqual(['a', 'b', 'c'], for_parent_scope.used) - self.assertItemsEqual(['a', 'b', 'c', '_'], for_parent_scope.modified) - self.assertItemsEqual(['a', 'b', 'c', '_'], for_parent_scope.created) + if_node = node.body[0].body[0] + self.assertScopeIs( + anno.getanno(if_node, 'body_scope'), ('x', 'y'), ('x', 'y', 'z'), + ('y', 'z')) + # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? + self.assertScopeIs( + anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), + ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) + self.assertScopeIs( + anno.getanno(if_node, 'orelse_scope'), ('x', 'y'), ('x', 'y', 'u'), + ('y', 'u')) + self.assertScopeIs( + anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), + ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) if __name__ == '__main__': |