diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/argmax_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/argmax_op_test.py | 25 |
1 files changed, 12 insertions, 13 deletions
diff --git a/tensorflow/python/kernel_tests/argmax_op_test.py b/tensorflow/python/kernel_tests/argmax_op_test.py index ae171dec55..ce06769902 100644 --- a/tensorflow/python/kernel_tests/argmax_op_test.py +++ b/tensorflow/python/kernel_tests/argmax_op_test.py @@ -29,12 +29,12 @@ class ArgMaxTest(test.TestCase): def _testArg(self, method, x, - dimension, + axis, expected_values, use_gpu=False, expected_err_re=None): with self.test_session(use_gpu=use_gpu): - ans = method(x, dimension=dimension) + ans = method(x, axis=axis) if expected_err_re is None: tf_ans = ans.eval() # Defaults to int64 output. @@ -48,27 +48,26 @@ class ArgMaxTest(test.TestCase): def _testBothArg(self, method, x, - dimension, + axis, expected_values, expected_err_re=None): - self._testArg(method, x, dimension, expected_values, True, expected_err_re) - self._testArg(method, x, dimension, expected_values, False, expected_err_re) + self._testArg(method, x, axis, expected_values, True, expected_err_re) + self._testArg(method, x, axis, expected_values, False, expected_err_re) def _testBasic(self, dtype): x = np.asarray(100 * np.random.randn(200), dtype=dtype) - # Check that argmin and argmax match numpy along the primary - # dimension + # Check that argmin and argmax match numpy along the primary axis self._testBothArg(math_ops.argmax, x, 0, x.argmax()) self._testBothArg(math_ops.argmin, x, 0, x.argmin()) def _testDim(self, dtype): x = np.asarray(100 * np.random.randn(3, 2, 4, 5, 6), dtype=dtype) - # Check that argmin and argmax match numpy along all dimensions - for dim in range(-5, 5): - self._testBothArg(math_ops.argmax, x, dim, x.argmax(dim)) - self._testBothArg(math_ops.argmin, x, dim, x.argmin(dim)) + # Check that argmin and argmax match numpy along all axes + for axis in range(-5, 5): + self._testBothArg(math_ops.argmax, x, axis, x.argmax(axis)) + self._testBothArg(math_ops.argmin, x, axis, x.argmin(axis)) def testFloat(self): self._testBasic(np.float32) @@ -78,7 +77,7 @@ class ArgMaxTest(test.TestCase): x = np.asarray(100 * np.random.randn(200), dtype=np.float32) expected_values = x.argmax() with self.test_session(use_gpu=True): - ans = math_ops.argmax(x, dimension=0, output_type=dtypes.int32) + ans = math_ops.argmax(x, axis=0, output_type=dtypes.int32) tf_ans = ans.eval() self.assertEqual(np.int32, tf_ans.dtype) # The values are equal when comparing int32 to int64 because @@ -86,7 +85,7 @@ class ArgMaxTest(test.TestCase): self.assertAllEqual(tf_ans, expected_values) expected_values = x.argmin() with self.test_session(use_gpu=True): - ans = math_ops.argmin(x, dimension=0, output_type=dtypes.int32) + ans = math_ops.argmin(x, axis=0, output_type=dtypes.int32) tf_ans = ans.eval() self.assertEqual(np.int32, tf_ans.dtype) self.assertAllEqual(tf_ans, expected_values) |