diff options
author | 2016-06-14 22:47:37 -0800 | |
---|---|---|
committer | 2016-06-15 00:02:42 -0700 | |
commit | cc40cd3b0a8b83f5ee071b7ee32c17b56815a89c (patch) | |
tree | ebbfdda9dfda2d2ebbe96e5f93053e4597c1a37d /tensorflow/core/util/tensor_slice_writer.cc | |
parent | 9e27c607dc6ab118eb4fe11ffdadfd79fd9eb3b4 (diff) |
Add a size check before attempting to serialize a variable.
This prevents the TensorSliceWriter from attempting to serialize
variables that are larger than 2GB. It prevents potential memory
corruption and segmentation faults.
Fixes #2447.
Change: 124921899
Diffstat (limited to 'tensorflow/core/util/tensor_slice_writer.cc')
-rw-r--r-- | tensorflow/core/util/tensor_slice_writer.cc | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/core/util/tensor_slice_writer.cc b/tensorflow/core/util/tensor_slice_writer.cc index 204d3b164a..74fcbbe649 100644 --- a/tensorflow/core/util/tensor_slice_writer.cc +++ b/tensorflow/core/util/tensor_slice_writer.cc @@ -126,6 +126,71 @@ Status TensorSliceWriter::Finish() { return s; } +/* static */ +size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) { + switch (dt) { + case DT_FLOAT: + return 4; + case DT_DOUBLE: + return 8; + case DT_INT32: + return 10; + case DT_UINT8: + return 2; + case DT_INT16: + return 10; + case DT_INT8: + return 10; + case DT_COMPLEX64: + return 8; + case DT_INT64: + return 10; + case DT_BOOL: + return 1; + case DT_QINT8: + return 10; + case DT_QUINT8: + return 2; + case DT_QINT32: + return 10; + case DT_QINT16: + return 10; + case DT_QUINT16: + return 3; + case DT_UINT16: + return 3; + case DT_COMPLEX128: + return 16; + case DT_HALF: + return 3; + case DT_INVALID: + case DT_STRING: + case DT_BFLOAT16: + default: + CHECK(false) << "MaxBytesPerElement not implemented for dtype: " << dt; + } + return 0; +} + +template <> +Status TensorSliceWriter::SaveData(const string* data, int num_elements, + SavedSlice* ss) { + size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes + + (num_elements * MaxBytesPerElement(DT_INT32)); + for (int i = 0; i < num_elements; ++i) { + size_bound += data[i].size(); + } + if (size_bound > kMaxMessageBytes) { + return errors::InvalidArgument( + "Tensor slice is too large to serialize (conservative estimate: ", + size_bound, " bytes)"); + } + Fill(data, num_elements, ss->mutable_data()); + DCHECK_GE(ss->ByteSize(), 0); + DCHECK_LE(ss->ByteSize(), size_bound); + return Status::OK(); +} + } // namespace checkpoint } // namespace tensorflow |