diff options
author | 2018-03-27 16:26:12 -0700 | |
---|---|---|
committer | 2018-03-27 16:29:08 -0700 | |
commit | 496840acbdd8b8b7688c257793e09a02229d21f6 (patch) | |
tree | b9f731fe7dec90b59234377c5e751f5e22531c90 /tensorflow/contrib/lite/kernels/fully_connected_test.cc | |
parent | a16761483ec55095158b1b11118d93ea00a538f4 (diff) |
Test all TFLite kernel implementations for fully connected.
PiperOrigin-RevId: 190693455
Diffstat (limited to 'tensorflow/contrib/lite/kernels/fully_connected_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/fully_connected_test.cc | 57 |
1 files changed, 47 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc index a0f766c4f4..87413000a9 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected_test.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc @@ -19,12 +19,25 @@ limitations under the License. #include <gmock/gmock.h> #include <gtest/gtest.h> +#include "absl/memory/memory.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" #include "tensorflow/contrib/lite/model.h" namespace tflite { + +namespace ops { +namespace builtin { + +TfLiteRegistration* Register_FULLY_CONNECTED_REF(); +TfLiteRegistration* Register_FULLY_CONNECTED_NEON_OPT(); +TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT(); +TfLiteRegistration* Register_FULLY_CONNECTED_PIE(); + +} // namespace builtin +} // namespace ops + namespace { using ::testing::ElementsAre; @@ -119,7 +132,8 @@ static float fully_connected_golden_output[] = { class BaseFullyConnectedOpModel : public SingleOpModel { public: // TODO(ahentz): test different activation types too. - BaseFullyConnectedOpModel(int units, int batches, const TensorData& input, + BaseFullyConnectedOpModel(TfLiteRegistration* registration, int units, + int batches, const TensorData& input, const TensorData& output = {TensorType_FLOAT32}) : batches_(batches), units_(units) { int total_input_size = 1; @@ -149,6 +163,8 @@ class BaseFullyConnectedOpModel : public SingleOpModel { BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) .Union()); + resolver_ = absl::make_unique<SingleOpResolver>( + BuiltinOperator_FULLY_CONNECTED, registration); BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); } @@ -208,10 +224,25 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel { } }; +const auto kKernelMap = new std::map<string, TfLiteRegistration*>({ + {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()}, + {"NeonOptimized", ops::builtin::Register_FULLY_CONNECTED_NEON_OPT()}, + {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()}, + {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()}, +}); + +class FullyConnectedOpTest : public SingleOpTest { + protected: + const std::map<string, TfLiteRegistration*>& GetKernelMap() override { + return *kKernelMap; + } +}; + // TODO(ahentz): add more small tests like this one, focused on making sure the // calculations are correct. -TEST(FullyConnectedOpTest, SimpleTest) { - FloatFullyConnectedOpModel m(3, 2, {TensorType_FLOAT32, {2, 10}}); +TEST_P(FullyConnectedOpTest, SimpleTest) { + FloatFullyConnectedOpModel m(GetRegistration(), 3, 2, + {TensorType_FLOAT32, {2, 10}}); m.SetWeights({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 @@ -229,9 +260,9 @@ TEST(FullyConnectedOpTest, SimpleTest) { EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); } -TEST(FullyConnectedOpTest, SimpleTestQuantized) { +TEST_P(FullyConnectedOpTest, SimpleTestQuantized) { QuantizedFullyConnectedOpModel m( - 3, 2, + GetRegistration(), 3, 2, /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64}, /*output=*/{TensorType_UINT8, {}, -127, 128}); @@ -261,7 +292,8 @@ TEST(FullyConnectedOpTest, SimpleTest4DInput) { // Note that it is not required that the first dimension be the number of // batches. All we care is that the input can be evenly distributed in // batches. In this case, we need the input to have multiples of '2'. - FloatFullyConnectedOpModel m(/*units=*/3, + FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(), + /*units=*/3, /*batches=*/2, /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}}); m.SetWeights({ @@ -284,9 +316,9 @@ TEST(FullyConnectedOpTest, SimpleTest4DInput) { })); } -TEST(FullyConnectedOpTest, SimpleTest4dInputQuantized) { +TEST_P(FullyConnectedOpTest, SimpleTest4dInputQuantized) { QuantizedFullyConnectedOpModel m( - 3, 2, + GetRegistration(), 3, 2, /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64}, /*output=*/{TensorType_UINT8, {}, -127, 128}); @@ -312,10 +344,15 @@ TEST(FullyConnectedOpTest, SimpleTest4dInputQuantized) { EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); } +INSTANTIATE_TEST_CASE_P( + FullyConnectedOpTest, FullyConnectedOpTest, + ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap))); + // TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard // to debug errors and doesn't necessarily test all the important details. -TEST(FullyConnectedOpTest, BlackBoxTest) { - FloatFullyConnectedOpModel m(16, 2, {TensorType_FLOAT32, {2, 8}}); +TEST_P(FullyConnectedOpTest, BlackBoxTest) { + FloatFullyConnectedOpModel m(GetRegistration(), 16, 2, + {TensorType_FLOAT32, {2, 8}}); m.SetWeights( {0.091327, 0.103366, -0.316505, -0.083120, 0.149366, -0.196636, -0.123672, 0.062800, 0.063031, 0.191670, -0.062001, -0.061504, |