diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index d26c3b2878..502de88f7c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -274,6 +274,19 @@ bool PropagateMinMaxAmongArrays(Model* model, return changed; } +bool HardcodeMinMaxForReshape(Model* model, Operator* op) { + Array& input = model->GetArray(op->inputs[0]); + Array& output = model->GetArray(op->outputs[0]); + + // If input and output both exist or do not exist, do nothing. + if ((!input.minmax && !output.minmax) || (input.minmax && output.minmax)) { + return false; + } + + // Otherwise propagate info amongst the input and output array. + return PropagateMinMaxAmongArrays(model, {op->inputs[0], op->outputs[0]}); +} + bool HardcodeMinMaxForLstmCell(Model* model, Operator* op) { CHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS); CHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS); @@ -370,7 +383,6 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { case OperatorType::kSlice: case OperatorType::kStridedSlice: case OperatorType::kSqueeze: - case OperatorType::kReshape: case OperatorType::kExpandDims: case OperatorType::kPad: case OperatorType::kGather: @@ -416,6 +428,10 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForLstmCell(model, op); break; + case OperatorType::kReshape: + changed = HardcodeMinMaxForReshape(model, op); + break; + default: break; } |