diff options
Diffstat (limited to 'tensorflow/python/ops/tensor_array_ops.py')
-rw-r--r-- | tensorflow/python/ops/tensor_array_ops.py | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index c12506694f..901dfbe913 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -46,8 +46,8 @@ class TensorArray(object): @@grad """ - def __init__( - self, dtype, size=None, tensor_array_name=None, handle=None, name=None): + def __init__(self, dtype, size=None, tensor_array_name=None, + handle=None, flow=None, name=None): """Construct a new TensorArray or wrap an existing TensorArray handle. Args: @@ -59,6 +59,8 @@ class TensorArray(object): set, handle should be None. handle: (optional) A `Tensor` handle to an existing TensorArray. If this is set, tensor_array_name should be None. + flow: (optional) A float `Tensor` scalar coming from an existing + TensorArray.flow. name: A name for the operation (optional). Raises: @@ -73,16 +75,15 @@ class TensorArray(object): if handle is None and size is None: raise ValueError("Size must be provided if handle is not provided") - with ops.op_scope([handle, size], name, "TensorArray") as scope: + self._dtype = dtype + with ops.op_scope([handle, size, flow], name, "TensorArray") as scope: if handle: self._handle = handle else: self._handle = gen_data_flow_ops._tensor_array( dtype=dtype, size=size, tensor_array_name=tensor_array_name, name=scope) - - self._flow = constant_op.constant(0, dtype=_dtypes.float32) - self._dtype = dtype + self._flow = flow or constant_op.constant(0, dtype=_dtypes.float32) @property def flow(self): @@ -90,14 +91,19 @@ class TensorArray(object): return self._flow @property + def dtype(self): + """The data type of this TensorArray.""" + return self._dtype + + @property def handle(self): """The reference to the TensorArray.""" return self._handle - def grad(self, source): + def grad(self, source, flow=None): g_handle = gen_data_flow_ops._tensor_array_grad( handle=self._handle, source=source) - g = TensorArray(dtype=self._dtype, size=None, handle=g_handle) + g = TensorArray(dtype=self._dtype, size=None, handle=g_handle, flow=flow) return g def read(self, index, name=None): |