aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/svdf_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/svdf_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/svdf_test.cc186
1 files changed, 135 insertions, 51 deletions
diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc
index 0f166dc69b..5af3ff8500 100644
--- a/tensorflow/contrib/lite/kernels/svdf_test.cc
+++ b/tensorflow/contrib/lite/kernels/svdf_test.cc
@@ -126,17 +126,20 @@ static float svdf_golden_output_rank_2[] = {
};
// Derived class of SingleOpModel, which is used to test SVDF TFLite op.
-class SVDFOpModel : public SingleOpModel {
+class BaseSVDFOpModel : public SingleOpModel {
public:
- SVDFOpModel(int batches, int units, int input_size, int memory_size, int rank)
+ BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
+ int rank,
+ TensorType weights_feature_type = TensorType_FLOAT32,
+ TensorType weights_time_type = TensorType_FLOAT32)
: batches_(batches),
units_(units),
input_size_(input_size),
memory_size_(memory_size),
rank_(rank) {
input_ = AddInput(TensorType_FLOAT32);
- weights_feature_ = AddInput(TensorType_FLOAT32);
- weights_time_ = AddInput(TensorType_FLOAT32);
+ weights_feature_ = AddInput(weights_feature_type);
+ weights_time_ = AddInput(weights_time_type);
bias_ = AddNullInput();
state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -182,7 +185,7 @@ class SVDFOpModel : public SingleOpModel {
int num_units() { return units_; }
int num_batches() { return batches_; }
- private:
+ protected:
int input_;
int weights_feature_;
int weights_time_;
@@ -197,7 +200,61 @@ class SVDFOpModel : public SingleOpModel {
int rank_;
};
-TEST(SVDFOpTest, BlackBoxTestRank1) {
+class SVDFOpModel : public BaseSVDFOpModel {
+ public:
+ using BaseSVDFOpModel::BaseSVDFOpModel;
+};
+
+class HybridSVDFOpModel : public BaseSVDFOpModel {
+ public:
+ HybridSVDFOpModel(int batches, int units, int input_size, int memory_size,
+ int rank)
+ : BaseSVDFOpModel(batches, units, input_size, memory_size, rank,
+ TensorType_UINT8, TensorType_UINT8) {}
+
+ void SetWeightsFeature(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(weights_feature_, f);
+ }
+
+ void SetWeightsTime(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(weights_time_, f);
+ }
+};
+
+class SVDFOpTest : public ::testing::Test {
+ protected:
+ void VerifyGoldens(float golden_input[], float golden_output[],
+ int golden_size, BaseSVDFOpModel* svdf,
+ float tolerance = 1e-5) {
+ const int svdf_num_batches = svdf->num_batches();
+ const int svdf_input_size = svdf->input_size();
+ const int svdf_num_units = svdf->num_units();
+ const int input_sequence_size =
+ golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches);
+ // Going over each input batch, setting the input tensor, invoking the SVDF
+ // op and checking the output with the expected golden values.
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch_start =
+ golden_input + i * svdf_input_size * svdf_num_batches;
+ float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
+ svdf->SetInput(0, batch_start, batch_end);
+
+ svdf->Invoke();
+
+ const float* golden_start =
+ golden_output + i * svdf_num_units * svdf_num_batches;
+ const float* golden_end =
+ golden_start + svdf_num_units * svdf_num_batches;
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(svdf->GetOutput(),
+ ElementsAreArray(ArrayFloatNear(expected, tolerance)));
+ }
+ }
+};
+
+TEST_F(SVDFOpTest, BlackBoxTestRank1) {
SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/1);
svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
@@ -218,31 +275,11 @@ TEST(SVDFOpTest, BlackBoxTestRank1) {
-0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
svdf.ResetState();
- const int svdf_num_batches = svdf.num_batches();
- const int svdf_input_size = svdf.input_size();
- const int svdf_num_units = svdf.num_units();
- const int input_sequence_size =
- sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches);
- // Going over each input batch, setting the input tensor, invoking the SVDF op
- // and checking the output with the expected golden values.
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches;
- float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
- svdf.SetInput(0, batch_start, batch_end);
-
- svdf.Invoke();
-
- float* golden_start =
- svdf_golden_output_rank_1 + i * svdf_num_units * svdf_num_batches;
- float* golden_end = golden_start + svdf_num_units * svdf_num_batches;
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
-
- EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
- }
+ VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
+ &svdf);
}
-TEST(SVDFOpTest, BlackBoxTestRank2) {
+TEST_F(SVDFOpTest, BlackBoxTestRank2) {
SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/2);
svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
@@ -278,28 +315,75 @@ TEST(SVDFOpTest, BlackBoxTestRank2) {
0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
svdf.ResetState();
- const int svdf_num_batches = svdf.num_batches();
- const int svdf_input_size = svdf.input_size();
- const int svdf_num_units = svdf.num_units();
- const int input_sequence_size =
- sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches);
- // Going over each input batch, setting the input tensor, invoking the SVDF op
- // and checking the output with the expected golden values.
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches;
- float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
- svdf.SetInput(0, batch_start, batch_end);
-
- svdf.Invoke();
-
- float* golden_start =
- svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches;
- float* golden_end = golden_start + svdf_num_units * svdf_num_batches;
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
-
- EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
- }
+ VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
+ &svdf);
+}
+
+TEST_F(SVDFOpTest, BlackBoxTestHybridRank1) {
+ HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
+ /*memory_size=*/10, /*rank=*/1);
+ svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
+ 0.22197971, 0.12416199, 0.27901134, 0.27557442,
+ 0.3905206, -0.36137494, -0.06634006, -0.10640851});
+
+ svdf.SetWeightsTime(
+ {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
+
+ svdf.ResetState();
+ VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
+ &svdf,
+ /*tolerance=*/0.002945);
+}
+
+TEST_F(SVDFOpTest, BlackBoxTestHybridRank2) {
+ HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
+ /*memory_size=*/10, /*rank=*/2);
+ svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
+ 0.12416199, 0.15785322, 0.27901134, 0.3905206,
+ 0.21931258, -0.36137494, -0.10640851, 0.31053296,
+ -0.36118156, -0.0976817, -0.36916667, 0.22197971,
+ 0.15294972, 0.38031587, 0.27557442, 0.39635518,
+ -0.21580373, -0.06634006, -0.02702999, 0.27072677});
+
+ svdf.SetWeightsTime(
+ {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657,
+
+ -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486,
+ 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187,
+
+ -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589,
+ 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836,
+
+ -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277,
+ -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214,
+
+ 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
+ 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
+
+ svdf.ResetState();
+ VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
+ &svdf,
+ /*tolerance=*/0.00625109);
}
} // namespace