aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/kernel_utils.h')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.h83
1 files changed, 83 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
index f3f42f0840..2a11b37a60 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -92,6 +92,89 @@ void LstmStep(
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr_batch);
+// Same as above but with quantized weight matrices. In detail:
+// Input of size 'n_batch * n_input':
+// input_ptr_batch
+//
+// LSTM weights:
+// Quantized input weights of size 'n_cell * n_input':
+// input_to_input_weights - optional (can be nullptr)
+// input_to_forget_weights
+// input_to_cell_weights
+// input_to_input_weights
+// Quantized recurrent weights of size 'n_cell * n_output':
+// recurrent_to_input_weights - optional
+// recurrent_to_forget_weights
+// recurrent_to_cell_weights
+// recurrent_to_input_weights
+// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
+// cell_to_input_weights - optional
+// cell_to_cell_weights - optional
+// cell_to_output_weights - optional
+// Quantized projection weights of size 'n_output * n_cell'
+// projection_weights_ptr - optional
+// Weight scales (scalars) for each of the weights above.
+// input_to_input_weights_scale - optional
+// input_to_forget_weights_scale
+// input_to_cell_weights_scale
+// input_to_output_weights_scale
+// recurrent_to_input_weights_scale - optional
+// recurrent_to_forget_weights_scale
+// recurrent_to_cell_weights_scale
+// recurrent_to_output_weights_scale
+// cell_to_input_weights_scale,
+// cell_to_forget_weights_scale,
+// cell_to_output_weights_scale,
+// projection_weights_scale - optional
+// Gate biases of size 'n_cell':
+// input_gate_bias_ptr - optional
+// forget_gate_bias_ptr
+// cell_gate_bias_ptr
+// output_gate_bias_ptr
+//
+// Temporary pre-allocated storage for quantized values:
+// quantized_input_ptr_batch (same size as input_ptr_batch)
+// quantized_output_state_ptr (same size as output_state_ptr)
+// quantized_cell_state_ptr (same size as cell_state_ptr)
+// Temporary pre-allocated storage for recovered values:
+// recovered_cell_weights (same size as cell_to_*_weights)
+//
+// Outputs:
+// output_state_ptr - size 'n_batch * n_output'
+// cell_state_ptr - size 'n_batch * n_cell'
+// output_ptr_batch - size 'n_batch * n_output'
+void LstmStep(
+ const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+ float input_to_input_weights_scale,
+ const int8_t* input_to_forget_weights_ptr,
+ float input_to_forget_weights_scale,
+ const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
+ const int8_t* input_to_output_weights_ptr,
+ float input_to_output_weights_scale,
+ const int8_t* recurrent_to_input_weights_ptr,
+ float recurrent_to_input_weights_scale,
+ const int8_t* recurrent_to_forget_weights_ptr,
+ float recurrent_to_forget_weights_scale,
+ const int8_t* recurrent_to_cell_weights_ptr,
+ float recurrent_to_cell_weights_scale,
+ const int8_t* recurrent_to_output_weights_ptr,
+ float recurrent_to_output_weights_scale,
+ const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+ const int8_t* cell_to_forget_weights_ptr,
+ float cell_to_forget_weights_scale,
+ const int8_t* cell_to_output_weights_ptr,
+ float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+ float projection_weights_scale, const float* projection_bias_ptr,
+ const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
+ int n_output, float* input_gate_scratch, float* forget_gate_scratch,
+ float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
+ float* product_scaling_factors, float* recovered_cell_weights,
+ int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
+ int8_t* quantized_cell_state_ptr, float* output_state_ptr,
+ float* cell_state_ptr, float* output_ptr_batch);
+
} // namespace kernel_utils
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_