diff options
Diffstat (limited to 'tensorflow/core/kernels/tensor_array.h')
-rw-r--r-- | tensorflow/core/kernels/tensor_array.h | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h index ae1700cd0a..4704130994 100644 --- a/tensorflow/core/kernels/tensor_array.h +++ b/tensorflow/core/kernels/tensor_array.h @@ -132,11 +132,12 @@ class TensorArray : public ResourceBase { // 'N' elements. While the underlying storage is a std::vector and // can hold more than MAX_INT entries, in practice we do not expect // users to construct this many Tensors for storage in a TensorArray. - TensorArray(const DataType& dtype, const Tensor& handle, int32 N, - const PartialTensorShape& element_shape, bool dynamic_size, - bool multiple_writes_aggregate, bool is_grad, int32 marked_size, - bool clear_after_read) - : dtype_(dtype), + TensorArray(const string& key, const DataType& dtype, const Tensor& handle, + int32 N, const PartialTensorShape& element_shape, + bool dynamic_size, bool multiple_writes_aggregate, bool is_grad, + int32 marked_size, bool clear_after_read) + : key_(key), + dtype_(dtype), handle_(handle), closed_(false), dynamic_size_(dynamic_size), @@ -334,6 +335,10 @@ class TensorArray : public ResourceBase { mutex* mu() { return &mu_; } Tensor* handle() { return &handle_; } + ResourceHandle resource_handle(OpKernelContext* ctx) { + return MakePerStepResourceHandle<TensorArray>(ctx, key_); + } + private: Status LockedWrite(OpKernelContext* ctx, const int32 index, PersistentTensor* value) EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -355,6 +360,8 @@ class TensorArray : public ResourceBase { return Status::OK(); } + const string key_; + const DataType dtype_; Tensor handle_; |