aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alan Chiao <alanchiao@google.com>2018-06-08 17:19:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-08 17:23:47 -0700
commit49a729901484a413fd605be735da9a563c24336a (patch)
tree77fc992757d2f3700d3e1eacbd654bf0941b769d
parentcf042e7e90c00d639904e2a5fad8a9cd9d6962da (diff)
Hybrid embedding lookup op
PiperOrigin-RevId: 199874482
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup.cc57
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_test.cc110
2 files changed, 147 insertions, 20 deletions
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
index 7539c0b30d..9410bead5e 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -24,7 +24,8 @@ limitations under the License.
// Output:
// Output.dim[0] == Tensor[0].dim[0], num of lookups
// Output.dim[1] == Tensor[1].dim[1], num of items per row
-// Each item in output is a raw bytes copy of corresponding item in input.
+// Each item in output is a raw bytes copy of the corresponding item in input,
+// or a dequantized value in the case of a uint8 input.
// When indices are out of bound, the ops will not succeed.
//
@@ -69,11 +70,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, outputSize);
}
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* output = GetOutput(context, node, 0);
- const TfLiteTensor* lookup = GetInput(context, node, 0);
- const TfLiteTensor* value = GetInput(context, node, 1);
-
+TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* lookup, const TfLiteTensor* value,
+ TfLiteTensor* output) {
const int row_size = SizeOfDimension(value, 0);
const int row_bytes = value->bytes / row_size;
@@ -91,6 +90,52 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* lookup, const TfLiteTensor* value,
+ TfLiteTensor* output) {
+ const int row_size = SizeOfDimension(value, 0);
+ const double scaling_factor = 1.0 / value->params.scale;
+
+ // col_size after we flatten tensor into 2D.
+ int col_size = 1;
+ for (int i = 1; i < NumDimensions(value); i++) {
+ col_size *= SizeOfDimension(value, i);
+ }
+
+ for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
+ int idx = lookup->data.i32[i];
+ if (idx >= row_size || idx < 0) {
+ context->ReportError(context, "Embedding Lookup: index out of bounds.");
+ return kTfLiteError;
+ } else {
+ // Dequantize embedding values.
+ // TODO(alanchiao): refactor scalar multiply into separate function
+ // for ease of adding a neon equivalent if ever necessary.
+ for (int j = 0; j < col_size; j++) {
+ output->data.f[j + i * col_size] =
+ value->data.uint8[j + idx * col_size] * scaling_factor;
+ }
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* lookup = GetInput(context, node, 0);
+ const TfLiteTensor* value = GetInput(context, node, 1);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (value->type) {
+ case kTfLiteFloat32:
+ return EvalFloat(context, node, lookup, value, output);
+ case kTfLiteUInt8:
+ return EvalHybrid(context, node, lookup, value, output);
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+}
+
} // namespace embedding_lookup
TfLiteRegistration* Register_EMBEDDING_LOOKUP() {
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
index 9b501878f1..04657fd863 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
@@ -7,13 +7,14 @@ You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
+distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License
+for the specific language governing permissions and limitations under the
+License.
==============================================================================*/
// Unit test for TFLite Lookup op.
+#include <initializer_list>
#include <iomanip>
#include <vector>
@@ -29,12 +30,13 @@ namespace {
using ::testing::ElementsAreArray;
-class EmbeddingLookupOpModel : public SingleOpModel {
+class BaseEmbeddingLookupOpModel : public SingleOpModel {
public:
- EmbeddingLookupOpModel(std::initializer_list<int> index_shape,
- std::initializer_list<int> weight_shape) {
+ BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
+ std::initializer_list<int> weight_shape,
+ TensorType weight_type = TensorType_FLOAT32) {
input_ = AddInput(TensorType_INT32);
- weight_ = AddInput(TensorType_FLOAT32);
+ weight_ = AddInput(weight_type);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0);
BuildInterpreter({index_shape, weight_shape});
@@ -44,6 +46,18 @@ class EmbeddingLookupOpModel : public SingleOpModel {
PopulateTensor(input_, data);
}
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input_;
+ int weight_;
+ int output_;
+};
+
+class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
+ public:
+ using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel;
+
void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
TfLiteTensor* tensor = interpreter_->tensor(weight_);
int rows = tensor->dims->data[0];
@@ -57,20 +71,25 @@ class EmbeddingLookupOpModel : public SingleOpModel {
}
}
}
+};
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
+ public:
+ HybridEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
+ std::initializer_list<int> weight_shape)
+ : BaseEmbeddingLookupOpModel(index_shape, weight_shape,
+ TensorType_UINT8) {}
- private:
- int input_;
- int weight_;
- int output_;
+ void SetWeight(std::initializer_list<float> data) {
+ SymmetricQuantizeAndPopulate(weight_, data);
+ }
};
// TODO(ahentz): write more tests that exercise the details of the op, such as
// lookup errors and variable input shapes.
TEST(EmbeddingLookupOpTest, SimpleTest) {
EmbeddingLookupOpModel m({3}, {3, 2, 4});
- m.PopulateTensor<int>(0, {1, 0, 2});
+ m.SetInput({1, 0, 2});
m.Set3DWeightMatrix(
[](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
@@ -84,6 +103,69 @@ TEST(EmbeddingLookupOpTest, SimpleTest) {
})));
}
+TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) {
+ HybridEmbeddingLookupOpModel m({3}, {3, 8});
+ m.SetInput({1, 0, 2});
+ m.SetWeight({
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ },
+ 7.41e-03)));
+}
+
+TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) {
+ HybridEmbeddingLookupOpModel m({3}, {3, 2, 4});
+ m.SetInput({1, 0, 2});
+ m.SetWeight({
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ },
+ 7.41e-03)));
+}
+
+TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) {
+ HybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2});
+ m.SetInput({1, 0, 2});
+ m.SetWeight({
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ },
+ 7.41e-03)));
+}
+
} // namespace
} // namespace tflite