aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc18
1 files changed, 11 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 323eefcd3a..40cd6dea82 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -32,7 +32,10 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
}
} // namespace
-bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status PropagateArrayDataTypes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
auto* op = it->get();
@@ -40,7 +43,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
for (const auto& input : op->inputs) {
if (!model->IsOptionalArray(input) &&
model->GetArray(input).data_type == ArrayDataType::kNone) {
- return false;
+ return ::tensorflow::Status::OK();
}
}
// Record data types of output before processing, so we can see at the
@@ -131,7 +134,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
auto* rand_op = static_cast<RandomUniformOperator*>(op);
// The output type of RandomUniform is specified with an attribute
if (rand_op->dtype == ArrayDataType::kNone) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(op->outputs.size(), 1);
SetDataTypeForAllOutputs(model, op, rand_op->dtype);
@@ -153,7 +156,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// This can make unsupported_op->output_data_types have more elements than
// op->outputs.
if (unsupported_op->output_data_types.size() < op->outputs.size()) {
- return false;
+ return ::tensorflow::Status::OK();
}
for (int i = 0; i < op->outputs.size(); ++i) {
const string& output = op->outputs[i];
@@ -164,7 +167,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
}
case OperatorType::kExpandDims: {
// Yield on ExpandDim until it is converted to Reshape
- return false;
+ return ::tensorflow::Status::OK();
}
case OperatorType::kSelect: {
// Select produces outputs with the same type as their 2nd input
@@ -248,10 +251,11 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// Return true if any output data type changed, false if none changed.
for (const auto& output : op->outputs) {
if (old_output_data_types[output] != model->GetArray(output).data_type) {
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
}
- return false;
+ return ::tensorflow::Status::OK();
}
} // namespace toco