From c9328e51b72f9f906364a523926abdc62095ffe0 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Fri, 3 Aug 2018 13:59:52 -0700 Subject: [tf.data] Add checkpointing for memory-based `cache()`. PiperOrigin-RevId: 207320100 --- .../data/python/kernel_tests/serialization/BUILD | 1 + .../cache_dataset_serialization_test.py | 189 ++++++---- tensorflow/core/kernels/data/cache_dataset_ops.cc | 392 ++++++++++++++++----- .../data/kernel_tests/cache_dataset_op_test.py | 2 +- 4 files changed, 423 insertions(+), 161 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD index 3c3f23f9a9..7b9ea191a4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -56,6 +56,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py index a0a1100893..1b6059ccbc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py @@ -19,6 +19,8 @@ from __future__ import print_function import os +from absl.testing import parameterized + from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors @@ -26,7 +28,8 @@ from tensorflow.python.platform import test class CacheDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + dataset_serialization_test_base.DatasetSerializationTestBase, + parameterized.TestCase): def setUp(self): self.range_size = 10 @@ -34,88 +37,123 @@ class CacheDatasetSerializationTest( self.num_outputs = self.range_size * self.num_repeats self.cache_file_prefix = 'test' - def ds_fn(self): - return dataset_ops.Dataset.range(self.range_size).cache( - os.path.join(self.get_temp_dir(), - self.cache_file_prefix)).repeat(self.num_repeats) + def make_dataset_fn(self, is_memory): + if is_memory: + filename = '' + else: + filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix) + + def ds_fn(): + return dataset_ops.Dataset.range(self.range_size).cache(filename).repeat( + self.num_repeats) + + return ds_fn def expected_outputs(self): return list(range(self.range_size)) * self.num_repeats - def testCheckpointBeforeOneEpoch(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Generate 5 entries from iterator and save checkpoint. - outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint and produce the rest of the elements from the # iterator. outputs.extend( self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) - def testCheckpointBeforeOneEpochThenRunFewSteps(self): - # Generate 8 entries from iterator but save checkpoint after producing - # 5. + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 8 entries from iterator but save checkpoint after producing 5. outputs = self.gen_outputs( - self.ds_fn, [5], - 8, - verify_exhausted=False, - save_checkpoint_at_end=False) + ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, range(8)) - # Restoring from checkpoint and running GetNext should return a - # `AlreadExistsError` now because the lockfile already exists. - with self.assertRaises(errors.AlreadyExistsError): - self.gen_outputs( - self.ds_fn, [], - self.num_outputs - 5, - ckpt_saved=True, - verify_exhausted=False) + if is_memory: + outputs = outputs[:5] + outputs.extend( + self.gen_outputs( + ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + else: + # Restoring from checkpoint and running GetNext should return + # `AlreadExistsError` now because the lockfile already exists. + with self.assertRaises(errors.AlreadyExistsError): + self.gen_outputs( + ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointAfterOneEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) - def testCheckpointAfterOneEpoch(self): # Generate 15 entries from iterator and save checkpoint. - outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) # Restore from checkpoint and produce the rest of the elements from the # iterator. outputs.extend( self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 15, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) - def testCheckpointAfterOneEpochThenRunFewSteps(self): - # Generate 18 entries from iterator but save checkpoint after producing - # 15. + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 18 entries from iterator but save checkpoint after producing 15. outputs = self.gen_outputs( - self.ds_fn, [15], - 18, - verify_exhausted=False, - save_checkpoint_at_end=False) + ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(8))) outputs = list(range(10)) + list(range(5)) + self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 15, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testCheckpointBeforeOneEpochButRunCompleteEpoch(self): - # Generate 13 entries from iterator but save checkpoint after producing - # 5. + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 13 entries from iterator but save checkpoint after producing 5. outputs = self.gen_outputs( - self.ds_fn, [5], - 13, - verify_exhausted=False, - save_checkpoint_at_end=False) + ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(3))) # Since we ran for more than one epoch, the cache was completely written. @@ -124,65 +162,90 @@ class CacheDatasetSerializationTest( # been completely written. outputs = list(range(5)) + self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testCheckpointUnusedWriterIterator(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointUnusedWriterIterator(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Checkpoint before get_next is called even once. - outputs = self.gen_outputs(self.ds_fn, [], 0, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False) self.assertSequenceEqual(outputs, []) outputs = self.gen_outputs( - self.ds_fn, [], - self.num_outputs, - ckpt_saved=True, - verify_exhausted=False) + ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testCheckpointUnusedMidwayWriterIterator(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointUnusedMidwayWriterIterator(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Produce 5 elements and checkpoint. - outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint, then produce no elements and checkpoint. outputs.extend( - self.gen_outputs( - self.ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) + self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint and produce rest of the elements. outputs.extend( self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testUnusedCheckpointError(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testUnusedCheckpointError(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Produce 5 elements and save ckpt. - outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) - # Since the complete cache has not been written, a new iterator which does - # not restore the checkpoint will throw an error since there is a partial - # cache shard. - with self.assertRaises(errors.AlreadyExistsError): + if is_memory: outputs = self.gen_outputs( - self.ds_fn, [], self.num_outputs, verify_exhausted=False) + ds_fn, [], self.num_outputs, verify_exhausted=False) + self.assertSequenceEqual(outputs, self.expected_outputs()) + else: + # Since the complete cache has not been written, a new iterator which does + # not restore the checkpoint will throw an error since there is a partial + # cache shard. + with self.assertRaises(errors.AlreadyExistsError): + outputs = self.gen_outputs( + ds_fn, [], self.num_outputs, verify_exhausted=False) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testIgnoreCheckpointIfCacheWritten(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) - def testIgnoreCheckpointIfCacheWritten(self): # Produce 15 elements and save ckpt. This will write the complete cache. - outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) # Build the iterator again but do not restore from ckpt. Since the cache # has already been written we should be able to use it. outputs = self.gen_outputs( - self.ds_fn, [], self.num_outputs, verify_exhausted=False) + ds_fn, [], self.num_outputs, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index ed4932bf32..86b0840aea 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -39,7 +39,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { ParseScalarArgument(ctx, "filename", &filename)); if (filename.empty()) { - *output = new MemoryDataset(input); + *output = new MemoryDataset(ctx, input); } else { *output = new FileDataset(ctx, input, filename, ctx->env()); } @@ -68,8 +68,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr(new FileCacheIterator( - {this, strings::StrCat(prefix, "::FileCacheIterator")})); + return std::unique_ptr( + new FileIterator({this, strings::StrCat(prefix, "::FileIterator")})); } const DataTypeVector& output_dtypes() const override { @@ -105,9 +105,9 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { tensor_index); } - class FileCacheIterator : public DatasetIterator { + class FileIterator : public DatasetIterator { public: - explicit FileCacheIterator(const Params& params) + explicit FileIterator(const Params& params) : DatasetIterator(params) { if (params.dataset->env_ ->FileExists(MetaFilename(params.dataset->filename_)) @@ -526,7 +526,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { enum Mode { read, write }; Mode mode_ GUARDED_BY(mu_); std::unique_ptr iterator_ GUARDED_BY(mu_); - }; // FileCacheIterator + }; // FileIterator const DatasetBase* const input_; const string filename_; @@ -538,9 +538,10 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { const string tensor_format_string_; }; // FileDataset - class MemoryDataset : public DatasetBase { + class MemoryDataset : public GraphDatasetBase { public: - explicit MemoryDataset(const DatasetBase* input) : input_(input) { + explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input) + : GraphDatasetBase(ctx), input_(input), cache_(new MemoryCache()) { input->Ref(); } @@ -548,18 +549,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - mutex_lock l(mu_); - if (cache_) { - return std::unique_ptr(new MemoryReaderIterator( - {this, strings::StrCat(prefix, "::MemoryReader")}, cache_.get())); - } - if (!writer_iterator_created_) { - writer_iterator_created_ = true; - return std::unique_ptr(new MemoryWriterIterator( - {this, strings::StrCat(prefix, "::MemoryWriter")})); - } - return std::unique_ptr(new DuplicateWriterIterator( - {this, strings::StrCat(prefix, "::DuplicateWriter")})); + return std::unique_ptr(new MemoryIterator( + {this, strings::StrCat(prefix, "::MemoryIterator")}, cache_)); } const DataTypeVector& output_dtypes() const override { @@ -574,114 +565,321 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { return "CacheDatasetOp::MemoryDataset"; } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_node = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node)); + Node* filename_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(string(""), &filename_node)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {input_node, filename_node}, output)); + return Status::OK(); + } + private: - // MemoryWriterIterator passes through and appends items from the input - // dataset to its vector. + // A thread-safe data structure for caching dataset elements. // - // This iterator is used when dataset->cache_ is null. After buffering - // the tensors in memory, upon exhausing the underlying iterator, they are - // updated into the parent dataset's cache_ pointer. - class MemoryWriterIterator : public DatasetIterator { + // The expected use is that a single `MemoryWriterIterator` populates the + // cache with dataset elements. Once all elements are cached, the cache can + // be used by one or more `MemoryReaderIterator`s. + class MemoryCache { public: - explicit MemoryWriterIterator(const Params& params) - : DatasetIterator(params), - cache_(new std::vector>) {} + MemoryCache() = default; - ~MemoryWriterIterator() override { + // Marks the cache as completed. + void Complete() { mutex_lock l(mu_); - if (cache_) { - LOG(ERROR) - << "The calling iterator did not fully read the dataset we were " - "attempting to cache. In order to avoid unexpected truncation " - "of the sequence, the current [partially cached] sequence " - "will be dropped. This can occur if you have a sequence " - "similar to `dataset.cache().take(k).repeat()`. Instead, swap " - "the order (i.e. `dataset.take(k).cache().repeat()`)"; - mutex_lock l2(dataset()->mu_); - dataset()->writer_iterator_created_ = false; - } + completed_ = true; } - Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + // Returns whether the cache is claimed. + bool IsClaimed() { + tf_shared_lock l(mu_); + return claimed_; } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + // Returns whether the cache is completed. + bool IsCompleted() { + tf_shared_lock l(mu_); + return completed_; + } + + // Attempts to claim the cache, returning whether the cache was claimed. + bool MaybeClaim() { mutex_lock l(mu_); - TF_RETURN_IF_ERROR( - input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); - if (*end_of_sequence) { - // Guard on cache_ to not crash if GetNext is called a second time - // after *end_of_sequence == true - if (cache_) { - mutex_lock l(dataset()->mu_); - DCHECK(dataset()->writer_iterator_created_); - DCHECK(!dataset()->cache_); - cache_.swap(dataset()->cache_); - } - return Status::OK(); + if (!claimed_) { + claimed_ = true; + return true; } - cache_->emplace_back(*out_tensors); - return Status::OK(); + return false; + } + + // Resets the cache. + void Reset() { + mutex_lock l(mu_); + claimed_ = false; + completed_ = false; + cache_.clear(); + } + + // Returns the element at the given index. + const std::vector& at(int64 index) { + tf_shared_lock l(mu_); + DCHECK(index < cache_.size()); + return cache_[index]; + } + + // Adds the element to the cache. + void emplace_back(std::vector element) { + mutex_lock l(mu_); + cache_.emplace_back(std::move(element)); + } + + // Returns the size of the cache. + size_t size() { + tf_shared_lock l(mu_); + return cache_.size(); } private: mutex mu_; - std::unique_ptr input_impl_ GUARDED_BY(mu_); - std::unique_ptr>> cache_ GUARDED_BY(mu_); - }; // MemoryWriterIterator - - class MemoryReaderIterator : public DatasetIterator { + // Determines whether a writer has claimed the cache. + bool claimed_ GUARDED_BY(mu_) = false; + // Determines whether all elements of the dataset have been cached. + bool completed_ GUARDED_BY(mu_) = false; + std::vector> cache_ GUARDED_BY(mu_); + }; + + class MemoryIterator : public DatasetIterator { public: - explicit MemoryReaderIterator( - const Params& params, const std::vector>* cache) - : DatasetIterator(params), cache_(cache), index_(0) { - CHECK(cache); + explicit MemoryIterator(const Params& params, + const std::shared_ptr& cache) + : DatasetIterator(params), cache_(cache) { + mode_ = cache->MaybeClaim() ? Mode::write : Mode::read; + InitializeIterator(); + } + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + if (mode_ == Mode::read && !cache_->IsCompleted()) { + return errors::Internal( + "Cache should only be read after it has been completed."); + } + return iterator_->Initialize(ctx); } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - if (index_ < cache_->size()) { - const std::vector& cache_tensors = (*cache_)[index_]; - out_tensors->insert(out_tensors->begin(), cache_tensors.begin(), - cache_tensors.end()); - index_++; - *end_of_sequence = false; - return Status::OK(); - } else { - *end_of_sequence = true; - return Status::OK(); + return iterator_->GetNext(ctx, out_tensors, end_of_sequence); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("mode"), mode_)); + if (cache_->IsClaimed()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("cache_claimed"), "")); + size_t cache_size = cache_->size(); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("cache_size"), cache_size)); + for (size_t i = 0; i < cache_size; i++) { + auto& element = cache_->at(i); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("cache[", i, "].size")), + element.size())); + for (size_t j = 0; j < element.size(); ++j) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat("cache[", i, "][", j, "]")), + element[j])); + } + } + if (cache_->IsCompleted()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("cache_completed"), "")); + } } + return SaveParent(writer, iterator_); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + iterator_.reset(); + cache_->Reset(); + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("mode"), &temp)); + mode_ = static_cast(temp); + } + if (reader->Contains(full_name("cache_claimed"))) { + CHECK(cache_->MaybeClaim()); + size_t cache_size; + { + int64 temp; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("cache_size"), &temp)); + cache_size = static_cast(temp); + } + for (size_t i = 0; i < cache_size; ++i) { + std::vector element; + size_t element_size; + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("cache[", i, "].size")), &temp)); + element_size = static_cast(temp); + } + element.reserve(element_size); + for (size_t j = 0; j < element_size; ++j) { + element.emplace_back(); + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat("cache[", i, "][", j, "]")), + &element.back())); + } + cache_->emplace_back(std::move(element)); + } + if (reader->Contains(full_name("cache_completed"))) { + cache_->Complete(); + } + } + InitializeIterator(); + TF_RETURN_IF_ERROR(iterator_->Initialize(ctx)); + return RestoreParent(ctx, reader, iterator_); } private: - mutex mu_; - const std::vector>* const cache_; - size_t index_ GUARDED_BY(mu_); - }; // MemoryReaderIterator + class MemoryWriterIterator : public DatasetIterator { + public: + explicit MemoryWriterIterator(const Params& params, + const std::shared_ptr& cache) + : DatasetIterator(params), cache_(cache) { + CHECK(cache_); + } - class DuplicateWriterIterator : public DatasetIterator { - public: - explicit DuplicateWriterIterator(const Params& params) - : DatasetIterator(params) {} + ~MemoryWriterIterator() override { + mutex_lock l(mu_); + if (cache_->size() > 0 && !cache_->IsCompleted()) { + LOG(WARNING) + << "The calling iterator did not fully read the dataset being " + "cached. In order to avoid unexpected truncation of the " + "dataset, the partially cached contents of the dataset" + "will be discarded. This can happen if you have an input " + "pipeline similar to `dataset.cache().take(k).repeat()`. " + "You should use `dataset.take(k).cache().repeat()` instead."; + cache_->Reset(); + } + } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - return errors::AlreadyExists( - "There appears to be a concurrent caching iterator running."); + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + if (*end_of_sequence) { + cache_->Complete(); + return Status::OK(); + } + cache_->emplace_back(*out_tensors); + return Status::OK(); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + return SaveParent(writer, input_impl_); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + return RestoreParent(ctx, reader, input_impl_); + } + + private: + mutex mu_; + std::unique_ptr input_impl_ GUARDED_BY(mu_); + std::shared_ptr cache_; + }; // MemoryWriterIterator + + class MemoryReaderIterator : public DatasetIterator { + public: + explicit MemoryReaderIterator(const Params& params, + const std::shared_ptr& cache) + : DatasetIterator(params), cache_(cache), index_(0) { + CHECK(cache); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("index"), index_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("index"), &temp)); + index_ = static_cast(temp); + } + return Status::OK(); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (index_ < cache_->size()) { + const std::vector& cache_tensors = cache_->at(index_); + out_tensors->insert(out_tensors->begin(), cache_tensors.begin(), + cache_tensors.end()); + index_++; + *end_of_sequence = false; + return Status::OK(); + } else { + *end_of_sequence = true; + return Status::OK(); + } + } + + private: + mutex mu_; + const std::shared_ptr cache_; + size_t index_ GUARDED_BY(mu_); + }; // MemoryReaderIterator + + void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + switch (mode_) { + case Mode::read: + iterator_.reset( + new MemoryReaderIterator({dataset(), prefix()}, cache_)); + break; + case Mode::write: + iterator_.reset( + new MemoryWriterIterator({dataset(), prefix()}, cache_)); + } } - }; // DuplicateWriterIterator + + mutex mu_; + std::shared_ptr cache_; + enum Mode { read, write }; + Mode mode_ GUARDED_BY(mu_); + std::unique_ptr iterator_ GUARDED_BY(mu_); + }; // MemoryIterator const DatasetBase* const input_; - mutable mutex mu_; - mutable std::unique_ptr>> cache_ - GUARDED_BY(mu_); - mutable bool writer_iterator_created_ GUARDED_BY(mu_) = false; + const std::shared_ptr cache_; }; // MemoryDataset }; // CacheDatasetOp diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py index 25269dc810..4f7fd3566e 100644 --- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py @@ -34,7 +34,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -class FilesystemCacheDatasetTest(test.TestCase): +class FileCacheDatasetTest(test.TestCase): def setUp(self): self.tmp_dir = tempfile.mkdtemp() -- cgit v1.2.3