From 1132333db732548e85c2c860239421b649dcc5d6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 6 Aug 2018 13:19:55 -0700 Subject: Add nullptr checks so that Write() and Flush() fail instead of segfault PiperOrigin-RevId: 207596283 --- .../core/lib/io/record_reader_writer_test.cc | 23 ++++++++++++++++++++++ tensorflow/core/lib/io/record_writer.cc | 9 +++++++++ 2 files changed, 32 insertions(+) (limited to 'tensorflow/core/lib') diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc index c36c909399..13bea1f8f1 100644 --- a/tensorflow/core/lib/io/record_reader_writer_test.cc +++ b/tensorflow/core/lib/io/record_reader_writer_test.cc @@ -189,4 +189,27 @@ TEST(RecordReaderWriterTest, TestZlib) { } } +TEST(RecordReaderWriterTest, TestUseAfterClose) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/record_reader_writer_flush_close_test"; + + { + std::unique_ptr file; + TF_CHECK_OK(env->NewWritableFile(fname, &file)); + + io::RecordWriterOptions options; + options.compression_type = io::RecordWriterOptions::ZLIB_COMPRESSION; + io::RecordWriter writer(file.get(), options); + TF_EXPECT_OK(writer.WriteRecord("abc")); + TF_CHECK_OK(writer.Flush()); + TF_CHECK_OK(writer.Close()); + + CHECK_EQ(writer.WriteRecord("abc").code(), error::FAILED_PRECONDITION); + CHECK_EQ(writer.Flush().code(), error::FAILED_PRECONDITION); + + // Second call to close is fine. + TF_CHECK_OK(writer.Close()); + } +} + } // namespace tensorflow diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc index ebc5648269..6e71d23e71 100644 --- a/tensorflow/core/lib/io/record_writer.cc +++ b/tensorflow/core/lib/io/record_writer.cc @@ -93,6 +93,10 @@ static uint32 MaskedCrc(const char* data, size_t n) { } Status RecordWriter::WriteRecord(StringPiece data) { + if (dest_ == nullptr) { + return Status(::tensorflow::error::FAILED_PRECONDITION, + "Writer not initialized or previously closed"); + } // Format of a single record: // uint64 length // uint32 masked crc of length @@ -111,6 +115,7 @@ Status RecordWriter::WriteRecord(StringPiece data) { } Status RecordWriter::Close() { + if (dest_ == nullptr) return Status::OK(); #if !defined(IS_SLIM_BUILD) if (IsZlibCompressed(options_)) { Status s = dest_->Close(); @@ -123,6 +128,10 @@ Status RecordWriter::Close() { } Status RecordWriter::Flush() { + if (dest_ == nullptr) { + return Status(::tensorflow::error::FAILED_PRECONDITION, + "Writer not initialized or previously closed"); + } if (IsZlibCompressed(options_)) { return dest_->Flush(); } -- cgit v1.2.3