aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/lite_test.py
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-06-06 16:18:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-06 16:21:40 -0700
commitc4a3763539dbdb2ee08cca99074d78ce3b6d54de (patch)
tree8d613e0f5941427e1e3f6e287747950cc3c3a513 /tensorflow/contrib/lite/python/lite_test.py
parent64204dd0addea52368400eea6c67616c312b594d (diff)
quantize_weights flag for tflite_convert.
PiperOrigin-RevId: 199549093
Diffstat (limited to 'tensorflow/contrib/lite/python/lite_test.py')
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 019a3a5f69..bbb00021f9 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -25,9 +25,11 @@ from tensorflow.contrib.lite.python import lite
from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.python.interpreter import Interpreter
from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@@ -291,6 +293,36 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
+ def testQuantizeWeights(self):
+ np.random.seed(0)
+ # We need the tensor to have more than 1024 elements for quantize_weights
+ # to kick in. Thus, the [33, 33] shape.
+ in_tensor_1 = array_ops.placeholder(
+ shape=[33, 33], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = constant_op.constant(
+ np.random.uniform(low=-10., high=10., size=(33, 33)),
+ shape=[33, 33],
+ dtype=dtypes.float32,
+ name='inputB')
+ out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
+ sess = session.Session()
+
+ # Convert float model.
+ float_converter = lite.TocoConverter.from_session(sess, [in_tensor_1],
+ [out_tensor])
+ float_tflite = float_converter.convert()
+ self.assertTrue(float_tflite)
+
+ # Convert quantized weights model.
+ quantized_weights_converter = lite.TocoConverter.from_session(
+ sess, [in_tensor_1], [out_tensor])
+ quantized_weights_converter.quantize_weights = True
+ quantized_weights_tflite = quantized_weights_converter.convert()
+ self.assertTrue(quantized_weights_tflite)
+
+ # Ensure that the quantized weights tflite model is smaller.
+ self.assertTrue(len(quantized_weights_tflite) < len(float_tflite))
+
class FromFrozenGraphFile(test_util.TensorFlowTestCase):