aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing/util.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/testing/util.h')
-rw-r--r--tensorflow/contrib/lite/testing/util.h31
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h
index 6d20aec141..8aa639157b 100644
--- a/tensorflow/contrib/lite/testing/util.h
+++ b/tensorflow/contrib/lite/testing/util.h
@@ -15,8 +15,39 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_
#define TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_
+#include <cstdio>
+
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/string.h"
+
namespace tflite {
+// An ErrorReporter that collects error message in a string, in addition
+// to printing to stderr.
+class TestErrorReporter : public ErrorReporter {
+ public:
+ int Report(const char* format, va_list args) override {
+ char buffer[1024];
+ int size = vsnprintf(buffer, sizeof(buffer), format, args);
+ fprintf(stderr, "%s", buffer);
+ error_messages_ += buffer;
+ num_calls_++;
+ return size;
+ }
+
+ void Reset() {
+ num_calls_ = 0;
+ error_messages_.clear();
+ }
+
+ int num_calls() const { return num_calls_; }
+ const string& error_messages() const { return error_messages_; }
+
+ private:
+ int num_calls_ = 0;
+ string error_messages_;
+};
+
inline void LogToStderr() {
#ifdef PLATFORM_GOOGLE
FLAGS_logtostderr = true;