aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-31 15:11:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-31 15:13:58 -0700
commit269a4ed1c27251b55cffe578b7bd969ec5975487 (patch)
tree83dd602a71ad69b3fcb7b5ff5adc59c7adac3758 /tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
parentf21816ecefe3f6e554d3b7daae3bb7f7a03bad20 (diff)
Internal change.
PiperOrigin-RevId: 198787391
Diffstat (limited to 'tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc41
1 files changed, 26 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
index 8429dba54b..164a0cbd08 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
@@ -41,7 +41,7 @@ constexpr int kOutputTensor = 1;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
- context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index);
+ context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -102,7 +102,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(2);
+ node->temporaries = TfLiteIntArrayCreate(3);
node->temporaries->data[0] = *scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
input_quantized->type = kTfLiteUInt8;
@@ -125,6 +125,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, hidden_state_quantized,
hidden_state_quantized_size));
}
+ node->temporaries->data[2] = *scratch_tensor_index + 2;
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = batch_size;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
}
return kTfLiteOk;
}
@@ -187,14 +197,12 @@ TfLiteStatus EvalFloat(const TfLiteTensor* input,
return kTfLiteOk;
}
-TfLiteStatus EvalQuantized(const TfLiteTensor* input,
- const TfLiteTensor* input_weights,
- const TfLiteTensor* recurrent_weights,
- const TfLiteTensor* bias,
- const TfLiteSequenceRNNParams* params,
- TfLiteTensor* input_scratch,
- TfLiteTensor* hidden_state_scratch,
- TfLiteTensor* hidden_state, TfLiteTensor* output) {
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_weights,
+ const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias,
+ const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch,
+ TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors,
+ TfLiteTensor* hidden_state, TfLiteTensor* output) {
const bool time_major = params->time_major;
const int batch_size =
(time_major) ? input->dims->data[1] : input->dims->data[0];
@@ -218,6 +226,7 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input,
reinterpret_cast<int8_t*>(input_scratch->data.uint8);
int8_t* quantized_hidden_state_ptr =
reinterpret_cast<int8_t*>(hidden_state_scratch->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
if (time_major) {
// Initialize the pointer to hidden state.
@@ -233,7 +242,8 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input,
input_ptr_batch, input_weights_ptr, input_weights_scale,
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
num_units, batch_size, params->activation, quantized_input_ptr,
- quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch);
+ quantized_hidden_state_ptr, scaling_factors_ptr,
+ hidden_state_ptr_batch, output_ptr_batch);
}
} else {
// For each batch
@@ -252,7 +262,7 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input,
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr,
input_size, num_units, /*batch_size=*/1, params->activation,
quantized_input_ptr, quantized_hidden_state_ptr,
- hidden_state_ptr_batch, output_ptr_batch);
+ scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch);
}
}
}
@@ -278,9 +288,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(mirkov): implement eval with quantized inputs as well.
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
- return EvalQuantized(input, input_weights, recurrent_weights, bias,
- params, input_quantized, hidden_state_quantized,
- hidden_state, output);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
+ return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
+ input_quantized, hidden_state_quantized,
+ scaling_factors, hidden_state, output);
}
default:
context->ReportError(context, "Type %d not currently supported.",