From 0d0f37aef1f87c4c31c12ce876864022f624cd4c Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 14 Jun 2016 22:47:37 -0800 Subject: Add a size check before attempting to serialize a variable. This prevents the TensorSliceWriter from attempting to serialize variables that are larger than 2GB. It prevents potential memory corruption and segmentation faults. Fixes #2447. Change: 124921899 --- tensorflow/core/util/tensor_slice_writer.cc | 65 +++++++++++++++++++ tensorflow/core/util/tensor_slice_writer.h | 39 ++++++++++-- tensorflow/core/util/tensor_slice_writer_test.cc | 81 ++++++++++++++++++++++++ tensorflow/python/training/saver_test.py | 27 ++++++++ 4 files changed, 207 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/util/tensor_slice_writer.cc b/tensorflow/core/util/tensor_slice_writer.cc index 204d3b164a..74fcbbe649 100644 --- a/tensorflow/core/util/tensor_slice_writer.cc +++ b/tensorflow/core/util/tensor_slice_writer.cc @@ -126,6 +126,71 @@ Status TensorSliceWriter::Finish() { return s; } +/* static */ +size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) { + switch (dt) { + case DT_FLOAT: + return 4; + case DT_DOUBLE: + return 8; + case DT_INT32: + return 10; + case DT_UINT8: + return 2; + case DT_INT16: + return 10; + case DT_INT8: + return 10; + case DT_COMPLEX64: + return 8; + case DT_INT64: + return 10; + case DT_BOOL: + return 1; + case DT_QINT8: + return 10; + case DT_QUINT8: + return 2; + case DT_QINT32: + return 10; + case DT_QINT16: + return 10; + case DT_QUINT16: + return 3; + case DT_UINT16: + return 3; + case DT_COMPLEX128: + return 16; + case DT_HALF: + return 3; + case DT_INVALID: + case DT_STRING: + case DT_BFLOAT16: + default: + CHECK(false) << "MaxBytesPerElement not implemented for dtype: " << dt; + } + return 0; +} + +template <> +Status TensorSliceWriter::SaveData(const string* data, int num_elements, + SavedSlice* ss) { + size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes + + (num_elements * MaxBytesPerElement(DT_INT32)); + for (int i = 0; i < num_elements; ++i) { + size_bound += data[i].size(); + } + if (size_bound > kMaxMessageBytes) { + return errors::InvalidArgument( + "Tensor slice is too large to serialize (conservative estimate: ", + size_bound, " bytes)"); + } + Fill(data, num_elements, ss->mutable_data()); + DCHECK_GE(ss->ByteSize(), 0); + DCHECK_LE(ss->ByteSize(), size_bound); + return Status::OK(); +} + } // namespace checkpoint } // namespace tensorflow diff --git a/tensorflow/core/util/tensor_slice_writer.h b/tensorflow/core/util/tensor_slice_writer.h index d93cddaebe..d103cda4aa 100644 --- a/tensorflow/core/util/tensor_slice_writer.h +++ b/tensorflow/core/util/tensor_slice_writer.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -61,11 +62,24 @@ class TensorSliceWriter { const TensorSlice& slice, const T* data); Status Finish(); - private: // Allocate "num_elements" elements in "ss" and save the data in "data" // there. template - static void SaveData(const T* data, int num_elements, SavedSlice* ss); + static Status SaveData(const T* data, int num_elements, SavedSlice* ss); + + static size_t MaxBytesPerElement(DataType dt); + + private: + static const size_t kMaxMessageBytes = 1LL << 31; + // Filling in the TensorProto in a SavedSlice will add the following + // header bytes, in addition to the data: + // - 1 byte: TensorProto tag and wire format + // - <= 5 bytes: TensorProto length + // - 1 byte: Repeated *_val tag and wire format + // - <= 5 bytes: *_val length + // However, we add 1KB of slack, to be conservative and guard + // against other additions to the TensorProto. + static const size_t kTensorProtoHeaderBytes = 1 << 10; const string filename_; const CreateBuilderFunction create_builder_; @@ -132,7 +146,7 @@ Status TensorSliceWriter::Add(const string& name, const TensorShape& shape, TensorShape saved_shape(ssm->shape()); TensorShape sliced_shape; TF_RETURN_IF_ERROR(slice.SliceTensorShape(saved_shape, &sliced_shape)); - SaveData(data, sliced_shape.num_elements(), ss); + TF_RETURN_IF_ERROR(SaveData(data, sliced_shape.num_elements(), ss)); string key = EncodeTensorNameSlice(name, slice); // TODO(yangke): consider doing a two-pass thing where the first pass just // list the tensor slices we want to save and then another pass to actually @@ -148,11 +162,26 @@ Status TensorSliceWriter::Add(const string& name, const TensorShape& shape, } template -void TensorSliceWriter::SaveData(const T* data, int num_elements, - SavedSlice* ss) { +Status TensorSliceWriter::SaveData(const T* data, int num_elements, + SavedSlice* ss) { + size_t size_bound = + ss->ByteSize() + kTensorProtoHeaderBytes + + (MaxBytesPerElement(DataTypeToEnum::value) * num_elements); + if (size_bound > kMaxMessageBytes) { + return errors::InvalidArgument( + "Tensor slice is too large to serialize (conservative estimate: ", + size_bound, " bytes)"); + } Fill(data, num_elements, ss->mutable_data()); + DCHECK_GE(ss->ByteSize(), 0); + DCHECK_LE(ss->ByteSize(), size_bound); + return Status::OK(); } +template <> +Status TensorSliceWriter::SaveData(const string* data, int num_elements, + SavedSlice* ss); + // Create a table builder that will write to "filename" in // tensorflow::io::Table format. If successful, return OK // and set "*builder" to the allocated builder. Otherwise, return a diff --git a/tensorflow/core/util/tensor_slice_writer_test.cc b/tensorflow/core/util/tensor_slice_writer_test.cc index cfa4b04803..ce8398913d 100644 --- a/tensorflow/core/util/tensor_slice_writer_test.cc +++ b/tensorflow/core/util/tensor_slice_writer_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/util/tensor_slice_writer.h" +#include + #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/logging.h" @@ -263,6 +265,85 @@ void TensorSliceWriteTestHelper::CheckEntries(const string& fname) { } } +template +size_t BytesPerElementHelper(DT value) { + SavedSlice ss; + std::array lo_data; + std::fill(lo_data.begin(), lo_data.end(), value); + TensorSliceWriter::SaveData(lo_data.data(), lo_data.size(), &ss); + int lo_byte_size = ss.ByteSize(); + + std::array hi_data; + std::fill(hi_data.begin(), hi_data.end(), value); + TensorSliceWriter::SaveData(hi_data.data(), hi_data.size(), &ss); + int hi_byte_size = ss.ByteSize(); + + return (hi_byte_size - lo_byte_size) / (hi_data.size() - lo_data.size()); +} + +TEST(TensorSliceWriteTest, CheckpointSize) { + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_BOOL), + BytesPerElementHelper(false)); + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_BOOL), + BytesPerElementHelper(true)); + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_FLOAT), + BytesPerElementHelper(-1.0)); + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_DOUBLE), + BytesPerElementHelper(-1.0)); + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_COMPLEX64), + BytesPerElementHelper(-1.0)); + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_COMPLEX128), + BytesPerElementHelper(-1.0)); + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT32), + BytesPerElementHelper(-1)); + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT64), + BytesPerElementHelper(-1)); + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_UINT16), + BytesPerElementHelper(std::numeric_limits::max())); + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_UINT8), + BytesPerElementHelper(std::numeric_limits::max())); + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT8), + BytesPerElementHelper(-1)); + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT16), + BytesPerElementHelper(-1)); + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QINT8), + BytesPerElementHelper(-1)); + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QUINT8), + BytesPerElementHelper(std::numeric_limits::max())); + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QINT32), + BytesPerElementHelper(-1)); + EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_HALF), + BytesPerElementHelper(Eigen::half(-1.0))); +} + +TEST(TensorSliceWriteTest, SizeErrors) { + const string filename = io::JoinPath(testing::TmpDir(), "checkpoint"); + + TensorSliceWriter writer(filename, CreateTableTensorSliceBuilder); + + // Add a 300MB int8 tensor slice, which will fail because it expands to 3GB. + { + TensorShape shape({300, 1000000}); + TensorSlice slice = TensorSlice::ParseOrDie("-:-"); + const std::vector data(300000000, -1); + Status s = writer.Add("test1", shape, slice, data.data()); + EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("Tensor slice is too large to serialize")); + } + + // Add a large string tensor slice, which will fail. + { + TensorShape shape({100, 1000000}); + TensorSlice slice = TensorSlice::ParseOrDie("-:-"); + const std::vector data(100000000, "rhubarbrhubarb"); + Status s = writer.Add("test2", shape, slice, data.data()); + EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("Tensor slice is too large to serialize")); + } +} + } // namespace checkpoint } // namespace tensorflow diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index a975506781..f44ec417a5 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -287,6 +287,33 @@ class SaverTest(tf.test.TestCase): expected_save_path = "%s-%d" % (save_path, global_step_int) self.assertEqual(expected_save_path, val) + def testLargeVariable(self): + save_path = os.path.join(self.get_temp_dir(), "large_variable") + with tf.Session("", graph=tf.Graph()) as sess: + # Declare a variable larger than 2GB. + with tf.device("/cpu:0"): + var = tf.Variable(tf.constant(-1, shape=[300, 1000000], dtype=tf.int8)) + save = tf.train.Saver({var.op.name: var}) + var.initializer.run() + with self.assertRaisesRegexp( + tf.errors.InvalidArgumentError, + "Tensor slice is too large to serialize"): + save.save(sess, save_path) + + with tf.Session("", graph=tf.Graph()) as sess: + # Declare a variable that is exactly 2GB. This should fail, + # because a serialized checkpoint includes other header + # metadata. + with tf.device("/cpu:0"): + var = tf.Variable( + tf.constant(False, shape=[2, 1024, 1024, 1024], dtype=tf.bool)) + save = tf.train.Saver({var.op.name: var}) + var.initializer.run() + with self.assertRaisesRegexp( + tf.errors.InvalidArgumentError, + "Tensor slice is too large to serialize"): + save.save(sess, save_path) + class SaveRestoreShardedTest(tf.test.TestCase): -- cgit v1.2.3