diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc | 33 |
1 files changed, 18 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc index c0b014b45e..7fd8f906e2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc @@ -132,7 +132,9 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, } // namespace -bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { +::tensorflow::Status IdentifyLstmCell::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; // This LSTM cell identification method is not invariant to commutation of // commutative operator inputs. For example, if input[0] and input[1] of the // final output multiplication were swapped, this method would not identify it @@ -143,13 +145,13 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { auto op_it = model->operators.begin() + op_index; Operator* final_output_mul = op_it->get(); if (final_output_mul->type != OperatorType::kMul) { - return false; + return ::tensorflow::Status::OK(); } Operator *state_output_tanh, *fc_output_sig; if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh, &state_output_tanh, OperatorType::kLogistic, &fc_output_sig)) { - return false; + return ::tensorflow::Status::OK(); } // State output TanH @@ -158,7 +160,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { Operator* state_combine_add; if (!MatchOperatorInputs(*state_output_tanh, *model, OperatorType::kAdd, &state_combine_add)) { - return false; + return ::tensorflow::Status::OK(); } // State forget & remember addition @@ -166,7 +168,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { if (!MatchOperatorInputs(*state_combine_add, *model, OperatorType::kMul, &state_forget_mul, OperatorType::kMul, &state_remember_mul)) { - return false; + return ::tensorflow::Status::OK(); } const string prev_state = state_forget_mul->inputs[0]; @@ -175,7 +177,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { if (!MatchOperatorInputs(*state_forget_mul, *model, OperatorType::kNone, nullptr, OperatorType::kLogistic, &state_forget_sig)) { - return false; + return ::tensorflow::Status::OK(); } // State remember gate @@ -183,40 +185,40 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { if (!MatchOperatorInputs(*state_remember_mul, *model, OperatorType::kLogistic, &state_remember_sig, OperatorType::kTanh, &state_info_tanh)) { - return false; + return ::tensorflow::Status::OK(); } // State remember "information" activation function Operator* fc_output_split; if (!MatchOperatorInputs(*state_info_tanh, *model, OperatorType::kSplit, &fc_output_split)) { - return false; + return ::tensorflow::Status::OK(); } // State remember gate activation function Operator* tmp; if (!MatchOperatorInputs(*state_remember_sig, *model, OperatorType::kSplit, &tmp) || (tmp != fc_output_split)) { - return false; + return ::tensorflow::Status::OK(); } // State forget gate activation function if (!MatchOperatorInputs(*state_forget_sig, *model, OperatorType::kSplit, &tmp) || (tmp != fc_output_split)) { - return false; + return ::tensorflow::Status::OK(); } // Fully connected output activation function if (!MatchOperatorInputs(*fc_output_sig, *model, OperatorType::kSplit, &tmp) || (tmp != fc_output_split)) { - return false; + return ::tensorflow::Status::OK(); } // Fully connected output split Operator* fully_connected; if (!MatchOperatorInputs(*fc_output_split, *model, OperatorType::kNone, nullptr, OperatorType::kFullyConnected, &fully_connected)) { - return false; + return ::tensorflow::Status::OK(); } // Fully connected op @@ -225,13 +227,13 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { OperatorType::kConcatenation, &concat_inputs, OperatorType::kNone, nullptr, OperatorType::kNone, nullptr)) { - return false; + return ::tensorflow::Status::OK(); } if (static_cast<FullyConnectedOperator*>(fully_connected)->weights_format != FullyConnectedWeightsFormat::kDefault) { // Not yet implemented: experimental shuffled weights in fused LSTM cell. - return false; + return ::tensorflow::Status::OK(); } // Emplace a new LSTM cell operator @@ -300,7 +302,8 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { model->operators.erase(FindOperator(model, *fully_connected)); DeleteArrayIfUnused(concat_inputs->outputs[0], model); model->operators.erase(FindOperator(model, *concat_inputs)); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |