aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/cache_dataset_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/cache_dataset_ops.cc')
-rw-r--r--tensorflow/core/kernels/data/cache_dataset_ops.cc18
1 files changed, 10 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index 3762e403a9..34c6c86538 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level description of
@@ -46,11 +46,11 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class FileDataset : public GraphDatasetBase {
+ class FileDataset : public DatasetBase {
public:
explicit FileDataset(OpKernelContext* ctx, const DatasetBase* input,
string filename, Env* env)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
filename_(std::move(filename)),
env_(env),
@@ -69,7 +69,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
- new FileIterator({this, strings::StrCat(prefix, "::FileIterator")}));
+ new FileIterator({this, strings::StrCat(prefix, "::FileCache")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -539,10 +539,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
const string tensor_format_string_;
}; // FileDataset
- class MemoryDataset : public GraphDatasetBase {
+ class MemoryDataset : public DatasetBase {
public:
explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input)
- : GraphDatasetBase(ctx), input_(input), cache_(new MemoryCache()) {
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ cache_(new MemoryCache()) {
input->Ref();
}
@@ -551,7 +553,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new MemoryIterator(
- {this, strings::StrCat(prefix, "::MemoryIterator")}, cache_));
+ {this, strings::StrCat(prefix, "::MemoryCache")}, cache_));
}
const DataTypeVector& output_dtypes() const override {
@@ -889,5 +891,5 @@ REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU),
CacheDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow