diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-25 07:03:40 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 07:09:12 -0700 |
commit | 2dff919b48799171c3a95acaea9e790cdcadb0c3 (patch) | |
tree | 942495387da778c1d4b7de2f44c44bd9cbae6bc1 /tensorflow/contrib/lite/kernels/lstm.cc | |
parent | a9d0bf9afc323be9ca52e1a23c52c3238a9b17cf (diff) |
Update kernel evals to use new kernel signatures.
PiperOrigin-RevId: 214432840
Diffstat (limited to 'tensorflow/contrib/lite/kernels/lstm.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/lstm.cc | 48 |
1 files changed, 28 insertions, 20 deletions
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index aaa3ce966e..5b996d00bc 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -893,18 +893,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { activation_out->type == kTfLiteFloat32 && concat_temp->type == kTfLiteFloat32 && activation_temp->type == kTfLiteFloat32) { + tflite::LstmCellParams op_params; + // Float LSTM cell does not need parameters to be set: leave untouched. optimized_ops::LstmCell( + op_params, // Inputs. - GetTensorData<float>(input), GetTensorDims(input), - GetTensorData<float>(prev_activation), GetTensorDims(prev_activation), - GetTensorData<float>(weights), GetTensorDims(weights), - GetTensorData<float>(bias), GetTensorDims(bias), - GetTensorData<float>(prev_state), GetTensorDims(prev_state), + GetTensorShape(input), GetTensorData<float>(input), + GetTensorShape(prev_activation), GetTensorData<float>(prev_activation), + GetTensorShape(weights), GetTensorData<float>(weights), + GetTensorShape(bias), GetTensorData<float>(bias), + GetTensorShape(prev_state), GetTensorData<float>(prev_state), // Outputs. - GetTensorData<float>(state_out), GetTensorDims(state_out), - GetTensorData<float>(activation_out), GetTensorDims(activation_out), - GetTensorData<float>(concat_temp), GetTensorDims(concat_temp), - GetTensorData<float>(activation_temp), GetTensorDims(activation_temp)); + GetTensorShape(state_out), GetTensorData<float>(state_out), + GetTensorShape(activation_out), GetTensorData<float>(activation_out), + GetTensorShape(concat_temp), GetTensorData<float>(concat_temp), + GetTensorShape(activation_temp), GetTensorData<float>(activation_temp)); } else if (input->type == kTfLiteUInt8 && prev_activation->type == kTfLiteUInt8 && weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 && @@ -934,20 +937,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { int accum_shift; tflite::QuantizeMultiplier(real_accum_multiplier, &accum_multiplier, &accum_shift); + tflite::LstmCellParams op_params; + op_params.weights_zero_point = weights->params.zero_point; + op_params.accum_multiplier = accum_multiplier; + op_params.accum_shift = accum_shift; optimized_ops::LstmCell<4>( + op_params, // Inputs. - GetTensorData<uint8_t>(input), GetTensorDims(input), - GetTensorData<uint8_t>(prev_activation), GetTensorDims(prev_activation), - GetTensorData<uint8_t>(weights), GetTensorDims(weights), - GetTensorData<int32_t>(bias), GetTensorDims(bias), - GetTensorData<int16_t>(prev_state), GetTensorDims(prev_state), + GetTensorShape(input), GetTensorData<uint8_t>(input), + GetTensorShape(prev_activation), + GetTensorData<uint8_t>(prev_activation), GetTensorShape(weights), + GetTensorData<uint8_t>(weights), GetTensorShape(bias), + GetTensorData<int32_t>(bias), GetTensorShape(prev_state), + GetTensorData<int16_t>(prev_state), // Outputs. - GetTensorData<int16_t>(state_out), GetTensorDims(state_out), - GetTensorData<uint8_t>(activation_out), GetTensorDims(activation_out), - GetTensorData<uint8_t>(concat_temp), GetTensorDims(concat_temp), - GetTensorData<int16_t>(activation_temp), GetTensorDims(activation_temp), - weights->params.zero_point, accum_multiplier, accum_shift, - gemm_context); + GetTensorShape(state_out), GetTensorData<int16_t>(state_out), + GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out), + GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp), + GetTensorShape(activation_temp), + GetTensorData<int16_t>(activation_temp), gemm_context); } else { context->ReportError(context, "Unsupported combination of data types for LstmCell"); |