aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-06-08 11:20:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-08 11:24:23 -0700
commit7eaf8941930c8b1a099b7ec626134b67179c07e3 (patch)
tree0c550be1da7c0de61f5d61eb0f8d2aeeb2518f1e
parentebb67e0d7da53b3b848630e63aaa80f1283d83bd (diff)
Use the new operators for list conversion. Includes list creation, append, pop, stack. Simplify the type annotation mechanism by having it literally copy its arguments, instead of attempting to resolve them.
PiperOrigin-RevId: 199822771
-rw-r--r--tensorflow/contrib/autograph/converters/lists.py233
-rw-r--r--tensorflow/contrib/autograph/converters/lists_test.py130
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/type_info.py40
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py18
4 files changed, 291 insertions, 130 deletions
diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py
index b49521b2c3..c15dfff9e8 100644
--- a/tensorflow/contrib/autograph/converters/lists.py
+++ b/tensorflow/contrib/autograph/converters/lists.py
@@ -33,82 +33,193 @@ from __future__ import print_function
import gast
from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.python.framework import dtypes
+from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+
+
+# Tags for local state.
+POP_USES = 'pop_uses'
class ListTransformer(transformer.Base):
"""Converts lists and related operations to their TF counterpart."""
- def _empty_list(self, node):
- if not anno.hasanno(node, 'element_type'):
- raise NotImplementedError(
- 'type inference for empty lists is not yet supported; '
- 'use set_element_type(<list>, <dtype>) to continue')
- dtype = anno.getanno(node, 'element_type')
- if not isinstance(dtype, dtypes.DType):
- # TODO(mdan): Allow non-TF dtypes?
- # That would be consistent with the dynamic dispatch pattern, but
- # we must make sure that doesn't become confusing.
- raise NotImplementedError('element type "%s" not yet supported' % dtype)
-
- dtype_name = dtype.name
- # TODO(mdan): Does it ever make sense not to use tensor lists?
+ def visit_List(self, node):
+ node = self.generic_visit(node)
template = """
- tf.TensorArray(tf.dtype_name, size=0, dynamic_size=True)
+ ag__.new_list(elements)
"""
- return templates.replace_as_expression(template, dtype_name=dtype_name)
+ return templates.replace_as_expression(template, elements=node)
- def _pre_populated_list(self, node):
- raise NotImplementedError('pre-populated lists')
+ def _replace_append_call(self, node):
+ assert len(node.args) == 1
+ assert isinstance(node.func, gast.Attribute)
+ template = """
+ target = ag__.list_append(target, element)
+ """
+ return templates.replace(
+ template,
+ target=node.func.value,
+ element=node.args[0])
+
+ def _replace_pop_call(self, node):
+ # Expressions that use pop() are converted to a statement + expression.
+ #
+ # For example:
+ #
+ # print(target.pop())
+ #
+ # ... is converted to:
+ #
+ # target, target_pop = ag__.list_pop(target)
+ # print(target_pop)
+ #
+ # Here, we just generate the variable name and swap it in,
+ # and _generate_pop_operation will handle the rest.
+ #
+ # Multiple uses of pop() are allowed:
+ #
+ # print(tartget.pop(), target.pop())
+ # print(tartget.pop().pop())
+ #
+ assert isinstance(node.func, gast.Attribute)
+ scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
+ target_node = node.func.value
+
+ # Attempt to use a related name if can get one. Otherwise use something
+ # generic.
+ if anno.hasanno(target_node, anno.Basic.QN):
+ target_name = anno.getanno(target_node, anno.Basic.QN).ssf()
+ else:
+ target_name = 'list'
+ pop_var_name = self.context.namer.new_symbol(target_name, scope.referenced)
+
+ pop_uses = self.get_local(POP_USES, [])
+ pop_uses.append((node, pop_var_name))
+ self.set_local(POP_USES, pop_uses)
+
+ return templates.replace_as_expression('var_name', var_name=pop_var_name)
+
+ def _replace_stack_call(self, node):
+ assert len(node.args) == 1
+ dtype = anno.getanno(
+ node.args[0],
+ 'element_type',
+ default=templates.replace_as_expression('None'))
+ template = """
+ ag__.list_stack(
+ target,
+ opts=ag__.ListStackOpts(
+ element_dtype=dtype,
+ original_call=orig_call))
+ """
+ return templates.replace_as_expression(
+ template,
+ dtype=dtype,
+ target=node.args[0],
+ orig_call=node.func)
- def visit_Expr(self, node):
+ def visit_Call(self, node):
node = self.generic_visit(node)
- if isinstance(node.value, gast.Call):
- call_node = node.value
-
- if not anno.hasanno(call_node.func, anno.Basic.QN):
- return node
- qn = anno.getanno(call_node.func, anno.Basic.QN)
-
- if qn.qn[-1] == 'append' and (len(call_node.args) == 1):
- template = """
- target = ag__.utils.dynamic_list_append(target, element)
- """
- node = templates.replace(
- template,
- target=qn.parent.ast(),
- element=call_node.args[0])
+
+ # TODO(mdan): This is insufficient if target is a function argument.
+ # In the case of function arguments, we need to add the list to the
+ # function's return value, because it is being modified.
+ # TODO(mdan): Checking just the name is brittle, can it be improved?
+ if isinstance(node.func, gast.Attribute):
+ func_name = node.func.attr
+ if func_name == 'append' and (len(node.args) == 1):
+ node = self._replace_append_call(node)
+ elif func_name == 'pop' and (len(node.args) <= 1):
+ node = self._replace_pop_call(node)
+ elif func_name == 'stack' and (len(node.args) == 1):
+ node = self._replace_stack_call(node)
+
return node
- def _replace_list_constructors(self, targets, values):
- for target in targets:
- if (isinstance(target, (gast.Tuple, gast.List)) and
- isinstance(values, (gast.Tuple, gast.List))):
- n_targets = len(target.elts)
- for i in range(n_targets):
- target_el, value_el = target.elts[i], values.elts[i]
- values.elts[i] = self._replace_list_constructors(
- (target_el,), value_el)
- return values
- if isinstance(values, gast.List):
- if values.elts:
- return self._pre_populated_list(values)
- else:
- return self._empty_list(values)
- return values
-
- def visit_Assign(self, node):
- node = self.generic_visit(node)
+ def _generate_pop_operation(self, original_call_node, pop_var_name):
+ assert isinstance(original_call_node.func, gast.Attribute)
+
+ if original_call_node.args:
+ pop_element = original_call_node.args[0]
+ else:
+ pop_element = parser.parse_expression('None')
+ # The call will be something like "target.pop()", and the dtype is hooked to
+ # target, hence the func.value.
+ dtype = anno.getanno(
+ original_call_node.func.value,
+ 'element_type',
+ default=templates.replace_as_expression('None'))
+ shape = anno.getanno(
+ original_call_node.func.value,
+ 'element_shape',
+ default=templates.replace_as_expression('None'))
+
+ template = """
+ target, pop_var_name = ag__.list_pop(
+ target, element,
+ opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
+ """
+ return templates.replace(
+ template,
+ target=original_call_node.func.value,
+ pop_var_name=pop_var_name,
+ element=pop_element,
+ dtype=dtype,
+ shape=shape)
+
+ def _postprocess_statement(self, node):
+ """Inserts any separate pop() calls that node may use."""
+ pop_uses = self.get_local(POP_USES, None)
+ if pop_uses:
+ replacements = []
+ for original_call_node, pop_var_name in pop_uses:
+ replacements.extend(
+ self._generate_pop_operation(original_call_node, pop_var_name))
+ replacements.append(node)
+ node = replacements
+ self.exit_local_scope()
+ return node, None
+
+ # TODO(mdan): Should we have a generic visit_block instead?
+ # Right now it feels that a visit_block would add too much magic that's
+ # hard to follow.
+
+ def _visit_and_process_block(self, block):
+ return self.visit_block(
+ block,
+ before_visit=self.enter_local_scope,
+ after_visit=self._postprocess_statement)
+
+ def visit_FunctionDef(self, node):
+ node.args = self.generic_visit(node.args)
+ node.decorator_list = self.visit_block(node.decorator_list)
+ node.body = self._visit_and_process_block(node.body)
+ return node
+
+ def visit_For(self, node):
+ node.target = self.visit(node.target)
+ node.body = self._visit_and_process_block(node.body)
+ node.orelse = self._visit_and_process_block(node.orelse)
+ return node
+
+ def visit_While(self, node):
+ node.test = self.visit(node.test)
+ node.body = self._visit_and_process_block(node.body)
+ node.orelse = self._visit_and_process_block(node.orelse)
+ return node
+
+ def visit_If(self, node):
+ node.test = self.visit(node.test)
+ node.body = self._visit_and_process_block(node.body)
+ node.orelse = self._visit_and_process_block(node.orelse)
+ return node
- # Only convert lists when they are assigned to a variable, e.g.:
- # l = []
- # TODO(mdan): A similar pattern exists in type_info.py
- # We should add a generic "unpack_assignment" function to the base
- # transformer, that has the same effect as applying some logic to the SSA
- # form.
- node.value = self._replace_list_constructors(node.targets, node.value)
+ def visit_With(self, node):
+ node.items = self.visit_block(node.items)
+ node.body = self._visit_and_process_block(node.body)
return node
diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py
index 74c6dc64f1..9f18ab9f44 100644
--- a/tensorflow/contrib/autograph/converters/lists_test.py
+++ b/tensorflow/contrib/autograph/converters/lists_test.py
@@ -22,74 +22,126 @@ from tensorflow.contrib.autograph import utils
from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.contrib.autograph.converters import lists
from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import list_ops
from tensorflow.python.platform import test
class ListTest(converter_test_base.TestCase):
- def test_empty_annotated_list(self):
+ def test_empty_list(self):
def test_fn():
- l = []
- utils.set_element_type(l, dtypes.int32)
- l.append(1)
- return l
+ return []
- node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils})
+ node = self.parse_and_analyze(test_fn, {})
node = lists.transform(node, self.ctx)
- with self.compiled(node, tensor_array_ops.TensorArray,
- dtypes.int32) as result:
- # TODO(mdan): Attach these additional modules automatically.
- result.utils = utils
- result.dtypes = dtypes
+ with self.compiled(node) as result:
+ tl = result.test_fn()
+ # Empty tensor lists cannot be evaluated or stacked.
+ self.assertTrue(isinstance(tl, ops.Tensor))
+ self.assertEqual(tl.dtype, dtypes.variant)
+
+ def test_initialized_list(self):
+
+ def test_fn():
+ return [1, 2, 3]
+
+ node = self.parse_and_analyze(test_fn, {})
+ node = lists.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
with self.test_session() as sess:
- self.assertAllEqual([1], sess.run(result.test_fn().stack()))
+ tl = result.test_fn()
+ r = list_ops.tensor_list_stack(tl, dtypes.int32)
+ self.assertAllEqual(sess.run(r), [1, 2, 3])
- def test_empty_annotated_lists_unpacked(self):
+ def test_list_append(self):
def test_fn():
- l, m = [], []
- utils.set_element_type(l, dtypes.int32)
- utils.set_element_type(m, dtypes.int32)
- l.append(1)
- m.append(2)
- return l, m
+ l = [1]
+ l.append(2)
+ l.append(3)
+ return l
- node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils})
+ node = self.parse_and_analyze(test_fn, {})
node = lists.transform(node, self.ctx)
- with self.compiled(node, tensor_array_ops.TensorArray,
- dtypes.int32) as result:
+ with self.compiled(node) as result:
+ with self.test_session() as sess:
+ tl = result.test_fn()
+ r = list_ops.tensor_list_stack(tl, dtypes.int32)
+ self.assertAllEqual(sess.run(r), [1, 2, 3])
+
+ def test_list_pop(self):
+
+ def test_fn():
+ l = [1, 2, 3]
+ utils.set_element_type(l, dtypes.int32, ())
+ s = l.pop()
+ return s, l
+
+ node = self.parse_and_analyze(
+ test_fn,
+ {
+ 'utils': utils,
+ 'dtypes': dtypes
+ },
+ include_type_analysis=True,
+ )
+ node = lists.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
result.utils = utils
result.dtypes = dtypes
with self.test_session() as sess:
- res_l, res_m = result.test_fn()
- self.assertEqual([1], sess.run(res_l.stack()))
- self.assertEqual([2], sess.run(res_m.stack()))
+ ts, tl = result.test_fn()
+ r = list_ops.tensor_list_stack(tl, dtypes.int32)
+ self.assertAllEqual(sess.run(r), [1, 2])
+ self.assertAllEqual(sess.run(ts), 3)
+
+ def test_double_list_pop(self):
- def test_empty_annotated_lists_list_unpacked(self):
+ def test_fn(l):
+ s = l.pop().pop()
+ return s
+
+ node = self.parse_and_analyze(test_fn, {})
+ node = lists.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
+ test_input = [1, 2, [1, 2, 3]]
+ # TODO(mdan): Pass a list of lists of tensor when we fully support that.
+ # For now, we just pass a regular Python list of lists just to verify that
+ # the two pop calls are sequenced properly.
+ self.assertAllEqual(result.test_fn(test_input), 3)
+
+ def test_list_stack(self):
+
+ tf = None # Will be replaced with a mock.
def test_fn():
- [l, m] = [], []
+ l = [1, 2, 3]
utils.set_element_type(l, dtypes.int32)
- utils.set_element_type(m, dtypes.int32)
- l.append(1)
- m.append(2)
- return l, m
-
- node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils})
+ return tf.stack(l)
+
+ node = self.parse_and_analyze(
+ test_fn,
+ {
+ 'utils': utils,
+ 'dtypes': dtypes
+ },
+ include_type_analysis=True,
+ )
node = lists.transform(node, self.ctx)
- with self.compiled(node, tensor_array_ops.TensorArray,
- dtypes.int32) as result:
+ with self.compiled(node, array_ops.stack, dtypes.int32) as result:
result.utils = utils
result.dtypes = dtypes
with self.test_session() as sess:
- res_l, res_m = result.test_fn()
- self.assertEqual([1], sess.run(res_l.stack()))
- self.assertEqual([2], sess.run(res_m.stack()))
+ self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3])
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
index d6555dc7e0..7d1e65c958 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
@@ -17,8 +17,8 @@
This analyzer uses known live values to further infer object types. This
may include for instance constructed objects and object member functions.
-In addition, the analyzer will also process annotations for TF (staged) type
-annotations.
+In addition, the analyzer also handles user annotations made in the code (for
+example, the autograph.set_element_type function).
Requires annotations generated by LiveValuesResolver.
"""
@@ -44,6 +44,7 @@ from __future__ import print_function
import gast
from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.python.util import tf_inspect
@@ -159,12 +160,10 @@ class TypeInfoResolver(transformer.Base):
# a = b
# then for future references to `a` we should have definition = `b`
definition = self.scope.getval(qn)
- if anno.hasanno(definition, 'type'):
- anno.setanno(node, 'type', anno.getanno(definition, 'type'))
- anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn'))
- if anno.hasanno(definition, 'element_type'):
- anno.setanno(node, 'element_type',
- anno.getanno(definition, 'element_type'))
+ anno.copyanno(definition, node, 'type')
+ anno.copyanno(definition, node, 'type_fqn')
+ anno.copyanno(definition, node, 'element_type')
+ anno.copyanno(definition, node, 'element_shape')
return node
def _process_variable_assignment(self, target, value):
@@ -211,23 +210,20 @@ class TypeInfoResolver(transformer.Base):
if (anno.getanno(node.func, 'live_val') is
self.context.type_annotation_func):
- if len(node.args) != 2:
- raise ValueError('"%s" must have exactly two parameters'
+ if len(node.args) < 2 or len(node.args) > 3:
+ raise ValueError('"%s" must have either two or three parameters'
% self.context.type_annotation_func)
- target_arg, type_arg = node.args
+ if len(node.args) == 2:
+ target_arg, type_arg = node.args
+ shape_arg = parser.parse_expression('None')
+ else:
+ target_arg, type_arg, shape_arg = node.args
if not anno.hasanno(target_arg, anno.Basic.QN):
raise ValueError('the first argument of "%s" must by a symbol'
% self.context.type_annotation_func)
- if isinstance(type_arg, gast.Str):
- element_type = type_arg.s
- elif isinstance(type_arg, gast.Num):
- element_type = type_arg.n
- else:
- if not anno.hasanno(type_arg, 'live_val'):
- raise ValueError(
- 'the second argument of "%s" must be statically resolvable' %
- self.context.type_annotation_func)
- element_type = anno.getanno(type_arg, 'live_val')
+ # TODO(mdan): This is vulnerable to symbol renaming.
+ element_type = type_arg
+ element_shape = shape_arg
target_symbol = anno.getanno(target_arg, anno.Basic.QN)
# Find the definition of this symbol and annotate it with the given
@@ -235,7 +231,9 @@ class TypeInfoResolver(transformer.Base):
# to receive the same type annotation.
definition = self.scope.getval(target_symbol)
anno.setanno(node, 'element_type', element_type)
+ anno.setanno(node, 'element_shape', element_shape)
anno.setanno(definition, 'element_type', element_type)
+ anno.setanno(definition, 'element_shape', element_shape)
# TODO(mdan): Should we update references between definition and here?
return self.generic_visit(node)
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
index 95cbf5ca79..484562f294 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
@@ -187,14 +187,14 @@ class TypeInfoResolverTest(test.TestCase):
def test_fn():
f = []
- f = utils.set_element_type(f, Foo)
+ f = utils.set_element_type(f, Foo, (1, 2, 3))
return f
node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils})
f_def = node.body[0].body[0].value
- self.assertEqual(anno.getanno(f_def, 'element_type'), Foo)
+ self.assertEqual(anno.getanno(f_def, 'element_type').id, 'Foo')
f_ref = node.body[0].body[1].value
- self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
+ self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo')
def test_type_annotation_args(self):
@@ -207,7 +207,7 @@ class TypeInfoResolverTest(test.TestCase):
node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils})
f_ref = node.body[0].body[1].value
- self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
+ self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo')
def test_nested_unpacking(self):
@@ -223,9 +223,9 @@ class TypeInfoResolverTest(test.TestCase):
node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar})
a, b, c = node.body[0].body[1].value.elts
- self.assertEquals(Foo, anno.getanno(a, 'type'))
- self.assertEquals(Bar, anno.getanno(b, 'type'))
- self.assertEquals(Foo, anno.getanno(c, 'type'))
+ self.assertEquals(anno.getanno(a, 'type'), Foo)
+ self.assertEquals(anno.getanno(b, 'type'), Bar)
+ self.assertEquals(anno.getanno(c, 'type'), Foo)
self.assertFalse(anno.hasanno(a, 'live_val'))
self.assertFalse(anno.hasanno(b, 'live_val'))
self.assertFalse(anno.hasanno(c, 'live_val'))
@@ -242,8 +242,8 @@ class TypeInfoResolverTest(test.TestCase):
node = self._parse_and_analyze(test_fn, {'utils': utils})
a, b = node.body[0].body[2].body[2].value.elts
- self.assertEquals(1, anno.getanno(a, 'element_type'))
- self.assertEquals(2, anno.getanno(b, 'element_type'))
+ self.assertEquals(anno.getanno(a, 'element_type').n, 1)
+ self.assertEquals(anno.getanno(b, 'element_type').n, 2)
self.assertFalse(anno.hasanno(a, 'type'))
self.assertFalse(anno.hasanno(b, 'type'))
self.assertFalse(anno.hasanno(a, 'live_val'))