diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/svdf_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/svdf_test.cc | 186 |
1 files changed, 135 insertions, 51 deletions
diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc index 0f166dc69b..5af3ff8500 100644 --- a/tensorflow/contrib/lite/kernels/svdf_test.cc +++ b/tensorflow/contrib/lite/kernels/svdf_test.cc @@ -126,17 +126,20 @@ static float svdf_golden_output_rank_2[] = { }; // Derived class of SingleOpModel, which is used to test SVDF TFLite op. -class SVDFOpModel : public SingleOpModel { +class BaseSVDFOpModel : public SingleOpModel { public: - SVDFOpModel(int batches, int units, int input_size, int memory_size, int rank) + BaseSVDFOpModel(int batches, int units, int input_size, int memory_size, + int rank, + TensorType weights_feature_type = TensorType_FLOAT32, + TensorType weights_time_type = TensorType_FLOAT32) : batches_(batches), units_(units), input_size_(input_size), memory_size_(memory_size), rank_(rank) { input_ = AddInput(TensorType_FLOAT32); - weights_feature_ = AddInput(TensorType_FLOAT32); - weights_time_ = AddInput(TensorType_FLOAT32); + weights_feature_ = AddInput(weights_feature_type); + weights_time_ = AddInput(weights_time_type); bias_ = AddNullInput(); state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -182,7 +185,7 @@ class SVDFOpModel : public SingleOpModel { int num_units() { return units_; } int num_batches() { return batches_; } - private: + protected: int input_; int weights_feature_; int weights_time_; @@ -197,7 +200,61 @@ class SVDFOpModel : public SingleOpModel { int rank_; }; -TEST(SVDFOpTest, BlackBoxTestRank1) { +class SVDFOpModel : public BaseSVDFOpModel { + public: + using BaseSVDFOpModel::BaseSVDFOpModel; +}; + +class HybridSVDFOpModel : public BaseSVDFOpModel { + public: + HybridSVDFOpModel(int batches, int units, int input_size, int memory_size, + int rank) + : BaseSVDFOpModel(batches, units, input_size, memory_size, rank, + TensorType_UINT8, TensorType_UINT8) {} + + void SetWeightsFeature(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(weights_feature_, f); + } + + void SetWeightsTime(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(weights_time_, f); + } +}; + +class SVDFOpTest : public ::testing::Test { + protected: + void VerifyGoldens(float golden_input[], float golden_output[], + int golden_size, BaseSVDFOpModel* svdf, + float tolerance = 1e-5) { + const int svdf_num_batches = svdf->num_batches(); + const int svdf_input_size = svdf->input_size(); + const int svdf_num_units = svdf->num_units(); + const int input_sequence_size = + golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF + // op and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = + golden_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + svdf->SetInput(0, batch_start, batch_end); + + svdf->Invoke(); + + const float* golden_start = + golden_output + i * svdf_num_units * svdf_num_batches; + const float* golden_end = + golden_start + svdf_num_units * svdf_num_batches; + std::vector<float> expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(svdf->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } + } +}; + +TEST_F(SVDFOpTest, BlackBoxTestRank1) { SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, /*memory_size=*/10, /*rank=*/1); svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, @@ -218,31 +275,11 @@ TEST(SVDFOpTest, BlackBoxTestRank1) { -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); svdf.ResetState(); - const int svdf_num_batches = svdf.num_batches(); - const int svdf_input_size = svdf.input_size(); - const int svdf_num_units = svdf.num_units(); - const int input_sequence_size = - sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); - // Going over each input batch, setting the input tensor, invoking the SVDF op - // and checking the output with the expected golden values. - for (int i = 0; i < input_sequence_size; i++) { - float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; - float* batch_end = batch_start + svdf_input_size * svdf_num_batches; - svdf.SetInput(0, batch_start, batch_end); - - svdf.Invoke(); - - float* golden_start = - svdf_golden_output_rank_1 + i * svdf_num_units * svdf_num_batches; - float* golden_end = golden_start + svdf_num_units * svdf_num_batches; - std::vector<float> expected; - expected.insert(expected.end(), golden_start, golden_end); - - EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), + &svdf); } -TEST(SVDFOpTest, BlackBoxTestRank2) { +TEST_F(SVDFOpTest, BlackBoxTestRank2) { SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, /*memory_size=*/10, /*rank=*/2); svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, @@ -278,28 +315,75 @@ TEST(SVDFOpTest, BlackBoxTestRank2) { 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); svdf.ResetState(); - const int svdf_num_batches = svdf.num_batches(); - const int svdf_input_size = svdf.input_size(); - const int svdf_num_units = svdf.num_units(); - const int input_sequence_size = - sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); - // Going over each input batch, setting the input tensor, invoking the SVDF op - // and checking the output with the expected golden values. - for (int i = 0; i < input_sequence_size; i++) { - float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; - float* batch_end = batch_start + svdf_input_size * svdf_num_batches; - svdf.SetInput(0, batch_start, batch_end); - - svdf.Invoke(); - - float* golden_start = - svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches; - float* golden_end = golden_start + svdf_num_units * svdf_num_batches; - std::vector<float> expected; - expected.insert(expected.end(), golden_start, golden_end); - - EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), + &svdf); +} + +TEST_F(SVDFOpTest, BlackBoxTestHybridRank1) { + HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/1); + svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, + 0.22197971, 0.12416199, 0.27901134, 0.27557442, + 0.3905206, -0.36137494, -0.06634006, -0.10640851}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); + + svdf.ResetState(); + VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), + &svdf, + /*tolerance=*/0.002945); +} + +TEST_F(SVDFOpTest, BlackBoxTestHybridRank2) { + HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/2); + svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, + 0.12416199, 0.15785322, 0.27901134, 0.3905206, + 0.21931258, -0.36137494, -0.10640851, 0.31053296, + -0.36118156, -0.0976817, -0.36916667, 0.22197971, + 0.15294972, 0.38031587, 0.27557442, 0.39635518, + -0.21580373, -0.06634006, -0.02702999, 0.27072677}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, + + -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, + 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, + + -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, + 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, + + -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, + -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, + + 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, + 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); + + svdf.ResetState(); + VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), + &svdf, + /*tolerance=*/0.00625109); } } // namespace |