aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/lang/special_functions_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/lang/special_functions_test.py')
-rw-r--r--tensorflow/contrib/autograph/lang/special_functions_test.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/contrib/autograph/lang/special_functions_test.py b/tensorflow/contrib/autograph/lang/special_functions_test.py
index a49cb64075..db492cc5c6 100644
--- a/tensorflow/contrib/autograph/lang/special_functions_test.py
+++ b/tensorflow/contrib/autograph/lang/special_functions_test.py
@@ -28,7 +28,23 @@ from tensorflow.python.platform import test
class SpecialFunctionsTest(test.TestCase):
- def test_basic(self):
+ def test_tensor_list_from_elements(self):
+ elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
+
+ l = special_functions.tensor_list(elements)
+ sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
+
+ def test_tensor_list_array_from_elements(self):
+ elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
+
+ l = special_functions.tensor_list(elements, use_tensor_array=True)
+ sl = l.stack()
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
+
+ def test_stack(self):
self.assertEqual(special_functions.stack(1, strict=False), 1)
self.assertListEqual(
special_functions.stack([1, 2, 3], strict=False), [1, 2, 3])