aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-06 16:09:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 16:17:43 -0700
commit9a6ab2af59f3b21ffa2b74093ccc9af4edaf7f98 (patch)
tree748882485661e750cf19f1d9ca182a590bbb7c8b /tensorflow
parent33d2a0e7064cd14540121e38457d4a81aa57a650 (diff)
[tf.data] Adding support for `num_parallel_calls` to `tf.data.Dataset.interleave`.
Unlike the `tf.data.contrib.parallel_interleave` whose parallelism is tied to the `cycle_length` argument, the newly introduced `num_parallel_calls` argument of `tf.data.Dataset.interleave` is decoupled from the `cycle_length` argument and identifies the degree of parallelism to use for fetching output elements. PiperOrigin-RevId: 211886816
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py45
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt13
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc594
-rw-r--r--tensorflow/core/ops/dataset_ops.cc13
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD2
-rw-r--r--tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py167
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py46
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt2
16 files changed, 771 insertions, 126 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
index 4881f63ab9..aa89674c6e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
@@ -210,6 +210,7 @@ py_test(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
index ac3892fe81..243f6405a1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
@@ -27,42 +28,38 @@ from tensorflow.python.platform import test
class InterleaveDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
+ dataset_serialization_test_base.DatasetSerializationTestBase,
+ parameterized.TestCase):
- def _build_iterator_graph(self, input_values, cycle_length, block_length):
+ def _build_iterator_graph(self, input_values, cycle_length, block_length,
+ num_parallel_calls):
repeat_count = 2
return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
repeat_count).interleave(
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
- cycle_length, block_length)
+ cycle_length, block_length, num_parallel_calls)
- def testSerializationCore(self):
+ @parameterized.named_parameters(
+ ("1", 2, 3, None),
+ ("2", 2, 3, 1),
+ ("3", 2, 3, 2),
+ ("4", 1, 3, None),
+ ("5", 1, 3, 1),
+ ("6", 2, 1, None),
+ ("7", 2, 1, 1),
+ ("8", 2, 1, 2),
+ )
+ def testSerializationCore(self, cycle_length, block_length,
+ num_parallel_calls):
input_values = np.array([4, 5, 6], dtype=np.int64)
num_outputs = np.sum(input_values) * 2
- # cycle_length > 1, block_length > 1
- cycle_length = 2
- block_length = 3
# pylint: disable=g-long-lambda
self.run_core_tests(
lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
+ input_values, cycle_length, block_length, num_parallel_calls),
lambda: self._build_iterator_graph(
- input_values, cycle_length * 2, block_length * 1),
+ input_values, cycle_length * 2, block_length, num_parallel_calls),
num_outputs)
- # cycle_length = 1
- cycle_length = 1
- block_length = 3
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
- None, num_outputs)
- # block_length = 1
- cycle_length = 2
- block_length = 1
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
- None, num_outputs)
# pylint: enable=g-long-lambda
def testSparseCore(self):
@@ -82,5 +79,5 @@ class InterleaveDatasetSerializationTest(
self.run_core_tests(_build_dataset, None, 20)
-if __name__ == '__main__':
+if __name__ == "__main__":
test.main()
diff --git a/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt
new file mode 100644
index 0000000000..27bc4013c3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ParallelInterleaveDatasetV2"
+ visibility: HIDDEN
+ attr {
+ name: "f"
+ description: <<END
+A function mapping elements of `input_dataset`, concatenated with
+`other_arguments`, to a Dataset variant that contains elements matching
+`output_types` and `output_shapes`.
+END
+ }
+ summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
+}
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index f8287cf0e3..640f1565b7 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <deque>
+#include <utility>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
@@ -21,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
@@ -34,8 +36,7 @@ namespace {
class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
@@ -125,6 +126,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
const DataTypeVector& output_dtypes() const override {
return output_types_;
}
+
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
@@ -1058,7 +1060,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList interleave_func_;
@@ -1067,6 +1068,593 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
+// The motivation for creating an alternative implementation of parallel
+// interleave is to decouple the degree of parallelism from the cycle length.
+// This makes it possible to change the degree of parallelism (e.g. through
+// auto-tuning) without changing the cycle length (which would change the order
+// in which elements are produced).
+//
+// Furthermore, this class favors modularity over extended functionality. In
+// particular, it refrains from implementing configurable buffering of output
+// elements and prefetching of input iterators, relying on other parts of
+// tf.data to provide this functionality if necessary.
+//
+// The above design choices were made with automated optimizations in mind,
+// isolating the degree of parallelism as the single tunable knob of this
+// implementation.
+class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
+ public:
+ explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ OpInputList inputs;
+ OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
+
+ int64 cycle_length = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument(ctx, "cycle_length", &cycle_length));
+ OP_REQUIRES(ctx, cycle_length > 0,
+ errors::InvalidArgument("`cycle_length` must be > 0"));
+
+ int64 block_length = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument(ctx, "block_length", &block_length));
+ OP_REQUIRES(ctx, block_length > 0,
+ errors::InvalidArgument("`block_length` must be > 0"));
+
+ int64 num_parallel_calls;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
+ &num_parallel_calls));
+ OP_REQUIRES(ctx, num_parallel_calls > 0,
+ errors::InvalidArgument(
+ "num_parallel_calls must be greater than zero."));
+ OP_REQUIRES(
+ ctx, num_parallel_calls <= cycle_length,
+ errors::InvalidArgument(
+ "num_parallel_calls must less than or equal to cycle_length."));
+
+ // TODO(b/114267189): Use `other_arguments(inputs.begin(), inputs.end());`.
+ std::vector<Tensor> other_arguments;
+ other_arguments.reserve(inputs.size());
+ for (const Tensor& t : inputs) {
+ other_arguments.push_back(t);
+ }
+ std::unique_ptr<CapturedFunction> captured_func;
+ OP_REQUIRES_OK(
+ ctx, CapturedFunction::Create(
+ interleave_func_, std::move(other_arguments), &captured_func));
+
+ *output = new Dataset(ctx, input, interleave_func_,
+ std::move(captured_func), cycle_length, block_length,
+ num_parallel_calls, output_types_, output_shapes_);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func,
+ std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
+ int64 block_length, int64 num_parallel_calls,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ interleave_func_(func),
+ captured_func_(std::move(captured_func)),
+ cycle_length_(cycle_length),
+ block_length_(block_length),
+ num_parallel_calls_(num_parallel_calls),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::ParallelInterleaveV2")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "ParallelInterleaveDatasetV2Op::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name()));
+ Node* input_node;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
+ Node* cycle_length_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
+ Node* block_length_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
+ Node* num_parallel_calls_node;
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
+ DataTypeVector other_arguments_types;
+ other_arguments_types.reserve(captured_func_->captured_inputs().size());
+ std::vector<Node*> other_arguments;
+ other_arguments.reserve(captured_func_->captured_inputs().size());
+ for (const Tensor& t : captured_func_->captured_inputs()) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ other_arguments.emplace_back(node);
+ other_arguments_types.emplace_back(t.dtype());
+ }
+ AttrValue f;
+ b->BuildAttrValue(interleave_func_, &f);
+ AttrValue other_arguments_types_attr;
+ b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this,
+ {{0, input_node},
+ {2, cycle_length_node},
+ {3, block_length_node},
+ {4, num_parallel_calls_node}},
+ {{1, other_arguments}},
+ {{"f", f}, {"Targuments", other_arguments_types_attr}}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ args_list_(params.dataset->cycle_length_),
+ current_elements_(params.dataset->cycle_length_),
+ element_in_use_(params.dataset->cycle_length_, false),
+ thread_pool_(new thread::ThreadPool(
+ Env::Default(), ThreadOptions(), "parallel_interleave",
+ dataset()->cycle_length_ /* num_threads */,
+ false /* low_latency_hint */)) {}
+
+ ~Iterator() override {
+ mutex_lock l(mu_);
+ // Cancel the runner thread.
+ cancelled_ = true;
+ cond_var_.notify_all();
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
+ }
+ }
+
+ Status Initialize(IteratorContext* ctx) override {
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ std::shared_ptr<InvocationResult> result;
+ do {
+ {
+ mutex_lock l(mu_);
+ EnsureRunnerThreadStarted(ctx);
+ while (invocation_results_.empty() &&
+ (!end_of_input_ || num_open_ > 0)) {
+ cond_var_.wait(l);
+ }
+ if (!invocation_results_.empty()) {
+ std::swap(result, invocation_results_.front());
+ invocation_results_.pop_front();
+ } else {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ }
+ cond_var_.notify_all();
+ result->notification.WaitForNotification();
+ } while (result->skip);
+
+ if (result->status.ok()) {
+ *out_tensors = std::move(result->return_values);
+ }
+ *end_of_sequence = false;
+ return result->status;
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
+ }
+ CHECK_EQ(num_calls_, 0);
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name("invocation_results.size"), invocation_results_.size()));
+ for (size_t i = 0; i < invocation_results_.size(); i++) {
+ std::shared_ptr<InvocationResult> result = invocation_results_[i];
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].size")),
+ result->return_values.size()));
+ for (size_t j = 0; j < result->return_values.size(); j++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(
+ strings::StrCat("invocation_results[", i, "][", j, "]")),
+ result->return_values[j]));
+ }
+ if (result->skip) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].skip")),
+ ""));
+ }
+ }
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("cycle_index"), cycle_index_));
+ if (end_of_input_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("end_of_input"), ""));
+ }
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("num_open"), num_open_));
+ TF_RETURN_IF_ERROR(WriteCurrentElements(writer));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ int64 invocation_results_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name("invocation_results.size"), &invocation_results_size));
+ for (size_t i = 0; i < invocation_results_size; i++) {
+ std::shared_ptr<InvocationResult> result(new InvocationResult());
+ invocation_results_.push_back(result);
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
+ size_t num_return_values;
+ {
+ int64 size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].size")),
+ &size));
+ num_return_values = static_cast<size_t>(size);
+ if (num_return_values != size) {
+ return errors::InvalidArgument(strings::StrCat(
+ full_name(
+ strings::StrCat("invocation_results[", i, "].size")),
+ ": ", size, " is not a valid value of type size_t."));
+ }
+ }
+ result->return_values.reserve(num_return_values);
+ for (size_t j = 0; j < num_return_values; j++) {
+ result->return_values.emplace_back();
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(full_name(strings::StrCat(
+ "invocation_results[", i, "][", j, "]")),
+ &result->return_values.back()));
+ }
+ result->skip = reader->Contains(
+ full_name(strings::StrCat("invocation_results[", i, "].skip")));
+ result->notification.Notify();
+ }
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("cycle_index"), &cycle_index_));
+ if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("num_open"), &num_open_));
+ TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader));
+ return Status::OK();
+ }
+
+ private:
+ struct InvocationResult {
+ Notification notification; // used for coordination with the consumer
+ Status status; // the invocation status
+ std::vector<Tensor> return_values; // the invocation result values
+ bool skip; // if set the result should be skipped
+ };
+
+ void EnsureRunnerThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!runner_thread_) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+ runner_thread_.reset(ctx->env()->StartThread(
+ {}, "runner_thread",
+ [this, new_ctx]() { RunnerThread(new_ctx); }));
+ }
+ }
+
+ // Fetches up to `results.size()` outputs from the cycle element at
+ // position `cycle_index`.
+ //
+ // If end of input is encountered, the `skip` field of the invocation
+ // result is used to identify results that should be skipped.
+ void FetchOutputs(
+ const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
+ const std::vector<std::shared_ptr<InvocationResult>>& results)
+ LOCKS_EXCLUDED(mu_) {
+ bool end_of_input = false;
+ for (auto& result : results) {
+ if (!end_of_input) {
+ result->status = current_elements_[cycle_index]->GetNext(
+ ctx.get(), &result->return_values, &end_of_input);
+ }
+ if (end_of_input) {
+ result->skip = true;
+ }
+ result->notification.Notify();
+ if (!result->status.ok()) {
+ break;
+ }
+ }
+
+ // Release the ownership of the cycle element iterator, closing the
+ // iterator if end of input was encountered.
+ {
+ if (end_of_input) {
+ current_elements_[cycle_index].reset();
+ }
+ mutex_lock l(mu_);
+ element_in_use_[cycle_index] = false;
+ num_calls_--;
+ if (end_of_input) {
+ args_list_[cycle_index].clear();
+ num_open_--;
+ }
+ }
+ cond_var_.notify_all();
+ }
+
+ int64 MaxInvocationResults() {
+ return dataset()->cycle_length_ * dataset()->block_length_;
+ }
+
+ // Method responsible for 1) creating iterators out of input elements, 2)
+ // determining the order in which elements are fetched from the iterators,
+ // and 3) scheduling the fetching of the elements to a threadpool.
+ //
+ // This method runs in the `runner_thread` background thread.
+ void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ while (true) {
+ {
+ mutex_lock l(mu_);
+ // Wait until this thread is cancelled, the end of input has been
+ // reached, or the cycle element at the `cycle_index_` position is
+ // not in use and there is space in the `invocation_results_` queue.
+ while (!cancelled_ && (!end_of_input_ || num_open_ > 0) &&
+ (element_in_use_[cycle_index_] ||
+ num_calls_ >= dataset()->num_parallel_calls_ ||
+ invocation_results_.size() >= MaxInvocationResults())) {
+ cond_var_.wait(l);
+ }
+
+ if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
+ return;
+ }
+
+ while (!element_in_use_[cycle_index_] &&
+ (!end_of_input_ || num_open_ > 0) &&
+ num_calls_ < dataset()->num_parallel_calls_ &&
+ invocation_results_.size() < MaxInvocationResults()) {
+ if (!current_elements_[cycle_index_]) {
+ // Try to create a new iterator from the next input element.
+ Status status = input_impl_->GetNext(
+ ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+ if (!status.ok()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ std::shared_ptr<InvocationResult>& result =
+ invocation_results_.back();
+ result->status.Update(status);
+ result->notification.Notify();
+ break;
+ }
+ if (!end_of_input_) {
+ Status status = MakeIteratorFromInputElement(
+ ctx.get(), args_list_[cycle_index_], cycle_index_,
+ dataset()->captured_func_.get(), prefix(),
+ &current_elements_[cycle_index_]);
+ if (!status.ok()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ std::shared_ptr<InvocationResult>& result =
+ invocation_results_.back();
+ result->status.Update(status);
+ result->notification.Notify();
+ break;
+ }
+ ++num_open_;
+ }
+ }
+ if (current_elements_[cycle_index_]) {
+ // Pre-allocate invocation results for outputs to be fetched
+ // and then fetch the outputs asynchronously.
+ std::vector<std::shared_ptr<InvocationResult>> results;
+ results.reserve(dataset()->block_length_);
+ for (int i = 0; i < dataset()->block_length_; ++i) {
+ invocation_results_.emplace_back(new InvocationResult());
+ results.push_back(invocation_results_.back());
+ }
+ num_calls_++;
+ element_in_use_[cycle_index_] = true;
+ thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
+ ctx, cycle_index_,
+ std::move(results)));
+ }
+ cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
+ }
+ }
+ cond_var_.notify_all();
+ }
+ }
+
+ Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
+ const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ CodeKey(index), static_cast<int64>(status.code())));
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
+ status.error_message()));
+ }
+ return Status::OK();
+ }
+
+ Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 code_int;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+ error::Code code = static_cast<error::Code>(code_int);
+
+ if (code != error::Code::OK) {
+ string error_message;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(ErrorMessageKey(index), &error_message));
+ *status = Status(code, error_message);
+ } else {
+ *status = Status::OK();
+ }
+ return Status::OK();
+ }
+
+ string CodeKey(size_t index) {
+ return full_name(
+ strings::StrCat("invocation_results[", index, "].code"));
+ }
+
+ string ErrorMessageKey(size_t index) {
+ return full_name(
+ strings::StrCat("invocation_results[", index, "].error_message"));
+ }
+
+ Status WriteCurrentElements(IteratorStateWriter* writer)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ for (int idx = 0; idx < current_elements_.size(); idx++) {
+ if (current_elements_[idx]) {
+ TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx]));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("args_size[", idx, "]")),
+ args_list_[idx].size()));
+ for (int i = 0; i < args_list_[idx].size(); i++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
+ args_list_[idx][i]));
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ Status ReadCurrentElements(IteratorContext* ctx,
+ IteratorStateReader* reader)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ for (int idx = 0; idx < current_elements_.size(); idx++) {
+ if (reader->Contains(
+ full_name(strings::StrCat("args_size[", idx, "]")))) {
+ int64 args_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("args_size[", idx, "]")),
+ &args_size));
+ args_list_[idx].resize(args_size);
+ for (int i = 0; i < args_size; i++) {
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
+ &args_list_[idx][i]));
+ }
+ TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
+ ctx, args_list_[idx], idx, dataset()->captured_func_.get(),
+ prefix(), &current_elements_[idx]));
+ TF_RETURN_IF_ERROR(
+ RestoreInput(ctx, reader, current_elements_[idx]));
+ } else {
+ current_elements_[idx].reset();
+ }
+ }
+ return Status::OK();
+ }
+
+ // Used for coordination between the main thread, the runner thread, and
+ // the worker threads.
+ mutex mu_;
+
+ // Used for coordination between the main thread, the runner thread, and
+ // the worker threads. In particular, the runner thread should only
+ // schedule new calls when the number of in-flight calls is less than the
+ // user specified level of parallelism, there are slots available in the
+ // `invocation_results_` buffer, the current cycle element is not in use,
+ // and there are elements left to be fetched.
+ condition_variable cond_var_;
+
+ // Iterator for input elements.
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+
+ // Identifies current cycle element.
+ int64 cycle_index_ = 0;
+
+ // Arguments for creating an iterator for cycle elements.
+ std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
+
+ // Iterators for the current cycle elements. Concurrent access is
+ // protected by `element_in_use_`.
+ std::vector<std::unique_ptr<IteratorBase>> current_elements_;
+
+ // Identifies cycle elements that are in use by worker threads.
+ std::vector<bool> element_in_use_ GUARDED_BY(mu_);
+
+ // Buffer for storing the invocation results.
+ std::deque<std::shared_ptr<InvocationResult>> invocation_results_
+ GUARDED_BY(mu_);
+
+ // Identifies whether end of input has been reached.
+ bool end_of_input_ GUARDED_BY(mu_) = false;
+
+ // Identifies the number of open iterators.
+ int64 num_open_ GUARDED_BY(mu_) = 0;
+
+ // Identifies the number of outstanding calls.
+ int64 num_calls_ GUARDED_BY(mu_) = 0;
+
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
+
+ // Identifies whether background activity should be cancelled.
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ };
+
+ const DatasetBase* const input_;
+ const NameAttrList interleave_func_;
+ const std::unique_ptr<CapturedFunction> captured_func_;
+ const int64 cycle_length_;
+ const int64 block_length_;
+ const int64 num_parallel_calls_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ NameAttrList interleave_func_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU),
+ ParallelInterleaveDatasetV2Op);
+
} // namespace
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 1a5ad8f421..145f4941c8 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -326,6 +326,19 @@ REGISTER_OP("ParallelInterleaveDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("ParallelInterleaveDatasetV2")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("cycle_length: int64")
+ .Input("block_length: int64")
+ .Input("num_parallel_calls: int64")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("GroupByReducerDataset")
.Input("input_dataset: variant")
.Input("key_func_other_arguments: Tkey_func_other_arguments")
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 23c98247bf..5cd1484084 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -137,6 +137,8 @@ tf_py_test(
size = "small",
srcs = ["interleave_dataset_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
index 7dbf7268d7..a35cee594a 100644
--- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
@@ -19,8 +19,10 @@ from __future__ import print_function
import itertools
+from absl.testing import parameterized
+import numpy as np
+
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
@@ -28,7 +30,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class InterleaveDatasetTest(test.TestCase):
+class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
def _interleave(self, lists, cycle_length, block_length):
num_open = 0
@@ -97,84 +99,85 @@ class InterleaveDatasetTest(test.TestCase):
expected_elements, self._interleave(input_lists, 7, 2)):
self.assertEqual(expected, produced)
- def testInterleaveDataset(self):
- input_values = array_ops.placeholder(dtypes.int64, shape=[None])
- cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
- block_length = array_ops.placeholder(dtypes.int64, shape=[])
-
- repeat_count = 2
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(input_values)
- .repeat(repeat_count)
- .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
- cycle_length, block_length))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- next_element = iterator.get_next()
+ @parameterized.named_parameters(
+ ("1", np.int64([4, 5, 6]), 1, 3, None),
+ ("2", np.int64([4, 5, 6]), 1, 3, 1),
+ ("3", np.int64([4, 5, 6]), 2, 1, None),
+ ("4", np.int64([4, 5, 6]), 2, 1, 1),
+ ("5", np.int64([4, 5, 6]), 2, 1, 2),
+ ("6", np.int64([4, 5, 6]), 2, 3, None),
+ ("7", np.int64([4, 5, 6]), 2, 3, 1),
+ ("8", np.int64([4, 5, 6]), 2, 3, 2),
+ ("9", np.int64([4, 5, 6]), 7, 2, None),
+ ("10", np.int64([4, 5, 6]), 7, 2, 1),
+ ("11", np.int64([4, 5, 6]), 7, 2, 3),
+ ("12", np.int64([4, 5, 6]), 7, 2, 5),
+ ("13", np.int64([4, 5, 6]), 7, 2, 7),
+ ("14", np.int64([]), 2, 3, None),
+ ("15", np.int64([0, 0, 0]), 2, 3, None),
+ ("16", np.int64([4, 0, 6]), 2, 3, None),
+ ("17", np.int64([4, 0, 6]), 2, 3, 1),
+ ("18", np.int64([4, 0, 6]), 2, 3, 2),
+ )
+ def testInterleaveDataset(self, input_values, cycle_length, block_length,
+ num_parallel_calls):
+ count = 2
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
+ count).interleave(
+ lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
+ cycle_length, block_length, num_parallel_calls)
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ def repeat(values, count):
+ result = []
+ for value in values:
+ result.append([value] * value)
+ return result * count
with self.test_session() as sess:
- # Cycle length 1 acts like `Dataset.flat_map()`.
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 1, block_length: 3})
-
- for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3):
- self.assertEqual(expected_element, sess.run(next_element))
-
- # Cycle length > 1.
- # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5,
- # 6, 5, 6, 5, 6, 5, 6, 5]
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 2, block_length: 1})
for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1):
- self.assertEqual(expected_element, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Cycle length > 1 and block length > 1.
- # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5,
- # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6]
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 2, block_length: 3})
- for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3):
- self.assertEqual(expected_element, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Cycle length > len(input_values) * repeat_count.
- # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4,
- # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6]
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 7, block_length: 2})
- for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2):
- self.assertEqual(expected_element, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Empty input.
- sess.run(init_op, feed_dict={input_values: [],
- cycle_length: 2, block_length: 3})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ repeat(input_values, count), cycle_length, block_length):
+ self.assertEqual(expected_element, sess.run(get_next))
+
+ for _ in range(2):
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ @parameterized.named_parameters(
+ ("1", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, None),
+ ("2", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, 1),
+ ("3", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, None),
+ ("4", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 1),
+ ("5", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 2),
+ ("6", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, None),
+ ("7", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 1),
+ ("8", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 2),
+ ("9", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, None),
+ ("10", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 1),
+ ("11", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 3),
+ ("12", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 5),
+ ("13", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 7),
+ )
+ def testInterleaveErrorDataset(self,
+ input_values,
+ cycle_length,
+ block_length,
+ num_parallel_calls):
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
+ lambda x: array_ops.check_numerics(x, "message")).interleave(
+ dataset_ops.Dataset.from_tensors, cycle_length, block_length,
+ num_parallel_calls)
+ get_next = dataset.make_one_shot_iterator().get_next()
- # Non-empty input leading to empty output.
- sess.run(init_op, feed_dict={input_values: [0, 0, 0],
- cycle_length: 2, block_length: 3})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Mixture of non-empty and empty interleaved datasets.
- sess.run(init_op, feed_dict={input_values: [4, 0, 6],
- cycle_length: 2, block_length: 3})
- for expected_element in self._interleave(
- [[4] * 4, [], [6] * 6] * repeat_count, 2, 3):
- self.assertEqual(expected_element, sess.run(next_element))
+ with self.test_session() as sess:
+ for value in input_values:
+ if np.isnan(value):
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+ else:
+ self.assertEqual(value, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ sess.run(get_next)
def testSparse(self):
@@ -201,20 +204,6 @@ class InterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testEmptyInput(self):
- iterator = (
- dataset_ops.Dataset.from_tensor_slices([])
- .repeat(None)
- .interleave(dataset_ops.Dataset.from_tensors, cycle_length=2)
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 6205ee392e..2c1aa22116 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1019,7 +1019,11 @@ class Dataset(object):
"""
return FlatMapDataset(self, map_func)
- def interleave(self, map_func, cycle_length, block_length=1):
+ def interleave(self,
+ map_func,
+ cycle_length,
+ block_length=1,
+ num_parallel_calls=None):
"""Maps `map_func` across this dataset, and interleaves the results.
For example, you can use `Dataset.interleave()` to process many input files
@@ -1082,11 +1086,19 @@ class Dataset(object):
processed concurrently.
block_length: The number of consecutive elements to produce from each
input element before cycling to another input element.
+ num_parallel_calls: (Optional.) If specified, the implementation creates
+ a threadpool, which is used to fetch inputs from cycle elements
+ asynchronously and in parallel. The default behavior is to fetch inputs
+ from cycle elements synchronously with no parallelism.
Returns:
Dataset: A `Dataset`.
"""
- return InterleaveDataset(self, map_func, cycle_length, block_length)
+ if num_parallel_calls is None:
+ return InterleaveDataset(self, map_func, cycle_length, block_length)
+ else:
+ return ParallelInterleaveDataset(self, map_func, cycle_length,
+ block_length, num_parallel_calls)
def filter(self, predicate):
"""Filters this dataset according to `predicate`.
@@ -2330,6 +2342,36 @@ class InterleaveDataset(FlatMapDataset):
return "Dataset.interleave()"
+class ParallelInterleaveDataset(FlatMapDataset):
+ """A `Dataset` that maps a function over its input and interleaves the result.
+
+ """
+
+ def __init__(self, input_dataset, map_func, cycle_length, block_length,
+ num_parallel_calls):
+ """See `Dataset.interleave()` for details."""
+ super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func)
+ self._cycle_length = ops.convert_to_tensor(
+ cycle_length, dtype=dtypes.int64, name="cycle_length")
+ self._block_length = ops.convert_to_tensor(
+ block_length, dtype=dtypes.int64, name="block_length")
+ self._num_parallel_calls = ops.convert_to_tensor(
+ num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.parallel_interleave_dataset_v2(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._map_func.captured_inputs, # pylint: disable=protected-access
+ self._cycle_length,
+ self._block_length,
+ self._num_parallel_calls,
+ f=self._map_func, # pylint: disable=protected-access
+ **flat_structure(self))
+
+ def _transformation_name(self):
+ return "Dataset.interleave()"
+
+
class FilterDataset(Dataset):
"""A `Dataset` that filters its input according to a predicate function."""
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
index 834f0954d5..87745420ee 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
@@ -60,7 +60,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 4d854a4cee..6dd46365b0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
index 601f095a60..35b7105eba 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
index 587829a4c0..8ae370af98 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
index 834f0954d5..87745420ee 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
@@ -60,7 +60,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 4d854a4cee..6dd46365b0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
index 601f095a60..35b7105eba 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
index 587829a4c0..8ae370af98 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"