diff options
author | Adria Puigdomenech <adriap@google.com> | 2018-10-04 03:19:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 03:24:11 -0700 |
commit | 6cc738da1748e819b9c8ee92dc2f1a7bdb291b50 (patch) | |
tree | 5658ba2e69b29cae520118880f89e35242574fbe /tensorflow/python | |
parent | 6b538d9ce54e878576131cde0c76e43a893180c2 (diff) |
Make batch_gather work with indices of dtype int64.
PiperOrigin-RevId: 215711383
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/batch_gather_op_test.py | 13 | ||||
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 14 |
3 files changed, 19 insertions, 9 deletions
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 9303c70c60..e055ef1c1b 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -76,6 +76,7 @@ tf_py_test( name = "batch_gather_op_test", srcs = ["batch_gather_op_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", diff --git a/tensorflow/python/kernel_tests/batch_gather_op_test.py b/tensorflow/python/kernel_tests/batch_gather_op_test.py index 7dd347989a..84e93b8136 100644 --- a/tensorflow/python/kernel_tests/batch_gather_op_test.py +++ b/tensorflow/python/kernel_tests/batch_gather_op_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.framework import constant_op @@ -29,7 +30,7 @@ _TEST_TYPES = (dtypes.int64, dtypes.float32, dtypes.complex64, dtypes.complex128) -class GatherTest(test.TestCase): +class GatherTest(test.TestCase, parameterized.TestCase): def _buildParams(self, data, dtype): data = data.astype(dtype.as_numpy_dtype) @@ -39,14 +40,15 @@ class GatherTest(test.TestCase): return data + 10j * data return data - def testSimpleGather(self): + @parameterized.parameters(dtypes.int32, dtypes.int64) + def testSimpleGather(self, indices_dtype): data = np.array([0, 1, 2, 3, 7, 5, 8, 9, 10, 11, 15, 13]) indices = [3, 4] with self.test_session(use_gpu=True): for dtype in _TEST_TYPES: params_np = self._buildParams(data, dtype) params = constant_op.constant(params_np) - indices_tf = constant_op.constant(indices) + indices_tf = constant_op.constant(indices, dtype=indices_dtype) gather_t = array_ops.batch_gather(params, indices_tf) expected_result = np.array([3, 7]) np_val = self._buildParams(expected_result, dtype) @@ -54,14 +56,15 @@ class GatherTest(test.TestCase): self.assertAllEqual(np_val, gather_val) self.assertEqual(np_val.shape, gather_t.get_shape()) - def test2DArray(self): + @parameterized.parameters(dtypes.int32, dtypes.int64) + def test2DArray(self, indices_dtype): data = np.array([[0, 1, 2, 3, 7, 5], [8, 9, 10, 11, 15, 13]]) indices = [[3], [4]] with self.test_session(use_gpu=True): for dtype in _TEST_TYPES: params_np = self._buildParams(data, dtype) params = constant_op.constant(params_np) - indices_tf = constant_op.constant(indices) + indices_tf = constant_op.constant(indices, dtype=indices_dtype) gather_t = array_ops.batch_gather(params, indices_tf) expected_result = np.array([[3], [15]]) np_val = self._buildParams(expected_result, dtype) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 9f5149d5ac..4be9c532f4 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2716,16 +2716,22 @@ def batch_gather(params, indices, name=None): params = ops.convert_to_tensor(params, name="params") indices_shape = shape(indices) params_shape = shape(params) + ndims = indices.shape.ndims if ndims is None: raise ValueError("batch_gather does not allow indices with unknown " "shape.") batch_indices = indices - accum_dim_value = 1 + indices_dtype = indices.dtype.base_dtype + accum_dim_value = ones((), dtype=indices_dtype) + # Use correct type for offset index computation + casted_params_shape = gen_math_ops.cast(params_shape, indices_dtype) for dim in range(ndims-1, 0, -1): - dim_value = params_shape[dim-1] - accum_dim_value *= params_shape[dim] - dim_indices = gen_math_ops._range(0, dim_value, 1) + dim_value = casted_params_shape[dim-1] + accum_dim_value *= casted_params_shape[dim] + start = zeros((), dtype=indices_dtype) + step = ones((), dtype=indices_dtype) + dim_indices = gen_math_ops._range(start, dim_value, step) dim_indices *= accum_dim_value dim_shape = stack([1] * (dim - 1) + [dim_value] + [1] * (ndims - dim), axis=0) |