aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/slide_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/slide_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/slide_dataset_op.cc157
1 files changed, 84 insertions, 73 deletions
diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc
index c17e9343ea..5765c61f30 100644
--- a/tensorflow/core/kernels/data/slide_dataset_op.cc
+++ b/tensorflow/core/kernels/data/slide_dataset_op.cc
@@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+
+#include <deque>
+#include <vector>
+
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset.h"
@@ -33,37 +37,40 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 window_size = 0;
- int64 stride = 0;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
- OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "stride", &stride));
OP_REQUIRES(
ctx, window_size > 0,
errors::InvalidArgument("Window size must be greater than zero."));
+ int64 window_shift = 0;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<int64>(ctx, "window_shift", &window_shift));
+ OP_REQUIRES(
+ ctx, window_shift > 0,
+ errors::InvalidArgument("Window shift must be greater than zero."));
+ int64 window_stride = 0;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<int64>(ctx, "window_stride", &window_stride));
OP_REQUIRES(
- ctx, stride > 0,
- errors::InvalidArgument("Stride must be greater than zero."));
- if (stride == window_size) {
- LOG(WARNING) << "stride: " << stride
+ ctx, window_stride > 0,
+ errors::InvalidArgument("window_stride must be greater than zero."));
+ if (window_size == window_shift && window_stride == 1) {
+ LOG(WARNING) << "window_shift: " << window_shift
<< " is equal to window_size: " << window_size
- << ", to use `batch` instead.";
- } else if (stride > window_size) {
- LOG(WARNING) << "stride: " << stride
- << " is greater than window_size: " << window_size
- << ", you will lose some data.";
+ << " and window_stride is 1, use `batch` instead.";
}
-
- *output = new Dataset(ctx, window_size, stride, input);
+ *output = new Dataset(ctx, window_size, window_shift, window_stride, input);
}
private:
class Dataset : public GraphDatasetBase {
public:
- Dataset(OpKernelContext* ctx, int64 window_size, int64 stride,
- const DatasetBase* input)
+ Dataset(OpKernelContext* ctx, int64 window_size, int64 window_shift,
+ int64 window_stride, const DatasetBase* input)
: GraphDatasetBase(ctx),
window_size_(window_size),
- stride_(stride),
+ window_shift_(window_shift),
+ window_stride_(window_stride),
input_(input) {
input_->Ref();
@@ -92,8 +99,8 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
}
string DebugString() const override {
- return strings::StrCat("SlideDatasetOp(", window_size_, ", ", stride_,
- ")::Dataset");
+ return strings::StrCat("SlideDatasetOp(", window_size_, ", ",
+ window_shift_, ", ", window_stride_, ")::Dataset");
}
protected:
@@ -102,16 +109,18 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
Node* window_size = nullptr;
- Node* stride = nullptr;
+ Node* window_shift = nullptr;
+ Node* window_stride = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size));
- TF_RETURN_IF_ERROR(b->AddScalar(stride_, &stride));
- TF_RETURN_IF_ERROR(
- b->AddDataset(this, {input_graph_node, window_size, stride}, output));
+ TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift));
+ TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, window_size, window_shift, window_stride},
+ output));
return Status::OK();
}
private:
-
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
@@ -125,7 +134,8 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
const int64 window_size = dataset()->window_size_;
- const int64 stride = dataset()->stride_;
+ const int64 window_shift = dataset()->window_shift_;
+ const int64 window_stride = dataset()->window_stride_;
std::vector<std::vector<Tensor>> batch_elements;
{
mutex_lock l(mu_);
@@ -134,55 +144,51 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
batch_elements.reserve(window_size);
- // Use cache if stride < window_size.
- if (stride < window_size) {
- const bool first_call = cache_.empty();
- if (first_call) {
- cache_.reserve(window_size);
- } else {
- // Reuse cache in the previous iteration.
- cache_.swap(batch_elements);
- }
- }
- // Fill up with new elements.
+
+ // Fill up buffer.
+ size_t target_size = TargetBufferSize(window_size, window_stride);
*end_of_sequence = false;
- for (size_t i = batch_elements.size(); i < window_size && !*end_of_sequence;
- ++i) {
- std::vector<Tensor> batch_element_tuple;
- TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
- end_of_sequence));
+ for (size_t i = buffer_.size(); i < target_size && !*end_of_sequence;
+ ++i) {
+ std::vector<Tensor> element;
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, &element, end_of_sequence));
if (!*end_of_sequence) {
- batch_elements.push_back(std::move(batch_element_tuple));
+ buffer_.push_back(std::move(element));
} else {
input_impl_.reset();
}
}
- // Drop the final smaller blocks.
- if (batch_elements.size() < window_size) {
+
+ // Drop the final smaller batch.
+ if (buffer_.size() < target_size) {
DCHECK(*end_of_sequence);
return Status::OK();
}
- if (stride < window_size) {
- // Cache the data used for the next iteration.
- for (size_t i = stride; i < window_size; ++i) {
- cache_.emplace_back(batch_elements[i]);
- }
- } else if (stride > window_size) {
- // Drop the data before the next iteration.
- std::vector<Tensor> batch_element_tuple;
- for (size_t i = window_size; i < stride && !*end_of_sequence; ++i) {
- TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
- end_of_sequence));
- if (*end_of_sequence) {
+ for (size_t i = 0; i < window_size; ++i) {
+ batch_elements.emplace_back(buffer_[window_stride * i]);
+ }
+
+ // Drop the data before the next iteration.
+ if (window_shift >= buffer_.size()) {
+ for (size_t i = buffer_.size(); i < window_shift; ++i) {
+ bool end_of_input;
+ std::vector<Tensor> element;
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, &element, &end_of_input));
+ if (end_of_input) {
input_impl_.reset();
+ break;
}
}
+ buffer_.clear();
+ } else {
+ buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift);
}
}
// Construct output tensors.
- // Those codes below are copied from batch_dataset_op.cc.
const size_t num_tuple_components = batch_elements[0].size();
const int64 num_batch_elements = batch_elements.size();
for (size_t component_index = 0; component_index < num_tuple_components;
@@ -224,15 +230,15 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
} else {
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
}
- // Save cache.
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(strings::StrCat("cache_size"), cache_.size()));
- for (int64 i = 0; i < cache_.size(); i++) {
+ // Save buffer.
+ TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"),
+ buffer_.size()));
+ for (int64 i = 0; i < buffer_.size(); i++) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
- strings::StrCat("cache[", i, "]_size"), cache_[i].size()));
- for (int64 j = 0; j < cache_[i].size(); j++) {
+ strings::StrCat("buffer[", i, "]_size"), buffer_[i].size()));
+ for (int64 j = 0; j < buffer_[i].size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
- strings::StrCat("cache[", i, "][", j, "]"), cache_[i][j]));
+ strings::StrCat("buffer[", i, "][", j, "]"), buffer_[i][j]));
}
}
return Status::OK();
@@ -246,32 +252,37 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
} else {
input_impl_.reset();
}
- // Restore cache.
- int64 cache_size;
+ // Restore buffer.
+ int64 buffer_size;
TF_RETURN_IF_ERROR(
- reader->ReadScalar(strings::StrCat("cache_size"), &cache_size));
- cache_.resize(cache_size);
- for (int64 i = 0; i < cache_size; i++) {
+ reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size));
+ buffer_.resize(buffer_size);
+ for (int64 i = 0; i < buffer_size; i++) {
int64 vector_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
- strings::StrCat("cache[", i, "]_size"), &vector_size));
- cache_[i].resize(vector_size);
+ strings::StrCat("buffer[", i, "]_size"), &vector_size));
+ buffer_[i].resize(vector_size);
for (int64 j = 0; j < vector_size; j++) {
TF_RETURN_IF_ERROR(reader->ReadTensor(
- strings::StrCat("cache[", i, "][", j, "]"), &cache_[i][j]));
+ strings::StrCat("buffer[", i, "][", j, "]"), &buffer_[i][j]));
}
}
return Status::OK();
}
private:
+ size_t TargetBufferSize(int64 window_size, int64 window_stride) {
+ return (window_size - 1) * window_stride + 1;
+ }
+
mutex mu_;
- std::vector<std::vector<Tensor>> cache_ GUARDED_BY(mu_);
+ std::deque<std::vector<Tensor>> buffer_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
const int64 window_size_;
- const int64 stride_;
+ const int64 window_shift_;
+ const int64 window_stride_;
const DatasetBase* const input_;
std::vector<PartialTensorShape> output_shapes_;
};