aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/lite_test.py
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-05-30 17:54:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-30 17:56:47 -0700
commit316549d36f6ab3d250ce9e33b768bbfb1a4d7362 (patch)
treecef32a4c8ace3dedac532c14fd39944d5bc4ed2b /tensorflow/contrib/lite/python/lite_test.py
parent2a484497062677f5cf0205ee3b9c28a64f03fe04 (diff)
Enable TOCO pip command line binding.
PiperOrigin-RevId: 198649827
Diffstat (limited to 'tensorflow/contrib/lite/python/lite_test.py')
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py180
1 files changed, 171 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 2f3105f3e6..28386ecb1a 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -29,8 +29,10 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
+from tensorflow.python.training.training_util import write_graph
class FromSessionTest(test_util.TensorFlowTestCase):
@@ -65,16 +67,22 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertEqual((0., 0.), output_details[0]['quantization'])
def testQuantization(self):
- in_tensor = array_ops.placeholder(
- shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input')
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
out_tensor = array_ops.fake_quant_with_min_max_args(
- in_tensor + in_tensor, min=0., max=1., name='output')
+ in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TocoConverter.from_session(
+ sess, [in_tensor_1, in_tensor_2], [out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8
- converter.quantized_input_stats = [(0., 1.)] # mean, std_dev
+ converter.quantized_input_stats = {
+ 'inputA': (0., 1.),
+ 'inputB': (0., 1.)
+ } # mean, std_dev
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -83,13 +91,19 @@ class FromSessionTest(test_util.TensorFlowTestCase):
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
- self.assertEqual(1, len(input_details))
- self.assertEqual('input', input_details[0]['name'])
+ self.assertEqual(2, len(input_details))
+ self.assertEqual('inputA', input_details[0]['name'])
self.assertEqual(np.uint8, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
self.assertEqual((1., 0.),
input_details[0]['quantization']) # scale, zero_point
+ self.assertEqual('inputB', input_details[1]['name'])
+ self.assertEqual(np.uint8, input_details[1]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
+ self.assertEqual((1., 0.),
+ input_details[1]['quantization']) # scale, zero_point
+
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output', output_details[0]['name'])
@@ -97,6 +111,26 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
+ def testQuantizationInvalid(self):
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
+ out_tensor = array_ops.fake_quant_with_min_max_args(
+ in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(
+ sess, [in_tensor_1, in_tensor_2], [out_tensor])
+ converter.inference_type = lite_constants.QUANTIZED_UINT8
+ converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev
+ with self.assertRaises(ValueError) as error:
+ converter.convert()
+ self.assertEqual(
+ 'Quantization input stats are not available for input tensors '
+ '\'inputB\'.', str(error.exception))
+
def testBatchSizeInvalid(self):
in_tensor = array_ops.placeholder(
shape=[None, 16, 16, 3], dtype=dtypes.float32)
@@ -152,8 +186,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(
- sess, [in_tensor], [out_tensor], freeze_variables=True)
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -188,6 +221,135 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(graphviz_output)
+class FromFlatbufferFile(test_util.TensorFlowTestCase):
+
+ def testFloat(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_flatbuffer_file(
+ graph_def_file, ['Placeholder'], ['add'])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testFloatWithShapesArray(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_flatbuffer_file(
+ graph_def_file, ['Placeholder'], ['add'],
+ input_shapes={'Placeholder': [1, 16, 16, 3]})
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+
+ def testFreezeGraph(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ var = variable_scope.get_variable(
+ 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + var
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+
+ # Ensure the graph with variables cannot be converted.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_flatbuffer_file(graph_def_file, ['Placeholder'],
+ ['add'])
+ self.assertEqual('Please freeze the graph using freeze_graph.py',
+ str(error.exception))
+
+ def testPbtxt(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt')
+ write_graph(sess.graph_def, '', graph_def_file, True)
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_flatbuffer_file(
+ graph_def_file, ['Placeholder'], ['add'])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testInvalidFile(self):
+ graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
+ with gfile.Open(graph_def_file, 'wb') as temp_file:
+ temp_file.write('bad data')
+ temp_file.flush()
+
+ # Attempts to convert the invalid model.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_flatbuffer_file(graph_def_file, ['Placeholder'],
+ ['add'])
+ self.assertEqual(
+ 'Unable to parse input file \'{}\'.'.format(graph_def_file),
+ str(error.exception))
+
+
class FromSavedModelTest(test_util.TensorFlowTestCase):
def _createSavedModel(self, shape):