diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-03 18:25:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-03 18:32:18 -0700 |
commit | 524799058dcc2fb25bf9a5bed49fa38e9b9ac387 (patch) | |
tree | c302a5e823442b4e6abbc4e058efed3b391bdb95 /tensorflow/python/lib | |
parent | 76c9af7e37709015ed51ee828010dcee925eb12e (diff) |
Change PyRecordWriter destructor order so that file_ is still available for writing when writer_ destructor is called
PiperOrigin-RevId: 207355721
Diffstat (limited to 'tensorflow/python/lib')
-rw-r--r-- | tensorflow/python/lib/io/py_record_writer.cc | 3 | ||||
-rw-r--r-- | tensorflow/python/lib/io/tf_record_test.py | 35 |
2 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc index ba749da47a..ae76fcceba 100644 --- a/tensorflow/python/lib/io/py_record_writer.cc +++ b/tensorflow/python/lib/io/py_record_writer.cc @@ -47,6 +47,9 @@ PyRecordWriter* PyRecordWriter::New(const string& filename, } PyRecordWriter::~PyRecordWriter() { + // Writer depends on file during close for zlib flush, so destruct first. + writer_.reset(); + file_.reset(); } bool PyRecordWriter::WriteRecord(tensorflow::StringPiece record) { diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py index dcc1a25f42..4a13e8c428 100644 --- a/tensorflow/python/lib/io/tf_record_test.py +++ b/tensorflow/python/lib/io/tf_record_test.py @@ -318,5 +318,40 @@ class TFRecordIteratorTest(TFCompressionTestCase): for _ in tf_record.tf_record_iterator(fn_truncated): pass + +class TFRecordWriterCloseAndFlushTests(test.TestCase): + + def setUp(self, compression_type=TFRecordCompressionType.NONE): + super(TFRecordWriterCloseAndFlushTests, self).setUp() + self._fn = os.path.join(self.get_temp_dir(), "tf_record_writer_test.txt") + self._options = tf_record.TFRecordOptions(compression_type) + self._writer = tf_record.TFRecordWriter(self._fn, self._options) + self._num_records = 20 + + def _Record(self, r): + return compat.as_bytes("Record %d" % r) + + def testWriteAndLeaveOpen(self): + records = list(map(self._Record, range(self._num_records))) + for record in records: + self._writer.write(record) + + # Verify no segfault if writer isn't explicitly closed. + + +class TFRecordWriterCloseAndFlushGzipTests(TFRecordWriterCloseAndFlushTests): + + def setUp(self): + super(TFRecordWriterCloseAndFlushGzipTests, + self).setUp(TFRecordCompressionType.GZIP) + + +class TFRecordWriterCloseAndFlushZlibTests(TFRecordWriterCloseAndFlushTests): + + def setUp(self): + super(TFRecordWriterCloseAndFlushZlibTests, + self).setUp(TFRecordCompressionType.ZLIB) + + if __name__ == "__main__": test.main() |