path: root/tensorflow/contrib/lite/tools
diff options
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-08-23 11:53:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 12:01:33 -0700
commit288b8a4368fe1f35f71911bf2d9055a5170ed890 (patch)
treecdd81c8791ce5635eea34df36521e84e1ddfc95e /tensorflow/contrib/lite/tools
parent15113cd567f630cd8806deeb82e608357ebed8c3 (diff)
TFLite quantize_weights tool.
PiperOrigin-RevId: 209974391
Diffstat (limited to 'tensorflow/contrib/lite/tools')
4 files changed, 459 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/tools/optimize/BUILD b/tensorflow/contrib/lite/tools/optimize/BUILD
new file mode 100644
index 0000000000..01fbce0ac7
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/BUILD
@@ -0,0 +1,11 @@
+# TODO(suharshs): Write quantize_weights tests that use small exportable files.
+# Then we can remove this file.
+ default_visibility = ["//visibility:public"],
+licenses(["notice"]) # Apache 2.0
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
new file mode 100644
index 0000000000..0758514e39
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
@@ -0,0 +1,280 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+See the License for the specific language governing permissions and
+limitations under the License.
+#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h"
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/core/platform/logging.h"
+namespace tflite {
+namespace optimize {
+namespace {
+// The minimum number of elements a weights array must have to be quantized
+// by this transformation.
+// TODO(suharshs): Make this configurable.
+const int kWeightsMinSize = 1024;
+// Nudge min and max so that floating point 0 falls exactly on a quantized
+// value, returning the nudges scale and zero_point.
+// Although this code originates from FakeQuantization in quantized training,
+// we may deviate from that implementation as we please since we do not fine
+// tune the weights with quantized training.
+void GetQuantizationParams(const float min, const float max,
+ const int quant_min, const int quant_max,
+ QuantizationParametersT* quantization_params) {
+ // Adjust the boundaries to guarantee 0 is included.
+ const float quant_min_float = std::min(static_cast<float>(quant_min), 0.0f);
+ const float quant_max_float = std::max(static_cast<float>(quant_max), 0.0f);
+ const float scale = (max - min) / (quant_max_float - quant_min_float);
+ const float zero_point_from_min = quant_min_float - min / scale;
+ int64_t zero_point;
+ if (zero_point_from_min < quant_min_float) {
+ zero_point = static_cast<int64_t>(quant_min);
+ } else if (zero_point_from_min > quant_max_float) {
+ zero_point = static_cast<int64_t>(quant_max);
+ } else {
+ zero_point = static_cast<int64_t>(std::round(zero_point_from_min));
+ }
+ quantization_params->scale = {scale};
+ quantization_params->zero_point = {zero_point};
+// Returns the number of elements in tensor.
+uint64 NumElements(const TensorT* tensor) {
+ if (tensor->shape.empty()) {
+ LOG(FATAL) << "Tensor has no shape information.";
+ }
+ uint64 num_elements = 1;
+ for (const uint64 dim : tensor->shape) {
+ num_elements *= dim;
+ }
+ return num_elements;
+uint64 CountTensorConsumers(const ModelT* model, const SubGraphT* subgraph,
+ int32_t tensor_idx) {
+ uint64 count = 0;
+ for (int op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) {
+ const OperatorT* op = subgraph->operators[op_idx].get();
+ if (op == nullptr) {
+ continue;
+ }
+ for (int i = 0; i < op->inputs.size(); ++i) {
+ if (op->inputs[i] == tensor_idx) {
+ count++;
+ }
+ }
+ }
+ return count;
+// Returns true if the Operator's weight tensor should be quantized.
+bool GetQuantizableTensorFromOperator(const ModelT* model, const OperatorT* op,
+ TensorT** tensor, int32_t* tensor_idx,
+ int32_t* op_input_index) {
+ SubGraphT* subgraph = model->subgraphs.at(0).get();
+ const BuiltinOperator op_code =
+ model->operator_codes[op->opcode_index]->builtin_code;
+ if (op_code == BuiltinOperator_CONV_2D ||
+ op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
+ op_code == BuiltinOperator_FULLY_CONNECTED ||
+ op_code == BuiltinOperator_SVDF) {
+ *op_input_index = 1;
+ } else if (op_code == BuiltinOperator_LSTM) {
+ // TODO(suharshs): Add RNN, and sequential/bidi versions.
+ *op_input_index = 2;
+ } else {
+ return false;
+ }
+ *tensor_idx = op->inputs[*op_input_index];
+ // TODO(suharshs): Support shared weights, i.e. If two tensors share the
+ // same weight array, things may break. (i.e. SSD object detection)
+ if (CountTensorConsumers(model, subgraph, *tensor_idx) != 1) {
+ LOG(INFO) << "Skipping quantization of tensor that is shared between "
+ "multiple multiple operations.";
+ return false;
+ }
+ *tensor = subgraph->tensors[*tensor_idx].get();
+ if ((*tensor)->type != TensorType_FLOAT32) {
+ LOG(INFO) << "Skipping quantization of tensor that is not type float.";
+ return false;
+ }
+ const uint64 num_elements = NumElements(*tensor);
+ if (num_elements < kWeightsMinSize) {
+ LOG(INFO) << "Skipping quantization of tensor because it has fewer than "
+ << kWeightsMinSize << " elements (" << num_elements << ").";
+ return false;
+ }
+ return true;
+// Quantizes tensor using asymmetric quantization with the min and max elements
+// of the tensor. This is needed to pass to Dequantize operations.
+TfLiteStatus AsymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
+ BufferT* buffer = model->buffers[tensor->buffer].get();
+ float* float_data = reinterpret_cast<float*>(buffer->data.data());
+ const uint64 num_elements = NumElements(tensor);
+ LOG(INFO) << "Quantizing tensor with " << num_elements << " elements.";
+ // Compute the quantization params.
+ float min_value = *std::min_element(float_data, float_data + num_elements);
+ float max_value = *std::max_element(float_data, float_data + num_elements);
+ GetQuantizationParams(min_value, max_value, 0, 255,
+ tensor->quantization.get());
+ // Quantize the buffer.
+ std::vector<uint8_t> quantized_buffer;
+ quantized_buffer.resize(num_elements);
+ const double inverse_scale = 1. / tensor->quantization->scale[0];
+ for (std::size_t i = 0; i < num_elements; i++) {
+ const float src_val = float_data[i];
+ double scaled_val;
+ if (tensor->quantization->scale[0] == 0) {
+ scaled_val = tensor->quantization->zero_point[0];
+ } else {
+ scaled_val =
+ tensor->quantization->zero_point[0] + inverse_scale * src_val;
+ }
+ uint8_t integer_val = static_cast<uint8_t>(std::round(scaled_val));
+ quantized_buffer[i] = integer_val;
+ }
+ model->buffers[tensor->buffer]->data = quantized_buffer;
+ // Update the tensor type.
+ tensor->type = TensorType_UINT8;
+ return kTfLiteOk;
+// Returns the index of the Dequantize op_code.
+// If a Dequantize op_code doesn't exist, adds it and returns its index.
+int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) {
+ for (int i = 0; i < model->operator_codes.size(); ++i) {
+ if (model->operator_codes[i]->builtin_code == BuiltinOperator_DEQUANTIZE) {
+ return i;
+ }
+ }
+ model->operator_codes.push_back(std::make_unique<OperatorCodeT>());
+ int op_code_idx = model->operator_codes.size() - 1;
+ model->operator_codes[op_code_idx]->builtin_code = BuiltinOperator_DEQUANTIZE;
+ // TODO(suharshs): How should the version be set in this op_code?
+ // Return the index of the newly placed OperatorCodeT.
+ return op_code_idx;
+// Creates a Dequantize OperatorT object.
+void MakeDequantizeOperator(ModelT* model, std::unique_ptr<OperatorT>* op,
+ int32_t input, int32_t output) {
+ OperatorT* op_raw = new OperatorT;
+ op_raw->opcode_index = GetOrInsertDequantizeOpCodeIndex(model);
+ op_raw->inputs = {input};
+ op_raw->outputs = {output};
+ op->reset(op_raw);
+// Create a new TensorT object.
+void MakeTensor(const string& name, const std::vector<int32_t>& shape,
+ std::unique_ptr<TensorT>* tensor) {
+ TensorT* tensor_raw = new TensorT;
+ tensor_raw->name = name;
+ tensor_raw->shape = shape;
+ tensor->reset(tensor_raw);
+} // namespace
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model) {
+ std::unique_ptr<ModelT> model;
+ model.reset(input_model->UnPack());
+ // TODO(suharshs): When models support multiple subgraphs, add support.
+ if (model->subgraphs.size() != 1) {
+ LOG(ERROR) << "Quantize weights tool only supports tflite models with one "
+ "subgraph.";
+ return kTfLiteError;
+ }
+ SubGraphT* subgraph = model->subgraphs.at(0).get();
+ std::vector<std::unique_ptr<OperatorT>> new_operators;
+ for (int i = 0; i < subgraph->operators.size(); ++i) {
+ OperatorT* op = subgraph->operators[i].get();
+ TensorT* tensor;
+ // The index of the weight tensor in subgraph->tensors.
+ int32_t tensor_idx;
+ int32_t op_input_idx; // The index of tensor_idx in the op->inputs.
+ // TODO(suharshs): Support hybrid ops that require symmetric quantization.
+ if (GetQuantizableTensorFromOperator(model.get(), op, &tensor, &tensor_idx,
+ &op_input_idx)) {
+ // Quantize the tensors.
+ TF_LITE_ENSURE_STATUS(AsymmetricQuantizeTensor(model.get(), tensor));
+ // Create a new tensor to be the output of the dequantize op.
+ std::unique_ptr<TensorT> dequantize_output;
+ MakeTensor(tensor->name + "_dequantize", tensor->shape,
+ &dequantize_output);
+ int32_t dequantize_output_idx = subgraph->tensors.size();
+ subgraph->tensors.push_back(std::move(dequantize_output));
+ // Create the Dequantize operation.
+ std::unique_ptr<OperatorT> dequantize_op;
+ MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx,
+ dequantize_output_idx);
+ // Update the op_input of tensor_idx to dequantize_output_idx.
+ op->inputs[op_input_idx] = dequantize_output_idx;
+ // Insert the updated op.
+ new_operators.push_back(std::move(subgraph->operators[i]));
+ // Insert the newly created Dequantize operation.
+ new_operators.push_back(std::move(dequantize_op));
+ } else {
+ // If this tensor wasn't quantizable, just copy the op over as-is.
+ new_operators.push_back(std::move(subgraph->operators[i]));
+ }
+ }
+ // At this point all unique_ptrs in the original operators are invalid, and
+ // we need to replace it with the new_operators vector.
+ subgraph->operators = std::move(new_operators);
+ flatbuffers::Offset<Model> output_model_location =
+ Model::Pack(*builder, model.get());
+ FinishModelBuffer(*builder, output_model_location);
+ return kTfLiteOk;
+} // namespace optimize
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
new file mode 100644
index 0000000000..a408c1662d
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+See the License for the specific language governing permissions and
+limitations under the License.
+#include <memory>
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+namespace tflite {
+namespace optimize {
+// Quantizes input_model and populates the provided builder with the new model.
+// A tflite::Model can be obtained from the builder with:
+// const uint8_t* buffer = builder->GetBufferPointer();
+// tflite::Model* model = GetModel(buffer);
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model);
+} // namespace optimize
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
new file mode 100644
index 0000000000..0e0676e5ff
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
@@ -0,0 +1,130 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+See the License for the specific language governing permissions and
+limitations under the License.
+#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h"
+#include <memory>
+#include "flatbuffers/flexbuffers.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+namespace tflite {
+namespace optimize {
+namespace {
+class QuantizeWeightsTest : public ::testing::Test {
+ protected:
+ int GetElementsNum(const TensorT* tensor) {
+ int tensor_size = 1;
+ for (const int dim : tensor->shape) {
+ tensor_size *= dim;
+ }
+ return tensor_size;
+ }
+ const OperatorT* GetOpWithOutput(const SubGraphT* subgraph,
+ int32_t output_tensor_idx) {
+ for (int i = 0; i < subgraph->operators.size(); ++i) {
+ OperatorT* op = subgraph->operators[i].get();
+ if (std::find(op->outputs.begin(), op->outputs.end(),
+ output_tensor_idx) != op->outputs.end()) {
+ return op;
+ }
+ }
+ return nullptr;
+ }
+ void CheckWeights(const Model* model_packed) {
+ std::unique_ptr<ModelT> model;
+ model.reset(model_packed->UnPack());
+ SubGraphT* subgraph = model->subgraphs.at(0).get();
+ for (int i = 0; i < subgraph->operators.size(); ++i) {
+ OperatorT* op = subgraph->operators[i].get();
+ const BuiltinOperator op_code =
+ model->operator_codes[op->opcode_index]->builtin_code;
+ // These are the operations that should be quantized.
+ int32_t tensor_idx;
+ if (op_code == BuiltinOperator_CONV_2D ||
+ op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
+ op_code == BuiltinOperator_FULLY_CONNECTED) {
+ tensor_idx = op->inputs[1];
+ } else if (op_code == BuiltinOperator_LSTM) {
+ // TODO(suharshs): Add tests for LSTMs.
+ tensor_idx = op->inputs[1];
+ } else {
+ continue;
+ }
+ const TensorT* tensor = subgraph->tensors[tensor_idx].get();
+ int tensor_size = GetElementsNum(tensor);
+ // If the tensor_size is less than 1024 we expect the tensor to remain
+ // unquantized.
+ if (tensor_size < 1024) {
+ ASSERT_TRUE(tensor->type == TensorType_FLOAT32) << tensor->name;
+ const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx);
+ // The weight tensor should not come from a dequantize op.
+ ASSERT_TRUE(preceding_op == nullptr);
+ } else {
+ // The input to the op should still be float.
+ ASSERT_TRUE(tensor->type == TensorType_FLOAT32) << tensor->name;
+ const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx);
+ ASSERT_TRUE(preceding_op != nullptr);
+ // The float input should be the dequantize output.
+ model->operator_codes[preceding_op->opcode_index]->builtin_code ==
+ BuiltinOperator_DEQUANTIZE);
+ // Finally, ensure that the input to the dequantize operation is
+ // quantized.
+ ASSERT_TRUE(subgraph->tensors[preceding_op->inputs[0]]->type ==
+ TensorType_UINT8);
+ // TODO(suharshs): Add more rigorous testing for the numerical values in
+ // the tensors.
+ }
+ }
+ }
+TEST_F(QuantizeWeightsTest, SimpleTest) {
+ string model_path =
+ "third_party/tensorflow/contrib/lite/tools/optimize/testdata/"
+ "mobilenet_v1_0.25_128.tflite";
+ std::unique_ptr<FlatBufferModel> input_fb =
+ FlatBufferModel::BuildFromFile(model_path.data());
+ const Model* input_model = input_fb->GetModel();
+ flatbuffers::FlatBufferBuilder builder;
+ EXPECT_EQ(QuantizeWeights(&builder, input_model), kTfLiteOk);
+ const uint8_t* buffer = builder.GetBufferPointer();
+ const Model* output_model = GetModel(buffer);
+ CheckWeights(output_model);
+// TODO(suharshs): Add tests that run the resulting model.
+} // namespace
+} // namespace optimize
+} // namespace tflite
+int main(int argc, char** argv) {
+ // On Linux, add: FLAGS_logtostderr = true;
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();