aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/tensor_util_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/tensor_util_test.py')
-rw-r--r--tensorflow/python/framework/tensor_util_test.py16
1 files changed, 15 insertions, 1 deletions
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index 8f9af29247..20e8601d73 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -514,9 +514,23 @@ class TensorUtilTest(tf.test.TestCase):
self.assertEquals(np.complex128, a.dtype)
self.assertAllEqual(np.array([[(1+2j), (3+4j)], [(5+6j), (7+8j)]]), a)
- def testUnsupportedDType(self):
+ def testUnsupportedDTypes(self):
with self.assertRaises(TypeError):
tensor_util.make_tensor_proto(np.array([1]), 0)
+ with self.assertRaises(TypeError):
+ tensor_util.make_tensor_proto(3, dtype=tf.qint8)
+ with self.assertRaises(TypeError):
+ tensor_util.make_tensor_proto([3], dtype=tf.qint8)
+
+ def testTensorShapeVerification(self):
+ array = np.array([[1], [2]])
+ correct_shape = (2, 1)
+ incorrect_shape = (1, 2)
+ tensor_util.make_tensor_proto(array, shape=correct_shape,
+ verify_shape=True)
+ with self.assertRaises(TypeError):
+ tensor_util.make_tensor_proto(array, shape=incorrect_shape,
+ verify_shape=True)
def testShapeTooLarge(self):
with self.assertRaises(ValueError):