diff options
author | 2018-10-04 03:19:46 -0700 | |
---|---|---|
committer | 2018-10-04 03:24:11 -0700 | |
commit | 6cc738da1748e819b9c8ee92dc2f1a7bdb291b50 (patch) | |
tree | 5658ba2e69b29cae520118880f89e35242574fbe /tensorflow/python/ops | |
parent | 6b538d9ce54e878576131cde0c76e43a893180c2 (diff) |
Make batch_gather work with indices of dtype int64.
PiperOrigin-RevId: 215711383
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 9f5149d5ac..4be9c532f4 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2716,16 +2716,22 @@ def batch_gather(params, indices, name=None): params = ops.convert_to_tensor(params, name="params") indices_shape = shape(indices) params_shape = shape(params) + ndims = indices.shape.ndims if ndims is None: raise ValueError("batch_gather does not allow indices with unknown " "shape.") batch_indices = indices - accum_dim_value = 1 + indices_dtype = indices.dtype.base_dtype + accum_dim_value = ones((), dtype=indices_dtype) + # Use correct type for offset index computation + casted_params_shape = gen_math_ops.cast(params_shape, indices_dtype) for dim in range(ndims-1, 0, -1): - dim_value = params_shape[dim-1] - accum_dim_value *= params_shape[dim] - dim_indices = gen_math_ops._range(0, dim_value, 1) + dim_value = casted_params_shape[dim-1] + accum_dim_value *= casted_params_shape[dim] + start = zeros((), dtype=indices_dtype) + step = ones((), dtype=indices_dtype) + dim_indices = gen_math_ops._range(start, dim_value, step) dim_indices *= accum_dim_value dim_shape = stack([1] * (dim - 1) + [dim_value] + [1] * (ndims - dim), axis=0) |