diff options
author | Andrew Selle <aselle@google.com> | 2017-11-10 10:35:35 -0800 |
---|---|---|
committer | Andrew Selle <aselle@andyselle.com> | 2017-11-10 16:14:42 -0800 |
commit | 0b15439f8f0f2d4755587f4096c3ea04cb199d23 (patch) | |
tree | 9aa4fc8162bf9b4ee50112a7b85703f70ca4df08 /tensorflow/contrib/lite/kernels/embedding_lookup.cc | |
parent | 7ac140a5845553275427162aabd9d54987144b4a (diff) |
Internal Change.
PiperOrigin-RevId: 175307445
Diffstat (limited to 'tensorflow/contrib/lite/kernels/embedding_lookup.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/embedding_lookup.cc | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc new file mode 100644 index 0000000000..4e8cb396d4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Ops that looks up items from matrix. +// +// Input: +// Tensor[0]: Row number to lookup, dim.size == 1, int32 +// Tensor[1]: 2-dimensional matrix of multi-dimensional items +// dim.size >= 2, any data type. +// first dimension is row, second dimension is column. +// +// Output: +// Output.dim[0] == Tensor[0].dim[0], num of lookups +// Output.dim[1] == Tensor[1].dim[1], num of items per row +// Each item in output is a raw bytes copy of corresponding item in input. +// When indices are out of bound, the ops will not succeed. +// + +#include <unistd.h> +#include <cassert> +#include <cmath> +#include <cstdio> +#include <cstdlib> +#include <cstring> +#include <iostream> +#include <limits> + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace embedding_lookup { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* lookup = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); + TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); + + TfLiteTensor* value = GetInput(context, node, 1); + TF_LITE_ENSURE(context, NumDimensions(value) >= 2); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); + + outputSize->data[0] = SizeOfDimension(lookup, 0); + outputSize->data[1] = SizeOfDimension(value, 1); + for (int i = 2; i < NumDimensions(value); i++) { + outputSize->data[i] = SizeOfDimension(value, i); + } + return context->ResizeTensor(context, output, outputSize); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* lookup = GetInput(context, node, 0); + TfLiteTensor* value = GetInput(context, node, 1); + + const int row_size = SizeOfDimension(value, 0); + const int row_bytes = value->bytes / row_size; + + for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { + int idx = lookup->data.i32[i]; + if (idx >= row_size || idx < 0) { + context->ReportError(context, "Embedding Lookup: index out of bounds."); + return kTfLiteError; + } else { + memcpy(output->data.raw + i * row_bytes, + value->data.raw + idx * row_bytes, row_bytes); + } + } + + return kTfLiteOk; +} + +} // namespace embedding_lookup + +TfLiteRegistration* Register_EMBEDDING_LOOKUP() { + static TfLiteRegistration r = {nullptr, nullptr, embedding_lookup::Prepare, + embedding_lookup::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite |