aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-18 12:10:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-18 12:14:11 -0700
commitf0aabfa0139cb83c857e6142286d025515fbf9a1 (patch)
treeb9fb13fda3ec820e545be902e4042c2c5c829793 /tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
parent03d18ae232c3cff4c56d1efec7bf29f9b16c4f68 (diff)
Make toco generate uint8 weights that are safe for fast int8 kernels.
PiperOrigin-RevId: 193395910
Diffstat (limited to 'tensorflow/contrib/lite/toco/toco_cmdline_flags.cc')
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index 74f98c8452..1611c4d0c0 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -141,6 +141,13 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.propagate_fake_quant_num_bits.default_value(),
"If true, use FakeQuant* operator num_bits attributes to adjust "
"array data_types."),
+ Flag("allow_nudging_weights_to_use_fast_gemm_kernel",
+ parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel.bind(),
+ parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel
+ .default_value(),
+ "Some fast uint8 GEMM kernels require uint8 weights to avoid the "
+ "value 0. This flag allows nudging them to 1 to allow proceeding, "
+ "with moderate inaccuracy."),
};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
@@ -230,6 +237,8 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone);
READ_TOCO_FLAG(debug_disable_recurrent_cell_fusion, FlagRequirement::kNone);
READ_TOCO_FLAG(propagate_fake_quant_num_bits, FlagRequirement::kNone);
+ READ_TOCO_FLAG(allow_nudging_weights_to_use_fast_gemm_kernel,
+ FlagRequirement::kNone);
// Deprecated flag handling.
if (parsed_toco_flags.input_type.specified()) {