aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-18 12:10:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-18 12:14:11 -0700
commitf0aabfa0139cb83c857e6142286d025515fbf9a1 (patch)
treeb9fb13fda3ec820e545be902e4042c2c5c829793
parent03d18ae232c3cff4c56d1efec7bf29f9b16c4f68 (diff)
Make toco generate uint8 weights that are safe for fast int8 kernels.
PiperOrigin-RevId: 193395910
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/args.h1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc209
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h13
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc9
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto7
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc5
7 files changed, 244 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index f696f4b845..3f73ef620e 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -219,6 +219,7 @@ cc_library(
"graph_transformations/drop_fake_quant.cc",
"graph_transformations/drop_im2col_arrays.cc",
"graph_transformations/ensure_bias_vectors.cc",
+ "graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc",
"graph_transformations/experimental_shuffle_fc_weights.cc",
"graph_transformations/fuse_activation_functions.cc",
"graph_transformations/fuse_binary_into_following_affine.cc",
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index c9662d05ce..fe30b88344 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -240,6 +240,7 @@ struct ParsedTocoFlags {
Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false);
Arg<bool> drop_control_dependency = Arg<bool>(false);
Arg<bool> propagate_fake_quant_num_bits = Arg<bool>(false);
+ Arg<bool> allow_nudging_weights_to_use_fast_gemm_kernel = Arg<bool>(false);
};
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
new file mode 100644
index 0000000000..394fa349e2
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
@@ -0,0 +1,209 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+// === Summary ===
+//
+// TLDR: Some of our 8-bit arithmetic operations require uint8 weight values
+// to avoid the value 0, thus ranging only in [1, 255]. This enables faster
+// runtime arithmetic kernels on ARM NEON. This is not relevant on most
+// other hardware architectures, and will cease to be relevant on ARM NEON
+// in the future. These topics are elaborated below ("Context").
+//
+// Having just one isolated uint8 value equal to 0 is fine. The bad case is when
+// two uint8 values are both zero and are less than 16 bytes apart.
+//
+// By default, toco generates a fatal error when that happens. The user may opt
+// in to more lax behavior by passing
+// --allow_nudging_weights_to_use_fast_gemm_kernel.
+// This causes toco to nudge such bad 0 values into the value 1, thus avoiding
+// the problem in exchange for compromising on accuracy.
+//
+// The present graph transformation implements both the default fatal-erroring
+// behavior, and, when allow_nudging_weights is set, also the lax nudging
+// behavior.
+//
+//
+// === Context ===
+//
+// Since March 2017, we have been using a trick to perform faster
+// 8bit matrix multiplications, to our knowledge first implemented in gemmlowp
+// here:
+// https://github.com/google/gemmlowp/commit/25b2989415b99e797e1ab977837111b2e231f81f
+//
+// This trick is explained in Appendix B of our paper,
+// https://arxiv.org/abs/1712.05877
+//
+// Here is the relevant paragraph:
+//
+// For efficient NEON implementation of the matrix multiplication’s
+// core accumulation, we use the following trick.
+// In the multiply-add operation in (10), we first change the
+// operands’ type from uint8 to int8 (which can be done by
+// subtracting 128 from the quantized values and zero-points).
+// Thus the core multiply-add becomes
+//
+// int32 += int8 * int8. (B.1)
+//
+// As mentioned in section 3, with a minor tweak of the quantized
+// training process, we can ensure that the weights, once
+// quantized as int8 values, never take the value −128. Hence,
+// the product in (B.1) is never −128 ∗ −128, and is therefore
+// always less than 2^14 in absolute value. Hence, (B.1)
+// can accumulate two products on a local int16 accumulator
+// before that needs to be accumulated into the true int32 accumulator.
+// This allows the use of an 8-way SIMD multiplication
+// (SMULL on int8 operands), followed by an 8-way
+// SIMD multiply-add (SMLAL on int8 operands), followed
+// by a pairwise-add-and-accumulate into the int32 accumulators
+// (SADALP).
+//
+// As that paragraph notes, quantized training should be suitably modified to
+// ensure that quantized uint8 weights value only range in [1, 255]. So the
+// problem that we are dealing with is only about the existing 8-bit quantized
+// models that haven't been trained specifically to get 8-bit weights only in
+// [1, 255].
+//
+// This spreadsheet shows the speed benefit of this trick across many existing
+// ARM-architecture CPUs:
+//
+// https://docs.google.com/spreadsheets/d/1-0LjdMvW0XtH1bYknC0bQINoFaxjTuL9eplZZcitykI/edit?usp=sharing
+//
+// Compare Row 18 (fast int8 trick) to Row 20 (regular uint8 kernel).
+//
+// The introduction of the 'dotprod' extension to ARM NEON, specifically the
+// SDOT instruction, renders this eventually moot. See the experimental
+// kernels contributed by ARM here,
+//
+// https://github.com/google/gemmlowp/pull/116
+//
+// However, as of April 2018, there don't seem to be any commercially available
+// CPU supporting these instructions (yet); we are waiting for
+// Cortex-A{75,55}-r1 to become available; the "-r1" is key here. Even if such
+// CPUs become available soon, it will presumably take years for them to
+// overtake the large volume of existing CPUs not supporting these new
+// instructions, especially in current and future low-end devices. All in all,
+// we can foresee these 'fast int8 kernels' to remain important to have into
+// the 2020s.
+//
+bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model,
+ std::size_t op_index) {
+ const auto& op = *model->operators[op_index];
+ int weights_index = 0;
+ switch (op.type) {
+ case OperatorType::kConv:
+ weights_index = 1;
+ break;
+ case OperatorType::kLstmCell:
+ weights_index = 2;
+ break;
+ case OperatorType::kFullyConnected: {
+ weights_index = 1;
+ const auto& fc_op = static_cast<const toco::FullyConnectedOperator&>(op);
+ CHECK(!fc_op.experimental_shuffled_weights)
+ << "This graph transformation expects to run before FC weights get "
+ "shuffled.";
+ break;
+ }
+ default:
+ // Other operator types are unaffected by this graph transformation,
+ // because their runtime implementations don't use the fast int8 trick.
+ // In particular that's the case of DepthwiseConv at the moment.
+ // We have to update this logic when that changes, e.g. if in the future
+ // some DepthwiseConv kernel wants to use the trick.
+ //
+ // The reason why that's not so likely, hence why it's fairly safe to
+ // stay conservative in the list of operators that we handle here, is that
+ // the fast int8 kernel trick is only applicable to ops that either are
+ // implemented as a GEMM, or use symmetric ranges for both weights and
+ // activations. The reason why GEMM is special (can use the trick even
+ // without symmetric ranges) is that it is so arithmetic-intense that
+ // it can use techniques reducing its implementation to the symmetric
+ // ranges case, with limited relative overhead (O(N^2) overhead vs
+ // O(N^3) GEMM cost). See https://arxiv.org/pdf/1712.05877, section
+ // 2.3 Efficient handling of zero-points.
+ //
+ // That's why at the moment we only handle operators that use a GEMM
+ // (Conv, fully-connected --- note that LSTM merely wraps a
+ // fully-connected operator).
+ return false;
+ }
+
+ const string& name = op.inputs[weights_index];
+ auto& array = model->GetArray(name);
+ if (!array.buffer) {
+ return false;
+ }
+ if (array.data_type != ArrayDataType::kUint8) {
+ return false;
+ }
+ auto& buffer_data = array.GetMutableBuffer<ArrayDataType::kUint8>().data;
+
+ int count_bad = 0;
+ int index_of_previous_bad_value = 0;
+ bool changed = false;
+
+ for (int i = 0; i < buffer_data.size(); i++) {
+ if (buffer_data[i] == 0) {
+ count_bad++;
+ if (count_bad > 1) {
+ const int distance = i - index_of_previous_bad_value;
+ // Semi-arbitrary threshold. The idea is that trouble only occurs
+ // when two bad values are very close to each other so that they
+ // are jointly used within registers inside some GEMM kernel.
+ // The details of that depend on the kernel. Our current fast ARM64
+ // kernel, for instance, only has an issue when the distance between
+ // consecutive bad values is exactly 8. We do not want to track such
+ // kernel details too closely here, so we pick a threshold that's
+ // a bit larger than that, to give us room to change kernels in the
+ // future without worrying.
+ static constexpr int kMinDistanceBetweenBadValues = 16;
+ if (distance < kMinDistanceBetweenBadValues) {
+ if (allow_nudging_weights()) {
+ buffer_data[i] = 1;
+ changed = true;
+ continue;
+ }
+ LOG(FATAL) << "Bad value for " << name << " at index " << i
+ << ", previous bad value at index "
+ << index_of_previous_bad_value << ", distance=" << distance
+ << ", kMinDistanceBetweenBadValues="
+ << kMinDistanceBetweenBadValues << ". Consider passing "
+ << "--allow_nudging_weights_to_use_fast_gemm_kernel "
+ << "if you don't care about accuracy.";
+ }
+ }
+ index_of_previous_bad_value = i;
+ }
+ }
+
+ if (changed) {
+ AddMessageF("Tweaked weights values for %s", LogName(op));
+ }
+
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 8075d0205d..72ffd51db4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -246,6 +246,19 @@ class ResolveConstantFakeQuant : public GraphTransformation {
bool propagate_fake_quant_num_bits_ = false;
};
+class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
+ public:
+ bool Run(Model* model, std::size_t op_index) override;
+ const char* Name() const override {
+ return "EnsureUint8WeightsSafeForFastInt8Kernels";
+ }
+ bool allow_nudging_weights() const { return allow_nudging_weights_; }
+ void set_allow_nudging_weights(bool val) { allow_nudging_weights_ = val; }
+
+ private:
+ bool allow_nudging_weights_ = false;
+};
+
#undef DECLARE_GRAPH_TRANSFORMATION
} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index 74f98c8452..1611c4d0c0 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -141,6 +141,13 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.propagate_fake_quant_num_bits.default_value(),
"If true, use FakeQuant* operator num_bits attributes to adjust "
"array data_types."),
+ Flag("allow_nudging_weights_to_use_fast_gemm_kernel",
+ parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel.bind(),
+ parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel
+ .default_value(),
+ "Some fast uint8 GEMM kernels require uint8 weights to avoid the "
+ "value 0. This flag allows nudging them to 1 to allow proceeding, "
+ "with moderate inaccuracy."),
};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
@@ -230,6 +237,8 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone);
READ_TOCO_FLAG(debug_disable_recurrent_cell_fusion, FlagRequirement::kNone);
READ_TOCO_FLAG(propagate_fake_quant_num_bits, FlagRequirement::kNone);
+ READ_TOCO_FLAG(allow_nudging_weights_to_use_fast_gemm_kernel,
+ FlagRequirement::kNone);
// Deprecated flag handling.
if (parsed_toco_flags.input_type.specified()) {
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index 869c512d93..a04017a6bf 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -37,7 +37,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
-// Next ID to use: 17.
+// Next ID to use: 18.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@@ -156,4 +156,9 @@ message TocoFlags {
// Input and output array data types may change because of this propagation
// and users must be sure to query the final data_type values.
optional bool propagate_fake_quant_num_bits = 14;
+
+ // Some fast uint8 GEMM kernels require uint8 weights to avoid the value 0.
+ // This flag allows nudging them to 1 to allow proceeding, with moderate
+ // inaccuracy.
+ optional bool allow_nudging_weights_to_use_fast_gemm_kernel = 17;
}
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 89cb2f85f8..7252ec2ea4 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -317,12 +317,17 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
}
CheckIsReadyForQuantization(*model);
+ auto* ensure_safe_for_int8_kernels =
+ new EnsureUint8WeightsSafeForFastInt8Kernels;
+ ensure_safe_for_int8_kernels->set_allow_nudging_weights(
+ toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel());
RunGraphTransformations(model, "quantization graph transformations",
{
new RemoveTrivialQuantizedActivationFunc,
new RemoveTrivialQuantizedMinMax,
new Quantize,
new RemoveFinalDequantizeOp,
+ ensure_safe_for_int8_kernels,
});
} else {
GraphTransformationsSet dequantization_transformations{new Dequantize};