diff options
author | 2017-07-11 09:51:54 -0700 | |
---|---|---|
committer | 2017-07-11 09:55:52 -0700 | |
commit | b1f9e2c89eb007cb4b9483d08dcace1e45e84164 (patch) | |
tree | b2b82fc0bd6abf3b77a412a251fcfdacf70a21dc /tensorflow/core/kernels/gather_functor.h | |
parent | 18a5510e67ef536c947512b70030c5c995ce7875 (diff) |
Add an axis parameter to tf.gather. Fixes GitHub issue #11223.
This brings tf.gather closer to compatibility with numpy.take.
To emulate gathering over an axis generally requires inefficient workarounds, e.g. transpose/gather/transpose. This technique is gaining popularity (hundreds of uses inside and outside of Google), so it is worth supporting efficiently.
For an `[a_0, ..., a_i, ..., a_n]` tensor, gathering `N` elements from axis `i` requires `(a_0*...*a_i-1) * N` copies of `(a_i+1 * ... * a_n)` elements each. The CPU kernel does this with memcpy which is far more efficient than transpose/gather/transpose since it requires no intermediate allocations and copies. The GPU kernel does the same number of copies but in parallel across multiple hardware threads.
Since this is a backwards incompatible change, this adds a "GatherV2" op with an axis input, and simultaneously supports backwards compatibility with "Gather" ops by defaulting to axis 0 if a 3rd input is not present.
PiperOrigin-RevId: 161541416
Diffstat (limited to 'tensorflow/core/kernels/gather_functor.h')
-rw-r--r-- | tensorflow/core/kernels/gather_functor.h | 78 |
1 files changed, 45 insertions, 33 deletions
diff --git a/tensorflow/core/kernels/gather_functor.h b/tensorflow/core/kernels/gather_functor.h index 1ad4d2da8d..dfa1a5f1f9 100644 --- a/tensorflow/core/kernels/gather_functor.h +++ b/tensorflow/core/kernels/gather_functor.h @@ -32,40 +32,51 @@ namespace functor { // Helper method to copy using memcpy. template <typename T, typename Index, typename SliceIndex, SliceIndex static_slice_elems> -SliceIndex HandleCopies(typename TTypes<T>::ConstMatrix params, +SliceIndex HandleCopies(typename TTypes<T, 3>::ConstTensor params, typename TTypes<Index>::ConstFlat indices, SliceIndex slice_elems, - typename TTypes<T>::Matrix out) { - const SliceIndex first_dim_size = - static_cast<SliceIndex>(indices.dimension(0)); - const Index limit = static_cast<Index>(params.dimension(0)); - T* out_base = &out(0, 0); - const T* params_base = ¶ms(0, 0); + typename TTypes<T, 3>::Tensor out) { + const SliceIndex indices_size = static_cast<SliceIndex>(indices.dimension(0)); + const SliceIndex batch_size = static_cast<SliceIndex>(params.dimension(0)); + const Index limit = static_cast<Index>(params.dimension(1)); + T* out_base = &out(0, 0, 0); + const T* params_base = ¶ms(0, 0, 0); if (static_slice_elems >= 0) { // Give compiler static knowledge of the number of elements/bytes slice_elems = static_slice_elems; } // Compute slice_bytes here so that static knowledge is available const size_t slice_bytes = slice_elems * sizeof(T); - for (SliceIndex i = 0; i < first_dim_size; i++) { - const SliceIndex j = i + 1; - if (j < first_dim_size) { - port::prefetch<port::PREFETCH_HINT_T0>(¶ms(indices(j), 0)); - port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0)); - } - // Grab the index and check its validity. An earlier version of the - // code checked it and then grabbed it from memory a second time, which - // was a security risk since it could have changed in between. - const Index index = internal::SubtleMustCopy(indices(i)); - if (!FastBoundsCheck(index, limit)) return i; - // Copy using memcpy if possible, otherwise an Eigen loop - // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve - // ahead-of-time compilation binary size). - if (is_simple_type<T>::value) { - memcpy(out_base + i * slice_elems, params_base + index * slice_elems, - slice_bytes); - } else { - out.template chip<0>(i) = params.template chip<0>(index); + for (SliceIndex b = 0; b < batch_size; b++) { + for (SliceIndex i = 0; i < indices_size; i++) { + const SliceIndex i_next = i + 1; + const SliceIndex b_next = b + 1; + if (i_next < indices_size) { + port::prefetch<port::PREFETCH_HINT_T0>(¶ms(b, indices(i_next), 0)); + port::prefetch<port::PREFETCH_HINT_T0>(&out(b, i_next, 0)); + } else if (b_next < batch_size) { + port::prefetch<port::PREFETCH_HINT_T0>(¶ms(b_next, indices(0), 0)); + port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, 0, 0)); + } + // Grab the index and check its validity. An earlier version of the + // code checked it and then grabbed it from memory a second time, which + // was a security risk since it could have changed in between. + const Index index = internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Copy using memcpy if possible, otherwise an Eigen loop + // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve + // ahead-of-time compilation binary size). + if (is_simple_type<T>::value) { + // Avoid auto-promotion to Index from SliceIndex by casting. + memcpy(out_base + (b * indices_size + i) * slice_elems, + params_base + (b * static_cast<SliceIndex>(limit) + + static_cast<SliceIndex>(index)) * + slice_elems, + slice_bytes); + } else { + // For non-"simple" types (e.g. strings). + out.template chip<1>(i) = params.template chip<1>(index); + } } } return -1; @@ -73,11 +84,11 @@ SliceIndex HandleCopies(typename TTypes<T>::ConstMatrix params, template <typename T, typename Index> struct GatherFunctorCPU { - int64 operator()(typename TTypes<T>::ConstMatrix params, + int64 operator()(typename TTypes<T, 3>::ConstTensor params, typename TTypes<Index>::ConstFlat indices, - typename TTypes<T>::Matrix out) { + typename TTypes<T, 3>::Tensor out) { const int64 N = indices.size(); - const int64 slice_size = out.size() / N; + const int64 slice_size = out.dimension(2); int64 bad_i; bool use_large = (slice_size > std::numeric_limits<int32>::max() || @@ -109,16 +120,17 @@ struct GatherFunctorCPU { template <typename Device, typename T, typename Index> struct GatherFunctor { - int64 operator()(const Device& d, typename TTypes<T>::ConstMatrix params, + int64 operator()(const Device& d, typename TTypes<T, 3>::ConstTensor params, typename TTypes<Index>::ConstFlat indices, - typename TTypes<T>::Matrix out); + typename TTypes<T, 3>::Tensor out); }; template <typename T, typename Index> struct GatherFunctor<CPUDevice, T, Index> { - int64 operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix params, + int64 operator()(const CPUDevice& d, + typename TTypes<T, 3>::ConstTensor params, typename TTypes<Index>::ConstFlat indices, - typename TTypes<T>::Matrix out) { + typename TTypes<T, 3>::Tensor out) { return GatherFunctorCPU<T, Index>()(params, indices, out); } }; |