diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc | 26 |
1 files changed, 15 insertions, 11 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc index f6f95481b5..5400d395ff 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc @@ -41,11 +41,14 @@ bool ComputeFillArray(Model* model, FillOperator* op) { return true; } -bool ResolveConstantFill::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantFill::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto fill_it = model->operators.begin() + op_index; auto* base_op = fill_it->get(); if (base_op->type != OperatorType::kFill) { - return false; + return ::tensorflow::Status::OK(); } auto* op = static_cast<FillOperator*>(base_op); @@ -55,44 +58,44 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes - return false; + return ::tensorflow::Status::OK(); } const auto& val_array = model->GetArray(op->inputs[1]); if (!val_array.has_shape()) { // Yield until the value shape has been resolved. - return false; + return ::tensorflow::Status::OK(); } if (!IsConstantParameterArray(*model, op->inputs[1])) { // Yield until the value is constant. - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(RequiredBufferSizeForShape(val_array.shape()), 1); switch (output_array.data_type) { case ArrayDataType::kFloat: if (!ComputeFillArray<ArrayDataType::kFloat>(model, op)) { - return false; + return ::tensorflow::Status::OK(); } break; case ArrayDataType::kUint8: if (!ComputeFillArray<ArrayDataType::kUint8>(model, op)) { - return false; + return ::tensorflow::Status::OK(); } break; case ArrayDataType::kInt32: if (!ComputeFillArray<ArrayDataType::kInt32>(model, op)) { - return false; + return ::tensorflow::Status::OK(); } break; case ArrayDataType::kInt64: if (!ComputeFillArray<ArrayDataType::kInt64>(model, op)) { - return false; + return ::tensorflow::Status::OK(); } break; default: @@ -114,7 +117,8 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) { // Erase the operator model->operators.erase(fill_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |