aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/converters/break_statements.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/converters/break_statements.py')
-rw-r--r--tensorflow/contrib/autograph/converters/break_statements.py35
1 files changed, 20 insertions, 15 deletions
diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py
index a990e359a2..2a60750bda 100644
--- a/tensorflow/contrib/autograph/converters/break_statements.py
+++ b/tensorflow/contrib/autograph/converters/break_statements.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Canonicalizes break statements by de-sugaring into a control boolean."""
+"""Lowers break statements to conditionals."""
from __future__ import absolute_import
from __future__ import division
@@ -24,17 +24,22 @@ from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
-# Tags for local state.
-BREAK_USED = 'break_used'
-CONTROL_VAR_NAME = 'control_var_name'
+class _Break(object):
+ def __init__(self):
+ self.used = False
+ self.control_var_name = None
-class BreakStatementTransformer(converter.Base):
+ def __repr__(self):
+ return 'used: %s, var: %s' % (self.used, self.control_var_name)
+
+
+class BreakTransformer(converter.Base):
"""Canonicalizes break statements into additional conditionals."""
def visit_Break(self, node):
- self.set_local(BREAK_USED, True)
- var_name = self.get_local(CONTROL_VAR_NAME)
+ self.state[_Break].used = True
+ var_name = self.state[_Break].control_var_name
# TODO(mdan): This will fail when expanded inside a top-level else block.
template = """
var_name = True
@@ -57,12 +62,12 @@ class BreakStatementTransformer(converter.Base):
block=block)
return node
- def _track_body(self, nodes, break_var):
- self.enter_local_scope()
- self.set_local(CONTROL_VAR_NAME, break_var)
+ def _process_body(self, nodes, break_var):
+ self.state[_Break].enter()
+ self.state[_Break].control_var_name = break_var
nodes = self.visit_block(nodes)
- break_used = self.get_local(BREAK_USED, False)
- self.exit_local_scope()
+ break_used = self.state[_Break].used
+ self.state[_Break].exit()
return nodes, break_used
def visit_While(self, node):
@@ -70,7 +75,7 @@ class BreakStatementTransformer(converter.Base):
break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
node.test = self.visit(node.test)
- node.body, break_used = self._track_body(node.body, break_var)
+ node.body, break_used = self._process_body(node.body, break_var)
# A break in the else clause applies to the containing scope.
node.orelse = self.visit_block(node.orelse)
@@ -101,7 +106,7 @@ class BreakStatementTransformer(converter.Base):
node.target = self.visit(node.target)
node.iter = self.visit(node.iter)
- node.body, break_used = self._track_body(node.body, break_var)
+ node.body, break_used = self._process_body(node.body, break_var)
# A break in the else clause applies to the containing scope.
node.orelse = self.visit_block(node.orelse)
@@ -138,4 +143,4 @@ class BreakStatementTransformer(converter.Base):
def transform(node, ctx):
- return BreakStatementTransformer(ctx).visit(node)
+ return BreakTransformer(ctx).visit(node)