aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/lib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-03 18:25:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-03 18:32:18 -0700
commit524799058dcc2fb25bf9a5bed49fa38e9b9ac387 (patch)
treec302a5e823442b4e6abbc4e058efed3b391bdb95 /tensorflow/python/lib
parent76c9af7e37709015ed51ee828010dcee925eb12e (diff)
Change PyRecordWriter destructor order so that file_ is still available for writing when writer_ destructor is called
PiperOrigin-RevId: 207355721
Diffstat (limited to 'tensorflow/python/lib')
-rw-r--r--tensorflow/python/lib/io/py_record_writer.cc3
-rw-r--r--tensorflow/python/lib/io/tf_record_test.py35
2 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc
index ba749da47a..ae76fcceba 100644
--- a/tensorflow/python/lib/io/py_record_writer.cc
+++ b/tensorflow/python/lib/io/py_record_writer.cc
@@ -47,6 +47,9 @@ PyRecordWriter* PyRecordWriter::New(const string& filename,
}
PyRecordWriter::~PyRecordWriter() {
+ // Writer depends on file during close for zlib flush, so destruct first.
+ writer_.reset();
+ file_.reset();
}
bool PyRecordWriter::WriteRecord(tensorflow::StringPiece record) {
diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py
index dcc1a25f42..4a13e8c428 100644
--- a/tensorflow/python/lib/io/tf_record_test.py
+++ b/tensorflow/python/lib/io/tf_record_test.py
@@ -318,5 +318,40 @@ class TFRecordIteratorTest(TFCompressionTestCase):
for _ in tf_record.tf_record_iterator(fn_truncated):
pass
+
+class TFRecordWriterCloseAndFlushTests(test.TestCase):
+
+ def setUp(self, compression_type=TFRecordCompressionType.NONE):
+ super(TFRecordWriterCloseAndFlushTests, self).setUp()
+ self._fn = os.path.join(self.get_temp_dir(), "tf_record_writer_test.txt")
+ self._options = tf_record.TFRecordOptions(compression_type)
+ self._writer = tf_record.TFRecordWriter(self._fn, self._options)
+ self._num_records = 20
+
+ def _Record(self, r):
+ return compat.as_bytes("Record %d" % r)
+
+ def testWriteAndLeaveOpen(self):
+ records = list(map(self._Record, range(self._num_records)))
+ for record in records:
+ self._writer.write(record)
+
+ # Verify no segfault if writer isn't explicitly closed.
+
+
+class TFRecordWriterCloseAndFlushGzipTests(TFRecordWriterCloseAndFlushTests):
+
+ def setUp(self):
+ super(TFRecordWriterCloseAndFlushGzipTests,
+ self).setUp(TFRecordCompressionType.GZIP)
+
+
+class TFRecordWriterCloseAndFlushZlibTests(TFRecordWriterCloseAndFlushTests):
+
+ def setUp(self):
+ super(TFRecordWriterCloseAndFlushZlibTests,
+ self).setUp(TFRecordCompressionType.ZLIB)
+
+
if __name__ == "__main__":
test.main()