aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/converters/lists_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/converters/lists_test.py')
-rw-r--r--tensorflow/contrib/autograph/converters/lists_test.py90
1 files changed, 37 insertions, 53 deletions
diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py
index ea04097b28..f906918ac0 100644
--- a/tensorflow/contrib/autograph/converters/lists_test.py
+++ b/tensorflow/contrib/autograph/converters/lists_test.py
@@ -18,9 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph import utils
from tensorflow.contrib.autograph.converters import lists
from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.contrib.autograph.lang import directives
+from tensorflow.contrib.autograph.lang import special_functions
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import parser
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -28,6 +31,9 @@ from tensorflow.python.ops import list_ops
from tensorflow.python.platform import test
+tf = None # Will be replaced by a mock.
+
+
class ListTest(converter_testing.TestCase):
def test_empty_list(self):
@@ -35,10 +41,7 @@ class ListTest(converter_testing.TestCase):
def test_fn():
return []
- node = self.parse_and_analyze(test_fn, {})
- node = lists.transform(node, self.ctx)
-
- with self.compiled(node) as result:
+ with self.converted(test_fn, lists, {}) as result:
tl = result.test_fn()
# Empty tensor lists cannot be evaluated or stacked.
self.assertTrue(isinstance(tl, ops.Tensor))
@@ -49,27 +52,19 @@ class ListTest(converter_testing.TestCase):
def test_fn():
return [1, 2, 3]
- node = self.parse_and_analyze(test_fn, {})
- node = lists.transform(node, self.ctx)
-
- with self.compiled(node) as result:
- with self.test_session() as sess:
- tl = result.test_fn()
- r = list_ops.tensor_list_stack(tl, dtypes.int32)
- self.assertAllEqual(sess.run(r), [1, 2, 3])
+ with self.converted(test_fn, lists, {}) as result:
+ self.assertAllEqual(result.test_fn(), [1, 2, 3])
def test_list_append(self):
def test_fn():
- l = [1]
+ l = special_functions.tensor_list([1])
l.append(2)
l.append(3)
return l
- node = self.parse_and_analyze(test_fn, {})
- node = lists.transform(node, self.ctx)
-
- with self.compiled(node) as result:
+ ns = {'special_functions': special_functions}
+ with self.converted(test_fn, lists, ns) as result:
with self.test_session() as sess:
tl = result.test_fn()
r = list_ops.tensor_list_stack(tl, dtypes.int32)
@@ -78,24 +73,21 @@ class ListTest(converter_testing.TestCase):
def test_list_pop(self):
def test_fn():
- l = [1, 2, 3]
- utils.set_element_type(l, dtypes.int32, ())
+ l = special_functions.tensor_list([1, 2, 3])
s = l.pop()
return s, l
- node = self.parse_and_analyze(
- test_fn,
- {
- 'utils': utils,
- 'dtypes': dtypes
- },
- include_type_analysis=True,
- )
- node = lists.transform(node, self.ctx)
-
- with self.compiled(node) as result:
- result.utils = utils
- result.dtypes = dtypes
+ ns = {'special_functions': special_functions}
+ node, ctx = self.prepare(test_fn, ns)
+ def_, = anno.getanno(node.body[0].body[0].targets[0],
+ anno.Static.ORIG_DEFINITIONS)
+ def_.directives[directives.set_element_type] = {
+ 'dtype': parser.parse_expression('tf.int32'),
+ 'shape': parser.parse_expression('()'),
+ }
+ node = lists.transform(node, ctx)
+
+ with self.compiled(node, ns, dtypes.int32) as result:
with self.test_session() as sess:
ts, tl = result.test_fn()
r = list_ops.tensor_list_stack(tl, dtypes.int32)
@@ -108,10 +100,7 @@ class ListTest(converter_testing.TestCase):
s = l.pop().pop()
return s
- node = self.parse_and_analyze(test_fn, {})
- node = lists.transform(node, self.ctx)
-
- with self.compiled(node) as result:
+ with self.converted(test_fn, lists, {}) as result:
test_input = [1, 2, [1, 2, 3]]
# TODO(mdan): Pass a list of lists of tensor when we fully support that.
# For now, we just pass a regular Python list of lists just to verify that
@@ -120,29 +109,24 @@ class ListTest(converter_testing.TestCase):
def test_list_stack(self):
- tf = None # Will be replaced with a mock.
-
def test_fn():
l = [1, 2, 3]
- utils.set_element_type(l, dtypes.int32)
return tf.stack(l)
- node = self.parse_and_analyze(
- test_fn,
- {
- 'utils': utils,
- 'dtypes': dtypes
- },
- include_type_analysis=True,
- )
- node = lists.transform(node, self.ctx)
-
- with self.compiled(node, array_ops.stack, dtypes.int32) as result:
- result.utils = utils
- result.dtypes = dtypes
+ node, ctx = self.prepare(test_fn, {})
+ def_, = anno.getanno(node.body[0].body[0].targets[0],
+ anno.Static.ORIG_DEFINITIONS)
+ def_.directives[directives.set_element_type] = {
+ 'dtype': parser.parse_expression('tf.int32')
+ }
+ node = lists.transform(node, ctx)
+
+ with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result:
with self.test_session() as sess:
self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3])
+ # TODO(mdan): Add a test with tf.stack with axis kwarg.
+
if __name__ == '__main__':
test.main()