diff options
Diffstat (limited to 'tensorflow/python/autograph/converters/control_flow.py')
-rw-r--r-- | tensorflow/python/autograph/converters/control_flow.py | 339 |
1 files changed, 339 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py new file mode 100644 index 0000000000..416a60d2ee --- /dev/null +++ b/tensorflow/python/autograph/converters/control_flow.py @@ -0,0 +1,339 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Handles control flow statements: while, for, if.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.python.autograph.core import converter +from tensorflow.python.autograph.pyct import anno +from tensorflow.python.autograph.pyct import ast_util +from tensorflow.python.autograph.pyct import parser +from tensorflow.python.autograph.pyct import templates +from tensorflow.python.autograph.pyct.static_analysis import annos + + +class SymbolNamer(object): + """Describes the interface for ControlFlowTransformer's namer.""" + + def new_symbol(self, name_root, reserved_locals): + """Generate a new unique symbol. + + Args: + name_root: String, used as stem in the new name. + reserved_locals: Set(string), additional local symbols that are reserved + and which should not be used. + Returns: + String. + """ + raise NotImplementedError() + + +class ControlFlowTransformer(converter.Base): + """Transforms control flow structures like loops an conditionals.""" + + def _create_cond_branch(self, body_name, aliased_orig_names, + aliased_new_names, body, returns): + if aliased_orig_names: + template = """ + def body_name(): + aliased_new_names, = aliased_orig_names, + body + return (returns,) + """ + return templates.replace( + template, + body_name=body_name, + body=body, + aliased_orig_names=aliased_orig_names, + aliased_new_names=aliased_new_names, + returns=returns) + else: + template = """ + def body_name(): + body + return (returns,) + """ + return templates.replace( + template, body_name=body_name, body=body, returns=returns) + + def _create_cond_expr(self, results, test, body_name, orelse_name): + if results is not None: + template = """ + results = ag__.utils.run_cond(test, body_name, orelse_name) + """ + return templates.replace( + template, + test=test, + results=results, + body_name=body_name, + orelse_name=orelse_name) + else: + template = """ + ag__.utils.run_cond(test, body_name, orelse_name) + """ + return templates.replace( + template, test=test, body_name=body_name, orelse_name=orelse_name) + + def _fmt_symbol_list(self, symbol_set): + if not symbol_set: + return 'no variables' + return ', '.join(map(str, symbol_set)) + + def _validate_no_live_vars_created(self, node): + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + live_vars_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) + live_vars_created_in_body = live_vars_out & body_scope.created + if live_vars_created_in_body: + raise ValueError( + 'The following variables are created inside the loop and used later:' + '\n%s\n' + 'Variables must be declared outside loops because loops may not' + ' necessarily execute.' % self._fmt_symbol_list( + live_vars_created_in_body)) + + def visit_If(self, node): + node = self.generic_visit(node) + + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) + defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) + live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) + + modified_in_cond = body_scope.modified | orelse_scope.modified + returned_from_cond = set() + for s in modified_in_cond: + if s in live_out: + returned_from_cond.add(s) + elif s.is_composite(): + # Special treatment for compound objects: if any of their owner entities + # are live, then they are outputs as well. + if any(owner in live_out for owner in s.owner_set): + returned_from_cond.add(s) + + need_alias_in_body = body_scope.modified & defined_in + need_alias_in_orelse = orelse_scope.modified & defined_in + + created_in_body = body_scope.modified & returned_from_cond - defined_in + created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in + + if created_in_body != created_in_orelse: + raise ValueError( + 'if statement may not initialize all variables: the true branch' + ' creates %s, while the false branch creates %s. Make sure all' + ' these variables are initialized either in both' + ' branches or before the if statement.' % + (self._fmt_symbol_list(created_in_body), + self._fmt_symbol_list(created_in_orelse))) + + # Alias the closure variables inside the conditional functions, to allow + # the functions access to the respective variables. + # We will alias variables independently for body and orelse scope, + # because different branches might write different variables. + aliased_body_orig_names = tuple(need_alias_in_body) + aliased_orelse_orig_names = tuple(need_alias_in_orelse) + aliased_body_new_names = tuple( + self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) + for s in aliased_body_orig_names) + aliased_orelse_new_names = tuple( + self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) + for s in aliased_orelse_orig_names) + + alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) + alias_orelse_map = dict( + zip(aliased_orelse_orig_names, aliased_orelse_new_names)) + + node_body = ast_util.rename_symbols(node.body, alias_body_map) + node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) + + returned_from_cond = tuple(returned_from_cond) + if returned_from_cond: + if len(returned_from_cond) == 1: + # TODO(mdan): Move this quirk into the operator implementation. + cond_results = returned_from_cond[0] + else: + cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) + + returned_from_body = tuple( + alias_body_map[s] if s in need_alias_in_body else s + for s in returned_from_cond) + returned_from_orelse = tuple( + alias_orelse_map[s] if s in need_alias_in_orelse else s + for s in returned_from_cond) + + else: + # When the cond would return no value, we leave the cond called without + # results. That in turn should trigger the side effect guards. The + # branch functions will return a dummy value that ensures cond + # actually has some return value as well. + cond_results = None + # TODO(mdan): This doesn't belong here; it's specific to the operator. + returned_from_body = templates.replace_as_expression('tf.constant(1)') + returned_from_orelse = templates.replace_as_expression('tf.constant(1)') + + body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) + orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) + + body_def = self._create_cond_branch( + body_name, + aliased_orig_names=aliased_body_orig_names, + aliased_new_names=aliased_body_new_names, + body=node_body, + returns=returned_from_body) + orelse_def = self._create_cond_branch( + orelse_name, + aliased_orig_names=aliased_orelse_orig_names, + aliased_new_names=aliased_orelse_new_names, + body=node_orelse, + returns=returned_from_orelse) + cond_expr = self._create_cond_expr(cond_results, node.test, body_name, + orelse_name) + + return body_def + orelse_def + cond_expr + + def visit_While(self, node): + self.generic_visit(node) + + self._validate_no_live_vars_created(node) + + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + body_closure = body_scope.modified - body_scope.created + all_referenced = body_scope.referenced + + cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE) + cond_closure = set() + for s in cond_scope.used: + for root in s.support_set: + if root not in body_scope.created: + cond_closure.add(root) + + state = list(body_closure) + if not state: + # TODO(mdan): Implement this properly. + # To complete this statement, we need to check whether any variable + # created inside the body scope is used before being modified outside the + # scope. This should be done during activity analysis, and in general + # should cover the case where variables may not be initialized. + raise ValueError('cannot convert while loop: no outputs') + + state_ssf = [ + self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state + ] + ssf_map = { + name: ssf + for name, ssf in zip(state, state_ssf) + if str(name) != ssf + } + + if len(state) == 1: + state = state[0] + state_ssf = state_ssf[0] + state_ast_tuple = state + else: + state_ast_tuple = gast.Tuple([n.ast() for n in state], None) + + node_body = ast_util.rename_symbols(node.body, ssf_map) + test = ast_util.rename_symbols(node.test, ssf_map) + + # TODO(b/113118541) investigate the need-for and correctness-of extra_deps + template = """ + def test_name(state_ssf): + return test + def body_name(state_ssf): + body + return state_ssf, + state_ast_tuple = ag__.while_stmt( + test_name, body_name, (state,), (extra_deps,)) + """ + node = templates.replace( + template, + state=state, + state_ssf=state_ssf, + state_ast_tuple=state_ast_tuple, + test_name=self.ctx.namer.new_symbol('loop_test', body_scope.referenced), + test=test, + body_name=self.ctx.namer.new_symbol('loop_body', body_scope.referenced), + body=node_body, + extra_deps=tuple(s.ast() for s in cond_closure), + ) + + return node + + def visit_For(self, node): + self.generic_visit(node) + + self._validate_no_live_vars_created(node) + + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + body_closure = body_scope.modified - body_scope.created + all_referenced = body_scope.referenced + + state = list(body_closure) + + state_ssf = [ + self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state + ] + ssf_map = { + name: ssf + for name, ssf in zip(state, state_ssf) + if str(name) != ssf + } + + if len(state) == 1: + state = state[0] + state_ssf = state_ssf[0] + state_ast_tuple = state + else: + state_ast_tuple = gast.Tuple([n.ast() for n in state], None) + + node_body = ast_util.rename_symbols(node.body, ssf_map) + if anno.hasanno(node, 'extra_test'): + extra_test = anno.getanno(node, 'extra_test') + extra_test = ast_util.rename_symbols(extra_test, ssf_map) + else: + extra_test = parser.parse_expression('True') + + template = """ + def extra_test_name(state_ssf): + return extra_test_expr + def body_name(loop_vars, state_ssf): + # Workaround for PEP-3113 + iterate = loop_vars + body + return state_ssf, + state_ast_tuple = ag__.for_stmt( + iter_, extra_test_name, body_name, (state,)) + """ + node = templates.replace( + template, + state=state, + state_ssf=state_ssf, + state_ast_tuple=state_ast_tuple, + iter_=node.iter, + iterate=node.target, + extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), + extra_test_expr=extra_test, + body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), + body=node_body) + + return node + + +def transform(node, ctx): + node = ControlFlowTransformer(ctx).visit(node) + return node |