diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2016-12-01 15:15:48 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-01 15:34:47 -0800 |
commit | edc5d96fa11966a9d1cabafd651eacc95c86608a (patch) | |
tree | 9b3686938feab2dd9ea83c4db8de1d66b53e6863 /tensorflow/python/ops/tensor_array_ops.py | |
parent | 6209ae88ca436b13c5807df3bb237a5613d42215 (diff) |
Add element_shape property to TensorArray creation op.
Change: 140784570
Diffstat (limited to 'tensorflow/python/ops/tensor_array_ops.py')
-rw-r--r-- | tensorflow/python/ops/tensor_array_ops.py | 90 |
1 files changed, 49 insertions, 41 deletions
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index 0681728581..f97de8d723 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -82,7 +82,7 @@ class TensorArray(object): def __init__(self, dtype, size=None, dynamic_size=None, clear_after_read=None, tensor_array_name=None, handle=None, - flow=None, infer_shape=True, elem_shape=None, name=None): + flow=None, infer_shape=True, element_shape=None, name=None): """Construct a new TensorArray or wrap an existing TensorArray handle. A note about the parameter `name`: @@ -110,8 +110,9 @@ class TensorArray(object): `TensorArray.flow`. infer_shape: (optional, default: True) If True, shape inference is enabled. In this case, all elements must have the same shape. - elem_shape: (optional, default: None) A TensorShape object specifying - the shape of all the elements of the TensorArray. + element_shape: (optional, default: None) A `TensorShape` object specifying + the shape constraints of each of the elements of the TensorArray. + Need not be fully defined. name: A name for the operation (optional). Raises: @@ -128,6 +129,9 @@ class TensorArray(object): if handle is not None and size is not None: raise ValueError("Cannot provide both a handle and size " "at the same time") + if handle is not None and element_shape is not None: + raise ValueError("Cannot provide both a handle and element_shape " + "at the same time") if handle is not None and dynamic_size is not None: raise ValueError("Cannot provide both a handle and dynamic_size " "at the same time") @@ -141,15 +145,15 @@ class TensorArray(object): self._dtype = dtype # Record the current static shape for the array elements. The element - # shape is defined either by `elem_shape` or the shape of the tensor + # shape is defined either by `element_shape` or the shape of the tensor # of the first write. If `infer_shape` is true, all writes checks for # shape equality. - if elem_shape is None: + if element_shape is None: self._infer_shape = infer_shape - self._elem_shape = [] + self._element_shape = [] else: self._infer_shape = True - self._elem_shape = [tensor_shape.TensorShape(elem_shape)] + self._element_shape = [tensor_shape.TensorShape(element_shape)] with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope: if handle is not None: self._handle = handle @@ -157,7 +161,8 @@ class TensorArray(object): if flow is not None: with ops.colocate_with(flow): self._handle = gen_data_flow_ops._tensor_array_v2( - dtype=dtype, size=size, dynamic_size=dynamic_size, + dtype=dtype, size=size, element_shape=element_shape, + dynamic_size=dynamic_size, clear_after_read=clear_after_read, tensor_array_name=tensor_array_name, name=scope) else: @@ -166,7 +171,8 @@ class TensorArray(object): # will retroactively set the device value of this op. with ops.device(None), ops.colocate_with(None, ignore_existing=True): self._handle = gen_data_flow_ops._tensor_array_v2( - dtype=dtype, size=size, dynamic_size=dynamic_size, + dtype=dtype, size=size, element_shape=element_shape, + dynamic_size=dynamic_size, clear_after_read=clear_after_read, tensor_array_name=tensor_array_name, name=scope) if flow is not None: @@ -205,7 +211,7 @@ class TensorArray(object): flow = array_ops.identity(flow, name="gradient_flow") g = TensorArray(dtype=self._dtype, handle=g_handle, flow=flow, infer_shape=self._infer_shape) - g._elem_shape = self._elem_shape + g._element_shape = self._element_shape return g def read(self, index, name=None): @@ -222,8 +228,8 @@ class TensorArray(object): value = gen_data_flow_ops._tensor_array_read_v2( handle=self._handle, index=index, flow_in=self._flow, dtype=self._dtype, name=name) - if self._elem_shape: - value.set_shape(self._elem_shape[0].dims) + if self._element_shape: + value.set_shape(self._element_shape[0].dims) return value def write(self, index, value, name=None): @@ -251,16 +257,16 @@ class TensorArray(object): ta = TensorArray(dtype=self._dtype, handle=self._handle) ta._flow = flow_out ta._infer_shape = self._infer_shape - ta._elem_shape = self._elem_shape + ta._element_shape = self._element_shape if ta._infer_shape: - val_shape = flow_out.op.inputs[2].get_shape() - if ta._elem_shape: - if not val_shape == ta._elem_shape[0]: + val_shape = value.get_shape() + if ta._element_shape: + if not val_shape == ta._element_shape[0]: raise ValueError( "Inconsistent shapes: saw %s but expected %s " - "(and infer_shape=True)" % (val_shape, ta._elem_shape[0])) + "(and infer_shape=True)" % (val_shape, ta._element_shape[0])) else: - ta._elem_shape.append(val_shape) + ta._element_shape.append(val_shape) return ta def stack(self, name=None): @@ -302,8 +308,8 @@ class TensorArray(object): The in the `TensorArray` selected by `indices`, packed into one tensor. """ with ops.colocate_with(self._handle): - if self._elem_shape: - element_shape = self._elem_shape[0] + if self._element_shape: + element_shape = self._element_shape[0] else: element_shape = tensor_shape.TensorShape(None) value = gen_data_flow_ops._tensor_array_gather_v2( @@ -313,8 +319,8 @@ class TensorArray(object): dtype=self._dtype, name=name, element_shape=element_shape) - if self._elem_shape and self._elem_shape[0].dims is not None: - value.set_shape([None] + self._elem_shape[0].dims) + if self._element_shape and self._element_shape[0].dims is not None: + value.set_shape([None] + self._element_shape[0].dims) return value def concat(self, name=None): @@ -329,9 +335,9 @@ class TensorArray(object): Returns: All the tensors in the TensorArray concatenated into one tensor. """ - if self._elem_shape and self._elem_shape[0].dims is not None: - element_shape_except0 = tensor_shape.TensorShape(self._elem_shape[0].dims[ - 1:]) + if self._element_shape and self._element_shape[0].dims is not None: + element_shape_except0 = ( + tensor_shape.TensorShape(self._element_shape[0].dims[1:])) else: element_shape_except0 = tensor_shape.TensorShape(None) with ops.colocate_with(self._handle): @@ -341,8 +347,8 @@ class TensorArray(object): dtype=self._dtype, name=name, element_shape_except0=element_shape_except0) - if self._elem_shape and self._elem_shape[0].dims is not None: - value.set_shape([None] + self._elem_shape[0].dims[1:]) + if self._element_shape and self._element_shape[0].dims is not None: + value.set_shape([None] + self._element_shape[0].dims[1:]) return value def unstack(self, value, name=None): @@ -401,19 +407,20 @@ class TensorArray(object): ta = TensorArray(dtype=self._dtype, handle=self._handle) ta._flow = flow_out ta._infer_shape = self._infer_shape - ta._elem_shape = self._elem_shape + ta._element_shape = self._element_shape if ta._infer_shape: val_shape = flow_out.op.inputs[2].get_shape() - elem_shape = tensor_shape.unknown_shape() + element_shape = tensor_shape.unknown_shape() if val_shape.dims is not None: - elem_shape = tensor_shape.TensorShape(val_shape.dims[1:]) - if ta._elem_shape: - if not elem_shape == ta._elem_shape[0]: + element_shape = tensor_shape.TensorShape(val_shape.dims[1:]) + if ta._element_shape: + if not element_shape == ta._element_shape[0]: raise ValueError( "Inconsistent shapes: saw %s but expected %s " - "(and infer_shape=True)" % (elem_shape, ta._elem_shape[0])) + "(and infer_shape=True)" + % (element_shape, ta._element_shape[0])) else: - ta._elem_shape.append(elem_shape) + ta._element_shape.append(element_shape) return ta def split(self, value, lengths, name=None): @@ -444,22 +451,23 @@ class TensorArray(object): ta = TensorArray(dtype=self._dtype, handle=self._handle) ta._flow = flow_out ta._infer_shape = self._infer_shape - ta._elem_shape = self._elem_shape + ta._element_shape = self._element_shape if ta._infer_shape: val_shape = flow_out.op.inputs[1].get_shape() clengths = tensor_util.constant_value(flow_out.op.inputs[2]) - elem_shape = tensor_shape.unknown_shape() + element_shape = tensor_shape.unknown_shape() if val_shape.dims is not None: if clengths is not None and clengths.max() == clengths.min(): - elem_shape = tensor_shape.TensorShape( + element_shape = tensor_shape.TensorShape( [clengths[0]] + val_shape.dims[1:]) - if ta._elem_shape: - if not elem_shape == ta._elem_shape[0]: + if ta._element_shape: + if not element_shape == ta._element_shape[0]: raise ValueError( "Inconsistent shapes: saw %s but expected %s " - "(and infer_shape=True)" % (elem_shape, ta._elem_shape[0])) + "(and infer_shape=True)" + % (element_shape, ta._element_shape[0])) else: - ta._elem_shape.append(elem_shape) + ta._element_shape.append(element_shape) return ta def size(self, name=None): |