aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc22
1 files changed, 14 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
index e760d08e5a..5a838168de 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
@@ -24,29 +24,35 @@ limitations under the License.
namespace toco {
-bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveSliceAttributes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto slice_it = model->operators.begin() + op_index;
auto* slice_op = slice_it->get();
- if (slice_op->type != OperatorType::kSlice) return false;
+ if (slice_op->type != OperatorType::kSlice) return ::tensorflow::Status::OK();
auto* op = static_cast<SliceOperator*>(slice_op);
- if (!op->begin.empty()) return false;
+ if (!op->begin.empty()) return ::tensorflow::Status::OK();
CHECK_EQ(op->inputs.size(), 3);
- if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
- if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[1]))
+ return ::tensorflow::Status::OK();
+ if (!IsConstantParameterArray(*model, op->inputs[2]))
+ return ::tensorflow::Status::OK();
const auto& begin_array = model->GetArray(op->inputs[1]);
- if (!begin_array.has_shape()) return false;
+ if (!begin_array.has_shape()) return ::tensorflow::Status::OK();
const auto& size_array = model->GetArray(op->inputs[2]);
- if (!size_array.has_shape()) return false;
+ if (!size_array.has_shape()) return ::tensorflow::Status::OK();
op->begin = begin_array.GetBuffer<ArrayDataType::kInt32>().data;
op->size = size_array.GetBuffer<ArrayDataType::kInt32>().data;
// TODO(dkalenichenko): Delete the extra inputs?
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco