aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-10-09 14:04:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 14:12:24 -0700
commit5785c0202f4f84c464ef22d0ff180730813f59f3 (patch)
treede62b56fe072aada15f01dc69000173b0a63822a
parent1f556d3a4172c30cf461e7e66334b70ffad2d559 (diff)
Improve the control flow conversion for loops by using dataflow analysis to construct the state. This is part of a larger refactoring which removes the reliance on the deprecated Scope.created field.
PiperOrigin-RevId: 216418556
-rw-r--r--tensorflow/python/autograph/converters/control_flow.py162
-rw-r--r--tensorflow/python/autograph/converters/control_flow_test.py4
-rw-r--r--tensorflow/python/autograph/pyct/qual_names.py3
3 files changed, 93 insertions, 76 deletions
diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py
index 416a60d2ee..70879f6c97 100644
--- a/tensorflow/python/autograph/converters/control_flow.py
+++ b/tensorflow/python/autograph/converters/control_flow.py
@@ -90,23 +90,11 @@ class ControlFlowTransformer(converter.Base):
return templates.replace(
template, test=test, body_name=body_name, orelse_name=orelse_name)
- def _fmt_symbol_list(self, symbol_set):
+ def _fmt_symbols(self, symbol_set):
if not symbol_set:
return 'no variables'
return ', '.join(map(str, symbol_set))
- def _validate_no_live_vars_created(self, node):
- body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
- live_vars_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
- live_vars_created_in_body = live_vars_out & body_scope.created
- if live_vars_created_in_body:
- raise ValueError(
- 'The following variables are created inside the loop and used later:'
- '\n%s\n'
- 'Variables must be declared outside loops because loops may not'
- ' necessarily execute.' % self._fmt_symbol_list(
- live_vars_created_in_body))
-
def visit_If(self, node):
node = self.generic_visit(node)
@@ -138,8 +126,8 @@ class ControlFlowTransformer(converter.Base):
' creates %s, while the false branch creates %s. Make sure all'
' these variables are initialized either in both'
' branches or before the if statement.' %
- (self._fmt_symbol_list(created_in_body),
- self._fmt_symbol_list(created_in_orelse)))
+ (self._fmt_symbols(created_in_body),
+ self._fmt_symbols(created_in_orelse)))
# Alias the closure variables inside the conditional functions, to allow
# the functions access to the respective variables.
@@ -206,51 +194,97 @@ class ControlFlowTransformer(converter.Base):
return body_def + orelse_def + cond_expr
- def visit_While(self, node):
- self.generic_visit(node)
-
- self._validate_no_live_vars_created(node)
-
+ def _get_loop_state(self, node):
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
- body_closure = body_scope.modified - body_scope.created
- all_referenced = body_scope.referenced
-
- cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE)
- cond_closure = set()
- for s in cond_scope.used:
- for root in s.support_set:
- if root not in body_scope.created:
- cond_closure.add(root)
-
- state = list(body_closure)
- if not state:
+ defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
+ live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
+ live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
+ reserved_symbols = body_scope.referenced
+
+ # Note that it doesn't matter whether the variables are live after the loop.
+ # If the loop modifies them nonlocally (e.g. the result of an iteration
+ # depends on the previous iteration), then they need to be included in
+ # the loop state, regardless of whether they are later used or not.
+ loop_state = body_scope.modified & live_in
+
+ undefined_lives = loop_state - defined_in
+ # Only simple variables must be defined. The composite ones will be
+ # implicitly checked at runtime.
+ undefined_simple_lives = {v for v in undefined_lives if v.is_simple()}
+ if undefined_simple_lives:
+ raise NameError(
+ 'cannot convert loop: it includes symbols that are undefined'
+ ' when entering the loop: {}'.format(
+ self._fmt_symbols(undefined_simple_lives)))
+
+ live_defs_in_loop = (body_scope.modified - live_in) & live_out
+ if live_defs_in_loop:
+ # TODO(mdan): Include reference to explanation why.
+ raise NotImplementedError(
+ 'cannot convert loop: it includes symbols that are defined'
+ ' inside the loop, but used later: {}. To fix, initialize'
+ ' these symbols before the loop'.format(
+ self._fmt_symbols(live_defs_in_loop)))
+
+ if not loop_state:
# TODO(mdan): Implement this properly.
- # To complete this statement, we need to check whether any variable
- # created inside the body scope is used before being modified outside the
- # scope. This should be done during activity analysis, and in general
- # should cover the case where variables may not be initialized.
- raise ValueError('cannot convert while loop: no outputs')
+ # We need to check whether any variable created inside the body scope
+ # is used before being modified outside the scope. This should be done
+ # during activity analysis, and in general should cover the case where
+ # variables may not be initialized.
+ raise ValueError('cannot convert loop: no outputs')
+
+ return loop_state, reserved_symbols
+ def _state_constructs(self, loop_state, reserved_symbols):
+ loop_state = list(loop_state)
state_ssf = [
- self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
+ self.ctx.namer.new_symbol(s.ssf(), reserved_symbols) for s in loop_state
]
ssf_map = {
name: ssf
- for name, ssf in zip(state, state_ssf)
+ for name, ssf in zip(loop_state, state_ssf)
if str(name) != ssf
}
- if len(state) == 1:
- state = state[0]
+ if len(loop_state) == 1:
+ loop_state = loop_state[0]
state_ssf = state_ssf[0]
- state_ast_tuple = state
+ state_ast_tuple = loop_state
else:
- state_ast_tuple = gast.Tuple([n.ast() for n in state], None)
+ state_ast_tuple = gast.Tuple([n.ast() for n in loop_state], None)
+
+ return loop_state, state_ssf, state_ast_tuple, ssf_map
+
+ def visit_While(self, node):
+ self.generic_visit(node)
+ loop_state, reserved_symbols = self._get_loop_state(node)
+
+ # Note: one might expect we can dispatch based on the loop condition.
+ # But because that is dependent on the state, it cannot be evaluated ahead
+ # of time - doing that would risk duplicating any effects the condition has.
+ # Furthermore, we cannot evaluate slices and attributes, because they might
+ # trigger __getitem__ or __getattribute__.
+ #
+ # A case where this fails includes ops with side effects on a stateful
+ # resource captured in an object:
+ #
+ # while self.v.read() > 0:
+ # self.v.assign(1)
+ #
+ # TODO(mdan): Handle the case above.
+ cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE)
+ cond_closure = set()
+ for s in cond_scope.used:
+ cond_closure.update(s.support_set)
+ cond_closure -= loop_state
+
+ loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
+ loop_state, reserved_symbols)
node_body = ast_util.rename_symbols(node.body, ssf_map)
test = ast_util.rename_symbols(node.test, ssf_map)
- # TODO(b/113118541) investigate the need-for and correctness-of extra_deps
template = """
def test_name(state_ssf):
return test
@@ -262,12 +296,12 @@ class ControlFlowTransformer(converter.Base):
"""
node = templates.replace(
template,
- state=state,
+ state=loop_state,
state_ssf=state_ssf,
state_ast_tuple=state_ast_tuple,
- test_name=self.ctx.namer.new_symbol('loop_test', body_scope.referenced),
+ test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
test=test,
- body_name=self.ctx.namer.new_symbol('loop_body', body_scope.referenced),
+ body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
body=node_body,
extra_deps=tuple(s.ast() for s in cond_closure),
)
@@ -277,30 +311,9 @@ class ControlFlowTransformer(converter.Base):
def visit_For(self, node):
self.generic_visit(node)
- self._validate_no_live_vars_created(node)
-
- body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
- body_closure = body_scope.modified - body_scope.created
- all_referenced = body_scope.referenced
-
- state = list(body_closure)
-
- state_ssf = [
- self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
- ]
- ssf_map = {
- name: ssf
- for name, ssf in zip(state, state_ssf)
- if str(name) != ssf
- }
-
- if len(state) == 1:
- state = state[0]
- state_ssf = state_ssf[0]
- state_ast_tuple = state
- else:
- state_ast_tuple = gast.Tuple([n.ast() for n in state], None)
-
+ loop_state, reserved_symbols = self._get_loop_state(node)
+ loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
+ loop_state, reserved_symbols)
node_body = ast_util.rename_symbols(node.body, ssf_map)
if anno.hasanno(node, 'extra_test'):
extra_test = anno.getanno(node, 'extra_test')
@@ -321,14 +334,15 @@ class ControlFlowTransformer(converter.Base):
"""
node = templates.replace(
template,
- state=state,
+ state=loop_state,
state_ssf=state_ssf,
state_ast_tuple=state_ast_tuple,
iter_=node.iter,
iterate=node.target,
- extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced),
+ extra_test_name=self.ctx.namer.new_symbol('extra_test',
+ reserved_symbols),
extra_test_expr=extra_test,
- body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
+ body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
body=node_body)
return node
diff --git a/tensorflow/python/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py
index cfa0ea920c..03fdfc804e 100644
--- a/tensorflow/python/autograph/converters/control_flow_test.py
+++ b/tensorflow/python/autograph/converters/control_flow_test.py
@@ -83,7 +83,7 @@ class ControlFlowTest(converter_testing.TestCase):
return s
node, ctx = self.prepare(bad_while_loop, {})
- with self.assertRaises(transformer.AutographParseError):
+ with self.assertRaises(NameError):
control_flow.transform(node, ctx)
def test_if_basic(self):
@@ -232,7 +232,7 @@ class ControlFlowTest(converter_testing.TestCase):
return s
node, ctx = self.prepare(bad_for_loop, {})
- with self.assertRaises(transformer.AutographParseError):
+ with self.assertRaises(NameError):
control_flow.transform(node, ctx)
def test_for_tuple_unpacking(self):
diff --git a/tensorflow/python/autograph/pyct/qual_names.py b/tensorflow/python/autograph/pyct/qual_names.py
index 334cbd7d38..6ad6199acf 100644
--- a/tensorflow/python/autograph/pyct/qual_names.py
+++ b/tensorflow/python/autograph/pyct/qual_names.py
@@ -99,6 +99,9 @@ class QN(object):
def is_symbol(self):
return isinstance(self.qn[0], str)
+ def is_simple(self):
+ return len(self.qn) <= 1
+
def is_composite(self):
return len(self.qn) > 1