diff options
Diffstat (limited to 'tensorflow/contrib/autograph/converters/break_statements.py')
-rw-r--r-- | tensorflow/contrib/autograph/converters/break_statements.py | 35 |
1 files changed, 20 insertions, 15 deletions
diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index a990e359a2..2a60750bda 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Canonicalizes break statements by de-sugaring into a control boolean.""" +"""Lowers break statements to conditionals.""" from __future__ import absolute_import from __future__ import division @@ -24,17 +24,22 @@ from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno -# Tags for local state. -BREAK_USED = 'break_used' -CONTROL_VAR_NAME = 'control_var_name' +class _Break(object): + def __init__(self): + self.used = False + self.control_var_name = None -class BreakStatementTransformer(converter.Base): + def __repr__(self): + return 'used: %s, var: %s' % (self.used, self.control_var_name) + + +class BreakTransformer(converter.Base): """Canonicalizes break statements into additional conditionals.""" def visit_Break(self, node): - self.set_local(BREAK_USED, True) - var_name = self.get_local(CONTROL_VAR_NAME) + self.state[_Break].used = True + var_name = self.state[_Break].control_var_name # TODO(mdan): This will fail when expanded inside a top-level else block. template = """ var_name = True @@ -57,12 +62,12 @@ class BreakStatementTransformer(converter.Base): block=block) return node - def _track_body(self, nodes, break_var): - self.enter_local_scope() - self.set_local(CONTROL_VAR_NAME, break_var) + def _process_body(self, nodes, break_var): + self.state[_Break].enter() + self.state[_Break].control_var_name = break_var nodes = self.visit_block(nodes) - break_used = self.get_local(BREAK_USED, False) - self.exit_local_scope() + break_used = self.state[_Break].used + self.state[_Break].exit() return nodes, break_used def visit_While(self, node): @@ -70,7 +75,7 @@ class BreakStatementTransformer(converter.Base): break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) - node.body, break_used = self._track_body(node.body, break_var) + node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) @@ -101,7 +106,7 @@ class BreakStatementTransformer(converter.Base): node.target = self.visit(node.target) node.iter = self.visit(node.iter) - node.body, break_used = self._track_body(node.body, break_var) + node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) @@ -138,4 +143,4 @@ class BreakStatementTransformer(converter.Base): def transform(node, ctx): - return BreakStatementTransformer(ctx).visit(node) + return BreakTransformer(ctx).visit(node) |