aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-09-26 09:13:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 09:25:39 -0700
commitd7de49e456fc84416fbf3a6de7ad1ed6c12d7a20 (patch)
treecb5adbcb9977bb8b9ae8b125c30452a47b285163 /tensorflow/python/autograph
parentc3203eb8bf0d7ae9dce133f982884622f666c681 (diff)
The return value checker should ignore inner functions.
PiperOrigin-RevId: 214614921
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r--tensorflow/python/autograph/converters/return_statements.py14
-rw-r--r--tensorflow/python/autograph/converters/return_statements_test.py12
2 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/converters/return_statements.py b/tensorflow/python/autograph/converters/return_statements.py
index 62da045d6a..496c99e3b5 100644
--- a/tensorflow/python/autograph/converters/return_statements.py
+++ b/tensorflow/python/autograph/converters/return_statements.py
@@ -212,6 +212,7 @@ class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor):
def __init__(self):
self.cant_return = False
+ self.function_level = 0
super(DetectReturnInUnsupportedControlFlow, self).__init__()
def visit_While(self, node):
@@ -229,6 +230,12 @@ class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor):
self.generic_visit(node)
self.cant_return = False
+ def visit_FunctionDef(self, node):
+ if not self.function_level:
+ self.function_level += 1
+ self.generic_visit(node)
+ self.function_level -= 1
+
def visit_Return(self, node):
if self.cant_return:
raise ValueError(
@@ -242,6 +249,7 @@ class DetectReturnInConditional(gast.NodeVisitor):
def __init__(self):
self.cant_return = False
+ self.function_level = 0
super(DetectReturnInConditional, self).__init__()
def visit_If(self, node):
@@ -249,6 +257,12 @@ class DetectReturnInConditional(gast.NodeVisitor):
self.generic_visit(node)
self.cant_return = False
+ def visit_FunctionDef(self, node):
+ if not self.function_level:
+ self.function_level += 1
+ self.generic_visit(node)
+ self.function_level -= 1
+
def visit_Return(self, node):
if self.cant_return:
raise ValueError(
diff --git a/tensorflow/python/autograph/converters/return_statements_test.py b/tensorflow/python/autograph/converters/return_statements_test.py
index 01dd03da0b..762fbc6f60 100644
--- a/tensorflow/python/autograph/converters/return_statements_test.py
+++ b/tensorflow/python/autograph/converters/return_statements_test.py
@@ -151,6 +151,18 @@ class SingleReturnTest(converter_testing.TestCase):
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, -2)
+ def test_nested_functions_in_control_flow(self):
+
+ def test_fn(x):
+
+ if x:
+ def inner_fn(y):
+ return y
+ inner_fn(x)
+
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, -2)
+
def test_loop(self):
def test_fn(x):