aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar Adria Puigdomenech <adriap@google.com>2018-10-04 03:19:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 03:24:11 -0700
commit6cc738da1748e819b9c8ee92dc2f1a7bdb291b50 (patch)
tree5658ba2e69b29cae520118880f89e35242574fbe /tensorflow/python/ops
parent6b538d9ce54e878576131cde0c76e43a893180c2 (diff)
Make batch_gather work with indices of dtype int64.
PiperOrigin-RevId: 215711383
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/array_ops.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 9f5149d5ac..4be9c532f4 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -2716,16 +2716,22 @@ def batch_gather(params, indices, name=None):
params = ops.convert_to_tensor(params, name="params")
indices_shape = shape(indices)
params_shape = shape(params)
+
ndims = indices.shape.ndims
if ndims is None:
raise ValueError("batch_gather does not allow indices with unknown "
"shape.")
batch_indices = indices
- accum_dim_value = 1
+ indices_dtype = indices.dtype.base_dtype
+ accum_dim_value = ones((), dtype=indices_dtype)
+ # Use correct type for offset index computation
+ casted_params_shape = gen_math_ops.cast(params_shape, indices_dtype)
for dim in range(ndims-1, 0, -1):
- dim_value = params_shape[dim-1]
- accum_dim_value *= params_shape[dim]
- dim_indices = gen_math_ops._range(0, dim_value, 1)
+ dim_value = casted_params_shape[dim-1]
+ accum_dim_value *= casted_params_shape[dim]
+ start = zeros((), dtype=indices_dtype)
+ step = ones((), dtype=indices_dtype)
+ dim_indices = gen_math_ops._range(start, dim_value, step)
dim_indices *= accum_dim_value
dim_shape = stack([1] * (dim - 1) + [dim_value] + [1] * (ndims - dim),
axis=0)