diff options
Diffstat (limited to 'tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 6292b4536e..d2b83f9eab 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -134,11 +134,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name())); + TF_RETURN_IF_ERROR( + b->AddFunction(ctx->flib_def(), interleave_func_.name())); Node* input_node; - TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node)); + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); Node* cycle_length_node; TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node)); Node* block_length_node; @@ -358,7 +360,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { mutex_lock l(mu_); mutex_lock ckpt_l(ckpt_mu_); if (input_impl_) { - TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); } else { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_exhausted"), "")); @@ -402,7 +404,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { mutex_lock l(mu_); mutex_lock ckpt_l(ckpt_mu_); if (!reader->Contains(full_name("input_exhausted"))) { - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); } else { input_impl_.reset(); } @@ -858,7 +860,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { string prefix = strings::StrCat("worker_thread_", index); if (worker_thread_states_[index].iterator != nullptr) { TF_RETURN_IF_ERROR( - SaveParent(writer, worker_thread_states_[index].iterator)); + SaveInput(writer, worker_thread_states_[index].iterator)); } else { TF_RETURN_IF_ERROR(writer->WriteScalar( full_name(strings::StrCat(prefix, "_iterator_exhausted")), "")); @@ -909,7 +911,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { Status s = dataset::MakeIteratorFromInputElement( ctx, worker_thread_states_[index].input, index, dataset()->captured_func_.get(), prefix(), &iterator); - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, iterator)); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator)); worker_thread_states_[index].iterator.swap(iterator); } TF_RETURN_IF_ERROR(ReadStatusLocked( |