diff options
-rw-r--r-- | tensorflow/python/framework/tensor_util.py | 2 | ||||
-rw-r--r-- | tensorflow/python/framework/tensor_util_test.py | 40 |
2 files changed, 33 insertions, 9 deletions
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 7bb85ab81a..896100f2ff 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -182,7 +182,7 @@ def _FlattenToStrings(nested_strings): _TENSOR_CONTENT_TYPES = frozenset([ dtypes.float32, dtypes.float64, dtypes.int32, dtypes.uint8, dtypes.int16, - dtypes.int8, dtypes.int64 + dtypes.int8, dtypes.int64, dtypes.qint8, dtypes.quint8, dtypes.qint32, ]) diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 10d2beccc3..f64e66ac37 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -231,14 +231,38 @@ class TensorUtilTest(tf.test.TestCase): self.assertAllClose(np.array([10, 20, 30], dtype=np.int64), a) def testQuantizedTypes(self): - for dtype in [tf.qint32, tf.quint8, tf.qint8]: - # Test with array. - t = tensor_util.make_tensor_proto([(10,), (20,), (30,)], dtype=dtype) - self.assertEquals(dtype, t.dtype) - self.assertProtoEquals("dim { size: 3 }", t.tensor_shape) - self.assertEquals(10, t.int_val[0]) - self.assertEquals(20, t.int_val[1]) - self.assertEquals(30, t.int_val[2]) + # Test with array. + data = [(21,), (22,), (23,)] + + t = tensor_util.make_tensor_proto(data, dtype=tf.qint32) + self.assertProtoEquals(""" + dtype: DT_QINT32 + tensor_shape { dim { size: 3 } } + tensor_content: "\025\000\000\000\026\000\000\000\027\000\000\000" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(tf.qint32.as_numpy_dtype, a.dtype) + self.assertAllEqual(np.array(data, dtype=a.dtype), a) + + t = tensor_util.make_tensor_proto(data, dtype=tf.quint8) + self.assertProtoEquals(""" + dtype: DT_QUINT8 + tensor_shape { dim { size: 3 } } + tensor_content: "\025\026\027" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(tf.quint8.as_numpy_dtype, a.dtype) + self.assertAllEqual(np.array(data, dtype=a.dtype), a) + + t = tensor_util.make_tensor_proto(data, dtype=tf.qint8) + self.assertProtoEquals(""" + dtype: DT_QINT8 + tensor_shape { dim { size: 3 } } + tensor_content: "\025\026\027" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(tf.qint8.as_numpy_dtype, a.dtype) + self.assertAllEqual(np.array(data, dtype=a.dtype), a) def testString(self): t = tensor_util.make_tensor_proto("foo") |