aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/operators/data_structures.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/operators/data_structures.py')
-rw-r--r--tensorflow/contrib/autograph/operators/data_structures.py91
1 files changed, 81 insertions, 10 deletions
diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/contrib/autograph/operators/data_structures.py
index 06d8727b0f..cc0a3c3544 100644
--- a/tensorflow/contrib/autograph/operators/data_structures.py
+++ b/tensorflow/contrib/autograph/operators/data_structures.py
@@ -28,7 +28,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import tensor_array_ops
-from tensorflow.python.ops import variables
# TODO(mdan): Once control flow supports objects, repackage as a class.
@@ -48,29 +47,101 @@ def new_list(iterable=None):
else:
elements = ()
- # TODO(mdan): Extend these criteria.
- if any(isinstance(el, variables.Variable) for el in elements):
+ if elements:
+ # When the list contains elements, it is assumed to be a "Python" lvalue
+ # list.
return _py_list_new(elements)
- return _tf_tensor_list_new(elements)
+ return tf_tensor_list_new(elements)
-def _tf_tensor_list_new(elements):
+def tf_tensor_array_new(elements, element_dtype=None, element_shape=None):
"""Overload of new_list that stages a Tensor list creation."""
elements = tuple(ops.convert_to_tensor(el) for el in elements)
+
+ all_dtypes = set(el.dtype for el in elements)
+ if len(all_dtypes) == 1:
+ inferred_dtype, = tuple(all_dtypes)
+ if element_dtype is not None and element_dtype != inferred_dtype:
+ raise ValueError(
+ 'incompatible dtype; specified: {}, inferred from {}: {}'.format(
+ element_dtype, elements, inferred_dtype))
+ elif len(all_dtypes) > 1:
+ raise ValueError(
+ 'TensorArray requires all elements to have the same dtype:'
+ ' {}'.format(elements))
+ else:
+ if element_dtype is None:
+ raise ValueError('dtype is required to create an empty TensorArray')
+
+ all_shapes = set(tuple(el.shape.as_list()) for el in elements)
+ if len(all_shapes) == 1:
+ inferred_shape, = tuple(all_shapes)
+ if element_shape is not None and element_shape != inferred_shape:
+ raise ValueError(
+ 'incompatible shape; specified: {}, inferred from {}: {}'.format(
+ element_shape, elements, inferred_shape))
+ elif len(all_shapes) > 1:
+ raise ValueError(
+ 'TensorArray requires all elements to have the same shape:'
+ ' {}'.format(elements))
+ # TODO(mdan): We may want to allow different shapes with infer_shape=False.
+ else:
+ inferred_shape = None
+
+ if element_dtype is None:
+ element_dtype = inferred_dtype
+ if element_shape is None:
+ element_shape = inferred_shape
+
+ l = tensor_array_ops.TensorArray(
+ dtype=element_dtype,
+ size=len(elements),
+ dynamic_size=True,
+ infer_shape=(element_shape is None),
+ element_shape=element_shape)
+ for i, el in enumerate(elements):
+ l = l.write(i, el)
+ return l
+
+
+def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
+ """Overload of new_list that stages a Tensor list creation."""
+ elements = tuple(ops.convert_to_tensor(el) for el in elements)
+
all_dtypes = set(el.dtype for el in elements)
if len(all_dtypes) == 1:
- element_dtype = tuple(all_dtypes)[0]
+ inferred_dtype = tuple(all_dtypes)[0]
+ if element_dtype is not None and element_dtype != inferred_dtype:
+ raise ValueError(
+ 'incompatible dtype; specified: {}, inferred from {}: {}'.format(
+ element_dtype, elements, inferred_dtype))
else:
# Heterogeneous lists are ok.
- element_dtype = dtypes.variant
+ if element_dtype is not None:
+ raise ValueError(
+ 'specified dtype {} is inconsistent with that of elements {}'.format(
+ element_dtype, elements))
+ inferred_dtype = dtypes.variant
- # TODO(mdan): This may fail for elements of variable shapes.
all_shapes = set(tuple(el.shape.as_list()) for el in elements)
if len(all_shapes) == 1:
- element_shape = array_ops.shape(elements[0])
+ inferred_shape = array_ops.shape(elements[0])
+ if element_shape is not None and element_shape != inferred_shape:
+ raise ValueError(
+ 'incompatible shape; specified: {}, inferred from {}: {}'.format(
+ element_shape, elements, inferred_shape))
else:
# Heterogeneous lists are ok.
- element_shape = constant_op.constant(-1) # unknown shape, by convention
+ if element_shape is not None:
+ raise ValueError(
+ 'specified shape {} is inconsistent with that of elements {}'.format(
+ element_shape, elements))
+ inferred_shape = constant_op.constant(-1) # unknown shape, by convention
+
+ if element_dtype is None:
+ element_dtype = inferred_dtype
+ if element_shape is None:
+ element_shape = inferred_shape
l = list_ops.empty_tensor_list(
element_shape=element_shape, element_dtype=element_dtype)