diff options
Diffstat (limited to 'tensorflow/contrib/lite/testing/util.h')
-rw-r--r-- | tensorflow/contrib/lite/testing/util.h | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h index 6d20aec141..8aa639157b 100644 --- a/tensorflow/contrib/lite/testing/util.h +++ b/tensorflow/contrib/lite/testing/util.h @@ -15,8 +15,39 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ #define TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ +#include <cstdio> + +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/string.h" + namespace tflite { +// An ErrorReporter that collects error message in a string, in addition +// to printing to stderr. +class TestErrorReporter : public ErrorReporter { + public: + int Report(const char* format, va_list args) override { + char buffer[1024]; + int size = vsnprintf(buffer, sizeof(buffer), format, args); + fprintf(stderr, "%s", buffer); + error_messages_ += buffer; + num_calls_++; + return size; + } + + void Reset() { + num_calls_ = 0; + error_messages_.clear(); + } + + int num_calls() const { return num_calls_; } + const string& error_messages() const { return error_messages_; } + + private: + int num_calls_ = 0; + string error_messages_; +}; + inline void LogToStderr() { #ifdef PLATFORM_GOOGLE FLAGS_logtostderr = true; |