diff options
Diffstat (limited to 'tensorflow/contrib/autograph/converters/slices.py')
-rw-r--r-- | tensorflow/contrib/autograph/converters/slices.py | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py index 3f5fc57125..c527f98613 100644 --- a/tensorflow/contrib/autograph/converters/slices.py +++ b/tensorflow/contrib/autograph/converters/slices.py @@ -21,7 +21,7 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.core import converter -from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.lang import directives from tensorflow.contrib.autograph.pyct import templates @@ -36,12 +36,14 @@ class SliceTransformer(converter.Base): def _process_single_assignment(self, target, value): if not isinstance(target, gast.Subscript): return None + if not isinstance(target.slice, gast.Index): + return None template = """ target = ag__.set_item(target, key, item) """ return templates.replace( - template, target=target.value, key=target.slice, item=value) + template, target=target.value, key=target.slice.value, item=value) def visit_Assign(self, node): node = self.generic_visit(node) @@ -56,17 +58,17 @@ class SliceTransformer(converter.Base): def visit_Subscript(self, node): node = self.generic_visit(node) if not isinstance(node.slice, gast.Index): - # TODO(mdan): It might make more sense to wave them through. - raise NotImplementedError('non-index slice') + return node if not isinstance(node.ctx, gast.Load): # Index writes are handled at a higher level, one at which the rvalue is # also available. return node - dtype = anno.getanno( + dtype = self.get_definition_directive( node.value, - 'element_type', + directives.set_element_type, + 'dtype', default=templates.replace_as_expression('None')) template = """ @@ -76,7 +78,7 @@ class SliceTransformer(converter.Base): opts=ag__.GetItemOpts(element_dtype=dtype)) """ return templates.replace_as_expression( - template, target=node.value, key=node.slice, dtype=dtype) + template, target=node.value, key=node.slice.value, dtype=dtype) def transform(node, ctx): |