aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc32
1 files changed, 30 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
index c90ad2cfeb..ada1235449 100644
--- a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
+++ b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
@@ -31,9 +31,37 @@ class FuzzParseTensor : public FuzzSession {
}
void FuzzImpl(const uint8_t* data, size_t size) final {
+ // We need to be sure that we don't request too many elements (i.e., we
+ // don't make ASAN OOM). In theory, a tensor shape can have arbitrary large
+ // number of elements, up to the limit of the memory available to the OS.
+ // However, due to the tracing done in ASAN, after 2^32 bytes of requested
+ // memory we would get a crash in the fuzzer (see b/34190148). Hence, let's
+ // try parsing the proto here, check that the size (if valid) is below a
+ // maximum threshold (using 2^20 for convenience), and then run the
+ // remainder of the fuzzer testing. Of course, this duplicates some work
+ // but it's better than repeating the investigation whenever Autofuzz
+ // detects another similar OOM.
+ string as_string = string(reinterpret_cast<const char*>(data), size);
+ TensorProto proto;
+ if (!ParseProtoUnlimited(&proto, as_string)) {
+ LOG(WARNING) << "Unable to parse proto of tensor\n";
+ return;
+ }
+ if (!TensorShape::IsValid(proto.tensor_shape())) {
+ LOG(WARNING) << "Invalid tensor shape\n";
+ return;
+ }
+ TensorShape shape(proto.tensor_shape());
+ const int64 num_elements = shape.num_elements();
+ const int64 max_num_elements = 1 << 20;
+ if (num_elements > max_num_elements) {
+ LOG(WARNING) << "Requiring a tensor with too many elements\n";
+ return;
+ }
+
+ // Now we can do the actual fuzz implementation
Tensor input_tensor(tensorflow::DT_STRING, TensorShape({}));
- input_tensor.scalar<string>()() =
- string(reinterpret_cast<const char*>(data), size);
+ input_tensor.scalar<string>()() = as_string;
// TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
RunOneInput(input_tensor).IgnoreError();
}