aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/tensor_array_ops.py
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-12-01 15:15:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-01 15:34:47 -0800
commitedc5d96fa11966a9d1cabafd651eacc95c86608a (patch)
tree9b3686938feab2dd9ea83c4db8de1d66b53e6863 /tensorflow/python/ops/tensor_array_ops.py
parent6209ae88ca436b13c5807df3bb237a5613d42215 (diff)
Add element_shape property to TensorArray creation op.
Change: 140784570
Diffstat (limited to 'tensorflow/python/ops/tensor_array_ops.py')
-rw-r--r--tensorflow/python/ops/tensor_array_ops.py90
1 files changed, 49 insertions, 41 deletions
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index 0681728581..f97de8d723 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -82,7 +82,7 @@ class TensorArray(object):
def __init__(self, dtype, size=None, dynamic_size=None,
clear_after_read=None, tensor_array_name=None, handle=None,
- flow=None, infer_shape=True, elem_shape=None, name=None):
+ flow=None, infer_shape=True, element_shape=None, name=None):
"""Construct a new TensorArray or wrap an existing TensorArray handle.
A note about the parameter `name`:
@@ -110,8 +110,9 @@ class TensorArray(object):
`TensorArray.flow`.
infer_shape: (optional, default: True) If True, shape inference
is enabled. In this case, all elements must have the same shape.
- elem_shape: (optional, default: None) A TensorShape object specifying
- the shape of all the elements of the TensorArray.
+ element_shape: (optional, default: None) A `TensorShape` object specifying
+ the shape constraints of each of the elements of the TensorArray.
+ Need not be fully defined.
name: A name for the operation (optional).
Raises:
@@ -128,6 +129,9 @@ class TensorArray(object):
if handle is not None and size is not None:
raise ValueError("Cannot provide both a handle and size "
"at the same time")
+ if handle is not None and element_shape is not None:
+ raise ValueError("Cannot provide both a handle and element_shape "
+ "at the same time")
if handle is not None and dynamic_size is not None:
raise ValueError("Cannot provide both a handle and dynamic_size "
"at the same time")
@@ -141,15 +145,15 @@ class TensorArray(object):
self._dtype = dtype
# Record the current static shape for the array elements. The element
- # shape is defined either by `elem_shape` or the shape of the tensor
+ # shape is defined either by `element_shape` or the shape of the tensor
# of the first write. If `infer_shape` is true, all writes checks for
# shape equality.
- if elem_shape is None:
+ if element_shape is None:
self._infer_shape = infer_shape
- self._elem_shape = []
+ self._element_shape = []
else:
self._infer_shape = True
- self._elem_shape = [tensor_shape.TensorShape(elem_shape)]
+ self._element_shape = [tensor_shape.TensorShape(element_shape)]
with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope:
if handle is not None:
self._handle = handle
@@ -157,7 +161,8 @@ class TensorArray(object):
if flow is not None:
with ops.colocate_with(flow):
self._handle = gen_data_flow_ops._tensor_array_v2(
- dtype=dtype, size=size, dynamic_size=dynamic_size,
+ dtype=dtype, size=size, element_shape=element_shape,
+ dynamic_size=dynamic_size,
clear_after_read=clear_after_read,
tensor_array_name=tensor_array_name, name=scope)
else:
@@ -166,7 +171,8 @@ class TensorArray(object):
# will retroactively set the device value of this op.
with ops.device(None), ops.colocate_with(None, ignore_existing=True):
self._handle = gen_data_flow_ops._tensor_array_v2(
- dtype=dtype, size=size, dynamic_size=dynamic_size,
+ dtype=dtype, size=size, element_shape=element_shape,
+ dynamic_size=dynamic_size,
clear_after_read=clear_after_read,
tensor_array_name=tensor_array_name, name=scope)
if flow is not None:
@@ -205,7 +211,7 @@ class TensorArray(object):
flow = array_ops.identity(flow, name="gradient_flow")
g = TensorArray(dtype=self._dtype, handle=g_handle, flow=flow,
infer_shape=self._infer_shape)
- g._elem_shape = self._elem_shape
+ g._element_shape = self._element_shape
return g
def read(self, index, name=None):
@@ -222,8 +228,8 @@ class TensorArray(object):
value = gen_data_flow_ops._tensor_array_read_v2(
handle=self._handle, index=index, flow_in=self._flow,
dtype=self._dtype, name=name)
- if self._elem_shape:
- value.set_shape(self._elem_shape[0].dims)
+ if self._element_shape:
+ value.set_shape(self._element_shape[0].dims)
return value
def write(self, index, value, name=None):
@@ -251,16 +257,16 @@ class TensorArray(object):
ta = TensorArray(dtype=self._dtype, handle=self._handle)
ta._flow = flow_out
ta._infer_shape = self._infer_shape
- ta._elem_shape = self._elem_shape
+ ta._element_shape = self._element_shape
if ta._infer_shape:
- val_shape = flow_out.op.inputs[2].get_shape()
- if ta._elem_shape:
- if not val_shape == ta._elem_shape[0]:
+ val_shape = value.get_shape()
+ if ta._element_shape:
+ if not val_shape == ta._element_shape[0]:
raise ValueError(
"Inconsistent shapes: saw %s but expected %s "
- "(and infer_shape=True)" % (val_shape, ta._elem_shape[0]))
+ "(and infer_shape=True)" % (val_shape, ta._element_shape[0]))
else:
- ta._elem_shape.append(val_shape)
+ ta._element_shape.append(val_shape)
return ta
def stack(self, name=None):
@@ -302,8 +308,8 @@ class TensorArray(object):
The in the `TensorArray` selected by `indices`, packed into one tensor.
"""
with ops.colocate_with(self._handle):
- if self._elem_shape:
- element_shape = self._elem_shape[0]
+ if self._element_shape:
+ element_shape = self._element_shape[0]
else:
element_shape = tensor_shape.TensorShape(None)
value = gen_data_flow_ops._tensor_array_gather_v2(
@@ -313,8 +319,8 @@ class TensorArray(object):
dtype=self._dtype,
name=name,
element_shape=element_shape)
- if self._elem_shape and self._elem_shape[0].dims is not None:
- value.set_shape([None] + self._elem_shape[0].dims)
+ if self._element_shape and self._element_shape[0].dims is not None:
+ value.set_shape([None] + self._element_shape[0].dims)
return value
def concat(self, name=None):
@@ -329,9 +335,9 @@ class TensorArray(object):
Returns:
All the tensors in the TensorArray concatenated into one tensor.
"""
- if self._elem_shape and self._elem_shape[0].dims is not None:
- element_shape_except0 = tensor_shape.TensorShape(self._elem_shape[0].dims[
- 1:])
+ if self._element_shape and self._element_shape[0].dims is not None:
+ element_shape_except0 = (
+ tensor_shape.TensorShape(self._element_shape[0].dims[1:]))
else:
element_shape_except0 = tensor_shape.TensorShape(None)
with ops.colocate_with(self._handle):
@@ -341,8 +347,8 @@ class TensorArray(object):
dtype=self._dtype,
name=name,
element_shape_except0=element_shape_except0)
- if self._elem_shape and self._elem_shape[0].dims is not None:
- value.set_shape([None] + self._elem_shape[0].dims[1:])
+ if self._element_shape and self._element_shape[0].dims is not None:
+ value.set_shape([None] + self._element_shape[0].dims[1:])
return value
def unstack(self, value, name=None):
@@ -401,19 +407,20 @@ class TensorArray(object):
ta = TensorArray(dtype=self._dtype, handle=self._handle)
ta._flow = flow_out
ta._infer_shape = self._infer_shape
- ta._elem_shape = self._elem_shape
+ ta._element_shape = self._element_shape
if ta._infer_shape:
val_shape = flow_out.op.inputs[2].get_shape()
- elem_shape = tensor_shape.unknown_shape()
+ element_shape = tensor_shape.unknown_shape()
if val_shape.dims is not None:
- elem_shape = tensor_shape.TensorShape(val_shape.dims[1:])
- if ta._elem_shape:
- if not elem_shape == ta._elem_shape[0]:
+ element_shape = tensor_shape.TensorShape(val_shape.dims[1:])
+ if ta._element_shape:
+ if not element_shape == ta._element_shape[0]:
raise ValueError(
"Inconsistent shapes: saw %s but expected %s "
- "(and infer_shape=True)" % (elem_shape, ta._elem_shape[0]))
+ "(and infer_shape=True)"
+ % (element_shape, ta._element_shape[0]))
else:
- ta._elem_shape.append(elem_shape)
+ ta._element_shape.append(element_shape)
return ta
def split(self, value, lengths, name=None):
@@ -444,22 +451,23 @@ class TensorArray(object):
ta = TensorArray(dtype=self._dtype, handle=self._handle)
ta._flow = flow_out
ta._infer_shape = self._infer_shape
- ta._elem_shape = self._elem_shape
+ ta._element_shape = self._element_shape
if ta._infer_shape:
val_shape = flow_out.op.inputs[1].get_shape()
clengths = tensor_util.constant_value(flow_out.op.inputs[2])
- elem_shape = tensor_shape.unknown_shape()
+ element_shape = tensor_shape.unknown_shape()
if val_shape.dims is not None:
if clengths is not None and clengths.max() == clengths.min():
- elem_shape = tensor_shape.TensorShape(
+ element_shape = tensor_shape.TensorShape(
[clengths[0]] + val_shape.dims[1:])
- if ta._elem_shape:
- if not elem_shape == ta._elem_shape[0]:
+ if ta._element_shape:
+ if not element_shape == ta._element_shape[0]:
raise ValueError(
"Inconsistent shapes: saw %s but expected %s "
- "(and infer_shape=True)" % (elem_shape, ta._elem_shape[0]))
+ "(and infer_shape=True)"
+ % (element_shape, ta._element_shape[0]))
else:
- ta._elem_shape.append(elem_shape)
+ ta._element_shape.append(element_shape)
return ta
def size(self, name=None):