aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/pyct/static_analysis/live_values_test.py')
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/live_values_test.py132
1 files changed, 132 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py
new file mode 100644
index 0000000000..882c380b78
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py
@@ -0,0 +1,132 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for live_values module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import live_values
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct.static_analysis import type_info
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+
+
+class LiveValuesResolverTest(test.TestCase):
+
+ def _parse_and_analyze(self,
+ test_fn,
+ namespace,
+ literals=None,
+ arg_types=None):
+ literals = literals or {}
+ node, source = parser.parse_entity(test_fn)
+ entity_info = transformer.EntityInfo(
+ source_code=source,
+ source_file=None,
+ namespace=namespace,
+ arg_values=None,
+ arg_types=arg_types,
+ owner_type=None)
+ node = qual_names.resolve(node)
+ graphs = cfg.build(node)
+ node = activity.resolve(node, entity_info)
+ node = reaching_definitions.resolve(node, entity_info, graphs,
+ reaching_definitions.Definition)
+ node = live_values.resolve(node, entity_info, literals)
+ node = type_info.resolve(node, entity_info)
+ node = live_values.resolve(node, entity_info, literals)
+ return node
+
+ def test_literals(self):
+
+ a = None
+
+ def test_fn():
+ return a
+
+ node = self._parse_and_analyze(test_fn, {}, literals={'a': 'bar'})
+ retval_node = node.body[0].body[0].value
+ self.assertEquals('bar', anno.getanno(retval_node, 'live_val'))
+
+ def test_primitive_values(self):
+
+ a = None
+
+ def test_fn():
+ return a
+
+ node = self._parse_and_analyze(test_fn, {'a': True})
+ retval_node = node.body[0].body[0].value
+ if six.PY2:
+ self.assertEqual(
+ anno.getanno(retval_node, 'fqn'), ('__builtin__', 'bool'))
+ else:
+ self.assertEqual(anno.getanno(retval_node, 'fqn'), ('builtins', 'bool'))
+
+ def test_namespace(self):
+
+ def foo():
+ return 'bar'
+
+ def test_fn():
+ return foo()
+
+ node = self._parse_and_analyze(test_fn, {'foo': foo})
+ func_node = node.body[0].body[0].value.func
+ self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
+ self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
+
+ def test_attribute_names(self):
+
+ def test_fn():
+ return constant_op.constant(0)
+
+ node = self._parse_and_analyze(test_fn, {'constant_op': constant_op})
+ func_node = node.body[0].body[0].value.func
+ self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val'))
+ self.assertEquals((constant_op.__name__, 'constant'),
+ anno.getanno(func_node, 'fqn'))
+
+ def test_attributes_with_type_hints(self):
+
+ class TestClass(object):
+
+ def member(self):
+ pass
+
+ def test_fn(self):
+ return self.member()
+
+ node = self._parse_and_analyze(
+ TestClass.test_fn, {'constant_op': constant_op},
+ arg_types={'self': (TestClass.__name__, TestClass)})
+ func_node = node.body[0].body[0].value.func
+ self.assertEquals(TestClass.member, anno.getanno(func_node, 'live_val'))
+ self.assertEquals(TestClass, anno.getanno(func_node, 'parent_type'))
+ self.assertEquals(('TestClass', 'member'), anno.getanno(func_node, 'fqn'))
+
+
+if __name__ == '__main__':
+ test.main()