diff options
author | Alan Chiao <alanchiao@google.com> | 2018-08-23 13:42:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-23 13:45:36 -0700 |
commit | 9a774e4d2d31443ea694938bec41237b4d6bcf02 (patch) | |
tree | eba93b9e5a8c96c5c40b2aff448e01690eb8a0db /tensorflow/contrib/lite/kernels/lstm.cc | |
parent | 6fe361f80d4277ea879b3182e1d7148a65a8ca21 (diff) |
Remove 18-input/3-output LSTM in favor of 20-input/1-output LSTM that supports
state API.
PiperOrigin-RevId: 209991722
Diffstat (limited to 'tensorflow/contrib/lite/kernels/lstm.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/lstm.cc | 76 |
1 files changed, 16 insertions, 60 deletions
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index ba251c451e..74dc3f25f9 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -37,7 +37,7 @@ namespace builtin { namespace lstm { struct OpData { - // Which kernel type to use. Full kernel (18 or 20 inputs) or basic kernel + // Which kernel type to use. Full kernel (20 inputs) or basic kernel // (5 inputs). TfLiteLSTMKernelType kernel_type; @@ -47,7 +47,7 @@ struct OpData { int scratch_tensor_index; }; -// For full inputs kernel (18 or 20 inputs). +// For full inputs kernel (20-inputs). namespace full { // Input Tensors of size {n_batch, n_input} @@ -81,19 +81,13 @@ constexpr int kProjectionWeightsTensor = 16; // Optional // Projection bias tensor of size {n_output} constexpr int kProjectionBiasTensor = 17; // Optional -// If the node has 20 inputs, the following 2 tensors are used as state tensors. -// These are defined as variable tensors, and will be modified by this op. +// These state tensors are defined as variable tensors, and will be modified by +// this op. constexpr int kInputActivationStateTensor = 18; constexpr int kInputCellStateTensor = 19; // Output tensors. -// * If the node has 18 inputs, these 2 tensors are used as state tensors. -// * If the node has 20 inputs, these 2 tensors are ignored. -// TODO(ycling): Make the 2 output state tensors optional, and propagate the -// state to output tensors when the 2 tensors present. -constexpr int kOutputStateTensor = 0; -constexpr int kCellStateTensor = 1; -constexpr int kOutputTensor = 2; +constexpr int kOutputTensor = 0; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* op_data = new OpData(); @@ -258,30 +252,12 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { OpData* op_data = reinterpret_cast<OpData*>(node->user_data); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 3); - - // True if the node is using input variable state tensors. It means: - // * The state tensors are defined as inputs. In this case it would be the - // 19th and 20th input tensors. - // * Otherwise, the output tensors are used to store states. - bool use_input_variable_states; - if (node->inputs->size == 20) { - use_input_variable_states = true; - op_data->activation_state_tensor_index = - node->inputs->data[kInputActivationStateTensor]; - op_data->cell_state_tensor_index = - node->inputs->data[kInputCellStateTensor]; - } else if (node->inputs->size == 18) { - use_input_variable_states = false; - op_data->activation_state_tensor_index = - node->outputs->data[kOutputStateTensor]; - op_data->cell_state_tensor_index = node->outputs->data[kCellStateTensor]; - } else { - context->ReportError( - context, "The LSTM Full kernel expects 18 or 20 inputs. Got %d inputs", - node->inputs->size); - return kTfLiteError; - } + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + TF_LITE_ENSURE_EQ(context, node->inputs->size, 20); + + op_data->activation_state_tensor_index = + node->inputs->data[kInputActivationStateTensor]; + op_data->cell_state_tensor_index = node->inputs->data[kInputCellStateTensor]; // Inferring batch size, number of outputs and number of cells from the // input tensors. @@ -316,31 +292,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* cell_state = &context->tensors[op_data->cell_state_tensor_index]; - if (use_input_variable_states) { - // Check the shape of input state tensors. - // These tensor may be 1D or 2D. It's fine as long as the total size is - // correct. - TF_LITE_ENSURE_EQ(context, NumElements(activation_state), - n_batch * n_output); - TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); - } else { - // If the state tensors are outputs, this function takes the - // responsibility to resize the state tensors. - TfLiteIntArray* activation_state_size = TfLiteIntArrayCreate(2); - activation_state_size->data[0] = n_batch; - activation_state_size->data[1] = n_output; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state, - activation_state_size)); - - TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); - cell_size->data[0] = n_batch; - cell_size->data[1] = n_cell; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, cell_state, cell_size)); - // Mark state tensors as persistent tensors. - activation_state->allocation_type = kTfLiteArenaRwPersistent; - cell_state->allocation_type = kTfLiteArenaRwPersistent; - } + // Check the shape of input state tensors. + // These tensor may be 1D or 2D. It's fine as long as the total size is + // correct. + TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output); + TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); // Resize the output tensors. TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); |