diff options
author | 2016-02-03 09:23:32 -0800 | |
---|---|---|
committer | 2016-02-03 09:36:02 -0800 | |
commit | 520247f8a4fb59b163a71482268c830de94be09c (patch) | |
tree | 0c992aab179084683d36e104ae48b442b81d286f /tensorflow/core/kernels/tensor_array.h | |
parent | e1b77eb2931c60c71bc149bb1ec13394c508162b (diff) |
Improve TensorArray: reduce thread contention, add dynamic write and size.
Change: 113749476
Diffstat (limited to 'tensorflow/core/kernels/tensor_array.h')
-rw-r--r-- | tensorflow/core/kernels/tensor_array.h | 117 |
1 files changed, 58 insertions, 59 deletions
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h index a6637a2340..82026c3cb4 100644 --- a/tensorflow/core/kernels/tensor_array.h +++ b/tensorflow/core/kernels/tensor_array.h @@ -24,16 +24,12 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/aggregate_ops.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { -typedef Eigen::ThreadPoolDevice CPUDevice; -typedef Eigen::GpuDevice GPUDevice; - // The TensorArray object keeps an array of PersistentTensors. It // allows reading from the array and writing to the array. // @@ -64,14 +60,22 @@ class TensorArray : public ResourceBase { // 'N' elements. While the underlying storage is a std::vector and // can hold more than MAX_INT entries, in practice we do not expect // users to construct this many Tensors for storage in a TensorArray. - TensorArray(const DataType& dtype, const Tensor& handle, int32 N) - : dtype_(dtype), handle_(handle), closed_(false), tensors_(N) {} + TensorArray(const DataType& dtype, const Tensor& handle, int32 N, + bool dynamic_size) + : dtype_(dtype), + handle_(handle), + closed_(false), + dynamic_size_(dynamic_size), + tensors_(N) {} // Write PersistentTensor 'value' to index 'index'. // // Preconditions: // * The TensorArray is not closed - // * The index is in [0, N) + // * If the array has dynamic size: + // The index is >= 0 + // Otherwise: + // The index is in [0, N) where N == Size() // * The dtype of the Tensor in 'value' matches the TensorArray's dtype. // * The Tensor at 'index' has not yet been written to. // @@ -80,34 +84,19 @@ class TensorArray : public ResourceBase { // * Index 'index' is marked as written. // // Note, value is passed as a pointer because we its underlying - // Tesnor's shape is accessed. Otherwise it is not modified. + // Tensor's shape is accessed. Otherwise it is not modified. Status Write(OpKernelContext* ctx, const int32 index, PersistentTensor* value) { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(LockedReturnIfClosed()); - if (index < 0 || static_cast<size_t>(index) >= tensors_.size()) { - return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1), - ": Tried to write to index ", index, - " but array size is: ", tensors_.size()); - } - TensorAndState& t = tensors_[index]; - if (t.written) { - return errors::InvalidArgument( - "TensorArray ", handle_.vec<string>()(1), - ": Could not write to TensorArray index ", index, - " because it has already been written to."); - } - Tensor* value_t = value->AccessTensor(ctx); - if (value_t->dtype() != dtype_) { - return errors::InvalidArgument( - "TensorArray ", handle_.vec<string>()(1), - ": Could not write to TensorArray index ", index, - " because the value dtype is ", DataTypeString(value_t->dtype()), - " but TensorArray dtype is ", DataTypeString(dtype_), "."); + return LockedWrite(ctx, index, value); + } + + Status WriteMany(OpKernelContext* ctx, + std::vector<PersistentTensor>* values) { + mutex_lock l(mu_); + for (int32 i = values->size() - 1; i >= 0; --i) { + TF_RETURN_IF_ERROR(LockedWrite(ctx, i, &(*values)[i])); } - t.tensor = *value; - t.shape = value_t->shape(); - t.written = true; return Status::OK(); } @@ -126,34 +115,16 @@ class TensorArray : public ResourceBase { // * Index 'index' is marked as read. Status Read(const int32 index, PersistentTensor* value) { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(LockedReturnIfClosed()); - if (index < 0 || static_cast<size_t>(index) >= tensors_.size()) { - return errors::InvalidArgument("Tried to read from index ", index, - " but array size is: ", tensors_.size()); - } - TensorAndState& t = tensors_[index]; - if (t.read) { - return errors::InvalidArgument( - "TensorArray ", handle_.vec<string>()(1), ": Could not read index ", - index, " twice because TensorArray a read-once object."); - } - if (!t.written) { - return errors::InvalidArgument( - "TensorArray ", handle_.vec<string>()(1), - ": Could not read from TensorArray index ", index, - " because it has not yet been written to."); - } - *value = t.tensor; - t.read = true; - t.tensor = PersistentTensor(); - return Status::OK(); + return LockedRead(index, value); } - // Return the Size of the TensorArray. - Status Size(int32* size) { + Status ReadMany(std::vector<PersistentTensor>* values) { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(LockedReturnIfClosed()); - *size = tensors_.size(); + values->clear(); + values->resize(tensors_.size()); + for (int32 i = 0; i < tensors_.size(); ++i) { + TF_RETURN_IF_ERROR(LockedRead(i, &(*values)[i])); + } return Status::OK(); } @@ -165,11 +136,31 @@ class TensorArray : public ResourceBase { return strings::StrCat("TensorArray[", tensors_.size(), "]"); } - inline bool IsClosed() { + bool IsClosed() { mutex_lock l(mu_); return closed_; } + // Return the Size of the TensorArray. + Status Size(int32* size) { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(LockedReturnIfClosed()); + *size = tensors_.size(); + return Status::OK(); + } + + // Once a TensorArray is being used for gradient calculations, it + // should be marked as no longer resizeable. + void DisableDynamicSize() { + mutex_lock l(mu_); + dynamic_size_ = false; + } + + bool HasDynamicSize() { + mutex_lock l(mu_); + return dynamic_size_; + } + // Clear the TensorArray, including any Tensor references, and mark as closed. void ClearAndMarkClosed() { mutex_lock l(mu_); @@ -181,7 +172,13 @@ class TensorArray : public ResourceBase { Tensor* handle() { return &handle_; } private: - Status LockedReturnIfClosed() const { + Status LockedWrite(OpKernelContext* ctx, const int32 index, + PersistentTensor* value) EXCLUSIVE_LOCKS_REQUIRED(mu_); + + Status LockedRead(const int32 index, PersistentTensor* value) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + Status LockedReturnIfClosed() const EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (closed_) { return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1), " has already been closed."); @@ -189,7 +186,7 @@ class TensorArray : public ResourceBase { return Status::OK(); } - DataType dtype_; + const DataType dtype_; Tensor handle_; mutex mu_; @@ -197,6 +194,8 @@ class TensorArray : public ResourceBase { bool closed_ GUARDED_BY(mu_); // Marks that the tensor_array_ has been cleared. + bool dynamic_size_; // Determines if Writes are allowed to grow the array. + // TensorAndState is used to keep track of the PersistentTensors // stored in the TensorArray, along with their shapes, and a boolean // that determines whether they have already been read or not. |