diff options
author | Alan Chiao <alanchiao@google.com> | 2018-07-12 17:55:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-12 17:59:36 -0700 |
commit | 6468c599d4b3b5c6fa94d171c1ba6b4254286c23 (patch) | |
tree | 329833d0916b31f4f27ba5df8284292e915b472e | |
parent | fb9992423c20ed716124f63b9fe7858a0198e897 (diff) |
Internal change.
PiperOrigin-RevId: 204398429
-rw-r--r-- | tensorflow/contrib/lite/kernels/embedding_lookup.cc | 3 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/embedding_lookup_test.cc | 36 |
2 files changed, 20 insertions, 19 deletions
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc index 0ba170a4da..f550339d03 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -112,8 +112,9 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, // TODO(alanchiao): refactor scalar multiply into separate function // for ease of adding a neon equivalent if ever necessary. for (int j = 0; j < col_size; j++) { + const int8_t* value_ptr = reinterpret_cast<int8_t*>(value->data.uint8); output->data.f[j + i * col_size] = - value->data.uint8[j + idx * col_size] * scaling_factor; + value_ptr[j + idx * col_size] * scaling_factor; } } } diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc index 04657fd863..4a88d168c6 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -107,9 +107,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) { HybridEmbeddingLookupOpModel m({3}, {3, 8}); m.SetInput({1, 0, 2}); m.SetWeight({ - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }); m.Invoke(); @@ -117,9 +117,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }, 7.41e-03))); } @@ -128,9 +128,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) { HybridEmbeddingLookupOpModel m({3}, {3, 2, 4}); m.SetInput({1, 0, 2}); m.SetWeight({ - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }); m.Invoke(); @@ -138,9 +138,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }, 7.41e-03))); } @@ -149,9 +149,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) { HybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2}); m.SetInput({1, 0, 2}); m.SetWeight({ - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }); m.Invoke(); @@ -159,9 +159,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }, 7.41e-03))); } |