diff options
author | Dan Moldovan <mdan@google.com> | 2018-09-26 09:13:54 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 09:25:39 -0700 |
commit | d7de49e456fc84416fbf3a6de7ad1ed6c12d7a20 (patch) | |
tree | cb5adbcb9977bb8b9ae8b125c30452a47b285163 /tensorflow/python/autograph | |
parent | c3203eb8bf0d7ae9dce133f982884622f666c681 (diff) |
The return value checker should ignore inner functions.
PiperOrigin-RevId: 214614921
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r-- | tensorflow/python/autograph/converters/return_statements.py | 14 | ||||
-rw-r--r-- | tensorflow/python/autograph/converters/return_statements_test.py | 12 |
2 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/converters/return_statements.py b/tensorflow/python/autograph/converters/return_statements.py index 62da045d6a..496c99e3b5 100644 --- a/tensorflow/python/autograph/converters/return_statements.py +++ b/tensorflow/python/autograph/converters/return_statements.py @@ -212,6 +212,7 @@ class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor): def __init__(self): self.cant_return = False + self.function_level = 0 super(DetectReturnInUnsupportedControlFlow, self).__init__() def visit_While(self, node): @@ -229,6 +230,12 @@ class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor): self.generic_visit(node) self.cant_return = False + def visit_FunctionDef(self, node): + if not self.function_level: + self.function_level += 1 + self.generic_visit(node) + self.function_level -= 1 + def visit_Return(self, node): if self.cant_return: raise ValueError( @@ -242,6 +249,7 @@ class DetectReturnInConditional(gast.NodeVisitor): def __init__(self): self.cant_return = False + self.function_level = 0 super(DetectReturnInConditional, self).__init__() def visit_If(self, node): @@ -249,6 +257,12 @@ class DetectReturnInConditional(gast.NodeVisitor): self.generic_visit(node) self.cant_return = False + def visit_FunctionDef(self, node): + if not self.function_level: + self.function_level += 1 + self.generic_visit(node) + self.function_level -= 1 + def visit_Return(self, node): if self.cant_return: raise ValueError( diff --git a/tensorflow/python/autograph/converters/return_statements_test.py b/tensorflow/python/autograph/converters/return_statements_test.py index 01dd03da0b..762fbc6f60 100644 --- a/tensorflow/python/autograph/converters/return_statements_test.py +++ b/tensorflow/python/autograph/converters/return_statements_test.py @@ -151,6 +151,18 @@ class SingleReturnTest(converter_testing.TestCase): self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(test_fn, -2) + def test_nested_functions_in_control_flow(self): + + def test_fn(x): + + if x: + def inner_fn(y): + return y + inner_fn(x) + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + def test_loop(self): def test_fn(x): |