aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/lang/special_functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/lang/special_functions.py')
-rw-r--r--tensorflow/contrib/autograph/lang/special_functions.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/contrib/autograph/lang/special_functions.py b/tensorflow/contrib/autograph/lang/special_functions.py
index 11135295a7..6149cbbd6c 100644
--- a/tensorflow/contrib/autograph/lang/special_functions.py
+++ b/tensorflow/contrib/autograph/lang/special_functions.py
@@ -26,6 +26,43 @@ from __future__ import print_function
from tensorflow.contrib.autograph.operators import data_structures
+def tensor_list(elements,
+ element_dtype=None,
+ element_shape=None,
+ use_tensor_array=False):
+ """Creates an tensor list and populates it with the given elements.
+
+ This function provides a more uniform access to tensor lists and tensor
+ arrays, and allows optional initialization.
+
+ Note: this function is a simplified wrapper. If you need greater control,
+ it is recommended to use the underlying implementation directly.
+
+ Args:
+ elements: Iterable[tf.Tensor, ...], the elements to initially fill the list
+ with
+ element_dtype: Optional[tf.DType], data type for the elements in the list;
+ required if the list is empty
+ element_shape: Optional[tf.TensorShape], shape for the elements in the list;
+ required if the list is empty
+ use_tensor_array: bool, whether to use the more compatible but restrictive
+ tf.TensorArray implementation
+ Returns:
+ Union[tf.Tensor, tf.TensorArray], the new list.
+ Raises:
+ ValueError: for invalid arguments
+ """
+ if not (elements or (element_dtype and element_shape)):
+ raise ValueError(
+ 'element_dtype and element_shape are required for empty lists')
+ if use_tensor_array:
+ return data_structures.tf_tensor_array_new(elements, element_dtype,
+ element_shape)
+ else:
+ return data_structures.tf_tensor_list_new(elements, element_dtype,
+ element_shape)
+
+
def stack(list_or_tensor, element_dtype=None, strict=True):
"""Stacks the input, if it admits the notion of stacking.