diff options
author | Nupur Garg <nupurgarg@google.com> | 2018-08-09 16:49:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-09 16:57:34 -0700 |
commit | 37bfe7a9f290c267ff7a804038fb6c8979975dee (patch) | |
tree | 622146198493c9594896c0b5bfb7f6cb0a60378c /tensorflow/contrib/lite/python/lite_test.py | |
parent | 01da694c97ac8d51974fdbc32d9a37da42642ed8 (diff) |
Fix from_from_keras_model_file function in TocoConverter.
PiperOrigin-RevId: 208133502
Diffstat (limited to 'tensorflow/contrib/lite/python/lite_test.py')
-rw-r--r-- | tensorflow/contrib/lite/python/lite_test.py | 62 |
1 files changed, 49 insertions, 13 deletions
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index ca2af5aaed..2f13684228 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.saved_model import saved_model @@ -198,6 +199,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) out_tensor = in_tensor + var sess = session.Session() + sess.run(_global_variables_initializer()) # Convert model and ensure model is not None. converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) @@ -655,9 +657,7 @@ class FromKerasFile(test_util.TensorFlowTestCase): tflite_model = converter.convert() self.assertTrue(tflite_model) - os.remove(keras_file) - - # Check values from converted model. + # Check tensor details of converted model. interpreter = Interpreter(model_content=tflite_model) interpreter.allocate_tensors() @@ -675,6 +675,18 @@ class FromKerasFile(test_util.TensorFlowTestCase): self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) + # Check inference of converted model. + input_data = np.array([[1, 2, 3]], dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], input_data) + interpreter.invoke() + tflite_result = interpreter.get_tensor(output_details[0]['index']) + + keras_model = keras.models.load_model(keras_file) + keras_result = keras_model.predict(input_data) + + np.testing.assert_almost_equal(tflite_result, keras_result, 5) + os.remove(keras_file) + def testSequentialModelInputArray(self): """Test a Sequential tf.keras model testing input arrays argument.""" keras_file = self._getSequentialModel() @@ -755,17 +767,17 @@ class FromKerasFile(test_util.TensorFlowTestCase): model.predict(x) fd, keras_file = tempfile.mkstemp('.h5') - keras.models.save_model(model, keras_file) + try: + keras.models.save_model(model, keras_file) + finally: + os.close(fd) # Convert to TFLite model. converter = lite.TocoConverter.from_keras_model_file(keras_file) tflite_model = converter.convert() self.assertTrue(tflite_model) - os.close(fd) - os.remove(keras_file) - - # Check values from converted model. + # Check tensor details of converted model. interpreter = Interpreter(model_content=tflite_model) interpreter.allocate_tensors() @@ -783,6 +795,18 @@ class FromKerasFile(test_util.TensorFlowTestCase): self.assertTrue(([1, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) + # Check inference of converted model. + input_data = np.array([[1, 2, 3]], dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], input_data) + interpreter.invoke() + tflite_result = interpreter.get_tensor(output_details[0]['index']) + + keras_model = keras.models.load_model(keras_file) + keras_result = keras_model.predict(input_data) + + np.testing.assert_almost_equal(tflite_result, keras_result, 5) + os.remove(keras_file) + def testFunctionalModelMultipleInputs(self): """Test a Functional tf.keras model with multiple inputs and outputs.""" a = keras.layers.Input(shape=(3,), name='input_a') @@ -865,17 +889,17 @@ class FromKerasFile(test_util.TensorFlowTestCase): model.predict(x) fd, keras_file = tempfile.mkstemp('.h5') - keras.models.save_model(model, keras_file) + try: + keras.models.save_model(model, keras_file) + finally: + os.close(fd) # Convert to TFLite model. converter = lite.TocoConverter.from_keras_model_file(keras_file) tflite_model = converter.convert() self.assertTrue(tflite_model) - os.close(fd) - os.remove(keras_file) - - # Check values from converted model. + # Check tensor details of converted model. interpreter = Interpreter(model_content=tflite_model) interpreter.allocate_tensors() @@ -893,6 +917,18 @@ class FromKerasFile(test_util.TensorFlowTestCase): self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) + # Check inference of converted model. + input_data = np.array([[1, 2, 3]], dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], input_data) + interpreter.invoke() + tflite_result = interpreter.get_tensor(output_details[0]['index']) + + keras_model = keras.models.load_model(keras_file) + keras_result = keras_model.predict(input_data) + + np.testing.assert_almost_equal(tflite_result, keras_result, 5) + os.remove(keras_file) + if __name__ == '__main__': test.main() |