diff options
Diffstat (limited to 'tensorflow/core/kernels/zip_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/zip_dataset_op.cc | 63 |
1 files changed, 6 insertions, 57 deletions
diff --git a/tensorflow/core/kernels/zip_dataset_op.cc b/tensorflow/core/kernels/zip_dataset_op.cc index f466c8b268..a80b9edbe4 100644 --- a/tensorflow/core/kernels/zip_dataset_op.cc +++ b/tensorflow/core/kernels/zip_dataset_op.cc @@ -35,15 +35,14 @@ class ZipDatasetOp : public DatasetOpKernel { OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input)); inputs.push_back(input); } - *output = new Dataset(ctx, inputs); + *output = new Dataset(inputs); } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: - explicit Dataset(OpKernelContext* ctx, - const std::vector<DatasetBase*>& inputs) - : GraphDatasetBase(ctx), inputs_(inputs) { + explicit Dataset(const std::vector<DatasetBase*>& inputs) + : inputs_(inputs) { for (const auto& input : inputs_) { input->Ref(); for (DataType dt : input->output_dtypes()) { @@ -77,21 +76,6 @@ class ZipDatasetOp : public DatasetOpKernel { string DebugString() override { return "ZipDatasetOp::Dataset"; } - protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, - Node** output) const override { - std::vector<NodeBuilder::NodeOut> input_graph_nodes; - input_graph_nodes.reserve(inputs_.size()); - for (const auto& input : inputs_) { - Node* input_node; - TF_RETURN_IF_ERROR(b->AddParentDataset(input, &input_node)); - input_graph_nodes.emplace_back(input_node); - } - TF_RETURN_IF_ERROR( - b->AddDatasetWithInputAsList(this, input_graph_nodes, output)); - return Status::OK(); - } - private: class Iterator : public DatasetIterator<Dataset> { public: @@ -109,10 +93,6 @@ class ZipDatasetOp : public DatasetOpKernel { std::vector<Tensor>* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - if (input_impls_.empty()) { - *end_of_sequence = true; - return Status::OK(); - } out_tensors->clear(); out_tensors->reserve(dataset()->output_dtypes().size()); for (const auto& input_impl : input_impls_) { @@ -120,43 +100,12 @@ class ZipDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR( input_impl->GetNext(ctx, &input_tensors, end_of_sequence)); if (*end_of_sequence) { - break; + return Status::OK(); } out_tensors->insert(out_tensors->end(), input_tensors.begin(), input_tensors.end()); } - if (*end_of_sequence) { - out_tensors->clear(); - input_impls_.clear(); - } else { - *end_of_sequence = false; - } - return Status::OK(); - } - - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - if (input_impls_.empty()) { - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("input_impls_empty"), "")); - } else { - for (auto& input_impl : input_impls_) - TF_RETURN_IF_ERROR(SaveParent(writer, input_impl)); - } - return Status::OK(); - } - - Status RestoreInternal(OpKernelContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - if (reader->Contains(full_name("input_impls_empty"))) { - input_impls_.clear(); - } else { - DCHECK_EQ(input_impls_.size(), dataset()->inputs_.size()); - for (auto& input_impl : input_impls_) - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl)); - } + *end_of_sequence = false; return Status::OK(); } |