aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc33
1 files changed, 18 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
index c0b014b45e..7fd8f906e2 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
@@ -132,7 +132,9 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
} // namespace
-bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status IdentifyLstmCell::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
// This LSTM cell identification method is not invariant to commutation of
// commutative operator inputs. For example, if input[0] and input[1] of the
// final output multiplication were swapped, this method would not identify it
@@ -143,13 +145,13 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
auto op_it = model->operators.begin() + op_index;
Operator* final_output_mul = op_it->get();
if (final_output_mul->type != OperatorType::kMul) {
- return false;
+ return ::tensorflow::Status::OK();
}
Operator *state_output_tanh, *fc_output_sig;
if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh,
&state_output_tanh, OperatorType::kLogistic,
&fc_output_sig)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// State output TanH
@@ -158,7 +160,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
Operator* state_combine_add;
if (!MatchOperatorInputs(*state_output_tanh, *model, OperatorType::kAdd,
&state_combine_add)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// State forget & remember addition
@@ -166,7 +168,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
if (!MatchOperatorInputs(*state_combine_add, *model, OperatorType::kMul,
&state_forget_mul, OperatorType::kMul,
&state_remember_mul)) {
- return false;
+ return ::tensorflow::Status::OK();
}
const string prev_state = state_forget_mul->inputs[0];
@@ -175,7 +177,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
if (!MatchOperatorInputs(*state_forget_mul, *model, OperatorType::kNone,
nullptr, OperatorType::kLogistic,
&state_forget_sig)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// State remember gate
@@ -183,40 +185,40 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
if (!MatchOperatorInputs(*state_remember_mul, *model, OperatorType::kLogistic,
&state_remember_sig, OperatorType::kTanh,
&state_info_tanh)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// State remember "information" activation function
Operator* fc_output_split;
if (!MatchOperatorInputs(*state_info_tanh, *model, OperatorType::kSplit,
&fc_output_split)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// State remember gate activation function
Operator* tmp;
if (!MatchOperatorInputs(*state_remember_sig, *model, OperatorType::kSplit,
&tmp) ||
(tmp != fc_output_split)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// State forget gate activation function
if (!MatchOperatorInputs(*state_forget_sig, *model, OperatorType::kSplit,
&tmp) ||
(tmp != fc_output_split)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Fully connected output activation function
if (!MatchOperatorInputs(*fc_output_sig, *model, OperatorType::kSplit,
&tmp) ||
(tmp != fc_output_split)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Fully connected output split
Operator* fully_connected;
if (!MatchOperatorInputs(*fc_output_split, *model, OperatorType::kNone,
nullptr, OperatorType::kFullyConnected,
&fully_connected)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Fully connected op
@@ -225,13 +227,13 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
OperatorType::kConcatenation, &concat_inputs,
OperatorType::kNone, nullptr, OperatorType::kNone,
nullptr)) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (static_cast<FullyConnectedOperator*>(fully_connected)->weights_format !=
FullyConnectedWeightsFormat::kDefault) {
// Not yet implemented: experimental shuffled weights in fused LSTM cell.
- return false;
+ return ::tensorflow::Status::OK();
}
// Emplace a new LSTM cell operator
@@ -300,7 +302,8 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
model->operators.erase(FindOperator(model, *fully_connected));
DeleteArrayIfUnused(concat_inputs->outputs[0], model);
model->operators.erase(FindOperator(model, *concat_inputs));
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco