diff options
Diffstat (limited to 'tensorflow/python/ops/array_ops.py')
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 9f5149d5ac..4be9c532f4 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2716,16 +2716,22 @@ def batch_gather(params, indices, name=None): params = ops.convert_to_tensor(params, name="params") indices_shape = shape(indices) params_shape = shape(params) + ndims = indices.shape.ndims if ndims is None: raise ValueError("batch_gather does not allow indices with unknown " "shape.") batch_indices = indices - accum_dim_value = 1 + indices_dtype = indices.dtype.base_dtype + accum_dim_value = ones((), dtype=indices_dtype) + # Use correct type for offset index computation + casted_params_shape = gen_math_ops.cast(params_shape, indices_dtype) for dim in range(ndims-1, 0, -1): - dim_value = params_shape[dim-1] - accum_dim_value *= params_shape[dim] - dim_indices = gen_math_ops._range(0, dim_value, 1) + dim_value = casted_params_shape[dim-1] + accum_dim_value *= casted_params_shape[dim] + start = zeros((), dtype=indices_dtype) + step = ones((), dtype=indices_dtype) + dim_indices = gen_math_ops._range(start, dim_value, step) dim_indices *= accum_dim_value dim_shape = stack([1] * (dim - 1) + [dim_value] + [1] * (ndims - dim), axis=0) |