aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/reshape_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/reshape_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/reshape_op_test.py106
1 files changed, 106 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py
new file mode 100644
index 0000000000..65b0e6d4bf
--- /dev/null
+++ b/tensorflow/python/kernel_tests/reshape_op_test.py
@@ -0,0 +1,106 @@
+"""Tests for tensorflow.ops.reshape_op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.kernel_tests import gradient_checker as gc
+
+
+class ReshapeTest(tf.test.TestCase):
+
+ def _testReshape(self, x, y, use_gpu=False):
+ with self.test_session(use_gpu=use_gpu):
+ np_ans = x.reshape(y)
+ tf_ans = tf.reshape(x, y)
+ out = tf_ans.eval()
+ self.assertEqual(tf_ans.get_shape(), out.shape)
+ self.assertShapeEqual(np_ans, tf_ans)
+
+ def _testBothReshape(self, x, y):
+ self._testReshape(x, y, False)
+ self._testReshape(x, y, True)
+
+ def testFloatBasic(self):
+ x = np.arange(1., 7.).reshape([1, 6]).astype(np.float32)
+ self._testBothReshape(x, [2, 3])
+
+ def testDoubleBasic(self):
+ x = np.arange(1., 7.).reshape([1, 6]).astype(np.float64)
+ self._testBothReshape(x, [2, 3])
+
+ def testInt32Basic(self):
+ x = np.arange(1., 7.).reshape([1, 6]).astype(np.int32)
+ self._testBothReshape(x, [2, 3])
+
+ def testSComplexBasic(self):
+ x = np.arange(1., 7.).reshape([1, 6]).astype(np.complex64)
+ self._testBothReshape(x, [2, 3])
+
+ def testFloatReshapeThreeDimensions(self):
+ x = np.arange(1., 28.).reshape([1, 27]).astype(np.float32)
+ self._testBothReshape(x, [3, 3, 3])
+
+ def testFloatUnspecifiedDimOnly(self):
+ x = np.arange(1., 7.).reshape([6]).astype(np.float32)
+ self._testBothReshape(x, [-1])
+
+ def testFloatUnspecifiedDimBegin(self):
+ x = np.arange(1., 7.).reshape([6]).astype(np.float32)
+ self._testBothReshape(x, [-1, 2])
+
+ def testFloatUnspecifiedDimEnd(self):
+ x = np.arange(1., 7.).reshape([6]).astype(np.float32)
+ self._testBothReshape(x, [3, -1])
+
+ # TODO(vrv): Add tests for failure conditions once python test_util
+ # reports errors.
+
+ def testFloatReshapeGradThreeDimensions(self):
+ x = np.arange(1., 25.).reshape([1, 24]).astype(np.float32)
+ s = list(np.shape(x))
+ with self.test_session():
+ input_tensor = tf.constant(x, shape=[2, 3, 4])
+ reshape_out = tf.reshape(input_tensor, [1, 8, 3])
+ err = gc.ComputeGradientError(input_tensor, s,
+ reshape_out, s, x_init_value=x)
+ print "Reshape gradient error = " % err
+ self.assertLess(err, 1e-3)
+
+ def testFloatEmpty(self):
+ x = np.empty((0, 0, 0, 0), dtype=np.float32)
+ self._testBothReshape(x, [1, 2, 3, 0])
+ self._testBothReshape(x, [1, 0, 0, 4])
+ self._testBothReshape(x, [0, 0, 0, 0])
+ self._testBothReshape(x, [1, 2, 0])
+ self._testBothReshape(x, [0, 0, 0])
+ self._testBothReshape(x, [1, -1, 5])
+
+ def testErrors(self):
+ x = tf.constant(0.0, shape=[1, 0, 3])
+ with self.assertRaisesRegexp(
+ ValueError, "cannot infer the missing input size"):
+ tf.reshape(x, [0, -1, 5])
+
+ y = tf.constant(0.0, shape=[23, 29, 31])
+ with self.assertRaisesRegexp(ValueError, "isn't divisible by 17"):
+ tf.reshape(y, [17, -1])
+
+ def testPartialShapes(self):
+ x = tf.placeholder(tf.float32)
+
+ # Unknown input shape, partial new shape.
+ y = tf.reshape(x, [1, 1, -1, 1])
+ self.assertEqual([1, 1, None, 1], y.get_shape().as_list())
+
+ # Unknown input shape, unknown new shape.
+ y = tf.reshape(x, tf.placeholder(tf.int32))
+ self.assertEqual(None, y.get_shape().ndims)
+
+ # Unknown input shape, known rank for new shape.
+ y = tf.reshape(x, tf.placeholder(tf.int32, shape=(3,)))
+ self.assertEqual([None, None, None], y.get_shape().as_list())
+
+
+if __name__ == "__main__":
+ tf.test.main()