From 5785c0202f4f84c464ef22d0ff180730813f59f3 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Tue, 9 Oct 2018 14:04:23 -0700 Subject: Improve the control flow conversion for loops by using dataflow analysis to construct the state. This is part of a larger refactoring which removes the reliance on the deprecated Scope.created field. PiperOrigin-RevId: 216418556 --- .../python/autograph/converters/control_flow.py | 162 +++++++++++---------- .../autograph/converters/control_flow_test.py | 4 +- tensorflow/python/autograph/pyct/qual_names.py | 3 + 3 files changed, 93 insertions(+), 76 deletions(-) diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py index 416a60d2ee..70879f6c97 100644 --- a/tensorflow/python/autograph/converters/control_flow.py +++ b/tensorflow/python/autograph/converters/control_flow.py @@ -90,23 +90,11 @@ class ControlFlowTransformer(converter.Base): return templates.replace( template, test=test, body_name=body_name, orelse_name=orelse_name) - def _fmt_symbol_list(self, symbol_set): + def _fmt_symbols(self, symbol_set): if not symbol_set: return 'no variables' return ', '.join(map(str, symbol_set)) - def _validate_no_live_vars_created(self, node): - body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) - live_vars_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) - live_vars_created_in_body = live_vars_out & body_scope.created - if live_vars_created_in_body: - raise ValueError( - 'The following variables are created inside the loop and used later:' - '\n%s\n' - 'Variables must be declared outside loops because loops may not' - ' necessarily execute.' % self._fmt_symbol_list( - live_vars_created_in_body)) - def visit_If(self, node): node = self.generic_visit(node) @@ -138,8 +126,8 @@ class ControlFlowTransformer(converter.Base): ' creates %s, while the false branch creates %s. Make sure all' ' these variables are initialized either in both' ' branches or before the if statement.' % - (self._fmt_symbol_list(created_in_body), - self._fmt_symbol_list(created_in_orelse))) + (self._fmt_symbols(created_in_body), + self._fmt_symbols(created_in_orelse))) # Alias the closure variables inside the conditional functions, to allow # the functions access to the respective variables. @@ -206,51 +194,97 @@ class ControlFlowTransformer(converter.Base): return body_def + orelse_def + cond_expr - def visit_While(self, node): - self.generic_visit(node) - - self._validate_no_live_vars_created(node) - + def _get_loop_state(self, node): body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) - body_closure = body_scope.modified - body_scope.created - all_referenced = body_scope.referenced - - cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE) - cond_closure = set() - for s in cond_scope.used: - for root in s.support_set: - if root not in body_scope.created: - cond_closure.add(root) - - state = list(body_closure) - if not state: + defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) + live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) + live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) + reserved_symbols = body_scope.referenced + + # Note that it doesn't matter whether the variables are live after the loop. + # If the loop modifies them nonlocally (e.g. the result of an iteration + # depends on the previous iteration), then they need to be included in + # the loop state, regardless of whether they are later used or not. + loop_state = body_scope.modified & live_in + + undefined_lives = loop_state - defined_in + # Only simple variables must be defined. The composite ones will be + # implicitly checked at runtime. + undefined_simple_lives = {v for v in undefined_lives if v.is_simple()} + if undefined_simple_lives: + raise NameError( + 'cannot convert loop: it includes symbols that are undefined' + ' when entering the loop: {}'.format( + self._fmt_symbols(undefined_simple_lives))) + + live_defs_in_loop = (body_scope.modified - live_in) & live_out + if live_defs_in_loop: + # TODO(mdan): Include reference to explanation why. + raise NotImplementedError( + 'cannot convert loop: it includes symbols that are defined' + ' inside the loop, but used later: {}. To fix, initialize' + ' these symbols before the loop'.format( + self._fmt_symbols(live_defs_in_loop))) + + if not loop_state: # TODO(mdan): Implement this properly. - # To complete this statement, we need to check whether any variable - # created inside the body scope is used before being modified outside the - # scope. This should be done during activity analysis, and in general - # should cover the case where variables may not be initialized. - raise ValueError('cannot convert while loop: no outputs') + # We need to check whether any variable created inside the body scope + # is used before being modified outside the scope. This should be done + # during activity analysis, and in general should cover the case where + # variables may not be initialized. + raise ValueError('cannot convert loop: no outputs') + + return loop_state, reserved_symbols + def _state_constructs(self, loop_state, reserved_symbols): + loop_state = list(loop_state) state_ssf = [ - self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state + self.ctx.namer.new_symbol(s.ssf(), reserved_symbols) for s in loop_state ] ssf_map = { name: ssf - for name, ssf in zip(state, state_ssf) + for name, ssf in zip(loop_state, state_ssf) if str(name) != ssf } - if len(state) == 1: - state = state[0] + if len(loop_state) == 1: + loop_state = loop_state[0] state_ssf = state_ssf[0] - state_ast_tuple = state + state_ast_tuple = loop_state else: - state_ast_tuple = gast.Tuple([n.ast() for n in state], None) + state_ast_tuple = gast.Tuple([n.ast() for n in loop_state], None) + + return loop_state, state_ssf, state_ast_tuple, ssf_map + + def visit_While(self, node): + self.generic_visit(node) + loop_state, reserved_symbols = self._get_loop_state(node) + + # Note: one might expect we can dispatch based on the loop condition. + # But because that is dependent on the state, it cannot be evaluated ahead + # of time - doing that would risk duplicating any effects the condition has. + # Furthermore, we cannot evaluate slices and attributes, because they might + # trigger __getitem__ or __getattribute__. + # + # A case where this fails includes ops with side effects on a stateful + # resource captured in an object: + # + # while self.v.read() > 0: + # self.v.assign(1) + # + # TODO(mdan): Handle the case above. + cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE) + cond_closure = set() + for s in cond_scope.used: + cond_closure.update(s.support_set) + cond_closure -= loop_state + + loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( + loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) - # TODO(b/113118541) investigate the need-for and correctness-of extra_deps template = """ def test_name(state_ssf): return test @@ -262,12 +296,12 @@ class ControlFlowTransformer(converter.Base): """ node = templates.replace( template, - state=state, + state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, - test_name=self.ctx.namer.new_symbol('loop_test', body_scope.referenced), + test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=test, - body_name=self.ctx.namer.new_symbol('loop_body', body_scope.referenced), + body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body, extra_deps=tuple(s.ast() for s in cond_closure), ) @@ -277,30 +311,9 @@ class ControlFlowTransformer(converter.Base): def visit_For(self, node): self.generic_visit(node) - self._validate_no_live_vars_created(node) - - body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) - body_closure = body_scope.modified - body_scope.created - all_referenced = body_scope.referenced - - state = list(body_closure) - - state_ssf = [ - self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state - ] - ssf_map = { - name: ssf - for name, ssf in zip(state, state_ssf) - if str(name) != ssf - } - - if len(state) == 1: - state = state[0] - state_ssf = state_ssf[0] - state_ast_tuple = state - else: - state_ast_tuple = gast.Tuple([n.ast() for n in state], None) - + loop_state, reserved_symbols = self._get_loop_state(node) + loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( + loop_state, reserved_symbols) node_body = ast_util.rename_symbols(node.body, ssf_map) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') @@ -321,14 +334,15 @@ class ControlFlowTransformer(converter.Base): """ node = templates.replace( template, - state=state, + state=loop_state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, - extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), + extra_test_name=self.ctx.namer.new_symbol('extra_test', + reserved_symbols), extra_test_expr=extra_test, - body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), + body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node_body) return node diff --git a/tensorflow/python/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py index cfa0ea920c..03fdfc804e 100644 --- a/tensorflow/python/autograph/converters/control_flow_test.py +++ b/tensorflow/python/autograph/converters/control_flow_test.py @@ -83,7 +83,7 @@ class ControlFlowTest(converter_testing.TestCase): return s node, ctx = self.prepare(bad_while_loop, {}) - with self.assertRaises(transformer.AutographParseError): + with self.assertRaises(NameError): control_flow.transform(node, ctx) def test_if_basic(self): @@ -232,7 +232,7 @@ class ControlFlowTest(converter_testing.TestCase): return s node, ctx = self.prepare(bad_for_loop, {}) - with self.assertRaises(transformer.AutographParseError): + with self.assertRaises(NameError): control_flow.transform(node, ctx) def test_for_tuple_unpacking(self): diff --git a/tensorflow/python/autograph/pyct/qual_names.py b/tensorflow/python/autograph/pyct/qual_names.py index 334cbd7d38..6ad6199acf 100644 --- a/tensorflow/python/autograph/pyct/qual_names.py +++ b/tensorflow/python/autograph/pyct/qual_names.py @@ -99,6 +99,9 @@ class QN(object): def is_symbol(self): return isinstance(self.qn[0], str) + def is_simple(self): + return len(self.qn) <= 1 + def is_composite(self): return len(self.qn) > 1 -- cgit v1.2.3