aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-29 11:46:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 11:49:44 -0700
commit406afc79b7b7b8b277e3794137c41b6049199005 (patch)
tree9ea62feb766d0e33fcc9b59c1c6b11d5c0599b5f
parentc8e967357ef0bf040e85e1fb1aa85af54e8d5689 (diff)
More support for fused quantized LSTM in TFLite interpreter
PiperOrigin-RevId: 202682712
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc10
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc108
2 files changed, 35 insertions, 83 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index 39f55208e4..2f1bb8f0ad 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -228,6 +228,14 @@ bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min,
return true;
}
+bool MinMaxApproximatelyEqual(const MinMax& minmax1, const MinMax& minmax2) {
+ const double magnitude =
+ std::min(minmax1.max - minmax1.min, minmax2.max - minmax2.min);
+ const double tolerated = 1e-6 * magnitude;
+ return std::abs(minmax1.min - minmax2.min) < tolerated &&
+ std::abs(minmax1.max - minmax2.max) < tolerated;
+}
+
// Propagates MinMax from any of the listed arrays, to all others.
// If multiple of these arrays have MinMax, then these are required
// to agree with each other.
@@ -250,7 +258,7 @@ bool PropagateMinMaxAmongArrays(Model* model,
for (const string& array_name : array_names) {
auto& array = model->GetArray(array_name);
if (array.minmax) {
- CHECK(*array.minmax == *reference_minmax)
+ CHECK(MinMaxApproximatelyEqual(*array.minmax, *reference_minmax))
<< "Both the following arrays have minmax, and they disagree: "
<< reference_array_name << " (" << reference_minmax->min << ","
<< reference_minmax->max << ") and " << array_name << " ("
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
index 3ca7f53512..c0b014b45e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
@@ -35,6 +35,26 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
return it;
}
+bool ValidateSourceOp(const Model& model, const string& array_name,
+ OperatorType op_type, Operator** source_op) {
+ if (op_type == OperatorType::kNone) {
+ CHECK(!source_op);
+ } else {
+ CHECK(source_op);
+ *source_op = GetOpWithOutput(model, array_name);
+ if (*source_op == nullptr) {
+ return false;
+ }
+
+ // Check that first operator, if connected, is of correct type
+ if ((*source_op)->type != op_type) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
// Returns true if the given operator has exactly 1 input, and is connected to
// the given op_type.
// We use kNone to indicate an input unattached to an operator output. Usually
@@ -47,24 +67,10 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
}
// Check if first input is disconnected/connected to an operator
- Operator* x = GetOpWithOutput(model, op.inputs[0]);
- if ((op_type == OperatorType::kNone) && (x != nullptr)) {
- return false;
- }
- if ((op_type != OperatorType::kNone) && (x == nullptr)) {
+ if (!ValidateSourceOp(model, op.inputs[0], op_type, connected_op)) {
return false;
}
- // Check that first operator, if connected, is of correct type
- if ((x != nullptr) && (x->type != op_type)) {
- return false;
- }
-
- // Successfully matched. Optionally return matching input operators.
- if (connected_op) {
- *connected_op = x;
- }
-
return true;
}
@@ -81,40 +87,15 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
}
// Check if first input is disconnected/connected to an operator
- Operator* x = GetOpWithOutput(model, op.inputs[0]);
- if ((a_op_type == OperatorType::kNone) && (x != nullptr)) {
- return false;
- }
- if ((a_op_type != OperatorType::kNone) && (x == nullptr)) {
- return false;
- }
-
- // Check that first operator, if connected, is of correct type
- if ((x != nullptr) && (x->type != a_op_type)) {
+ if (!ValidateSourceOp(model, op.inputs[0], a_op_type, a_op)) {
return false;
}
// Check if second input is disconnected/connected to an operator
- Operator* y = GetOpWithOutput(model, op.inputs[1]);
- if ((b_op_type == OperatorType::kNone) && (y != nullptr)) {
- return false;
- }
- if ((b_op_type != OperatorType::kNone) && (y == nullptr)) {
+ if (!ValidateSourceOp(model, op.inputs[1], b_op_type, b_op)) {
return false;
}
- // Check that second operator, if connected, is of correct type
- if ((y != nullptr) && (y->type != b_op_type)) {
- return false;
- }
-
- // Successfully matched. Optionally return matching input operators.
- if (a_op != nullptr) {
- *a_op = x;
- }
- if (b_op != nullptr) {
- *b_op = y;
- }
return true;
}
@@ -132,57 +113,20 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
}
// Check if first input is disconnected/connected to an operator
- Operator* x = GetOpWithOutput(model, op.inputs[0]);
- if ((a_op_type == OperatorType::kNone) && (x != nullptr)) {
- return false;
- }
- if ((a_op_type != OperatorType::kNone) && (x == nullptr)) {
- return false;
- }
-
- // Check that first operator, if connected, is of correct type
- if ((x != nullptr) && (x->type != a_op_type)) {
+ if (!ValidateSourceOp(model, op.inputs[0], a_op_type, a_op)) {
return false;
}
// Check if second input is disconnected/connected to an operator
- Operator* y = GetOpWithOutput(model, op.inputs[1]);
- if ((b_op_type == OperatorType::kNone) && (y != nullptr)) {
- return false;
- }
- if ((b_op_type != OperatorType::kNone) && (y == nullptr)) {
- return false;
- }
-
- // Check that second operator, if connected, is of correct type
- if ((y != nullptr) && (y->type != b_op_type)) {
+ if (!ValidateSourceOp(model, op.inputs[1], b_op_type, b_op)) {
return false;
}
// Check if third input is disconnected/connected to an operator
- Operator* z = GetOpWithOutput(model, op.inputs[2]);
- if ((c_op_type == OperatorType::kNone) && (z != nullptr)) {
- return false;
- }
- if ((c_op_type != OperatorType::kNone) && (z == nullptr)) {
+ if (!ValidateSourceOp(model, op.inputs[2], c_op_type, c_op)) {
return false;
}
- // Check that third operator, if connected, is of correct type
- if ((z != nullptr) && (z->type != c_op_type)) {
- return false;
- }
-
- // Successfully matched. Optionally return matching input operators.
- if (a_op != nullptr) {
- *a_op = x;
- }
- if (b_op != nullptr) {
- *b_op = y;
- }
- if (c_op != nullptr) {
- *c_op = z;
- }
return true;
}