diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/reshape_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/reshape_op_test.py | 106 |
1 files changed, 106 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py new file mode 100644 index 0000000000..65b0e6d4bf --- /dev/null +++ b/tensorflow/python/kernel_tests/reshape_op_test.py @@ -0,0 +1,106 @@ +"""Tests for tensorflow.ops.reshape_op.""" +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +from tensorflow.python.kernel_tests import gradient_checker as gc + + +class ReshapeTest(tf.test.TestCase): + + def _testReshape(self, x, y, use_gpu=False): + with self.test_session(use_gpu=use_gpu): + np_ans = x.reshape(y) + tf_ans = tf.reshape(x, y) + out = tf_ans.eval() + self.assertEqual(tf_ans.get_shape(), out.shape) + self.assertShapeEqual(np_ans, tf_ans) + + def _testBothReshape(self, x, y): + self._testReshape(x, y, False) + self._testReshape(x, y, True) + + def testFloatBasic(self): + x = np.arange(1., 7.).reshape([1, 6]).astype(np.float32) + self._testBothReshape(x, [2, 3]) + + def testDoubleBasic(self): + x = np.arange(1., 7.).reshape([1, 6]).astype(np.float64) + self._testBothReshape(x, [2, 3]) + + def testInt32Basic(self): + x = np.arange(1., 7.).reshape([1, 6]).astype(np.int32) + self._testBothReshape(x, [2, 3]) + + def testSComplexBasic(self): + x = np.arange(1., 7.).reshape([1, 6]).astype(np.complex64) + self._testBothReshape(x, [2, 3]) + + def testFloatReshapeThreeDimensions(self): + x = np.arange(1., 28.).reshape([1, 27]).astype(np.float32) + self._testBothReshape(x, [3, 3, 3]) + + def testFloatUnspecifiedDimOnly(self): + x = np.arange(1., 7.).reshape([6]).astype(np.float32) + self._testBothReshape(x, [-1]) + + def testFloatUnspecifiedDimBegin(self): + x = np.arange(1., 7.).reshape([6]).astype(np.float32) + self._testBothReshape(x, [-1, 2]) + + def testFloatUnspecifiedDimEnd(self): + x = np.arange(1., 7.).reshape([6]).astype(np.float32) + self._testBothReshape(x, [3, -1]) + + # TODO(vrv): Add tests for failure conditions once python test_util + # reports errors. + + def testFloatReshapeGradThreeDimensions(self): + x = np.arange(1., 25.).reshape([1, 24]).astype(np.float32) + s = list(np.shape(x)) + with self.test_session(): + input_tensor = tf.constant(x, shape=[2, 3, 4]) + reshape_out = tf.reshape(input_tensor, [1, 8, 3]) + err = gc.ComputeGradientError(input_tensor, s, + reshape_out, s, x_init_value=x) + print "Reshape gradient error = " % err + self.assertLess(err, 1e-3) + + def testFloatEmpty(self): + x = np.empty((0, 0, 0, 0), dtype=np.float32) + self._testBothReshape(x, [1, 2, 3, 0]) + self._testBothReshape(x, [1, 0, 0, 4]) + self._testBothReshape(x, [0, 0, 0, 0]) + self._testBothReshape(x, [1, 2, 0]) + self._testBothReshape(x, [0, 0, 0]) + self._testBothReshape(x, [1, -1, 5]) + + def testErrors(self): + x = tf.constant(0.0, shape=[1, 0, 3]) + with self.assertRaisesRegexp( + ValueError, "cannot infer the missing input size"): + tf.reshape(x, [0, -1, 5]) + + y = tf.constant(0.0, shape=[23, 29, 31]) + with self.assertRaisesRegexp(ValueError, "isn't divisible by 17"): + tf.reshape(y, [17, -1]) + + def testPartialShapes(self): + x = tf.placeholder(tf.float32) + + # Unknown input shape, partial new shape. + y = tf.reshape(x, [1, 1, -1, 1]) + self.assertEqual([1, 1, None, 1], y.get_shape().as_list()) + + # Unknown input shape, unknown new shape. + y = tf.reshape(x, tf.placeholder(tf.int32)) + self.assertEqual(None, y.get_shape().ndims) + + # Unknown input shape, known rank for new shape. + y = tf.reshape(x, tf.placeholder(tf.int32, shape=(3,))) + self.assertEqual([None, None, None], y.get_shape().as_list()) + + +if __name__ == "__main__": + tf.test.main() |