aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/skip_gram_test.cc
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2017-11-10 10:35:35 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:42 -0800
commit0b15439f8f0f2d4755587f4096c3ea04cb199d23 (patch)
tree9aa4fc8162bf9b4ee50112a7b85703f70ca4df08 /tensorflow/contrib/lite/kernels/skip_gram_test.cc
parent7ac140a5845553275427162aabd9d54987144b4a (diff)
Internal Change.
PiperOrigin-RevId: 175307445
Diffstat (limited to 'tensorflow/contrib/lite/kernels/skip_gram_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/skip_gram_test.cc257
1 files changed, 257 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/skip_gram_test.cc b/tensorflow/contrib/lite/kernels/skip_gram_test.cc
new file mode 100644
index 0000000000..e7f6bc904b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/skip_gram_test.cc
@@ -0,0 +1,257 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+static char kSentence[] = "The quick\t brown fox\n jumps over\n the lazy dog!";
+
+class SkipGramOp : public SingleOpModel {
+ public:
+ SkipGramOp(int ngram_size, int max_skip_size, bool include_all_ngrams) {
+ input_ = AddInput(TensorType_STRING);
+ output_ = AddOutput(TensorType_STRING);
+
+ SetBuiltinOp(BuiltinOperator_SKIP_GRAM, BuiltinOptions_SkipGramOptions,
+ CreateSkipGramOptions(builder_, ngram_size, max_skip_size,
+ include_all_ngrams)
+ .Union());
+ BuildInterpreter({{1}});
+ }
+ void SetInput(const string& content) {
+ PopulateStringTensor(input_, {content});
+ }
+
+ std::vector<string> GetOutput() {
+ std::vector<string> ans;
+ TfLiteTensor* tensor = interpreter_->tensor(output_);
+
+ int num = GetStringCount(tensor);
+ for (int i = 0; i < num; i++) {
+ StringRef strref = GetString(tensor, i);
+ ans.push_back(string(strref.str, strref.len));
+ }
+ return ans;
+ }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(SkipGramTest, TestUnigram) {
+ SkipGramOp m(1, 0, false);
+
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), testing::UnorderedElementsAreArray(
+ {"The", "quick", "brown", "fox", "jumps",
+ "over", "the", "lazy", "dog!"}));
+}
+
+TEST(SkipGramTest, TestBigram) {
+ SkipGramOp m(2, 0, false);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {"The quick", "quick brown", "brown fox", "fox jumps",
+ "jumps over", "over the", "the lazy", "lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestAllBigram) {
+ SkipGramOp m(2, 0, true);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {// Unigram
+ "The", "quick", "brown", "fox", "jumps", "over", "the",
+ "lazy", "dog!",
+ // Bigram
+ "The quick", "quick brown", "brown fox", "fox jumps",
+ "jumps over", "over the", "the lazy", "lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestAllTrigram) {
+ SkipGramOp m(3, 0, true);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {// Unigram
+ "The", "quick", "brown", "fox", "jumps", "over", "the",
+ "lazy", "dog!",
+ // Bigram
+ "The quick", "quick brown", "brown fox", "fox jumps",
+ "jumps over", "over the", "the lazy", "lazy dog!",
+ // Trigram
+ "The quick brown", "quick brown fox", "brown fox jumps",
+ "fox jumps over", "jumps over the", "over the lazy",
+ "the lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestSkip1Bigram) {
+ SkipGramOp m(2, 1, false);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {"The quick", "The brown", "quick brown", "quick fox", "brown fox",
+ "brown jumps", "fox jumps", "fox over", "jumps over", "jumps the",
+ "over the", "over lazy", "the lazy", "the dog!", "lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestSkip2Bigram) {
+ SkipGramOp m(2, 2, false);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {"The quick", "The brown", "The fox", "quick brown",
+ "quick fox", "quick jumps", "brown fox", "brown jumps",
+ "brown over", "fox jumps", "fox over", "fox the",
+ "jumps over", "jumps the", "jumps lazy", "over the",
+ "over lazy", "over dog!", "the lazy", "the dog!",
+ "lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestSkip1Trigram) {
+ SkipGramOp m(3, 1, false);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {"The quick brown", "The quick fox", "The brown fox",
+ "The brown jumps", "quick brown fox", "quick brown jumps",
+ "quick fox jumps", "quick fox over", "brown fox jumps",
+ "brown fox over", "brown jumps over", "brown jumps the",
+ "fox jumps over", "fox jumps the", "fox over the",
+ "fox over lazy", "jumps over the", "jumps over lazy",
+ "jumps the lazy", "jumps the dog!", "over the lazy",
+ "over the dog!", "over lazy dog!", "the lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestSkip2Trigram) {
+ SkipGramOp m(3, 2, false);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {"The quick brown", "The quick fox", "The quick jumps",
+ "The brown fox", "The brown jumps", "The brown over",
+ "The fox jumps", "The fox over", "The fox the",
+ "quick brown fox", "quick brown jumps", "quick brown over",
+ "quick fox jumps", "quick fox over", "quick fox the",
+ "quick jumps over", "quick jumps the", "quick jumps lazy",
+ "brown fox jumps", "brown fox over", "brown fox the",
+ "brown jumps over", "brown jumps the", "brown jumps lazy",
+ "brown over the", "brown over lazy", "brown over dog!",
+ "fox jumps over", "fox jumps the", "fox jumps lazy",
+ "fox over the", "fox over lazy", "fox over dog!",
+ "fox the lazy", "fox the dog!", "jumps over the",
+ "jumps over lazy", "jumps over dog!", "jumps the lazy",
+ "jumps the dog!", "jumps lazy dog!", "over the lazy",
+ "over the dog!", "over lazy dog!", "the lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestAllSkip2Trigram) {
+ SkipGramOp m(3, 2, true);
+ m.SetInput(kSentence);
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutput(),
+ testing::UnorderedElementsAreArray(
+ {// Unigram
+ "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy",
+ "dog!",
+ // Bigram
+ "The quick", "The brown", "The fox", "quick brown", "quick fox",
+ "quick jumps", "brown fox", "brown jumps", "brown over", "fox jumps",
+ "fox over", "fox the", "jumps over", "jumps the", "jumps lazy",
+ "over the", "over lazy", "over dog!", "the lazy", "the dog!",
+ "lazy dog!",
+ // Trigram
+ "The quick brown", "The quick fox", "The quick jumps",
+ "The brown fox", "The brown jumps", "The brown over",
+ "The fox jumps", "The fox over", "The fox the", "quick brown fox",
+ "quick brown jumps", "quick brown over", "quick fox jumps",
+ "quick fox over", "quick fox the", "quick jumps over",
+ "quick jumps the", "quick jumps lazy", "brown fox jumps",
+ "brown fox over", "brown fox the", "brown jumps over",
+ "brown jumps the", "brown jumps lazy", "brown over the",
+ "brown over lazy", "brown over dog!", "fox jumps over",
+ "fox jumps the", "fox jumps lazy", "fox over the", "fox over lazy",
+ "fox over dog!", "fox the lazy", "fox the dog!", "jumps over the",
+ "jumps over lazy", "jumps over dog!", "jumps the lazy",
+ "jumps the dog!", "jumps lazy dog!", "over the lazy",
+ "over the dog!", "over lazy dog!", "the lazy dog!"}));
+}
+
+TEST(SkipGramTest, TestSingleWord) {
+ SkipGramOp m(1, 1, false);
+ m.SetInput("Hi");
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAre("Hi"));
+}
+
+TEST(SkipGramTest, TestWordsLessThanGram) {
+ SkipGramOp m(3, 1, false);
+ m.SetInput("Hi hi");
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), std::vector<string>());
+}
+
+TEST(SkipGramTest, TestEmptyInput) {
+ SkipGramOp m(1, 1, false);
+ m.SetInput("");
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAre());
+}
+
+TEST(SkipGramTest, TestWhitespaceInput) {
+ SkipGramOp m(1, 1, false);
+ m.SetInput(" ");
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAre());
+}
+
+TEST(SkipGramTest, TestInputWithExtraSpace) {
+ SkipGramOp m(1, 1, false);
+ m.SetInput(" Hello world ! ");
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAre("Hello", "world", "!"));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}