/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/toco/tooling_util.h" #include #include #include #include #include #include #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" #include "re2/re2.h" #include "tensorflow/contrib/lite/toco/dump_graphviz.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" namespace toco { // Find the longest common prefix of two strings. absl::string_view FindLongestCommonPrefix(absl::string_view a, absl::string_view b) { if (a.empty() || b.empty()) return absl::string_view(); const char* pa = a.data(); const char* pb = b.data(); size_t count = 0; const size_t limit = std::min(a.size(), b.size()); while (count < limit && *pa == *pb) { ++pa; ++pb; ++count; } return absl::string_view(a.data(), count); } string LogName(const Operator& op) { const string& opname = HelpfulOperatorTypeName(op); if (op.outputs.empty()) { return toco::port::StringF("{%s operator}", opname); } else { return toco::port::StringF("{%s operator with output %s}", opname, op.outputs[0]); } } string ArrayDataTypeName(ArrayDataType data_type) { switch (data_type) { case ArrayDataType::kFloat: return "Float"; case ArrayDataType::kInt8: return "Int8"; case ArrayDataType::kUint8: return "Uint8"; case ArrayDataType::kInt16: return "Int16"; case ArrayDataType::kUint16: return "Uint16"; case ArrayDataType::kInt32: return "Int32"; case ArrayDataType::kUint32: return "Uint32"; case ArrayDataType::kInt64: return "Int64"; case ArrayDataType::kUint64: return "Uint64"; case ArrayDataType::kString: return "String"; case ArrayDataType::kBool: return "Bool"; case ArrayDataType::kNone: return "None"; default: LOG(FATAL) << "Unhandled array data type " << static_cast(data_type); } } bool IsInputArray(const Model& model, const string& array_name) { for (const auto& input_array : model.flags.input_arrays()) { if (array_name == input_array.name()) { return true; } } return false; } bool IsOutputArray(const Model& model, const string& array_name) { for (const auto& output_array : model.flags.output_arrays()) { if (array_name == output_array) { return true; } } return false; } bool IsArrayConsumed(const Model& model, const string& name) { if (GetOpWithInput(model, name)) { return true; } if (IsOutputArray(model, name)) { return true; } for (const auto& rnn_state : model.flags.rnn_states()) { if (rnn_state.back_edge_source_array() == name) { return true; } } return false; } int CountTrueOutputs(const Model& model, const Operator& op) { int count = 0; for (const string& output : op.outputs) { if (IsArrayConsumed(model, output)) { ++count; } } return count; } int CountOpsWithInput(const Model& model, const string& array_name) { int count = 0; for (const auto& op : model.operators) { for (auto& input : op->inputs) { if (input == array_name) { count++; // Breaking here is important: some graphs have ops that use the // same array as more than one of their inputs, and in that case // we want it counted only once. break; } } } return count; } bool DeleteArrayIfUnused(const string& array_name, Model* model) { if (IsDiscardableArray(*model, array_name) && CountOpsWithInput(*model, array_name) == 0) { model->EraseArray(array_name); return true; } return false; } bool DeleteArrayIfUsedOnce(const string& array_name, Model* model) { if (IsDiscardableArray(*model, array_name) && CountOpsWithInput(*model, array_name) == 1) { model->EraseArray(array_name); return true; } return false; } void DeleteOpAndArraysIfUnused(Model* model, Operator* op) { for (const string& array_name : op->inputs) { DeleteArrayIfUsedOnce(array_name, model); } auto op_it = FindOp(*model, op); CHECK(op_it != model->operators.end()); model->operators.erase(op_it); } std::vector>::const_iterator FindOpWithOutput( const Model& model, const string& array_name) { for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { for (auto& output : it->get()->outputs) { if (output == array_name) { return it; } } } return model.operators.end(); } std::vector>::iterator FindOpWithOutput( Model& model, const string& array_name) { for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { for (auto& output : it->get()->outputs) { if (output == array_name) { return it; } } } return model.operators.end(); } Operator* GetOpWithOutput(const Model& model, const string& array_name) { auto it = FindOpWithOutput(model, array_name); return it == model.operators.end() ? nullptr : it->get(); } // GetFirstOpWithInput assumes that this finds the first op. std::vector>::const_iterator FindOpWithInput( const Model& model, const string& array_name) { for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { for (auto& input : it->get()->inputs) { if (input == array_name) { return it; } } } return model.operators.end(); } std::vector>::iterator FindOpWithInput( Model& model, const string& array_name) { for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { for (auto& input : it->get()->inputs) { if (input == array_name) { return it; } } } return model.operators.end(); } std::vector>::const_iterator FindOp( const Model& model, const Operator* op) { for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { if (it->get() == op) { return it; } } return model.operators.end(); } std::vector>::iterator FindOp(Model& model, const Operator* op) { for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { if (it->get() == op) { return it; } } return model.operators.end(); } Operator* GetOpWithInput(const Model& model, const string& array_name) { auto it = FindOpWithInput(model, array_name); return it == model.operators.end() ? nullptr : it->get(); } Operator* GetFirstOpWithInput(const Model& model, const string& array_name) { auto it = FindOpWithInput(model, array_name); return it == model.operators.end() ? nullptr : it->get(); } void ReplaceArrayUsage(Model* model, const string& old_array_name, const string& new_array_name) { for (auto& op_it : model->operators) { Operator* op = op_it.get(); for (size_t i = 0; i < op->inputs.size(); ++i) { if (op->inputs[i] == old_array_name) { op->inputs[i] = new_array_name; } } for (size_t i = 0; i < op->outputs.size(); ++i) { if (op->outputs[i] == old_array_name) { op->outputs[i] = new_array_name; } } } } string FormatArraysList(const Model& model, const std::vector& list) { if (list.empty()) { return "[]"; } string result = ""; if (list.size() > 1) { result += "[ "; } for (std::size_t i = 0; i < list.size(); i++) { if (i > 0) { result += ", "; } result += list[i]; } if (list.size() > 1) { result += " ]"; } return result; } const char* OperatorTypeName(OperatorType type) { switch (type) { #define HANDLE_OPERATORTYPENAME_CASE(c) \ case OperatorType::k##c: \ return #c; HANDLE_OPERATORTYPENAME_CASE(Add) HANDLE_OPERATORTYPENAME_CASE(AddN) HANDLE_OPERATORTYPENAME_CASE(AveragePool) HANDLE_OPERATORTYPENAME_CASE(BatchMatMul) HANDLE_OPERATORTYPENAME_CASE(BatchNormalization) HANDLE_OPERATORTYPENAME_CASE(Conv) HANDLE_OPERATORTYPENAME_CASE(Concatenation) HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv) HANDLE_OPERATORTYPENAME_CASE(DepthToSpace) HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth) HANDLE_OPERATORTYPENAME_CASE(FullyConnected) HANDLE_OPERATORTYPENAME_CASE(Dequantize) HANDLE_OPERATORTYPENAME_CASE(L2Normalization) HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization) HANDLE_OPERATORTYPENAME_CASE(Log) HANDLE_OPERATORTYPENAME_CASE(Logistic) HANDLE_OPERATORTYPENAME_CASE(LstmCell) HANDLE_OPERATORTYPENAME_CASE(MaxPool) HANDLE_OPERATORTYPENAME_CASE(L2Pool) HANDLE_OPERATORTYPENAME_CASE(FakeQuant) HANDLE_OPERATORTYPENAME_CASE(Mul) HANDLE_OPERATORTYPENAME_CASE(RandomUniform) HANDLE_OPERATORTYPENAME_CASE(Relu) HANDLE_OPERATORTYPENAME_CASE(Relu1) HANDLE_OPERATORTYPENAME_CASE(Relu6) HANDLE_OPERATORTYPENAME_CASE(PRelu) HANDLE_OPERATORTYPENAME_CASE(ReorderAxes) HANDLE_OPERATORTYPENAME_CASE(Softmax) HANDLE_OPERATORTYPENAME_CASE(LogSoftmax) HANDLE_OPERATORTYPENAME_CASE(Div) HANDLE_OPERATORTYPENAME_CASE(Tanh) HANDLE_OPERATORTYPENAME_CASE(Sin) HANDLE_OPERATORTYPENAME_CASE(All) HANDLE_OPERATORTYPENAME_CASE(Assert) HANDLE_OPERATORTYPENAME_CASE(ExpandDims) HANDLE_OPERATORTYPENAME_CASE(Fill) HANDLE_OPERATORTYPENAME_CASE(FloorMod) HANDLE_OPERATORTYPENAME_CASE(FloorDiv) HANDLE_OPERATORTYPENAME_CASE(Greater) HANDLE_OPERATORTYPENAME_CASE(GreaterEqual) HANDLE_OPERATORTYPENAME_CASE(Identity) HANDLE_OPERATORTYPENAME_CASE(Less) HANDLE_OPERATORTYPENAME_CASE(LessEqual) HANDLE_OPERATORTYPENAME_CASE(MatMul) HANDLE_OPERATORTYPENAME_CASE(ReduceMax) // Reduction Max HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum HANDLE_OPERATORTYPENAME_CASE(Merge) HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum HANDLE_OPERATORTYPENAME_CASE(Neg) HANDLE_OPERATORTYPENAME_CASE(OneHot) HANDLE_OPERATORTYPENAME_CASE(Pack) HANDLE_OPERATORTYPENAME_CASE(Pad) HANDLE_OPERATORTYPENAME_CASE(PadV2) HANDLE_OPERATORTYPENAME_CASE(StridedSlice) HANDLE_OPERATORTYPENAME_CASE(Range) HANDLE_OPERATORTYPENAME_CASE(Rank) HANDLE_OPERATORTYPENAME_CASE(Reshape) HANDLE_OPERATORTYPENAME_CASE(Squeeze) HANDLE_OPERATORTYPENAME_CASE(Rsqrt) HANDLE_OPERATORTYPENAME_CASE(Shape) HANDLE_OPERATORTYPENAME_CASE(Slice) HANDLE_OPERATORTYPENAME_CASE(Split) HANDLE_OPERATORTYPENAME_CASE(Sqrt) HANDLE_OPERATORTYPENAME_CASE(Square) HANDLE_OPERATORTYPENAME_CASE(Switch) HANDLE_OPERATORTYPENAME_CASE(Sub) HANDLE_OPERATORTYPENAME_CASE(Sum) HANDLE_OPERATORTYPENAME_CASE(Tile) HANDLE_OPERATORTYPENAME_CASE(Transpose) HANDLE_OPERATORTYPENAME_CASE(TransposeConv) HANDLE_OPERATORTYPENAME_CASE(Concat) HANDLE_OPERATORTYPENAME_CASE(ConcatV2) HANDLE_OPERATORTYPENAME_CASE(Cast) HANDLE_OPERATORTYPENAME_CASE(Floor) HANDLE_OPERATORTYPENAME_CASE(Gather) HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear) HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND) HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND) HANDLE_OPERATORTYPENAME_CASE(Mean) HANDLE_OPERATORTYPENAME_CASE(ReduceProd) HANDLE_OPERATORTYPENAME_CASE(Svdf) HANDLE_OPERATORTYPENAME_CASE(ArgMax) HANDLE_OPERATORTYPENAME_CASE(ArgMin) HANDLE_OPERATORTYPENAME_CASE(TopK_V2) HANDLE_OPERATORTYPENAME_CASE(Unsupported) HANDLE_OPERATORTYPENAME_CASE(Exp) HANDLE_OPERATORTYPENAME_CASE(DynamicPartition) HANDLE_OPERATORTYPENAME_CASE(DynamicStitch) HANDLE_OPERATORTYPENAME_CASE(Select) HANDLE_OPERATORTYPENAME_CASE(SparseToDense) HANDLE_OPERATORTYPENAME_CASE(Equal) HANDLE_OPERATORTYPENAME_CASE(NotEqual) HANDLE_OPERATORTYPENAME_CASE(Pow) HANDLE_OPERATORTYPENAME_CASE(Any) HANDLE_OPERATORTYPENAME_CASE(LogicalAnd) HANDLE_OPERATORTYPENAME_CASE(LogicalNot) HANDLE_OPERATORTYPENAME_CASE(LogicalOr) HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder) HANDLE_OPERATORTYPENAME_CASE(Unpack) HANDLE_OPERATORTYPENAME_CASE(ZerosLike) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE } } string HelpfulOperatorTypeName(const Operator& op) { if (op.type == OperatorType::kUnsupported) { return toco::port::StringF( "(Unsupported TensorFlow op: %s)", static_cast(op).tensorflow_op); } return OperatorTypeName(op.type); } bool OperatorSupportsFusedActivation(OperatorType type) { switch (type) { case OperatorType::kAdd: case OperatorType::kAveragePool: case OperatorType::kBatchNormalization: case OperatorType::kConv: case OperatorType::kDepthwiseConv: case OperatorType::kDiv: case OperatorType::kFullyConnected: case OperatorType::kL2Pool: case OperatorType::kMaxPool: case OperatorType::kMul: case OperatorType::kSub: return true; default: return false; } } void LogSummary(int log_level, const Model& model) { VLOG(log_level) << "Operators summary (" << model.operators.size() << " operators):"; std::unordered_multiset ops_by_type; for (const auto& op : model.operators) { ops_by_type.insert(op->type); } auto it = ops_by_type.begin(); while (it != ops_by_type.end()) { int count = ops_by_type.count(*it); VLOG(log_level) << " " << OperatorTypeName(*it) << ": " << count; std::advance(it, count); } } void LogArray(int log_level, const Model& model, const string& name) { VLOG(log_level) << "Array: " << name; if (!model.HasArray(name)) { VLOG(log_level) << " DOES NOT EXIST"; return; } const auto& array = model.GetArray(name); VLOG(log_level) << " Data type: " << ArrayDataTypeName(array.data_type); VLOG(log_level) << " Final type: " << ArrayDataTypeName(array.final_data_type); if (array.buffer) { VLOG(log_level) << " Constant Buffer"; } if (array.alloc) { VLOG(log_level) << " Transient Alloc"; } if (array.has_shape()) { const Shape& array_shape = array.shape(); if (array_shape.dimensions_count() == 0) { VLOG(log_level) << " (Zero dimensions)"; } else { string message = " Dims: "; bool first = true; for (const int dim : array_shape.dims()) { if (!first) { message += ", "; } first = false; toco::port::AppendF(&message, "%d", dim); } VLOG(log_level) << message; } } if (array.minmax) { VLOG(log_level) << " MinMax: " << array.minmax->min << " .. " << array.minmax->max; } if (array.quantization_params) { VLOG(log_level) << " QuantizationParams: zero_point=" << static_cast(array.quantization_params->zero_point) << ", scale=" << array.quantization_params->scale; } } void DumpGraphvizVideoFrame(const Model& model) { namespace port = toco::port; const auto& dump_options = *GraphVizDumpOptions::singleton(); if (!dump_options.dump_graphviz_video) { return; } CHECK(!dump_options.dump_graphviz.empty()); // TODO(benoitjacob): the static data here means that this function // is stateful, not reentrant, and effectively leaks memory till exit // (since dump_hashes can only grow in size). It also means that it // really only is intended to be called for a single model during the // process' lifetime. So it's not great design at all. The overriding // design aspect here is to make the video-dumping code as unintrusive // and self-contained as possible. Eventually, we'll want to have that // cleaned-up, but that will require some form of general statefulness // in toco (some kind of 'tooling state' data structure) that does // not exist at present, and would be premature to design here just for // this new video-dumping feature. static int dump_id = 0; static std::unordered_set dump_hashes; string graphviz_dump; DumpGraphviz(model, &graphviz_dump); std::size_t hash = std::hash{}(graphviz_dump); if (!dump_hashes.count(hash)) { LOG(INFO) << "DUMPING GRAPHVIZ VIDEO FRAME: " << dump_id; dump_hashes.insert(hash); CHECK(port::file::SetContents( port::file::JoinPath( dump_options.dump_graphviz, toco::port::StringF("toco_video_%05d.dot", dump_id)), graphviz_dump, port::file::Defaults()) .ok()); dump_id++; } } void LogDump(int log_level, const string& message, const Model& model) { namespace port = toco::port; const auto& dump_options = *GraphVizDumpOptions::singleton(); DumpGraphvizVideoFrame(model); if (!dump_options.dump_graphviz.empty()) { string graphviz_dump; DumpGraphviz(model, &graphviz_dump); CHECK(port::file::SetContents( port::file::JoinPath( dump_options.dump_graphviz, absl::StrCat("toco_", absl::StrReplaceAll(message, {{" ", "_"}}), ".dot")), graphviz_dump, port::file::Defaults()) .ok()); } if (!VLOG_IS_ON(log_level)) { return; } VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")"; LogSummary(log_level, model); std::unordered_set already_printed_arrays; for (const auto& op : model.operators) { for (const auto& input : op->inputs) { if (!already_printed_arrays.count(input)) { already_printed_arrays.insert(input); LogArray(log_level, model, input); } } VLOG(log_level) << HelpfulOperatorTypeName(*op) << " :"; VLOG(log_level) << " " << FormatArraysList(model, op->inputs) << " -> " << FormatArraysList(model, op->outputs); if (op->fused_activation_function != FusedActivationFunctionType::kNone) { VLOG(log_level) << " (with fused activation function)"; } for (const auto& output : op->outputs) { if (!already_printed_arrays.count(output)) { already_printed_arrays.insert(output); LogArray(log_level, model, output); } } } VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")"; } // Note remaining raw-array extension in ProcessTensorFlowReshapeOperator(). void ExtendShape(Shape* shape, int new_shape_size) { CHECK_GE(new_shape_size, shape->dimensions_count()); const int size_increase = new_shape_size - shape->dimensions_count(); auto* shape_dims = shape->mutable_dims(); shape_dims->insert(shape_dims->begin(), size_increase, 1); } // TODO(b/62904716) Remove along with remaining uses. void UnextendShape(Shape* shape, int new_shape_size) { CHECK_LE(new_shape_size, shape->dimensions_count()); const int size_reduction = shape->dimensions_count() - new_shape_size; for (int i = 0; i < size_reduction; i++) { CHECK_EQ(shape->dims(i), 1); } std::vector& shape_dims = *shape->mutable_dims(); shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction); } // In general, zero-sized dimensions are disallowed, but there are exceptions, // e.g., if the tensor data itself represents a scalar (rank 0) shape, its // shape will have dimensions [0]. CheckNonEmptyShapeDimensions is more // strict, and is appropriate for ops and comparisons where an empty shape // doesn't make sense. template void CheckValidShapeDimensions(const Dims& dims) { if (dims.size() == 1 && dims[0] == 0) { return; } for (const auto& dim : dims) { CHECK_GE(dim, 1); } } void CheckValidShape(const Shape& shape) { CheckValidShapeDimensions(shape.dims()); } bool IsNonEmpty(const Shape& shape) { for (int i = 0; i < shape.dimensions_count(); ++i) { if (shape.dims(i) < 1) return false; } return true; } void CheckNonEmptyShapeDimensions(const Shape& shape) { for (int i = 0; i < shape.dimensions_count(); ++i) { CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i << ". shape = " << ShapeToString(shape); } } bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) { CheckNonEmptyShapeDimensions(shape0); CheckNonEmptyShapeDimensions(shape1); const Shape* longer = &shape0; const Shape* shorter = &shape1; if (shape1.dimensions_count() > shape0.dimensions_count()) { longer = &shape1; shorter = &shape0; } // Walk dimensions back to front until we run out of dimensions in the shorter // shape. int longer_index = longer->dimensions_count() - 1; int shorter_index = shorter->dimensions_count() - 1; while (shorter_index >= 0) { const int d_long = longer->dims(longer_index); const int d_short = shorter->dims(shorter_index); // Broadcasting fails if the dimensions are different *and* neither is 1. if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) { return false; } longer_index--; shorter_index--; } return true; } bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) { CheckNonEmptyShapeDimensions(shape0); CheckNonEmptyShapeDimensions(shape1); const Shape* longer = &shape0; const Shape* shorter = &shape1; if (shape1.dimensions_count() > shape0.dimensions_count()) { longer = &shape1; shorter = &shape0; } // Walk dimensions back to front until we run out of dimensions in the shorter // shape. int longer_index = longer->dimensions_count() - 1; int shorter_index = shorter->dimensions_count() - 1; while (shorter_index >= 0) { const int d_long = longer->dims(longer_index); const int d_short = shorter->dims(shorter_index); // Extending fails if the dimensions are different. if (d_long != d_short) { return false; } longer_index--; shorter_index--; } // The remaining dimensions in the longer shape must be 1. while (longer_index >= 0) { const int d_long = longer->dims(longer_index); if (d_long != 1) { return false; } longer_index--; } return true; } int RequiredBufferSizeForShape(const Shape& shape) { CheckValidShape(shape); int max_offset = 1; for (const auto& dim : shape.dims()) { max_offset *= dim; } return max_offset; } bool IsConstantParameterArray(const Model& model, const string& name) { if (!model.HasArray(name)) { return false; } return !!model.GetArray(name).buffer; } namespace { template bool CompareArrayBuffers(const Array& lhs_array, const Array& rhs_array) { CHECK(lhs_array.data_type == rhs_array.data_type) << "Data types must match"; CHECK(lhs_array.buffer) << "LHS must be constant"; CHECK(rhs_array.buffer) << "RHS must be constant"; const auto& lhs_data = lhs_array.GetBuffer().data; const auto& rhs_data = rhs_array.GetBuffer().data; CHECK_EQ(lhs_data.size(), rhs_data.size()) << "Buffer sizes must match in element count"; for (int i = 0; i < lhs_data.size(); ++i) { if (lhs_data[i] != rhs_data[i]) { return false; } } return true; } } // namespace bool CompareConstantArrays(const Array& lhs_array, const Array& rhs_array) { bool attrs_equal = lhs_array.shape() == rhs_array.shape() && lhs_array.data_type == rhs_array.data_type && lhs_array.final_data_type == rhs_array.final_data_type && lhs_array.minmax == rhs_array.minmax && lhs_array.quantization_params == rhs_array.quantization_params; if (!attrs_equal) { return false; } switch (lhs_array.data_type) { case ArrayDataType::kBool: return CompareArrayBuffers(lhs_array, rhs_array); case ArrayDataType::kFloat: return CompareArrayBuffers(lhs_array, rhs_array); case ArrayDataType::kInt8: return CompareArrayBuffers(lhs_array, rhs_array); case ArrayDataType::kUint8: return CompareArrayBuffers(lhs_array, rhs_array); case ArrayDataType::kInt16: return CompareArrayBuffers(lhs_array, rhs_array); case ArrayDataType::kUint16: return CompareArrayBuffers(lhs_array, rhs_array); case ArrayDataType::kInt32: return CompareArrayBuffers(lhs_array, rhs_array); case ArrayDataType::kUint32: return CompareArrayBuffers(lhs_array, rhs_array); case ArrayDataType::kInt64: return CompareArrayBuffers(lhs_array, rhs_array); case ArrayDataType::kUint64: return CompareArrayBuffers(lhs_array, rhs_array); case ArrayDataType::kString: return CompareArrayBuffers(lhs_array, rhs_array); default: LOG(FATAL) << "Unsupported data type: " << ArrayDataTypeName(lhs_array.data_type); return false; } } namespace { // Take an array name, which may be something like "name:3_5" and make it // acceptable as a TF node name, say "name_3_5"; string SanitizeNameForTFNode(const string& array_name) { auto node_name = array_name; std::replace(node_name.begin(), node_name.end(), ':', '_'); return node_name; } void CheckInputArraysAreNotOutputArrays(const ModelFlags& model_flags) { for (const auto& input_array : model_flags.input_arrays()) { for (const string& output_array : model_flags.output_arrays()) { QCHECK_NE(input_array.name(), output_array) << "The array " << output_array << " is listed in both --input_arrays and --output_arrays."; } } } bool IsAsciiPrintable(const string& name) { for (char c : name) { if (!absl::ascii_isprint(c)) { return false; } } return true; } string DumpAscii(const string& name) { string result; port::AppendF(&result, "ASCII | Hex\n"); port::AppendF(&result, "------+----\n"); for (char c : name) { if (absl::ascii_isprint(c)) { port::AppendF(&result, "%c | %x\n", c, c); } else { port::AppendF(&result, " | %x Not ASCII printable!\n", c); } } return result; } void CheckNonAsciiIOArrays(const ModelFlags& model_flags) { if (model_flags.allow_nonascii_arrays()) { return; } for (const auto& input_array : model_flags.input_arrays()) { QCHECK(IsAsciiPrintable(input_array.name())) << "Non-ASCII-printable character found in --input_arrays: " << input_array.name() << ". Pass --allow_nonascii_arrays to allow that. " << "Here is a dump of the string:\n\n" << DumpAscii(input_array.name()); } for (const string& output_array : model_flags.output_arrays()) { QCHECK(IsAsciiPrintable(output_array)) << "Non-ASCII-printable character found in --output_arrays: " << output_array << ". Pass --allow_nonascii_arrays to allow that. " << "Here is a dump of the string:\n\n" << DumpAscii(output_array); } } void CheckNonExistentIOArrays(const Model& model) { // "non-existent" is interpreted in the stronger sense of // "not actually produced/consumed by an op". // Rationale: we have to artificially fix up TensorFlow graphs by creating // any array that it refers to, so just checking that arrays exist isn't // sufficient. The real invariant here is whether arrays are produced/consumed // by something. if (model.flags.allow_nonexistent_arrays()) { return; } static constexpr char general_comment[] = "Is it a typo? To silence this message, pass this flag: " "allow_nonexistent_arrays"; for (const auto& input_array : model.flags.input_arrays()) { QCHECK(GetOpWithInput(model, input_array.name())) << "Specified input array \"" << input_array.name() << "\" is not consumed by any op in this graph. " << general_comment; } for (const string& output_array : model.flags.output_arrays()) { QCHECK(GetOpWithOutput(model, output_array)) << "Specified output array \"" << output_array << "\" is not produced by any op in this graph. " << general_comment; } for (const auto& rnn_state : model.flags.rnn_states()) { if (!rnn_state.discardable()) { // Check that all RNN states are consumed QCHECK(GetOpWithInput(model, rnn_state.state_array())) << "Specified RNN state \"" << rnn_state.state_array() << "\" is not consumed by any op in this graph. " << general_comment; // Check that all RNN back-edge source arrays are produced QCHECK(GetOpWithOutput(model, rnn_state.back_edge_source_array())) << "Specified RNN back-edge source array \"" << rnn_state.back_edge_source_array() << "\" is not produced by any op in this graph. " << general_comment; } } } } // namespace void CheckNoMissingArray(const Model& model) { for (const auto& op : model.operators) { for (const auto& input : op->inputs) { CHECK(model.HasArray(input) || model.optional_arrays.count(input)) << "Input: " << input << " missing for op: " << op->outputs[0] << "."; } for (const auto& output : op->outputs) { CHECK(model.HasArray(output)) << "Output: " << output << " missing."; } } CheckNonExistentIOArrays(model); } void FixNoMissingArray(Model* model) { for (const auto& op : model->operators) { for (const auto& input : op->inputs) { if (!model->HasArray(input)) { model->GetOrCreateArray(input); } } for (const auto& output : op->outputs) { if (!model->HasArray(output)) { model->GetOrCreateArray(output); } } } if (model->flags.allow_nonexistent_arrays()) { for (const string& output_array : model->flags.output_arrays()) { model->GetOrCreateArray(output_array); } for (const auto& rnn_state : model->flags.rnn_states()) { model->GetOrCreateArray(rnn_state.state_array()); model->GetOrCreateArray(rnn_state.back_edge_source_array()); } } } void CheckNoOrphanedArray(const Model& model) { std::unordered_set arrays_without_known_use; for (const auto& array : model.GetArrayMap()) { if (IsDiscardableArray(model, array.first)) { arrays_without_known_use.insert(array.first); } } for (const auto& op : model.operators) { for (const auto& input : op->inputs) { arrays_without_known_use.erase(input); } for (const auto& output : op->outputs) { arrays_without_known_use.erase(output); } } for (const auto& rnn_state : model.flags.rnn_states()) { arrays_without_known_use.erase(rnn_state.state_array()); arrays_without_known_use.erase(rnn_state.back_edge_source_array()); } if (!arrays_without_known_use.empty()) { for (const auto& array : arrays_without_known_use) { LOG(INFO) << "Error: Orphaned array: " << array; } } CHECK(arrays_without_known_use.empty()); } void FixNoOrphanedArray(Model* model) { std::unordered_set arrays_without_known_use; for (const auto& array : model->GetArrayMap()) { arrays_without_known_use.insert(array.first); } for (const auto& op : model->operators) { for (const auto& input : op->inputs) { arrays_without_known_use.erase(input); } for (const auto& output : op->outputs) { arrays_without_known_use.erase(output); } } for (const auto& rnn_state : model->flags.rnn_states()) { arrays_without_known_use.erase(rnn_state.state_array()); arrays_without_known_use.erase(rnn_state.back_edge_source_array()); } for (const auto& array : arrays_without_known_use) { if (IsDiscardableArray(*model, array)) { model->EraseArray(array); } } } // Apply checks to arrays individually (for-each fashion). // // Check consistency of array fields, check name. void CheckEachArray(const Model& model) { for (const auto& array_entry : model.GetArrayMap()) { const auto& array = array_entry.second; // It's OK to have a buffer or an alloc, but not both. // (Since allocs are for transient arrays without a buffer). CHECK(!array->buffer || !array->alloc); if (array->buffer) { // If there is a buffer, its type should be consistent with data_type. CHECK(array->buffer->type == array->data_type); // The presence of a fixed buffer should imply the presence of a fixed // shape. CHECK(array->has_shape()); // Constant buffer should has a valid shape. CheckValidShape(array->shape()); // The shape flat-size should agree with the buffer length. CHECK_EQ(array->buffer->Length(), RequiredBufferSizeForShape(array->shape())); } // Check name. Either "name_with_suffix_8", "name_with_port:3", but not // "name_with_both:3_8". const string& name = array_entry.first; auto colon_pos = name.find_first_of(":"); if (colon_pos != string::npos) { CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789"), string::npos) << "Array name must only have digits after colon"; } CHECK_GT(colon_pos, 0) << "First character of array name must not be a colon."; } } void CheckOperatorOrdering(const Model& model) { std::unordered_set arrays_behind_us; for (const auto& array_entry : model.GetArrayMap()) { if (!GetOpWithOutput(model, array_entry.first)) { arrays_behind_us.insert(array_entry.first); } } arrays_behind_us.insert(model.optional_arrays.begin(), model.optional_arrays.end()); for (const auto& op : model.operators) { for (const auto& input : op->inputs) { if (!IsConstantParameterArray(model, input)) { CHECK(arrays_behind_us.count(input)); } } for (const auto& output : op->outputs) { CHECK(!arrays_behind_us.count(output)); arrays_behind_us.insert(output); } } for (const string& output_array : model.flags.output_arrays()) { CHECK(arrays_behind_us.count(output_array)); } } void FixOperatorOrdering(Model* model) { std::unordered_set arrays_behind_us; for (const auto& array_entry : model->GetArrayMap()) { if (!GetOpWithOutput(*model, array_entry.first)) { arrays_behind_us.insert(array_entry.first); } } arrays_behind_us.insert(model->optional_arrays.begin(), model->optional_arrays.end()); std::vector> old_operators; std::swap(old_operators, model->operators); std::set remaining; for (std::size_t i = 0; i < old_operators.size(); i++) { remaining.insert(i); } std::unordered_map reason_why_leftover; while (true) { bool inserted_something = false; for (auto i : remaining) { bool can_insert = true; auto& op = old_operators[i]; CHECK(op); for (const auto& input : op->inputs) { if (!IsConstantParameterArray(*model, input) && !arrays_behind_us.count(input)) { for (const string& output : op->outputs) { reason_why_leftover[output] = input; } can_insert = false; break; } } if (can_insert) { model->operators.emplace_back(nullptr); for (const auto& output : op->outputs) { arrays_behind_us.insert(output); } std::swap(op, model->operators.back()); remaining.erase(i); inserted_something = true; break; } } if (!inserted_something) { break; } } if (!remaining.empty()) { LOG(ERROR) << "No viable ordering of operators was found. " << "Here is a 'backtrace' of at least one part of the graph that is " << "problematic. It starts with the first operator that has as " << "problematic input array, and then walks back the graph to " << "the operator that produced that input array, etc., until we find " << "the root cause:"; LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT"; LOG(ERROR) << "Here is the first-encountered operator with a bad input: "; const Operator* bad_op = old_operators[*remaining.begin()].get(); std::unordered_set bad_inputs_already_traced; // The following while(true) loop should always end with a LOG(FATAL). while (true) { LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : " << FormatArraysList(*model, bad_op->inputs) << " -> " << FormatArraysList(*model, bad_op->outputs); bool found_bad_output = false; string bad_output; for (const string& output : bad_op->outputs) { if (reason_why_leftover.count(output)) { found_bad_output = true; bad_output = output; break; } } CHECK(found_bad_output); const string& bad_input = reason_why_leftover[bad_output]; LOG(ERROR) << "The bad input here is: " << bad_input; if (bad_inputs_already_traced.count(bad_input)) { LOG(FATAL) << "Cycle found! We already encountered that " << "input array, " << bad_input << ", earlier in the " << "above trace! We expect graphs to be acyclic, even " << "RNNs. Let us know if some graph actually needs to have " << "cycles, but first, please check if it really is " << "an *inference* graph. *Training* graphs are out-of-scope " << "for toco."; } bad_inputs_already_traced.insert(bad_input); bad_op = nullptr; for (auto i : remaining) { const Operator* op = old_operators[i].get(); for (const string& output : op->outputs) { if (bad_input == output) { bad_op = op; break; } } if (bad_op) { break; } } if (!bad_op) { LOG(ERROR) << "And that's the root cause: " << "that array, " << bad_input << ", isn't produced by any " << "operator, or provided in any other way."; LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT"; LOG(FATAL) << "(The above was a multi-line fatal error)"; } LOG(ERROR) << "And that array is the output of the following operator:"; } } CHECK(remaining.empty()) << "Should never get here! In case of bad graph, " << "the above code should have generated a FATAL error already!"; } void CheckInvariants(const Model& model) { CheckInputArraysAreNotOutputArrays(model.flags); CheckNonAsciiIOArrays(model.flags); CheckNoMissingArray(model); CheckNoOrphanedArray(model); CheckEachArray(model); CheckOperatorOrdering(model); } void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check, const int count, const string& count_description) { if (model_check.count_min() >= 0) { CHECK_GE(count, model_check.count_min()) << "Mismatch in " << count_description << ": count was " << count << ", but the specified " << (model_check.count_max() > model_check.count_min() ? "minimum" : "value") << " was " << model_check.count_min() << "."; } if (model_check.count_max() > model_check.count_min()) { CHECK_LE(count, model_check.count_max()) << "Mismatch in " << count_description << ": count was " << count << ", but the specified maximum was " << model_check.count_max() << "."; } } void CheckModelCounts(const Model& model) { std::unordered_multiset ops_by_type; std::unordered_map op_type_by_name; if (model.flags.model_checks_size() == 0) { return; } for (const auto& op : model.operators) { ops_by_type.insert(op->type); op_type_by_name[OperatorTypeName(op->type)] = op->type; } for (const auto& model_check : model.flags.model_checks()) { string count_type = model_check.count_type(); if (count_type == "None") { continue; } else if (count_type == "Arrays") { CheckCountInRange(model_check, model.GetArrayMap().size(), "count of arrays"); } else if (count_type == "Total") { CheckCountInRange(model_check, model.operators.size(), "count of all operator instances"); } else { // The check type is not itself checked against the set of valid // operators, mainly because the enum set cannot be iterated in C++. const int found_count = op_type_by_name.count(count_type) > 0 ? ops_by_type.count(op_type_by_name[count_type]) : 0; CheckCountInRange(model_check, found_count, "count of instances of " + count_type + " operator"); } } } void FixEdgeArrays(Model* model) { for (const string& output_array_name : model->flags.output_arrays()) { if (!GetOpWithOutput(*model, output_array_name)) { // Output has no operator producing it. Change that by inserting a copy. LOG(WARNING) << "Fixing constant output array " << output_array_name << " by inserting a copy. This is not optimal."; string intermediate_array_name = AvailableArrayName(*model, output_array_name + "_copy"); CloneArray(model, output_array_name, intermediate_array_name); InsertCopyOperator(model, intermediate_array_name, output_array_name); } } } void DedupeConstantArrays(Model* model, size_t min_size) { // Walk all 0..N and compare with the remaining n+1..N. // This lets us avoid N^2 comparisions and erase duplicate arrays while // iterating. const auto& array_map = model->GetArrayMap(); for (auto lhs_array_it = array_map.begin(); lhs_array_it != array_map.end(); ++lhs_array_it) { const auto& lhs_array_name = lhs_array_it->first; const auto& lhs_array = *lhs_array_it->second; if (!IsConstantParameterArray(*model, lhs_array_name)) { // Not a constant array; skip. continue; } ArrayDataType final_data_type = lhs_array.final_data_type != ArrayDataType::kNone ? lhs_array.final_data_type : lhs_array.data_type; // Ignore small arrays, don't check string arrays because it is not possible // to estimate its size. if (final_data_type != ArrayDataType::kString) { size_t array_byte_size = lhs_array.buffer->Length() * ElementSize(final_data_type); if (array_byte_size < min_size) { // Too small; skip. continue; } } auto next_lhs_array_it = lhs_array_it; ++next_lhs_array_it; for (auto rhs_array_it = next_lhs_array_it; rhs_array_it != array_map.end();) { const auto& rhs_array_name = rhs_array_it->first; const auto& rhs_array = *rhs_array_it->second; ++rhs_array_it; if (!IsConstantParameterArray(*model, rhs_array_name)) { // Not a constant array; skip. continue; } if (!IsDiscardableArray(*model, rhs_array_name)) { // Can't remove the array as it's not discardable (such as an IO edge). continue; } if (!CompareConstantArrays(lhs_array, rhs_array)) { // Arrays aren't equal; skip. continue; } // Arrays can be deduped! VLOG(1) << "Deduplicating arrays; using " << lhs_array_name << " in place of " << rhs_array_name; ReplaceArrayUsage(model, rhs_array_name, lhs_array_name); // Note: rhs_array_it above is already incremented so this is safe. model->EraseArray(rhs_array_name); } } } namespace { void CopyArrayAttribs(const Array& source_array, Array* target_array) { target_array->data_type = source_array.data_type; target_array->final_data_type = source_array.final_data_type; target_array->copy_shape(source_array.shape()); if (source_array.minmax) { target_array->GetOrCreateMinMax() = source_array.GetMinMax(); } else { target_array->minmax.reset(); } if (source_array.quantization_params) { target_array->GetOrCreateQuantizationParams() = source_array.GetQuantizationParams(); } else { target_array->quantization_params.reset(); } } } // namespace void InsertCopyOperator(Model* model, const string& source_array_name, const string& target_array_name) { // Reshape to the same size. This should be a no-op. const Array& source_array = model->GetArray(source_array_name); std::vector shape = source_array.shape().dims(); // Drop constant data from the target array as the copy will be done at // runtime. Array& target_array = model->GetOrCreateArray(target_array_name); target_array.buffer.reset(); CopyArrayAttribs(source_array, &target_array); // Insert copy operator. auto* copy_op = new TensorFlowReshapeOperator; copy_op->inputs = { source_array_name, CreateInt32Array( model, AvailableArrayName(*model, target_array_name + "_copy_shape"), shape)}; copy_op->outputs = {target_array_name}; if (target_array.has_shape()) { copy_op->shape = target_array.shape().dims(); } model->operators.emplace_back(copy_op); } void CloneArray(Model* model, const string& source_array_name, const string& target_array_name) { CHECK(!model->HasArray(target_array_name)); const Array& source_array = model->GetArray(source_array_name); Array& target_array = model->GetOrCreateArray(target_array_name); CopyArrayAttribs(source_array, &target_array); if (source_array.minmax) { const auto& smm = source_array.GetMinMax(); auto& tmm = target_array.GetOrCreateMinMax(); tmm.min = smm.min; tmm.max = smm.max; } if (source_array.quantization_params) { const auto& sqp = source_array.GetQuantizationParams(); auto& tqp = target_array.GetOrCreateQuantizationParams(); tqp.zero_point = sqp.zero_point; tqp.scale = sqp.scale; } target_array.data_type = source_array.data_type; target_array.final_data_type = source_array.final_data_type; target_array.copy_shape(source_array.shape()); switch (source_array.data_type) { case ArrayDataType::kBool: CopyArrayBuffer(source_array, &target_array); break; case ArrayDataType::kFloat: CopyArrayBuffer(source_array, &target_array); break; case ArrayDataType::kInt8: CopyArrayBuffer(source_array, &target_array); break; case ArrayDataType::kUint8: CopyArrayBuffer(source_array, &target_array); break; case ArrayDataType::kInt16: CopyArrayBuffer(source_array, &target_array); break; case ArrayDataType::kUint16: CopyArrayBuffer(source_array, &target_array); break; case ArrayDataType::kInt32: CopyArrayBuffer(source_array, &target_array); break; case ArrayDataType::kUint32: CopyArrayBuffer(source_array, &target_array); break; case ArrayDataType::kInt64: CopyArrayBuffer(source_array, &target_array); break; case ArrayDataType::kUint64: CopyArrayBuffer(source_array, &target_array); break; case ArrayDataType::kString: CopyArrayBuffer(source_array, &target_array); break; default: LOG(FATAL) << "Unsupported data type: " << ArrayDataTypeName(source_array.data_type); return; } } void MakeArrayDims(int num_dims, int batch, int height, int width, int depth, std::vector* out_dims) { CHECK(out_dims->empty()); if (num_dims == 0) { return; } else if (num_dims == 1) { CHECK_EQ(batch, 1); *out_dims = {depth}; } else if (num_dims == 2) { *out_dims = {batch, depth}; } else if (num_dims == 3) { CHECK_EQ(batch, 1); *out_dims = {height, width, depth}; } else if (num_dims == 4) { *out_dims = {batch, height, width, depth}; } else { LOG(FATAL) << "Should not get here: " << num_dims; } } void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) { int batch = 1; int num_dims = -1; for (const auto& input_array : model->flags.input_arrays()) { // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find // a better match by name. if (input_array.name() == name || num_dims == -1) { num_dims = input_array.shape().dims_size(); if (num_dims > 0) { batch = input_array.shape().dims(0); } } } Array& array = model->GetOrCreateArray(name); if (array.has_shape()) { num_dims = array.shape().dimensions_count(); } if (!array.has_shape() && num_dims >= 0) { Shape* shape = array.mutable_shape(); std::vector dims; MakeArrayDims(num_dims, batch, 1, 1, size, &dims); *shape->mutable_dims() = dims; } } void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { // Merge info about input_arrays from model_flags into model->flags for (const auto& specified_input_array : model_flags.input_arrays()) { toco::InputArray* dst_input_array = nullptr; for (int i = 0; i < model->flags.input_arrays_size(); i++) { toco::InputArray* candidate_dst_input_array = model->flags.mutable_input_arrays(i); if (candidate_dst_input_array->name() == specified_input_array.name()) { // specified_input_array from model_flags maps to dst_input_array // in model->flags dst_input_array = candidate_dst_input_array; break; } } if (!dst_input_array) { // Specified_input_array from model_flags is not found in model->flags. // Match a name-less specified input array when there can be no ambiguity // as there is only 1 input array. if (model->flags.input_arrays_size() == 1 && model_flags.input_arrays_size() == 1 && !specified_input_array.has_name()) { dst_input_array = model->flags.mutable_input_arrays(0); } } if (!dst_input_array) { // Still no match, so create a new input array to copy // specified_input_array into. dst_input_array = model->flags.add_input_arrays(); dst_input_array->set_name(specified_input_array.name()); } #define RESOLVE_MODEL_FLAG(field_name) \ if (specified_input_array.has_##field_name()) { \ if (dst_input_array->has_##field_name()) { \ QCHECK_EQ(dst_input_array->field_name(), \ specified_input_array.field_name()) \ << "For input array '" << dst_input_array->name() << "', " \ << "specified " #field_name " flag with value: " \ << specified_input_array.field_name() \ << " does not agree with already defined " #field_name \ " of this model, with value: " \ << specified_input_array.field_name(); \ } else { \ dst_input_array->set_##field_name(specified_input_array.field_name()); \ } \ } RESOLVE_MODEL_FLAG(std_value); RESOLVE_MODEL_FLAG(mean_value); #undef RESOLVE_MODEL_FLAG if (specified_input_array.has_shape()) { if (dst_input_array->has_shape()) { QCHECK_EQ(specified_input_array.shape().dims_size(), dst_input_array->shape().dims_size()) << "For input array '" << specified_input_array.name() << "', " << "size of specified input shape flag with size: " << specified_input_array.shape().dims_size() << " does not agree with already defined input shape" " of this model, with size: " << dst_input_array->shape().dims_size(); // We treat the first dimension as a special case, since it is often // a batch size and the input_shape flag is effectively overriding // the model. for (int i = 1; i < specified_input_array.shape().dims_size(); i++) { QCHECK_EQ(specified_input_array.shape().dims(i), dst_input_array->shape().dims(i)) << "At dimension number " << i << " of input array " << specified_input_array.name() << ", the specified shape's " << "dimension flag with dimension: " << specified_input_array.shape().dims(i) << " does not agree with already defined shape" << " of this model, with dimension: " << dst_input_array->shape().dims(i); } } else { *dst_input_array->mutable_shape() = specified_input_array.shape(); } } if (specified_input_array.has_data_type()) { QCHECK(!dst_input_array->has_data_type()); dst_input_array->set_data_type(specified_input_array.data_type()); } } if (model_flags.output_arrays_size() > 0) { model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays()); } #define RESOLVE_MODEL_FLAG(name) \ if (model_flags.has_##name()) { \ if (model->flags.has_##name()) { \ QCHECK_EQ(model_flags.name(), model->flags.name()) \ << "Specified " #name " flag with value: " << model_flags.name() \ << " does not agree with already defined " #name \ " of this model, with value: " \ << model->flags.name(); \ } else { \ model->flags.set_##name(model_flags.name()); \ } \ } RESOLVE_MODEL_FLAG(variable_batch) #undef RESOLVE_MODEL_FLAG if (!model_flags.rnn_states().empty()) { model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states()); } if (model->flags.model_checks_size() == 0) { model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks()); } QCHECK_GT(model->flags.output_arrays_size(), 0) << "This model does not define output arrays, so a " "--output_arrays flag must be given on the command-line."; for (auto& input_array_proto : *model->flags.mutable_input_arrays()) { auto& input_array = model->GetOrCreateArray(input_array_proto.name()); if (input_array_proto.has_data_type()) { const ArrayDataType specified_type = ConvertIODataTypeToArrayDataType(input_array_proto.data_type()); QCHECK(specified_type != ArrayDataType::kNone); if (input_array.data_type != ArrayDataType::kNone) { QCHECK(specified_type == input_array.data_type) << "For input array " << input_array_proto.name() << " the specified input data type " << IODataType_Name(input_array_proto.data_type()) << " conflicts with the existing type."; } input_array.data_type = specified_type; } if (input_array.data_type == ArrayDataType::kNone) { // We start out with a float input array; // that may get replaced by a uint8 array later, by // MakeInitialDequantizeOp. input_array.data_type = ArrayDataType::kFloat; } // Compare/merge the model->flags describing the input_shape with // the actual input array's shape. if (!input_array.has_shape()) { if (input_array_proto.has_shape()) { auto& input_array_dims = *input_array.mutable_shape()->mutable_dims(); CheckValidShapeDimensions(input_array_proto.shape().dims()); for (auto dim : input_array_proto.shape().dims()) { input_array_dims.push_back(dim); } } } else { if (input_array_proto.has_shape()) { // If an input shape was specified on the flags ensure that it matches // the actual shape in the model. const auto& input_array_dims = *input_array.mutable_shape()->mutable_dims(); CHECK_EQ(input_array_dims.size(), input_array_proto.shape().dims_size()); for (int i = 0; i < input_array_dims.size(); i++) { CHECK_EQ(input_array_dims[i], input_array_proto.shape().dims(i)); } } else { for (int i = 0; i < input_array.shape().dimensions_count(); i++) { input_array_proto.mutable_shape()->add_dims( input_array.shape().dims(i)); } } } const float mean_value = input_array_proto.mean_value(); const float std_value = input_array_proto.std_value(); MinMax input_minmax; float qmin = 0, qmax = 255; if (input_array.data_type == ArrayDataType::kInt16) { qmin = -32768; qmax = 32767; } input_minmax.min = (qmin - mean_value) / std_value; input_minmax.max = (qmax - mean_value) / std_value; if (!input_array.minmax) { input_array.GetOrCreateMinMax() = input_minmax; } } // Creation of the RNN state arrays for (const auto& rnn_state : model->flags.rnn_states()) { CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(), model); } model->flags.set_change_concat_input_ranges( model_flags.change_concat_input_ranges()); model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays()); model->flags.set_allow_nonexistent_arrays( model_flags.allow_nonexistent_arrays()); CHECK(!model->flags.has_arrays_extra_info()); *model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info(); } void CheckIsReadyForQuantization(const Model& model) { for (const auto& op : model.operators) { for (const auto& input : op->inputs) { const auto& input_array = model.GetArray(input); if (input_array.data_type != ArrayDataType::kFloat) { // The array is not floats, no quantization needed. continue; } if (input_array.minmax) { // The array has minmax, we're good. continue; } if (input_array.buffer) { // The array has a constant buffer, so we can // fall back to computing the minmax from actual array entries // (with a WARNING about possible accuracy implications). continue; } LOG(FATAL) << "Array " << input << ", which is an input to the " << HelpfulOperatorTypeName(*op) << " operator producing the output " << "array " << op->outputs[0] << ", is lacking min/max data, " << "which is necessary for quantization. If accuracy matters, either " << "target a non-quantized output format, or run quantized training " << "with your model from a floating point checkpoint to change the " << "input graph to contain min/max information. If you don't care " << "about accuracy, you can pass --default_ranges_min= and " << "--default_ranges_max= for easy experimentation."; } } } int ElementSize(ArrayDataType data_type) { switch (data_type) { case ArrayDataType::kBool: return sizeof(bool); case ArrayDataType::kFloat: return 4; case ArrayDataType::kInt8: return 1; case ArrayDataType::kUint8: return 1; case ArrayDataType::kInt16: return 2; case ArrayDataType::kUint16: return 2; case ArrayDataType::kInt32: return 4; case ArrayDataType::kUint32: return 4; case ArrayDataType::kInt64: return 8; case ArrayDataType::kUint64: return 8; // Usually not critical limitation because strings are only input and/or // output. case ArrayDataType::kString: LOG(FATAL) << "Transient arrays with strings are not supported yet"; return 0; default: LOG(FATAL) << "Unknown data_type = " << static_cast(data_type); return 0; } } void DropMinMax(Model* model, const string& array_name) { auto& array = model->GetArray(array_name); if (!!array.minmax) { LOG(WARNING) << "Dropping MinMax information in array " << array_name << ". Expect inaccuracy in quantized inference."; array.minmax = nullptr; } } bool IsAllocatableTransientArray(const Model& model, const string& array_name) { // Optional array is not transient if (model.IsOptionalArray(array_name)) return false; // The model's input and output arrays are externally allocated. // They are not transient arrays. if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) { return false; } const auto& array = &model.GetArray(array_name); // An array with a constant buffer isn't a transient array. if (!!array->buffer) { return false; } // An array without shape isn't allocatable. if (!array->has_shape()) { return false; } return true; } string AvailableArrayName(const Model& model, const string& name) { string sanitized_name = SanitizeNameForTFNode(name); if (!model.HasArray(sanitized_name) && !model.IsOptionalArray(sanitized_name)) { return sanitized_name; } const int kNumSuffixesToTry = 1000; for (int i = 0; i < kNumSuffixesToTry; i++) { const string& name_with_suffix = toco::port::StringF("%s_%d", sanitized_name, i); if (!model.HasArray(name_with_suffix) && !model.IsOptionalArray(name_with_suffix)) { return name_with_suffix; } } LOG(FATAL) << "Could not find an available array name starting with " << sanitized_name << ". Tried " << kNumSuffixesToTry << " suffixes, all were taken!"; return ""; } string ShapeToString(const Shape& shape) { if (shape.dimensions_count() == 0) { return "[]"; } return absl::StrCat("[ ", absl::StrJoin(shape.dims(), ", "), " ]"); } void PrintArrayShape(Model* model, const string& name) { if (!model->GetArray(name).has_shape()) { LOG(INFO) << name << " has no shape"; return; } LOG(INFO) << name << " has shape: " << ShapeToString(model->GetArray(name).shape()); } bool IsArrayFullyConnectedWeights(const Model& model, const string& name) { bool is_fc_weights = false; bool is_something_else = false; for (const auto& op : model.operators) { for (int input_index = 0; input_index < op->inputs.size(); input_index++) { if (op->inputs[input_index] == name) { if (op->type == OperatorType::kFullyConnected && input_index == 1) { is_fc_weights = true; } else { is_something_else = true; } } } } CHECK(!(is_fc_weights && is_something_else)); return is_fc_weights; } string CreateInt32Array(Model* model, const string& param_name, const std::vector& value) { auto param_array_name = AvailableArrayName(*model, param_name); auto& param_array = model->GetOrCreateArray(param_array_name); param_array.mutable_shape()->ReplaceDims({static_cast(value.size())}); param_array.data_type = ArrayDataType::kInt32; auto& param_array_data = param_array.GetMutableBuffer().data; param_array_data.resize(RequiredBufferSizeForShape(param_array.shape())); for (int i = 0; i < value.size(); ++i) { param_array_data[i] = value[i]; } return param_array_name; } bool EstimateArithmeticOpsCount(const Model& model, int64* result) { int64 total = 0; for (const auto& op : model.operators) { switch (op->type) { case OperatorType::kFullyConnected: case OperatorType::kConv: case OperatorType::kDepthwiseConv: { const auto& output_array = model.GetArray(op->outputs[0]); const auto& weights_array = model.GetArray(op->inputs[1]); if (!output_array.has_shape() || !weights_array.has_shape()) { return false; } int cols = 1; for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) { cols *= output_array.shape().dims(i); } const int64 cost_per_col = 2 * RequiredBufferSizeForShape(weights_array.shape()); total += cost_per_col * cols; if (op->inputs.size() > 2) { // There is a bias vector. One more op per output value. total += RequiredBufferSizeForShape(output_array.shape()); } break; } case OperatorType::kAdd: case OperatorType::kSub: case OperatorType::kMul: { const auto& output_array = model.GetArray(op->outputs[0]); if (!output_array.has_shape()) { return false; } total += RequiredBufferSizeForShape(output_array.shape()); break; } case OperatorType::kAddN: { const auto& output_array = model.GetArray(op->outputs[0]); if (!output_array.has_shape()) { return false; } // AddN cost is roughly the same cost as N-1 Adds. const int num_adds = op->inputs.size() - 1; total += num_adds * RequiredBufferSizeForShape(output_array.shape()); break; } case OperatorType::kLogistic: case OperatorType::kSoftmax: case OperatorType::kLogSoftmax: case OperatorType::kTanh: { const auto& output_array = model.GetArray(op->outputs[0]); if (!output_array.has_shape()) { return false; } // As a very rough ballpark, the cost of evaluating a math function // such as tanh or logistic is about 32 multiplications, and about as // many additions/subtractions. (Just a power-of-two order-of-magnitude // from looking at actual implementations that we use in runtime/ code). total += 64 * RequiredBufferSizeForShape(output_array.shape()); break; } case OperatorType::kMaxPool: { const auto& maxpool = *static_cast(op.get()); const auto& output_array = model.GetArray(op->outputs[0]); if (!output_array.has_shape()) { return false; } total += RequiredBufferSizeForShape(output_array.shape()) * maxpool.kheight * maxpool.kwidth; break; } case OperatorType::kAveragePool: { const auto& avgpool = *static_cast(op.get()); const auto& output_array = model.GetArray(op->outputs[0]); if (!output_array.has_shape()) { return false; } total += RequiredBufferSizeForShape(output_array.shape()) * avgpool.kheight * avgpool.kwidth; break; } case OperatorType::kL2Pool: { const auto* maxpool = static_cast(op.get()); const auto& output_array = model.GetArray(op->outputs[0]); if (!output_array.has_shape()) { return false; } // The sum of squares requires (kheight*kwidth) multiply-adds, // and then there is the sqrt which we ballpark at 32 ops. const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32; total += RequiredBufferSizeForShape(output_array.shape()) * cost_per_val; break; } case OperatorType::kL2Normalization: { const auto& output_array = model.GetArray(op->outputs[0]); if (!output_array.has_shape()) { return false; } // Computing the squared L2 norm is N multiply-adds so 2N ops, // then the single inverse-sqrt is negligible, then we multiply each // value by the resulting multiplier, so an extra N ops. Total 3N ops. total += 3 * RequiredBufferSizeForShape(output_array.shape()); break; } default: break; } } *result = total; return true; } void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, std::vector* shuffle) { CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order)); shuffle->resize(4); for (int i = 0; i < 4; i++) { (*shuffle)[i] = i; } if (input_axes_order == output_axes_order) { // nothing to do } else if (AxesCount(input_axes_order) == 2) { shuffle->resize(2); (*shuffle)[0] = 1; (*shuffle)[1] = 0; } else if (input_axes_order == AxesOrder::kOHWI && output_axes_order == AxesOrder::kHWIO) { // 3210 <- 3210 // HWIO <- OHWI *shuffle = {1, 2, 3, 0}; } else if (input_axes_order == AxesOrder::kHWIO && output_axes_order == AxesOrder::kOHWI) { // 3210 <- 3210 // OHWI <- HWIO *shuffle = {3, 0, 1, 2}; } else if (input_axes_order == AxesOrder::kOHWI && output_axes_order == AxesOrder::kHWOI) { *shuffle = {1, 2, 0, 3}; } else { LOG(FATAL) << "Bad shuffle"; } } void ExtendShuffle(const std::vector& input_shuffle, int newdim, std::vector* extended_shuffle) { *extended_shuffle = input_shuffle; CHECK(newdim >= input_shuffle.size()); const int pad_size = newdim - input_shuffle.size(); extended_shuffle->resize(newdim); for (int i = 0; i < pad_size; i++) { (*extended_shuffle)[i] = i; } for (int i = pad_size; i < newdim; i++) { (*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size; } } void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order, AxesOrder output_axes_order, Shape* output_shape) { if (input_axes_order == AxesOrder::kHWIM && output_axes_order == AxesOrder::k1HWO) { // This special case isn't just a permutation, the IM pair of dims get // merged into the 3 dim, so we have to special-case it. *output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1), input_shape.dims(3) * input_shape.dims(2)}); } else { std::vector shuffle; GetShuffleShape(input_axes_order, output_axes_order, &shuffle); std::vector* output_dims = output_shape->mutable_dims(); output_dims->resize(input_shape.dimensions_count()); for (int i = 0; i < input_shape.dimensions_count(); i++) { (*output_dims)[i] = input_shape.dims(shuffle[i]); } } } template void ShuffleArrayTemplate(const Shape& input_shape, AxesOrder input_axes_order, AxesOrder output_axes_order, const Shape& output_shape, const T* input_data, T* output_data) { if (input_axes_order == AxesOrder::kHWIM && output_axes_order == AxesOrder::k1HWO) { // This special case isn't just a permutation, the IM pair of dims get // merged into the O dim, so we have to special-case it. Fortunately, // as far as array shuffling is concerned, it's just the identity // transformation. memcpy(output_data, input_data, RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0])); return; } CHECK(input_shape.dimensions_count() == output_shape.dimensions_count()); const int dim = input_shape.dimensions_count(); CHECK_LE(dim, 4); std::vector shuffle; GetShuffleShape(input_axes_order, output_axes_order, &shuffle); CHECK(shuffle.size() >= dim); for (int i = 0; i < dim; i++) { CHECK(shuffle[i] >= 0 && shuffle[i] < dim); CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i)); } Shape extended_input_shape = input_shape; ExtendShape(&extended_input_shape, 4); Shape extended_output_shape = output_shape; ExtendShape(&extended_output_shape, 4); std::vector extended_shuffle; ExtendShuffle(shuffle, 4, &extended_shuffle); const std::vector& extended_input_dims = extended_input_shape.dims(); const std::vector& extended_output_dims = extended_output_shape.dims(); // TODO(starka): Rework to handle different numbers of dimensions. int input_strides[4]; input_strides[3] = 1; input_strides[2] = extended_input_dims[3]; input_strides[1] = input_strides[2] * extended_input_dims[2]; input_strides[0] = input_strides[1] * extended_input_dims[1]; const int input_stride_0 = input_strides[extended_shuffle[3]]; const int input_stride_1 = input_strides[extended_shuffle[2]]; const int input_stride_2 = input_strides[extended_shuffle[1]]; const int input_stride_3 = input_strides[extended_shuffle[0]]; const int output_size_0 = extended_output_dims[3]; const int output_size_1 = extended_output_dims[2]; const int output_size_2 = extended_output_dims[1]; const int output_size_3 = extended_output_dims[0]; const int output_stride_0 = 1; const int output_stride_1 = output_size_0; const int output_stride_2 = output_stride_1 * output_size_1; const int output_stride_3 = output_stride_2 * output_size_2; for (int i3 = 0; i3 < output_size_3; i3++) { const T* const input_ptr_3 = input_data + i3 * input_stride_3; T* const output_ptr_3 = output_data + i3 * output_stride_3; for (int i2 = 0; i2 < output_size_2; i2++) { const T* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2; T* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2; for (int i1 = 0; i1 < output_size_1; i1++) { const T* input_ptr = input_ptr_2 + i1 * input_stride_1; T* output_ptr = output_ptr_2 + i1 * output_stride_1; T* const output_ptr_end = output_ptr + output_size_0 * output_stride_0; while (output_ptr != output_ptr_end) { *output_ptr = *input_ptr; input_ptr += input_stride_0; output_ptr += output_stride_0; } } } } } void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order, AxesOrder output_axes_order, const Shape& output_shape, const uint8* input_data, uint8* output_data) { ShuffleArrayTemplate(input_shape, input_axes_order, output_axes_order, output_shape, input_data, output_data); } void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order, AxesOrder output_axes_order, const Shape& output_shape, const float* input_data, float* output_data) { ShuffleArrayTemplate(input_shape, input_axes_order, output_axes_order, output_shape, input_data, output_data); } int AxesCount(AxesOrder axes_order) { switch (axes_order) { case AxesOrder::kOneAxis: return 1; case AxesOrder::kRC: return 2; case AxesOrder::kCR: return 2; case AxesOrder::kHWIO: return 4; case AxesOrder::kOHWI: return 4; case AxesOrder::kHWIM: return 4; case AxesOrder::k1HWO: return 4; case AxesOrder::kNHWC: return 4; case AxesOrder::kHWOI: return 4; default: LOG(FATAL) << "Bad AxesOrder"; return 0; } } bool IsDiscardableArray(const Model& model, const string& array_name) { if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) { return false; } for (const auto& rnn_state : model.flags.rnn_states()) { if (!rnn_state.discardable()) { if (array_name == rnn_state.state_array()) { return false; } if (array_name == rnn_state.back_edge_source_array()) { return false; } } } return true; } bool ReshapeIsEquivalentToTranspose(const Model& model, const TensorFlowReshapeOperator* op, bool allow_extra_unary_dims) { CHECK(!op->shape.empty()); CHECK(model.HasArray(op->inputs[0])); CHECK(model.HasArray(op->outputs[0])); const auto& input_array = model.GetArray(op->inputs[0]); const auto& output_array = model.GetArray(op->outputs[0]); CHECK(input_array.has_shape()); CHECK(output_array.has_shape()); std::vector in_shape = input_array.shape().dims(); std::vector out_shape = output_array.shape().dims(); // If the reshape changes the number of dimensions so it cannot be interpreted // as a transpose. if (!allow_extra_unary_dims && in_shape.size() != out_shape.size()) { return false; } in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1), in_shape.end()); out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1), out_shape.end()); return in_shape == out_shape; } void CheckFinalDataTypesSatisfied(const Model& model) { for (const auto& array_entry : model.GetArrayMap()) { const auto& array = *array_entry.second; if (array.data_type == ArrayDataType::kBool) { // Boolean values are never quantized. continue; } // If the final data type is int16, the data type may be float, for example // after dequantization. if (array.final_data_type != ArrayDataType::kNone && array.final_data_type != ArrayDataType::kInt16) { CHECK(array.data_type == array.final_data_type) << "Array \"" << array_entry.first << "\" has mis-matching actual and final data types (data_type=" << ArrayDataTypeName(array.data_type) << ", final_data_type=" << ArrayDataTypeName(array.final_data_type) << ")."; } } } ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) { switch (type) { case FLOAT: return ArrayDataType::kFloat; case QUANTIZED_UINT8: return ArrayDataType::kUint8; case QUANTIZED_INT16: return ArrayDataType::kInt16; case INT32: return ArrayDataType::kInt32; case INT64: return ArrayDataType::kInt64; case BOOL: return ArrayDataType::kBool; default: return ArrayDataType::kNone; } } void FinishBuildingRNNStates(Model* model) { for (const auto& rnn_state : model->flags.rnn_states()) { if (!model->HasArray(rnn_state.back_edge_source_array()) || !model->HasArray(rnn_state.state_array())) { CHECK(model->HasArray(rnn_state.back_edge_source_array())); CHECK(model->HasArray(rnn_state.state_array())); continue; } const auto& src_array = model->GetArray(rnn_state.back_edge_source_array()); auto& dst_array = model->GetArray(rnn_state.state_array()); if (src_array.data_type == ArrayDataType::kNone && dst_array.data_type == ArrayDataType::kNone) { dst_array.data_type = ArrayDataType::kFloat; } } } // Returns the array names that match the ArraysExtraInfo's name and // name_regexp. The regexp match is for a full match. std::unordered_set ScanArrayNames( const Model& model, const toco::ArraysExtraInfo_Entry& entry) { std::unordered_set matches; if (model.HasArray(entry.name())) { matches.insert(entry.name()); } if (!entry.name_regexp().empty()) { const auto& arrays = model.GetArrayMap(); const RE2 name_regexp = {entry.name_regexp()}; for (auto it = arrays.begin(); it != arrays.end(); ++it) { if (RE2::FullMatch(it->first, name_regexp)) { matches.insert(it->first); } } } return matches; } void UseArraysExtraInfo(Model* model, bool quantize_output) { for (const auto& entry : model->flags.arrays_extra_info().entries()) { const auto matches = ScanArrayNames(*model, entry); for (const auto& matched_name : matches) { auto& array = model->GetArray(matched_name); if (entry.has_min() || entry.has_max()) { CHECK_EQ(entry.has_min(), entry.has_max()); auto& minmax = array.GetOrCreateMinMax(); minmax.min = entry.min(); minmax.max = entry.max(); } if (entry.has_data_type() && quantize_output) { array.final_data_type = ConvertIODataTypeToArrayDataType(entry.data_type()); } if (entry.has_shape()) { array.clear_shape(); // Make sure to create the shape even if there are no dims, to // correctly record 0-D shapes. array.mutable_shape(); for (int dim : entry.shape().dims()) { array.mutable_shape()->mutable_dims()->push_back(dim); } } if (entry.has_constant_float_value()) { CHECK(array.has_shape()); if (array.data_type == ArrayDataType::kFloat) { auto& data = array.GetMutableBuffer().data; data.resize(RequiredBufferSizeForShape(array.shape())); for (float& f : data) { f = entry.constant_float_value(); } } } } } } void UndoWeightsShuffling(Model* model) { for (const auto& op : model->operators) { if (op->type != toco::OperatorType::kFullyConnected) { continue; } const auto& fc_op = static_cast(*op); if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) { continue; } const string& weights_name = fc_op.inputs[1]; QCHECK_EQ(CountOpsWithInput(*model, weights_name), 1); auto& weights_array = model->GetArray(weights_name); QCHECK(weights_array.data_type == ArrayDataType::kUint8); auto& weights_data = weights_array.GetMutableBuffer().data; const auto& weights_shape = weights_array.shape(); QCHECK_EQ(weights_shape.dimensions_count(), 2); const int rows = weights_shape.dims(0); const int cols = weights_shape.dims(1); QCHECK_EQ(rows % 4, 0); QCHECK_EQ(cols % 16, 0); CHECK_EQ(rows * cols, weights_data.size()); // Compute the de-shuffled weights std::vector deshuffled_data(weights_data.size()); uint8* shuffled_data_ptr = weights_data.data(); for (int r = 0; r < rows; r += 4) { for (int c = 0; c < cols; c += 16) { for (int i = 0; i < 4; i++) { uint8* deshuffled_data_ptr = deshuffled_data.data() + (r + i) * cols + c; for (int j = 0; j < 16; j++) { uint8 shuffled_val = *shuffled_data_ptr++; // Deshuffling isn't only about deshuffling the storage layout, // it's also about undoing the flipping of the sign bit, which is // performed on the shuffled weights. uint8 deshuffled_val = shuffled_val ^ 0x80; *deshuffled_data_ptr++ = deshuffled_val; } } } } CHECK_EQ(shuffled_data_ptr, weights_data.data() + rows * cols); // Switch this FC op to using the deshuffled weights. weights_data = std::move(deshuffled_data); } } void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst) { if (src.minmax) { dst->GetOrCreateMinMax() = src.GetMinMax(); } if (src.quantization_params) { dst->GetOrCreateQuantizationParams() = src.GetQuantizationParams(); } dst->narrow_range = src.narrow_range; } } // namespace toco