aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/where_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/where_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/where_op_test.py43
1 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py
new file mode 100644
index 0000000000..263f98f622
--- /dev/null
+++ b/tensorflow/python/kernel_tests/where_op_test.py
@@ -0,0 +1,43 @@
+"""Tests for tensorflow.ops.reverse_sequence_op."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class WhereOpTest(tf.test.TestCase):
+
+ def _testWhere(self, x, truth, expected_err_re=None):
+ with self.test_session():
+ ans = tf.where(x)
+ self.assertEqual([None, x.ndim], ans.get_shape().as_list())
+ if expected_err_re is None:
+ tf_ans = ans.eval()
+ self.assertAllClose(tf_ans, truth, atol=1e-10)
+ else:
+ with self.assertRaisesOpError(expected_err_re):
+ ans.eval()
+
+ def testBasicMat(self):
+ x = np.asarray([[True, False], [True, False]])
+
+ # Ensure RowMajor mode
+ truth = np.asarray([[0, 0], [1, 0]], dtype=np.int64)
+
+ self._testWhere(x, truth)
+
+ def testBasic3Tensor(self):
+ x = np.asarray(
+ [[[True, False], [True, False]], [[False, True], [False, True]],
+ [[False, False], [False, True]]])
+
+ # Ensure RowMajor mode
+ truth = np.asarray(
+ [[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1], [2, 1, 1]],
+ dtype=np.int64)
+
+ self._testWhere(x, truth)
+
+
+if __name__ == "__main__":
+ tf.test.main()