aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-07-27 12:15:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-27 12:18:17 -0700
commitab9f0a628f61fcb19b6b09cb51bf05ff8c702a80 (patch)
tree10bf419251e8d089d9690e62abe8d5d0f03356e2
parent78d225ef8a6a32423febc67803fabdff05b378c0 (diff)
Update functionality of --allow_nudging_weights_to_use_fast_gemm_kernel.
PiperOrigin-RevId: 206354203
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc11
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h4
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc7
3 files changed, 19 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
index 75642bbc37..c13fc0de75 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
@@ -181,7 +181,7 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model,
// future without worrying.
static constexpr int kMinDistanceBetweenBadValues = 16;
if (distance < kMinDistanceBetweenBadValues) {
- if (allow_nudging_weights()) {
+ if (allow_nudging_weights() || has_default_ranges_flag()) {
buffer_data[i] = 1;
changed = true;
continue;
@@ -200,6 +200,15 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model,
}
if (changed) {
+ if (has_default_ranges_flag()) {
+ std::cerr
+ << "Since the specified values of --default_ranges_min and "
+ "--default_ranges_max result in values incompatible with TFLite's "
+ "fast int8 kernels, "
+ "--allow_nudging_weights_to_use_fast_gemm_kernel "
+ "has been enabled. This may affect the accuracy of the model."
+ << std::endl;
+ }
AddMessageF("Tweaked weights values for %s", LogName(op));
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index b7634e28c6..8d9a4c4700 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -262,8 +262,12 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
bool allow_nudging_weights() const { return allow_nudging_weights_; }
void set_allow_nudging_weights(bool val) { allow_nudging_weights_ = val; }
+ bool has_default_ranges_flag() const { return has_default_ranges_flag_; }
+ void set_has_default_ranges_flag(bool val) { has_default_ranges_flag_ = val; }
+
private:
bool allow_nudging_weights_ = false;
+ bool has_default_ranges_flag_ = false;
};
#undef DECLARE_GRAPH_TRANSFORMATION
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index aa7f6996eb..fcd3cbab07 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -309,8 +309,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
// HardcodeMinMax to move changes through the graph as we make changes.
auto propagate_default_min_max =
absl::make_unique<PropagateDefaultMinMax>();
- if (toco_flags.has_default_ranges_min() &&
- toco_flags.has_default_ranges_max()) {
+ bool has_default_ranges_flag = (toco_flags.has_default_ranges_min() &&
+ toco_flags.has_default_ranges_max());
+ if (has_default_ranges_flag) {
propagate_default_min_max->DefineTypeRange(
ArrayDataType::kUint8, toco_flags.default_ranges_min(),
toco_flags.default_ranges_max());
@@ -335,6 +336,8 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
new EnsureUint8WeightsSafeForFastInt8Kernels;
ensure_safe_for_int8_kernels->set_allow_nudging_weights(
toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel());
+ ensure_safe_for_int8_kernels->set_has_default_ranges_flag(
+ has_default_ranges_flag);
RunGraphTransformations(model, "quantization graph transformations",
{
new RemoveTrivialQuantizedActivationFunc,