aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-06-14 22:47:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-15 00:02:42 -0700
commitcc40cd3b0a8b83f5ee071b7ee32c17b56815a89c (patch)
treeebbfdda9dfda2d2ebbe96e5f93053e4597c1a37d
parent9e27c607dc6ab118eb4fe11ffdadfd79fd9eb3b4 (diff)
Add a size check before attempting to serialize a variable.
This prevents the TensorSliceWriter from attempting to serialize variables that are larger than 2GB. It prevents potential memory corruption and segmentation faults. Fixes #2447. Change: 124921899
-rw-r--r--tensorflow/core/util/tensor_slice_writer.cc65
-rw-r--r--tensorflow/core/util/tensor_slice_writer.h39
-rw-r--r--tensorflow/core/util/tensor_slice_writer_test.cc81
-rw-r--r--tensorflow/python/training/saver_test.py27
4 files changed, 207 insertions, 5 deletions
diff --git a/tensorflow/core/util/tensor_slice_writer.cc b/tensorflow/core/util/tensor_slice_writer.cc
index 204d3b164a..74fcbbe649 100644
--- a/tensorflow/core/util/tensor_slice_writer.cc
+++ b/tensorflow/core/util/tensor_slice_writer.cc
@@ -126,6 +126,71 @@ Status TensorSliceWriter::Finish() {
return s;
}
+/* static */
+size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) {
+ switch (dt) {
+ case DT_FLOAT:
+ return 4;
+ case DT_DOUBLE:
+ return 8;
+ case DT_INT32:
+ return 10;
+ case DT_UINT8:
+ return 2;
+ case DT_INT16:
+ return 10;
+ case DT_INT8:
+ return 10;
+ case DT_COMPLEX64:
+ return 8;
+ case DT_INT64:
+ return 10;
+ case DT_BOOL:
+ return 1;
+ case DT_QINT8:
+ return 10;
+ case DT_QUINT8:
+ return 2;
+ case DT_QINT32:
+ return 10;
+ case DT_QINT16:
+ return 10;
+ case DT_QUINT16:
+ return 3;
+ case DT_UINT16:
+ return 3;
+ case DT_COMPLEX128:
+ return 16;
+ case DT_HALF:
+ return 3;
+ case DT_INVALID:
+ case DT_STRING:
+ case DT_BFLOAT16:
+ default:
+ CHECK(false) << "MaxBytesPerElement not implemented for dtype: " << dt;
+ }
+ return 0;
+}
+
+template <>
+Status TensorSliceWriter::SaveData(const string* data, int num_elements,
+ SavedSlice* ss) {
+ size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes +
+ (num_elements * MaxBytesPerElement(DT_INT32));
+ for (int i = 0; i < num_elements; ++i) {
+ size_bound += data[i].size();
+ }
+ if (size_bound > kMaxMessageBytes) {
+ return errors::InvalidArgument(
+ "Tensor slice is too large to serialize (conservative estimate: ",
+ size_bound, " bytes)");
+ }
+ Fill(data, num_elements, ss->mutable_data());
+ DCHECK_GE(ss->ByteSize(), 0);
+ DCHECK_LE(ss->ByteSize(), size_bound);
+ return Status::OK();
+}
+
} // namespace checkpoint
} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_writer.h b/tensorflow/core/util/tensor_slice_writer.h
index d93cddaebe..d103cda4aa 100644
--- a/tensorflow/core/util/tensor_slice_writer.h
+++ b/tensorflow/core/util/tensor_slice_writer.h
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -61,11 +62,24 @@ class TensorSliceWriter {
const TensorSlice& slice, const T* data);
Status Finish();
- private:
// Allocate "num_elements" elements in "ss" and save the data in "data"
// there.
template <typename T>
- static void SaveData(const T* data, int num_elements, SavedSlice* ss);
+ static Status SaveData(const T* data, int num_elements, SavedSlice* ss);
+
+ static size_t MaxBytesPerElement(DataType dt);
+
+ private:
+ static const size_t kMaxMessageBytes = 1LL << 31;
+ // Filling in the TensorProto in a SavedSlice will add the following
+ // header bytes, in addition to the data:
+ // - 1 byte: TensorProto tag and wire format
+ // - <= 5 bytes: TensorProto length
+ // - 1 byte: Repeated *_val tag and wire format
+ // - <= 5 bytes: *_val length
+ // However, we add 1KB of slack, to be conservative and guard
+ // against other additions to the TensorProto.
+ static const size_t kTensorProtoHeaderBytes = 1 << 10;
const string filename_;
const CreateBuilderFunction create_builder_;
@@ -132,7 +146,7 @@ Status TensorSliceWriter::Add(const string& name, const TensorShape& shape,
TensorShape saved_shape(ssm->shape());
TensorShape sliced_shape;
TF_RETURN_IF_ERROR(slice.SliceTensorShape(saved_shape, &sliced_shape));
- SaveData(data, sliced_shape.num_elements(), ss);
+ TF_RETURN_IF_ERROR(SaveData(data, sliced_shape.num_elements(), ss));
string key = EncodeTensorNameSlice(name, slice);
// TODO(yangke): consider doing a two-pass thing where the first pass just
// list the tensor slices we want to save and then another pass to actually
@@ -148,11 +162,26 @@ Status TensorSliceWriter::Add(const string& name, const TensorShape& shape,
}
template <typename T>
-void TensorSliceWriter::SaveData(const T* data, int num_elements,
- SavedSlice* ss) {
+Status TensorSliceWriter::SaveData(const T* data, int num_elements,
+ SavedSlice* ss) {
+ size_t size_bound =
+ ss->ByteSize() + kTensorProtoHeaderBytes +
+ (MaxBytesPerElement(DataTypeToEnum<T>::value) * num_elements);
+ if (size_bound > kMaxMessageBytes) {
+ return errors::InvalidArgument(
+ "Tensor slice is too large to serialize (conservative estimate: ",
+ size_bound, " bytes)");
+ }
Fill(data, num_elements, ss->mutable_data());
+ DCHECK_GE(ss->ByteSize(), 0);
+ DCHECK_LE(ss->ByteSize(), size_bound);
+ return Status::OK();
}
+template <>
+Status TensorSliceWriter::SaveData(const string* data, int num_elements,
+ SavedSlice* ss);
+
// Create a table builder that will write to "filename" in
// tensorflow::io::Table format. If successful, return OK
// and set "*builder" to the allocated builder. Otherwise, return a
diff --git a/tensorflow/core/util/tensor_slice_writer_test.cc b/tensorflow/core/util/tensor_slice_writer_test.cc
index cfa4b04803..ce8398913d 100644
--- a/tensorflow/core/util/tensor_slice_writer_test.cc
+++ b/tensorflow/core/util/tensor_slice_writer_test.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/util/tensor_slice_writer.h"
+#include <array>
+
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/logging.h"
@@ -263,6 +265,85 @@ void TensorSliceWriteTestHelper::CheckEntries(const string& fname) {
}
}
+template <typename DT>
+size_t BytesPerElementHelper(DT value) {
+ SavedSlice ss;
+ std::array<DT, 1> lo_data;
+ std::fill(lo_data.begin(), lo_data.end(), value);
+ TensorSliceWriter::SaveData(lo_data.data(), lo_data.size(), &ss);
+ int lo_byte_size = ss.ByteSize();
+
+ std::array<DT, 1001> hi_data;
+ std::fill(hi_data.begin(), hi_data.end(), value);
+ TensorSliceWriter::SaveData(hi_data.data(), hi_data.size(), &ss);
+ int hi_byte_size = ss.ByteSize();
+
+ return (hi_byte_size - lo_byte_size) / (hi_data.size() - lo_data.size());
+}
+
+TEST(TensorSliceWriteTest, CheckpointSize) {
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_BOOL),
+ BytesPerElementHelper<bool>(false));
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_BOOL),
+ BytesPerElementHelper<bool>(true));
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_FLOAT),
+ BytesPerElementHelper<float>(-1.0));
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_DOUBLE),
+ BytesPerElementHelper<double>(-1.0));
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_COMPLEX64),
+ BytesPerElementHelper<complex64>(-1.0));
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_COMPLEX128),
+ BytesPerElementHelper<complex128>(-1.0));
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT32),
+ BytesPerElementHelper<int32>(-1));
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT64),
+ BytesPerElementHelper<int64>(-1));
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_UINT16),
+ BytesPerElementHelper<uint16>(std::numeric_limits<uint16>::max()));
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_UINT8),
+ BytesPerElementHelper<uint8>(std::numeric_limits<uint8>::max()));
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT8),
+ BytesPerElementHelper<int8>(-1));
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT16),
+ BytesPerElementHelper<int16>(-1));
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QINT8),
+ BytesPerElementHelper<qint8>(-1));
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QUINT8),
+ BytesPerElementHelper<quint8>(std::numeric_limits<uint8>::max()));
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QINT32),
+ BytesPerElementHelper<qint32>(-1));
+ EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_HALF),
+ BytesPerElementHelper<Eigen::half>(Eigen::half(-1.0)));
+}
+
+TEST(TensorSliceWriteTest, SizeErrors) {
+ const string filename = io::JoinPath(testing::TmpDir(), "checkpoint");
+
+ TensorSliceWriter writer(filename, CreateTableTensorSliceBuilder);
+
+ // Add a 300MB int8 tensor slice, which will fail because it expands to 3GB.
+ {
+ TensorShape shape({300, 1000000});
+ TensorSlice slice = TensorSlice::ParseOrDie("-:-");
+ const std::vector<int8> data(300000000, -1);
+ Status s = writer.Add("test1", shape, slice, data.data());
+ EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("Tensor slice is too large to serialize"));
+ }
+
+ // Add a large string tensor slice, which will fail.
+ {
+ TensorShape shape({100, 1000000});
+ TensorSlice slice = TensorSlice::ParseOrDie("-:-");
+ const std::vector<string> data(100000000, "rhubarbrhubarb");
+ Status s = writer.Add("test2", shape, slice, data.data());
+ EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("Tensor slice is too large to serialize"));
+ }
+}
+
} // namespace checkpoint
} // namespace tensorflow
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index a975506781..f44ec417a5 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -287,6 +287,33 @@ class SaverTest(tf.test.TestCase):
expected_save_path = "%s-%d" % (save_path, global_step_int)
self.assertEqual(expected_save_path, val)
+ def testLargeVariable(self):
+ save_path = os.path.join(self.get_temp_dir(), "large_variable")
+ with tf.Session("", graph=tf.Graph()) as sess:
+ # Declare a variable larger than 2GB.
+ with tf.device("/cpu:0"):
+ var = tf.Variable(tf.constant(-1, shape=[300, 1000000], dtype=tf.int8))
+ save = tf.train.Saver({var.op.name: var})
+ var.initializer.run()
+ with self.assertRaisesRegexp(
+ tf.errors.InvalidArgumentError,
+ "Tensor slice is too large to serialize"):
+ save.save(sess, save_path)
+
+ with tf.Session("", graph=tf.Graph()) as sess:
+ # Declare a variable that is exactly 2GB. This should fail,
+ # because a serialized checkpoint includes other header
+ # metadata.
+ with tf.device("/cpu:0"):
+ var = tf.Variable(
+ tf.constant(False, shape=[2, 1024, 1024, 1024], dtype=tf.bool))
+ save = tf.train.Saver({var.op.name: var})
+ var.initializer.run()
+ with self.assertRaisesRegexp(
+ tf.errors.InvalidArgumentError,
+ "Tensor slice is too large to serialize"):
+ save.save(sess, save_path)
+
class SaveRestoreShardedTest(tf.test.TestCase):