aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/autograph/converters/conditional_expressions.py97
-rw-r--r--tensorflow/python/autograph/operators/__init__.py1
2 files changed, 4 insertions, 94 deletions
diff --git a/tensorflow/python/autograph/converters/conditional_expressions.py b/tensorflow/python/autograph/converters/conditional_expressions.py
index 40728f555d..a4eef7e6a1 100644
--- a/tensorflow/python/autograph/converters/conditional_expressions.py
+++ b/tensorflow/python/autograph/converters/conditional_expressions.py
@@ -19,109 +19,18 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.core import converter
-from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import templates
-from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
-
-
-class _FunctionDefs(object):
-
- def __init__(self):
- self.nodes = []
-
-
-class _Statement(object):
-
- def __init__(self):
- self.scope = None
class ConditionalExpressionTransformer(converter.Base):
"""Converts contitional expressions to functional form."""
- def _postprocess_statement(self, node):
- """Inserts any separate functions that node may use."""
- replacements = []
- for def_node in self.state[_FunctionDefs].nodes:
- replacements.extend(def_node)
- replacements.append(node)
- node = replacements
- # The corresponding enter is called by self.visit_block (see _process_block)
- self.state[_FunctionDefs].exit()
- return node, None
-
- def _create_branch(self, expr, name_stem):
- scope = self.state[_Statement].scope
- name = self.ctx.namer.new_symbol(name_stem, scope.referenced)
- template = """
- def name():
- return expr,
- """
- node = templates.replace(template, name=name, expr=expr)
- self.state[_FunctionDefs].nodes.append(node)
- return name
-
def visit_IfExp(self, node):
- if anno.hasanno(node.test, anno.Basic.QN):
- name_root = anno.getanno(node.test, anno.Basic.QN).ssf()
- else:
- name_root = 'ifexp'
-
- true_fn_name = self._create_branch(node.body, '%s_true' % name_root)
- false_fn_name = self._create_branch(node.orelse, '%s_false' % name_root)
-
return templates.replace_as_expression(
- 'ag__.utils.run_cond(test, true_fn_name, false_fn_name)',
+ 'ag__.if_stmt(test, lambda: true_expr, lambda: false_expr)',
test=node.test,
- true_fn_name=true_fn_name,
- false_fn_name=false_fn_name)
-
- def _process_block(self, scope, block):
- self.state[_Statement].enter()
- self.state[_Statement].scope = scope
- block = self.visit_block(
- block,
- before_visit=self.state[_FunctionDefs].enter,
- after_visit=self._postprocess_statement)
- self.state[_Statement].exit()
- return block
-
- def visit_FunctionDef(self, node):
- node.args = self.generic_visit(node.args)
- node.decorator_list = self.visit_block(node.decorator_list)
- node.body = self._process_block(
- anno.getanno(node, anno.Static.SCOPE), node.body)
- return node
-
- def visit_For(self, node):
- node.target = self.visit(node.target)
- node.body = self._process_block(
- anno.getanno(node, NodeAnno.BODY_SCOPE), node.body)
- node.orelse = self._process_block(
- anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse)
- return node
-
- def visit_While(self, node):
- node.test = self.visit(node.test)
- node.body = self._process_block(
- anno.getanno(node, NodeAnno.BODY_SCOPE), node.body)
- node.orelse = self._process_block(
- anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse)
- return node
-
- def visit_If(self, node):
- node.test = self.visit(node.test)
- node.body = self._process_block(
- anno.getanno(node, NodeAnno.BODY_SCOPE), node.body)
- node.orelse = self._process_block(
- anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse)
- return node
-
- def visit_With(self, node):
- node.items = self.visit_block(node.items)
- node.body = self._process_block(
- anno.getanno(node, NodeAnno.BODY_SCOPE), node.body)
- return node
+ true_expr=node.body,
+ false_expr=node.orelse)
def transform(node, ctx):
diff --git a/tensorflow/python/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py
index 53f4b0ddc8..8ba2558ac2 100644
--- a/tensorflow/python/autograph/operators/__init__.py
+++ b/tensorflow/python/autograph/operators/__init__.py
@@ -38,6 +38,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.operators.control_flow import for_stmt
+from tensorflow.python.autograph.operators.control_flow import if_stmt
from tensorflow.python.autograph.operators.control_flow import while_stmt
from tensorflow.python.autograph.operators.data_structures import list_append
from tensorflow.python.autograph.operators.data_structures import list_pop