aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/argmax_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/argmax_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/argmax_op_test.py25
1 files changed, 12 insertions, 13 deletions
diff --git a/tensorflow/python/kernel_tests/argmax_op_test.py b/tensorflow/python/kernel_tests/argmax_op_test.py
index ae171dec55..ce06769902 100644
--- a/tensorflow/python/kernel_tests/argmax_op_test.py
+++ b/tensorflow/python/kernel_tests/argmax_op_test.py
@@ -29,12 +29,12 @@ class ArgMaxTest(test.TestCase):
def _testArg(self,
method,
x,
- dimension,
+ axis,
expected_values,
use_gpu=False,
expected_err_re=None):
with self.test_session(use_gpu=use_gpu):
- ans = method(x, dimension=dimension)
+ ans = method(x, axis=axis)
if expected_err_re is None:
tf_ans = ans.eval()
# Defaults to int64 output.
@@ -48,27 +48,26 @@ class ArgMaxTest(test.TestCase):
def _testBothArg(self,
method,
x,
- dimension,
+ axis,
expected_values,
expected_err_re=None):
- self._testArg(method, x, dimension, expected_values, True, expected_err_re)
- self._testArg(method, x, dimension, expected_values, False, expected_err_re)
+ self._testArg(method, x, axis, expected_values, True, expected_err_re)
+ self._testArg(method, x, axis, expected_values, False, expected_err_re)
def _testBasic(self, dtype):
x = np.asarray(100 * np.random.randn(200), dtype=dtype)
- # Check that argmin and argmax match numpy along the primary
- # dimension
+ # Check that argmin and argmax match numpy along the primary axis
self._testBothArg(math_ops.argmax, x, 0, x.argmax())
self._testBothArg(math_ops.argmin, x, 0, x.argmin())
def _testDim(self, dtype):
x = np.asarray(100 * np.random.randn(3, 2, 4, 5, 6), dtype=dtype)
- # Check that argmin and argmax match numpy along all dimensions
- for dim in range(-5, 5):
- self._testBothArg(math_ops.argmax, x, dim, x.argmax(dim))
- self._testBothArg(math_ops.argmin, x, dim, x.argmin(dim))
+ # Check that argmin and argmax match numpy along all axes
+ for axis in range(-5, 5):
+ self._testBothArg(math_ops.argmax, x, axis, x.argmax(axis))
+ self._testBothArg(math_ops.argmin, x, axis, x.argmin(axis))
def testFloat(self):
self._testBasic(np.float32)
@@ -78,7 +77,7 @@ class ArgMaxTest(test.TestCase):
x = np.asarray(100 * np.random.randn(200), dtype=np.float32)
expected_values = x.argmax()
with self.test_session(use_gpu=True):
- ans = math_ops.argmax(x, dimension=0, output_type=dtypes.int32)
+ ans = math_ops.argmax(x, axis=0, output_type=dtypes.int32)
tf_ans = ans.eval()
self.assertEqual(np.int32, tf_ans.dtype)
# The values are equal when comparing int32 to int64 because
@@ -86,7 +85,7 @@ class ArgMaxTest(test.TestCase):
self.assertAllEqual(tf_ans, expected_values)
expected_values = x.argmin()
with self.test_session(use_gpu=True):
- ans = math_ops.argmin(x, dimension=0, output_type=dtypes.int32)
+ ans = math_ops.argmin(x, axis=0, output_type=dtypes.int32)
tf_ans = ans.eval()
self.assertEqual(np.int32, tf_ans.dtype)
self.assertAllEqual(tf_ans, expected_values)