aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/interpreter_test.py
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-03-22 00:26:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-22 00:29:17 -0700
commit585fb74541ed914845eccd3da4b1a2c94a99779e (patch)
tree566d8b4a0e4f791ed725b1b83f96e2039489347b /tensorflow/contrib/lite/python/interpreter_test.py
parentf83711104b64a108ac43213c92f13827343d09ef (diff)
Minor style improvement in TFLite interpreter_test.py
PiperOrigin-RevId: 190027161
Diffstat (limited to 'tensorflow/contrib/lite/python/interpreter_test.py')
-rw-r--r--tensorflow/contrib/lite/python/interpreter_test.py49
1 files changed, 25 insertions, 24 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py
index bf124410f3..cd2386f526 100644
--- a/tensorflow/contrib/lite/python/interpreter_test.py
+++ b/tensorflow/contrib/lite/python/interpreter_test.py
@@ -61,30 +61,31 @@ class InterpreterTest(test_util.TensorFlowTestCase):
'testdata/permute_uint8.tflite')
with io.open(model_path, 'rb') as model_file:
data = model_file.read()
- interpreter = interpreter_wrapper.Interpreter(model_content=data)
- interpreter.allocate_tensors()
-
- input_details = interpreter.get_input_details()
- self.assertEqual(1, len(input_details))
- self.assertEqual('input', input_details[0]['name'])
- self.assertEqual(np.uint8, input_details[0]['dtype'])
- self.assertTrue(([1, 4] == input_details[0]['shape']).all())
- self.assertEqual((1.0, 0), input_details[0]['quantization'])
-
- output_details = interpreter.get_output_details()
- self.assertEqual(1, len(output_details))
- self.assertEqual('output', output_details[0]['name'])
- self.assertEqual(np.uint8, output_details[0]['dtype'])
- self.assertTrue(([1, 4] == output_details[0]['shape']).all())
- self.assertEqual((1.0, 0), output_details[0]['quantization'])
-
- test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
- expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
- interpreter.set_tensor(input_details[0]['index'], test_input)
- interpreter.invoke()
-
- output_data = interpreter.get_tensor(output_details[0]['index'])
- self.assertTrue((expected_output == output_data).all())
+
+ interpreter = interpreter_wrapper.Interpreter(model_content=data)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('input', input_details[0]['name'])
+ self.assertEqual(np.uint8, input_details[0]['dtype'])
+ self.assertTrue(([1, 4] == input_details[0]['shape']).all())
+ self.assertEqual((1.0, 0), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('output', output_details[0]['name'])
+ self.assertEqual(np.uint8, output_details[0]['dtype'])
+ self.assertTrue(([1, 4] == output_details[0]['shape']).all())
+ self.assertEqual((1.0, 0), output_details[0]['quantization'])
+
+ test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
+ expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
+ interpreter.set_tensor(input_details[0]['index'], test_input)
+ interpreter.invoke()
+
+ output_data = interpreter.get_tensor(output_details[0]['index'])
+ self.assertTrue((expected_output == output_data).all())
if __name__ == '__main__':