diff options
author | 2018-05-31 15:11:26 -0700 | |
---|---|---|
committer | 2018-05-31 15:13:58 -0700 | |
commit | 269a4ed1c27251b55cffe578b7bd969ec5975487 (patch) | |
tree | 83dd602a71ad69b3fcb7b5ff5adc59c7adac3758 /tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc | |
parent | f21816ecefe3f6e554d3b7daae3bb7f7a03bad20 (diff) |
Internal change.
PiperOrigin-RevId: 198787391
Diffstat (limited to 'tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc | 41 |
1 files changed, 26 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc index 8429dba54b..164a0cbd08 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -41,7 +41,7 @@ constexpr int kOutputTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; - context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index); + context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); return scratch_tensor_index; } @@ -102,7 +102,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) { int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data); TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(2); + node->temporaries = TfLiteIntArrayCreate(3); node->temporaries->data[0] = *scratch_tensor_index; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); input_quantized->type = kTfLiteUInt8; @@ -125,6 +125,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, hidden_state_quantized, hidden_state_quantized_size)); } + node->temporaries->data[2] = *scratch_tensor_index + 2; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = batch_size; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } } return kTfLiteOk; } @@ -187,14 +197,12 @@ TfLiteStatus EvalFloat(const TfLiteTensor* input, return kTfLiteOk; } -TfLiteStatus EvalQuantized(const TfLiteTensor* input, - const TfLiteTensor* input_weights, - const TfLiteTensor* recurrent_weights, - const TfLiteTensor* bias, - const TfLiteSequenceRNNParams* params, - TfLiteTensor* input_scratch, - TfLiteTensor* hidden_state_scratch, - TfLiteTensor* hidden_state, TfLiteTensor* output) { +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_weights, + const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias, + const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch, + TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors, + TfLiteTensor* hidden_state, TfLiteTensor* output) { const bool time_major = params->time_major; const int batch_size = (time_major) ? input->dims->data[1] : input->dims->data[0]; @@ -218,6 +226,7 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input, reinterpret_cast<int8_t*>(input_scratch->data.uint8); int8_t* quantized_hidden_state_ptr = reinterpret_cast<int8_t*>(hidden_state_scratch->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; if (time_major) { // Initialize the pointer to hidden state. @@ -233,7 +242,8 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input, input_ptr_batch, input_weights_ptr, input_weights_scale, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, batch_size, params->activation, quantized_input_ptr, - quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch); + quantized_hidden_state_ptr, scaling_factors_ptr, + hidden_state_ptr_batch, output_ptr_batch); } } else { // For each batch @@ -252,7 +262,7 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, /*batch_size=*/1, params->activation, quantized_input_ptr, quantized_hidden_state_ptr, - hidden_state_ptr_batch, output_ptr_batch); + scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch); } } } @@ -278,9 +288,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(mirkov): implement eval with quantized inputs as well. TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); - return EvalQuantized(input, input_weights, recurrent_weights, bias, - params, input_quantized, hidden_state_quantized, - hidden_state, output); + TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); + return EvalHybrid(input, input_weights, recurrent_weights, bias, params, + input_quantized, hidden_state_quantized, + scaling_factors, hidden_state, output); } default: context->ReportError(context, "Type %d not currently supported.", |