aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-25 10:14:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 10:23:11 -0700
commit0584b943f4eca8a5761480ebb524c930aa808f0d (patch)
treea3a2a611af1c6089add18f3c7e8e10bb60363ca9 /tensorflow/contrib/lite
parent67b9e8eafd826d3f430eb7f6780e815fcac5859e (diff)
An ErrorReporter to be used in tests.
PiperOrigin-RevId: 206012444
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r--tensorflow/contrib/lite/interpreter.cc5
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc17
-rw-r--r--tensorflow/contrib/lite/model_test.cc10
-rw-r--r--tensorflow/contrib/lite/testing/BUILD4
-rw-r--r--tensorflow/contrib/lite/testing/util.h31
5 files changed, 42 insertions, 25 deletions
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 5a5c907b6e..e38597495d 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -279,8 +279,9 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
int node_index;
TfLiteDelegateParams* params = CreateDelegateParams(delegate, subgraph);
- AddNodeWithParameters(subgraph.input_tensors, subgraph.output_tensors,
- nullptr, 0, params, &registration, &node_index);
+ TF_LITE_ENSURE_STATUS(AddNodeWithParameters(
+ subgraph.input_tensors, subgraph.output_tensors, nullptr, 0, params,
+ &registration, &node_index));
// Initialize the output tensors's delegate-related fields.
for (int tensor_index : subgraph.output_tensors) {
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 10119903fe..2bf598bad7 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -647,18 +647,6 @@ TEST(BasicInterpreter, AllocateTwice) {
ASSERT_EQ(old_tensor1_ptr, interpreter.tensor(1)->data.raw);
}
-struct TestErrorReporter : public ErrorReporter {
- int Report(const char* format, va_list args) override {
- char buffer[1024];
- int size = vsnprintf(buffer, sizeof(buffer), format, args);
- all_reports += buffer;
- calls++;
- return size;
- }
- int calls = 0;
- std::string all_reports;
-};
-
TEST(BasicInterpreter, TestNullErrorReporter) {
TestErrorReporter reporter;
Interpreter interpreter;
@@ -668,8 +656,9 @@ TEST(BasicInterpreter, TestCustomErrorReporter) {
TestErrorReporter reporter;
Interpreter interpreter(&reporter);
ASSERT_NE(interpreter.Invoke(), kTfLiteOk);
- ASSERT_EQ(reporter.all_reports, "Invoke called on model that is not ready.");
- ASSERT_EQ(reporter.calls, 1);
+ ASSERT_EQ(reporter.error_messages(),
+ "Invoke called on model that is not ready.");
+ ASSERT_EQ(reporter.num_calls(), 1);
}
TEST(BasicInterpreter, TestUnsupportedDelegateFunctions) {
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
index edfdec9315..df4f60d4ad 100644
--- a/tensorflow/contrib/lite/model_test.cc
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -241,14 +241,6 @@ TEST(BasicFlatBufferModel, TestWithNullVerifier) {
"tensorflow/contrib/lite/testdata/test_model.bin", nullptr));
}
-struct TestErrorReporter : public ErrorReporter {
- int Report(const char* format, va_list args) override {
- calls++;
- return 0;
- }
- int calls = 0;
-};
-
// This makes sure the ErrorReporter is marshalled from FlatBufferModel to
// the Interpreter.
TEST(BasicFlatBufferModel, TestCustomErrorReporter) {
@@ -262,7 +254,7 @@ TEST(BasicFlatBufferModel, TestCustomErrorReporter) {
TrivialResolver resolver;
InterpreterBuilder(*model, resolver)(&interpreter);
ASSERT_NE(interpreter->Invoke(), kTfLiteOk);
- ASSERT_EQ(reporter.calls, 1);
+ ASSERT_EQ(reporter.num_calls(), 1);
}
// This makes sure the ErrorReporter is marshalled from FlatBufferModel to
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 6c7f494e9b..8a2705950d 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -210,6 +210,10 @@ cc_library(
cc_library(
name = "util",
hdrs = ["util.h"],
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string",
+ ],
)
cc_test(
diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h
index 6d20aec141..8aa639157b 100644
--- a/tensorflow/contrib/lite/testing/util.h
+++ b/tensorflow/contrib/lite/testing/util.h
@@ -15,8 +15,39 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_
#define TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_
+#include <cstdio>
+
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/string.h"
+
namespace tflite {
+// An ErrorReporter that collects error message in a string, in addition
+// to printing to stderr.
+class TestErrorReporter : public ErrorReporter {
+ public:
+ int Report(const char* format, va_list args) override {
+ char buffer[1024];
+ int size = vsnprintf(buffer, sizeof(buffer), format, args);
+ fprintf(stderr, "%s", buffer);
+ error_messages_ += buffer;
+ num_calls_++;
+ return size;
+ }
+
+ void Reset() {
+ num_calls_ = 0;
+ error_messages_.clear();
+ }
+
+ int num_calls() const { return num_calls_; }
+ const string& error_messages() const { return error_messages_; }
+
+ private:
+ int num_calls_ = 0;
+ string error_messages_;
+};
+
inline void LogToStderr() {
#ifdef PLATFORM_GOOGLE
FLAGS_logtostderr = true;