diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/array_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/array_ops_test.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 9c9c8d4675..c9c8ecb72c 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -173,5 +173,40 @@ class ReverseTest(test_util.TensorFlowTestCase): tf.reverse(data_2d_t, dims_3d_t) +class MeshgridTest(test_util.TensorFlowTestCase): + + def _compare(self, n, np_dtype, use_gpu): + inputs = [] + for i in range(n): + x = np.linspace(-10, 10, 5).astype(np_dtype) + if np_dtype in (np.complex64, np.complex128): + x += 1j + inputs.append(x) + + numpy_out = np.meshgrid(*inputs) + with self.test_session(use_gpu=use_gpu): + tf_out = array_ops.meshgrid(*inputs) + for X, _X in zip(numpy_out, tf_out): + self.assertAllEqual(X, _X.eval()) + + def testCompare(self): + for t in (np.float16, np.float32, np.float64, np.int32, np.int64, + np.complex64, np.complex128): + # Don't test the one-dimensional case, as + # old numpy versions don't support it + self._compare(2, t, False) + self._compare(3, t, False) + self._compare(4, t, False) + self._compare(5, t, False) + + # Test for inputs with rank not equal to 1 + x = [[1, 1], [1, 1]] + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "needs to have rank 1"): + with self.test_session(): + X, _ = array_ops.meshgrid(x, x) + X.eval() + + if __name__ == "__main__": googletest.main() |