aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc18
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index d26c3b2878..502de88f7c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -274,6 +274,19 @@ bool PropagateMinMaxAmongArrays(Model* model,
return changed;
}
+bool HardcodeMinMaxForReshape(Model* model, Operator* op) {
+ Array& input = model->GetArray(op->inputs[0]);
+ Array& output = model->GetArray(op->outputs[0]);
+
+ // If input and output both exist or do not exist, do nothing.
+ if ((!input.minmax && !output.minmax) || (input.minmax && output.minmax)) {
+ return false;
+ }
+
+ // Otherwise propagate info amongst the input and output array.
+ return PropagateMinMaxAmongArrays(model, {op->inputs[0], op->outputs[0]});
+}
+
bool HardcodeMinMaxForLstmCell(Model* model, Operator* op) {
CHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS);
CHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS);
@@ -370,7 +383,6 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
case OperatorType::kSlice:
case OperatorType::kStridedSlice:
case OperatorType::kSqueeze:
- case OperatorType::kReshape:
case OperatorType::kExpandDims:
case OperatorType::kPad:
case OperatorType::kGather:
@@ -416,6 +428,10 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
changed = HardcodeMinMaxForLstmCell(model, op);
break;
+ case OperatorType::kReshape:
+ changed = HardcodeMinMaxForReshape(model, op);
+ break;
+
default:
break;
}