From 12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Tue, 9 Oct 2018 11:38:15 -0700 Subject: Return ::tensorflow::Status in Toco Graph Transformations. PiperOrigin-RevId: 216392908 --- .../toco/graph_transformations/identify_dilated_conv.cc | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc') diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc index aac77eb39e..9e4a3005a1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc @@ -168,7 +168,10 @@ bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op, return true; } -bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { +::tensorflow::Status IdentifyDilatedConv::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; auto* stb_op = it->get(); @@ -176,17 +179,17 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { // *************************************************************************** // SpaceToBatch Op. if (stb_op->type != OperatorType::kSpaceToBatchND) { - return false; + return ::tensorflow::Status::OK(); } if (stb_op->inputs.size() != 3) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(stb_op->outputs.size(), 1); // Extract the dilation factor from Input[1] of SpaceToBatch // TODO(mjmatthews): Support 2D dilation factors. const auto& block_shape_array = model->GetArray(stb_op->inputs[1]); if (!block_shape_array.buffer) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(block_shape_array.shape().dimensions_count(), 1); int dilation_factor = @@ -195,7 +198,7 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { // Expand Op auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]); if (!post_stb_op) { - return false; + return ::tensorflow::Status::OK(); } bool has_expand_op = false; if (post_stb_op->type == OperatorType::kExpandDims) { @@ -229,7 +232,8 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { } } - return changed; + *modified = changed; + return ::tensorflow::Status::OK(); } } // namespace toco -- cgit v1.2.3