diff options
Diffstat (limited to 'tensorflow/python/autograph/pyct/static_analysis/live_values_test.py')
-rw-r--r-- | tensorflow/python/autograph/pyct/static_analysis/live_values_test.py | 132 |
1 files changed, 132 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py new file mode 100644 index 0000000000..882c380b78 --- /dev/null +++ b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py @@ -0,0 +1,132 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for live_values module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.python.autograph.pyct import anno +from tensorflow.python.autograph.pyct import cfg +from tensorflow.python.autograph.pyct import parser +from tensorflow.python.autograph.pyct import qual_names +from tensorflow.python.autograph.pyct import transformer +from tensorflow.python.autograph.pyct.static_analysis import activity +from tensorflow.python.autograph.pyct.static_analysis import live_values +from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions +from tensorflow.python.autograph.pyct.static_analysis import type_info +from tensorflow.python.framework import constant_op +from tensorflow.python.platform import test + + +class LiveValuesResolverTest(test.TestCase): + + def _parse_and_analyze(self, + test_fn, + namespace, + literals=None, + arg_types=None): + literals = literals or {} + node, source = parser.parse_entity(test_fn) + entity_info = transformer.EntityInfo( + source_code=source, + source_file=None, + namespace=namespace, + arg_values=None, + arg_types=arg_types, + owner_type=None) + node = qual_names.resolve(node) + graphs = cfg.build(node) + node = activity.resolve(node, entity_info) + node = reaching_definitions.resolve(node, entity_info, graphs, + reaching_definitions.Definition) + node = live_values.resolve(node, entity_info, literals) + node = type_info.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, literals) + return node + + def test_literals(self): + + a = None + + def test_fn(): + return a + + node = self._parse_and_analyze(test_fn, {}, literals={'a': 'bar'}) + retval_node = node.body[0].body[0].value + self.assertEquals('bar', anno.getanno(retval_node, 'live_val')) + + def test_primitive_values(self): + + a = None + + def test_fn(): + return a + + node = self._parse_and_analyze(test_fn, {'a': True}) + retval_node = node.body[0].body[0].value + if six.PY2: + self.assertEqual( + anno.getanno(retval_node, 'fqn'), ('__builtin__', 'bool')) + else: + self.assertEqual(anno.getanno(retval_node, 'fqn'), ('builtins', 'bool')) + + def test_namespace(self): + + def foo(): + return 'bar' + + def test_fn(): + return foo() + + node = self._parse_and_analyze(test_fn, {'foo': foo}) + func_node = node.body[0].body[0].value.func + self.assertEquals(foo, anno.getanno(func_node, 'live_val')) + self.assertEquals(('foo',), anno.getanno(func_node, 'fqn')) + + def test_attribute_names(self): + + def test_fn(): + return constant_op.constant(0) + + node = self._parse_and_analyze(test_fn, {'constant_op': constant_op}) + func_node = node.body[0].body[0].value.func + self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val')) + self.assertEquals((constant_op.__name__, 'constant'), + anno.getanno(func_node, 'fqn')) + + def test_attributes_with_type_hints(self): + + class TestClass(object): + + def member(self): + pass + + def test_fn(self): + return self.member() + + node = self._parse_and_analyze( + TestClass.test_fn, {'constant_op': constant_op}, + arg_types={'self': (TestClass.__name__, TestClass)}) + func_node = node.body[0].body[0].value.func + self.assertEquals(TestClass.member, anno.getanno(func_node, 'live_val')) + self.assertEquals(TestClass, anno.getanno(func_node, 'parent_type')) + self.assertEquals(('TestClass', 'member'), anno.getanno(func_node, 'fqn')) + + +if __name__ == '__main__': + test.main() |