diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/where_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/where_op_test.py | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py new file mode 100644 index 0000000000..263f98f622 --- /dev/null +++ b/tensorflow/python/kernel_tests/where_op_test.py @@ -0,0 +1,43 @@ +"""Tests for tensorflow.ops.reverse_sequence_op.""" +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + + +class WhereOpTest(tf.test.TestCase): + + def _testWhere(self, x, truth, expected_err_re=None): + with self.test_session(): + ans = tf.where(x) + self.assertEqual([None, x.ndim], ans.get_shape().as_list()) + if expected_err_re is None: + tf_ans = ans.eval() + self.assertAllClose(tf_ans, truth, atol=1e-10) + else: + with self.assertRaisesOpError(expected_err_re): + ans.eval() + + def testBasicMat(self): + x = np.asarray([[True, False], [True, False]]) + + # Ensure RowMajor mode + truth = np.asarray([[0, 0], [1, 0]], dtype=np.int64) + + self._testWhere(x, truth) + + def testBasic3Tensor(self): + x = np.asarray( + [[[True, False], [True, False]], [[False, True], [False, True]], + [[False, False], [False, True]]]) + + # Ensure RowMajor mode + truth = np.asarray( + [[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1], [2, 1, 1]], + dtype=np.int64) + + self._testWhere(x, truth) + + +if __name__ == "__main__": + tf.test.main() |