aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-06 19:57:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-06 20:01:13 -0800
commit4380d6eff899ca2f5e14d4d92f7fcf770b36b099 (patch)
tree56abf421c7591e3723c76a4ec1a1cae81752edaf
parentecbb8b1ccac295537827dfe1ca25ddb03ca5f22b (diff)
Add basic support for explicit type annotations. This is done by inserting a no-op function call. Note that this is meant as fallback, and we prefer the following alternatives (in their order) for inferring the type:
1. Automatic from context, e.g. the type of a list based on the elements added to it (WIP) 2. Type annotations (Python 3.6+ only) PiperOrigin-RevId: 188120527
-rw-r--r--tensorflow/contrib/py2tf/impl/conversion.py37
-rw-r--r--tensorflow/contrib/py2tf/pyct/context.py6
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/BUILD1
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py63
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py25
-rw-r--r--tensorflow/contrib/py2tf/utils/BUILD1
-rw-r--r--tensorflow/contrib/py2tf/utils/__init__.py1
-rw-r--r--tensorflow/contrib/py2tf/utils/type_hints.py41
8 files changed, 153 insertions, 22 deletions
diff --git a/tensorflow/contrib/py2tf/impl/conversion.py b/tensorflow/contrib/py2tf/impl/conversion.py
index c6f4988375..97ee4ca435 100644
--- a/tensorflow/contrib/py2tf/impl/conversion.py
+++ b/tensorflow/contrib/py2tf/impl/conversion.py
@@ -41,6 +41,7 @@ from tensorflow.contrib.py2tf.pyct import qual_names
from tensorflow.contrib.py2tf.pyct.static_analysis import activity
from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
+from tensorflow.contrib.py2tf.utils import type_hints
from tensorflow.python.util import tf_inspect
@@ -48,7 +49,9 @@ from tensorflow.python.util import tf_inspect
class ConversionMap(object):
- """ConversionMaps keep track of converting function hierarchies.
+ """ConversionMap keeps track of converting function hierarchies.
+
+ This object is mutable, and is updated as functions are converted.
Attributes:
recursive: Whether to recusrively convert any functions that the decorator
@@ -154,14 +157,20 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types):
conversion_map.add_to_cache(o, node)
if conversion_map.recursive:
- for obj in conversion_map.name_map.keys():
- if obj not in conversion_map.dependency_cache:
- if (hasattr(obj, 'im_class') and
- getattr(obj, 'im_class') not in conversion_map.partial_types):
- # Class members are converted with their objects, unless they're
- # only converted partially.
- continue
- entity_to_graph(obj, conversion_map, {}, {})
+ while True:
+ candidate = None
+ for obj in conversion_map.name_map.keys():
+ if obj not in conversion_map.dependency_cache:
+ candidate = obj
+ break
+ if candidate is None:
+ break
+ if (hasattr(candidate, 'im_class') and
+ getattr(candidate, 'im_class') not in conversion_map.partial_types):
+ # Class members are converted with their objects, unless they're
+ # only converted partially.
+ continue
+ entity_to_graph(candidate, conversion_map, {}, {})
return node, new_name
@@ -169,9 +178,10 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types):
def class_to_graph(c, conversion_map):
"""Specialization of `entity_to_graph` for classes."""
converted_members = {}
- members = tf_inspect.getmembers(c, predicate=tf_inspect.ismethod)
+ method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
+ members = tf_inspect.getmembers(c, predicate=method_filter)
if not members:
- raise ValueError('Cannot convert %s: it has no member methods.')
+ raise ValueError('Cannot convert %s: it has no member methods.' % c)
class_namespace = None
for _, m in members:
@@ -191,7 +201,7 @@ def class_to_graph(c, conversion_map):
class_name,
bases=[],
keywords=[],
- body=converted_members.values(),
+ body=list(converted_members.values()),
decorator_list=[])
return node, class_name
@@ -233,7 +243,8 @@ def function_to_graph(f, conversion_map, arg_values, arg_types,
arg_values=arg_values,
arg_types=arg_types,
owner_type=owner_type,
- recursive=conversion_map.recursive)
+ recursive=conversion_map.recursive,
+ type_annotation_func=type_hints.set_element_type)
node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators)
# TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
diff --git a/tensorflow/contrib/py2tf/pyct/context.py b/tensorflow/contrib/py2tf/pyct/context.py
index 4fcf2a687d..b34015cfd2 100644
--- a/tensorflow/contrib/py2tf/pyct/context.py
+++ b/tensorflow/contrib/py2tf/pyct/context.py
@@ -22,6 +22,8 @@ from __future__ import print_function
class EntityContext(object):
"""Contains information about an entity, like source code.
+ In general, objects of this class should be considered immutable.
+
Attributes:
namer: Namer that matches the contract of all converters.
source_code: The entity's source code.
@@ -33,8 +35,9 @@ class EntityContext(object):
owner_type: The surrounding class type of the function, if present.
"""
+ # TODO(mdan): Remove the default and update tests.
def __init__(self, namer, source_code, source_file, namespace, arg_values,
- arg_types, owner_type, recursive):
+ arg_types, owner_type, recursive, type_annotation_func=None):
self.namer = namer
self.source_code = source_code
self.source_file = source_file
@@ -43,3 +46,4 @@ class EntityContext(object):
self.arg_types = {} if arg_types is None else arg_types
self.owner_type = owner_type
self.recursive = recursive
+ self.type_annotation_func = type_annotation_func
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD b/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD
index fbfce18c60..2799b56a00 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD
@@ -60,6 +60,7 @@ py_test(
deps = [
":static_analysis",
"//tensorflow/contrib/py2tf/pyct",
+ "//tensorflow/contrib/py2tf/utils",
"//tensorflow/python:client_testlib",
],
)
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
index 8203bda0f9..5556a58c02 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
@@ -14,9 +14,29 @@
# ==============================================================================
"""Type resolution.
+This analyzer uses known live values to further infer object types. This
+may include for instance constructed objects and object member functions.
+
+In addition, the analyzer will also process annotations for TF (staged) type
+annotations.
+
Requires annotations generated by LiveValuesResolver.
"""
+# TODO(mdan): This would be more robust with a CFG.
+# Situations with multiple reaching modifications (e.g. modified inside and
+# outside a control flow statement) should be more robustly detected and
+# analyzed.
+
+# TODO(mdan): Look into using Python AST's type annotation fields instead.
+# It would be desirable to use that mechanism if we can.
+# Some caveats to consider: We may need to annotate other nodes like
+# Attribute. It may also not be feasible for us to faithfully to replicate
+# PY3's type annotations where it isn't available. It would also require us
+# to design rigorous type definitions that can accommodate Python types
+# as well as TensorFLow dtypes and shapes.
+
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -29,7 +49,7 @@ from tensorflow.python.util import tf_inspect
class Scope(object):
- """Encloses symbol value references.
+ """Tracks symbol value references.
Attributes:
values: A dict mapping string to gast.Node, containing the value that was
@@ -138,11 +158,14 @@ class TypeInfoResolver(transformer.Base):
elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn):
# E.g. if we had
# a = b
- # then for future references to `a` we should have traced_source = `b`
- traced_source = self.scope.getval(qn)
- if anno.hasanno(traced_source, 'type'):
- anno.setanno(node, 'type', anno.getanno(traced_source, 'type'))
- anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn'))
+ # then for future references to `a` we should have definition = `b`
+ definition = self.scope.getval(qn)
+ if anno.hasanno(definition, 'type'):
+ anno.setanno(node, 'type', anno.getanno(definition, 'type'))
+ anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn'))
+ if anno.hasanno(definition, 'element_type'):
+ anno.setanno(node, 'element_type',
+ anno.getanno(definition, 'element_type'))
return node
def _process_variable_assignment(self, source, targets):
@@ -181,6 +204,34 @@ class TypeInfoResolver(transformer.Base):
self._process_variable_assignment(node.value, node.targets)
return node
+ def visit_Call(self, node):
+ if anno.hasanno(node.func, 'live_val'):
+ # Symbols targeted by the "set_type" marker function are assigned the data
+ # type that it specified.
+ if (anno.getanno(node.func, 'live_val') is
+ self.context.type_annotation_func):
+ # Expecting the actual type to be the second argument.
+ if len(node.args) != 2:
+ raise ValueError('"%s" must have exactly two parameters'
+ % self.context.type_annotation_func)
+ if not anno.hasanno(node.args[0], anno.Basic.QN):
+ raise ValueError('the first argument of "%s" must by a symbol'
+ % self.context.type_annotation_func)
+ if not anno.hasanno(node.args[1], 'live_val'):
+ raise ValueError(
+ 'the second argument of "%s" must be statically resolvable' %
+ self.context.type_annotation_func)
+ target_symbol = anno.getanno(node.args[0], anno.Basic.QN)
+ element_type = anno.getanno(node.args[1], 'live_val')
+ # Find the definition of this symbol and annotate it with the given
+ # data type. That in turn will cause future uses of the symbol
+ # to receive the same type annotation.
+ definition = self.scope.getval(target_symbol)
+ anno.setanno(node, 'element_type', element_type)
+ anno.setanno(definition, 'element_type', element_type)
+ # TODO(mdan): Should we update references between definition and here?
+ return self.generic_visit(node)
+
def resolve(node, context):
return TypeInfoResolver(context).visit(node)
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
index a3e78202c8..0d9d5a85f0 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.py2tf import utils
from tensorflow.contrib.py2tf.pyct import anno
from tensorflow.contrib.py2tf.pyct import context
from tensorflow.contrib.py2tf.pyct import parser
@@ -56,7 +57,10 @@ class ScopeTest(test.TestCase):
class TypeInfoResolverTest(test.TestCase):
- def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
+ def _parse_and_analyze(self,
+ test_fn,
+ namespace,
+ arg_types=None):
node, source = parser.parse_entity(test_fn)
ctx = context.EntityContext(
namer=None,
@@ -66,7 +70,8 @@ class TypeInfoResolverTest(test.TestCase):
arg_values=None,
arg_types=arg_types,
owner_type=None,
- recursive=True)
+ recursive=True,
+ type_annotation_func=utils.set_element_type)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx)
node = live_values.resolve(node, ctx, {})
@@ -175,6 +180,22 @@ class TypeInfoResolverTest(test.TestCase):
method_call = node.body[0].body[1].value.func
self.assertFalse(anno.hasanno(method_call, 'live_val'))
+ def test_type_annotation(self):
+
+ class Foo(object):
+ pass
+
+ def test_fn():
+ f = []
+ f = utils.set_element_type(f, Foo)
+ return f
+
+ node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils})
+ f_def = node.body[0].body[0].value
+ self.assertEqual(anno.getanno(f_def, 'element_type'), Foo)
+ f_ref = node.body[0].body[1].value
+ self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/py2tf/utils/BUILD b/tensorflow/contrib/py2tf/utils/BUILD
index 63261d5043..c6a894b508 100644
--- a/tensorflow/contrib/py2tf/utils/BUILD
+++ b/tensorflow/contrib/py2tf/utils/BUILD
@@ -28,6 +28,7 @@ py_library(
"tensor_list.py",
"testing.py",
"type_check.py",
+ "type_hints.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
diff --git a/tensorflow/contrib/py2tf/utils/__init__.py b/tensorflow/contrib/py2tf/utils/__init__.py
index 313e5c97cc..997c815887 100644
--- a/tensorflow/contrib/py2tf/utils/__init__.py
+++ b/tensorflow/contrib/py2tf/utils/__init__.py
@@ -27,3 +27,4 @@ from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_while
from tensorflow.contrib.py2tf.utils.py_func import wrap_py_func
from tensorflow.contrib.py2tf.utils.testing import fake_tf
from tensorflow.contrib.py2tf.utils.type_check import is_tensor
+from tensorflow.contrib.py2tf.utils.type_hints import set_element_type
diff --git a/tensorflow/contrib/py2tf/utils/type_hints.py b/tensorflow/contrib/py2tf/utils/type_hints.py
new file mode 100644
index 0000000000..aeb9e54561
--- /dev/null
+++ b/tensorflow/contrib/py2tf/utils/type_hints.py
@@ -0,0 +1,41 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""No-op utilities that provide static type hints.
+
+These are used when the data type is not known at creation, for instance in the
+case of empty lists.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+def set_element_type(entity, dtype, shape=None):
+ """Indicates that the entity is expected hold items of specified type.
+
+ This function is a no-op. Its presence merely marks the data type of its
+ argument. The staged TensorFlow ops will reflect and assert this data type.
+
+ Args:
+ entity: A Tensor or TensorArray.
+ dtype: TensorFlow dtype value to assert for entity.
+ shape: Optional shape to assert for entity.
+ Returns:
+ The value of entity, unchanged.
+ """
+ del dtype
+ del shape
+ return entity