aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt13
-rw-r--r--tensorflow/core/api_def/python_api/api_def_AnonymousIterator.pbtxt4
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc74
-rw-r--r--tensorflow/core/ops/dataset_ops.cc6
-rw-r--r--tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py36
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py4
6 files changed, 134 insertions, 3 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt b/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt
new file mode 100644
index 0000000000..d8c2ed40a3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "AnonymousIterator"
+ out_arg {
+ name: "handle"
+ description: <<END
+A handle to the iterator that can be passed to a "MakeIterator" or
+"IteratorGetNext" op. In contrast to Iterator, AnonymousIterator prevents
+resource sharing by name, and does not keep a reference to the resource
+container.
+END
+ }
+ summary: "A container for an iterator resource."
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_AnonymousIterator.pbtxt b/tensorflow/core/api_def/python_api/api_def_AnonymousIterator.pbtxt
new file mode 100644
index 0000000000..98b7def4d6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_AnonymousIterator.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "AnonymousIterator"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index b6bf0ecd09..87bc8ebefe 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -438,6 +438,9 @@ class IteratorStateVariant {
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
kIteratorVariantTypeName);
+// Note that IteratorHandleOp holds a reference to the resource it creates. If
+// cleaning up resources with DestroyResourceOp is important, consider creating
+// resource containers with AnonymousIteratorHandleOp instead.
class IteratorHandleOp : public OpKernel {
public:
explicit IteratorHandleOp(OpKernelConstruction* ctx)
@@ -574,6 +577,75 @@ class IteratorHandleOp : public OpKernel {
string name_;
};
+// Like IteratorHandleOp, but creates handles which are never shared, and does
+// not hold a reference to these handles. The latter is important for eager
+// execution, since OpKernel instances generally live as long as the program
+// running them.
+class AnonymousIteratorHandleOp : public OpKernel {
+ public:
+ explicit AnonymousIteratorHandleOp(OpKernelConstruction* context)
+ : OpKernel(context), graph_def_version_(context->graph_def_version()) {
+ OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_dtypes_));
+ OP_REQUIRES_OK(context, context->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ FunctionLibraryRuntime* lib;
+ std::unique_ptr<DeviceMgr> device_mgr(nullptr);
+ std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
+ OP_REQUIRES_OK(context,
+ context->function_library()->Clone(&flib_def, &pflr, &lib));
+
+ ResourceMgr* mgr = context->resource_manager();
+
+ const string container_name = "AnonymousIterator";
+ string unique_name;
+ {
+ mutex_lock l(static_resource_lookup_mutex_);
+ while (true) { // Find an unused name
+ IteratorResource* existing_resource = nullptr;
+ unique_name = strings::StrCat("AnonymousIterator", current_id_++);
+ Status status = mgr->Lookup<IteratorResource>(
+ container_name, unique_name, &existing_resource);
+ if (status.code() == error::NOT_FOUND) {
+ break;
+ }
+ OP_REQUIRES_OK(context, status);
+ existing_resource->Unref();
+ }
+ IteratorResource* new_resource = new IteratorResource(
+ output_dtypes_, output_shapes_, graph_def_version_,
+ std::move(device_mgr), std::move(flib_def), std::move(pflr), lib);
+ // Create the resource with our chosen name under the resource lookup
+ // mutex to avoid another kernel racily creating a resource with this
+ // name.
+ OP_REQUIRES_OK(context, mgr->Create<IteratorResource>(
+ container_name, unique_name, new_resource));
+ }
+ OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+ context, 0, container_name, unique_name,
+ MakeTypeIndex<IteratorResource>()));
+ }
+
+ private:
+ // Coordinates Iterator unique name creation across AnonymousIteratorHandleOp
+ // instances.
+ static mutex static_resource_lookup_mutex_;
+ // current_id_ is just a hint for creating unique names. If it turns out
+ // there's a collision (e.g. because another AnonymousIteratorHandleOp
+ // instance is generating handles) we'll just skip that id.
+ static int64 current_id_ GUARDED_BY(static_resource_lookup_mutex_);
+ DataTypeVector output_dtypes_;
+ std::vector<PartialTensorShape> output_shapes_;
+ const int graph_def_version_;
+};
+
+// Static initializers for AnonymousIteratorHandleOp id counting.
+mutex AnonymousIteratorHandleOp::static_resource_lookup_mutex_{
+ LINKER_INITIALIZED};
+int64 AnonymousIteratorHandleOp::current_id_(0);
+
class MakeIteratorOp : public OpKernel {
public:
explicit MakeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@@ -1066,6 +1138,8 @@ class DeserializeIteratorOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
MakeIteratorOp);
+REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_CPU),
+ AnonymousIteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
ToSingleElementOp);
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 576946eddd..6d7d8630a7 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -564,6 +564,12 @@ REGISTER_OP("Iterator")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("AnonymousIterator")
+ .Output("handle: resource")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("MakeIterator")
.Input("dataset: variant")
.Input("iterator: resource")
diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
index 1ddedfda4e..e99f0a203b 100644
--- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
@@ -24,6 +24,7 @@ import zlib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -38,6 +39,13 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
+try:
+ import psutil # pylint: disable=g-import-not-at-top
+ psutil_import_succeeded = True
+except ImportError:
+ psutil_import_succeeded = False
+
+
class TextLineDatasetTest(test.TestCase):
def _lineText(self, f, l):
@@ -162,6 +170,34 @@ class TextLineDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(iterator.get_next())
+ def testIteratorResourceCleanup(self):
+ filename = os.path.join(self.get_temp_dir(), "text.txt")
+ with open(filename, "wt") as f:
+ for i in range(3):
+ f.write("%d\n" % (i,))
+ with context.eager_mode():
+ first_iterator = iter(readers.TextLineDataset(filename))
+ self.assertEqual(b"0", next(first_iterator).numpy())
+ second_iterator = iter(readers.TextLineDataset(filename))
+ self.assertEqual(b"0", next(second_iterator).numpy())
+ # Eager kernel caching is based on op attributes, which includes the
+ # Dataset's output shape. Create a different kernel to test that they
+ # don't create resources with the same names.
+ different_kernel_iterator = iter(
+ readers.TextLineDataset(filename).repeat().batch(16))
+ self.assertEqual([16], next(different_kernel_iterator).shape)
+ # Remove our references to the Python Iterator objects, which (assuming no
+ # reference cycles) is enough to trigger DestroyResourceOp and close the
+ # partially-read files.
+ del first_iterator
+ del second_iterator
+ del different_kernel_iterator
+ if not psutil_import_succeeded:
+ self.skipTest(
+ "psutil is required to check that we've closed our files.")
+ open_files = psutil.Process().open_files()
+ self.assertNotIn(filename, [open_file.path for open_file in open_files])
+
class FixedLengthRecordReaderTest(test.TestCase):
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index fd164277b6..b6dba4e3ca 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -471,9 +471,7 @@ class EagerIterator(object):
sparse.as_dense_types(self._output_types, self._output_classes))
self._flat_output_shapes = nest.flatten(
sparse.as_dense_shapes(self._output_shapes, self._output_classes))
- self._resource = gen_dataset_ops.iterator(
- shared_name="",
- container=_generate_shared_name("eageriterator"),
+ self._resource = gen_dataset_ops.anonymous_iterator(
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
gen_dataset_ops.make_iterator(ds_variant, self._resource)