aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-10 12:14:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-10 12:17:27 -0700
commitbd95d55a2886677ba194351197d93c8b1408cc85 (patch)
treedab3692368df669482f035bbd97726a1980bca37 /tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
parent3ffa132c03ff02decc86a31d8bf888e9381278a7 (diff)
Implementation of the unidirectional_sequence_rnn TFLite Op using the symmetric quantization.
PiperOrigin-RevId: 196152754
Diffstat (limited to 'tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc184
1 files changed, 159 insertions, 25 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
index ac00c37b67..5ae635bfda 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -38,17 +39,26 @@ constexpr int kBiasTensor = 3;
constexpr int kHiddenStateTensor = 0;
constexpr int kOutputTensor = 1;
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
- TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
- TfLiteTensor* input_weights =
- &context->tensors[node->inputs->data[kWeightsTensor]];
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
TfLiteTensor* recurrent_weights =
- &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
- TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
+ GetInput(context, node, kRecurrentWeightsTensor);
+ TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
@@ -64,9 +74,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]);
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
- TfLiteTensor* hidden_state =
- &context->tensors[node->outputs->data[kHiddenStateTensor]];
- TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
+ TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Resize state.
TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2);
@@ -86,22 +95,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size_array));
+ // Allocate temporary tensors to store quantized values of input and
+ // hidden_state tensors.
+ if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(2);
+ node->temporaries->data[0] = *scratch_tensor_index;
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+ node->temporaries->data[1] = *scratch_tensor_index + 1;
+ TfLiteTensor* hidden_state_quantized =
+ GetTemporary(context, node, /*index=*/1);
+ hidden_state_quantized->type = kTfLiteUInt8;
+ hidden_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(hidden_state_quantized->dims,
+ hidden_state->dims)) {
+ TfLiteIntArray* hidden_state_quantized_size =
+ TfLiteIntArrayCopy(hidden_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, hidden_state_quantized,
+ hidden_state_quantized_size));
+ }
+ }
return kTfLiteOk;
}
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
-
- TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
- TfLiteTensor* input_weights =
- &context->tensors[node->inputs->data[kWeightsTensor]];
- TfLiteTensor* recurrent_weights =
- &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
- TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
- TfLiteTensor* hidden_state =
- &context->tensors[node->outputs->data[kHiddenStateTensor]];
- TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
-
+TfLiteStatus EvalFloat(const TfLiteTensor* input,
+ const TfLiteTensor* input_weights,
+ const TfLiteTensor* recurrent_weights,
+ const TfLiteTensor* bias,
+ const TfLiteSequenceRNNParams* params,
+ TfLiteTensor* hidden_state, TfLiteTensor* output) {
// Initialize the pointer bias.
const float* bias_ptr = bias->data.f;
@@ -120,7 +151,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (time_major) {
// Initialize the pointer to hidden state.
float* hidden_state_ptr_batch = hidden_state->data.f;
- // Unroll the sequence and use batch batch operations for efficiency.
+ // Unroll the sequence and use batch operations for efficiency.
for (int s = 0; s < max_time; s++) {
// Initialize the pointer to input and output.
const float* input_ptr_batch =
@@ -154,12 +185,115 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+TfLiteStatus EvalQuantized(const TfLiteTensor* input,
+ const TfLiteTensor* input_weights,
+ const TfLiteTensor* recurrent_weights,
+ const TfLiteTensor* bias,
+ const TfLiteSequenceRNNParams* params,
+ TfLiteTensor* input_scratch,
+ TfLiteTensor* hidden_state_scratch,
+ TfLiteTensor* hidden_state, TfLiteTensor* output) {
+ const bool time_major = params->time_major;
+ const int batch_size =
+ (time_major) ? input->dims->data[1] : input->dims->data[0];
+ const int max_time =
+ (time_major) ? input->dims->data[0] : input->dims->data[1];
+ const int num_units = input_weights->dims->data[0];
+ const int input_size = input->dims->data[2];
+
+ // Initialize the pointer bias.
+ const float* bias_ptr = bias->data.f;
+ // Initialize input_weights and recurrent_weights.
+ const int8_t* input_weights_ptr =
+ reinterpret_cast<const int8_t*>(input_weights->data.uint8);
+ const int8_t* recurrent_weights_ptr =
+ reinterpret_cast<const int8_t*>(recurrent_weights->data.uint8);
+ // Get the scale of the quantized weights.
+ float input_weights_scale = input_weights->params.scale;
+ float recurrent_weights_scale = recurrent_weights->params.scale;
+ // Initialize temporary storage for quantized values.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_scratch->data.uint8);
+ int8_t* quantized_hidden_state_ptr =
+ reinterpret_cast<int8_t*>(hidden_state_scratch->data.uint8);
+
+ if (time_major) {
+ // Initialize the pointer to hidden state.
+ float* hidden_state_ptr_batch = hidden_state->data.f;
+ // Unroll the sequence and use batch operations for efficiency.
+ for (int s = 0; s < max_time; s++) {
+ // Initialize the pointer to input and output.
+ const float* input_ptr_batch =
+ input->data.f + s * input_size * batch_size;
+ float* output_ptr_batch = output->data.f + s * num_units * batch_size;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, input_weights_ptr, input_weights_scale,
+ recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
+ num_units, batch_size, params->activation, quantized_input_ptr,
+ quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch);
+ }
+ } else {
+ // For each batch
+ for (int b = 0; b < batch_size; b++) {
+ // Initialize the pointer to hidden state.
+ float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
+ for (int s = 0; s < max_time; s++) {
+ // Initialize the pointer to input and output.
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ float* output_ptr_batch =
+ output->data.f + b * num_units * max_time + s * num_units;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, input_weights_ptr, input_weights_scale,
+ recurrent_weights_ptr, recurrent_weights_scale, bias_ptr,
+ input_size, num_units, /*batch_size=*/1, params->activation,
+ quantized_input_ptr, quantized_hidden_state_ptr,
+ hidden_state_ptr_batch, output_ptr_batch);
+ }
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
+
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
+ TfLiteTensor* recurrent_weights =
+ GetInput(context, node, kRecurrentWeightsTensor);
+ TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (input_weights->type) {
+ case kTfLiteFloat32:
+ return EvalFloat(input, input_weights, recurrent_weights, bias, params,
+ hidden_state, output);
+ case kTfLiteUInt8: {
+ // TODO(mirkov): implement eval with quantized inputs as well.
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
+ TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
+ return EvalQuantized(input, input_weights, recurrent_weights, bias,
+ params, input_quantized, hidden_state_quantized,
+ hidden_state, output);
+ }
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
} // namespace unidirectional_sequence_rnn
TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
- unidirectional_sequence_rnn::Prepare,
- unidirectional_sequence_rnn::Eval};
+ static TfLiteRegistration r = {
+ unidirectional_sequence_rnn::Init, unidirectional_sequence_rnn::Free,
+ unidirectional_sequence_rnn::Prepare, unidirectional_sequence_rnn::Eval};
return &r;
}