aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/embedding_lookup.cc
diff options
context:
space:
mode:
authorGravatar Alan Chiao <alanchiao@google.com>2018-06-08 17:19:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-08 17:23:47 -0700
commit49a729901484a413fd605be735da9a563c24336a (patch)
tree77fc992757d2f3700d3e1eacbd654bf0941b769d /tensorflow/contrib/lite/kernels/embedding_lookup.cc
parentcf042e7e90c00d639904e2a5fad8a9cd9d6962da (diff)
Hybrid embedding lookup op
PiperOrigin-RevId: 199874482
Diffstat (limited to 'tensorflow/contrib/lite/kernels/embedding_lookup.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup.cc57
1 files changed, 51 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
index 7539c0b30d..9410bead5e 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -24,7 +24,8 @@ limitations under the License.
// Output:
// Output.dim[0] == Tensor[0].dim[0], num of lookups
// Output.dim[1] == Tensor[1].dim[1], num of items per row
-// Each item in output is a raw bytes copy of corresponding item in input.
+// Each item in output is a raw bytes copy of the corresponding item in input,
+// or a dequantized value in the case of a uint8 input.
// When indices are out of bound, the ops will not succeed.
//
@@ -69,11 +70,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, outputSize);
}
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* output = GetOutput(context, node, 0);
- const TfLiteTensor* lookup = GetInput(context, node, 0);
- const TfLiteTensor* value = GetInput(context, node, 1);
-
+TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* lookup, const TfLiteTensor* value,
+ TfLiteTensor* output) {
const int row_size = SizeOfDimension(value, 0);
const int row_bytes = value->bytes / row_size;
@@ -91,6 +90,52 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* lookup, const TfLiteTensor* value,
+ TfLiteTensor* output) {
+ const int row_size = SizeOfDimension(value, 0);
+ const double scaling_factor = 1.0 / value->params.scale;
+
+ // col_size after we flatten tensor into 2D.
+ int col_size = 1;
+ for (int i = 1; i < NumDimensions(value); i++) {
+ col_size *= SizeOfDimension(value, i);
+ }
+
+ for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
+ int idx = lookup->data.i32[i];
+ if (idx >= row_size || idx < 0) {
+ context->ReportError(context, "Embedding Lookup: index out of bounds.");
+ return kTfLiteError;
+ } else {
+ // Dequantize embedding values.
+ // TODO(alanchiao): refactor scalar multiply into separate function
+ // for ease of adding a neon equivalent if ever necessary.
+ for (int j = 0; j < col_size; j++) {
+ output->data.f[j + i * col_size] =
+ value->data.uint8[j + idx * col_size] * scaling_factor;
+ }
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* lookup = GetInput(context, node, 0);
+ const TfLiteTensor* value = GetInput(context, node, 1);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (value->type) {
+ case kTfLiteFloat32:
+ return EvalFloat(context, node, lookup, value, output);
+ case kTfLiteUInt8:
+ return EvalHybrid(context, node, lookup, value, output);
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+}
+
} // namespace embedding_lookup
TfLiteRegistration* Register_EMBEDDING_LOOKUP() {