aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/lstm.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-25 07:03:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 07:09:12 -0700
commit2dff919b48799171c3a95acaea9e790cdcadb0c3 (patch)
tree942495387da778c1d4b7de2f44c44bd9cbae6bc1 /tensorflow/contrib/lite/kernels/lstm.cc
parenta9d0bf9afc323be9ca52e1a23c52c3238a9b17cf (diff)
Update kernel evals to use new kernel signatures.
PiperOrigin-RevId: 214432840
Diffstat (limited to 'tensorflow/contrib/lite/kernels/lstm.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc48
1 files changed, 28 insertions, 20 deletions
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index aaa3ce966e..5b996d00bc 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -893,18 +893,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
activation_out->type == kTfLiteFloat32 &&
concat_temp->type == kTfLiteFloat32 &&
activation_temp->type == kTfLiteFloat32) {
+ tflite::LstmCellParams op_params;
+ // Float LSTM cell does not need parameters to be set: leave untouched.
optimized_ops::LstmCell(
+ op_params,
// Inputs.
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(prev_activation), GetTensorDims(prev_activation),
- GetTensorData<float>(weights), GetTensorDims(weights),
- GetTensorData<float>(bias), GetTensorDims(bias),
- GetTensorData<float>(prev_state), GetTensorDims(prev_state),
+ GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(prev_activation), GetTensorData<float>(prev_activation),
+ GetTensorShape(weights), GetTensorData<float>(weights),
+ GetTensorShape(bias), GetTensorData<float>(bias),
+ GetTensorShape(prev_state), GetTensorData<float>(prev_state),
// Outputs.
- GetTensorData<float>(state_out), GetTensorDims(state_out),
- GetTensorData<float>(activation_out), GetTensorDims(activation_out),
- GetTensorData<float>(concat_temp), GetTensorDims(concat_temp),
- GetTensorData<float>(activation_temp), GetTensorDims(activation_temp));
+ GetTensorShape(state_out), GetTensorData<float>(state_out),
+ GetTensorShape(activation_out), GetTensorData<float>(activation_out),
+ GetTensorShape(concat_temp), GetTensorData<float>(concat_temp),
+ GetTensorShape(activation_temp), GetTensorData<float>(activation_temp));
} else if (input->type == kTfLiteUInt8 &&
prev_activation->type == kTfLiteUInt8 &&
weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 &&
@@ -934,20 +937,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
int accum_shift;
tflite::QuantizeMultiplier(real_accum_multiplier, &accum_multiplier,
&accum_shift);
+ tflite::LstmCellParams op_params;
+ op_params.weights_zero_point = weights->params.zero_point;
+ op_params.accum_multiplier = accum_multiplier;
+ op_params.accum_shift = accum_shift;
optimized_ops::LstmCell<4>(
+ op_params,
// Inputs.
- GetTensorData<uint8_t>(input), GetTensorDims(input),
- GetTensorData<uint8_t>(prev_activation), GetTensorDims(prev_activation),
- GetTensorData<uint8_t>(weights), GetTensorDims(weights),
- GetTensorData<int32_t>(bias), GetTensorDims(bias),
- GetTensorData<int16_t>(prev_state), GetTensorDims(prev_state),
+ GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(prev_activation),
+ GetTensorData<uint8_t>(prev_activation), GetTensorShape(weights),
+ GetTensorData<uint8_t>(weights), GetTensorShape(bias),
+ GetTensorData<int32_t>(bias), GetTensorShape(prev_state),
+ GetTensorData<int16_t>(prev_state),
// Outputs.
- GetTensorData<int16_t>(state_out), GetTensorDims(state_out),
- GetTensorData<uint8_t>(activation_out), GetTensorDims(activation_out),
- GetTensorData<uint8_t>(concat_temp), GetTensorDims(concat_temp),
- GetTensorData<int16_t>(activation_temp), GetTensorDims(activation_temp),
- weights->params.zero_point, accum_multiplier, accum_shift,
- gemm_context);
+ GetTensorShape(state_out), GetTensorData<int16_t>(state_out),
+ GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out),
+ GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp),
+ GetTensorShape(activation_temp),
+ GetTensorData<int16_t>(activation_temp), gemm_context);
} else {
context->ReportError(context,
"Unsupported combination of data types for LstmCell");