aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-17 16:31:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 16:35:28 -0700
commit8ef1ece7d0ecdec633a22a8100fdae05cfbacb3e (patch)
tree7123d7e44983f26da690ac511ceb09b77c067114
parentf5116dd366a5bb1d679e1682c13b8fa3c4830a84 (diff)
[tf.data] Introducing `tf.data.Dataset.window(size, shift, stride, drop_remainder)`, which can be used for combining elements of input dataset into "windows". A window
is itself a finite dataset and, among other things, can be used for generalized batching (see https://github.com/tensorflow/community/pull/5 for details). PiperOrigin-RevId: 213360134
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py7
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py51
-rw-r--r--tensorflow/contrib/data/python/ops/sliding.py4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt23
-rw-r--r--tensorflow/core/kernels/data/window_dataset_op.cc215
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt14
-rw-r--r--tensorflow/core/ops/dataset_ops.cc10
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD17
-rw-r--r--tensorflow/python/data/kernel_tests/window_dataset_op_test.py295
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py93
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt4
18 files changed, 679 insertions, 82 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index 6eaa0b1959..8b7b3ac0f7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -89,13 +89,14 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
return dataset_ops.Dataset.zip(
tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args]))
- dataset = self._structuredDataset(structure, shape, dtype).apply(
+ dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
grouping.window_dataset(5)).flat_map(fn)
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
expected = sess.run(self._structuredElement(structure, shape, dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
+ for _ in range(5):
+ actual = sess.run(get_next)
+ self._assertEqual(expected, actual)
@parameterized.named_parameters(
("1", None, np.int32([]), dtypes.bool),
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 099e10db92..020167e4d1 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -255,6 +255,7 @@ def _map_x_dataset(map_func):
return _apply_fn
+# TODO(b/115382007) Remove this once canned reducers move to core.
def window_dataset(window_size):
"""A transformation that creates window datasets from the input dataset.
@@ -271,7 +272,12 @@ def window_dataset(window_size):
"""
def _apply_fn(dataset):
- return _WindowDataset(dataset, window_size)
+ return dataset_ops.WindowDataset(
+ dataset,
+ size=window_size,
+ shift=window_size,
+ stride=1,
+ drop_remainder=False)
return _apply_fn
@@ -556,46 +562,3 @@ class _MapXDataset(dataset_ops.Dataset):
@property
def output_types(self):
return self._output_types
-
-
-class _WindowDataset(dataset_ops.Dataset):
- """A dataset that creates window datasets from the input elements."""
-
- def __init__(self, input_dataset, window_size):
- """See `window_dataset()` for more details."""
- super(_WindowDataset, self).__init__()
- self._input_dataset = input_dataset
- self._window_size = ops.convert_to_tensor(
- window_size, dtype=dtypes.int64, name="window_size")
- self._output_classes = nest.pack_sequence_as(
- input_dataset.output_classes,
- [
- dataset_ops._NestedDatasetComponent( # pylint: disable=protected-access
- output_classes=output_class,
- output_shapes=output_shape,
- output_types=output_type)
- for output_class, output_shape, output_type in zip(
- nest.flatten(input_dataset.output_classes),
- nest.flatten(input_dataset.output_shapes),
- nest.flatten(input_dataset.output_types))
- ])
- self._output_shapes = self._output_classes
- self._output_types = self._output_classes
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.window_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._window_size,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py
index 8025dcdd16..b0d6a16c20 100644
--- a/tensorflow/contrib/data/python/ops/sliding.py
+++ b/tensorflow/contrib/data/python/ops/sliding.py
@@ -67,6 +67,10 @@ class _SlideDataset(dataset_ops.Dataset):
@deprecation.deprecated_args(
None, "stride is deprecated, use window_shift instead", "stride")
+@deprecation.deprecated(
+ None, "Use `tf.data.Dataset.window(size=window_size, shift=window_shift, "
+ "stride=window_stride).flat_map(lambda x: x.batch(window.size))` "
+ "instead.")
def sliding_window_batch(window_size,
stride=None,
window_shift=None,
diff --git a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
index 1bc3660479..01387b7527 100644
--- a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
@@ -2,10 +2,31 @@ op {
visibility: HIDDEN
graph_op_name: "WindowDataset"
in_arg {
- name: "window_size"
+ name: "size"
description: <<END
A scalar representing the number of elements to accumulate in a window.
END
}
+ in_arg {
+ name: "shift"
+ description: <<END
+A scalar representing the steps moving the sliding window forward in one
+iteration. It must be positive.
+END
+ }
+ in_arg {
+ name: "stride"
+ description: <<END
+A scalar representing the stride of the input elements of the sliding window.
+It must be positive.
+END
+ }
+ in_arg {
+ name: "drop_remainder"
+ description: <<END
+A scalar representing whether a window should be dropped in case its size is
+smaller than desired.
+END
+ }
summary: "A dataset that creates window datasets from the input dataset."
}
diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc
index 3975086841..ac44623ce2 100644
--- a/tensorflow/core/kernels/data/window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/window_dataset_op.cc
@@ -33,22 +33,44 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 window_size = 0;
- OP_REQUIRES_OK(
- ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "size", &window_size));
OP_REQUIRES(
ctx, window_size > 0,
errors::InvalidArgument("Window size must be greater than zero."));
- *output = new Dataset(ctx, window_size, input);
+ int64 window_shift = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "shift", &window_shift));
+ OP_REQUIRES(
+ ctx, window_shift > 0,
+ errors::InvalidArgument("Window shift must be greater than zero."));
+
+ int64 window_stride = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "stride", &window_stride));
+ OP_REQUIRES(
+ ctx, window_stride > 0,
+ errors::InvalidArgument("Window stride must be greater than zero."));
+
+ bool drop_remainder;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<bool>(ctx, "drop_remainder", &drop_remainder));
+
+ *output = new Dataset(ctx, input, window_size, window_shift, window_stride,
+ drop_remainder);
}
private:
class Dataset : public DatasetBase {
public:
- Dataset(OpKernelContext* ctx, int64 window_size, const DatasetBase* input)
+ Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 window_size,
+ int64 window_shift, int64 window_stride, bool drop_remainder)
: DatasetBase(DatasetContext(ctx)),
+ input_(input),
window_size_(window_size),
- input_(input) {
+ window_shift_(window_shift),
+ window_stride_(window_stride),
+ drop_remainder_(drop_remainder) {
input_->Ref();
}
@@ -72,7 +94,8 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
}
string DebugString() const override {
- return strings::StrCat("WindowDatasetOp(", window_size_, ")::Dataset");
+ return strings::StrCat("WindowDatasetOp(", window_size_, window_shift_,
+ window_stride_, drop_remainder_, ")::Dataset");
}
protected:
@@ -81,10 +104,19 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
- Node* window_size = nullptr;
- TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size));
+ Node* window_size_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size_node));
+ Node* window_shift_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift_node));
+ Node* window_stride_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride_node));
+ Node* drop_remainder_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));
TF_RETURN_IF_ERROR(
- b->AddDataset(this, {input_graph_node, window_size}, output));
+ b->AddDataset(this,
+ {input_graph_node, window_size_node, window_shift_node,
+ window_stride_node, drop_remainder_node},
+ output));
return Status::OK();
}
@@ -101,37 +133,79 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- // Each row of `window_elements` is a tuple of tensors from the
- // input iterator.
+ const int64 window_size = dataset()->window_size_;
+ const int64 window_shift = dataset()->window_shift_;
+ const int64 window_stride = dataset()->window_stride_;
std::vector<std::vector<Tensor>> window_elements;
+ Status status = Status::OK();
{
mutex_lock l(mu_);
- if (!input_impl_) {
+ if (!input_impl_ && buffer_.empty()) {
*end_of_sequence = true;
return Status::OK();
}
- window_elements.reserve(dataset()->window_size_);
- *end_of_sequence = false;
- for (int i = 0; i < dataset()->window_size_ && !*end_of_sequence;
- ++i) {
- std::vector<Tensor> window_element_tuple;
- TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &window_element_tuple,
- end_of_sequence));
- if (!*end_of_sequence) {
- window_elements.emplace_back(std::move(window_element_tuple));
- } else {
- input_impl_.reset();
+
+ // Add elements to the buffer.
+ size_t target_size = TargetBufferSize(window_size, window_stride);
+ if (input_impl_) {
+ *end_of_sequence = false;
+ for (size_t i = buffer_.size();
+ i < target_size && !*end_of_sequence; ++i) {
+ std::vector<Tensor> element;
+ Status status =
+ input_impl_->GetNext(ctx, &element, end_of_sequence);
+ if (!*end_of_sequence) {
+ buffer_.emplace_back(std::move(element), status);
+ } else {
+ input_impl_.reset();
+ }
}
}
+
+ // If there are not enough elements and `drop_remainder` is set, we do
+ // not wish to return a smaller window.
+ if (buffer_.empty() ||
+ (dataset()->drop_remainder_ && buffer_.size() < target_size)) {
+ DCHECK(*end_of_sequence);
+ return Status::OK();
+ }
+
+ int num_elements = 1 + (buffer_.size() - 1) / window_stride;
+ window_elements.reserve(num_elements);
+ for (size_t i = 0; i < num_elements; ++i) {
+ status.Update(buffer_[window_stride * i].status);
+ if (!status.ok()) {
+ break;
+ }
+ window_elements.emplace_back(buffer_[window_stride * i].result);
+ }
+
+ // Shift the window, discarding elements if necessary.
+ int buffer_size = buffer_.size();
+ if (window_shift >= buffer_size) {
+ for (size_t i = buffer_size; input_impl_ && i < window_shift; ++i) {
+ bool end_of_input;
+ std::vector<Tensor> element;
+ // Ignore non-error status of discarded elements.
+ input_impl_->GetNext(ctx, &element, &end_of_input).IgnoreError();
+ if (end_of_input) {
+ input_impl_.reset();
+ }
+ }
+ buffer_.clear();
+ } else {
+ buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift);
+ }
}
- if (window_elements.empty()) {
- DCHECK(*end_of_sequence);
- return Status::OK();
+ if (!status.ok()) {
+ return status;
}
+ // Construct output tensors.
const size_t num_tuple_components = window_elements[0].size();
const int64 num_window_elements = window_elements.size();
+ *end_of_sequence = false;
for (size_t idx = 0; idx < num_tuple_components; ++idx) {
DatasetBase* window_dataset;
std::vector<std::vector<Tensor>> window_component_elements;
@@ -154,7 +228,6 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(window_dataset,
&out_tensors->back()));
}
- *end_of_sequence = false;
return Status::OK();
}
@@ -167,6 +240,20 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
} else {
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
}
+ // Save buffer.
+ TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"),
+ buffer_.size()));
+ for (int64 i = 0; i < buffer_.size(); i++) {
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, buffer_[i].status));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(strings::StrCat("buffer[", i, "].size"),
+ buffer_[i].result.size()));
+ for (int64 j = 0; j < buffer_[i].result.size(); j++) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteTensor(strings::StrCat("buffer[", i, "][", j, "]"),
+ buffer_[i].result[j]));
+ }
+ }
return Status::OK();
}
@@ -178,22 +265,92 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
} else {
input_impl_.reset();
}
+ // Restore buffer.
+ int64 buffer_size;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size));
+ buffer_.resize(buffer_size);
+ for (int64 i = 0; i < buffer_size; i++) {
+ int64 vector_size;
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &buffer_[i].status));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ strings::StrCat("buffer[", i, "].size"), &vector_size));
+ buffer_[i].result.resize(vector_size);
+ for (int64 j = 0; j < vector_size; j++) {
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(strings::StrCat("buffer[", i, "][", j, "]"),
+ &buffer_[i].result[j]));
+ }
+ }
return Status::OK();
}
private:
+ struct InvocationResult {
+ InvocationResult() = default;
+ InvocationResult(std::vector<Tensor>&& result, const Status& status)
+ : result(result), status(status) {}
+
+ std::vector<Tensor> result;
+ Status status;
+ };
+
+ Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
+ const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ CodeKey(index), static_cast<int64>(status.code())));
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
+ status.error_message()));
+ }
+ return Status::OK();
+ }
+
+ Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 code_int;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+ error::Code code = static_cast<error::Code>(code_int);
+
+ if (code != error::Code::OK) {
+ string error_message;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(ErrorMessageKey(index), &error_message));
+ *status = Status(code, error_message);
+ } else {
+ *status = Status::OK();
+ }
+ return Status::OK();
+ }
+
+ string CodeKey(size_t index) {
+ return full_name(strings::StrCat("buffer[", index, "].code"));
+ }
+
+ string ErrorMessageKey(size_t index) {
+ return full_name(strings::StrCat("buffer[", index, "].error_message"));
+ }
+
+ size_t TargetBufferSize(int64 window_size, int64 window_stride) {
+ return (window_size - 1) * window_stride + 1;
+ }
+
mutex mu_;
+ std::deque<InvocationResult> buffer_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
- const int64 window_size_;
const DatasetBase* const input_;
+ const int64 window_size_;
+ const int64 window_shift_;
+ const int64 window_stride_;
+ const bool drop_remainder_;
};
};
REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU),
WindowDatasetOp);
-
} // namespace
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 57c6bda98b..e59958749c 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -75602,9 +75602,21 @@ op {
type: DT_VARIANT
}
input_arg {
- name: "window_size"
+ name: "size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "shift"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "stride"
type: DT_INT64
}
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
output_arg {
name: "handle"
type: DT_VARIANT
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 7d9e7b2d3f..4d3f272c1b 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -396,14 +396,20 @@ REGISTER_OP("FilterByLastComponentDataset")
REGISTER_OP("WindowDataset")
.Input("input_dataset: variant")
- .Input("window_size: int64")
+ .Input("size: int64")
+ .Input("shift: int64")
+ .Input("stride: int64")
+ .Input("drop_remainder: bool")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
- // batch_size should be a scalar.
+ // size, shift, stride, and drop_remainder should be scalars.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
return shape_inference::ScalarShape(c);
});
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 631b87a718..17d4fec662 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -407,3 +407,20 @@ cuda_py_test(
"//tensorflow/python:tensor_shape",
],
)
+
+tf_py_test(
+ name = "window_dataset_op_test",
+ size = "small",
+ srcs = ["window_dataset_op_test.py"],
+ additional_deps = [
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
diff --git a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
new file mode 100644
index 0000000000..fd4348426d
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
@@ -0,0 +1,295 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class WindowDatasetTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("1", 20, 14, 7, 1),
+ ("2", 20, 17, 9, 1),
+ ("3", 20, 14, 14, 1),
+ ("4", 20, 10, 14, 1),
+ ("5", 20, 14, 19, 1),
+ ("6", 20, 4, 1, 2),
+ ("7", 20, 2, 1, 6),
+ ("8", 20, 4, 7, 2),
+ ("9", 20, 2, 7, 6),
+ ("10", 1, 10, 4, 1),
+ ("11", 0, 10, 4, 1),
+ ("12", 20, 14, 7, 1, False),
+ ("13", 20, 17, 9, 1, False),
+ ("14", 20, 14, 14, 1, False),
+ ("15", 20, 10, 14, 1, False),
+ ("16", 20, 14, 19, 1, False),
+ ("17", 20, 4, 1, 2, False),
+ ("18", 20, 2, 1, 6, False),
+ ("19", 20, 4, 7, 2, False),
+ ("20", 20, 2, 7, 6, False),
+ ("21", 1, 10, 4, 1, False),
+ ("22", 0, 10, 4, 1, False),
+ )
+ def testWindowDataset(self, count, size, shift, stride, drop_remainder=True):
+ """Tests a dataset that slides a window its input elements."""
+ components = (np.arange(7),
+ np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
+ np.array(37.0) * np.arange(7))
+
+ count_t = array_ops.placeholder(dtypes.int64, shape=[])
+ size_t = array_ops.placeholder(dtypes.int64, shape=[])
+ shift_t = array_ops.placeholder(dtypes.int64, shape=[])
+ stride_t = array_ops.placeholder(dtypes.int64, shape=[])
+ drop_remainder_t = array_ops.placeholder(dtypes.bool, shape=[])
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ def _flat_map_fn(x, y, z):
+ return dataset_ops.Dataset.zip((x.batch(batch_size=size_t),
+ y.batch(batch_size=size_t),
+ z.batch(batch_size=size_t)))
+
+ iterator = dataset_ops.Dataset.from_tensor_slices(components).map(
+ _map_fn).repeat(count).window(
+ size=size_t,
+ shift=shift_t,
+ stride=stride_t,
+ drop_remainder=drop_remainder_t).flat_map(
+ _flat_map_fn).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ self.assertEqual([[None] + list(c.shape[1:]) for c in components],
+ [t.shape.as_list() for t in get_next])
+
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ count_t: count,
+ size_t: size,
+ shift_t: shift,
+ stride_t: stride,
+ drop_remainder_t: drop_remainder
+ })
+ num_full_batches = max(
+ 0, (count * 7 - ((size - 1) * stride + 1)) // shift + 1)
+ for i in range(num_full_batches):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range(size):
+ self.assertAllEqual(component[(i * shift + j * stride) % 7]**2,
+ result_component[j])
+ if not drop_remainder:
+ num_partial_batches = (count * 7) // shift + (
+ (count * 7) % shift > 0) - num_full_batches
+ for i in range(num_partial_batches):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ remaining = (count * 7) - ((num_full_batches + i) * shift)
+ num_elements = remaining // stride + ((remaining % stride) > 0)
+ for j in range(num_elements):
+ self.assertAllEqual(
+ component[((num_full_batches + i) * shift + j * stride) % 7]
+ **2, result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ @parameterized.named_parameters(
+ ("1", 14, 0, 3, 1),
+ ("2", 14, 3, 0, 1),
+ ("3", 14, 3, 3, 0),
+ )
+ def testWindowDatasetInvalid(self, count, size, shift, stride):
+ count_t = array_ops.placeholder(dtypes.int64, shape=[])
+ size_t = array_ops.placeholder(dtypes.int64, shape=[])
+ shift_t = array_ops.placeholder(dtypes.int64, shape=[])
+ stride_t = array_ops.placeholder(dtypes.int64, shape=[])
+
+ iterator = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(
+ count_t).window(
+ size=size_t, shift=shift_t,
+ stride=stride_t).flat_map(lambda x: x.batch(batch_size=size_t)
+ ).make_initializable_iterator()
+ init_op = iterator.initializer
+
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(
+ init_op,
+ feed_dict={
+ count_t: count,
+ size_t: size,
+ shift_t: shift,
+ stride_t: stride
+ })
+
+ def assertSparseValuesEqual(self, a, b):
+ self.assertAllEqual(a.indices, b.indices)
+ self.assertAllEqual(a.values, b.values)
+ self.assertAllEqual(a.dense_shape, b.dense_shape)
+
+ def testWindowSparse(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ iterator = dataset_ops.Dataset.range(10).map(_sparse).window(
+ size=5, shift=3, drop_remainder=True).flat_map(
+ lambda x: x.batch(batch_size=5)).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ num_batches = (10 - 5) // 3 + 1
+ for i in range(num_batches):
+ actual = sess.run(get_next)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
+ values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
+ dense_shape=[5, 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testWindowSparseWithDifferentDenseShapes(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=array_ops.expand_dims(
+ math_ops.range(i, dtype=dtypes.int64), 1),
+ values=array_ops.fill([math_ops.to_int32(i)], i),
+ dense_shape=[i])
+
+ iterator = dataset_ops.Dataset.range(10).map(_sparse).window(
+ size=5, shift=3, drop_remainder=True).flat_map(
+ lambda x: x.batch(batch_size=5)).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ num_batches = (10 - 5) // 3 + 1
+ for i in range(num_batches):
+ actual = sess.run(get_next)
+ expected_indices = []
+ expected_values = []
+ for j in range(5):
+ for k in range(i * 3 + j):
+ expected_indices.append([j, k])
+ expected_values.append(i * 3 + j)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=expected_indices,
+ values=expected_values,
+ dense_shape=[5, i * 3 + 5 - 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testNestedWindowSparse(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ iterator = dataset_ops.Dataset.range(10).map(_sparse).window(
+ size=4, shift=2,
+ drop_remainder=True).flat_map(lambda x: x.batch(batch_size=4)).window(
+ size=3, shift=1, drop_remainder=True).flat_map(
+ lambda x: x.batch(batch_size=3)).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ # Slide: 1st batch.
+ actual = sess.run(get_next)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
+ [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
+ [2, 2, 0], [2, 3, 0]],
+ values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7],
+ dense_shape=[3, 4, 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ # Slide: 2nd batch.
+ actual = sess.run(get_next)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
+ [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
+ [2, 2, 0], [2, 3, 0]],
+ values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9],
+ dense_shape=[3, 4, 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testWindowShapeError(self):
+
+ def generator():
+ yield [1.0, 2.0, 3.0]
+ yield [4.0, 5.0, 6.0]
+ yield [7.0, 8.0, 9.0, 10.0]
+
+ iterator = dataset_ops.Dataset.from_generator(
+ generator, dtypes.float32, output_shapes=[None]).window(
+ size=3, shift=1).flat_map(
+ lambda x: x.batch(batch_size=3)).make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r"Cannot batch tensors with different shapes in component 0. "
+ r"First element had shape \[3\] and element 2 had shape \[4\]."):
+ sess.run(next_element)
+
+ def testWindowIgnoreErrors(self):
+ input_values = np.float32([1., np.nan, 2., np.nan, 3.])
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
+ lambda x: array_ops.check_numerics(x, "message")).window(
+ size=2, shift=2, stride=2,
+ drop_remainder=True).flat_map(lambda x: x.batch(batch_size=2))
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ self.assertAllEqual(np.float32([1., 2.]), sess.run(get_next))
+ self.assertAllEqual(np.float32([2., 3.]), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index c985e00dd1..93b3a7b93b 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1115,7 +1115,7 @@ class Dataset(object):
return FilterDataset(self, predicate)
def apply(self, transformation_func):
- """Apply a transformation function to this dataset.
+ """Applies a transformation function to this dataset.
`apply` enables chaining of custom `Dataset` transformations, which are
represented as functions that take one `Dataset` argument and return a
@@ -1131,7 +1131,7 @@ class Dataset(object):
Args:
transformation_func: A function that takes one `Dataset` argument and
- returns a `Dataset`.
+ returns a `Dataset`.
Returns:
Dataset: The `Dataset` returned by applying `transformation_func` to this
@@ -1142,6 +1142,45 @@ class Dataset(object):
raise TypeError("`transformation_func` must return a Dataset.")
return dataset
+ def window(self, size, shift=None, stride=1, drop_remainder=False):
+ """Combines input elements into a dataset of windows.
+
+ Each window is a dataset itself and contains `size` elements (or
+ possibly fewer if there are not enough input elements to fill the window
+ and `drop_remainder` evaluates to false).
+
+ The `stride` argument determines the stride of the input elements,
+ and the `shift` argument determines the shift of the window.
+
+ For example:
+ - `tf.data.Dataset.range(7).window(2)` produces
+ `{{0, 1}, {2, 3}, {4, 5}, {6}}`
+ - `tf.data.Dataset.range(7).window(3, 2, 1, True)` produces
+ `{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}`
+ - `tf.data.Dataset.range(7).window(3, 1, 2, True)` produces
+ `{{0, 2, 4}, {1, 3, 5}, {2, 4, 6}}`
+
+ Args:
+ size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements
+ of the input dataset to combine into a window.
+ shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ forward shift of the sliding window in each iteration. Defaults to
+ `size`.
+ stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ stride of the input elements in the sliding window.
+ drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
+ whether a window should be dropped in case its size is smaller than
+ `window_size`.
+
+ Returns:
+ Dataset: A `Dataset` of windows, each of which is a nested `Dataset` with
+ the same structure as this dataset, but a finite subsequence of its
+ elements.
+ """
+ if shift is None:
+ shift = size
+ return WindowDataset(self, size, shift, stride, drop_remainder)
+
class TensorDataset(Dataset):
"""A `Dataset` with a single element, viz. a nested structure of tensors."""
@@ -2442,3 +2481,53 @@ class PrefetchDataset(Dataset):
@property
def output_types(self):
return self._input_dataset.output_types
+
+
+class WindowDataset(Dataset):
+ """A dataset that creates window datasets from the input elements."""
+
+ def __init__(self, input_dataset, size, shift, stride, drop_remainder):
+ """See `window_dataset()` for more details."""
+ super(WindowDataset, self).__init__()
+ self._input_dataset = input_dataset
+ self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
+ self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
+ self._stride = ops.convert_to_tensor(
+ stride, dtype=dtypes.int64, name="stride")
+ self._drop_remainder = ops.convert_to_tensor(
+ drop_remainder, dtype=dtypes.bool, name="drop_remainder")
+ self._output_classes = nest.pack_sequence_as(
+ input_dataset.output_classes,
+ [
+ _NestedDatasetComponent( # pylint: disable=protected-access
+ output_classes=output_class,
+ output_shapes=output_shape,
+ output_types=output_type)
+ for output_class, output_shape, output_type in zip(
+ nest.flatten(input_dataset.output_classes),
+ nest.flatten(input_dataset.output_shapes),
+ nest.flatten(input_dataset.output_types))
+ ])
+ self._output_shapes = self._output_classes
+ self._output_types = self._output_classes
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.window_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._size,
+ self._shift,
+ self._stride,
+ self._drop_remainder,
+ **flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
index 87745420ee..c3ba2dba57 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
@@ -111,6 +111,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 6dd46365b0..3541671bee 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -112,6 +112,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
index 35b7105eba..b113c18ee0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -112,6 +112,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
index 8ae370af98..7210bf5db4 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
@@ -112,6 +112,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
index 87745420ee..c3ba2dba57 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
@@ -111,6 +111,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 6dd46365b0..3541671bee 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -112,6 +112,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
index 35b7105eba..b113c18ee0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -112,6 +112,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
index 8ae370af98..7210bf5db4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
@@ -112,6 +112,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}