aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/lstm.cc
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-06-19 12:35:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-19 12:38:27 -0700
commit5fab6df2788937bee1cce3a4e8f5b9d1db7497ec (patch)
treeba18594841593a0b2a3eda55c076ca78c7bf0d4e /tensorflow/contrib/lite/kernels/lstm.cc
parent8f19772410ec20010e9930f9765dbd3aaeb06111 (diff)
Support Variable Tensor API in LSTM Full kernel.
TFLite LSTM now supports 5 inputs, 18 inputs and 20 inputs. PiperOrigin-RevId: 201222516
Diffstat (limited to 'tensorflow/contrib/lite/kernels/lstm.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc161
1 files changed, 106 insertions, 55 deletions
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index eb26a02455..1dda97c101 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -37,14 +37,17 @@ namespace builtin {
namespace lstm {
struct OpData {
- // Which kernel type to use. Full kernel (18-inputs) or basic kernel
- // (5-inputs).
+ // Which kernel type to use. Full kernel (18 or 20 inputs) or basic kernel
+ // (5 inputs).
TfLiteLSTMKernelType kernel_type;
- // Only used by full kernel.
+
+ // These fields are only used by full kernel.
+ int activation_state_tensor_index;
+ int cell_state_tensor_index;
int scratch_tensor_index;
};
-// For full inputs kernel (18-inputs).
+// For full inputs kernel (18 or 20 inputs).
namespace full {
// Input Tensors of size {n_batch, n_input}
@@ -78,7 +81,16 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
// Projection bias tensor of size {n_output}
constexpr int kProjectionBiasTensor = 17; // Optional
+// If the node has 20 inputs, the following 2 tensors are used as state tensors.
+// These are defined as variable tensors, and will be modified by this op.
+constexpr int kInputActivationStateTensor = 18;
+constexpr int kInputCellStateTensor = 19;
+
// Output tensors.
+// * If the node has 18 inputs, these 2 tensors are used as state tensors.
+// * If the node has 20 inputs, these 2 tensors are ignored.
+// TODO(ycling): Make the 2 output state tensors optional, and propagate the
+// state to output tensors when the 2 tensors present.
constexpr int kOutputStateTensor = 0;
constexpr int kCellStateTensor = 1;
constexpr int kOutputTensor = 2;
@@ -246,10 +258,31 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
- // Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
+ // True if the node is using input variable state tensors. It means:
+ // * The state tensors are defined as inputs. In this case it would be the
+ // 19th and 20th input tensors.
+ // * Otherwise, the output tensors are used to store states.
+ bool use_input_variable_states;
+ if (node->inputs->size == 20) {
+ use_input_variable_states = true;
+ op_data->activation_state_tensor_index =
+ node->inputs->data[kInputActivationStateTensor];
+ op_data->cell_state_tensor_index =
+ node->inputs->data[kInputCellStateTensor];
+ } else if (node->inputs->size == 18) {
+ use_input_variable_states = false;
+ op_data->activation_state_tensor_index =
+ node->outputs->data[kOutputStateTensor];
+ op_data->cell_state_tensor_index = node->outputs->data[kCellStateTensor];
+ } else {
+ context->ReportError(
+ context, "The LSTM Full kernel expects 18 or 20 inputs. Got %d inputs",
+ node->inputs->size);
+ return kTfLiteError;
+ }
+
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
@@ -274,34 +307,47 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
- // Get the pointer to output, output_state and cell_state tensors.
+ // Get the pointer to output, activation_state and cell_state tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- // Resize the output, output_state and cell_state tensors.
+ TfLiteTensor* activation_state =
+ &context->tensors[op_data->activation_state_tensor_index];
+ TfLiteTensor* cell_state =
+ &context->tensors[op_data->cell_state_tensor_index];
+
+ if (use_input_variable_states) {
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(activation_state),
+ n_batch * n_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
+ } else {
+ // If the state tensors are outputs, this function takes the
+ // responsibility to resize the state tensors.
+ TfLiteIntArray* activation_state_size = TfLiteIntArrayCreate(2);
+ activation_state_size->data[0] = n_batch;
+ activation_state_size->data[1] = n_output;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state,
+ activation_state_size));
+
+ TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
+ cell_size->data[0] = n_batch;
+ cell_size->data[1] = n_cell;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state, cell_size));
+ // Mark state tensors as persistent tensors.
+ activation_state->allocation_type = kTfLiteArenaRwPersistent;
+ cell_state->allocation_type = kTfLiteArenaRwPersistent;
+ }
+
+ // Resize the output tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
output_size->data[0] = n_batch;
output_size->data[1] = n_output;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
- TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2);
- output_state_size->data[0] = n_batch;
- output_state_size->data[1] = n_output;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, output_state, output_state_size));
-
- TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
- cell_size->data[0] = n_batch;
- cell_size->data[1] = n_cell;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, cell_state, cell_size));
-
- // Mark state tensors as persistent tensors.
- output_state->allocation_type = kTfLiteArenaRwPersistent;
- cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
// The weights are of consistent type, so it suffices to check one.
// TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
@@ -337,7 +383,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (is_hybrid_op) {
// Allocate temporary tensors to store quantized values of input,
- // output_state and cell_state tensors.
+ // activation_state and cell_state tensors.
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
input_quantized->type = kTfLiteUInt8;
@@ -348,17 +394,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_quantized_size));
}
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
- TfLiteTensor* output_state_quantized =
+ TfLiteTensor* activation_state_quantized =
GetTemporary(context, node, /*index=*/2);
- output_state_quantized->type = kTfLiteUInt8;
- output_state_quantized->allocation_type = kTfLiteArenaRw;
- if (!TfLiteIntArrayEqual(output_state_quantized->dims,
- output_state->dims)) {
- TfLiteIntArray* output_state_quantized_size =
- TfLiteIntArrayCopy(output_state->dims);
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, output_state_quantized,
- output_state_quantized_size));
+ activation_state_quantized->type = kTfLiteUInt8;
+ activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
+ activation_state->dims)) {
+ TfLiteIntArray* activation_state_quantized_size =
+ TfLiteIntArrayCopy(activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, activation_state_quantized,
+ activation_state_quantized_size));
}
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
TfLiteTensor* cell_state_quantized =
@@ -438,7 +484,7 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output) {
const int n_batch = input->dims->data[0];
const int n_input = input->dims->data[1];
@@ -499,7 +545,7 @@ TfLiteStatus EvalFloat(
const float* cell_bias_ptr = cell_bias->data.f;
const float* output_gate_bias_ptr = output_gate_bias->data.f;
- float* output_state_ptr = output_state->data.f;
+ float* activation_state_ptr = activation_state->data.f;
float* cell_state_ptr = cell_state->data.f;
float* output_ptr_batch = output->data.f;
@@ -512,8 +558,8 @@ TfLiteStatus EvalFloat(
cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
- output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
- cell_scratch, output_gate_scratch, output_ptr_batch);
+ activation_state_ptr, cell_state_ptr, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
return kTfLiteOk;
}
@@ -536,9 +582,9 @@ TfLiteStatus EvalHybrid(
const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
- TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
- TfLiteTensor* output_state, TfLiteTensor* cell_state,
- TfLiteTensor* output) {
+ TfLiteTensor* activation_state_quantized,
+ TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
const int n_batch = input->dims->data[0];
const int n_input = input->dims->data[1];
// n_cell and n_output will be the same size when there is no projection.
@@ -639,15 +685,15 @@ TfLiteStatus EvalHybrid(
const float* cell_bias_ptr = cell_bias->data.f;
const float* output_gate_bias_ptr = output_gate_bias->data.f;
- float* output_state_ptr = output_state->data.f;
+ float* activation_state_ptr = activation_state->data.f;
float* cell_state_ptr = cell_state->data.f;
float* output_ptr_batch = output->data.f;
// Temporary storage for quantized values and scaling factors.
int8_t* quantized_input_ptr =
reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_output_state_ptr =
- reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+ int8_t* quantized_activation_state_ptr =
+ reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
int8_t* quantized_cell_state_ptr =
reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
float* scaling_factors_ptr = scaling_factors->data.f;
@@ -672,14 +718,16 @@ TfLiteStatus EvalHybrid(
input_gate_scratch, forget_gate_scratch, cell_scratch,
output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
recovered_cell_weights_ptr, quantized_input_ptr,
- quantized_output_state_ptr, quantized_cell_state_ptr, output_state_ptr,
- cell_state_ptr, output_ptr_batch);
+ quantized_activation_state_ptr, quantized_cell_state_ptr,
+ activation_state_ptr, cell_state_ptr, output_ptr_batch);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_to_input_weights =
@@ -723,8 +771,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
+ TfLiteTensor* activation_state =
+ &context->tensors[op_data->activation_state_tensor_index];
+ TfLiteTensor* cell_state =
+ &context->tensors[op_data->cell_state_tensor_index];
+
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// TODO(mirkov): add a check that weights are all uint8s or all floats.
@@ -738,11 +789,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
cell_to_output_weights, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias,
projection_weights, projection_bias, params,
- scratch_buffer, output_state, cell_state, output);
+ scratch_buffer, activation_state, cell_state, output);
}
case kTfLiteUInt8: {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
- TfLiteTensor* output_state_quantized =
+ TfLiteTensor* activation_state_quantized =
GetTemporary(context, node, /*index=*/2);
TfLiteTensor* cell_state_quantized =
GetTemporary(context, node, /*index=*/3);
@@ -760,8 +811,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
projection_weights, projection_bias, params, scratch_buffer,
scaling_factors, prod_scaling_factors, recovered_cell_weights,
- input_quantized, output_state_quantized, cell_state_quantized,
- output_state, cell_state, output);
+ input_quantized, activation_state_quantized, cell_state_quantized,
+ activation_state, cell_state, output);
}
default:
context->ReportError(context, "Type %d is not currently supported.",