aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-10-02 12:15:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 12:24:39 -0700
commit8d4ef71f06a06a093419bf0f80562a1941059029 (patch)
treedfca34248365279eefa5293351d172e4909ccc5d /tensorflow/python/autograph
parent16b44d48d485dbb62b9922e172df4cc460174046 (diff)
Allow creating a list from a tensor. Fix a few inconsistencies in the tensor list constructors.
PiperOrigin-RevId: 215435720
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r--tensorflow/python/autograph/lang/special_functions.py24
-rw-r--r--tensorflow/python/autograph/lang/special_functions_test.py37
-rw-r--r--tensorflow/python/autograph/operators/data_structures.py17
-rw-r--r--tensorflow/python/autograph/operators/data_structures_test.py31
4 files changed, 99 insertions, 10 deletions
diff --git a/tensorflow/python/autograph/lang/special_functions.py b/tensorflow/python/autograph/lang/special_functions.py
index e4838d1b6d..62ac018ac4 100644
--- a/tensorflow/python/autograph/lang/special_functions.py
+++ b/tensorflow/python/autograph/lang/special_functions.py
@@ -24,6 +24,26 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.operators import data_structures
+from tensorflow.python.framework import tensor_util
+
+
+def _validate_list_constructor(elements, element_dtype, element_shape):
+ """Validates the inputs of tensor_list."""
+ if element_dtype is not None and element_shape is not None:
+ return
+ if tensor_util.is_tensor(elements):
+ return
+ if isinstance(elements, (list, tuple)):
+ if elements:
+ return
+ else:
+ raise ValueError(
+ 'element_dtype and element_shape are required when elements are'
+ ' empty')
+
+ raise ValueError(
+ 'unknown type for elements: {}; only Tensor, list and tuple are'
+ ' allowed'.format(type(elements)))
def tensor_list(elements,
@@ -52,9 +72,7 @@ def tensor_list(elements,
Raises:
ValueError: for invalid arguments
"""
- if not (elements or (element_dtype and element_shape)):
- raise ValueError(
- 'element_dtype and element_shape are required for empty lists')
+ _validate_list_constructor(elements, element_dtype, element_shape)
if use_tensor_array:
return data_structures.tf_tensor_array_new(elements, element_dtype,
element_shape)
diff --git a/tensorflow/python/autograph/lang/special_functions_test.py b/tensorflow/python/autograph/lang/special_functions_test.py
index 545dd11729..206a32d07c 100644
--- a/tensorflow/python/autograph/lang/special_functions_test.py
+++ b/tensorflow/python/autograph/lang/special_functions_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -28,12 +30,43 @@ from tensorflow.python.platform import test
class SpecialFunctionsTest(test.TestCase):
+ def test_tensor_list_empty_list(self):
+ l = special_functions.tensor_list([],
+ element_dtype=dtypes.int32,
+ element_shape=())
+ sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(sl), [])
+
+ l = special_functions.tensor_list((),
+ element_dtype=dtypes.int32,
+ element_shape=())
+ sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(sl), [])
+
+ def test_tensor_list_tensor(self):
+ l = special_functions.tensor_list(
+ constant_op.constant([], dtype=dtypes.int32))
+ sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(sl), [])
+
+ def test_tensor_list_unsupported_initializer(self):
+ with self.assertRaisesRegexp(ValueError, 'unknown type'):
+ special_functions.tensor_list(np.array([1, 2, 3]))
+
+ def test_tensor_list_empty_list_no_type(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'element_dtype and element_shape are required'):
+ special_functions.tensor_list([])
+
def test_tensor_list_from_elements(self):
elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
l = special_functions.tensor_list(elements)
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
- with self.cached_session() as sess:
+ with self.test_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_tensor_list_array_from_elements(self):
@@ -41,7 +74,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements, use_tensor_array=True)
sl = l.stack()
- with self.cached_session() as sess:
+ with self.test_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_stack(self):
diff --git a/tensorflow/python/autograph/operators/data_structures.py b/tensorflow/python/autograph/operators/data_structures.py
index cc0a3c3544..b3a3851333 100644
--- a/tensorflow/python/autograph/operators/data_structures.py
+++ b/tensorflow/python/autograph/operators/data_structures.py
@@ -106,6 +106,14 @@ def tf_tensor_array_new(elements, element_dtype=None, element_shape=None):
def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
"""Overload of new_list that stages a Tensor list creation."""
+ if tensor_util.is_tensor(elements):
+ if element_shape is not None:
+ raise ValueError(
+ 'element shape may not be specified when creating list from tensor')
+ element_shape = array_ops.shape(elements)[1:]
+ l = list_ops.tensor_list_from_tensor(elements, element_shape=element_shape)
+ return l
+
elements = tuple(ops.convert_to_tensor(el) for el in elements)
all_dtypes = set(el.dtype for el in elements)
@@ -115,13 +123,15 @@ def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
raise ValueError(
'incompatible dtype; specified: {}, inferred from {}: {}'.format(
element_dtype, elements, inferred_dtype))
- else:
+ elif all_dtypes:
# Heterogeneous lists are ok.
if element_dtype is not None:
raise ValueError(
'specified dtype {} is inconsistent with that of elements {}'.format(
element_dtype, elements))
inferred_dtype = dtypes.variant
+ else:
+ inferred_dtype = dtypes.variant
all_shapes = set(tuple(el.shape.as_list()) for el in elements)
if len(all_shapes) == 1:
@@ -130,19 +140,22 @@ def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
raise ValueError(
'incompatible shape; specified: {}, inferred from {}: {}'.format(
element_shape, elements, inferred_shape))
- else:
+ elif all_shapes:
# Heterogeneous lists are ok.
if element_shape is not None:
raise ValueError(
'specified shape {} is inconsistent with that of elements {}'.format(
element_shape, elements))
inferred_shape = constant_op.constant(-1) # unknown shape, by convention
+ else:
+ inferred_shape = constant_op.constant(-1) # unknown shape, by convention
if element_dtype is None:
element_dtype = inferred_dtype
if element_shape is None:
element_shape = inferred_shape
+ element_shape = ops.convert_to_tensor(element_shape, dtype=dtypes.int32)
l = list_ops.empty_tensor_list(
element_shape=element_shape, element_dtype=element_dtype)
for el in elements:
diff --git a/tensorflow/python/autograph/operators/data_structures_test.py b/tensorflow/python/autograph/operators/data_structures_test.py
index 8532dbe466..6039b07982 100644
--- a/tensorflow/python/autograph/operators/data_structures_test.py
+++ b/tensorflow/python/autograph/operators/data_structures_test.py
@@ -45,6 +45,20 @@ class ListTest(test.TestCase):
with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [3, 4, 5])
+ def test_tf_tensor_list_new_empty(self):
+ l = data_structures.tf_tensor_list_new([],
+ element_dtype=dtypes.int32,
+ element_shape=())
+ t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(t), [])
+
+ def test_tf_tensor_list_new_from_tensor(self):
+ l = data_structures.tf_tensor_list_new(constant_op.constant([3, 4, 5]))
+ t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(t), [3, 4, 5])
+
def test_tf_tensor_list_new_illegal_input(self):
with self.assertRaises(ValueError):
data_structures.tf_tensor_list_new([3, 4.0])
@@ -56,9 +70,8 @@ class ListTest(test.TestCase):
with self.assertRaises(ValueError):
data_structures.tf_tensor_list_new([3, 4], element_shape=(2,))
with self.assertRaises(ValueError):
- data_structures.tf_tensor_list_new([], element_shape=(2,))
- with self.assertRaises(ValueError):
- data_structures.tf_tensor_list_new([], element_dtype=dtypes.float32)
+ data_structures.tf_tensor_list_new(
+ constant_op.constant([1, 2, 3]), element_shape=[1])
def test_tf_tensor_array_new(self):
l = data_structures.tf_tensor_array_new([3, 4, 5])
@@ -141,6 +154,18 @@ class ListTest(test.TestCase):
t = data_structures.list_stack(l, opts)
self.assertAllEqual(sess.run(t), sess.run(initial_list))
+ def test_stack_tensor_list_empty(self):
+ l = list_ops.empty_tensor_list(
+ element_shape=-1,
+ element_dtype=dtypes.variant)
+
+ opts = data_structures.ListStackOpts(
+ element_dtype=dtypes.int32, original_call=None)
+
+ # TODO(mdan): Allow stacking empty lists if the dtype and shape are known.
+ with self.assertRaises(ValueError):
+ data_structures.list_stack(l, opts)
+
def test_stack_fallback(self):
def dummy_function(l):