aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/lite_test.py
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-08-09 16:49:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 16:57:34 -0700
commit37bfe7a9f290c267ff7a804038fb6c8979975dee (patch)
tree622146198493c9594896c0b5bfb7f6cb0a60378c /tensorflow/contrib/lite/python/lite_test.py
parent01da694c97ac8d51974fdbc32d9a37da42642ed8 (diff)
Fix from_from_keras_model_file function in TocoConverter.
PiperOrigin-RevId: 208133502
Diffstat (limited to 'tensorflow/contrib/lite/python/lite_test.py')
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py62
1 files changed, 49 insertions, 13 deletions
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index ca2af5aaed..2f13684228 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
@@ -198,6 +199,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = in_tensor + var
sess = session.Session()
+ sess.run(_global_variables_initializer())
# Convert model and ensure model is not None.
converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
@@ -655,9 +657,7 @@ class FromKerasFile(test_util.TensorFlowTestCase):
tflite_model = converter.convert()
self.assertTrue(tflite_model)
- os.remove(keras_file)
-
- # Check values from converted model.
+ # Check tensor details of converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
@@ -675,6 +675,18 @@ class FromKerasFile(test_util.TensorFlowTestCase):
self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
+ # Check inference of converted model.
+ input_data = np.array([[1, 2, 3]], dtype=np.float32)
+ interpreter.set_tensor(input_details[0]['index'], input_data)
+ interpreter.invoke()
+ tflite_result = interpreter.get_tensor(output_details[0]['index'])
+
+ keras_model = keras.models.load_model(keras_file)
+ keras_result = keras_model.predict(input_data)
+
+ np.testing.assert_almost_equal(tflite_result, keras_result, 5)
+ os.remove(keras_file)
+
def testSequentialModelInputArray(self):
"""Test a Sequential tf.keras model testing input arrays argument."""
keras_file = self._getSequentialModel()
@@ -755,17 +767,17 @@ class FromKerasFile(test_util.TensorFlowTestCase):
model.predict(x)
fd, keras_file = tempfile.mkstemp('.h5')
- keras.models.save_model(model, keras_file)
+ try:
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
# Convert to TFLite model.
converter = lite.TocoConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
- os.close(fd)
- os.remove(keras_file)
-
- # Check values from converted model.
+ # Check tensor details of converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
@@ -783,6 +795,18 @@ class FromKerasFile(test_util.TensorFlowTestCase):
self.assertTrue(([1, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
+ # Check inference of converted model.
+ input_data = np.array([[1, 2, 3]], dtype=np.float32)
+ interpreter.set_tensor(input_details[0]['index'], input_data)
+ interpreter.invoke()
+ tflite_result = interpreter.get_tensor(output_details[0]['index'])
+
+ keras_model = keras.models.load_model(keras_file)
+ keras_result = keras_model.predict(input_data)
+
+ np.testing.assert_almost_equal(tflite_result, keras_result, 5)
+ os.remove(keras_file)
+
def testFunctionalModelMultipleInputs(self):
"""Test a Functional tf.keras model with multiple inputs and outputs."""
a = keras.layers.Input(shape=(3,), name='input_a')
@@ -865,17 +889,17 @@ class FromKerasFile(test_util.TensorFlowTestCase):
model.predict(x)
fd, keras_file = tempfile.mkstemp('.h5')
- keras.models.save_model(model, keras_file)
+ try:
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
# Convert to TFLite model.
converter = lite.TocoConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
- os.close(fd)
- os.remove(keras_file)
-
- # Check values from converted model.
+ # Check tensor details of converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
@@ -893,6 +917,18 @@ class FromKerasFile(test_util.TensorFlowTestCase):
self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
+ # Check inference of converted model.
+ input_data = np.array([[1, 2, 3]], dtype=np.float32)
+ interpreter.set_tensor(input_details[0]['index'], input_data)
+ interpreter.invoke()
+ tflite_result = interpreter.get_tensor(output_details[0]['index'])
+
+ keras_model = keras.models.load_model(keras_file)
+ keras_result = keras_model.predict(input_data)
+
+ np.testing.assert_almost_equal(tflite_result, keras_result, 5)
+ os.remove(keras_file)
+
if __name__ == '__main__':
test.main()