aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/converters/control_flow.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/converters/control_flow.py')
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow.py165
1 files changed, 74 insertions, 91 deletions
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py
index f4a8710627..5a5a6ad63a 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/contrib/autograph/converters/control_flow.py
@@ -25,8 +25,7 @@ from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis import cfg
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.contrib.autograph.pyct.static_analysis import annos
class SymbolNamer(object):
@@ -47,6 +46,7 @@ class SymbolNamer(object):
class ControlFlowTransformer(converter.Base):
"""Transforms control flow structures like loops an conditionals."""
+
def _create_cond_branch(self, body_name, aliased_orig_names,
aliased_new_names, body, returns):
if aliased_orig_names:
@@ -90,55 +90,51 @@ class ControlFlowTransformer(converter.Base):
return templates.replace(
template, test=test, body_name=body_name, orelse_name=orelse_name)
- def visit_If(self, node):
- self.generic_visit(node)
+ def _fmt_symbol_list(self, symbol_set):
+ if not symbol_set:
+ return 'no variables'
+ return ', '.join(map(str, symbol_set))
- body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
- orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)
- body_defs = body_scope.created | body_scope.modified
- orelse_defs = orelse_scope.created | orelse_scope.modified
- live = anno.getanno(node, 'live_out')
-
- # We'll need to check if we're closing over variables that are defined
- # elsewhere in the function
- # NOTE: we can only detect syntactic closure in the scope
- # of the code passed in. If the AutoGraph'd function itself closes
- # over other variables, this analysis won't take that into account.
- defined = anno.getanno(node, 'defined_in')
-
- # We only need to return variables that are
- # - modified by one or both branches
- # - live (or has a live parent) at the end of the conditional
- modified = []
- for def_ in body_defs | orelse_defs:
- def_with_parents = set((def_,)) | def_.support_set
- if live & def_with_parents:
- modified.append(def_)
-
- # We need to check if live created variables are balanced
- # in both branches
- created = live & (body_scope.created | orelse_scope.created)
-
- # The if statement is illegal if there are variables that are created,
- # that are also live, but both branches don't create them.
- if created:
- if created != (body_scope.created & live):
- raise ValueError(
- 'The main branch does not create all live symbols that the else '
- 'branch does.')
- if created != (orelse_scope.created & live):
- raise ValueError(
- 'The else branch does not create all live symbols that the main '
- 'branch does.')
-
- # Alias the closure variables inside the conditional functions
- # to avoid errors caused by the local variables created in the branch
- # functions.
+ def visit_If(self, node):
+ node = self.generic_visit(node)
+
+ body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
+ orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
+ defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
+ live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
+
+ modified_in_cond = body_scope.modified | orelse_scope.modified
+ returned_from_cond = set()
+ for s in modified_in_cond:
+ if s in live_out:
+ returned_from_cond.add(s)
+ elif s.is_composite():
+ # Special treatment for compound objects: if any of their owner entities
+ # are live, then they are outputs as well.
+ if any(owner in live_out for owner in s.owner_set):
+ returned_from_cond.add(s)
+
+ need_alias_in_body = body_scope.modified & defined_in
+ need_alias_in_orelse = orelse_scope.modified & defined_in
+
+ created_in_body = body_scope.modified & returned_from_cond - defined_in
+ created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in
+
+ if created_in_body != created_in_orelse:
+ raise ValueError(
+ 'if statement may not initialize all variables: the true branch'
+ ' creates %s, while the false branch creates %s. Make sure all'
+ ' these variables are initialized either in both'
+ ' branches or before the if statement.' %
+ (self._fmt_symbol_list(created_in_body),
+ self._fmt_symbol_list(created_in_orelse)))
+
+ # Alias the closure variables inside the conditional functions, to allow
+ # the functions access to the respective variables.
# We will alias variables independently for body and orelse scope,
# because different branches might write different variables.
- aliased_body_orig_names = tuple(body_scope.modified - body_scope.created)
- aliased_orelse_orig_names = tuple(orelse_scope.modified -
- orelse_scope.created)
+ aliased_body_orig_names = tuple(need_alias_in_body)
+ aliased_orelse_orig_names = tuple(need_alias_in_orelse)
aliased_body_new_names = tuple(
self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
for s in aliased_body_orig_names)
@@ -153,58 +149,47 @@ class ControlFlowTransformer(converter.Base):
node_body = ast_util.rename_symbols(node.body, alias_body_map)
node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)
- if not modified:
+ returned_from_cond = tuple(returned_from_cond)
+ if returned_from_cond:
+ if len(returned_from_cond) == 1:
+ # TODO(mdan): Move this quirk into the operator implementation.
+ cond_results = returned_from_cond[0]
+ else:
+ cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)
+
+ returned_from_body = tuple(
+ alias_body_map[s] if s in need_alias_in_body else s
+ for s in returned_from_cond)
+ returned_from_orelse = tuple(
+ alias_orelse_map[s] if s in need_alias_in_orelse else s
+ for s in returned_from_cond)
+
+ else:
# When the cond would return no value, we leave the cond called without
# results. That in turn should trigger the side effect guards. The
# branch functions will return a dummy value that ensures cond
# actually has some return value as well.
- results = None
- elif len(modified) == 1:
- results = modified[0]
- else:
- results = gast.Tuple([s.ast() for s in modified], None)
+ cond_results = None
+ # TODO(mdan): This doesn't belong here; it's specific to the operator.
+ returned_from_body = templates.replace_as_expression('tf.constant(1)')
+ returned_from_orelse = templates.replace_as_expression('tf.constant(1)')
body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
- if modified:
-
- def build_returns(aliased_names, alias_map, scope):
- """Builds list of return variables for a branch of a conditional."""
- returns = []
- for s in modified:
- if s in aliased_names:
- returns.append(alias_map[s])
- else:
- if s not in scope.created | defined:
- raise ValueError(
- 'Attempting to return variable "%s" from the true branch of '
- 'a conditional, but it was not closed over, or created in '
- 'this branch.' % str(s))
- else:
- returns.append(s)
- return tuple(returns)
-
- body_returns = build_returns(aliased_body_orig_names, alias_body_map,
- body_scope)
- orelse_returns = build_returns(aliased_orelse_orig_names,
- alias_orelse_map, orelse_scope)
-
- else:
- body_returns = orelse_returns = templates.replace('tf.ones(())')[0].value
body_def = self._create_cond_branch(
body_name,
- aliased_orig_names=tuple(aliased_body_orig_names),
- aliased_new_names=tuple(aliased_body_new_names),
+ aliased_orig_names=aliased_body_orig_names,
+ aliased_new_names=aliased_body_new_names,
body=node_body,
- returns=body_returns)
+ returns=returned_from_body)
orelse_def = self._create_cond_branch(
orelse_name,
- aliased_orig_names=tuple(aliased_orelse_orig_names),
- aliased_new_names=tuple(aliased_orelse_new_names),
+ aliased_orig_names=aliased_orelse_orig_names,
+ aliased_new_names=aliased_orelse_new_names,
body=node_orelse,
- returns=orelse_returns)
- cond_expr = self._create_cond_expr(results, node.test, body_name,
+ returns=returned_from_orelse)
+ cond_expr = self._create_cond_expr(cond_results, node.test, body_name,
orelse_name)
return body_def + orelse_def + cond_expr
@@ -212,11 +197,11 @@ class ControlFlowTransformer(converter.Base):
def visit_While(self, node):
self.generic_visit(node)
- body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+ body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
body_closure = body_scope.modified - body_scope.created
all_referenced = body_scope.referenced
- cond_scope = anno.getanno(node, NodeAnno.COND_SCOPE)
+ cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE)
cond_closure = set()
for s in cond_scope.referenced:
for root in s.support_set:
@@ -277,7 +262,7 @@ class ControlFlowTransformer(converter.Base):
def visit_For(self, node):
self.generic_visit(node)
- body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+ body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
body_closure = body_scope.modified - body_scope.created
all_referenced = body_scope.referenced
@@ -331,7 +316,5 @@ class ControlFlowTransformer(converter.Base):
def transform(node, ctx):
- cfg.run_analyses(node, cfg.Liveness(ctx.info))
- cfg.run_analyses(node, cfg.Defined(ctx.info))
node = ControlFlowTransformer(ctx).visit(node)
return node