diff options
author | 2018-03-22 00:26:31 -0700 | |
---|---|---|
committer | 2018-03-22 00:29:17 -0700 | |
commit | 585fb74541ed914845eccd3da4b1a2c94a99779e (patch) | |
tree | 566d8b4a0e4f791ed725b1b83f96e2039489347b /tensorflow/contrib/lite/python/interpreter_test.py | |
parent | f83711104b64a108ac43213c92f13827343d09ef (diff) |
Minor style improvement in TFLite interpreter_test.py
PiperOrigin-RevId: 190027161
Diffstat (limited to 'tensorflow/contrib/lite/python/interpreter_test.py')
-rw-r--r-- | tensorflow/contrib/lite/python/interpreter_test.py | 49 |
1 files changed, 25 insertions, 24 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py index bf124410f3..cd2386f526 100644 --- a/tensorflow/contrib/lite/python/interpreter_test.py +++ b/tensorflow/contrib/lite/python/interpreter_test.py @@ -61,30 +61,31 @@ class InterpreterTest(test_util.TensorFlowTestCase): 'testdata/permute_uint8.tflite') with io.open(model_path, 'rb') as model_file: data = model_file.read() - interpreter = interpreter_wrapper.Interpreter(model_content=data) - interpreter.allocate_tensors() - - input_details = interpreter.get_input_details() - self.assertEqual(1, len(input_details)) - self.assertEqual('input', input_details[0]['name']) - self.assertEqual(np.uint8, input_details[0]['dtype']) - self.assertTrue(([1, 4] == input_details[0]['shape']).all()) - self.assertEqual((1.0, 0), input_details[0]['quantization']) - - output_details = interpreter.get_output_details() - self.assertEqual(1, len(output_details)) - self.assertEqual('output', output_details[0]['name']) - self.assertEqual(np.uint8, output_details[0]['dtype']) - self.assertTrue(([1, 4] == output_details[0]['shape']).all()) - self.assertEqual((1.0, 0), output_details[0]['quantization']) - - test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8) - expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8) - interpreter.set_tensor(input_details[0]['index'], test_input) - interpreter.invoke() - - output_data = interpreter.get_tensor(output_details[0]['index']) - self.assertTrue((expected_output == output_data).all()) + + interpreter = interpreter_wrapper.Interpreter(model_content=data) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('input', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 4] == input_details[0]['shape']).all()) + self.assertEqual((1.0, 0), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('output', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 4] == output_details[0]['shape']).all()) + self.assertEqual((1.0, 0), output_details[0]['quantization']) + + test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8) + expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8) + interpreter.set_tensor(input_details[0]['index'], test_input) + interpreter.invoke() + + output_data = interpreter.get_tensor(output_details[0]['index']) + self.assertTrue((expected_output == output_data).all()) if __name__ == '__main__': |