aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
blob: 3cd5d06baebc5a7a1807f156a52a24675343be7d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
#include "tensorflow/core/platform/logging.h"

namespace toco {

::tensorflow::Status RemoveUnusedOp::Run(Model* model, std::size_t op_index,
                                         bool* modified) {
  *modified = false;
  const auto it = model->operators.begin() + op_index;
  const auto* op = it->get();

  // Bail if any output is used, and is not an input_array of
  // the model. We allow specifying an arbitrary input_array,
  // treating the part of the graph leading up to it as unused.
  for (const auto& output : op->outputs) {
    CHECK(model->HasArray(output));
    // If this output is provided as the model's input array,
    // then we don't need this operator to produce its contents.
    if (IsInputArray(*model, output)) {
      continue;
    }
    // If this output is provided as a RNN's state array,
    // then we don't need this operator to produce its contents.
    // So far this case has only been encountered with TensorFlow
    // Fill ops used to zero-initialize RNN states, which is
    // redundant for us as we zero-initialize RNN states anyway.
    bool found_output_as_rnn_state_array = false;
    for (const auto& rnn_state : model->flags.rnn_states()) {
      if (output == rnn_state.state_array()) {
        CHECK(op->type == OperatorType::kFill ||
              op->type == OperatorType::kIdentity);
        found_output_as_rnn_state_array = true;
        break;
      }
    }
    if (found_output_as_rnn_state_array) {
      continue;
    }
    for (const string& output_array : model->flags.output_arrays()) {
      if (output == output_array) {
        return ::tensorflow::Status::OK();
      }
    }
    for (const auto& rnn_state : model->flags.rnn_states()) {
      if (output == rnn_state.back_edge_source_array()) {
        // The output is consumed by a RNN back-edge..
        if (!IsDiscardableArray(*model, rnn_state.back_edge_source_array()) ||
            !IsDiscardableArray(*model, rnn_state.state_array()) ||
            CountOpsWithInput(*model, rnn_state.state_array())) {
          return ::tensorflow::Status::OK();
        }
      }
    }
    if (CountOpsWithInput(*model, output)) {
      return ::tensorflow::Status::OK();
    }
  }

  if (op->unresolved_outputs) {
    AddMessageF("Not discarding %s because it has unresolved outputs.",
                LogName(*op));
    return ::tensorflow::Status::OK();
  }

  AddMessageF("Discarding %s because none of its outputs is used.",
              LogName(*op));

  // At that point we know that none of the outputs is used, so we will
  // definitely remove the node and all its outputs.

  // Remove any input array that not the output of another op, and only used by
  // this op.
  for (const auto& input : op->inputs) {
    if (!GetOpWithOutput(*model, input)) {
      DeleteArrayIfUsedOnce(input, model);
    }
  }

  // Remove the node and its now-unused output arrays.
  for (const auto& output : op->outputs) {
    // If the output array is the model's input array, don't remove that.
    // That's the case when cropping a model at a given --input_array.
    if (IsDiscardableArray(*model, output)) {
      model->EraseArray(output);
    }
  }
  model->operators.erase(it);
  *modified = true;
  return ::tensorflow::Status::OK();
}

}  // namespace toco