aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/iterator_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/iterator_ops.cc')
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc28
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index da9d29dd76..61a6c06135 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -130,7 +130,7 @@ class IteratorResource : public ResourceBase {
Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
string serialized_graph_def;
- TF_RETURN_IF_ERROR(reader->ReadScalar(GraphDatasetBase::kDatasetGraphKey,
+ TF_RETURN_IF_ERROR(reader->ReadScalar(DatasetBase::kDatasetGraphKey,
&serialized_graph_def));
GraphDef graph_def;
if (!graph_def.ParseFromString(serialized_graph_def)) {
@@ -138,7 +138,7 @@ class IteratorResource : public ResourceBase {
}
string output_node;
TF_RETURN_IF_ERROR(reader->ReadScalar(
- GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node));
+ DatasetBase::kDatasetGraphOutputNodeKey, &output_node));
DatasetBase* dataset = nullptr;
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
@@ -161,9 +161,9 @@ class IteratorResource : public ResourceBase {
graph_runner.Run(&graph, lib, {}, {output_node}, &outputs));
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iterator;
- TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
+ TF_RETURN_IF_ERROR(
+ dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator));
TF_RETURN_IF_ERROR(set_iterator(std::move(iterator)));
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
@@ -611,9 +611,9 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) {
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
core::ScopedUnref unref(iterator_resource);
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iterator;
- OP_REQUIRES_OK(ctx, dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
+ OP_REQUIRES_OK(
+ ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iterator));
OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
}
@@ -633,11 +633,11 @@ class ToSingleElementOp : public AsyncOpKernel {
DatasetBase* dataset;
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iterator;
OP_REQUIRES_OK_ASYNC(
ctx,
- dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator),
+ dataset->MakeIterator(IteratorContext(ctx), "SingleElementIterator",
+ &iterator),
done);
// NOTE(jsimsa): We must destroy the iterator before calling `done()`, to
@@ -651,8 +651,8 @@ class ToSingleElementOp : public AsyncOpKernel {
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence = false;
- Status s =
- raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
+ Status s = raw_iterator->GetNext(IteratorContext(ctx), &components,
+ &end_of_sequence);
if (!s.ok()) {
ctx->SetStatus(s);
return;
@@ -667,8 +667,8 @@ class ToSingleElementOp : public AsyncOpKernel {
}
components.clear();
- Status s2 =
- raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
+ Status s2 = raw_iterator->GetNext(IteratorContext(ctx), &components,
+ &end_of_sequence);
if (!s2.ok()) {
ctx->SetStatus(s2);
return;
@@ -836,9 +836,9 @@ class OneShotIteratorOp : public AsyncOpKernel {
// factory function.
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iter;
- TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iter));
+ TF_RETURN_IF_ERROR(
+ dataset->MakeIterator(IteratorContext(ctx), "Iterator", &iter));
TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter)));
(*iterator)->Ref();