diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-09-06 16:09:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-06 16:17:43 -0700 |
commit | 9a6ab2af59f3b21ffa2b74093ccc9af4edaf7f98 (patch) | |
tree | 748882485661e750cf19f1d9ca182a590bbb7c8b /tensorflow | |
parent | 33d2a0e7064cd14540121e38457d4a81aa57a650 (diff) |
[tf.data] Adding support for `num_parallel_calls` to `tf.data.Dataset.interleave`.
Unlike the `tf.data.contrib.parallel_interleave` whose parallelism is tied to the `cycle_length` argument, the newly introduced `num_parallel_calls` argument of `tf.data.Dataset.interleave` is decoupled from the `cycle_length` argument and identifies the degree of parallelism to use for fetching output elements.
PiperOrigin-RevId: 211886816
Diffstat (limited to 'tensorflow')
16 files changed, 771 insertions, 126 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD index 4881f63ab9..aa89674c6e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -210,6 +210,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py index ac3892fe81..243f6405a1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base @@ -27,42 +28,38 @@ from tensorflow.python.platform import test class InterleaveDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + dataset_serialization_test_base.DatasetSerializationTestBase, + parameterized.TestCase): - def _build_iterator_graph(self, input_values, cycle_length, block_length): + def _build_iterator_graph(self, input_values, cycle_length, block_length, + num_parallel_calls): repeat_count = 2 return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( repeat_count).interleave( lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length) + cycle_length, block_length, num_parallel_calls) - def testSerializationCore(self): + @parameterized.named_parameters( + ("1", 2, 3, None), + ("2", 2, 3, 1), + ("3", 2, 3, 2), + ("4", 1, 3, None), + ("5", 1, 3, 1), + ("6", 2, 1, None), + ("7", 2, 1, 1), + ("8", 2, 1, 2), + ) + def testSerializationCore(self, cycle_length, block_length, + num_parallel_calls): input_values = np.array([4, 5, 6], dtype=np.int64) num_outputs = np.sum(input_values) * 2 - # cycle_length > 1, block_length > 1 - cycle_length = 2 - block_length = 3 # pylint: disable=g-long-lambda self.run_core_tests( lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), + input_values, cycle_length, block_length, num_parallel_calls), lambda: self._build_iterator_graph( - input_values, cycle_length * 2, block_length * 1), + input_values, cycle_length * 2, block_length, num_parallel_calls), num_outputs) - # cycle_length = 1 - cycle_length = 1 - block_length = 3 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) - # block_length = 1 - cycle_length = 2 - block_length = 1 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) # pylint: enable=g-long-lambda def testSparseCore(self): @@ -82,5 +79,5 @@ class InterleaveDatasetSerializationTest( self.run_core_tests(_build_dataset, None, 20) -if __name__ == '__main__': +if __name__ == "__main__": test.main() diff --git a/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt new file mode 100644 index 0000000000..27bc4013c3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt @@ -0,0 +1,13 @@ +op { + graph_op_name: "ParallelInterleaveDatasetV2" + visibility: HIDDEN + attr { + name: "f" + description: <<END +A function mapping elements of `input_dataset`, concatenated with +`other_arguments`, to a Dataset variant that contains elements matching +`output_types` and `output_shapes`. +END + } + summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`." +} diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index f8287cf0e3..640f1565b7 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <deque> +#include <utility> #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -21,6 +22,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" @@ -34,8 +36,7 @@ namespace { class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { public: explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); @@ -125,6 +126,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { const DataTypeVector& output_dtypes() const override { return output_types_; } + const std::vector<PartialTensorShape>& output_shapes() const override { return output_shapes_; } @@ -1058,7 +1060,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { const std::vector<PartialTensorShape> output_shapes_; }; - const int graph_def_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; NameAttrList interleave_func_; @@ -1067,6 +1068,593 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU), ParallelInterleaveDatasetOp); +// The motivation for creating an alternative implementation of parallel +// interleave is to decouple the degree of parallelism from the cycle length. +// This makes it possible to change the degree of parallelism (e.g. through +// auto-tuning) without changing the cycle length (which would change the order +// in which elements are produced). +// +// Furthermore, this class favors modularity over extended functionality. In +// particular, it refrains from implementing configurable buffering of output +// elements and prefetching of input iterators, relying on other parts of +// tf.data to provide this functionality if necessary. +// +// The above design choices were made with automated optimizations in mind, +// isolating the degree of parallelism as the single tunable knob of this +// implementation. +class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { + public: + explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + OpInputList inputs; + OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); + + int64 cycle_length = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "cycle_length", &cycle_length)); + OP_REQUIRES(ctx, cycle_length > 0, + errors::InvalidArgument("`cycle_length` must be > 0")); + + int64 block_length = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "block_length", &block_length)); + OP_REQUIRES(ctx, block_length > 0, + errors::InvalidArgument("`block_length` must be > 0")); + + int64 num_parallel_calls; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", + &num_parallel_calls)); + OP_REQUIRES(ctx, num_parallel_calls > 0, + errors::InvalidArgument( + "num_parallel_calls must be greater than zero.")); + OP_REQUIRES( + ctx, num_parallel_calls <= cycle_length, + errors::InvalidArgument( + "num_parallel_calls must less than or equal to cycle_length.")); + + // TODO(b/114267189): Use `other_arguments(inputs.begin(), inputs.end());`. + std::vector<Tensor> other_arguments; + other_arguments.reserve(inputs.size()); + for (const Tensor& t : inputs) { + other_arguments.push_back(t); + } + std::unique_ptr<CapturedFunction> captured_func; + OP_REQUIRES_OK( + ctx, CapturedFunction::Create( + interleave_func_, std::move(other_arguments), &captured_func)); + + *output = new Dataset(ctx, input, interleave_func_, + std::move(captured_func), cycle_length, block_length, + num_parallel_calls, output_types_, output_shapes_); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, + std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length, + int64 block_length, int64 num_parallel_calls, + const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes) + : DatasetBase(DatasetContext(ctx)), + input_(input), + interleave_func_(func), + captured_func_(std::move(captured_func)), + cycle_length_(cycle_length), + block_length_(block_length), + num_parallel_calls_(num_parallel_calls), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>(new Iterator( + {this, strings::StrCat(prefix, "::ParallelInterleaveV2")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "ParallelInterleaveDatasetV2Op::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name())); + Node* input_node; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); + Node* cycle_length_node; + TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node)); + Node* block_length_node; + TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node)); + Node* num_parallel_calls_node; + TF_RETURN_IF_ERROR( + b->AddScalar(num_parallel_calls_, &num_parallel_calls_node)); + DataTypeVector other_arguments_types; + other_arguments_types.reserve(captured_func_->captured_inputs().size()); + std::vector<Node*> other_arguments; + other_arguments.reserve(captured_func_->captured_inputs().size()); + for (const Tensor& t : captured_func_->captured_inputs()) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + other_arguments.emplace_back(node); + other_arguments_types.emplace_back(t.dtype()); + } + AttrValue f; + b->BuildAttrValue(interleave_func_, &f); + AttrValue other_arguments_types_attr; + b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, + {{0, input_node}, + {2, cycle_length_node}, + {3, block_length_node}, + {4, num_parallel_calls_node}}, + {{1, other_arguments}}, + {{"f", f}, {"Targuments", other_arguments_types_attr}}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params), + args_list_(params.dataset->cycle_length_), + current_elements_(params.dataset->cycle_length_), + element_in_use_(params.dataset->cycle_length_, false), + thread_pool_(new thread::ThreadPool( + Env::Default(), ThreadOptions(), "parallel_interleave", + dataset()->cycle_length_ /* num_threads */, + false /* low_latency_hint */)) {} + + ~Iterator() override { + mutex_lock l(mu_); + // Cancel the runner thread. + cancelled_ = true; + cond_var_.notify_all(); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_.wait(l); + } + } + + Status Initialize(IteratorContext* ctx) override { + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + std::shared_ptr<InvocationResult> result; + do { + { + mutex_lock l(mu_); + EnsureRunnerThreadStarted(ctx); + while (invocation_results_.empty() && + (!end_of_input_ || num_open_ > 0)) { + cond_var_.wait(l); + } + if (!invocation_results_.empty()) { + std::swap(result, invocation_results_.front()); + invocation_results_.pop_front(); + } else { + *end_of_sequence = true; + return Status::OK(); + } + } + cond_var_.notify_all(); + result->notification.WaitForNotification(); + } while (result->skip); + + if (result->status.ok()) { + *out_tensors = std::move(result->return_values); + } + *end_of_sequence = false; + return result->status; + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_.wait(l); + } + CHECK_EQ(num_calls_, 0); + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name("invocation_results.size"), invocation_results_.size())); + for (size_t i = 0; i < invocation_results_.size(); i++) { + std::shared_ptr<InvocationResult> result = invocation_results_[i]; + TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("invocation_results[", i, "].size")), + result->return_values.size())); + for (size_t j = 0; j < result->return_values.size(); j++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name( + strings::StrCat("invocation_results[", i, "][", j, "]")), + result->return_values[j])); + } + if (result->skip) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("invocation_results[", i, "].skip")), + "")); + } + } + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("cycle_index"), cycle_index_)); + if (end_of_input_) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("end_of_input"), "")); + } + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("num_open"), num_open_)); + TF_RETURN_IF_ERROR(WriteCurrentElements(writer)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + int64 invocation_results_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name("invocation_results.size"), &invocation_results_size)); + for (size_t i = 0; i < invocation_results_size; i++) { + std::shared_ptr<InvocationResult> result(new InvocationResult()); + invocation_results_.push_back(result); + TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status)); + size_t num_return_values; + { + int64 size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("invocation_results[", i, "].size")), + &size)); + num_return_values = static_cast<size_t>(size); + if (num_return_values != size) { + return errors::InvalidArgument(strings::StrCat( + full_name( + strings::StrCat("invocation_results[", i, "].size")), + ": ", size, " is not a valid value of type size_t.")); + } + } + result->return_values.reserve(num_return_values); + for (size_t j = 0; j < num_return_values; j++) { + result->return_values.emplace_back(); + TF_RETURN_IF_ERROR( + reader->ReadTensor(full_name(strings::StrCat( + "invocation_results[", i, "][", j, "]")), + &result->return_values.back())); + } + result->skip = reader->Contains( + full_name(strings::StrCat("invocation_results[", i, "].skip"))); + result->notification.Notify(); + } + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("cycle_index"), &cycle_index_)); + if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("num_open"), &num_open_)); + TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader)); + return Status::OK(); + } + + private: + struct InvocationResult { + Notification notification; // used for coordination with the consumer + Status status; // the invocation status + std::vector<Tensor> return_values; // the invocation result values + bool skip; // if set the result should be skipped + }; + + void EnsureRunnerThreadStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!runner_thread_) { + std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); + runner_thread_.reset(ctx->env()->StartThread( + {}, "runner_thread", + [this, new_ctx]() { RunnerThread(new_ctx); })); + } + } + + // Fetches up to `results.size()` outputs from the cycle element at + // position `cycle_index`. + // + // If end of input is encountered, the `skip` field of the invocation + // result is used to identify results that should be skipped. + void FetchOutputs( + const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index, + const std::vector<std::shared_ptr<InvocationResult>>& results) + LOCKS_EXCLUDED(mu_) { + bool end_of_input = false; + for (auto& result : results) { + if (!end_of_input) { + result->status = current_elements_[cycle_index]->GetNext( + ctx.get(), &result->return_values, &end_of_input); + } + if (end_of_input) { + result->skip = true; + } + result->notification.Notify(); + if (!result->status.ok()) { + break; + } + } + + // Release the ownership of the cycle element iterator, closing the + // iterator if end of input was encountered. + { + if (end_of_input) { + current_elements_[cycle_index].reset(); + } + mutex_lock l(mu_); + element_in_use_[cycle_index] = false; + num_calls_--; + if (end_of_input) { + args_list_[cycle_index].clear(); + num_open_--; + } + } + cond_var_.notify_all(); + } + + int64 MaxInvocationResults() { + return dataset()->cycle_length_ * dataset()->block_length_; + } + + // Method responsible for 1) creating iterators out of input elements, 2) + // determining the order in which elements are fetched from the iterators, + // and 3) scheduling the fetching of the elements to a threadpool. + // + // This method runs in the `runner_thread` background thread. + void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) { + while (true) { + { + mutex_lock l(mu_); + // Wait until this thread is cancelled, the end of input has been + // reached, or the cycle element at the `cycle_index_` position is + // not in use and there is space in the `invocation_results_` queue. + while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && + (element_in_use_[cycle_index_] || + num_calls_ >= dataset()->num_parallel_calls_ || + invocation_results_.size() >= MaxInvocationResults())) { + cond_var_.wait(l); + } + + if (cancelled_ || (end_of_input_ && num_open_ == 0)) { + return; + } + + while (!element_in_use_[cycle_index_] && + (!end_of_input_ || num_open_ > 0) && + num_calls_ < dataset()->num_parallel_calls_ && + invocation_results_.size() < MaxInvocationResults()) { + if (!current_elements_[cycle_index_]) { + // Try to create a new iterator from the next input element. + Status status = input_impl_->GetNext( + ctx.get(), &args_list_[cycle_index_], &end_of_input_); + if (!status.ok()) { + invocation_results_.emplace_back(new InvocationResult()); + std::shared_ptr<InvocationResult>& result = + invocation_results_.back(); + result->status.Update(status); + result->notification.Notify(); + break; + } + if (!end_of_input_) { + Status status = MakeIteratorFromInputElement( + ctx.get(), args_list_[cycle_index_], cycle_index_, + dataset()->captured_func_.get(), prefix(), + ¤t_elements_[cycle_index_]); + if (!status.ok()) { + invocation_results_.emplace_back(new InvocationResult()); + std::shared_ptr<InvocationResult>& result = + invocation_results_.back(); + result->status.Update(status); + result->notification.Notify(); + break; + } + ++num_open_; + } + } + if (current_elements_[cycle_index_]) { + // Pre-allocate invocation results for outputs to be fetched + // and then fetch the outputs asynchronously. + std::vector<std::shared_ptr<InvocationResult>> results; + results.reserve(dataset()->block_length_); + for (int i = 0; i < dataset()->block_length_; ++i) { + invocation_results_.emplace_back(new InvocationResult()); + results.push_back(invocation_results_.back()); + } + num_calls_++; + element_in_use_[cycle_index_] = true; + thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this, + ctx, cycle_index_, + std::move(results))); + } + cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_; + } + } + cond_var_.notify_all(); + } + } + + Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, + const Status& status) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + CodeKey(index), static_cast<int64>(status.code()))); + if (!status.ok()) { + TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), + status.error_message())); + } + return Status::OK(); + } + + Status ReadStatusLocked(IteratorStateReader* reader, size_t index, + Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 code_int; + TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); + error::Code code = static_cast<error::Code>(code_int); + + if (code != error::Code::OK) { + string error_message; + TF_RETURN_IF_ERROR( + reader->ReadScalar(ErrorMessageKey(index), &error_message)); + *status = Status(code, error_message); + } else { + *status = Status::OK(); + } + return Status::OK(); + } + + string CodeKey(size_t index) { + return full_name( + strings::StrCat("invocation_results[", index, "].code")); + } + + string ErrorMessageKey(size_t index) { + return full_name( + strings::StrCat("invocation_results[", index, "].error_message")); + } + + Status WriteCurrentElements(IteratorStateWriter* writer) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + for (int idx = 0; idx < current_elements_.size(); idx++) { + if (current_elements_[idx]) { + TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx])); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("args_size[", idx, "]")), + args_list_[idx].size())); + for (int i = 0; i < args_list_[idx].size(); i++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat("args_list_[", idx, "][", i, "]")), + args_list_[idx][i])); + } + } + } + return Status::OK(); + } + + Status ReadCurrentElements(IteratorContext* ctx, + IteratorStateReader* reader) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + for (int idx = 0; idx < current_elements_.size(); idx++) { + if (reader->Contains( + full_name(strings::StrCat("args_size[", idx, "]")))) { + int64 args_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("args_size[", idx, "]")), + &args_size)); + args_list_[idx].resize(args_size); + for (int i = 0; i < args_size; i++) { + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat("args_list_[", idx, "][", i, "]")), + &args_list_[idx][i])); + } + TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( + ctx, args_list_[idx], idx, dataset()->captured_func_.get(), + prefix(), ¤t_elements_[idx])); + TF_RETURN_IF_ERROR( + RestoreInput(ctx, reader, current_elements_[idx])); + } else { + current_elements_[idx].reset(); + } + } + return Status::OK(); + } + + // Used for coordination between the main thread, the runner thread, and + // the worker threads. + mutex mu_; + + // Used for coordination between the main thread, the runner thread, and + // the worker threads. In particular, the runner thread should only + // schedule new calls when the number of in-flight calls is less than the + // user specified level of parallelism, there are slots available in the + // `invocation_results_` buffer, the current cycle element is not in use, + // and there are elements left to be fetched. + condition_variable cond_var_; + + // Iterator for input elements. + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + + // Identifies current cycle element. + int64 cycle_index_ = 0; + + // Arguments for creating an iterator for cycle elements. + std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_); + + // Iterators for the current cycle elements. Concurrent access is + // protected by `element_in_use_`. + std::vector<std::unique_ptr<IteratorBase>> current_elements_; + + // Identifies cycle elements that are in use by worker threads. + std::vector<bool> element_in_use_ GUARDED_BY(mu_); + + // Buffer for storing the invocation results. + std::deque<std::shared_ptr<InvocationResult>> invocation_results_ + GUARDED_BY(mu_); + + // Identifies whether end of input has been reached. + bool end_of_input_ GUARDED_BY(mu_) = false; + + // Identifies the number of open iterators. + int64 num_open_ GUARDED_BY(mu_) = 0; + + // Identifies the number of outstanding calls. + int64 num_calls_ GUARDED_BY(mu_) = 0; + + std::unique_ptr<thread::ThreadPool> thread_pool_; + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); + + // Identifies whether background activity should be cancelled. + bool cancelled_ GUARDED_BY(mu_) = false; + }; + + const DatasetBase* const input_; + const NameAttrList interleave_func_; + const std::unique_ptr<CapturedFunction> captured_func_; + const int64 cycle_length_; + const int64 block_length_; + const int64 num_parallel_calls_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; + }; + + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; + NameAttrList interleave_func_; +}; + +REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU), + ParallelInterleaveDatasetV2Op); + } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 1a5ad8f421..145f4941c8 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -326,6 +326,19 @@ REGISTER_OP("ParallelInterleaveDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("ParallelInterleaveDatasetV2") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("cycle_length: int64") + .Input("block_length: int64") + .Input("num_parallel_calls: int64") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("GroupByReducerDataset") .Input("input_dataset: variant") .Input("key_func_other_arguments: Tkey_func_other_arguments") diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 23c98247bf..5cd1484084 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -137,6 +137,8 @@ tf_py_test( size = "small", srcs = ["interleave_dataset_op_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py index 7dbf7268d7..a35cee594a 100644 --- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py @@ -19,8 +19,10 @@ from __future__ import print_function import itertools +from absl.testing import parameterized +import numpy as np + from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops @@ -28,7 +30,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class InterleaveDatasetTest(test.TestCase): +class InterleaveDatasetTest(test.TestCase, parameterized.TestCase): def _interleave(self, lists, cycle_length, block_length): num_open = 0 @@ -97,84 +99,85 @@ class InterleaveDatasetTest(test.TestCase): expected_elements, self._interleave(input_lists, 7, 2)): self.assertEqual(expected, produced) - def testInterleaveDataset(self): - input_values = array_ops.placeholder(dtypes.int64, shape=[None]) - cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) - block_length = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_count = 2 - - dataset = ( - dataset_ops.Dataset.from_tensor_slices(input_values) - .repeat(repeat_count) - .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - next_element = iterator.get_next() + @parameterized.named_parameters( + ("1", np.int64([4, 5, 6]), 1, 3, None), + ("2", np.int64([4, 5, 6]), 1, 3, 1), + ("3", np.int64([4, 5, 6]), 2, 1, None), + ("4", np.int64([4, 5, 6]), 2, 1, 1), + ("5", np.int64([4, 5, 6]), 2, 1, 2), + ("6", np.int64([4, 5, 6]), 2, 3, None), + ("7", np.int64([4, 5, 6]), 2, 3, 1), + ("8", np.int64([4, 5, 6]), 2, 3, 2), + ("9", np.int64([4, 5, 6]), 7, 2, None), + ("10", np.int64([4, 5, 6]), 7, 2, 1), + ("11", np.int64([4, 5, 6]), 7, 2, 3), + ("12", np.int64([4, 5, 6]), 7, 2, 5), + ("13", np.int64([4, 5, 6]), 7, 2, 7), + ("14", np.int64([]), 2, 3, None), + ("15", np.int64([0, 0, 0]), 2, 3, None), + ("16", np.int64([4, 0, 6]), 2, 3, None), + ("17", np.int64([4, 0, 6]), 2, 3, 1), + ("18", np.int64([4, 0, 6]), 2, 3, 2), + ) + def testInterleaveDataset(self, input_values, cycle_length, block_length, + num_parallel_calls): + count = 2 + dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat( + count).interleave( + lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length, num_parallel_calls) + get_next = dataset.make_one_shot_iterator().get_next() + + def repeat(values, count): + result = [] + for value in values: + result.append([value] * value) + return result * count with self.test_session() as sess: - # Cycle length 1 acts like `Dataset.flat_map()`. - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 1, block_length: 3}) - - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3): - self.assertEqual(expected_element, sess.run(next_element)) - - # Cycle length > 1. - # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, - # 6, 5, 6, 5, 6, 5, 6, 5] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 1}) for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > 1 and block length > 1. - # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, - # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > len(input_values) * repeat_count. - # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, - # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 7, block_length: 2}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Empty input. - sess.run(init_op, feed_dict={input_values: [], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) + repeat(input_values, count), cycle_length, block_length): + self.assertEqual(expected_element, sess.run(get_next)) + + for _ in range(2): + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + @parameterized.named_parameters( + ("1", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, None), + ("2", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, 1), + ("3", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, None), + ("4", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 1), + ("5", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 2), + ("6", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, None), + ("7", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 1), + ("8", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 2), + ("9", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, None), + ("10", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 1), + ("11", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 3), + ("12", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 5), + ("13", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 7), + ) + def testInterleaveErrorDataset(self, + input_values, + cycle_length, + block_length, + num_parallel_calls): + dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map( + lambda x: array_ops.check_numerics(x, "message")).interleave( + dataset_ops.Dataset.from_tensors, cycle_length, block_length, + num_parallel_calls) + get_next = dataset.make_one_shot_iterator().get_next() - # Non-empty input leading to empty output. - sess.run(init_op, feed_dict={input_values: [0, 0, 0], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Mixture of non-empty and empty interleaved datasets. - sess.run(init_op, feed_dict={input_values: [4, 0, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [], [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) + with self.test_session() as sess: + for value in input_values: + if np.isnan(value): + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + else: + self.assertEqual(value, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) + sess.run(get_next) def testSparse(self): @@ -201,20 +204,6 @@ class InterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testEmptyInput(self): - iterator = ( - dataset_ops.Dataset.from_tensor_slices([]) - .repeat(None) - .interleave(dataset_ops.Dataset.from_tensors, cycle_length=2) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 6205ee392e..2c1aa22116 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -1019,7 +1019,11 @@ class Dataset(object): """ return FlatMapDataset(self, map_func) - def interleave(self, map_func, cycle_length, block_length=1): + def interleave(self, + map_func, + cycle_length, + block_length=1, + num_parallel_calls=None): """Maps `map_func` across this dataset, and interleaves the results. For example, you can use `Dataset.interleave()` to process many input files @@ -1082,11 +1086,19 @@ class Dataset(object): processed concurrently. block_length: The number of consecutive elements to produce from each input element before cycling to another input element. + num_parallel_calls: (Optional.) If specified, the implementation creates + a threadpool, which is used to fetch inputs from cycle elements + asynchronously and in parallel. The default behavior is to fetch inputs + from cycle elements synchronously with no parallelism. Returns: Dataset: A `Dataset`. """ - return InterleaveDataset(self, map_func, cycle_length, block_length) + if num_parallel_calls is None: + return InterleaveDataset(self, map_func, cycle_length, block_length) + else: + return ParallelInterleaveDataset(self, map_func, cycle_length, + block_length, num_parallel_calls) def filter(self, predicate): """Filters this dataset according to `predicate`. @@ -2330,6 +2342,36 @@ class InterleaveDataset(FlatMapDataset): return "Dataset.interleave()" +class ParallelInterleaveDataset(FlatMapDataset): + """A `Dataset` that maps a function over its input and interleaves the result. + + """ + + def __init__(self, input_dataset, map_func, cycle_length, block_length, + num_parallel_calls): + """See `Dataset.interleave()` for details.""" + super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func) + self._cycle_length = ops.convert_to_tensor( + cycle_length, dtype=dtypes.int64, name="cycle_length") + self._block_length = ops.convert_to_tensor( + block_length, dtype=dtypes.int64, name="block_length") + self._num_parallel_calls = ops.convert_to_tensor( + num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") + + def _as_variant_tensor(self): + return gen_dataset_ops.parallel_interleave_dataset_v2( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._map_func.captured_inputs, # pylint: disable=protected-access + self._cycle_length, + self._block_length, + self._num_parallel_calls, + f=self._map_func, # pylint: disable=protected-access + **flat_structure(self)) + + def _transformation_name(self): + return "Dataset.interleave()" + + class FilterDataset(Dataset): """A `Dataset` that filters its input according to a predicate function.""" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt index 834f0954d5..87745420ee 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt @@ -60,7 +60,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt index 4d854a4cee..6dd46365b0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -61,7 +61,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt index 601f095a60..35b7105eba 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt @@ -61,7 +61,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt index 587829a4c0..8ae370af98 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt @@ -61,7 +61,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt index 834f0954d5..87745420ee 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt @@ -60,7 +60,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt index 4d854a4cee..6dd46365b0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -61,7 +61,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt index 601f095a60..35b7105eba 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt @@ -61,7 +61,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt index 587829a4c0..8ae370af98 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt @@ -61,7 +61,7 @@ tf_class { } member_method { name: "interleave" - argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " } member_method { name: "list_files" |