aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
index 3f768bfee1..5b6a984ee1 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
@@ -33,9 +33,10 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
return false;
}
- // Already a compact LstmCell with LstmCellOperator::NUM_INPUTS of inputs,
- // do not need to merge cell inputs.
- if (src_op->inputs.size() == LstmCellOperator::NUM_INPUTS) {
+ // Already a compact LstmCell. Do not need to merge cell inputs.
+ const auto* src_lstm_op = static_cast<LstmCellOperator*>(src_op);
+ if (src_lstm_op->kernel_type != LstmCellOperator::KERNEL_FULL ||
+ src_lstm_op->inputs.size() != kExtendedLstmInputCount) {
return false;
}
@@ -136,6 +137,7 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
// Emplace a new LSTM cell operator (use basic 5 inputs kernel).
auto lstm_cell_op = absl::make_unique<LstmCellOperator>();
+ lstm_cell_op->kernel_type = LstmCellOperator::KERNEL_BASIC;
// Compact LstmCell's 5 inputs.
lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS);