From 0b15439f8f0f2d4755587f4096c3ea04cb199d23 Mon Sep 17 00:00:00 2001 From: Andrew Selle Date: Fri, 10 Nov 2017 10:35:35 -0800 Subject: Internal Change. PiperOrigin-RevId: 175307445 --- .../contrib/lite/kernels/hashtable_lookup.cc | 155 +++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 tensorflow/contrib/lite/kernels/hashtable_lookup.cc (limited to 'tensorflow/contrib/lite/kernels/hashtable_lookup.cc') diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc new file mode 100644 index 0000000000..3b82601d11 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc @@ -0,0 +1,155 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Op that looks up items from hashtable. +// +// Input: +// Tensor[0]: Hash key to lookup, dim.size == 1, int32 +// Tensor[1]: Key of hashtable, dim.size == 1, int32 +// *MUST* be sorted in ascending order. +// Tensor[2]: Value of hashtable, dim.size >= 1 +// Tensor[1].Dim[0] == Tensor[2].Dim[0] +// +// Output: +// Output[0].dim[0] == Tensor[0].dim[0], num of lookups +// Each item in output is a raw bytes copy of corresponding item in input. +// When key does not exist in hashtable, the returned bytes are all 0s. +// +// Output[1].dim = { Tensor[0].dim[0] }, num of lookups +// Each item indicates whether the corresponding lookup has a returned value. +// 0 for missing key, 1 for found key. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { + +namespace { + +int greater(const void* a, const void* b) { + return *static_cast(a) - *static_cast(b); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); + + TfLiteTensor* lookup = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); + TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); + + TfLiteTensor* key = GetInput(context, node, 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1); + TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32); + + TfLiteTensor* value = GetInput(context, node, 2); + TF_LITE_ENSURE(context, NumDimensions(value) >= 1); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0), + SizeOfDimension(value, 0)); + if (value->type == kTfLiteString) { + TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1); + } + + TfLiteTensor* hits = GetOutput(context, node, 1); + TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8); + TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1); + hitSize->data[0] = SizeOfDimension(lookup, 0); + + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, value->type, output->type); + + TfLiteStatus status = kTfLiteOk; + if (output->type != kTfLiteString) { + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); + outputSize->data[0] = SizeOfDimension(lookup, 0); + for (int i = 1; i < NumDimensions(value); i++) { + outputSize->data[i] = SizeOfDimension(value, i); + } + status = context->ResizeTensor(context, output, outputSize); + } + if (context->ResizeTensor(context, hits, hitSize) == kTfLiteError) { + status = kTfLiteError; + } + return status; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* hits = GetOutput(context, node, 1); + TfLiteTensor* lookup = GetInput(context, node, 0); + TfLiteTensor* key = GetInput(context, node, 1); + TfLiteTensor* value = GetInput(context, node, 2); + + const int num_rows = SizeOfDimension(value, 0); + const int row_bytes = value->bytes / num_rows; + void* pointer = nullptr; + DynamicBuffer buf; + + for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { + int idx = -1; + pointer = bsearch(&(lookup->data.i32[i]), key->data.i32, num_rows, + sizeof(int32_t), greater); + if (pointer != nullptr) { + idx = (reinterpret_cast(pointer) - (key->data.raw)) / + sizeof(int32_t); + } + + if (idx >= num_rows || idx < 0) { + if (output->type == kTfLiteString) { + buf.AddString(nullptr, 0); + } else { + memset(output->data.raw + i * row_bytes, 0, row_bytes); + } + hits->data.uint8[i] = 0; + } else { + if (output->type == kTfLiteString) { + buf.AddString(GetString(value, idx)); + } else { + memcpy(output->data.raw + i * row_bytes, + value->data.raw + idx * row_bytes, row_bytes); + } + hits->data.uint8[i] = 1; + } + } + if (output->type == kTfLiteString) { + buf.WriteToTensor(output); + } + + return kTfLiteOk; +} +} // namespace + +TfLiteRegistration* Register_HASHTABLE_LOOKUP() { + static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite -- cgit v1.2.3