diff options
Diffstat (limited to 'tensorflow/core/kernels/data/iterator_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/data/iterator_ops.cc | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index da9d29dd76..61a6c06135 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -130,7 +130,7 @@ class IteratorResource : public ResourceBase { Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) { string serialized_graph_def; - TF_RETURN_IF_ERROR(reader->ReadScalar(GraphDatasetBase::kDatasetGraphKey, + TF_RETURN_IF_ERROR(reader->ReadScalar(DatasetBase::kDatasetGraphKey, &serialized_graph_def)); GraphDef graph_def; if (!graph_def.ParseFromString(serialized_graph_def)) { @@ -138,7 +138,7 @@ class IteratorResource : public ResourceBase { } string output_node; TF_RETURN_IF_ERROR(reader->ReadScalar( - GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node)); + DatasetBase::kDatasetGraphOutputNodeKey, &output_node)); DatasetBase* dataset = nullptr; Graph graph(OpRegistry::Global()); TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); @@ -161,9 +161,9 @@ class IteratorResource : public ResourceBase { graph_runner.Run(&graph, lib, {}, {output_node}, &outputs)); TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset)); - IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); std::unique_ptr<IteratorBase> iterator; - TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iterator)); + TF_RETURN_IF_ERROR( + dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator)); TF_RETURN_IF_ERROR(set_iterator(std::move(iterator))); std::shared_ptr<IteratorBase> captured_iterator(iterator_); @@ -611,9 +611,9 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) { ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource)); core::ScopedUnref unref(iterator_resource); - IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); std::unique_ptr<IteratorBase> iterator; - OP_REQUIRES_OK(ctx, dataset->MakeIterator(&iter_ctx, "Iterator", &iterator)); + OP_REQUIRES_OK( + ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator)); OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator))); } @@ -633,11 +633,11 @@ class ToSingleElementOp : public AsyncOpKernel { DatasetBase* dataset; OP_REQUIRES_OK_ASYNC( ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done); - IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); std::unique_ptr<IteratorBase> iterator; OP_REQUIRES_OK_ASYNC( ctx, - dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator), + dataset->MakeIterator(IteratorContext(ctx), "SingleElementIterator", + &iterator), done); // NOTE(jsimsa): We must destroy the iterator before calling `done()`, to @@ -651,8 +651,8 @@ class ToSingleElementOp : public AsyncOpKernel { components.reserve(dataset->output_dtypes().size()); bool end_of_sequence = false; - Status s = - raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence); + Status s = raw_iterator->GetNext(IteratorContext(ctx), &components, + &end_of_sequence); if (!s.ok()) { ctx->SetStatus(s); return; @@ -667,8 +667,8 @@ class ToSingleElementOp : public AsyncOpKernel { } components.clear(); - Status s2 = - raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence); + Status s2 = raw_iterator->GetNext(IteratorContext(ctx), &components, + &end_of_sequence); if (!s2.ok()) { ctx->SetStatus(s2); return; @@ -836,9 +836,9 @@ class OneShotIteratorOp : public AsyncOpKernel { // factory function. DatasetBase* dataset; TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset)); - IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); std::unique_ptr<IteratorBase> iter; - TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iter)); + TF_RETURN_IF_ERROR( + dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iter)); TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter))); (*iterator)->Ref(); |