aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/lite_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-15 16:05:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-15 16:08:40 -0700
commit44a854b85e50d0cdf519747cdb3d21de087b0444 (patch)
treec31e500d65b04578fc846d41d019ec6e62a990e4 /tensorflow/contrib/lite/python/lite_test.py
parent1d74a69443f741e69f9f52cb6bc2940b4d4ae3b7 (diff)
Some fixes to testInferenceInputType
PiperOrigin-RevId: 200789288
Diffstat (limited to 'tensorflow/contrib/lite/python/lite_test.py')
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 8c9d2c1651..a9475de474 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -267,7 +267,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(num_items_graphviz_video > num_items_graphviz)
def testInferenceInputType(self):
- in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8)
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
@@ -286,14 +287,13 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertEqual('Placeholder', input_details[0]['name'])
self.assertEqual(np.uint8, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
- self.assertEqual((0., 0.), input_details[0]['quantization'])
+ self.assertEqual((1., 0.), input_details[0]['quantization'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('add', output_details[0]['name'])
- self.assertEqual(np.uint8, output_details[0]['dtype'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
- self.assertEqual((0., 0.), input_details[0]['quantization'])
def testDefaultRangesStats(self):
in_tensor = array_ops.placeholder(