diff options
Diffstat (limited to 'tensorflow/core/kernels/data/slide_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/slide_dataset_op.cc | 157 |
1 files changed, 84 insertions, 73 deletions
diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc index c17e9343ea..5765c61f30 100644 --- a/tensorflow/core/kernels/data/slide_dataset_op.cc +++ b/tensorflow/core/kernels/data/slide_dataset_op.cc @@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + +#include <deque> +#include <vector> + #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/dataset.h" @@ -33,37 +37,40 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { int64 window_size = 0; - int64 stride = 0; OP_REQUIRES_OK( ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size)); - OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "stride", &stride)); OP_REQUIRES( ctx, window_size > 0, errors::InvalidArgument("Window size must be greater than zero.")); + int64 window_shift = 0; + OP_REQUIRES_OK( + ctx, ParseScalarArgument<int64>(ctx, "window_shift", &window_shift)); + OP_REQUIRES( + ctx, window_shift > 0, + errors::InvalidArgument("Window shift must be greater than zero.")); + int64 window_stride = 0; + OP_REQUIRES_OK( + ctx, ParseScalarArgument<int64>(ctx, "window_stride", &window_stride)); OP_REQUIRES( - ctx, stride > 0, - errors::InvalidArgument("Stride must be greater than zero.")); - if (stride == window_size) { - LOG(WARNING) << "stride: " << stride + ctx, window_stride > 0, + errors::InvalidArgument("window_stride must be greater than zero.")); + if (window_size == window_shift && window_stride == 1) { + LOG(WARNING) << "window_shift: " << window_shift << " is equal to window_size: " << window_size - << ", to use `batch` instead."; - } else if (stride > window_size) { - LOG(WARNING) << "stride: " << stride - << " is greater than window_size: " << window_size - << ", you will lose some data."; + << " and window_stride is 1, use `batch` instead."; } - - *output = new Dataset(ctx, window_size, stride, input); + *output = new Dataset(ctx, window_size, window_shift, window_stride, input); } private: class Dataset : public GraphDatasetBase { public: - Dataset(OpKernelContext* ctx, int64 window_size, int64 stride, - const DatasetBase* input) + Dataset(OpKernelContext* ctx, int64 window_size, int64 window_shift, + int64 window_stride, const DatasetBase* input) : GraphDatasetBase(ctx), window_size_(window_size), - stride_(stride), + window_shift_(window_shift), + window_stride_(window_stride), input_(input) { input_->Ref(); @@ -92,8 +99,8 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { } string DebugString() const override { - return strings::StrCat("SlideDatasetOp(", window_size_, ", ", stride_, - ")::Dataset"); + return strings::StrCat("SlideDatasetOp(", window_size_, ", ", + window_shift_, ", ", window_stride_, ")::Dataset"); } protected: @@ -102,16 +109,18 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); Node* window_size = nullptr; - Node* stride = nullptr; + Node* window_shift = nullptr; + Node* window_stride = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size)); - TF_RETURN_IF_ERROR(b->AddScalar(stride_, &stride)); - TF_RETURN_IF_ERROR( - b->AddDataset(this, {input_graph_node, window_size, stride}, output)); + TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift)); + TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {input_graph_node, window_size, window_shift, window_stride}, + output)); return Status::OK(); } private: - class Iterator : public DatasetIterator<Dataset> { public: explicit Iterator(const Params& params) @@ -125,7 +134,8 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { std::vector<Tensor>* out_tensors, bool* end_of_sequence) override { const int64 window_size = dataset()->window_size_; - const int64 stride = dataset()->stride_; + const int64 window_shift = dataset()->window_shift_; + const int64 window_stride = dataset()->window_stride_; std::vector<std::vector<Tensor>> batch_elements; { mutex_lock l(mu_); @@ -134,55 +144,51 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } batch_elements.reserve(window_size); - // Use cache if stride < window_size. - if (stride < window_size) { - const bool first_call = cache_.empty(); - if (first_call) { - cache_.reserve(window_size); - } else { - // Reuse cache in the previous iteration. - cache_.swap(batch_elements); - } - } - // Fill up with new elements. + + // Fill up buffer. + size_t target_size = TargetBufferSize(window_size, window_stride); *end_of_sequence = false; - for (size_t i = batch_elements.size(); i < window_size && !*end_of_sequence; - ++i) { - std::vector<Tensor> batch_element_tuple; - TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple, - end_of_sequence)); + for (size_t i = buffer_.size(); i < target_size && !*end_of_sequence; + ++i) { + std::vector<Tensor> element; + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &element, end_of_sequence)); if (!*end_of_sequence) { - batch_elements.push_back(std::move(batch_element_tuple)); + buffer_.push_back(std::move(element)); } else { input_impl_.reset(); } } - // Drop the final smaller blocks. - if (batch_elements.size() < window_size) { + + // Drop the final smaller batch. + if (buffer_.size() < target_size) { DCHECK(*end_of_sequence); return Status::OK(); } - if (stride < window_size) { - // Cache the data used for the next iteration. - for (size_t i = stride; i < window_size; ++i) { - cache_.emplace_back(batch_elements[i]); - } - } else if (stride > window_size) { - // Drop the data before the next iteration. - std::vector<Tensor> batch_element_tuple; - for (size_t i = window_size; i < stride && !*end_of_sequence; ++i) { - TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple, - end_of_sequence)); - if (*end_of_sequence) { + for (size_t i = 0; i < window_size; ++i) { + batch_elements.emplace_back(buffer_[window_stride * i]); + } + + // Drop the data before the next iteration. + if (window_shift >= buffer_.size()) { + for (size_t i = buffer_.size(); i < window_shift; ++i) { + bool end_of_input; + std::vector<Tensor> element; + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &element, &end_of_input)); + if (end_of_input) { input_impl_.reset(); + break; } } + buffer_.clear(); + } else { + buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift); } } // Construct output tensors. - // Those codes below are copied from batch_dataset_op.cc. const size_t num_tuple_components = batch_elements[0].size(); const int64 num_batch_elements = batch_elements.size(); for (size_t component_index = 0; component_index < num_tuple_components; @@ -224,15 +230,15 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { } else { TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); } - // Save cache. - TF_RETURN_IF_ERROR( - writer->WriteScalar(strings::StrCat("cache_size"), cache_.size())); - for (int64 i = 0; i < cache_.size(); i++) { + // Save buffer. + TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"), + buffer_.size())); + for (int64 i = 0; i < buffer_.size(); i++) { TF_RETURN_IF_ERROR(writer->WriteScalar( - strings::StrCat("cache[", i, "]_size"), cache_[i].size())); - for (int64 j = 0; j < cache_[i].size(); j++) { + strings::StrCat("buffer[", i, "]_size"), buffer_[i].size())); + for (int64 j = 0; j < buffer_[i].size(); j++) { TF_RETURN_IF_ERROR(writer->WriteTensor( - strings::StrCat("cache[", i, "][", j, "]"), cache_[i][j])); + strings::StrCat("buffer[", i, "][", j, "]"), buffer_[i][j])); } } return Status::OK(); @@ -246,32 +252,37 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { } else { input_impl_.reset(); } - // Restore cache. - int64 cache_size; + // Restore buffer. + int64 buffer_size; TF_RETURN_IF_ERROR( - reader->ReadScalar(strings::StrCat("cache_size"), &cache_size)); - cache_.resize(cache_size); - for (int64 i = 0; i < cache_size; i++) { + reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size)); + buffer_.resize(buffer_size); + for (int64 i = 0; i < buffer_size; i++) { int64 vector_size; TF_RETURN_IF_ERROR(reader->ReadScalar( - strings::StrCat("cache[", i, "]_size"), &vector_size)); - cache_[i].resize(vector_size); + strings::StrCat("buffer[", i, "]_size"), &vector_size)); + buffer_[i].resize(vector_size); for (int64 j = 0; j < vector_size; j++) { TF_RETURN_IF_ERROR(reader->ReadTensor( - strings::StrCat("cache[", i, "][", j, "]"), &cache_[i][j])); + strings::StrCat("buffer[", i, "][", j, "]"), &buffer_[i][j])); } } return Status::OK(); } private: + size_t TargetBufferSize(int64 window_size, int64 window_stride) { + return (window_size - 1) * window_stride + 1; + } + mutex mu_; - std::vector<std::vector<Tensor>> cache_ GUARDED_BY(mu_); + std::deque<std::vector<Tensor>> buffer_ GUARDED_BY(mu_); std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); }; const int64 window_size_; - const int64 stride_; + const int64 window_shift_; + const int64 window_stride_; const DatasetBase* const input_; std::vector<PartialTensorShape> output_shapes_; }; |