diff options
Diffstat (limited to 'tensorflow/core/kernels/data/cache_dataset_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/data/cache_dataset_ops.cc | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index 3762e403a9..34c6c86538 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level description of @@ -46,11 +46,11 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { } private: - class FileDataset : public GraphDatasetBase { + class FileDataset : public DatasetBase { public: explicit FileDataset(OpKernelContext* ctx, const DatasetBase* input, string filename, Env* env) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), input_(input), filename_(std::move(filename)), env_(env), @@ -69,7 +69,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr<IteratorBase>( - new FileIterator({this, strings::StrCat(prefix, "::FileIterator")})); + new FileIterator({this, strings::StrCat(prefix, "::FileCache")})); } const DataTypeVector& output_dtypes() const override { @@ -539,10 +539,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { const string tensor_format_string_; }; // FileDataset - class MemoryDataset : public GraphDatasetBase { + class MemoryDataset : public DatasetBase { public: explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input) - : GraphDatasetBase(ctx), input_(input), cache_(new MemoryCache()) { + : DatasetBase(DatasetContext(ctx)), + input_(input), + cache_(new MemoryCache()) { input->Ref(); } @@ -551,7 +553,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr<IteratorBase>(new MemoryIterator( - {this, strings::StrCat(prefix, "::MemoryIterator")}, cache_)); + {this, strings::StrCat(prefix, "::MemoryCache")}, cache_)); } const DataTypeVector& output_dtypes() const override { @@ -889,5 +891,5 @@ REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU), CacheDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow |