aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/converters/slices_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/converters/slices_test.py')
-rw-r--r--tensorflow/contrib/autograph/converters/slices_test.py47
1 files changed, 32 insertions, 15 deletions
diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py
index df9a4c8bab..3c0f81e8bc 100644
--- a/tensorflow/contrib/autograph/converters/slices_test.py
+++ b/tensorflow/contrib/autograph/converters/slices_test.py
@@ -18,9 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph import utils
from tensorflow.contrib.autograph.converters import slices
from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.contrib.autograph.lang import directives
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import list_ops
@@ -32,28 +35,42 @@ class SliceTest(converter_testing.TestCase):
def test_index_access(self):
def test_fn(l):
- utils.set_element_type(l, dtypes.int32)
return l[1]
- node = self.parse_and_analyze(
- test_fn,
- {
- 'utils': utils,
- 'dtypes': dtypes
- },
- include_type_analysis=True,
- )
- node = slices.transform(node, self.ctx)
-
- with self.compiled(node, dtypes.int32) as result:
- result.utils = utils
- result.dtypes = dtypes
+ node, ctx = self.prepare(test_fn, {})
+ def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS)
+ def_.directives[directives.set_element_type] = {
+ 'dtype': parser.parse_expression('tf.int32')
+ }
+ node = slices.transform(node, ctx)
+
+ with self.compiled(node, {}, dtypes.int32) as result:
with self.test_session() as sess:
tl = list_ops.tensor_list_from_tensor(
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
y = result.test_fn(tl)
self.assertEqual(2, sess.run(y))
+ def test_index_access_multiple_definitions(self):
+
+ def test_fn(l):
+ if l:
+ l = []
+ return l[1]
+
+ node, ctx = self.prepare(test_fn, {})
+ def_, = anno.getanno(node.body[0].args.args[0], anno.Static.DEFINITIONS)
+ def_.directives[directives.set_element_type] = {
+ 'dtype': parser.parse_expression('tf.int32')
+ }
+ def_, = anno.getanno(node.body[0].body[0].body[0].targets[0],
+ anno.Static.DEFINITIONS)
+ def_.directives[directives.set_element_type] = {
+ 'dtype': parser.parse_expression('tf.float32')
+ }
+ with self.assertRaises(transformer.AutographParseError):
+ slices.transform(node, ctx)
+
if __name__ == '__main__':
test.main()