aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
index aac77eb39e..9e4a3005a1 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
@@ -168,7 +168,10 @@ bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op,
return true;
}
-bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status IdentifyDilatedConv::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto it = model->operators.begin() + op_index;
auto* stb_op = it->get();
@@ -176,17 +179,17 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
// ***************************************************************************
// SpaceToBatch Op.
if (stb_op->type != OperatorType::kSpaceToBatchND) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (stb_op->inputs.size() != 3) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(stb_op->outputs.size(), 1);
// Extract the dilation factor from Input[1] of SpaceToBatch
// TODO(mjmatthews): Support 2D dilation factors.
const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
if (!block_shape_array.buffer) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
int dilation_factor =
@@ -195,7 +198,7 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
// Expand Op
auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
if (!post_stb_op) {
- return false;
+ return ::tensorflow::Status::OK();
}
bool has_expand_op = false;
if (post_stb_op->type == OperatorType::kExpandDims) {
@@ -229,7 +232,8 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
}
}
- return changed;
+ *modified = changed;
+ return ::tensorflow::Status::OK();
}
} // namespace toco