diff options
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py')
-rw-r--r-- | tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py | 263 |
1 files changed, 263 insertions, 0 deletions
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py new file mode 100644 index 0000000000..243fe804b2 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py @@ -0,0 +1,263 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for reaching_definitions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions +from tensorflow.python.platform import test + + +class DefinitionInfoTest(test.TestCase): + + def _parse_and_analyze(self, test_fn): + node, source = parser.parse_entity(test_fn) + entity_info = transformer.EntityInfo( + source_code=source, + source_file=None, + namespace={}, + arg_values=None, + arg_types=None, + owner_type=None) + node = qual_names.resolve(node) + node = activity.resolve(node, entity_info) + graphs = cfg.build(node) + node = reaching_definitions.resolve(node, entity_info, graphs, + reaching_definitions.Definition) + return node + + def assertHasDefs(self, node, num): + defs = anno.getanno(node, anno.Static.DEFINITIONS) + self.assertEqual(len(defs), num) + for r in defs: + self.assertIsInstance(r, reaching_definitions.Definition) + + def assertHasDefinedIn(self, node, expected): + defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) + defined_in_str = set(str(v) for v in defined_in) + if not expected: + expected = () + if not isinstance(expected, tuple): + expected = (expected,) + self.assertSetEqual(defined_in_str, set(expected)) + + def assertSameDef(self, first, second): + self.assertHasDefs(first, 1) + self.assertHasDefs(second, 1) + self.assertIs( + anno.getanno(first, anno.Static.DEFINITIONS)[0], + anno.getanno(second, anno.Static.DEFINITIONS)[0]) + + def assertNotSameDef(self, first, second): + self.assertHasDefs(first, 1) + self.assertHasDefs(second, 1) + self.assertIsNot( + anno.getanno(first, anno.Static.DEFINITIONS)[0], + anno.getanno(second, anno.Static.DEFINITIONS)[0]) + + def test_conditional(self): + + def test_fn(a, b): + a = [] + if b: + a = [] + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasDefs(fn_body[0].targets[0], 1) + self.assertHasDefs(fn_body[1].test, 1) + self.assertHasDefs(fn_body[1].body[0].targets[0], 1) + self.assertHasDefs(fn_body[2].value, 2) + + self.assertHasDefinedIn(fn_body[1], ('a', 'b')) + + def test_while(self): + + def test_fn(a): + max(a) + while True: + a = a + a = a + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasDefs(fn_body[0].value.args[0], 1) + self.assertHasDefs(fn_body[1].body[0].targets[0], 1) + self.assertHasDefs(fn_body[1].body[1].targets[0], 1) + self.assertHasDefs(fn_body[1].body[1].value, 1) + # The loop does have an invariant test, but the CFG doesn't know that. + self.assertHasDefs(fn_body[1].body[0].value, 2) + self.assertHasDefs(fn_body[2].value, 2) + + def test_while_else(self): + + def test_fn(x, i): + y = 0 + while x: + x += i + if i: + break + else: + y = 1 + return x, y + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasDefs(fn_body[0].targets[0], 1) + self.assertHasDefs(fn_body[1].test, 2) + self.assertHasDefs(fn_body[1].body[0].target, 1) + self.assertHasDefs(fn_body[1].body[1].test, 1) + self.assertHasDefs(fn_body[1].orelse[0].targets[0], 1) + self.assertHasDefs(fn_body[2].value.elts[0], 2) + self.assertHasDefs(fn_body[2].value.elts[1], 2) + + def test_for_else(self): + + def test_fn(x, i): + y = 0 + for i in x: + x += i + if i: + break + else: + continue + else: + y = 1 + return x, y + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasDefs(fn_body[0].targets[0], 1) + self.assertHasDefs(fn_body[1].target, 1) + self.assertHasDefs(fn_body[1].body[0].target, 1) + self.assertHasDefs(fn_body[1].body[1].test, 1) + self.assertHasDefs(fn_body[1].orelse[0].targets[0], 1) + self.assertHasDefs(fn_body[2].value.elts[0], 2) + self.assertHasDefs(fn_body[2].value.elts[1], 2) + + def test_nested_functions(self): + + def test_fn(a, b): + a = [] + if b: + a = [] + + def foo(): + return a + + foo() + + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + def_of_a_in_if = fn_body[1].body[0].targets[0] + + self.assertHasDefs(fn_body[0].targets[0], 1) + self.assertHasDefs(fn_body[1].test, 1) + self.assertHasDefs(def_of_a_in_if, 1) + self.assertHasDefs(fn_body[2].value, 2) + + inner_fn_body = fn_body[1].body[1].body + self.assertSameDef(inner_fn_body[0].value, def_of_a_in_if) + + def test_nested_functions_isolation(self): + + def test_fn(a): + a = 0 + + def child(): + a = 1 + return a + + child() + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + parent_return = fn_body[3] + child_return = fn_body[1].body[1] + # The assignment `a = 1` makes `a` local to `child`. + self.assertNotSameDef(parent_return.value, child_return.value) + + def test_function_call_in_with(self): + + def foo(_): + pass + + def test_fn(a): + with foo(a): + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasDefs(fn_body[0].items[0].context_expr.func, 0) + self.assertHasDefs(fn_body[0].items[0].context_expr.args[0], 1) + + def test_mutation_subscript(self): + + def test_fn(a): + l = [] + l[0] = a + return l + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + creation = fn_body[0].targets[0] + mutation = fn_body[1].targets[0].value + use = fn_body[2].value + self.assertSameDef(creation, mutation) + self.assertSameDef(creation, use) + + def test_replacement(self): + + def foo(a): + return a + + def test_fn(a): + a = foo(a) + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + param = node.body[0].args.args[0] + source = fn_body[0].value.args[0] + target = fn_body[0].targets[0] + retval = fn_body[1].value + self.assertSameDef(param, source) + self.assertNotSameDef(source, target) + self.assertSameDef(target, retval) + + +if __name__ == '__main__': + test.main() |