diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 323eefcd3a..40cd6dea82 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -32,7 +32,10 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op, } } // namespace -bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status PropagateArrayDataTypes::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; auto* op = it->get(); @@ -40,7 +43,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { for (const auto& input : op->inputs) { if (!model->IsOptionalArray(input) && model->GetArray(input).data_type == ArrayDataType::kNone) { - return false; + return ::tensorflow::Status::OK(); } } // Record data types of output before processing, so we can see at the @@ -131,7 +134,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { auto* rand_op = static_cast<RandomUniformOperator*>(op); // The output type of RandomUniform is specified with an attribute if (rand_op->dtype == ArrayDataType::kNone) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(op->outputs.size(), 1); SetDataTypeForAllOutputs(model, op, rand_op->dtype); @@ -153,7 +156,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { // This can make unsupported_op->output_data_types have more elements than // op->outputs. if (unsupported_op->output_data_types.size() < op->outputs.size()) { - return false; + return ::tensorflow::Status::OK(); } for (int i = 0; i < op->outputs.size(); ++i) { const string& output = op->outputs[i]; @@ -164,7 +167,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { } case OperatorType::kExpandDims: { // Yield on ExpandDim until it is converted to Reshape - return false; + return ::tensorflow::Status::OK(); } case OperatorType::kSelect: { // Select produces outputs with the same type as their 2nd input @@ -248,10 +251,11 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { // Return true if any output data type changed, false if none changed. for (const auto& output : op->outputs) { if (old_output_data_types[output] != model->GetArray(output).data_type) { - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } - return false; + return ::tensorflow::Status::OK(); } } // namespace toco |