aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Adria Puigdomenech <adriap@google.com>2018-10-04 03:19:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 03:24:11 -0700
commit6cc738da1748e819b9c8ee92dc2f1a7bdb291b50 (patch)
tree5658ba2e69b29cae520118880f89e35242574fbe /tensorflow/python
parent6b538d9ce54e878576131cde0c76e43a893180c2 (diff)
Make batch_gather work with indices of dtype int64.
PiperOrigin-RevId: 215711383
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/batch_gather_op_test.py13
-rw-r--r--tensorflow/python/ops/array_ops.py14
3 files changed, 19 insertions, 9 deletions
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 9303c70c60..e055ef1c1b 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -76,6 +76,7 @@ tf_py_test(
name = "batch_gather_op_test",
srcs = ["batch_gather_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
diff --git a/tensorflow/python/kernel_tests/batch_gather_op_test.py b/tensorflow/python/kernel_tests/batch_gather_op_test.py
index 7dd347989a..84e93b8136 100644
--- a/tensorflow/python/kernel_tests/batch_gather_op_test.py
+++ b/tensorflow/python/kernel_tests/batch_gather_op_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import constant_op
@@ -29,7 +30,7 @@ _TEST_TYPES = (dtypes.int64, dtypes.float32,
dtypes.complex64, dtypes.complex128)
-class GatherTest(test.TestCase):
+class GatherTest(test.TestCase, parameterized.TestCase):
def _buildParams(self, data, dtype):
data = data.astype(dtype.as_numpy_dtype)
@@ -39,14 +40,15 @@ class GatherTest(test.TestCase):
return data + 10j * data
return data
- def testSimpleGather(self):
+ @parameterized.parameters(dtypes.int32, dtypes.int64)
+ def testSimpleGather(self, indices_dtype):
data = np.array([0, 1, 2, 3, 7, 5, 8, 9, 10, 11, 15, 13])
indices = [3, 4]
with self.test_session(use_gpu=True):
for dtype in _TEST_TYPES:
params_np = self._buildParams(data, dtype)
params = constant_op.constant(params_np)
- indices_tf = constant_op.constant(indices)
+ indices_tf = constant_op.constant(indices, dtype=indices_dtype)
gather_t = array_ops.batch_gather(params, indices_tf)
expected_result = np.array([3, 7])
np_val = self._buildParams(expected_result, dtype)
@@ -54,14 +56,15 @@ class GatherTest(test.TestCase):
self.assertAllEqual(np_val, gather_val)
self.assertEqual(np_val.shape, gather_t.get_shape())
- def test2DArray(self):
+ @parameterized.parameters(dtypes.int32, dtypes.int64)
+ def test2DArray(self, indices_dtype):
data = np.array([[0, 1, 2, 3, 7, 5], [8, 9, 10, 11, 15, 13]])
indices = [[3], [4]]
with self.test_session(use_gpu=True):
for dtype in _TEST_TYPES:
params_np = self._buildParams(data, dtype)
params = constant_op.constant(params_np)
- indices_tf = constant_op.constant(indices)
+ indices_tf = constant_op.constant(indices, dtype=indices_dtype)
gather_t = array_ops.batch_gather(params, indices_tf)
expected_result = np.array([[3], [15]])
np_val = self._buildParams(expected_result, dtype)
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 9f5149d5ac..4be9c532f4 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -2716,16 +2716,22 @@ def batch_gather(params, indices, name=None):
params = ops.convert_to_tensor(params, name="params")
indices_shape = shape(indices)
params_shape = shape(params)
+
ndims = indices.shape.ndims
if ndims is None:
raise ValueError("batch_gather does not allow indices with unknown "
"shape.")
batch_indices = indices
- accum_dim_value = 1
+ indices_dtype = indices.dtype.base_dtype
+ accum_dim_value = ones((), dtype=indices_dtype)
+ # Use correct type for offset index computation
+ casted_params_shape = gen_math_ops.cast(params_shape, indices_dtype)
for dim in range(ndims-1, 0, -1):
- dim_value = params_shape[dim-1]
- accum_dim_value *= params_shape[dim]
- dim_indices = gen_math_ops._range(0, dim_value, 1)
+ dim_value = casted_params_shape[dim-1]
+ accum_dim_value *= casted_params_shape[dim]
+ start = zeros((), dtype=indices_dtype)
+ step = ones((), dtype=indices_dtype)
+ dim_indices = gen_math_ops._range(start, dim_value, step)
dim_indices *= accum_dim_value
dim_shape = stack([1] * (dim - 1) + [dim_value] + [1] * (ndims - dim),
axis=0)