diff options
-rw-r--r-- | tensorflow/python/autograph/converters/conditional_expressions.py | 97 | ||||
-rw-r--r-- | tensorflow/python/autograph/operators/__init__.py | 1 |
2 files changed, 4 insertions, 94 deletions
diff --git a/tensorflow/python/autograph/converters/conditional_expressions.py b/tensorflow/python/autograph/converters/conditional_expressions.py index 40728f555d..a4eef7e6a1 100644 --- a/tensorflow/python/autograph/converters/conditional_expressions.py +++ b/tensorflow/python/autograph/converters/conditional_expressions.py @@ -19,109 +19,18 @@ from __future__ import division from __future__ import print_function from tensorflow.python.autograph.core import converter -from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import templates -from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno - - -class _FunctionDefs(object): - - def __init__(self): - self.nodes = [] - - -class _Statement(object): - - def __init__(self): - self.scope = None class ConditionalExpressionTransformer(converter.Base): """Converts contitional expressions to functional form.""" - def _postprocess_statement(self, node): - """Inserts any separate functions that node may use.""" - replacements = [] - for def_node in self.state[_FunctionDefs].nodes: - replacements.extend(def_node) - replacements.append(node) - node = replacements - # The corresponding enter is called by self.visit_block (see _process_block) - self.state[_FunctionDefs].exit() - return node, None - - def _create_branch(self, expr, name_stem): - scope = self.state[_Statement].scope - name = self.ctx.namer.new_symbol(name_stem, scope.referenced) - template = """ - def name(): - return expr, - """ - node = templates.replace(template, name=name, expr=expr) - self.state[_FunctionDefs].nodes.append(node) - return name - def visit_IfExp(self, node): - if anno.hasanno(node.test, anno.Basic.QN): - name_root = anno.getanno(node.test, anno.Basic.QN).ssf() - else: - name_root = 'ifexp' - - true_fn_name = self._create_branch(node.body, '%s_true' % name_root) - false_fn_name = self._create_branch(node.orelse, '%s_false' % name_root) - return templates.replace_as_expression( - 'ag__.utils.run_cond(test, true_fn_name, false_fn_name)', + 'ag__.if_stmt(test, lambda: true_expr, lambda: false_expr)', test=node.test, - true_fn_name=true_fn_name, - false_fn_name=false_fn_name) - - def _process_block(self, scope, block): - self.state[_Statement].enter() - self.state[_Statement].scope = scope - block = self.visit_block( - block, - before_visit=self.state[_FunctionDefs].enter, - after_visit=self._postprocess_statement) - self.state[_Statement].exit() - return block - - def visit_FunctionDef(self, node): - node.args = self.generic_visit(node.args) - node.decorator_list = self.visit_block(node.decorator_list) - node.body = self._process_block( - anno.getanno(node, anno.Static.SCOPE), node.body) - return node - - def visit_For(self, node): - node.target = self.visit(node.target) - node.body = self._process_block( - anno.getanno(node, NodeAnno.BODY_SCOPE), node.body) - node.orelse = self._process_block( - anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse) - return node - - def visit_While(self, node): - node.test = self.visit(node.test) - node.body = self._process_block( - anno.getanno(node, NodeAnno.BODY_SCOPE), node.body) - node.orelse = self._process_block( - anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse) - return node - - def visit_If(self, node): - node.test = self.visit(node.test) - node.body = self._process_block( - anno.getanno(node, NodeAnno.BODY_SCOPE), node.body) - node.orelse = self._process_block( - anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse) - return node - - def visit_With(self, node): - node.items = self.visit_block(node.items) - node.body = self._process_block( - anno.getanno(node, NodeAnno.BODY_SCOPE), node.body) - return node + true_expr=node.body, + false_expr=node.orelse) def transform(node, ctx): diff --git a/tensorflow/python/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py index 53f4b0ddc8..8ba2558ac2 100644 --- a/tensorflow/python/autograph/operators/__init__.py +++ b/tensorflow/python/autograph/operators/__init__.py @@ -38,6 +38,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.autograph.operators.control_flow import for_stmt +from tensorflow.python.autograph.operators.control_flow import if_stmt from tensorflow.python.autograph.operators.control_flow import while_stmt from tensorflow.python.autograph.operators.data_structures import list_append from tensorflow.python.autograph.operators.data_structures import list_pop |