aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/take_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/take_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/take_dataset_op.cc59
1 files changed, 4 insertions, 55 deletions
diff --git a/tensorflow/core/kernels/take_dataset_op.cc b/tensorflow/core/kernels/take_dataset_op.cc
index fb294a96b1..c3f33d663c 100644
--- a/tensorflow/core/kernels/take_dataset_op.cc
+++ b/tensorflow/core/kernels/take_dataset_op.cc
@@ -35,14 +35,14 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
// Create a new TakeDatasetOp::Dataset, and return it as the output.
int64 count;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
- *output = new Dataset(ctx, count, input);
+ *output = new Dataset(count, input);
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
- Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
- : GraphDatasetBase(ctx), count_(count), input_(input) {
+ Dataset(int64 count, const DatasetBase* input)
+ : count_(count), input_(input) {
input_->Ref();
}
@@ -72,18 +72,6 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
string DebugString() override { return "TakeDatasetOp::Dataset"; }
- protected:
- Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
- Node** output) const override {
- Node* input_graph_node = nullptr;
- TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node));
- Node* count = nullptr;
- TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
- TF_RETURN_IF_ERROR(
- b->AddDataset(this, {input_graph_node, count}, output));
- return Status::OK();
- }
-
private:
class EmptyIterator : public DatasetIterator<Dataset> {
public:
@@ -95,16 +83,6 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
*end_of_sequence = true;
return Status::OK();
}
-
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- return Status::OK();
- }
-
- Status RestoreInternal(OpKernelContext* ctx,
- IteratorStateReader* reader) override {
- return Status::OK();
- }
};
class FiniteIterator : public DatasetIterator<Dataset> {
@@ -118,10 +96,6 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
- if (!input_impl_) {
- *end_of_sequence = true;
- return Status::OK();
- }
while (i_ < dataset()->count_) {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@@ -136,31 +110,6 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
- if (input_impl_) {
- TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
- } else {
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("input_impl_empty"), ""));
- }
- return Status::OK();
- }
-
- Status RestoreInternal(OpKernelContext* ctx,
- IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
- if (!reader->Contains(full_name("input_impl_empty"))) {
- TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
- } else {
- input_impl_.reset();
- }
- return Status::OK();
- }
-
private:
mutex mu_;
int64 i_ GUARDED_BY(mu_);