aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/tensor_array_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/tensor_array_ops.py')
-rw-r--r--tensorflow/python/ops/tensor_array_ops.py22
1 files changed, 14 insertions, 8 deletions
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index c12506694f..901dfbe913 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -46,8 +46,8 @@ class TensorArray(object):
@@grad
"""
- def __init__(
- self, dtype, size=None, tensor_array_name=None, handle=None, name=None):
+ def __init__(self, dtype, size=None, tensor_array_name=None,
+ handle=None, flow=None, name=None):
"""Construct a new TensorArray or wrap an existing TensorArray handle.
Args:
@@ -59,6 +59,8 @@ class TensorArray(object):
set, handle should be None.
handle: (optional) A `Tensor` handle to an existing TensorArray. If this
is set, tensor_array_name should be None.
+ flow: (optional) A float `Tensor` scalar coming from an existing
+ TensorArray.flow.
name: A name for the operation (optional).
Raises:
@@ -73,16 +75,15 @@ class TensorArray(object):
if handle is None and size is None:
raise ValueError("Size must be provided if handle is not provided")
- with ops.op_scope([handle, size], name, "TensorArray") as scope:
+ self._dtype = dtype
+ with ops.op_scope([handle, size, flow], name, "TensorArray") as scope:
if handle:
self._handle = handle
else:
self._handle = gen_data_flow_ops._tensor_array(
dtype=dtype, size=size, tensor_array_name=tensor_array_name,
name=scope)
-
- self._flow = constant_op.constant(0, dtype=_dtypes.float32)
- self._dtype = dtype
+ self._flow = flow or constant_op.constant(0, dtype=_dtypes.float32)
@property
def flow(self):
@@ -90,14 +91,19 @@ class TensorArray(object):
return self._flow
@property
+ def dtype(self):
+ """The data type of this TensorArray."""
+ return self._dtype
+
+ @property
def handle(self):
"""The reference to the TensorArray."""
return self._handle
- def grad(self, source):
+ def grad(self, source, flow=None):
g_handle = gen_data_flow_ops._tensor_array_grad(
handle=self._handle, source=source)
- g = TensorArray(dtype=self._dtype, size=None, handle=g_handle)
+ g = TensorArray(dtype=self._dtype, size=None, handle=g_handle, flow=flow)
return g
def read(self, index, name=None):