aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-06-01 18:30:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-01 18:33:35 -0700
commitdbdd276a05c417963b3f06f71e801540bde9ab7c (patch)
tree7a3d8c875fb393026266603b293eb59869ea7268
parentd81328115bd10de70570c46dbfc683cd0238d779 (diff)
Quantize weights transformation for toco.
Finds float weight tensors, quantizes them to 8 bits, and adds Dequantize operations after them. PiperOrigin-RevId: 198955123
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/args.h1
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc108
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc167
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc4
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc11
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto7
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc3
11 files changed, 319 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index b8acc9a8e0..7ea4f32ef6 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -245,6 +245,7 @@ cc_library(
"graph_transformations/quantization_util.cc",
"graph_transformations/quantization_util.h",
"graph_transformations/quantize.cc",
+ "graph_transformations/quantize_weights.cc",
"graph_transformations/read_fake_quant_min_max.cc",
"graph_transformations/remove_final_dequantize_op.cc",
"graph_transformations/remove_tensorflow_assert.cc",
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index 77bc54f191..9f5ca66d05 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -234,6 +234,7 @@ struct ParsedTocoFlags {
Arg<bool> drop_fake_quant = Arg<bool>(false);
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
Arg<bool> allow_custom_ops = Arg<bool>(false);
+ Arg<bool> quantize_weights = Arg<bool>(false);
// Deprecated flags
Arg<string> input_type;
Arg<string> input_types;
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
index 9e99287f82..a8381169b8 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
@@ -203,6 +203,10 @@ have.
graph transformations on them, at the cost of no longer faithfully matching
inference and training arithmetic.
+* `--quantize_weights`. Type: boolean. Default: false. Store weights as
+ quantized weights followed by dequantize operations. Computation is still
+ done in float, but reduces model size (at the cost of accuracy and latency).
+
## Logging flags
The following are standard Google logging flags:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 8da242aa9c..1bc7557d46 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -139,6 +139,7 @@ DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits);
DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes)
DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax)
DECLARE_GRAPH_TRANSFORMATION(Quantize)
+DECLARE_GRAPH_TRANSFORMATION(QuantizeWeights)
DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp)
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert)
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc
new file mode 100644
index 0000000000..88ea0945e7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc
@@ -0,0 +1,108 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+
+// The minimum number of elements a weights array must have to be quantized
+// by this transformation.
+// TODO(suharshs): Make this minimum size configurable.
+const int kWeightsMinSize = 1024;
+
+// Gets the quantization params from the float array.
+void GetQuantizationParamsFromArray(const Array& array,
+ QuantizationParams* params) {
+ const std::vector<float>& float_vals =
+ array.GetBuffer<ArrayDataType::kFloat>().data;
+ auto minmax = std::minmax_element(float_vals.begin(), float_vals.end());
+ MinMax toco_minmax;
+ toco_minmax.min = *minmax.first;
+ toco_minmax.max = *minmax.second;
+ GetQuantizationParams(ArrayDataType::kUint8, toco_minmax, params);
+}
+
+} // namespace
+
+bool QuantizeWeights::Run(Model* model, std::size_t op_index) {
+ const auto op_it = model->operators.begin() + op_index;
+ Operator* op = op_it->get();
+
+ // Get the weights tensor, if the current operator has one.
+ int weights_index;
+ if (op->type == OperatorType::kConv ||
+ op->type == OperatorType::kDepthwiseConv ||
+ op->type == OperatorType::kFullyConnected) {
+ weights_index = 1;
+ } else if (op->type == OperatorType::kLstmCell) {
+ weights_index = LstmCellOperator::WEIGHTS_INPUT;
+ } else {
+ return false;
+ }
+
+ // Return early if the array isn't a constant param, this can happen in early
+ // transformation passes until transpose operations following the weight array
+ // are resolved.
+ const string weights = op->inputs[weights_index];
+ if (!IsConstantParameterArray(*model, weights)) {
+ return false;
+ }
+
+ // Return early if the weight tensor is not type float.
+ Array& weights_array = model->GetArray(weights);
+ if (weights_array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+
+ // Return early if the tensor is too small. Small tensors don't take up too
+ // much space and can result in bad quantization results.
+ if (weights_array.GetBuffer<ArrayDataType::kFloat>().data.size() <
+ kWeightsMinSize) {
+ return false;
+ }
+
+ // Quantize the weight tensor to type kUint8.
+ QuantizationParams params;
+ GetQuantizationParamsFromArray(weights_array, &params);
+ QuantizeArray(this, model, weights, ArrayDataType::kUint8, params);
+
+ // Insert a Dequantize operation after the quantized weights tensor.
+ auto* dequantize_op = new DequantizeOperator;
+ model->operators.emplace(op_it, dequantize_op);
+
+ // Create a new intermediate tensor to connect the Dequantize op to the
+ // original op.
+ const string dequantized_output =
+ AvailableArrayName(*model, weights + "_dequantized");
+ Array& dequantized_output_array = model->GetOrCreateArray(dequantized_output);
+ dequantized_output_array.data_type = ArrayDataType::kFloat;
+
+ // Connect up the new Dequantize op with the weights and original op.
+ op->inputs[weights_index] = dequantized_output;
+ dequantize_op->inputs = {weights};
+ dequantize_op->outputs = {dequantized_output};
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
index 8dcd4adc90..95e8433be2 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
@@ -8,8 +8,8 @@ load(
)
tf_cc_test(
- name = "resolve_constant_concatenation_test",
- srcs = ["resolve_constant_concatenation_test.cc"],
+ name = "lstm_utils_test",
+ srcs = ["lstm_utils_test.cc"],
deps = [
"//tensorflow/contrib/lite/toco:graph_transformations",
"//tensorflow/contrib/lite/toco:model",
@@ -19,8 +19,20 @@ tf_cc_test(
)
tf_cc_test(
- name = "lstm_utils_test",
- srcs = ["lstm_utils_test.cc"],
+ name = "quantize_weights_test",
+ srcs = ["quantize_weights_test.cc"],
+ deps = [
+ "//tensorflow/contrib/lite/toco:graph_transformations",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tf_cc_test(
+ name = "resolve_constant_concatenation_test",
+ srcs = ["resolve_constant_concatenation_test.cc"],
deps = [
"//tensorflow/contrib/lite/toco:graph_transformations",
"//tensorflow/contrib/lite/toco:model",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc
new file mode 100644
index 0000000000..c05eb0929f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc
@@ -0,0 +1,167 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <math.h>
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+class QuantizeWeightsTest : public ::testing::Test {
+ protected:
+ QuantizeWeightsTest() {}
+
+ // The name of the weights input array.
+ const string kWeightsName = "weights";
+ // The zero_point of the values in the input array.
+ const int kZeroPoint = 128;
+
+ // Prepare a hypothetical TOCO model of a quantizable fully connected float
+ // layer.
+ void PrepareModel(Model* model, int elements_per_dim) {
+ std::vector<string> fc_input_names = {"inputs", kWeightsName};
+
+ const int kDim = 4;
+ const int buf_size = std::pow(elements_per_dim, static_cast<double>(kDim));
+ auto in_buf = absl::make_unique<float[]>(buf_size);
+ // Initialize the array with values from -128.0 to 127.0, since these values
+ // should be exactly representable by quantization.
+ for (int i = 0; i < buf_size; i++) {
+ in_buf[i] = static_cast<float>(i % 256 - kZeroPoint);
+ }
+
+ for (const string& fc_input_name : fc_input_names) {
+ Array& in_array = model->GetOrCreateArray(fc_input_name);
+ in_array.data_type = ArrayDataType::kFloat;
+
+ // Initialize shape for the input array.
+ Shape* in_array_shape = in_array.mutable_shape();
+ std::vector<int>* in_array_shape_dim = in_array_shape->mutable_dims();
+ in_array_shape_dim->resize(kDim, elements_per_dim);
+ auto& in_array_buffer =
+ in_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ in_array_buffer.data.resize(buf_size);
+ float* buf_ptr =
+ in_array.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
+ std::copy(in_buf.get(), in_buf.get() + buf_size, buf_ptr);
+ }
+
+ auto* fc_op = new FullyConnectedOperator;
+ fc_op->inputs = fc_input_names;
+ fc_op->outputs = {"fc_op_outputs"};
+ Array& out_array = model->GetOrCreateArray(fc_op->outputs[0]);
+ out_array.data_type = ArrayDataType::kFloat;
+ Shape* out_array_shape = out_array.mutable_shape();
+ std::vector<int>* out_array_shape_dim = out_array_shape->mutable_dims();
+ out_array_shape_dim->resize(kDim, elements_per_dim);
+ model->operators.push_back(std::unique_ptr<Operator>(fc_op));
+ }
+};
+
+TEST_F(QuantizeWeightsTest, QuantizedFullyConnected) {
+ // Test that weight arrays that are large enough are quantized.
+ Model model;
+ // 6 elements per dim gives us 1296 elements, which is sufficient to be
+ // quantized.
+ PrepareModel(&model, 6);
+
+ // Check the state of the graph before the transformation.
+ const auto& float_array_map = model.GetArrayMap();
+ EXPECT_EQ(float_array_map.size(), 3);
+ // Before the transformation, all arrays should be type float.
+ for (const auto& element : float_array_map) {
+ EXPECT_EQ(element.second->data_type, ArrayDataType::kFloat);
+ }
+ const std::vector<float> float_weight_vals =
+ model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
+
+ // Invoke the transformation.
+ GraphTransformationsSet graph_transformation_set;
+ graph_transformation_set.Add(new toco::QuantizeWeights);
+ (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+
+ // Check the state of the graph after the transformation.
+ const auto& quantized_array_map = model.GetArrayMap();
+ EXPECT_EQ(quantized_array_map.size(), 4);
+ // After the transformation, three arrays should be type float and one array
+ // should be uint8.
+ int num_float = 0;
+ int num_uint8 = 0;
+ for (const auto& element : quantized_array_map) {
+ if (element.second->data_type == ArrayDataType::kFloat) {
+ num_float++;
+ } else if (element.second->data_type == ArrayDataType::kUint8) {
+ num_uint8++;
+ } else {
+ FAIL() << "Unexpected array type.";
+ }
+ }
+ EXPECT_EQ(num_float, 3);
+ EXPECT_EQ(num_uint8, 1);
+ // Ensure that the values were quantized correctly.
+ const std::vector<uint8>& quantized_weight_vals =
+ model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kUint8>().data;
+ for (int i = 0; i < quantized_weight_vals.size(); i++) {
+ EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i] + kZeroPoint);
+ }
+
+ // Ensure that a Dequantize operator has been inserted before the
+ // FullyConnectedLayer.
+ EXPECT_EQ(model.operators[0]->type, OperatorType::kDequantize);
+}
+
+TEST_F(QuantizeWeightsTest, NotQuantizedFullyConnected) {
+ // Test that weight arrays that are too small are left untouched.
+ Model model;
+ // 5 elements per dim gives us 625 elements, which is NOT sufficient to be
+ // quantized.
+ PrepareModel(&model, 5);
+
+ // Check the state of the graph before the transformation.
+ const auto& float_array_map = model.GetArrayMap();
+ EXPECT_EQ(float_array_map.size(), 3);
+ // Before the transformation, all arrays should be type float.
+ for (auto it = float_array_map.begin(); it != float_array_map.end(); it++) {
+ EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat);
+ }
+ std::vector<float> float_weight_vals =
+ model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
+
+ // Invoke the transformation.
+ GraphTransformationsSet graph_transformation_set;
+ graph_transformation_set.Add(new toco::QuantizeWeights);
+ (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+
+ // Check the state of the graph after the transformation.
+ const auto& post_array_map = model.GetArrayMap();
+ EXPECT_EQ(post_array_map.size(), 3);
+ for (auto it = post_array_map.begin(); it != post_array_map.end(); it++) {
+ EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat);
+ }
+ // Ensure that the values remain unchanged.
+ std::vector<float> const& quantized_weight_vals =
+ model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
+ for (int i = 0; i < quantized_weight_vals.size(); i++) {
+ EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i]);
+ }
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
index 3a1d175b98..66cfed4ac2 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
@@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <memory>
#include <string>
-#include <unordered_map>
#include <vector>
#include <gmock/gmock.h>
@@ -126,7 +124,7 @@ class ResolveConstantConcatenationTest : public ::testing::Test {
Array& in_array = model->GetOrCreateArray(concat_input_name);
in_array.data_type = ArrayDataType::kFloat;
- // Initialize shape for the input array.
+ // Initialize shape for the input array.
Shape* in_array_shape = in_array.mutable_shape();
std::vector<int>* in_array_shape_dim = in_array_shape->mutable_dims();
for (int i = 0; i < kDim; i++) {
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index 9c6ad673ab..87a1e429b9 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -158,6 +158,11 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.split_tflite_lstm_inputs.default_value(),
"Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. "
"Ignored if the output format is not TFLite."),
+ Flag("quantize_weights", parsed_flags.quantize_weights.bind(),
+ parsed_flags.quantize_weights.default_value(),
+ "Store weights as quantized weights followed by dequantize "
+ "operations. Computation is still done in float, but reduces model "
+ "size (at the cost of accuracy and latency)."),
};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
@@ -251,6 +256,7 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
FlagRequirement::kNone);
READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone);
READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
+ READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
// Deprecated flag handling.
if (parsed_toco_flags.input_type.specified()) {
@@ -284,6 +290,11 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
QCHECK(toco::IODataType_Parse(input_types[0], &input_type));
toco_flags->set_inference_input_type(input_type);
}
+ if (parsed_toco_flags.quantize_weights.value()) {
+ QCHECK_NE(toco_flags->inference_type(), IODataType::QUANTIZED_UINT8)
+ << "quantize_weights is not supported with inference_type "
+ "QUANTIZED_UINT8.";
+ }
#undef READ_TOCO_FLAG
#undef PARSE_TOCO_FLAG
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index 15f755c104..4fe57879fb 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -37,7 +37,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
-// Next ID to use: 20.
+// Next ID to use: 21.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@@ -169,4 +169,9 @@ message TocoFlags {
// Split the LSTM inputs from 5 tensors to 18 tensors for TFLite.
// Ignored if the output format is not TFLite.
optional bool split_tflite_lstm_inputs = 19 [default = true];
+
+ // Store weights as quantized weights followed by dequantize operations.
+ // Computation is still done in float, but reduces model size (at the cost of
+ // accuracy and latency).
+ optional bool quantize_weights = 20 [default = false];
}
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index a648883d1f..1fe76f8163 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -269,6 +269,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
transformations.Add(new toco::MergeLstmCellInputs);
}
}
+ if (toco_flags.quantize_weights()) {
+ transformations.Add(new QuantizeWeights);
+ }
transformations.Add(new ResolveConstantConcatenation);
RunGraphTransformations(model, "general graph transformations",
transformations);