diff options
Diffstat (limited to 'tensorflow/python/autograph/pyct/static_analysis/type_info_test.py')
-rw-r--r-- | tensorflow/python/autograph/pyct/static_analysis/type_info_test.py | 207 |
1 files changed, 207 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py new file mode 100644 index 0000000000..34ba3d2f13 --- /dev/null +++ b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py @@ -0,0 +1,207 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for type_info module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.autograph.pyct import anno +from tensorflow.python.autograph.pyct import cfg +from tensorflow.python.autograph.pyct import parser +from tensorflow.python.autograph.pyct import qual_names +from tensorflow.python.autograph.pyct import transformer +from tensorflow.python.autograph.pyct.static_analysis import activity +from tensorflow.python.autograph.pyct.static_analysis import live_values +from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions +from tensorflow.python.autograph.pyct.static_analysis import type_info +from tensorflow.python.client import session +from tensorflow.python.platform import test +from tensorflow.python.training import training + + +class ScopeTest(test.TestCase): + + def test_basic(self): + scope = type_info.Scope(None) + self.assertFalse(scope.hasval('foo')) + + scope.setval('foo', 'bar') + self.assertTrue(scope.hasval('foo')) + + self.assertFalse(scope.hasval('baz')) + + def test_nesting(self): + scope = type_info.Scope(None) + scope.setval('foo', '') + + child = type_info.Scope(scope) + self.assertTrue(child.hasval('foo')) + self.assertTrue(scope.hasval('foo')) + + child.setval('bar', '') + self.assertTrue(child.hasval('bar')) + self.assertFalse(scope.hasval('bar')) + + +class TypeInfoResolverTest(test.TestCase): + + def _parse_and_analyze(self, + test_fn, + namespace, + arg_types=None): + node, source = parser.parse_entity(test_fn) + entity_info = transformer.EntityInfo( + source_code=source, + source_file=None, + namespace=namespace, + arg_values=None, + arg_types=arg_types, + owner_type=None) + node = qual_names.resolve(node) + graphs = cfg.build(node) + node = activity.resolve(node, entity_info) + node = reaching_definitions.resolve(node, entity_info, graphs, + reaching_definitions.Definition) + node = live_values.resolve(node, entity_info, {}) + node = type_info.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, {}) + return node + + def test_constructor_detection(self): + + def test_fn(): + opt = training.GradientDescentOptimizer(0.1) + return opt + + node = self._parse_and_analyze(test_fn, {'training': training}) + call_node = node.body[0].body[0].value + self.assertEquals(training.GradientDescentOptimizer, + anno.getanno(call_node, 'type')) + self.assertEquals((training.__name__, 'GradientDescentOptimizer'), + anno.getanno(call_node, 'type_fqn')) + + def test_class_members_of_detected_constructor(self): + + def test_fn(): + opt = training.GradientDescentOptimizer(0.1) + opt.minimize(0) + + node = self._parse_and_analyze(test_fn, {'training': training}) + method_call = node.body[0].body[1].value.func + self.assertEquals(training.GradientDescentOptimizer.minimize, + anno.getanno(method_call, 'live_val')) + + def test_class_members_in_with_stmt(self): + + def test_fn(x): + with session.Session() as sess: + sess.run(x) + + node = self._parse_and_analyze(test_fn, {'session': session}) + constructor_call = node.body[0].body[0].items[0].context_expr + self.assertEquals(session.Session, anno.getanno(constructor_call, 'type')) + self.assertEquals((session.__name__, 'Session'), + anno.getanno(constructor_call, 'type_fqn')) + + method_call = node.body[0].body[0].body[0].value.func + self.assertEquals(session.Session.run, anno.getanno(method_call, + 'live_val')) + + def test_constructor_data_dependent(self): + + def test_fn(x): + if x > 0: + opt = training.GradientDescentOptimizer(0.1) + else: + opt = training.GradientDescentOptimizer(0.01) + opt.minimize(0) + + node = self._parse_and_analyze(test_fn, {'training': training}) + method_call = node.body[0].body[1].value.func + self.assertFalse(anno.hasanno(method_call, 'live_val')) + + def test_parameter_class_members(self): + + def test_fn(opt): + opt.minimize(0) + + node = self._parse_and_analyze(test_fn, {}) + method_call = node.body[0].body[0].value.func + self.assertFalse(anno.hasanno(method_call, 'live_val')) + + def test_parameter_class_members_with_value_hints(self): + + def test_fn(opt): + opt.minimize(0) + + node = self._parse_and_analyze( + test_fn, {}, + arg_types={ + 'opt': (training.GradientDescentOptimizer.__name__, + training.GradientDescentOptimizer) + }) + + method_call = node.body[0].body[0].value.func + self.assertEquals(training.GradientDescentOptimizer.minimize, + anno.getanno(method_call, 'live_val')) + + def test_function_variables(self): + + def bar(): + pass + + def test_fn(): + foo = bar + foo() + + node = self._parse_and_analyze(test_fn, {'bar': bar}) + method_call = node.body[0].body[1].value.func + self.assertFalse(anno.hasanno(method_call, 'live_val')) + + def test_nested_members(self): + + def test_fn(): + foo = training.GradientDescentOptimizer(0.1) + foo.bar.baz() + + node = self._parse_and_analyze(test_fn, {'training': training}) + method_call = node.body[0].body[1].value.func + self.assertFalse(anno.hasanno(method_call, 'live_val')) + + def test_nested_unpacking(self): + + class Foo(object): + pass + + class Bar(object): + pass + + def test_fn(): + a, (b, c) = (Foo(), (Bar(), Foo())) + return a, b, c + + node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar}) + a, b, c = node.body[0].body[1].value.elts + self.assertEquals(anno.getanno(a, 'type'), Foo) + self.assertEquals(anno.getanno(b, 'type'), Bar) + self.assertEquals(anno.getanno(c, 'type'), Foo) + self.assertFalse(anno.hasanno(a, 'live_val')) + self.assertFalse(anno.hasanno(b, 'live_val')) + self.assertFalse(anno.hasanno(c, 'live_val')) + + +if __name__ == '__main__': + test.main() |