aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/array_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/array_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 9c9c8d4675..c9c8ecb72c 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -173,5 +173,40 @@ class ReverseTest(test_util.TensorFlowTestCase):
tf.reverse(data_2d_t, dims_3d_t)
+class MeshgridTest(test_util.TensorFlowTestCase):
+
+ def _compare(self, n, np_dtype, use_gpu):
+ inputs = []
+ for i in range(n):
+ x = np.linspace(-10, 10, 5).astype(np_dtype)
+ if np_dtype in (np.complex64, np.complex128):
+ x += 1j
+ inputs.append(x)
+
+ numpy_out = np.meshgrid(*inputs)
+ with self.test_session(use_gpu=use_gpu):
+ tf_out = array_ops.meshgrid(*inputs)
+ for X, _X in zip(numpy_out, tf_out):
+ self.assertAllEqual(X, _X.eval())
+
+ def testCompare(self):
+ for t in (np.float16, np.float32, np.float64, np.int32, np.int64,
+ np.complex64, np.complex128):
+ # Don't test the one-dimensional case, as
+ # old numpy versions don't support it
+ self._compare(2, t, False)
+ self._compare(3, t, False)
+ self._compare(4, t, False)
+ self._compare(5, t, False)
+
+ # Test for inputs with rank not equal to 1
+ x = [[1, 1], [1, 1]]
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "needs to have rank 1"):
+ with self.test_session():
+ X, _ = array_ops.meshgrid(x, x)
+ X.eval()
+
+
if __name__ == "__main__":
googletest.main()