aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/identity_op_py_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/identity_op_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/identity_op_py_test.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/identity_op_py_test.py b/tensorflow/python/kernel_tests/identity_op_py_test.py
new file mode 100644
index 0000000000..2209cf08ad
--- /dev/null
+++ b/tensorflow/python/kernel_tests/identity_op_py_test.py
@@ -0,0 +1,47 @@
+"""Tests for IdentityOp."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.ops import gen_array_ops
+
+
+class IdentityOpTest(tf.test.TestCase):
+
+ def testInt32_6(self):
+ with self.test_session():
+ value = tf.identity([1, 2, 3, 4, 5, 6]).eval()
+ self.assertAllEqual(np.array([1, 2, 3, 4, 5, 6]), value)
+
+ def testInt32_2_3(self):
+ with self.test_session():
+ inp = tf.constant([10, 20, 30, 40, 50, 60], shape=[2, 3])
+ value = tf.identity(inp).eval()
+ self.assertAllEqual(np.array([[10, 20, 30], [40, 50, 60]]), value)
+
+ def testString(self):
+ with self.test_session():
+ value = tf.identity(["A", "b", "C", "d", "E", "f"]).eval()
+ self.assertAllEqual(["A", "b", "C", "d", "E", "f"], value)
+
+ def testIdentityShape(self):
+ with self.test_session():
+ shape = [2, 3]
+ array_2x3 = [[1, 2, 3], [6, 5, 4]]
+ tensor = tf.constant(array_2x3)
+ self.assertEquals(shape, tensor.get_shape())
+ self.assertEquals(shape, tf.identity(tensor).get_shape())
+ self.assertEquals(shape, tf.identity(array_2x3).get_shape())
+ self.assertEquals(shape, tf.identity(np.array(array_2x3)).get_shape())
+
+ def testRefIdentityShape(self):
+ with self.test_session():
+ shape = [2, 3]
+ tensor = tf.Variable(tf.constant([[1, 2, 3], [6, 5, 4]], dtype=tf.int32))
+ self.assertEquals(shape, tensor.get_shape())
+ self.assertEquals(shape, gen_array_ops._ref_identity(tensor).get_shape())
+
+
+if __name__ == "__main__":
+ tf.test.main()