aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
index 45335fd78c..3f768bfee1 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
@@ -146,16 +146,19 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] = prev_activ_input;
lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state_input;
- // Reorder LstmCell's 4 outputs.
+ // Reorder LstmCell's 3 outputs.
lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS);
lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] =
src_op->outputs[kOutputTensor];
lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] =
src_op->outputs[kCellStateTensor];
- lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] =
- src_op->outputs[kScratchBufferTensor];
lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] =
src_op->outputs[kOutputStateTensor];
+ // Create a new temp array for the fourth output.
+ const string& concat_temp_array_name =
+ AvailableArrayName(*model, base_name + "concat_temp");
+ model->GetOrCreateArray(concat_temp_array_name);
+ lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name;
// Add the op into model.
model->operators.emplace(op_it, std::move(lstm_cell_op));