aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/operators/data_structures_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/operators/data_structures_test.py')
-rw-r--r--tensorflow/contrib/autograph/operators/data_structures_test.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/contrib/autograph/operators/data_structures_test.py
index 8bbb52d6c1..7ea11a839b 100644
--- a/tensorflow/contrib/autograph/operators/data_structures_test.py
+++ b/tensorflow/contrib/autograph/operators/data_structures_test.py
@@ -37,10 +37,51 @@ class ListTest(test.TestCase):
def test_new_list_tensor(self):
l = data_structures.new_list([3, 4, 5])
+ self.assertAllEqual(l, [3, 4, 5])
+
+ def test_tf_tensor_list_new(self):
+ l = data_structures.tf_tensor_list_new([3, 4, 5])
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.test_session() as sess:
self.assertAllEqual(sess.run(t), [3, 4, 5])
+ def test_tf_tensor_list_new_illegal_input(self):
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_list_new([3, 4.0])
+ # TODO(mdan): It might make more sense to type cast in this case.
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_list_new([3, 4], element_dtype=dtypes.float32)
+ # Tensor lists do support heterogeneous lists.
+ self.assertIsNot(data_structures.tf_tensor_list_new([3, [4, 5]]), None)
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_list_new([3, 4], element_shape=(2,))
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_list_new([], element_shape=(2,))
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_list_new([], element_dtype=dtypes.float32)
+
+ def test_tf_tensor_array_new(self):
+ l = data_structures.tf_tensor_array_new([3, 4, 5])
+ t = l.stack()
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(t), [3, 4, 5])
+
+ def test_tf_tensor_array_new_illegal_input(self):
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_array_new([3, 4.0])
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_array_new([3, 4], element_dtype=dtypes.float32)
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_array_new([3, [4, 5]])
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_array_new([3, 4], element_shape=(2,))
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_array_new([], element_shape=(2,))
+ # TAs can infer the shape.
+ self.assertIsNot(
+ data_structures.tf_tensor_array_new([], element_dtype=dtypes.float32),
+ None)
+
def test_append_tensor_list(self):
l = data_structures.new_list()
x = constant_op.constant([1, 2, 3])