diff options
6 files changed, 134 insertions, 3 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt b/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt new file mode 100644 index 0000000000..d8c2ed40a3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt @@ -0,0 +1,13 @@ +op { + graph_op_name: "AnonymousIterator" + out_arg { + name: "handle" + description: <<END +A handle to the iterator that can be passed to a "MakeIterator" or +"IteratorGetNext" op. In contrast to Iterator, AnonymousIterator prevents +resource sharing by name, and does not keep a reference to the resource +container. +END + } + summary: "A container for an iterator resource." +} diff --git a/tensorflow/core/api_def/python_api/api_def_AnonymousIterator.pbtxt b/tensorflow/core/api_def/python_api/api_def_AnonymousIterator.pbtxt new file mode 100644 index 0000000000..98b7def4d6 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_AnonymousIterator.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "AnonymousIterator" + visibility: HIDDEN +} diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index b6bf0ecd09..87bc8ebefe 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -438,6 +438,9 @@ class IteratorStateVariant { REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant, kIteratorVariantTypeName); +// Note that IteratorHandleOp holds a reference to the resource it creates. If +// cleaning up resources with DestroyResourceOp is important, consider creating +// resource containers with AnonymousIteratorHandleOp instead. class IteratorHandleOp : public OpKernel { public: explicit IteratorHandleOp(OpKernelConstruction* ctx) @@ -574,6 +577,75 @@ class IteratorHandleOp : public OpKernel { string name_; }; +// Like IteratorHandleOp, but creates handles which are never shared, and does +// not hold a reference to these handles. The latter is important for eager +// execution, since OpKernel instances generally live as long as the program +// running them. +class AnonymousIteratorHandleOp : public OpKernel { + public: + explicit AnonymousIteratorHandleOp(OpKernelConstruction* context) + : OpKernel(context), graph_def_version_(context->graph_def_version()) { + OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_dtypes_)); + OP_REQUIRES_OK(context, context->GetAttr("output_shapes", &output_shapes_)); + } + + void Compute(OpKernelContext* context) override { + FunctionLibraryRuntime* lib; + std::unique_ptr<DeviceMgr> device_mgr(nullptr); + std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr); + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr); + OP_REQUIRES_OK(context, + context->function_library()->Clone(&flib_def, &pflr, &lib)); + + ResourceMgr* mgr = context->resource_manager(); + + const string container_name = "AnonymousIterator"; + string unique_name; + { + mutex_lock l(static_resource_lookup_mutex_); + while (true) { // Find an unused name + IteratorResource* existing_resource = nullptr; + unique_name = strings::StrCat("AnonymousIterator", current_id_++); + Status status = mgr->Lookup<IteratorResource>( + container_name, unique_name, &existing_resource); + if (status.code() == error::NOT_FOUND) { + break; + } + OP_REQUIRES_OK(context, status); + existing_resource->Unref(); + } + IteratorResource* new_resource = new IteratorResource( + output_dtypes_, output_shapes_, graph_def_version_, + std::move(device_mgr), std::move(flib_def), std::move(pflr), lib); + // Create the resource with our chosen name under the resource lookup + // mutex to avoid another kernel racily creating a resource with this + // name. + OP_REQUIRES_OK(context, mgr->Create<IteratorResource>( + container_name, unique_name, new_resource)); + } + OP_REQUIRES_OK(context, MakeResourceHandleToOutput( + context, 0, container_name, unique_name, + MakeTypeIndex<IteratorResource>())); + } + + private: + // Coordinates Iterator unique name creation across AnonymousIteratorHandleOp + // instances. + static mutex static_resource_lookup_mutex_; + // current_id_ is just a hint for creating unique names. If it turns out + // there's a collision (e.g. because another AnonymousIteratorHandleOp + // instance is generating handles) we'll just skip that id. + static int64 current_id_ GUARDED_BY(static_resource_lookup_mutex_); + DataTypeVector output_dtypes_; + std::vector<PartialTensorShape> output_shapes_; + const int graph_def_version_; +}; + +// Static initializers for AnonymousIteratorHandleOp id counting. +mutex AnonymousIteratorHandleOp::static_resource_lookup_mutex_{ + LINKER_INITIALIZED}; +int64 AnonymousIteratorHandleOp::current_id_(0); + class MakeIteratorOp : public OpKernel { public: explicit MakeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} @@ -1066,6 +1138,8 @@ class DeserializeIteratorOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp); REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU), MakeIteratorOp); +REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_CPU), + AnonymousIteratorHandleOp); REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU), ToSingleElementOp); REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU), diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 576946eddd..6d7d8630a7 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -564,6 +564,12 @@ REGISTER_OP("Iterator") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("AnonymousIterator") + .Output("handle: resource") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("MakeIterator") .Input("dataset: variant") .Input("iterator: resource") diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py index 1ddedfda4e..e99f0a203b 100644 --- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py @@ -24,6 +24,7 @@ import zlib from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -38,6 +39,13 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat +try: + import psutil # pylint: disable=g-import-not-at-top + psutil_import_succeeded = True +except ImportError: + psutil_import_succeeded = False + + class TextLineDatasetTest(test.TestCase): def _lineText(self, f, l): @@ -162,6 +170,34 @@ class TextLineDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) + def testIteratorResourceCleanup(self): + filename = os.path.join(self.get_temp_dir(), "text.txt") + with open(filename, "wt") as f: + for i in range(3): + f.write("%d\n" % (i,)) + with context.eager_mode(): + first_iterator = iter(readers.TextLineDataset(filename)) + self.assertEqual(b"0", next(first_iterator).numpy()) + second_iterator = iter(readers.TextLineDataset(filename)) + self.assertEqual(b"0", next(second_iterator).numpy()) + # Eager kernel caching is based on op attributes, which includes the + # Dataset's output shape. Create a different kernel to test that they + # don't create resources with the same names. + different_kernel_iterator = iter( + readers.TextLineDataset(filename).repeat().batch(16)) + self.assertEqual([16], next(different_kernel_iterator).shape) + # Remove our references to the Python Iterator objects, which (assuming no + # reference cycles) is enough to trigger DestroyResourceOp and close the + # partially-read files. + del first_iterator + del second_iterator + del different_kernel_iterator + if not psutil_import_succeeded: + self.skipTest( + "psutil is required to check that we've closed our files.") + open_files = psutil.Process().open_files() + self.assertNotIn(filename, [open_file.path for open_file in open_files]) + class FixedLengthRecordReaderTest(test.TestCase): diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index fd164277b6..b6dba4e3ca 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -471,9 +471,7 @@ class EagerIterator(object): sparse.as_dense_types(self._output_types, self._output_classes)) self._flat_output_shapes = nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)) - self._resource = gen_dataset_ops.iterator( - shared_name="", - container=_generate_shared_name("eageriterator"), + self._resource = gen_dataset_ops.anonymous_iterator( output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) gen_dataset_ops.make_iterator(ds_variant, self._resource) |