diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc index ce825c91af..69209b8dec 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc @@ -24,20 +24,25 @@ limitations under the License. namespace toco { -bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveGatherAttributes::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto* gather_op = model->operators[op_index].get(); - if (gather_op->type != OperatorType::kGather) return false; + if (gather_op->type != OperatorType::kGather) + return ::tensorflow::Status::OK(); auto* op = static_cast<GatherOperator*>(gather_op); if (op->axis) { // Attributes already resolved - return false; + return ::tensorflow::Status::OK(); } - if (op->inputs.size() != 3) return false; - if (!IsConstantParameterArray(*model, op->inputs[2])) return false; + if (op->inputs.size() != 3) return ::tensorflow::Status::OK(); + if (!IsConstantParameterArray(*model, op->inputs[2])) + return ::tensorflow::Status::OK(); const auto& indices_array = model->GetArray(op->inputs[2]); - if (!indices_array.has_shape()) return false; + if (!indices_array.has_shape()) return ::tensorflow::Status::OK(); const auto& axis_data = indices_array.GetBuffer<ArrayDataType::kInt32>().data; CHECK_EQ(axis_data.size(), 1) << "Multidimensional gather not supported on " << LogName(*op); @@ -47,7 +52,8 @@ bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) { DeleteArrayIfUsedOnce(op->inputs[2], model); op->inputs.resize(2); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |