aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/converters/lists.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/converters/lists.py')
-rw-r--r--tensorflow/contrib/autograph/converters/lists.py30
1 files changed, 21 insertions, 9 deletions
diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py
index d77a044798..a02fc827b8 100644
--- a/tensorflow/contrib/autograph/converters/lists.py
+++ b/tensorflow/contrib/autograph/converters/lists.py
@@ -33,6 +33,7 @@ from __future__ import print_function
import gast
from tensorflow.contrib.autograph.core import converter
+from tensorflow.contrib.autograph.lang import directives
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import templates
@@ -88,12 +89,12 @@ class ListTransformer(converter.Base):
scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
target_node = node.func.value
- # Attempt to use a related name if can get one. Otherwise use something
+ # Attempt to use a related name if one exists. Otherwise use something
# generic.
if anno.hasanno(target_node, anno.Basic.QN):
target_name = anno.getanno(target_node, anno.Basic.QN).ssf()
else:
- target_name = 'list'
+ target_name = 'list_'
pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced)
pop_uses = self.get_local(POP_USES, [])
@@ -104,9 +105,10 @@ class ListTransformer(converter.Base):
def _replace_stack_call(self, node):
assert len(node.args) == 1
- dtype = anno.getanno(
+ dtype = self.get_definition_directive(
node.args[0],
- 'element_type',
+ directives.set_element_type,
+ 'dtype',
default=templates.replace_as_expression('None'))
template = """
ag__.list_stack(
@@ -134,7 +136,10 @@ class ListTransformer(converter.Base):
node = self._replace_append_call(node)
elif func_name == 'pop' and (len(node.args) <= 1):
node = self._replace_pop_call(node)
- elif func_name == 'stack' and (len(node.args) == 1):
+ elif (func_name == 'stack' and (len(node.args) == 1) and
+ (not node.keywords or node.keywords[0].arg == 'strict')):
+ # This avoids false positives with keyword args.
+ # TODO(mdan): handle kwargs properly.
node = self._replace_stack_call(node)
return node
@@ -146,15 +151,22 @@ class ListTransformer(converter.Base):
pop_element = original_call_node.args[0]
else:
pop_element = parser.parse_expression('None')
+
# The call will be something like "target.pop()", and the dtype is hooked to
# target, hence the func.value.
- dtype = anno.getanno(
+ # TODO(mdan): For lists of lists, this won't work.
+ # The reason why it won't work is because it's unclear how to annotate
+ # the list as a "list of lists with a certain element type" when using
+ # operations like `l.pop().pop()`.
+ dtype = self.get_definition_directive(
original_call_node.func.value,
- 'element_type',
+ directives.set_element_type,
+ 'dtype',
default=templates.replace_as_expression('None'))
- shape = anno.getanno(
+ shape = self.get_definition_directive(
original_call_node.func.value,
- 'element_shape',
+ directives.set_element_type,
+ 'shape',
default=templates.replace_as_expression('None'))
template = """