diff options
Diffstat (limited to 'tensorflow/python/framework/tensor_util_test.py')
-rw-r--r-- | tensorflow/python/framework/tensor_util_test.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 8f9af29247..20e8601d73 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -514,9 +514,23 @@ class TensorUtilTest(tf.test.TestCase): self.assertEquals(np.complex128, a.dtype) self.assertAllEqual(np.array([[(1+2j), (3+4j)], [(5+6j), (7+8j)]]), a) - def testUnsupportedDType(self): + def testUnsupportedDTypes(self): with self.assertRaises(TypeError): tensor_util.make_tensor_proto(np.array([1]), 0) + with self.assertRaises(TypeError): + tensor_util.make_tensor_proto(3, dtype=tf.qint8) + with self.assertRaises(TypeError): + tensor_util.make_tensor_proto([3], dtype=tf.qint8) + + def testTensorShapeVerification(self): + array = np.array([[1], [2]]) + correct_shape = (2, 1) + incorrect_shape = (1, 2) + tensor_util.make_tensor_proto(array, shape=correct_shape, + verify_shape=True) + with self.assertRaises(TypeError): + tensor_util.make_tensor_proto(array, shape=incorrect_shape, + verify_shape=True) def testShapeTooLarge(self): with self.assertRaises(ValueError): |