diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc index 45335fd78c..3f768bfee1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc @@ -146,16 +146,19 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) { lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] = prev_activ_input; lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state_input; - // Reorder LstmCell's 4 outputs. + // Reorder LstmCell's 3 outputs. lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS); lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] = src_op->outputs[kOutputTensor]; lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] = src_op->outputs[kCellStateTensor]; - lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = - src_op->outputs[kScratchBufferTensor]; lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] = src_op->outputs[kOutputStateTensor]; + // Create a new temp array for the fourth output. + const string& concat_temp_array_name = + AvailableArrayName(*model, base_name + "concat_temp"); + model->GetOrCreateArray(concat_temp_array_name); + lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name; // Add the op into model. model->operators.emplace(op_it, std::move(lstm_cell_op)); |