diff options
Diffstat (limited to 'tensorflow/core/kernels/take_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/take_dataset_op.cc | 59 |
1 files changed, 4 insertions, 55 deletions
diff --git a/tensorflow/core/kernels/take_dataset_op.cc b/tensorflow/core/kernels/take_dataset_op.cc index fb294a96b1..c3f33d663c 100644 --- a/tensorflow/core/kernels/take_dataset_op.cc +++ b/tensorflow/core/kernels/take_dataset_op.cc @@ -35,14 +35,14 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { // Create a new TakeDatasetOp::Dataset, and return it as the output. int64 count; OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count)); - *output = new Dataset(ctx, count, input); + *output = new Dataset(count, input); } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: - Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input) - : GraphDatasetBase(ctx), count_(count), input_(input) { + Dataset(int64 count, const DatasetBase* input) + : count_(count), input_(input) { input_->Ref(); } @@ -72,18 +72,6 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { string DebugString() override { return "TakeDatasetOp::Dataset"; } - protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, - Node** output) const override { - Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node)); - Node* count = nullptr; - TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); - TF_RETURN_IF_ERROR( - b->AddDataset(this, {input_graph_node, count}, output)); - return Status::OK(); - } - private: class EmptyIterator : public DatasetIterator<Dataset> { public: @@ -95,16 +83,6 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { *end_of_sequence = true; return Status::OK(); } - - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - return Status::OK(); - } - - Status RestoreInternal(OpKernelContext* ctx, - IteratorStateReader* reader) override { - return Status::OK(); - } }; class FiniteIterator : public DatasetIterator<Dataset> { @@ -118,10 +96,6 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { std::vector<Tensor>* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. - if (!input_impl_) { - *end_of_sequence = true; - return Status::OK(); - } while (i_ < dataset()->count_) { TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); @@ -136,31 +110,6 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); - if (input_impl_) { - TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); - } else { - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("input_impl_empty"), "")); - } - return Status::OK(); - } - - Status RestoreInternal(OpKernelContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); - if (!reader->Contains(full_name("input_impl_empty"))) { - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); - } else { - input_impl_.reset(); - } - return Status::OK(); - } - private: mutex mu_; int64 i_ GUARDED_BY(mu_); |