aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc')
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc100
1 files changed, 100 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc
new file mode 100644
index 0000000000..9b8676bab6
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc
@@ -0,0 +1,100 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include <farmhash.h>
+
+namespace tflite {
+
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_EXTRACT_FEATURES();
+
+namespace {
+
+using ::testing::ElementsAre;
+
+class ExtractFeatureOpModel : public SingleOpModel {
+ public:
+ explicit ExtractFeatureOpModel(const std::vector<string>& input) {
+ input_ = AddInput(TensorType_STRING);
+ signature_ = AddOutput(TensorType_INT32);
+ weight_ = AddOutput(TensorType_FLOAT32);
+
+ SetCustomOp("ExtractFeatures", {}, Register_EXTRACT_FEATURES);
+ BuildInterpreter({{static_cast<int>(input.size())}});
+ PopulateStringTensor(input_, input);
+ }
+
+ std::vector<int> GetSignature() { return ExtractVector<int>(signature_); }
+ std::vector<float> GetWeight() { return ExtractVector<float>(weight_); }
+
+ private:
+ int input_;
+ int signature_;
+ int weight_;
+};
+
+int CalcFeature(const string& str) {
+ return ::util::Fingerprint64(str) % 1000000;
+}
+
+TEST(ExtractFeatureOpTest, RegularInput) {
+ ExtractFeatureOpModel m({"<S>", "<S> Hi", "Hi", "Hi !", "!", "! <E>", "<E>"});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(),
+ ElementsAre(0, CalcFeature("<S> Hi"), CalcFeature("Hi"),
+ CalcFeature("Hi !"), CalcFeature("!"),
+ CalcFeature("! <E>"), 0));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(0, 2, 1, 2, 1, 2, 0));
+}
+
+TEST(ExtractFeatureOpTest, OneInput) {
+ ExtractFeatureOpModel m({"Hi"});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(), ElementsAre(CalcFeature("Hi")));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(1));
+}
+
+TEST(ExtractFeatureOpTest, ZeroInput) {
+ ExtractFeatureOpModel m({});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(), ElementsAre(0));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(0));
+}
+
+TEST(ExtractFeatureOpTest, AllBlacklistInput) {
+ ExtractFeatureOpModel m({"<S>", "<E>"});
+ m.Invoke();
+ EXPECT_THAT(m.GetSignature(), ElementsAre(0, 0));
+ EXPECT_THAT(m.GetWeight(), ElementsAre(0, 0));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}