aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
blob: 1c78de253e702f5e546467bbed0758c24dbe0443 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_

// Specialization of GatherNdSlice to CPU

#define EIGEN_USE_THREADS

#include <atomic>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/gather_nd_op.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"

namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;

namespace generator {

template <typename T, typename Index, int IXDIM>
class GatherNdSliceGenerator {
 public:
  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE GatherNdSliceGenerator(
      const Index slice_size, typename TTypes<Index>::ConstMatrix Tindices,
      typename TTypes<T, IXDIM + 1>::ConstTensor Tparams,
      typename TTypes<T>::Matrix Tout, std::atomic<Index>* error_loc)
      : slice_size_(slice_size),
        Tindices_(Tindices),
        Tparams_(Tparams),
        Tout_(Tout),
        error_loc_(error_loc) {}

  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool GenerateIndices(
      const Index loc, Eigen::array<Eigen::DenseIndex, IXDIM + 1>* ix) const {
    (*ix)[IXDIM] = 0;
    bool out_of_bounds = false;
    for (int i = 0; i < IXDIM; ++i) {
      const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i));
      (*ix)[i] = ix_i;
      out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
    }
    return out_of_bounds;
  }

  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32
  operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const {
    const Index loc = loc_array[0];
    Eigen::array<Eigen::DenseIndex, IXDIM + 1> ix;
    Eigen::array<Eigen::DenseIndex, 2> ix_out;
    ix_out[0] = loc;
    ix_out[1] = 0;
    const bool out_of_bounds = GenerateIndices(loc, &ix);
    if (TF_PREDICT_FALSE(out_of_bounds)) {
      error_loc_->store(loc);
      std::fill_n(&Tout_(ix_out), slice_size_, T());
    } else {
      std::copy_n(&Tparams_(ix), slice_size_, &Tout_(ix_out));
    }

    return static_cast<int32>(0);  // Return something...
  }

 private:
  const Index slice_size_;
  const typename TTypes<Index>::ConstMatrix Tindices_;
  const typename TTypes<T, IXDIM + 1>::ConstTensor Tparams_;
  mutable typename TTypes<T>::Matrix Tout_;
  std::atomic<Index>* error_loc_;
};

}  // namespace generator

namespace functor {

template <typename T, typename Index, int IXDIM>
struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
  Index operator()(const CPUDevice& d, const Index slice_size,
                   typename TTypes<int32>::Scalar Tscratch,
                   typename TTypes<T, IXDIM + 1>::ConstTensor Tparams,
                   typename TTypes<Index>::ConstMatrix Tindices,
                   typename TTypes<T>::Matrix Tout) {
    std::atomic<Index> error_loc(-1);

    const Eigen::DenseIndex batch_size = Tindices.dimension(0);
#if !defined(EIGEN_HAS_INDEX_LIST)
    Eigen::Tensor<Eigen::DenseIndex, 1>::Dimensions reshape_dims{{ 1 }};
    Eigen::array<Eigen::DenseIndex, 1> broadcast_dims{{ batch_size }};
#else
    Eigen::IndexList<Eigen::type2index<1> > reshape_dims;
    Eigen::IndexList<Eigen::DenseIndex> broadcast_dims;
    broadcast_dims.set(0, batch_size);
#endif
    generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
        slice_size, Tindices, Tparams, Tout, &error_loc);

#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// Eigen implementation below is not highly performant. gather_nd_generator
// does not seem to be called in parallel, leading to very poor performance.
// Additionally, since it uses scalar (Tscratch) to invoke 'generate', it
// needs to go through redundant operations like 'reshape', 'broadcast' and
// 'sum'. OpenMP loop below essentially does same thing as Eigen code, but
// is considerably more efficient.
#pragma omp parallel for
    for (Eigen::DenseIndex i = 0; i < batch_size; i++) {
      const Eigen::array<Eigen::DenseIndex, 1> loc{i};
      gather_nd_generator(loc);
    }
#else   // INTEL_MKL && ENABLE_MKL
    Tscratch.device(d) = Tscratch.reshape(reshape_dims)
                             .broadcast(broadcast_dims)
                             .generate(gather_nd_generator)
                             .sum();
#endif  // INTEL_MKL && ENABLE_MKL

    // error_loc() returns -1 if there's no out-of-bounds index,
    // otherwise it returns the location of an OOB index in Tindices.
    return error_loc.load();
  }
};

#define REGISTER_GATHER_ND_FULL(T, Index)                                     \
  template Index GatherNdSlice<CPUDevice, T, Index, CPU_PROVIDED_IXDIM>::     \
  operator()(const CPUDevice& d, const Index slice_size,                      \
             typename TTypes<int32>::Scalar Tscratch,                         \
             typename TTypes<T, CPU_PROVIDED_IXDIM + 1>::ConstTensor Tparams, \
             typename TTypes<Index>::ConstMatrix Tindices,                    \
             typename TTypes<T>::Matrix Tout);

#define REGISTER_GATHER_ND_CPU(type)    \
  REGISTER_GATHER_ND_FULL(type, int32); \
  REGISTER_GATHER_ND_FULL(type, int64)

TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);

}  // namespace functor

}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_