aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/tensor_slice_writer.cc
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-06-14 22:47:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-15 00:02:42 -0700
commitcc40cd3b0a8b83f5ee071b7ee32c17b56815a89c (patch)
treeebbfdda9dfda2d2ebbe96e5f93053e4597c1a37d /tensorflow/core/util/tensor_slice_writer.cc
parent9e27c607dc6ab118eb4fe11ffdadfd79fd9eb3b4 (diff)
Add a size check before attempting to serialize a variable.
This prevents the TensorSliceWriter from attempting to serialize variables that are larger than 2GB. It prevents potential memory corruption and segmentation faults. Fixes #2447. Change: 124921899
Diffstat (limited to 'tensorflow/core/util/tensor_slice_writer.cc')
-rw-r--r--tensorflow/core/util/tensor_slice_writer.cc65
1 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/core/util/tensor_slice_writer.cc b/tensorflow/core/util/tensor_slice_writer.cc
index 204d3b164a..74fcbbe649 100644
--- a/tensorflow/core/util/tensor_slice_writer.cc
+++ b/tensorflow/core/util/tensor_slice_writer.cc
@@ -126,6 +126,71 @@ Status TensorSliceWriter::Finish() {
return s;
}
+/* static */
+size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) {
+ switch (dt) {
+ case DT_FLOAT:
+ return 4;
+ case DT_DOUBLE:
+ return 8;
+ case DT_INT32:
+ return 10;
+ case DT_UINT8:
+ return 2;
+ case DT_INT16:
+ return 10;
+ case DT_INT8:
+ return 10;
+ case DT_COMPLEX64:
+ return 8;
+ case DT_INT64:
+ return 10;
+ case DT_BOOL:
+ return 1;
+ case DT_QINT8:
+ return 10;
+ case DT_QUINT8:
+ return 2;
+ case DT_QINT32:
+ return 10;
+ case DT_QINT16:
+ return 10;
+ case DT_QUINT16:
+ return 3;
+ case DT_UINT16:
+ return 3;
+ case DT_COMPLEX128:
+ return 16;
+ case DT_HALF:
+ return 3;
+ case DT_INVALID:
+ case DT_STRING:
+ case DT_BFLOAT16:
+ default:
+ CHECK(false) << "MaxBytesPerElement not implemented for dtype: " << dt;
+ }
+ return 0;
+}
+
+template <>
+Status TensorSliceWriter::SaveData(const string* data, int num_elements,
+ SavedSlice* ss) {
+ size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes +
+ (num_elements * MaxBytesPerElement(DT_INT32));
+ for (int i = 0; i < num_elements; ++i) {
+ size_bound += data[i].size();
+ }
+ if (size_bound > kMaxMessageBytes) {
+ return errors::InvalidArgument(
+ "Tensor slice is too large to serialize (conservative estimate: ",
+ size_bound, " bytes)");
+ }
+ Fill(data, num_elements, ss->mutable_data());
+ DCHECK_GE(ss->ByteSize(), 0);
+ DCHECK_LE(ss->ByteSize(), size_bound);
+ return Status::OK();
+}
+
} // namespace checkpoint
} // namespace tensorflow