From bd95d55a2886677ba194351197d93c8b1408cc85 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 May 2018 12:14:52 -0700 Subject: Implementation of the unidirectional_sequence_rnn TFLite Op using the symmetric quantization. PiperOrigin-RevId: 196152754 --- .../lite/kernels/unidirectional_sequence_rnn.cc | 184 ++++++++++++++++++--- 1 file changed, 159 insertions(+), 25 deletions(-) (limited to 'tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc') diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc index ac00c37b67..5ae635bfda 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -38,17 +39,26 @@ constexpr int kBiasTensor = 3; constexpr int kHiddenStateTensor = 0; constexpr int kOutputTensor = 1; +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* scratch_tensor_index = new int; + context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index); + return scratch_tensor_index; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); - TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; - TfLiteTensor* input_weights = - &context->tensors[node->inputs->data[kWeightsTensor]]; + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); TfLiteTensor* recurrent_weights = - &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; - TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + GetInput(context, node, kRecurrentWeightsTensor); + TfLiteTensor* bias = GetInput(context, node, kBiasTensor); // Check all the parameters of tensor match within themselves and match the // input configuration. @@ -64,9 +74,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]); TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); - TfLiteTensor* hidden_state = - &context->tensors[node->outputs->data[kHiddenStateTensor]]; - TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Resize state. TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2); @@ -86,22 +95,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size_array)); + // Allocate temporary tensors to store quantized values of input and + // hidden_state tensors. + if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) { + int* scratch_tensor_index = reinterpret_cast(node->user_data); + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(2); + node->temporaries->data[0] = *scratch_tensor_index; + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + node->temporaries->data[1] = *scratch_tensor_index + 1; + TfLiteTensor* hidden_state_quantized = + GetTemporary(context, node, /*index=*/1); + hidden_state_quantized->type = kTfLiteUInt8; + hidden_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(hidden_state_quantized->dims, + hidden_state->dims)) { + TfLiteIntArray* hidden_state_quantized_size = + TfLiteIntArrayCopy(hidden_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, hidden_state_quantized, + hidden_state_quantized_size)); + } + } return kTfLiteOk; } -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - - TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; - TfLiteTensor* input_weights = - &context->tensors[node->inputs->data[kWeightsTensor]]; - TfLiteTensor* recurrent_weights = - &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; - TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; - TfLiteTensor* hidden_state = - &context->tensors[node->outputs->data[kHiddenStateTensor]]; - TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; - +TfLiteStatus EvalFloat(const TfLiteTensor* input, + const TfLiteTensor* input_weights, + const TfLiteTensor* recurrent_weights, + const TfLiteTensor* bias, + const TfLiteSequenceRNNParams* params, + TfLiteTensor* hidden_state, TfLiteTensor* output) { // Initialize the pointer bias. const float* bias_ptr = bias->data.f; @@ -120,7 +151,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (time_major) { // Initialize the pointer to hidden state. float* hidden_state_ptr_batch = hidden_state->data.f; - // Unroll the sequence and use batch batch operations for efficiency. + // Unroll the sequence and use batch operations for efficiency. for (int s = 0; s < max_time; s++) { // Initialize the pointer to input and output. const float* input_ptr_batch = @@ -154,12 +185,115 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalQuantized(const TfLiteTensor* input, + const TfLiteTensor* input_weights, + const TfLiteTensor* recurrent_weights, + const TfLiteTensor* bias, + const TfLiteSequenceRNNParams* params, + TfLiteTensor* input_scratch, + TfLiteTensor* hidden_state_scratch, + TfLiteTensor* hidden_state, TfLiteTensor* output) { + const bool time_major = params->time_major; + const int batch_size = + (time_major) ? input->dims->data[1] : input->dims->data[0]; + const int max_time = + (time_major) ? input->dims->data[0] : input->dims->data[1]; + const int num_units = input_weights->dims->data[0]; + const int input_size = input->dims->data[2]; + + // Initialize the pointer bias. + const float* bias_ptr = bias->data.f; + // Initialize input_weights and recurrent_weights. + const int8_t* input_weights_ptr = + reinterpret_cast(input_weights->data.uint8); + const int8_t* recurrent_weights_ptr = + reinterpret_cast(recurrent_weights->data.uint8); + // Get the scale of the quantized weights. + float input_weights_scale = input_weights->params.scale; + float recurrent_weights_scale = recurrent_weights->params.scale; + // Initialize temporary storage for quantized values. + int8_t* quantized_input_ptr = + reinterpret_cast(input_scratch->data.uint8); + int8_t* quantized_hidden_state_ptr = + reinterpret_cast(hidden_state_scratch->data.uint8); + + if (time_major) { + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f; + // Unroll the sequence and use batch operations for efficiency. + for (int s = 0; s < max_time; s++) { + // Initialize the pointer to input and output. + const float* input_ptr_batch = + input->data.f + s * input_size * batch_size; + float* output_ptr_batch = output->data.f + s * num_units * batch_size; + + kernel_utils::RnnBatchStep( + input_ptr_batch, input_weights_ptr, input_weights_scale, + recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, + num_units, batch_size, params->activation, quantized_input_ptr, + quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch); + } + } else { + // For each batch + for (int b = 0; b < batch_size; b++) { + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; + for (int s = 0; s < max_time; s++) { + // Initialize the pointer to input and output. + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + float* output_ptr_batch = + output->data.f + b * num_units * max_time + s * num_units; + + kernel_utils::RnnBatchStep( + input_ptr_batch, input_weights_ptr, input_weights_scale, + recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, + input_size, num_units, /*batch_size=*/1, params->activation, + quantized_input_ptr, quantized_hidden_state_ptr, + hidden_state_ptr_batch, output_ptr_batch); + } + } + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); + TfLiteTensor* recurrent_weights = + GetInput(context, node, kRecurrentWeightsTensor); + TfLiteTensor* bias = GetInput(context, node, kBiasTensor); + TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input_weights->type) { + case kTfLiteFloat32: + return EvalFloat(input, input_weights, recurrent_weights, bias, params, + hidden_state, output); + case kTfLiteUInt8: { + // TODO(mirkov): implement eval with quantized inputs as well. + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TfLiteTensor* input_quantized = GetTemporary(context, node, 0); + TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); + return EvalQuantized(input, input_weights, recurrent_weights, bias, + params, input_quantized, hidden_state_quantized, + hidden_state, output); + } + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + } // namespace unidirectional_sequence_rnn TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN() { - static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, - unidirectional_sequence_rnn::Prepare, - unidirectional_sequence_rnn::Eval}; + static TfLiteRegistration r = { + unidirectional_sequence_rnn::Init, unidirectional_sequence_rnn::Free, + unidirectional_sequence_rnn::Prepare, unidirectional_sequence_rnn::Eval}; return &r; } -- cgit v1.2.3