aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/pyct/static_analysis/type_info_test.py')
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/type_info_test.py207
1 files changed, 207 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py
new file mode 100644
index 0000000000..34ba3d2f13
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py
@@ -0,0 +1,207 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for type_info module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import live_values
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct.static_analysis import type_info
+from tensorflow.python.client import session
+from tensorflow.python.platform import test
+from tensorflow.python.training import training
+
+
+class ScopeTest(test.TestCase):
+
+ def test_basic(self):
+ scope = type_info.Scope(None)
+ self.assertFalse(scope.hasval('foo'))
+
+ scope.setval('foo', 'bar')
+ self.assertTrue(scope.hasval('foo'))
+
+ self.assertFalse(scope.hasval('baz'))
+
+ def test_nesting(self):
+ scope = type_info.Scope(None)
+ scope.setval('foo', '')
+
+ child = type_info.Scope(scope)
+ self.assertTrue(child.hasval('foo'))
+ self.assertTrue(scope.hasval('foo'))
+
+ child.setval('bar', '')
+ self.assertTrue(child.hasval('bar'))
+ self.assertFalse(scope.hasval('bar'))
+
+
+class TypeInfoResolverTest(test.TestCase):
+
+ def _parse_and_analyze(self,
+ test_fn,
+ namespace,
+ arg_types=None):
+ node, source = parser.parse_entity(test_fn)
+ entity_info = transformer.EntityInfo(
+ source_code=source,
+ source_file=None,
+ namespace=namespace,
+ arg_values=None,
+ arg_types=arg_types,
+ owner_type=None)
+ node = qual_names.resolve(node)
+ graphs = cfg.build(node)
+ node = activity.resolve(node, entity_info)
+ node = reaching_definitions.resolve(node, entity_info, graphs,
+ reaching_definitions.Definition)
+ node = live_values.resolve(node, entity_info, {})
+ node = type_info.resolve(node, entity_info)
+ node = live_values.resolve(node, entity_info, {})
+ return node
+
+ def test_constructor_detection(self):
+
+ def test_fn():
+ opt = training.GradientDescentOptimizer(0.1)
+ return opt
+
+ node = self._parse_and_analyze(test_fn, {'training': training})
+ call_node = node.body[0].body[0].value
+ self.assertEquals(training.GradientDescentOptimizer,
+ anno.getanno(call_node, 'type'))
+ self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
+ anno.getanno(call_node, 'type_fqn'))
+
+ def test_class_members_of_detected_constructor(self):
+
+ def test_fn():
+ opt = training.GradientDescentOptimizer(0.1)
+ opt.minimize(0)
+
+ node = self._parse_and_analyze(test_fn, {'training': training})
+ method_call = node.body[0].body[1].value.func
+ self.assertEquals(training.GradientDescentOptimizer.minimize,
+ anno.getanno(method_call, 'live_val'))
+
+ def test_class_members_in_with_stmt(self):
+
+ def test_fn(x):
+ with session.Session() as sess:
+ sess.run(x)
+
+ node = self._parse_and_analyze(test_fn, {'session': session})
+ constructor_call = node.body[0].body[0].items[0].context_expr
+ self.assertEquals(session.Session, anno.getanno(constructor_call, 'type'))
+ self.assertEquals((session.__name__, 'Session'),
+ anno.getanno(constructor_call, 'type_fqn'))
+
+ method_call = node.body[0].body[0].body[0].value.func
+ self.assertEquals(session.Session.run, anno.getanno(method_call,
+ 'live_val'))
+
+ def test_constructor_data_dependent(self):
+
+ def test_fn(x):
+ if x > 0:
+ opt = training.GradientDescentOptimizer(0.1)
+ else:
+ opt = training.GradientDescentOptimizer(0.01)
+ opt.minimize(0)
+
+ node = self._parse_and_analyze(test_fn, {'training': training})
+ method_call = node.body[0].body[1].value.func
+ self.assertFalse(anno.hasanno(method_call, 'live_val'))
+
+ def test_parameter_class_members(self):
+
+ def test_fn(opt):
+ opt.minimize(0)
+
+ node = self._parse_and_analyze(test_fn, {})
+ method_call = node.body[0].body[0].value.func
+ self.assertFalse(anno.hasanno(method_call, 'live_val'))
+
+ def test_parameter_class_members_with_value_hints(self):
+
+ def test_fn(opt):
+ opt.minimize(0)
+
+ node = self._parse_and_analyze(
+ test_fn, {},
+ arg_types={
+ 'opt': (training.GradientDescentOptimizer.__name__,
+ training.GradientDescentOptimizer)
+ })
+
+ method_call = node.body[0].body[0].value.func
+ self.assertEquals(training.GradientDescentOptimizer.minimize,
+ anno.getanno(method_call, 'live_val'))
+
+ def test_function_variables(self):
+
+ def bar():
+ pass
+
+ def test_fn():
+ foo = bar
+ foo()
+
+ node = self._parse_and_analyze(test_fn, {'bar': bar})
+ method_call = node.body[0].body[1].value.func
+ self.assertFalse(anno.hasanno(method_call, 'live_val'))
+
+ def test_nested_members(self):
+
+ def test_fn():
+ foo = training.GradientDescentOptimizer(0.1)
+ foo.bar.baz()
+
+ node = self._parse_and_analyze(test_fn, {'training': training})
+ method_call = node.body[0].body[1].value.func
+ self.assertFalse(anno.hasanno(method_call, 'live_val'))
+
+ def test_nested_unpacking(self):
+
+ class Foo(object):
+ pass
+
+ class Bar(object):
+ pass
+
+ def test_fn():
+ a, (b, c) = (Foo(), (Bar(), Foo()))
+ return a, b, c
+
+ node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar})
+ a, b, c = node.body[0].body[1].value.elts
+ self.assertEquals(anno.getanno(a, 'type'), Foo)
+ self.assertEquals(anno.getanno(b, 'type'), Bar)
+ self.assertEquals(anno.getanno(c, 'type'), Foo)
+ self.assertFalse(anno.hasanno(a, 'live_val'))
+ self.assertFalse(anno.hasanno(b, 'live_val'))
+ self.assertFalse(anno.hasanno(c, 'live_val'))
+
+
+if __name__ == '__main__':
+ test.main()