aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/zip_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/zip_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/zip_dataset_op.cc63
1 files changed, 6 insertions, 57 deletions
diff --git a/tensorflow/core/kernels/zip_dataset_op.cc b/tensorflow/core/kernels/zip_dataset_op.cc
index f466c8b268..a80b9edbe4 100644
--- a/tensorflow/core/kernels/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/zip_dataset_op.cc
@@ -35,15 +35,14 @@ class ZipDatasetOp : public DatasetOpKernel {
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
inputs.push_back(input);
}
- *output = new Dataset(ctx, inputs);
+ *output = new Dataset(inputs);
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
- explicit Dataset(OpKernelContext* ctx,
- const std::vector<DatasetBase*>& inputs)
- : GraphDatasetBase(ctx), inputs_(inputs) {
+ explicit Dataset(const std::vector<DatasetBase*>& inputs)
+ : inputs_(inputs) {
for (const auto& input : inputs_) {
input->Ref();
for (DataType dt : input->output_dtypes()) {
@@ -77,21 +76,6 @@ class ZipDatasetOp : public DatasetOpKernel {
string DebugString() override { return "ZipDatasetOp::Dataset"; }
- protected:
- Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
- Node** output) const override {
- std::vector<NodeBuilder::NodeOut> input_graph_nodes;
- input_graph_nodes.reserve(inputs_.size());
- for (const auto& input : inputs_) {
- Node* input_node;
- TF_RETURN_IF_ERROR(b->AddParentDataset(input, &input_node));
- input_graph_nodes.emplace_back(input_node);
- }
- TF_RETURN_IF_ERROR(
- b->AddDatasetWithInputAsList(this, input_graph_nodes, output));
- return Status::OK();
- }
-
private:
class Iterator : public DatasetIterator<Dataset> {
public:
@@ -109,10 +93,6 @@ class ZipDatasetOp : public DatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
- if (input_impls_.empty()) {
- *end_of_sequence = true;
- return Status::OK();
- }
out_tensors->clear();
out_tensors->reserve(dataset()->output_dtypes().size());
for (const auto& input_impl : input_impls_) {
@@ -120,43 +100,12 @@ class ZipDatasetOp : public DatasetOpKernel {
TF_RETURN_IF_ERROR(
input_impl->GetNext(ctx, &input_tensors, end_of_sequence));
if (*end_of_sequence) {
- break;
+ return Status::OK();
}
out_tensors->insert(out_tensors->end(), input_tensors.begin(),
input_tensors.end());
}
- if (*end_of_sequence) {
- out_tensors->clear();
- input_impls_.clear();
- } else {
- *end_of_sequence = false;
- }
- return Status::OK();
- }
-
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- if (input_impls_.empty()) {
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("input_impls_empty"), ""));
- } else {
- for (auto& input_impl : input_impls_)
- TF_RETURN_IF_ERROR(SaveParent(writer, input_impl));
- }
- return Status::OK();
- }
-
- Status RestoreInternal(OpKernelContext* ctx,
- IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- if (reader->Contains(full_name("input_impls_empty"))) {
- input_impls_.clear();
- } else {
- DCHECK_EQ(input_impls_.size(), dataset()->inputs_.size());
- for (auto& input_impl : input_impls_)
- TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl));
- }
+ *end_of_sequence = false;
return Status::OK();
}