diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/tooling_util.cc | 1552 |
1 files changed, 1552 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc new file mode 100644 index 0000000000..bcbfed62d3 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -0,0 +1,1552 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +#include <functional> +#include <iterator> +#include <set> +#include <unordered_map> +#include <unordered_set> +#include <utility> + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "tensorflow/contrib/lite/toco/dump_graphviz.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/core/platform/logging.h" + + +namespace toco { + +string LogName(const Operator& op) { + const string& opname = HelpfulOperatorTypeName(op); + if (op.outputs.empty()) { + return toco::port::StringF("{%s operator}", opname); + } else { + return toco::port::StringF("{%s operator with output %s}", opname, + op.outputs[0]); + } +} + +bool IsInputArray(const Model& model, const string& name) { + for (const auto& input_array : model.flags.input_arrays()) { + if (input_array.name() == name) { + return true; + } + } + return false; +} + +bool IsArrayConsumed(const Model& model, const string& name) { + if (GetOpWithInput(model, name)) { + return true; + } + for (const string& model_output : model.flags.output_arrays()) { + if (model_output == name) { + return true; + } + } + for (const auto& rnn_state : model.flags.rnn_states()) { + if (rnn_state.back_edge_source_array() == name) { + return true; + } + } + return false; +} + +int CountTrueOutputs(const Model& model, const Operator& op) { + int count = 0; + for (const string& output : op.outputs) { + if (IsArrayConsumed(model, output)) { + ++count; + } + } + return count; +} + +int CountOpsWithInput(const Model& model, const string& array_name) { + int count = 0; + for (const auto& op : model.operators) { + for (auto& input : op->inputs) { + if (input == array_name) { + count++; + } + } + } + return count; +} + +bool DeleteArrayIfUnused(const string& array_name, Model* model) { + if (CountOpsWithInput(*model, array_name) == 0) { + model->arrays.erase(array_name); + return true; + } + return false; +} + +std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput( + const Model& model, const string& array_name) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + for (auto& output : it->get()->outputs) { + if (output == array_name) { + return it; + } + } + } + return model.operators.end(); +} + +std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput( + Model& model, const string& array_name) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + for (auto& output : it->get()->outputs) { + if (output == array_name) { + return it; + } + } + } + return model.operators.end(); +} + +Operator* GetOpWithOutput(const Model& model, const string& array_name) { + auto it = FindOpWithOutput(model, array_name); + return it == model.operators.end() ? nullptr : it->get(); +} + +// GetFirstOpWithInput assumes that this finds the first op. +std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput( + const Model& model, const string& array_name) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + for (auto& input : it->get()->inputs) { + if (input == array_name) { + return it; + } + } + } + return model.operators.end(); +} + +std::vector<std::unique_ptr<Operator>>::const_iterator FindOp( + const Model& model, const Operator* op) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + if (it->get() == op) { + return it; + } + } + return model.operators.end(); +} + +std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model, + const Operator* op) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + if (it->get() == op) { + return it; + } + } + return model.operators.end(); +} + +Operator* GetOpWithInput(const Model& model, const string& array_name) { + auto it = FindOpWithInput(model, array_name); + return it == model.operators.end() ? nullptr : it->get(); +} + +Operator* GetFirstOpWithInput(const Model& model, const string& array_name) { + auto it = FindOpWithInput(model, array_name); + return it == model.operators.end() ? nullptr : it->get(); +} + +string FormatArraysList(const Model& model, const std::vector<string>& list) { + if (list.empty()) { + return "[]"; + } + string result = ""; + if (list.size() > 1) { + result += "[ "; + } + for (std::size_t i = 0; i < list.size(); i++) { + if (i > 0) { + result += ", "; + } + result += list[i]; + } + if (list.size() > 1) { + result += " ]"; + } + return result; +} + +const char* OperatorTypeName(OperatorType type) { + switch (type) { +#define HANDLE_OPERATORTYPENAME_CASE(c) \ + case OperatorType::k##c: \ + return #c; + HANDLE_OPERATORTYPENAME_CASE(Add) + HANDLE_OPERATORTYPENAME_CASE(AveragePool) + HANDLE_OPERATORTYPENAME_CASE(BatchNormalization) + HANDLE_OPERATORTYPENAME_CASE(Conv) + HANDLE_OPERATORTYPENAME_CASE(Concatenation) + HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv) + HANDLE_OPERATORTYPENAME_CASE(DepthToSpace) + HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth) + HANDLE_OPERATORTYPENAME_CASE(FullyConnected) + HANDLE_OPERATORTYPENAME_CASE(Dequantize) + HANDLE_OPERATORTYPENAME_CASE(L2Normalization) + HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization) + HANDLE_OPERATORTYPENAME_CASE(Logistic) + HANDLE_OPERATORTYPENAME_CASE(LstmCell) + HANDLE_OPERATORTYPENAME_CASE(MaxPool) + HANDLE_OPERATORTYPENAME_CASE(L2Pool) + HANDLE_OPERATORTYPENAME_CASE(FakeQuant) + HANDLE_OPERATORTYPENAME_CASE(Mul) + HANDLE_OPERATORTYPENAME_CASE(Relu) + HANDLE_OPERATORTYPENAME_CASE(Relu1) + HANDLE_OPERATORTYPENAME_CASE(Relu6) + HANDLE_OPERATORTYPENAME_CASE(ReorderAxes) + HANDLE_OPERATORTYPENAME_CASE(Softmax) + HANDLE_OPERATORTYPENAME_CASE(Div) + HANDLE_OPERATORTYPENAME_CASE(Tanh) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowAssert) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreater) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreaterEqual) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowIdentity) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowLess) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowLessEqual) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMatMul) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMax) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMaximum) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMerge) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMin) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMinimum) + HANDLE_OPERATORTYPENAME_CASE(Pad) + HANDLE_OPERATORTYPENAME_CASE(StridedSlice) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowReshape) + HANDLE_OPERATORTYPENAME_CASE(Squeeze) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowRsqrt) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowShape) + HANDLE_OPERATORTYPENAME_CASE(Slice) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowSplit) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowSqrt) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowSquare) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowSwitch) + HANDLE_OPERATORTYPENAME_CASE(Sub) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowSum) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowTile) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcat) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcatV2) + HANDLE_OPERATORTYPENAME_CASE(Cast) + HANDLE_OPERATORTYPENAME_CASE(Floor) + HANDLE_OPERATORTYPENAME_CASE(Gather) + HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear) + HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND) + HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND) + HANDLE_OPERATORTYPENAME_CASE(Mean) + HANDLE_OPERATORTYPENAME_CASE(Svdf) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported) + default: + LOG(FATAL) << "Unhandled op type"; +#undef HANDLE_OPERATORTYPENAME_CASE + } +} + +string HelpfulOperatorTypeName(const Operator& op) { + if (op.type == OperatorType::kTensorFlowUnsupported) { + return toco::port::StringF( + "(Unsupported TensorFlow op: %s)", + static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op); + } + return OperatorTypeName(op.type); +} + +void LogSummary(int log_level, const Model& model) { + VLOG(log_level) << "Operators summary (" << model.operators.size() + << " operators): "; + std::unordered_multiset<OperatorType> ops_by_type; + for (const auto& op : model.operators) { + ops_by_type.insert(op->type); + } + auto it = ops_by_type.begin(); + while (it != ops_by_type.end()) { + int count = ops_by_type.count(*it); + VLOG(log_level) << " " << OperatorTypeName(*it) << ": " << count; + std::advance(it, count); + } +} + +void LogArray(int log_level, const Model& model, const string& name) { + const auto& array = model.GetArray(name); + VLOG(log_level) << "Array: " << name; + switch (array.data_type) { + case ArrayDataType::kNone: + break; + case ArrayDataType::kFloat: + VLOG(log_level) << " Data type: kFloat"; + break; + case ArrayDataType::kInt32: + VLOG(log_level) << " Data type: kInt32"; + break; + case ArrayDataType::kUint8: + VLOG(log_level) << " Data type: kUint8"; + break; + default: + VLOG(log_level) << " Data type: other (numerical value: " + << static_cast<int>(array.data_type) << ")"; + break; + } + if (array.buffer) { + VLOG(log_level) << " Constant Buffer"; + } + if (array.alloc) { + VLOG(log_level) << " Transient Alloc"; + } + if (array.has_shape()) { + const Shape& array_shape = array.shape(); + if (array_shape.dimensions_count() == 0) { + VLOG(log_level) << " (Zero dimensions)"; + } else { + string message = " Dims: "; + bool first = true; + for (const int dim : array_shape.dims()) { + if (!first) { + message += ", "; + } + first = false; + toco::port::AppendF(&message, "%d", dim); + } + VLOG(log_level) << message; + } + } + if (array.minmax) { + VLOG(log_level) << " MinMax: " << array.minmax->min << " .. " + << array.minmax->max; + } + if (array.quantization_params) { + VLOG(log_level) << " QuantizationParams: zero_point=" + << array.quantization_params->zero_point + << ", scale=" << array.quantization_params->scale; + } +} + +void DumpGraphvizVideoFrame(const Model& model) { + namespace port = toco::port; + + const auto& dump_options = *GraphVizDumpOptions::singleton(); + if (!dump_options.dump_graphviz_video) { + return; + } + CHECK(!dump_options.dump_graphviz.empty()); + // TODO(benoitjacob): the static data here means that this function + // is stateful, not reentrant, and effectively leaks memory till exit + // (since dump_hashes can only grow in size). It also means that it + // really only is intended to be called for a single model during the + // process' lifetime. So it's not great design at all. The overriding + // design aspect here is to make the video-dumping code as unintrusive + // and self-contained as possible. Eventually, we'll want to have that + // cleaned-up, but that will require some form of general statefulness + // in toco (some kind of 'tooling state' data structure) that does + // not exist at present, and would be premature to design here just for + // this new video-dumping feature. + static int dump_id = 0; + static std::unordered_set<std::size_t> dump_hashes; + string graphviz_dump; + DumpGraphviz(model, &graphviz_dump); + std::size_t hash = std::hash<string>{}(graphviz_dump); + if (!dump_hashes.count(hash)) { + dump_hashes.insert(hash); + CHECK(port::file::SetContents( + port::file::JoinPath( + dump_options.dump_graphviz, + toco::port::StringF("toco_video_%05d.dot", dump_id)), + graphviz_dump, port::file::Defaults()) + .ok()); + dump_id++; + } +} + +void LogDump(int log_level, const string& message, const Model& model) { + namespace port = toco::port; + const auto& dump_options = *GraphVizDumpOptions::singleton(); + + DumpGraphvizVideoFrame(model); + if (!dump_options.dump_graphviz.empty()) { + string graphviz_dump; + + DumpGraphviz(model, &graphviz_dump); + CHECK(port::file::SetContents( + port::file::JoinPath( + dump_options.dump_graphviz, + absl::StrCat("toco_", + absl::StrReplaceAll(message, {{" ", "_"}}), + ".dot")), + graphviz_dump, port::file::Defaults()) + .ok()); + } + + if (!VLOG_IS_ON(log_level)) { + return; + } + VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")"; + LogSummary(log_level, model); + std::unordered_set<string> already_printed_arrays; + for (const auto& op : model.operators) { + for (const auto& input : op->inputs) { + if (!already_printed_arrays.count(input)) { + already_printed_arrays.insert(input); + LogArray(log_level, model, input); + } + } + VLOG(log_level) << HelpfulOperatorTypeName(*op) << " : "; + VLOG(log_level) << " " << FormatArraysList(model, op->inputs) << " -> " + << FormatArraysList(model, op->outputs); + if (op->fused_activation_function != FusedActivationFunctionType::kNone) { + VLOG(log_level) << " (with fused activation function)"; + } + for (const auto& output : op->outputs) { + if (!already_printed_arrays.count(output)) { + already_printed_arrays.insert(output); + LogArray(log_level, model, output); + } + } + } + VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")"; +} + +// Note remaining raw-array extension in ProcessTensorFlowReshapeOperator(). +void ExtendShape(Shape* shape, int new_shape_size) { + CHECK_GE(new_shape_size, shape->dimensions_count()); + const int size_increase = new_shape_size - shape->dimensions_count(); + auto* shape_dims = shape->mutable_dims(); + shape_dims->insert(shape_dims->begin(), size_increase, 1); +} + +// TODO(b/62904716) Remove along with remaining uses. +void UnextendShape(Shape* shape, int new_shape_size) { + CHECK_LE(new_shape_size, shape->dimensions_count()); + const int size_reduction = shape->dimensions_count() - new_shape_size; + for (int i = 0; i < size_reduction; i++) { + CHECK_EQ(shape->dims(i), 1); + } + std::vector<int>& shape_dims = *shape->mutable_dims(); + shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction); +} + +void CheckShapeDimensions(const Shape& shape) { + for (int i = 0; i < shape.dimensions_count(); ++i) { + CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i + << ". shape = " << ShapeToString(shape); + } +} + +bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) { + CheckShapeDimensions(shape0); + CheckShapeDimensions(shape1); + + const Shape* longer = &shape0; + const Shape* shorter = &shape1; + if (shape1.dimensions_count() > shape0.dimensions_count()) { + longer = &shape1; + shorter = &shape0; + } + + // Walk dimensions back to front until we run out of dimensions in the shorter + // shape. + int longer_index = longer->dimensions_count() - 1; + int shorter_index = shorter->dimensions_count() - 1; + while (shorter_index >= 0) { + const int d_long = longer->dims(longer_index); + const int d_short = shorter->dims(shorter_index); + // Broadcasting fails if the dimensions are different *and* neither is 1. + if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) { + return false; + } + longer_index--; + shorter_index--; + } + return true; +} + +bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) { + CheckShapeDimensions(shape0); + CheckShapeDimensions(shape1); + + const Shape* longer = &shape0; + const Shape* shorter = &shape1; + if (shape1.dimensions_count() > shape0.dimensions_count()) { + longer = &shape1; + shorter = &shape0; + } + + // Walk dimensions back to front until we run out of dimensions in the shorter + // shape. + int longer_index = longer->dimensions_count() - 1; + int shorter_index = shorter->dimensions_count() - 1; + while (shorter_index >= 0) { + const int d_long = longer->dims(longer_index); + const int d_short = shorter->dims(shorter_index); + // Extending fails if the dimensions are different. + if (d_long != d_short) { + return false; + } + longer_index--; + shorter_index--; + } + + // The remaining dimensions in the longer shape must be 1. + while (longer_index >= 0) { + const int d_long = longer->dims(longer_index); + if (d_long != 1) { + return false; + } + longer_index--; + } + + return true; +} + +int RequiredBufferSizeForShape(const Shape& shape) { + int max_offset = 1; + for (const auto& dim : shape.dims()) { + CHECK_GE(dim, 1); + max_offset *= dim; + } + return max_offset; +} + +bool IsConstantParameterArray(const Model& model, const string& name) { + if (!model.arrays.count(name)) { + return false; + } + + return !!model.arrays.at(name)->buffer; +} + +void CheckNoMissingArray(const Model& model) { + for (const auto& op : model.operators) { + for (const auto& input : op->inputs) { + CHECK(model.arrays.count(input)); + } + for (const auto& output : op->outputs) { + CHECK(model.arrays.count(output)); + } + } + for (const auto& input_array : model.flags.input_arrays()) { + CHECK(model.arrays.count(input_array.name())) + << "Input array not found: " << input_array.name(); + } + for (const string& output_array : model.flags.output_arrays()) { + CHECK(model.arrays.count(output_array)) + << "Output array not found: " << output_array; + } + for (const auto& rnn_state : model.flags.rnn_states()) { + CHECK(model.arrays.count(rnn_state.state_array())); + CHECK(model.arrays.count(rnn_state.back_edge_source_array())); + } +} + +void FixNoMissingArray(Model* model) { + for (const auto& op : model->operators) { + for (const auto& input : op->inputs) { + if (!model->arrays.count(input)) { + model->GetOrCreateArray(input); + } + } + for (const auto& output : op->outputs) { + if (!model->arrays.count(output)) { + model->GetOrCreateArray(output); + } + } + } + for (const string& output_array : model->flags.output_arrays()) { + if (!model->arrays.count(output_array)) { + model->GetOrCreateArray(output_array); + } + } +} + +void CheckNoOrphanedArray(const Model& model) { + std::unordered_set<string> arrays_without_known_use; + for (const auto& array : model.arrays) { + arrays_without_known_use.insert(array.first); + } + for (const auto& op : model.operators) { + for (const auto& input : op->inputs) { + arrays_without_known_use.erase(input); + } + for (const auto& output : op->outputs) { + arrays_without_known_use.erase(output); + } + } + if (!arrays_without_known_use.empty()) { + for (const auto& array : arrays_without_known_use) { + LOG(INFO) << "Error: Orphaned array: " << array; + } + } + CHECK(arrays_without_known_use.empty()); +} + +void FixNoOrphanedArray(Model* model) { + std::unordered_set<string> arrays_without_known_use; + for (const auto& array : model->arrays) { + arrays_without_known_use.insert(array.first); + } + for (const auto& op : model->operators) { + for (const auto& input : op->inputs) { + arrays_without_known_use.erase(input); + } + for (const auto& output : op->outputs) { + arrays_without_known_use.erase(output); + } + } + for (const auto& array : arrays_without_known_use) { + model->arrays.erase(array); + } +} + +void CheckArrayFieldsConsistent(const Model& model) { + for (const auto& array_entry : model.arrays) { + const auto& array = array_entry.second; + if (array->has_shape()) { + for (int d : array->shape().dims()) { + CHECK_GE(d, 1); + } + } + // It's OK to have a buffer or an alloc, but not both. + // (Since allocs are for transient arrays without a buffer). + CHECK(!array->buffer || !array->alloc); + // If there is a buffer, its type should be consistent with data_type. + if (array->buffer) { + CHECK(array->buffer->type == array->data_type); + } + } +} + +void CheckOperatorOrdering(const Model& model) { + std::unordered_set<string> arrays_behind_us; + for (const auto& array_entry : model.arrays) { + if (!GetOpWithOutput(model, array_entry.first)) { + arrays_behind_us.insert(array_entry.first); + } + } + for (const auto& op : model.operators) { + for (const auto& input : op->inputs) { + if (!IsConstantParameterArray(model, input)) { + CHECK(arrays_behind_us.count(input)); + } + } + for (const auto& output : op->outputs) { + CHECK(!arrays_behind_us.count(output)); + arrays_behind_us.insert(output); + } + } + for (const string& output_array : model.flags.output_arrays()) { + CHECK(arrays_behind_us.count(output_array)); + } +} + +void FixOperatorOrdering(Model* model) { + std::unordered_set<string> arrays_behind_us; + for (const auto& array_entry : model->arrays) { + if (!GetOpWithOutput(*model, array_entry.first)) { + arrays_behind_us.insert(array_entry.first); + } + } + std::vector<std::unique_ptr<Operator>> old_operators; + std::swap(old_operators, model->operators); + std::set<std::size_t> remaining; + for (std::size_t i = 0; i < old_operators.size(); i++) { + remaining.insert(i); + } + std::unordered_map<string, string> reason_why_leftover; + while (true) { + bool inserted_something = false; + for (auto i : remaining) { + bool can_insert = true; + auto& op = old_operators[i]; + CHECK(op.get()); + for (const auto& input : op->inputs) { + if (!IsConstantParameterArray(*model, input) && + !arrays_behind_us.count(input)) { + for (const string& output : op->outputs) { + reason_why_leftover[output] = input; + } + can_insert = false; + break; + } + } + if (can_insert) { + model->operators.emplace_back(nullptr); + for (const auto& output : op->outputs) { + arrays_behind_us.insert(output); + } + std::swap(op, model->operators.back()); + remaining.erase(i); + inserted_something = true; + break; + } + } + if (!inserted_something) { + break; + } + } + if (!remaining.empty()) { + LOG(ERROR) + << "No viable ordering of operators was found. " + << "Here is a 'backtrace' of at least one part of the graph that is " + << "problematic. It starts with the first operator that has as " + << "problematic input array, and then walks back the graph to " + << "the operator that produced that input array, etc., until we find " + << "the root cause:"; + LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT"; + LOG(ERROR) << "Here is the first-encountered operator with a bad input: "; + const Operator* bad_op = old_operators[*remaining.begin()].get(); + std::unordered_set<string> bad_inputs_already_traced; + // The following while(true) loop should always end with a LOG(FATAL). + while (true) { + LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : " + << FormatArraysList(*model, bad_op->inputs) << " -> " + << FormatArraysList(*model, bad_op->outputs); + bool found_bad_output = false; + string bad_output; + for (const string& output : bad_op->outputs) { + if (reason_why_leftover.count(output)) { + found_bad_output = true; + bad_output = output; + break; + } + } + CHECK(found_bad_output); + const string& bad_input = reason_why_leftover[bad_output]; + LOG(ERROR) << "The bad input here is: " << bad_input; + if (bad_inputs_already_traced.count(bad_input)) { + LOG(FATAL) + << "Cycle found! We already encountered that " + << "input array, " << bad_input << ", earlier in the " + << "above trace! We expect graphs to be acyclic, even " + << "RNNs. Let us know if some graph actually needs to have " + << "cycles, but first, please check if it really is " + << "an *inference* graph. *Training* graphs are out-of-scope " + << "for toco."; + } + bad_inputs_already_traced.insert(bad_input); + bad_op = nullptr; + for (auto i : remaining) { + const Operator* op = old_operators[i].get(); + for (const string& output : op->outputs) { + if (bad_input == output) { + bad_op = op; + break; + } + } + if (bad_op) { + break; + } + } + if (!bad_op) { + LOG(ERROR) << "And that's the root cause: " + << "that array, " << bad_input << ", isn't produced by any " + << "operator, or provided in any other way."; + LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT"; + LOG(FATAL) << "(The above was a multi-line fatal error)"; + } + LOG(ERROR) << "And that array is the output of the following operator:"; + } + } + CHECK(remaining.empty()) + << "Should never get here! In case of bad graph, " + << "the above code should have generated a FATAL error already!"; +} + +// Checks that the --input_arrays of the Model are actually used by at least +// one of the --output_arrays i.e. that the graph contains a path from each one +// of the inputs to at least one of the outputs. This catches cases where the +// user passed the wrong --input_arrays or --output_arrays, which otherwise may +// result in cryptic error messages. +void CheckInputUsedByOutputs(const Model& model) { + std::set<string> used_arrays; + for (const string& output : model.flags.output_arrays()) { + used_arrays.insert(output); + } + for (int i = model.operators.size() - 1; i >= 0; i--) { + bool is_op_used = false; + for (const string& op_output : model.operators[i]->outputs) { + if (used_arrays.count(op_output)) { + is_op_used = true; + break; + } + } + if (!is_op_used) { + continue; + } + for (const string& op_input : model.operators[i]->inputs) { + used_arrays.insert(op_input); + } + } + for (const auto& input_array : model.flags.input_arrays()) { + QCHECK(used_arrays.count(input_array.name())) + << "The graph does not connect the input (" << input_array.name() + << ") specified by --input_arrays to any of the specified " + << "--output_arrays (" + << absl::StrJoin(model.flags.output_arrays(), ", ") + << "). Did you pass the wrong flags for this model, " + << "or is that model's graph actually incomplete?"; + } +} + +void CheckInvariants(const Model& model) { + CheckNoMissingArray(model); + CheckNoOrphanedArray(model); + CheckArrayFieldsConsistent(model); + CheckOperatorOrdering(model); + CheckInputUsedByOutputs(model); +} + +void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check, + const int count, const string& count_description) { + if (model_check.count_min() >= 0) { + CHECK_GE(count, model_check.count_min()) + << "Mismatch in " << count_description << ": count was " << count + << ", but the specified " + << (model_check.count_max() > model_check.count_min() ? "minimum" + : "value") + << " was " << model_check.count_min() << "."; + } + if (model_check.count_max() > model_check.count_min()) { + CHECK_LE(count, model_check.count_max()) + << "Mismatch in " << count_description << ": count was " << count + << ", but the specified maximum was " << model_check.count_max() << "."; + } +} + +void CheckModelCounts(const Model& model) { + std::unordered_multiset<OperatorType> ops_by_type; + std::unordered_map<string, OperatorType> op_type_by_name; + if (model.flags.model_checks_size() == 0) { + return; + } + + for (const auto& op : model.operators) { + ops_by_type.insert(op->type); + op_type_by_name[OperatorTypeName(op->type)] = op->type; + } + for (const auto& model_check : model.flags.model_checks()) { + string count_type = model_check.count_type(); + if (count_type == "None") { + continue; + } else if (count_type == "Arrays") { + CheckCountInRange(model_check, model.arrays.size(), "count of arrays"); + } else if (count_type == "Total") { + CheckCountInRange(model_check, model.operators.size(), + "count of all operator instances"); + } else { + // The check type is not itself checked against the set of valid + // operators, mainly because the enum set cannot be iterated in C++. + const int found_count = + op_type_by_name.count(count_type) > 0 + ? ops_by_type.count(op_type_by_name[count_type]) + : 0; + CheckCountInRange(model_check, found_count, + "count of instances of " + count_type + " operator"); + } + } +} + +void MakeArrayDims(int num_dims, int batch, int height, int width, int depth, + std::vector<int>* out_dims) { + CHECK(out_dims->empty()); + if (num_dims == 1) { + CHECK_EQ(batch, 1); + *out_dims = {depth}; + } else if (num_dims == 2) { + *out_dims = {batch, depth}; + } else if (num_dims == 3) { + CHECK_EQ(batch, 1); + *out_dims = {height, width, depth}; + } else if (num_dims == 4) { + *out_dims = {batch, height, width, depth}; + } else { + LOG(FATAL) << "Should not get here: " << num_dims; + } +} + +void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) { + int batch = 1; + int num_dims = -1; + for (const auto& input_array : model->flags.input_arrays()) { + // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find + // a better match by name. + if (input_array.name() == name || num_dims == -1) { + num_dims = input_array.shape_size(); + if (num_dims != 0) { + batch = input_array.shape(0); + } + } + } + Array& array = model->GetOrCreateArray(name); + if (array.has_shape()) { + num_dims = array.shape().dimensions_count(); + } + std::vector<int> dims; + MakeArrayDims(num_dims, batch, 1, 1, size, &dims); + CHECK(array.data_type == ArrayDataType::kFloat || + array.data_type == ArrayDataType::kNone); + array.data_type = ArrayDataType::kFloat; + if (!array.has_shape()) { + Shape* shape = array.mutable_shape(); + *shape->mutable_dims() = dims; + } +} + +void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { + // Merge info about input_arrays from model_flags into model->flags + for (const auto& specified_input_array : model_flags.input_arrays()) { + toco::InputArray* dst_input_array = nullptr; + for (int i = 0; i < model->flags.input_arrays_size(); i++) { + toco::InputArray* candidate_dst_input_array = + model->flags.mutable_input_arrays(i); + if (candidate_dst_input_array->name() == specified_input_array.name()) { + // specified_input_array from model_flags maps to dst_input_array + // in model->flags + dst_input_array = candidate_dst_input_array; + break; + } + } + if (!dst_input_array) { + // specified_input_array from model_flags is not found in model->flags. + // Match a name-less specified input array when there can be no ambiguity + // as there is only 1 input array. + if (model->flags.input_arrays_size() == 1 && + model_flags.input_arrays_size() == 1 && + !specified_input_array.has_name()) { + dst_input_array = model->flags.mutable_input_arrays(0); + } + } + if (!dst_input_array) { + // Still no match, so create a new input array to copy + // specified_input_array into. + dst_input_array = model->flags.add_input_arrays(); + dst_input_array->set_name(specified_input_array.name()); + } + +#define RESOLVE_MODEL_FLAG(field_name) \ + if (specified_input_array.has_##field_name()) { \ + if (dst_input_array->has_##field_name()) { \ + QCHECK_EQ(dst_input_array->field_name(), \ + specified_input_array.field_name()) \ + << "For input array '" << dst_input_array->name() << "', " \ + << "specified " #field_name " flag with value: " \ + << specified_input_array.field_name() \ + << " does not agree with already defined " #field_name \ + " of this model, with value: " \ + << specified_input_array.field_name(); \ + } else { \ + dst_input_array->set_##field_name(specified_input_array.field_name()); \ + } \ + } + RESOLVE_MODEL_FLAG(std_value); + RESOLVE_MODEL_FLAG(mean_value); +#undef RESOLVE_MODEL_FLAG + + if (!specified_input_array.shape().empty()) { + if (!dst_input_array->shape().empty()) { + QCHECK_EQ(specified_input_array.shape().size(), + dst_input_array->shape().size()) + << "For input array '" << specified_input_array.name() << "', " + << "size of specified input shape flag with size: " + << specified_input_array.shape().size() + << " does not agree with already defined input shape" + " of this model, with size: " + << dst_input_array->shape().size(); + // We treat the first dimension as a special case, since it is often + // a batch size and the input_shape flag is effectively overriding + // the model. + for (int i = 1; i < specified_input_array.shape().size(); i++) { + QCHECK_EQ(specified_input_array.shape().Get(i), + dst_input_array->shape().Get(i)) + << "At dimension number " << i << " of input array " + << specified_input_array.name() << ", the specified shape's " + << "dimension flag with dimension: " + << specified_input_array.shape().Get(i) + << " does not agree with already defined shape" + << " of this model, with dimension: " + << dst_input_array->shape().Get(i); + } + } else { + dst_input_array->mutable_shape()->CopyFrom( + specified_input_array.shape()); + } + } + } + + if (model_flags.output_arrays_size() > 0) { + model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays()); + } + +#define RESOLVE_MODEL_FLAG(name) \ + if (model_flags.has_##name()) { \ + if (model->flags.has_##name()) { \ + QCHECK_EQ(model_flags.name(), model->flags.name()) \ + << "Specified " #name " flag with value: " << model_flags.name() \ + << " does not agree with already defined " #name \ + " of this model, with value: " \ + << model->flags.name(); \ + } else { \ + model->flags.set_##name(model_flags.name()); \ + } \ + } + + RESOLVE_MODEL_FLAG(variable_batch) + RESOLVE_MODEL_FLAG(drop_control_dependency) + +#undef RESOLVE_MODEL_FLAG + + if (model->flags.rnn_states_size() == 0) { + model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states()); + } else { + CHECK_EQ(model->flags.rnn_states_size(), model_flags.rnn_states_size()); + for (int i = 0; i < model->flags.rnn_states_size(); i++) { + CHECK_EQ(model->flags.rnn_states(i).state_array(), + model_flags.rnn_states(i).state_array()); + CHECK_EQ(model->flags.rnn_states(i).back_edge_source_array(), + model_flags.rnn_states(i).back_edge_source_array()); + } + } + + if (model->flags.model_checks_size() == 0) { + model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks()); + } + + QCHECK_GT(model->flags.input_arrays_size(), 0) + << "This model does not define input arrays, so a " + "--input_arrays flag must be given on the command-line."; + QCHECK_GT(model->flags.output_arrays_size(), 0) + << "This model does not define output arrays, so a " + "--output_arrays flag must be given on the command-line."; + + for (const auto& input_array_proto : model->flags.input_arrays()) { + QCHECK(!input_array_proto.shape().empty()) + << "This model does not have shape defined for input array " + << input_array_proto.name() + << ", so one must be specified by a non-empty --input_shape " + "command-line flag."; + + auto& input_array = model->GetOrCreateArray(input_array_proto.name()); + if (input_array.data_type == ArrayDataType::kNone) { + // We start out with a float input array; + // that may get replaced by a uint8 array later, by + // MakeInitialDequantizeOp. + input_array.data_type = ArrayDataType::kFloat; + } + + // Compare/merge the model->flags describing the input_shape with + // the actual input array's shape. + auto& input_array_dims = *input_array.mutable_shape()->mutable_dims(); + if (input_array_dims.empty()) { + for (auto dim : input_array_proto.shape()) { + CHECK_GE(dim, 1); + input_array_dims.push_back(dim); + } + } else { + CHECK_EQ(input_array_dims.size(), input_array_proto.shape_size()); + for (int i = 0; i < input_array_dims.size(); i++) { + CHECK_EQ(input_array_dims[i], input_array_proto.shape(i)); + } + } + + const float mean_value = input_array_proto.mean_value(); + const float std_value = input_array_proto.std_value(); + MinMax input_minmax; + input_minmax.min = (0.f - mean_value) / std_value; + input_minmax.max = (255.f - mean_value) / std_value; + if (input_array.minmax) { + if (input_array_proto.has_mean_value() || + input_array_proto.has_std_value()) { + CHECK(input_minmax == *input_array.minmax) + << input_minmax.min << ", " << input_minmax.max + << " != " << input_array.minmax->min << ", " + << input_array.minmax->max; + } + } else { + input_array.GetOrCreateMinMax() = input_minmax; + } + } + // Creation of the RNN state arrays + for (const auto& rnn_state : model->flags.rnn_states()) { + if (!rnn_state.manually_create()) { + continue; + } + CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(), + model); + } +} + +void CheckIsReadyForQuantization(const Model& model) { + for (const auto& op : model.operators) { + for (const auto& input : op->inputs) { + const auto& input_array = model.GetArray(input); + if (input_array.data_type != ArrayDataType::kFloat) { + // The array is not floats, no quantization needed. + continue; + } + if (input_array.minmax) { + // The array has minmax, we're good. + continue; + } + if (input_array.buffer) { + // The array has a constant buffer, so we can + // fall back to computing the minmax from actual array entries + // (with a WARNING about possible accuracy implications). + continue; + } + LOG(FATAL) + << "Array " << input << ", which is an input to the " + << HelpfulOperatorTypeName(*op) << " operator producing the output " + << "array " << op->outputs[0] << ", is lacking min/max data, " + << "which is necessary for quantization. Either target a " + << "non-quantized output format, or change the input graph to " + << "contain min/max information, or pass --default_ranges_min= and " + << "--default_ranges_max= if you do not care about the accuracy of " + << "results."; + } + } +} + +void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min, + double default_ranges_max) { + for (const auto& op : model->operators) { + for (const auto& input : op->inputs) { + auto& input_array = model->GetArray(input); + if (!input_array.minmax && !input_array.buffer) { + auto& minmax = input_array.GetOrCreateMinMax(); + minmax.min = default_ranges_min; + minmax.max = default_ranges_max; + } + } + for (const auto& output : op->outputs) { + auto& output_array = model->GetArray(output); + if (!output_array.minmax && !output_array.buffer) { + auto& minmax = output_array.GetOrCreateMinMax(); + minmax.min = default_ranges_min; + minmax.max = default_ranges_max; + } + } + } +} + +int ElementSize(ArrayDataType data_type) { + switch (data_type) { + case ArrayDataType::kFloat: + return 4; + case ArrayDataType::kInt32: + return 4; + case ArrayDataType::kUint8: + return 1; + default: + LOG(FATAL) << "Should not get here."; + return 0; + } +} + +void DropMinMax(Model* model, const string& array_name) { + auto& array = model->GetArray(array_name); + if (!!array.minmax) { + LOG(WARNING) << "Dropping MinMax information in array " << array_name + << ". Expect inaccuracy in quantized inference."; + array.minmax = nullptr; + } +} + +bool IsAllocatableTransientArray(const Model& model, const string& array_name) { + // The model's input and output arrays are externally allocated. + // They are not transient arrays. + if (IsInputArray(model, array_name)) { + return false; + } + for (const string& output_array : model.flags.output_arrays()) { + if (array_name == output_array) { + return false; + } + } + const auto& array = model.arrays.at(array_name); + // An array with a constant buffer isn't a transient array. + if (!!array->buffer) { + return false; + } + // An array without shape isn't allocatable. + if (!array->has_shape()) { + return false; + } + return true; +} + +string AvailableArrayName(const Model& model, const string& name) { + if (!model.arrays.count(name)) { + return name; + } + const int kNumSuffixesToTry = 1000; + for (int i = 0; i < kNumSuffixesToTry; i++) { + const string& name_with_suffix = toco::port::StringF("%s_%d", name, i); + if (!model.arrays.count(name_with_suffix)) { + return name_with_suffix; + } + } + LOG(FATAL) << "Could not find an available array name starting with " << name + << ". Tried " << kNumSuffixesToTry << " suffixes, all were taken!"; + return ""; +} + +string ShapeToString(const Shape& shape) { + if (shape.dimensions_count() == 0) { + return "[]"; + } + + return absl::StrCat("[ ", absl::StrJoin(shape.dims(), ", "), " ]"); +} + +void PrintArrayShape(Model* model, const string& name) { + if (!model->arrays[name]->has_shape()) { + LOG(INFO) << name << " has no shape"; + return; + } + LOG(INFO) << name + << " has shape: " << ShapeToString(model->arrays[name]->shape()); +} + +bool IsArrayFullyConnectedWeights(const Model& model, const string& name) { + bool is_fc_weights = false; + bool is_something_else = false; + for (const auto& op : model.operators) { + for (int input_index = 0; input_index < op->inputs.size(); input_index++) { + if (op->inputs[input_index] == name) { + if (op->type == OperatorType::kFullyConnected && input_index == 1) { + is_fc_weights = true; + } else { + is_something_else = true; + } + } + } + } + CHECK(!(is_fc_weights && is_something_else)); + return is_fc_weights; +} + +bool EstimateArithmeticOpsCount(const Model& model, int64* result) { + int64 total = 0; + for (const auto& op : model.operators) { + switch (op->type) { + case OperatorType::kFullyConnected: + case OperatorType::kConv: + case OperatorType::kDepthwiseConv: { + const auto& output_array = model.GetArray(op->outputs[0]); + const auto& weights_array = model.GetArray(op->inputs[1]); + if (!output_array.has_shape() || !weights_array.has_shape()) { + return false; + } + int cols = 1; + for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) { + cols *= output_array.shape().dims(i); + } + const int64 cost_per_col = + 2 * RequiredBufferSizeForShape(weights_array.shape()); + total += cost_per_col * cols; + if (op->inputs.size() > 2) { + // There is a bias vector. One more op per output value. + total += RequiredBufferSizeForShape(output_array.shape()); + } + break; + } + case OperatorType::kAdd: + case OperatorType::kSub: + case OperatorType::kMul: { + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + total += RequiredBufferSizeForShape(output_array.shape()); + break; + } + case OperatorType::kLogistic: + case OperatorType::kSoftmax: + case OperatorType::kTanh: { + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + // As a very rough ballpark, the cost of evaluating a math function + // such as tanh or logistic is about 32 multiplications, and about as + // many additions/subtractions. (Just a power-of-two order-of-magnitude + // from looking at actual implementations that we use in runtime/ code). + total += 64 * RequiredBufferSizeForShape(output_array.shape()); + break; + } + case OperatorType::kMaxPool: { + const auto& maxpool = *static_cast<const MaxPoolOperator*>(op.get()); + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + total += RequiredBufferSizeForShape(output_array.shape()) * + maxpool.kheight * maxpool.kwidth; + break; + } + case OperatorType::kAveragePool: { + const auto& avgpool = + *static_cast<const AveragePoolOperator*>(op.get()); + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + total += RequiredBufferSizeForShape(output_array.shape()) * + avgpool.kheight * avgpool.kwidth; + break; + } + case OperatorType::kL2Pool: { + const auto* maxpool = static_cast<const MaxPoolOperator*>(op.get()); + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + // The sum of squares requires (kheight*kwidth) multiply-adds, + // and then there is the sqrt which we ballpark at 32 ops. + const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32; + total += + RequiredBufferSizeForShape(output_array.shape()) * cost_per_val; + break; + } + case OperatorType::kL2Normalization: { + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + // Computing the squared L2 norm is N multiply-adds so 2N ops, + // then the single inverse-sqrt is negligible, then we multiply each + // value by the resulting multiplier, so an extra N ops. Total 3N ops. + total += 3 * RequiredBufferSizeForShape(output_array.shape()); + break; + } + default: + break; + } + } + *result = total; + return true; +} + +namespace { + +void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, + std::vector<int>* shuffle) { + CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order)); + shuffle->resize(4); + for (int i = 0; i < 4; i++) { + (*shuffle)[i] = i; + } + if (input_axes_order == output_axes_order) { + // nothing to do + } else if (AxesCount(input_axes_order) == 2) { + shuffle->resize(2); + (*shuffle)[0] = 1; + (*shuffle)[1] = 0; + } else if (input_axes_order == AxesOrder::kOHWI && + output_axes_order == AxesOrder::kHWIO) { + // 3210 <- 3210 + // HWIO <- OHWI + (*shuffle)[0] = 1; + (*shuffle)[1] = 2; + (*shuffle)[2] = 3; + (*shuffle)[3] = 0; + } else if (input_axes_order == AxesOrder::kHWIO && + output_axes_order == AxesOrder::kOHWI) { + // 3210 <- 3210 + // OHWI <- HWIO + (*shuffle)[0] = 3; + (*shuffle)[1] = 0; + (*shuffle)[2] = 1; + (*shuffle)[3] = 2; + } else { + LOG(FATAL) << "Bad shuffle"; + } +} + +// Extend shuffle is designed to match ExtendShape, which pads the shape with +// unit dimensions at the beginning. +void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim, + std::vector<int>* extended_shuffle) { + *extended_shuffle = input_shuffle; + CHECK(newdim >= input_shuffle.size()); + const int pad_size = newdim - input_shuffle.size(); + extended_shuffle->resize(newdim); + for (int i = 0; i < pad_size; i++) { + (*extended_shuffle)[i] = i; + } + for (int i = pad_size; i < newdim; i++) { + (*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size; + } +} + +} // end anonymous namespace + +void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order, + AxesOrder output_axes_order, Shape* output_shape) { + if (input_axes_order == AxesOrder::kHWIM && + output_axes_order == AxesOrder::k1HWO) { + // This special case isn't just a permutation, the IM pair of dims get + // merged into the 3 dim, so we have to special-case it. + *output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1), + input_shape.dims(3) * input_shape.dims(2)}); + } else { + std::vector<int> shuffle; + GetShuffleShape(input_axes_order, output_axes_order, &shuffle); + std::vector<int>* output_dims = output_shape->mutable_dims(); + output_dims->resize(input_shape.dimensions_count()); + for (int i = 0; i < input_shape.dimensions_count(); i++) { + (*output_dims)[i] = input_shape.dims(shuffle[i]); + } + } +} + +void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order, + AxesOrder output_axes_order, const Shape& output_shape, + const float* input_data, float* output_data) { + if (input_axes_order == AxesOrder::kHWIM && + output_axes_order == AxesOrder::k1HWO) { + // This special case isn't just a permutation, the IM pair of dims get + // merged into the O dim, so we have to special-case it. Fortunately, + // as far as array shuffling is concerned, it's just the identity + // transformation. + memcpy(output_data, input_data, + RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0])); + return; + } + CHECK(input_shape.dimensions_count() == output_shape.dimensions_count()); + const int dim = input_shape.dimensions_count(); + CHECK_LE(dim, 4); + std::vector<int> shuffle; + GetShuffleShape(input_axes_order, output_axes_order, &shuffle); + CHECK(shuffle.size() >= dim); + for (int i = 0; i < dim; i++) { + CHECK(shuffle[i] >= 0 && shuffle[i] < dim); + CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i)); + } + Shape extended_input_shape = input_shape; + ExtendShape(&extended_input_shape, 4); + Shape extended_output_shape = output_shape; + ExtendShape(&extended_output_shape, 4); + std::vector<int> extended_shuffle; + ExtendShuffle(shuffle, 4, &extended_shuffle); + + const std::vector<int>& extended_input_dims = extended_input_shape.dims(); + const std::vector<int>& extended_output_dims = extended_output_shape.dims(); + + // TODO(starka): Rework to handle different numbers of dimensions. + int input_strides[4]; + input_strides[3] = 1; + input_strides[2] = extended_input_dims[3]; + input_strides[1] = input_strides[2] * extended_input_dims[2]; + input_strides[0] = input_strides[1] * extended_input_dims[1]; + const int input_stride_0 = input_strides[extended_shuffle[3]]; + const int input_stride_1 = input_strides[extended_shuffle[2]]; + const int input_stride_2 = input_strides[extended_shuffle[1]]; + const int input_stride_3 = input_strides[extended_shuffle[0]]; + + const int output_size_0 = extended_output_dims[3]; + const int output_size_1 = extended_output_dims[2]; + const int output_size_2 = extended_output_dims[1]; + const int output_size_3 = extended_output_dims[0]; + const int output_stride_0 = 1; + const int output_stride_1 = output_size_0; + const int output_stride_2 = output_stride_1 * output_size_1; + const int output_stride_3 = output_stride_2 * output_size_2; + + for (int i3 = 0; i3 < output_size_3; i3++) { + const float* const input_ptr_3 = input_data + i3 * input_stride_3; + float* const output_ptr_3 = output_data + i3 * output_stride_3; + for (int i2 = 0; i2 < output_size_2; i2++) { + const float* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2; + float* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2; + for (int i1 = 0; i1 < output_size_1; i1++) { + const float* input_ptr = input_ptr_2 + i1 * input_stride_1; + float* output_ptr = output_ptr_2 + i1 * output_stride_1; + float* const output_ptr_end = + output_ptr + output_size_0 * output_stride_0; + while (output_ptr != output_ptr_end) { + *output_ptr = *input_ptr; + input_ptr += input_stride_0; + output_ptr += output_stride_0; + } + } + } + } +} + +int AxesCount(AxesOrder axes_order) { + switch (axes_order) { + case AxesOrder::kOneAxis: + return 1; + case AxesOrder::kRC: + return 2; + case AxesOrder::kCR: + return 2; + case AxesOrder::kHWIO: + return 4; + case AxesOrder::kOHWI: + return 4; + case AxesOrder::kHWIM: + return 4; + case AxesOrder::k1HWO: + return 4; + case AxesOrder::kNHWC: + return 4; + default: + LOG(FATAL) << "Bad AxesOrder"; + return 0; + } +} + +bool IsDiscardableArray(const Model& model, const string& array_name) { + for (const auto& input_array : model.flags.input_arrays()) { + if (array_name == input_array.name()) { + return false; + } + } + for (const string& output_array : model.flags.output_arrays()) { + if (array_name == output_array) { + return false; + } + } + for (const auto& rnn_state : model.flags.rnn_states()) { + if (array_name == rnn_state.state_array()) { + return false; + } + if (array_name == rnn_state.back_edge_source_array()) { + return false; + } + } + return true; +} + +void CheckFinalDataTypesSatisfied(const Model& model) { + for (const auto& array_entry : model.arrays) { + const auto& array = *array_entry.second; + if (array.final_data_type != ArrayDataType::kNone) { + CHECK(array.final_data_type == array.data_type); + } + } +} + +} // namespace toco |