From f0aabfa0139cb83c857e6142286d025515fbf9a1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Apr 2018 12:10:51 -0700 Subject: Make toco generate uint8 weights that are safe for fast int8 kernels. PiperOrigin-RevId: 193395910 --- tensorflow/contrib/lite/toco/toco_tooling.cc | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'tensorflow/contrib/lite/toco/toco_tooling.cc') diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 89cb2f85f8..7252ec2ea4 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -317,12 +317,17 @@ void Transform(const TocoFlags& toco_flags, Model* model) { } CheckIsReadyForQuantization(*model); + auto* ensure_safe_for_int8_kernels = + new EnsureUint8WeightsSafeForFastInt8Kernels; + ensure_safe_for_int8_kernels->set_allow_nudging_weights( + toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel()); RunGraphTransformations(model, "quantization graph transformations", { new RemoveTrivialQuantizedActivationFunc, new RemoveTrivialQuantizedMinMax, new Quantize, new RemoveFinalDequantizeOp, + ensure_safe_for_int8_kernels, }); } else { GraphTransformationsSet dequantization_transformations{new Dequantize}; -- cgit v1.2.3