aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/gather_functor.h
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2017-07-11 09:51:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-11 09:55:52 -0700
commitb1f9e2c89eb007cb4b9483d08dcace1e45e84164 (patch)
treeb2b82fc0bd6abf3b77a412a251fcfdacf70a21dc /tensorflow/core/kernels/gather_functor.h
parent18a5510e67ef536c947512b70030c5c995ce7875 (diff)
Add an axis parameter to tf.gather. Fixes GitHub issue #11223.
This brings tf.gather closer to compatibility with numpy.take. To emulate gathering over an axis generally requires inefficient workarounds, e.g. transpose/gather/transpose. This technique is gaining popularity (hundreds of uses inside and outside of Google), so it is worth supporting efficiently. For an `[a_0, ..., a_i, ..., a_n]` tensor, gathering `N` elements from axis `i` requires `(a_0*...*a_i-1) * N` copies of `(a_i+1 * ... * a_n)` elements each. The CPU kernel does this with memcpy which is far more efficient than transpose/gather/transpose since it requires no intermediate allocations and copies. The GPU kernel does the same number of copies but in parallel across multiple hardware threads. Since this is a backwards incompatible change, this adds a "GatherV2" op with an axis input, and simultaneously supports backwards compatibility with "Gather" ops by defaulting to axis 0 if a 3rd input is not present. PiperOrigin-RevId: 161541416
Diffstat (limited to 'tensorflow/core/kernels/gather_functor.h')
-rw-r--r--tensorflow/core/kernels/gather_functor.h78
1 files changed, 45 insertions, 33 deletions
diff --git a/tensorflow/core/kernels/gather_functor.h b/tensorflow/core/kernels/gather_functor.h
index 1ad4d2da8d..dfa1a5f1f9 100644
--- a/tensorflow/core/kernels/gather_functor.h
+++ b/tensorflow/core/kernels/gather_functor.h
@@ -32,40 +32,51 @@ namespace functor {
// Helper method to copy using memcpy.
template <typename T, typename Index, typename SliceIndex,
SliceIndex static_slice_elems>
-SliceIndex HandleCopies(typename TTypes<T>::ConstMatrix params,
+SliceIndex HandleCopies(typename TTypes<T, 3>::ConstTensor params,
typename TTypes<Index>::ConstFlat indices,
SliceIndex slice_elems,
- typename TTypes<T>::Matrix out) {
- const SliceIndex first_dim_size =
- static_cast<SliceIndex>(indices.dimension(0));
- const Index limit = static_cast<Index>(params.dimension(0));
- T* out_base = &out(0, 0);
- const T* params_base = &params(0, 0);
+ typename TTypes<T, 3>::Tensor out) {
+ const SliceIndex indices_size = static_cast<SliceIndex>(indices.dimension(0));
+ const SliceIndex batch_size = static_cast<SliceIndex>(params.dimension(0));
+ const Index limit = static_cast<Index>(params.dimension(1));
+ T* out_base = &out(0, 0, 0);
+ const T* params_base = &params(0, 0, 0);
if (static_slice_elems >= 0) {
// Give compiler static knowledge of the number of elements/bytes
slice_elems = static_slice_elems;
}
// Compute slice_bytes here so that static knowledge is available
const size_t slice_bytes = slice_elems * sizeof(T);
- for (SliceIndex i = 0; i < first_dim_size; i++) {
- const SliceIndex j = i + 1;
- if (j < first_dim_size) {
- port::prefetch<port::PREFETCH_HINT_T0>(&params(indices(j), 0));
- port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0));
- }
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
- const Index index = internal::SubtleMustCopy(indices(i));
- if (!FastBoundsCheck(index, limit)) return i;
- // Copy using memcpy if possible, otherwise an Eigen loop
- // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve
- // ahead-of-time compilation binary size).
- if (is_simple_type<T>::value) {
- memcpy(out_base + i * slice_elems, params_base + index * slice_elems,
- slice_bytes);
- } else {
- out.template chip<0>(i) = params.template chip<0>(index);
+ for (SliceIndex b = 0; b < batch_size; b++) {
+ for (SliceIndex i = 0; i < indices_size; i++) {
+ const SliceIndex i_next = i + 1;
+ const SliceIndex b_next = b + 1;
+ if (i_next < indices_size) {
+ port::prefetch<port::PREFETCH_HINT_T0>(&params(b, indices(i_next), 0));
+ port::prefetch<port::PREFETCH_HINT_T0>(&out(b, i_next, 0));
+ } else if (b_next < batch_size) {
+ port::prefetch<port::PREFETCH_HINT_T0>(&params(b_next, indices(0), 0));
+ port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, 0, 0));
+ }
+ // Grab the index and check its validity. An earlier version of the
+ // code checked it and then grabbed it from memory a second time, which
+ // was a security risk since it could have changed in between.
+ const Index index = internal::SubtleMustCopy(indices(i));
+ if (!FastBoundsCheck(index, limit)) return i;
+ // Copy using memcpy if possible, otherwise an Eigen loop
+ // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve
+ // ahead-of-time compilation binary size).
+ if (is_simple_type<T>::value) {
+ // Avoid auto-promotion to Index from SliceIndex by casting.
+ memcpy(out_base + (b * indices_size + i) * slice_elems,
+ params_base + (b * static_cast<SliceIndex>(limit) +
+ static_cast<SliceIndex>(index)) *
+ slice_elems,
+ slice_bytes);
+ } else {
+ // For non-"simple" types (e.g. strings).
+ out.template chip<1>(i) = params.template chip<1>(index);
+ }
}
}
return -1;
@@ -73,11 +84,11 @@ SliceIndex HandleCopies(typename TTypes<T>::ConstMatrix params,
template <typename T, typename Index>
struct GatherFunctorCPU {
- int64 operator()(typename TTypes<T>::ConstMatrix params,
+ int64 operator()(typename TTypes<T, 3>::ConstTensor params,
typename TTypes<Index>::ConstFlat indices,
- typename TTypes<T>::Matrix out) {
+ typename TTypes<T, 3>::Tensor out) {
const int64 N = indices.size();
- const int64 slice_size = out.size() / N;
+ const int64 slice_size = out.dimension(2);
int64 bad_i;
bool use_large = (slice_size > std::numeric_limits<int32>::max() ||
@@ -109,16 +120,17 @@ struct GatherFunctorCPU {
template <typename Device, typename T, typename Index>
struct GatherFunctor {
- int64 operator()(const Device& d, typename TTypes<T>::ConstMatrix params,
+ int64 operator()(const Device& d, typename TTypes<T, 3>::ConstTensor params,
typename TTypes<Index>::ConstFlat indices,
- typename TTypes<T>::Matrix out);
+ typename TTypes<T, 3>::Tensor out);
};
template <typename T, typename Index>
struct GatherFunctor<CPUDevice, T, Index> {
- int64 operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix params,
+ int64 operator()(const CPUDevice& d,
+ typename TTypes<T, 3>::ConstTensor params,
typename TTypes<Index>::ConstFlat indices,
- typename TTypes<T>::Matrix out) {
+ typename TTypes<T, 3>::Tensor out) {
return GatherFunctorCPU<T, Index>()(params, indices, out);
}
};