diff options
author | 2018-10-02 12:15:36 -0700 | |
---|---|---|
committer | 2018-10-02 12:24:39 -0700 | |
commit | 8d4ef71f06a06a093419bf0f80562a1941059029 (patch) | |
tree | dfca34248365279eefa5293351d172e4909ccc5d /tensorflow/python/autograph | |
parent | 16b44d48d485dbb62b9922e172df4cc460174046 (diff) |
Allow creating a list from a tensor. Fix a few inconsistencies in the tensor list constructors.
PiperOrigin-RevId: 215435720
Diffstat (limited to 'tensorflow/python/autograph')
4 files changed, 99 insertions, 10 deletions
diff --git a/tensorflow/python/autograph/lang/special_functions.py b/tensorflow/python/autograph/lang/special_functions.py index e4838d1b6d..62ac018ac4 100644 --- a/tensorflow/python/autograph/lang/special_functions.py +++ b/tensorflow/python/autograph/lang/special_functions.py @@ -24,6 +24,26 @@ from __future__ import division from __future__ import print_function from tensorflow.python.autograph.operators import data_structures +from tensorflow.python.framework import tensor_util + + +def _validate_list_constructor(elements, element_dtype, element_shape): + """Validates the inputs of tensor_list.""" + if element_dtype is not None and element_shape is not None: + return + if tensor_util.is_tensor(elements): + return + if isinstance(elements, (list, tuple)): + if elements: + return + else: + raise ValueError( + 'element_dtype and element_shape are required when elements are' + ' empty') + + raise ValueError( + 'unknown type for elements: {}; only Tensor, list and tuple are' + ' allowed'.format(type(elements))) def tensor_list(elements, @@ -52,9 +72,7 @@ def tensor_list(elements, Raises: ValueError: for invalid arguments """ - if not (elements or (element_dtype and element_shape)): - raise ValueError( - 'element_dtype and element_shape are required for empty lists') + _validate_list_constructor(elements, element_dtype, element_shape) if use_tensor_array: return data_structures.tf_tensor_array_new(elements, element_dtype, element_shape) diff --git a/tensorflow/python/autograph/lang/special_functions_test.py b/tensorflow/python/autograph/lang/special_functions_test.py index 545dd11729..206a32d07c 100644 --- a/tensorflow/python/autograph/lang/special_functions_test.py +++ b/tensorflow/python/autograph/lang/special_functions_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.autograph.lang import special_functions from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -28,12 +30,43 @@ from tensorflow.python.platform import test class SpecialFunctionsTest(test.TestCase): + def test_tensor_list_empty_list(self): + l = special_functions.tensor_list([], + element_dtype=dtypes.int32, + element_shape=()) + sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.test_session() as sess: + self.assertAllEqual(sess.run(sl), []) + + l = special_functions.tensor_list((), + element_dtype=dtypes.int32, + element_shape=()) + sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.test_session() as sess: + self.assertAllEqual(sess.run(sl), []) + + def test_tensor_list_tensor(self): + l = special_functions.tensor_list( + constant_op.constant([], dtype=dtypes.int32)) + sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.test_session() as sess: + self.assertAllEqual(sess.run(sl), []) + + def test_tensor_list_unsupported_initializer(self): + with self.assertRaisesRegexp(ValueError, 'unknown type'): + special_functions.tensor_list(np.array([1, 2, 3])) + + def test_tensor_list_empty_list_no_type(self): + with self.assertRaisesRegexp( + ValueError, 'element_dtype and element_shape are required'): + special_functions.tensor_list([]) + def test_tensor_list_from_elements(self): elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])] l = special_functions.tensor_list(elements) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) - with self.cached_session() as sess: + with self.test_session() as sess: self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]]) def test_tensor_list_array_from_elements(self): @@ -41,7 +74,7 @@ class SpecialFunctionsTest(test.TestCase): l = special_functions.tensor_list(elements, use_tensor_array=True) sl = l.stack() - with self.cached_session() as sess: + with self.test_session() as sess: self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]]) def test_stack(self): diff --git a/tensorflow/python/autograph/operators/data_structures.py b/tensorflow/python/autograph/operators/data_structures.py index cc0a3c3544..b3a3851333 100644 --- a/tensorflow/python/autograph/operators/data_structures.py +++ b/tensorflow/python/autograph/operators/data_structures.py @@ -106,6 +106,14 @@ def tf_tensor_array_new(elements, element_dtype=None, element_shape=None): def tf_tensor_list_new(elements, element_dtype=None, element_shape=None): """Overload of new_list that stages a Tensor list creation.""" + if tensor_util.is_tensor(elements): + if element_shape is not None: + raise ValueError( + 'element shape may not be specified when creating list from tensor') + element_shape = array_ops.shape(elements)[1:] + l = list_ops.tensor_list_from_tensor(elements, element_shape=element_shape) + return l + elements = tuple(ops.convert_to_tensor(el) for el in elements) all_dtypes = set(el.dtype for el in elements) @@ -115,13 +123,15 @@ def tf_tensor_list_new(elements, element_dtype=None, element_shape=None): raise ValueError( 'incompatible dtype; specified: {}, inferred from {}: {}'.format( element_dtype, elements, inferred_dtype)) - else: + elif all_dtypes: # Heterogeneous lists are ok. if element_dtype is not None: raise ValueError( 'specified dtype {} is inconsistent with that of elements {}'.format( element_dtype, elements)) inferred_dtype = dtypes.variant + else: + inferred_dtype = dtypes.variant all_shapes = set(tuple(el.shape.as_list()) for el in elements) if len(all_shapes) == 1: @@ -130,19 +140,22 @@ def tf_tensor_list_new(elements, element_dtype=None, element_shape=None): raise ValueError( 'incompatible shape; specified: {}, inferred from {}: {}'.format( element_shape, elements, inferred_shape)) - else: + elif all_shapes: # Heterogeneous lists are ok. if element_shape is not None: raise ValueError( 'specified shape {} is inconsistent with that of elements {}'.format( element_shape, elements)) inferred_shape = constant_op.constant(-1) # unknown shape, by convention + else: + inferred_shape = constant_op.constant(-1) # unknown shape, by convention if element_dtype is None: element_dtype = inferred_dtype if element_shape is None: element_shape = inferred_shape + element_shape = ops.convert_to_tensor(element_shape, dtype=dtypes.int32) l = list_ops.empty_tensor_list( element_shape=element_shape, element_dtype=element_dtype) for el in elements: diff --git a/tensorflow/python/autograph/operators/data_structures_test.py b/tensorflow/python/autograph/operators/data_structures_test.py index 8532dbe466..6039b07982 100644 --- a/tensorflow/python/autograph/operators/data_structures_test.py +++ b/tensorflow/python/autograph/operators/data_structures_test.py @@ -45,6 +45,20 @@ class ListTest(test.TestCase): with self.cached_session() as sess: self.assertAllEqual(sess.run(t), [3, 4, 5]) + def test_tf_tensor_list_new_empty(self): + l = data_structures.tf_tensor_list_new([], + element_dtype=dtypes.int32, + element_shape=()) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.cached_session() as sess: + self.assertAllEqual(sess.run(t), []) + + def test_tf_tensor_list_new_from_tensor(self): + l = data_structures.tf_tensor_list_new(constant_op.constant([3, 4, 5])) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.cached_session() as sess: + self.assertAllEqual(sess.run(t), [3, 4, 5]) + def test_tf_tensor_list_new_illegal_input(self): with self.assertRaises(ValueError): data_structures.tf_tensor_list_new([3, 4.0]) @@ -56,9 +70,8 @@ class ListTest(test.TestCase): with self.assertRaises(ValueError): data_structures.tf_tensor_list_new([3, 4], element_shape=(2,)) with self.assertRaises(ValueError): - data_structures.tf_tensor_list_new([], element_shape=(2,)) - with self.assertRaises(ValueError): - data_structures.tf_tensor_list_new([], element_dtype=dtypes.float32) + data_structures.tf_tensor_list_new( + constant_op.constant([1, 2, 3]), element_shape=[1]) def test_tf_tensor_array_new(self): l = data_structures.tf_tensor_array_new([3, 4, 5]) @@ -141,6 +154,18 @@ class ListTest(test.TestCase): t = data_structures.list_stack(l, opts) self.assertAllEqual(sess.run(t), sess.run(initial_list)) + def test_stack_tensor_list_empty(self): + l = list_ops.empty_tensor_list( + element_shape=-1, + element_dtype=dtypes.variant) + + opts = data_structures.ListStackOpts( + element_dtype=dtypes.int32, original_call=None) + + # TODO(mdan): Allow stacking empty lists if the dtype and shape are known. + with self.assertRaises(ValueError): + data_structures.list_stack(l, opts) + def test_stack_fallback(self): def dummy_function(l): |