aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/converters/slices.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/converters/slices.py')
-rw-r--r--tensorflow/contrib/autograph/converters/slices.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py
index 3f5fc57125..c527f98613 100644
--- a/tensorflow/contrib/autograph/converters/slices.py
+++ b/tensorflow/contrib/autograph/converters/slices.py
@@ -21,7 +21,7 @@ from __future__ import print_function
import gast
from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.lang import directives
from tensorflow.contrib.autograph.pyct import templates
@@ -36,12 +36,14 @@ class SliceTransformer(converter.Base):
def _process_single_assignment(self, target, value):
if not isinstance(target, gast.Subscript):
return None
+ if not isinstance(target.slice, gast.Index):
+ return None
template = """
target = ag__.set_item(target, key, item)
"""
return templates.replace(
- template, target=target.value, key=target.slice, item=value)
+ template, target=target.value, key=target.slice.value, item=value)
def visit_Assign(self, node):
node = self.generic_visit(node)
@@ -56,17 +58,17 @@ class SliceTransformer(converter.Base):
def visit_Subscript(self, node):
node = self.generic_visit(node)
if not isinstance(node.slice, gast.Index):
- # TODO(mdan): It might make more sense to wave them through.
- raise NotImplementedError('non-index slice')
+ return node
if not isinstance(node.ctx, gast.Load):
# Index writes are handled at a higher level, one at which the rvalue is
# also available.
return node
- dtype = anno.getanno(
+ dtype = self.get_definition_directive(
node.value,
- 'element_type',
+ directives.set_element_type,
+ 'dtype',
default=templates.replace_as_expression('None'))
template = """
@@ -76,7 +78,7 @@ class SliceTransformer(converter.Base):
opts=ag__.GetItemOpts(element_dtype=dtype))
"""
return templates.replace_as_expression(
- template, target=node.value, key=node.slice, dtype=dtype)
+ template, target=node.value, key=node.slice.value, dtype=dtype)
def transform(node, ctx):