diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-06 13:19:55 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-06 13:24:14 -0700 |
commit | 1132333db732548e85c2c860239421b649dcc5d6 (patch) | |
tree | 3157419c6179d191360964296f91ade431fd880a /tensorflow/python/lib | |
parent | 598f043fee12daf711cd65f36487483c722a489d (diff) |
Add nullptr checks so that Write() and Flush() fail instead of segfault
PiperOrigin-RevId: 207596283
Diffstat (limited to 'tensorflow/python/lib')
-rw-r--r-- | tensorflow/python/lib/io/py_record_writer.cc | 29 | ||||
-rw-r--r-- | tensorflow/python/lib/io/tf_record.py | 1 | ||||
-rw-r--r-- | tensorflow/python/lib/io/tf_record_test.py | 29 |
3 files changed, 48 insertions, 11 deletions
diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc index ae76fcceba..3c64813735 100644 --- a/tensorflow/python/lib/io/py_record_writer.cc +++ b/tensorflow/python/lib/io/py_record_writer.cc @@ -59,6 +59,11 @@ bool PyRecordWriter::WriteRecord(tensorflow::StringPiece record) { } void PyRecordWriter::Flush(TF_Status* out_status) { + if (writer_ == nullptr) { + TF_SetStatus(out_status, TF_FAILED_PRECONDITION, + "Writer not initialized or previously closed"); + return; + } Status s = writer_->Flush(); if (!s.ok()) { Set_TF_Status_from_Status(out_status, s); @@ -67,18 +72,22 @@ void PyRecordWriter::Flush(TF_Status* out_status) { } void PyRecordWriter::Close(TF_Status* out_status) { - Status s = writer_->Close(); - if (!s.ok()) { - Set_TF_Status_from_Status(out_status, s); - return; + if (writer_ != nullptr) { + Status s = writer_->Close(); + if (!s.ok()) { + Set_TF_Status_from_Status(out_status, s); + return; + } + writer_.reset(nullptr); } - writer_.reset(nullptr); - s = file_->Close(); - if (!s.ok()) { - Set_TF_Status_from_Status(out_status, s); - return; + if (file_ != nullptr) { + Status s = file_->Close(); + if (!s.ok()) { + Set_TF_Status_from_Status(out_status, s); + return; + } + file_.reset(nullptr); } - file_.reset(nullptr); } } // namespace io diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py index bf2d6f68b5..941d6cd67c 100644 --- a/tensorflow/python/lib/io/tf_record.py +++ b/tensorflow/python/lib/io/tf_record.py @@ -125,6 +125,7 @@ class TFRecordWriter(object): Args: record: str """ + # TODO(sethtroisi): Failures are currently swallowed, change that. self._writer.WriteRecord(record) def flush(self): diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py index 4a13e8c428..4743c037ec 100644 --- a/tensorflow/python/lib/io/tf_record_test.py +++ b/tensorflow/python/lib/io/tf_record_test.py @@ -318,7 +318,6 @@ class TFRecordIteratorTest(TFCompressionTestCase): for _ in tf_record.tf_record_iterator(fn_truncated): pass - class TFRecordWriterCloseAndFlushTests(test.TestCase): def setUp(self, compression_type=TFRecordCompressionType.NONE): @@ -338,6 +337,34 @@ class TFRecordWriterCloseAndFlushTests(test.TestCase): # Verify no segfault if writer isn't explicitly closed. + def testWriteAndRead(self): + records = list(map(self._Record, range(self._num_records))) + for record in records: + self._writer.write(record) + self._writer.close() + + actual = list(tf_record.tf_record_iterator(self._fn, self._options)) + self.assertListEqual(actual, records) + + def testDoubleClose(self): + self._writer.write(self._Record(0)) + self._writer.close() + self._writer.close() + + def testFlushAfterCloseIsError(self): + self._writer.write(self._Record(0)) + self._writer.close() + + with self.assertRaises(errors_impl.FailedPreconditionError): + self._writer.flush() + + def testWriteAfterClose(self): + self._writer.write(self._Record(0)) + self._writer.close() + + # TODO(sethtroisi): No way to know this failed, changed that. + self._writer.write(self._Record(1)) + class TFRecordWriterCloseAndFlushGzipTests(TFRecordWriterCloseAndFlushTests): |