aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
index 404f27e067..5295eeccec 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
@@ -59,6 +59,15 @@ bool IsReshapeTrivial(const Model& model, const Operator& op,
if (CountOpsWithInput(model, op.outputs[0]) == 1) {
const auto* next_op = GetOpWithInput(model, op.outputs[0]);
if (next_op->type == OperatorType::kReshape) {
+ if (!IsDiscardableArray(model, next_op->outputs[0])) {
+ // If the |next_op| output is used as a model output we need to preserve
+ // its shape.
+ transformation->AddMessageF(
+ "%s cannot be merged into following reshape %s as it is "
+ "non-discardable and must keep the specified shape",
+ LogName(op), LogName(*next_op));
+ return false;
+ }
transformation->AddMessageF(
"%s is trivial because its output is only consumed by another "
"Reshape op %s",