diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-15 16:05:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-15 16:08:40 -0700 |
commit | 44a854b85e50d0cdf519747cdb3d21de087b0444 (patch) | |
tree | c31e500d65b04578fc846d41d019ec6e62a990e4 /tensorflow/contrib/lite/python/lite_test.py | |
parent | 1d74a69443f741e69f9f52cb6bc2940b4d4ae3b7 (diff) |
Some fixes to testInferenceInputType
PiperOrigin-RevId: 200789288
Diffstat (limited to 'tensorflow/contrib/lite/python/lite_test.py')
-rw-r--r-- | tensorflow/contrib/lite/python/lite_test.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 8c9d2c1651..a9475de474 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -267,7 +267,8 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(num_items_graphviz_video > num_items_graphviz) def testInferenceInputType(self): - in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8) + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) out_tensor = in_tensor + in_tensor sess = session.Session() @@ -286,14 +287,13 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertEqual('Placeholder', input_details[0]['name']) self.assertEqual(np.uint8, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) - self.assertEqual((0., 0.), input_details[0]['quantization']) + self.assertEqual((1., 0.), input_details[0]['quantization']) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) self.assertEqual('add', output_details[0]['name']) - self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertEqual(np.float32, output_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) - self.assertEqual((0., 0.), input_details[0]['quantization']) def testDefaultRangesStats(self): in_tensor = array_ops.placeholder( |