diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2018-06-06 16:18:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-06 16:21:40 -0700 |
commit | c4a3763539dbdb2ee08cca99074d78ce3b6d54de (patch) | |
tree | 8d613e0f5941427e1e3f6e287747950cc3c3a513 /tensorflow/contrib/lite/python/lite_test.py | |
parent | 64204dd0addea52368400eea6c67616c312b594d (diff) |
quantize_weights flag for tflite_convert.
PiperOrigin-RevId: 199549093
Diffstat (limited to 'tensorflow/contrib/lite/python/lite_test.py')
-rw-r--r-- | tensorflow/contrib/lite/python/lite_test.py | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 019a3a5f69..bbb00021f9 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -25,9 +25,11 @@ from tensorflow.contrib.lite.python import lite from tensorflow.contrib.lite.python import lite_constants from tensorflow.contrib.lite.python.interpreter import Interpreter from tensorflow.python.client import session +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import gfile from tensorflow.python.platform import test @@ -291,6 +293,36 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + def testQuantizeWeights(self): + np.random.seed(0) + # We need the tensor to have more than 1024 elements for quantize_weights + # to kick in. Thus, the [33, 33] shape. + in_tensor_1 = array_ops.placeholder( + shape=[33, 33], dtype=dtypes.float32, name='inputA') + in_tensor_2 = constant_op.constant( + np.random.uniform(low=-10., high=10., size=(33, 33)), + shape=[33, 33], + dtype=dtypes.float32, + name='inputB') + out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') + sess = session.Session() + + # Convert float model. + float_converter = lite.TocoConverter.from_session(sess, [in_tensor_1], + [out_tensor]) + float_tflite = float_converter.convert() + self.assertTrue(float_tflite) + + # Convert quantized weights model. + quantized_weights_converter = lite.TocoConverter.from_session( + sess, [in_tensor_1], [out_tensor]) + quantized_weights_converter.quantize_weights = True + quantized_weights_tflite = quantized_weights_converter.convert() + self.assertTrue(quantized_weights_tflite) + + # Ensure that the quantized weights tflite model is smaller. + self.assertTrue(len(quantized_weights_tflite) < len(float_tflite)) + class FromFrozenGraphFile(test_util.TensorFlowTestCase): |