aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/fully_connected_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-27 16:26:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-27 16:29:08 -0700
commit496840acbdd8b8b7688c257793e09a02229d21f6 (patch)
treeb9f731fe7dec90b59234377c5e751f5e22531c90 /tensorflow/contrib/lite/kernels/fully_connected_test.cc
parenta16761483ec55095158b1b11118d93ea00a538f4 (diff)
Test all TFLite kernel implementations for fully connected.
PiperOrigin-RevId: 190693455
Diffstat (limited to 'tensorflow/contrib/lite/kernels/fully_connected_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected_test.cc57
1 files changed, 47 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
index a0f766c4f4..87413000a9 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected_test.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
@@ -19,12 +19,25 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
#include "tensorflow/contrib/lite/model.h"
namespace tflite {
+
+namespace ops {
+namespace builtin {
+
+TfLiteRegistration* Register_FULLY_CONNECTED_REF();
+TfLiteRegistration* Register_FULLY_CONNECTED_NEON_OPT();
+TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT();
+TfLiteRegistration* Register_FULLY_CONNECTED_PIE();
+
+} // namespace builtin
+} // namespace ops
+
namespace {
using ::testing::ElementsAre;
@@ -119,7 +132,8 @@ static float fully_connected_golden_output[] = {
class BaseFullyConnectedOpModel : public SingleOpModel {
public:
// TODO(ahentz): test different activation types too.
- BaseFullyConnectedOpModel(int units, int batches, const TensorData& input,
+ BaseFullyConnectedOpModel(TfLiteRegistration* registration, int units,
+ int batches, const TensorData& input,
const TensorData& output = {TensorType_FLOAT32})
: batches_(batches), units_(units) {
int total_input_size = 1;
@@ -149,6 +163,8 @@ class BaseFullyConnectedOpModel : public SingleOpModel {
BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
.Union());
+ resolver_ = absl::make_unique<SingleOpResolver>(
+ BuiltinOperator_FULLY_CONNECTED, registration);
BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
}
@@ -208,10 +224,25 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
}
};
+const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
+ {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()},
+ {"NeonOptimized", ops::builtin::Register_FULLY_CONNECTED_NEON_OPT()},
+ {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
+ {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()},
+});
+
+class FullyConnectedOpTest : public SingleOpTest {
+ protected:
+ const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+ return *kKernelMap;
+ }
+};
+
// TODO(ahentz): add more small tests like this one, focused on making sure the
// calculations are correct.
-TEST(FullyConnectedOpTest, SimpleTest) {
- FloatFullyConnectedOpModel m(3, 2, {TensorType_FLOAT32, {2, 10}});
+TEST_P(FullyConnectedOpTest, SimpleTest) {
+ FloatFullyConnectedOpModel m(GetRegistration(), 3, 2,
+ {TensorType_FLOAT32, {2, 10}});
m.SetWeights({
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
@@ -229,9 +260,9 @@ TEST(FullyConnectedOpTest, SimpleTest) {
EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
}
-TEST(FullyConnectedOpTest, SimpleTestQuantized) {
+TEST_P(FullyConnectedOpTest, SimpleTestQuantized) {
QuantizedFullyConnectedOpModel m(
- 3, 2,
+ GetRegistration(), 3, 2,
/*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
/*output=*/{TensorType_UINT8, {}, -127, 128});
@@ -261,7 +292,8 @@ TEST(FullyConnectedOpTest, SimpleTest4DInput) {
// Note that it is not required that the first dimension be the number of
// batches. All we care is that the input can be evenly distributed in
// batches. In this case, we need the input to have multiples of '2'.
- FloatFullyConnectedOpModel m(/*units=*/3,
+ FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(),
+ /*units=*/3,
/*batches=*/2,
/*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}});
m.SetWeights({
@@ -284,9 +316,9 @@ TEST(FullyConnectedOpTest, SimpleTest4DInput) {
}));
}
-TEST(FullyConnectedOpTest, SimpleTest4dInputQuantized) {
+TEST_P(FullyConnectedOpTest, SimpleTest4dInputQuantized) {
QuantizedFullyConnectedOpModel m(
- 3, 2,
+ GetRegistration(), 3, 2,
/*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64},
/*output=*/{TensorType_UINT8, {}, -127, 128});
@@ -312,10 +344,15 @@ TEST(FullyConnectedOpTest, SimpleTest4dInputQuantized) {
EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
}
+INSTANTIATE_TEST_CASE_P(
+ FullyConnectedOpTest, FullyConnectedOpTest,
+ ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
// TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard
// to debug errors and doesn't necessarily test all the important details.
-TEST(FullyConnectedOpTest, BlackBoxTest) {
- FloatFullyConnectedOpModel m(16, 2, {TensorType_FLOAT32, {2, 8}});
+TEST_P(FullyConnectedOpTest, BlackBoxTest) {
+ FloatFullyConnectedOpModel m(GetRegistration(), 16, 2,
+ {TensorType_FLOAT32, {2, 8}});
m.SetWeights(
{0.091327, 0.103366, -0.316505, -0.083120, 0.149366, -0.196636,
-0.123672, 0.062800, 0.063031, 0.191670, -0.062001, -0.061504,