aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/framework/tensor_util.py2
-rw-r--r--tensorflow/python/framework/tensor_util_test.py40
2 files changed, 33 insertions, 9 deletions
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 7bb85ab81a..896100f2ff 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -182,7 +182,7 @@ def _FlattenToStrings(nested_strings):
_TENSOR_CONTENT_TYPES = frozenset([
dtypes.float32, dtypes.float64, dtypes.int32, dtypes.uint8, dtypes.int16,
- dtypes.int8, dtypes.int64
+ dtypes.int8, dtypes.int64, dtypes.qint8, dtypes.quint8, dtypes.qint32,
])
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index 10d2beccc3..f64e66ac37 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -231,14 +231,38 @@ class TensorUtilTest(tf.test.TestCase):
self.assertAllClose(np.array([10, 20, 30], dtype=np.int64), a)
def testQuantizedTypes(self):
- for dtype in [tf.qint32, tf.quint8, tf.qint8]:
- # Test with array.
- t = tensor_util.make_tensor_proto([(10,), (20,), (30,)], dtype=dtype)
- self.assertEquals(dtype, t.dtype)
- self.assertProtoEquals("dim { size: 3 }", t.tensor_shape)
- self.assertEquals(10, t.int_val[0])
- self.assertEquals(20, t.int_val[1])
- self.assertEquals(30, t.int_val[2])
+ # Test with array.
+ data = [(21,), (22,), (23,)]
+
+ t = tensor_util.make_tensor_proto(data, dtype=tf.qint32)
+ self.assertProtoEquals("""
+ dtype: DT_QINT32
+ tensor_shape { dim { size: 3 } }
+ tensor_content: "\025\000\000\000\026\000\000\000\027\000\000\000"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(tf.qint32.as_numpy_dtype, a.dtype)
+ self.assertAllEqual(np.array(data, dtype=a.dtype), a)
+
+ t = tensor_util.make_tensor_proto(data, dtype=tf.quint8)
+ self.assertProtoEquals("""
+ dtype: DT_QUINT8
+ tensor_shape { dim { size: 3 } }
+ tensor_content: "\025\026\027"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(tf.quint8.as_numpy_dtype, a.dtype)
+ self.assertAllEqual(np.array(data, dtype=a.dtype), a)
+
+ t = tensor_util.make_tensor_proto(data, dtype=tf.qint8)
+ self.assertProtoEquals("""
+ dtype: DT_QINT8
+ tensor_shape { dim { size: 3 } }
+ tensor_content: "\025\026\027"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(tf.qint8.as_numpy_dtype, a.dtype)
+ self.assertAllEqual(np.array(data, dtype=a.dtype), a)
def testString(self):
t = tensor_util.make_tensor_proto("foo")