aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/lib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-06 13:19:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-06 13:24:14 -0700
commit1132333db732548e85c2c860239421b649dcc5d6 (patch)
tree3157419c6179d191360964296f91ade431fd880a /tensorflow/python/lib
parent598f043fee12daf711cd65f36487483c722a489d (diff)
Add nullptr checks so that Write() and Flush() fail instead of segfault
PiperOrigin-RevId: 207596283
Diffstat (limited to 'tensorflow/python/lib')
-rw-r--r--tensorflow/python/lib/io/py_record_writer.cc29
-rw-r--r--tensorflow/python/lib/io/tf_record.py1
-rw-r--r--tensorflow/python/lib/io/tf_record_test.py29
3 files changed, 48 insertions, 11 deletions
diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc
index ae76fcceba..3c64813735 100644
--- a/tensorflow/python/lib/io/py_record_writer.cc
+++ b/tensorflow/python/lib/io/py_record_writer.cc
@@ -59,6 +59,11 @@ bool PyRecordWriter::WriteRecord(tensorflow::StringPiece record) {
}
void PyRecordWriter::Flush(TF_Status* out_status) {
+ if (writer_ == nullptr) {
+ TF_SetStatus(out_status, TF_FAILED_PRECONDITION,
+ "Writer not initialized or previously closed");
+ return;
+ }
Status s = writer_->Flush();
if (!s.ok()) {
Set_TF_Status_from_Status(out_status, s);
@@ -67,18 +72,22 @@ void PyRecordWriter::Flush(TF_Status* out_status) {
}
void PyRecordWriter::Close(TF_Status* out_status) {
- Status s = writer_->Close();
- if (!s.ok()) {
- Set_TF_Status_from_Status(out_status, s);
- return;
+ if (writer_ != nullptr) {
+ Status s = writer_->Close();
+ if (!s.ok()) {
+ Set_TF_Status_from_Status(out_status, s);
+ return;
+ }
+ writer_.reset(nullptr);
}
- writer_.reset(nullptr);
- s = file_->Close();
- if (!s.ok()) {
- Set_TF_Status_from_Status(out_status, s);
- return;
+ if (file_ != nullptr) {
+ Status s = file_->Close();
+ if (!s.ok()) {
+ Set_TF_Status_from_Status(out_status, s);
+ return;
+ }
+ file_.reset(nullptr);
}
- file_.reset(nullptr);
}
} // namespace io
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index bf2d6f68b5..941d6cd67c 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -125,6 +125,7 @@ class TFRecordWriter(object):
Args:
record: str
"""
+ # TODO(sethtroisi): Failures are currently swallowed, change that.
self._writer.WriteRecord(record)
def flush(self):
diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py
index 4a13e8c428..4743c037ec 100644
--- a/tensorflow/python/lib/io/tf_record_test.py
+++ b/tensorflow/python/lib/io/tf_record_test.py
@@ -318,7 +318,6 @@ class TFRecordIteratorTest(TFCompressionTestCase):
for _ in tf_record.tf_record_iterator(fn_truncated):
pass
-
class TFRecordWriterCloseAndFlushTests(test.TestCase):
def setUp(self, compression_type=TFRecordCompressionType.NONE):
@@ -338,6 +337,34 @@ class TFRecordWriterCloseAndFlushTests(test.TestCase):
# Verify no segfault if writer isn't explicitly closed.
+ def testWriteAndRead(self):
+ records = list(map(self._Record, range(self._num_records)))
+ for record in records:
+ self._writer.write(record)
+ self._writer.close()
+
+ actual = list(tf_record.tf_record_iterator(self._fn, self._options))
+ self.assertListEqual(actual, records)
+
+ def testDoubleClose(self):
+ self._writer.write(self._Record(0))
+ self._writer.close()
+ self._writer.close()
+
+ def testFlushAfterCloseIsError(self):
+ self._writer.write(self._Record(0))
+ self._writer.close()
+
+ with self.assertRaises(errors_impl.FailedPreconditionError):
+ self._writer.flush()
+
+ def testWriteAfterClose(self):
+ self._writer.write(self._Record(0))
+ self._writer.close()
+
+ # TODO(sethtroisi): No way to know this failed, changed that.
+ self._writer.write(self._Record(1))
+
class TFRecordWriterCloseAndFlushGzipTests(TFRecordWriterCloseAndFlushTests):