aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-11 17:35:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-11 17:38:48 -0800
commit6ee404d17929c613b217400406e7e665010ebf18 (patch)
tree5f94c1a130368be770eaec4e1fd52a7d9a34115f
parenta3f47ad628be2e590bcbfadf93a08446d3af6ef2 (diff)
Add support for if conditionals. Fix a bug in the activity analysis.
PiperOrigin-RevId: 181686453
-rw-r--r--tensorflow/contrib/py2tf/convert/control_flow.py92
-rw-r--r--tensorflow/contrib/py2tf/convert/control_flow_test.py46
-rw-r--r--tensorflow/contrib/py2tf/convert/for_canonicalization.py6
-rw-r--r--tensorflow/contrib/py2tf/convert/logical_expressions.py7
-rw-r--r--tensorflow/contrib/py2tf/convert/side_effect_guards.py2
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/access.py51
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py94
7 files changed, 256 insertions, 42 deletions
diff --git a/tensorflow/contrib/py2tf/convert/control_flow.py b/tensorflow/contrib/py2tf/convert/control_flow.py
index 40b6ba6cbb..8ebd9ad93d 100644
--- a/tensorflow/contrib/py2tf/convert/control_flow.py
+++ b/tensorflow/contrib/py2tf/convert/control_flow.py
@@ -40,6 +40,17 @@ class SymbolNamer(object):
raise NotImplementedError()
+class SymbolRenamer(gast.NodeTransformer):
+
+ def __init__(self, name_map):
+ self.name_map = name_map
+
+ def visit_Name(self, node):
+ if node.id in self.name_map:
+ node.id = self.name_map[node.id]
+ return node
+
+
class ControlFlowTransformer(gast.NodeTransformer):
"""Transforms control flow structures like loops an conditionals."""
@@ -52,11 +63,84 @@ class ControlFlowTransformer(gast.NodeTransformer):
assert False, 'for statement should have been canonicalized at this point'
def visit_If(self, node):
- raise NotImplementedError()
+ self.generic_visit(node)
+
+ body_scope = anno.getanno(node, 'body_scope')
+ orelse_scope = anno.getanno(node, 'orelse_scope')
+
+ if body_scope.created - orelse_scope.created:
+ raise ValueError(
+ 'The if branch creates new symbols that the else branch does not.')
+ if orelse_scope.created - body_scope.created:
+ raise ValueError(
+ 'The else branch creates new symbols that the if branch does not.')
+
+ def template( # pylint:disable=missing-docstring
+ test,
+ body_name,
+ body,
+ orelse_name,
+ orelse,
+ aliased,
+ aliases, # pylint:disable=unused-argument
+ aliased_results,
+ results): # pylint:disable=unused-argument
+
+ def body_name(): # pylint:disable=function-redefined
+ aliases, = aliased, # pylint:disable=unused-variable
+ body # pylint:disable=pointless-statement
+ return (aliased_results,)
+
+ def orelse_name(): # pylint:disable=function-redefined
+ aliases, = aliased, # pylint:disable=unused-variable
+ orelse # pylint:disable=pointless-statement
+ return (aliased_results,)
+
+ results = tf.cond(test, body_name, orelse_name) # pylint:disable=undefined-variable
+
+ all_modified = tuple(body_scope.modified | orelse_scope.modified)
+ all_referenced = body_scope.referenced | orelse_scope.referenced
+
+ # Alias the closure variables inside the conditional functions
+ # to avoid errors caused by the local variables created in the branch
+ # functions.
+ need_alias = (
+ (body_scope.modified | orelse_scope.modified) -
+ (body_scope.created | orelse_scope.created))
+ aliased = tuple(need_alias)
+ aliases = tuple(
+ self.namer.new_symbol(s, all_referenced) for s in aliased)
+ alias_map = dict(zip(aliased, aliases))
+ node_body = node.body
+ node_body = [SymbolRenamer(alias_map).visit(n) for n in node_body]
+ node_orelse = node.orelse
+ node_orelse = [SymbolRenamer(alias_map).visit(n) for n in node_orelse]
+
+ if len(all_modified) == 1:
+ results = gast.Name(all_modified[0], None, None)
+ else:
+ results = gast.Tuple(
+ tuple(gast.Name(s, None, None) for s in all_modified), None)
+
+ return templates.replace(
+ template,
+ test=node.test,
+ body_name=gast.Name(
+ self.namer.new_symbol('if_true', all_referenced), None, None),
+ body=node_body,
+ orelse_name=gast.Name(
+ self.namer.new_symbol('if_false', all_referenced), None, None),
+ orelse=node_orelse,
+ aliased=tuple(gast.Name(s, None, None) for s in aliased),
+ aliases=tuple(gast.Name(s, None, None) for s in aliases),
+ aliased_results=tuple(
+ gast.Name(alias_map[s] if s in aliased else s, None, None)
+ for s in all_modified),
+ results=results)
def visit_While(self, node):
self.generic_visit(node)
- # Scrape out the data flow analysis
+
body_scope = anno.getanno(node, 'body_scope')
body_closure = tuple(body_scope.modified - body_scope.created)
@@ -77,8 +161,8 @@ class ControlFlowTransformer(gast.NodeTransformer):
state_ast_tuple = tf.while_loop(test_name, body_name, [state]) # pylint:disable=undefined-variable
- test_name = self.namer.new_symbol('loop_test', body_scope.used)
- body_name = self.namer.new_symbol('loop_body', body_scope.used)
+ test_name = self.namer.new_symbol('loop_test', body_scope.referenced)
+ body_name = self.namer.new_symbol('loop_body', body_scope.referenced)
if len(body_closure) == 1:
state = gast.Name(body_closure[0], None, None)
state_ast_tuple = state
diff --git a/tensorflow/contrib/py2tf/convert/control_flow_test.py b/tensorflow/contrib/py2tf/convert/control_flow_test.py
index c27a079546..51237a291d 100644
--- a/tensorflow/contrib/py2tf/convert/control_flow_test.py
+++ b/tensorflow/contrib/py2tf/convert/control_flow_test.py
@@ -31,8 +31,13 @@ from tensorflow.python.platform import test
class TestNamer(control_flow.SymbolNamer):
- def new_symbol(self, name_root, _):
- return name_root
+ def new_symbol(self, name_root, used):
+ i = 0
+ while True:
+ name = '%s%d' % (name_root, i)
+ if name not in used:
+ return name
+ i += 1
class ControlFlowTest(test.TestCase):
@@ -78,6 +83,43 @@ class ControlFlowTest(test.TestCase):
with self.test_session() as sess:
self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5))))
+ def test_simple_if(self):
+
+ def test_fn(n):
+ a = 0
+ b = 0
+ if n > 0:
+ a = -n
+ else:
+ b = 2 * n
+ return a, b
+
+ node = self._parse_and_analyze(test_fn, {})
+ node = control_flow.transform(node, TestNamer())
+ result = compiler.ast_to_object(node)
+ setattr(result, 'tf', control_flow_ops)
+
+ with self.test_session() as sess:
+ self.assertEqual((-1, 0), sess.run(
+ result.test_fn(constant_op.constant(1))))
+ self.assertEqual((0, -2),
+ sess.run(result.test_fn(constant_op.constant(-1))))
+
+ def test_if_single_var(self):
+
+ def test_fn(n):
+ if n > 0:
+ n = -n
+ return n
+
+ node = self._parse_and_analyze(test_fn, {})
+ node = control_flow.transform(node, TestNamer())
+ result = compiler.ast_to_object(node)
+ setattr(result, 'tf', control_flow_ops)
+
+ with self.test_session() as sess:
+ self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/py2tf/convert/for_canonicalization.py b/tensorflow/contrib/py2tf/convert/for_canonicalization.py
index eb31ac386f..c51a2326a6 100644
--- a/tensorflow/contrib/py2tf/convert/for_canonicalization.py
+++ b/tensorflow/contrib/py2tf/convert/for_canonicalization.py
@@ -55,8 +55,10 @@ class ForLoopCanonicalizationTransformer(gast.NodeTransformer):
loop_iter=node.iter,
target=node.target,
body=node.body,
- i=gast.Name(self.namer.new_symbol('i', body_scope.used), None, None),
- n=gast.Name(self.namer.new_symbol('n', body_scope.used), None, None))
+ i=gast.Name(
+ self.namer.new_symbol('i', body_scope.referenced), None, None),
+ n=gast.Name(
+ self.namer.new_symbol('n', body_scope.referenced), None, None))
def transform(node, namer):
diff --git a/tensorflow/contrib/py2tf/convert/logical_expressions.py b/tensorflow/contrib/py2tf/convert/logical_expressions.py
index c2f27a5f10..cfa23627fa 100644
--- a/tensorflow/contrib/py2tf/convert/logical_expressions.py
+++ b/tensorflow/contrib/py2tf/convert/logical_expressions.py
@@ -34,10 +34,15 @@ class LogicalExpressionTransformer(gast.NodeTransformer):
self.op_mapping = {
gast.And: 'tf.logical_and',
gast.Or: 'tf.logical_or',
+ gast.Not: 'tf.logical_not',
}
def visit_UnaryOp(self, node):
- raise NotImplementedError()
+ if isinstance(node.op, gast.Not):
+ tf_function = parser.parse_str(self.op_mapping[type(
+ node.op)]).body[0].value
+ node = gast.Call(func=tf_function, args=[node.operand], keywords=[])
+ return node
def visit_BoolOp(self, node):
# TODO(mdan): A normalizer may be useful here. Use ANF?
diff --git a/tensorflow/contrib/py2tf/convert/side_effect_guards.py b/tensorflow/contrib/py2tf/convert/side_effect_guards.py
index 61b428fd8f..1f25303fba 100644
--- a/tensorflow/contrib/py2tf/convert/side_effect_guards.py
+++ b/tensorflow/contrib/py2tf/convert/side_effect_guards.py
@@ -112,7 +112,7 @@ class SideEffectGuardTransformer(gast.NodeTransformer):
# tf.py_func(...)
args_scope = anno.getanno(node.value, 'args_scope')
- temp_name = self.namer.new_symbol('temp', args_scope.parent.used)
+ temp_name = self.namer.new_symbol('temp', args_scope.parent.referenced)
# TODO(mdan): Unsafe reference modification!
args_scope.mark_write(temp_name)
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/access.py b/tensorflow/contrib/py2tf/pyct/static_analysis/access.py
index 25409c63ba..05dc09eb04 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/access.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/access.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import copy
+
import gast
from tensorflow.contrib.py2tf.pyct import anno
@@ -52,10 +54,24 @@ class Scope(object):
self.created = set()
self.used = set()
+ @property
+ def referenced(self):
+ return self.used | self.modified
+
def __repr__(self):
return 'Scope{r=%s, c=%s, w=%s}' % (tuple(self.used), tuple(self.created),
tuple(self.modified))
+ def copy_from(self, other):
+ self.modified = copy.copy(other.modified)
+ self.created = copy.copy(other.created)
+ self.used = copy.copy(other.used)
+
+ def merge_from(self, other):
+ self.modified |= other.modified
+ self.created |= other.created
+ self.used |= other.used
+
def has(self, name):
if name in self.modified:
return True
@@ -131,7 +147,6 @@ class AccessResolver(gast.NodeTransformer):
def _process_block_node(self, node, block, scope_name):
current_scope = self.scope
- anno.setanno(node, '%s_parent_scope' % scope_name, current_scope)
block_scope = Scope(current_scope, isolated=False)
self.scope = block_scope
for n in block:
@@ -140,17 +155,43 @@ class AccessResolver(gast.NodeTransformer):
self.scope = current_scope
return node
+ def _process_parallel_blocks(self, parent, children):
+ # Because the scopes are not isolated, processing any child block
+ # modifies the parent state causing the other child blocks to be
+ # processed incorrectly. So we need to checkpoint the parent scope so that
+ # each child sees the same context.
+ before_parent = Scope(None)
+ before_parent.copy_from(self.scope)
+ after_children = []
+ for child, name in children:
+ self.scope.copy_from(before_parent)
+ parent = self._process_block_node(parent, child, name)
+ after_child = Scope(None)
+ after_child.copy_from(self.scope)
+ after_children.append(after_child)
+ for after_child in after_children:
+ self.scope.merge_from(after_child)
+ for child, name in children:
+ anno.setanno(parent, '%s_parent_scope' % name, self.scope)
+ return parent
+
+ def visit_If(self, node):
+ self.visit(node.test)
+ node = self._process_parallel_blocks(
+ node, ((node.body, 'body'), (node.orelse, 'orelse')))
+ return node
+
def visit_For(self, node):
self.visit(node.target)
self.visit(node.iter)
- node = self._process_block_node(node, node.body, 'body')
- node = self._process_block_node(node, node.orelse, 'orelse')
+ node = self._process_parallel_blocks(
+ node, ((node.body, 'body'), (node.orelse, 'orelse')))
return node
def visit_While(self, node):
self.visit(node.test)
- node = self._process_block_node(node, node.body, 'body')
- node = self._process_block_node(node, node.orelse, 'orelse')
+ node = self._process_parallel_blocks(
+ node, ((node.body, 'body'), (node.orelse, 'orelse')))
return node
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py
index 5f1c45c6c3..b16ce7b467 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/access_test.py
@@ -41,6 +41,26 @@ class ScopeTest(test.TestCase):
scope.mark_read('bar')
self.assertFalse(scope.has('bar'))
+ def test_copy(self):
+ scope = access.Scope(None)
+ scope.mark_write('foo')
+
+ other = access.Scope(None)
+ other.copy_from(scope)
+
+ self.assertTrue('foo' in other.created)
+
+ scope.mark_write('bar')
+ scope.copy_from(other)
+
+ self.assertFalse('bar' in scope.created)
+
+ scope.mark_write('bar')
+ scope.merge_from(other)
+
+ self.assertTrue('bar' in scope.created)
+ self.assertFalse('bar' in other.created)
+
def test_nesting(self):
scope = access.Scope(None)
scope.mark_write('foo')
@@ -75,6 +95,11 @@ class AccessResolverTest(test.TestCase):
self.assertTrue(anno.getanno(node.body[0].body[2].value,
'is_local')) # b in return b
+ def assertScopeIs(self, scope, used, modified, created):
+ self.assertItemsEqual(used, scope.used)
+ self.assertItemsEqual(modified, scope.modified)
+ self.assertItemsEqual(created, scope.created)
+
def test_print_statement(self):
def test_fn(a):
@@ -96,12 +121,9 @@ class AccessResolverTest(test.TestCase):
# The call node should be the one being annotated.
print_node = print_node.value
print_args_scope = anno.getanno(print_node, 'args_scope')
-
# We basically need to detect which variables are captured by the call
# arguments.
- self.assertItemsEqual(['a', 'b'], print_args_scope.used)
- self.assertItemsEqual([], print_args_scope.modified)
- self.assertItemsEqual([], print_args_scope.created)
+ self.assertScopeIs(print_args_scope, ('a', 'b'), (), ())
def test_call(self):
@@ -115,13 +137,10 @@ class AccessResolverTest(test.TestCase):
node = access.resolve(node)
call_node = node.body[0].body[2].value
- call_args_scope = anno.getanno(call_node, 'args_scope')
-
# We basically need to detect which variables are captured by the call
# arguments.
- self.assertItemsEqual(['a', 'b'], call_args_scope.used)
- self.assertItemsEqual([], call_args_scope.modified)
- self.assertItemsEqual([], call_args_scope.created)
+ self.assertScopeIs(
+ anno.getanno(call_node, 'args_scope'), ('a', 'b'), (), ())
def test_while(self):
@@ -136,16 +155,11 @@ class AccessResolverTest(test.TestCase):
node = access.resolve(node)
while_node = node.body[0].body[1]
- while_body_scope = anno.getanno(while_node, 'body_scope')
- while_parent_scope = anno.getanno(while_node, 'body_parent_scope')
-
- self.assertItemsEqual(['b'], while_body_scope.used)
- self.assertItemsEqual(['b', 'c'], while_body_scope.modified)
- self.assertItemsEqual(['c'], while_body_scope.created)
-
- self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.used)
- self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.modified)
- self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.created)
+ self.assertScopeIs(
+ anno.getanno(while_node, 'body_scope'), ('b',), ('b', 'c'), ('c',))
+ self.assertScopeIs(
+ anno.getanno(while_node, 'body_parent_scope'), ('a', 'b', 'c'),
+ ('a', 'b', 'c'), ('a', 'b', 'c'))
def test_for(self):
@@ -160,16 +174,42 @@ class AccessResolverTest(test.TestCase):
node = access.resolve(node)
for_node = node.body[0].body[1]
- for_body_scope = anno.getanno(for_node, 'body_scope')
- for_parent_scope = anno.getanno(for_node, 'body_parent_scope')
+ self.assertScopeIs(
+ anno.getanno(for_node, 'body_scope'), ('b',), ('b', 'c'), ('c',))
+ self.assertScopeIs(
+ anno.getanno(for_node, 'body_parent_scope'), ('a', 'b', 'c'),
+ ('a', 'b', 'c', '_'), ('a', 'b', 'c', '_'))
+
+ def test_if(self):
+
+ def test_fn(x):
+ if x > 0:
+ x = -x
+ y = 2 * x
+ z = -y
+ else:
+ x = 2 * x
+ y = -x
+ u = -y
+ return z, u
- self.assertItemsEqual(['b'], for_body_scope.used)
- self.assertItemsEqual(['b', 'c'], for_body_scope.modified)
- self.assertItemsEqual(['c'], for_body_scope.created)
+ node = parser.parse_object(test_fn)
+ node = access.resolve(node)
- self.assertItemsEqual(['a', 'b', 'c'], for_parent_scope.used)
- self.assertItemsEqual(['a', 'b', 'c', '_'], for_parent_scope.modified)
- self.assertItemsEqual(['a', 'b', 'c', '_'], for_parent_scope.created)
+ if_node = node.body[0].body[0]
+ self.assertScopeIs(
+ anno.getanno(if_node, 'body_scope'), ('x', 'y'), ('x', 'y', 'z'),
+ ('y', 'z'))
+ # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
+ self.assertScopeIs(
+ anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'),
+ ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
+ self.assertScopeIs(
+ anno.getanno(if_node, 'orelse_scope'), ('x', 'y'), ('x', 'y', 'u'),
+ ('y', 'u'))
+ self.assertScopeIs(
+ anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'),
+ ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
if __name__ == '__main__':