diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-25 10:14:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-25 10:23:11 -0700 |
commit | 0584b943f4eca8a5761480ebb524c930aa808f0d (patch) | |
tree | a3a2a611af1c6089add18f3c7e8e10bb60363ca9 /tensorflow/contrib/lite | |
parent | 67b9e8eafd826d3f430eb7f6780e815fcac5859e (diff) |
An ErrorReporter to be used in tests.
PiperOrigin-RevId: 206012444
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r-- | tensorflow/contrib/lite/interpreter.cc | 5 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter_test.cc | 17 | ||||
-rw-r--r-- | tensorflow/contrib/lite/model_test.cc | 10 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/util.h | 31 |
5 files changed, 42 insertions, 25 deletions
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 5a5c907b6e..e38597495d 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -279,8 +279,9 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( int node_index; TfLiteDelegateParams* params = CreateDelegateParams(delegate, subgraph); - AddNodeWithParameters(subgraph.input_tensors, subgraph.output_tensors, - nullptr, 0, params, ®istration, &node_index); + TF_LITE_ENSURE_STATUS(AddNodeWithParameters( + subgraph.input_tensors, subgraph.output_tensors, nullptr, 0, params, + ®istration, &node_index)); // Initialize the output tensors's delegate-related fields. for (int tensor_index : subgraph.output_tensors) { diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 10119903fe..2bf598bad7 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -647,18 +647,6 @@ TEST(BasicInterpreter, AllocateTwice) { ASSERT_EQ(old_tensor1_ptr, interpreter.tensor(1)->data.raw); } -struct TestErrorReporter : public ErrorReporter { - int Report(const char* format, va_list args) override { - char buffer[1024]; - int size = vsnprintf(buffer, sizeof(buffer), format, args); - all_reports += buffer; - calls++; - return size; - } - int calls = 0; - std::string all_reports; -}; - TEST(BasicInterpreter, TestNullErrorReporter) { TestErrorReporter reporter; Interpreter interpreter; @@ -668,8 +656,9 @@ TEST(BasicInterpreter, TestCustomErrorReporter) { TestErrorReporter reporter; Interpreter interpreter(&reporter); ASSERT_NE(interpreter.Invoke(), kTfLiteOk); - ASSERT_EQ(reporter.all_reports, "Invoke called on model that is not ready."); - ASSERT_EQ(reporter.calls, 1); + ASSERT_EQ(reporter.error_messages(), + "Invoke called on model that is not ready."); + ASSERT_EQ(reporter.num_calls(), 1); } TEST(BasicInterpreter, TestUnsupportedDelegateFunctions) { diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc index edfdec9315..df4f60d4ad 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/contrib/lite/model_test.cc @@ -241,14 +241,6 @@ TEST(BasicFlatBufferModel, TestWithNullVerifier) { "tensorflow/contrib/lite/testdata/test_model.bin", nullptr)); } -struct TestErrorReporter : public ErrorReporter { - int Report(const char* format, va_list args) override { - calls++; - return 0; - } - int calls = 0; -}; - // This makes sure the ErrorReporter is marshalled from FlatBufferModel to // the Interpreter. TEST(BasicFlatBufferModel, TestCustomErrorReporter) { @@ -262,7 +254,7 @@ TEST(BasicFlatBufferModel, TestCustomErrorReporter) { TrivialResolver resolver; InterpreterBuilder(*model, resolver)(&interpreter); ASSERT_NE(interpreter->Invoke(), kTfLiteOk); - ASSERT_EQ(reporter.calls, 1); + ASSERT_EQ(reporter.num_calls(), 1); } // This makes sure the ErrorReporter is marshalled from FlatBufferModel to diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 6c7f494e9b..8a2705950d 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -210,6 +210,10 @@ cc_library( cc_library( name = "util", hdrs = ["util.h"], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string", + ], ) cc_test( diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h index 6d20aec141..8aa639157b 100644 --- a/tensorflow/contrib/lite/testing/util.h +++ b/tensorflow/contrib/lite/testing/util.h @@ -15,8 +15,39 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ #define TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ +#include <cstdio> + +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/string.h" + namespace tflite { +// An ErrorReporter that collects error message in a string, in addition +// to printing to stderr. +class TestErrorReporter : public ErrorReporter { + public: + int Report(const char* format, va_list args) override { + char buffer[1024]; + int size = vsnprintf(buffer, sizeof(buffer), format, args); + fprintf(stderr, "%s", buffer); + error_messages_ += buffer; + num_calls_++; + return size; + } + + void Reset() { + num_calls_ = 0; + error_messages_.clear(); + } + + int num_calls() const { return num_calls_; } + const string& error_messages() const { return error_messages_; } + + private: + int num_calls_ = 0; + string error_messages_; +}; + inline void LogToStderr() { #ifdef PLATFORM_GOOGLE FLAGS_logtostderr = true; |