aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/gather_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-21 09:21:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-21 10:32:29 -0700
commit412ae46894b42c76e8222a3dfa52fe333ce58706 (patch)
tree90019095b2fcd8a1e346cabe3663119f125d1f56 /tensorflow/core/kernels/gather_op.cc
parentee73700d148a271dbe407ca092b2ae879febb1a9 (diff)
Move gather_op's CPU implementation out of the kernel file.
Mkae test_util.py print the actual exception type when there is a mismatch between actual exception and expected exception. Change: 133845037
Diffstat (limited to 'tensorflow/core/kernels/gather_op.cc')
-rw-r--r--tensorflow/core/kernels/gather_op.cc71
1 files changed, 2 insertions, 69 deletions
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc
index edce9a3197..179931fd4e 100644
--- a/tensorflow/core/kernels/gather_op.cc
+++ b/tensorflow/core/kernels/gather_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/gather_op_cpu_impl.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h"
@@ -92,81 +93,13 @@ class GatherOp : public OpKernel {
namespace functor {
-// Helper method to copy using memcpy.
-template <typename T, typename Index, typename SliceIndex,
- SliceIndex static_slice_elems>
-SliceIndex HandleCopies(typename TTypes<T>::ConstMatrix params,
- typename TTypes<Index>::ConstFlat indices,
- SliceIndex slice_elems,
- typename TTypes<T>::Matrix out) {
- const SliceIndex first_dim_size =
- static_cast<SliceIndex>(indices.dimension(0));
- const Index limit = static_cast<Index>(params.dimension(0));
- T* out_base = &out(0, 0);
- const T* params_base = &params(0, 0);
- if (static_slice_elems >= 0) {
- // Give compiler static knowledge of the number of elements/bytes
- CHECK_EQ(static_slice_elems, slice_elems);
- slice_elems = static_slice_elems;
- }
- // Compute slice_bytes here so that static knowledge is available
- const size_t slice_bytes = slice_elems * sizeof(T);
- for (SliceIndex i = 0; i < first_dim_size; i++) {
- const SliceIndex j = i + 1;
- if (j < first_dim_size) {
- port::prefetch<port::PREFETCH_HINT_T0>(&params(indices(j), 0));
- port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0));
- }
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
- const Index index = internal::SubtleMustCopy(indices(i));
- if (!FastBoundsCheck(index, limit)) return i;
- // Copy using memcpy if possible, otherwise an Eigen loop
- if (Allocator::is_simple<T>::value) {
- memcpy(out_base + i * slice_elems, params_base + index * slice_elems,
- slice_bytes);
- } else {
- out.template chip<0>(i) = params.template chip<0>(index);
- }
- }
- return -1;
-}
-
// Specialization gather functor for CPU.
template <typename T, typename Index>
struct Gather<CPUDevice, T, Index> {
int64 operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix params,
typename TTypes<Index>::ConstFlat indices,
typename TTypes<T>::Matrix out) {
- const int64 N = indices.size();
- const int64 slice_size = out.size() / N;
- int64 bad_i;
-
- bool use_large = (slice_size > std::numeric_limits<int32>::max() ||
- params.size() > std::numeric_limits<int32>::max() ||
- N > std::numeric_limits<int32>::max());
-#define CALL(elems) \
- do { \
- if (use_large) { \
- bad_i = HandleCopies<T, Index, int64, elems>(params, indices, \
- slice_size, out); \
- } else { \
- const int32 small_slice = static_cast<int32>(slice_size); \
- bad_i = HandleCopies<T, Index, int32, elems>(params, indices, \
- small_slice, out); \
- } \
- } while (0)
-
- if (slice_size == 10)
- CALL(10);
- else if (slice_size == 20)
- CALL(20);
- else
- CALL(-1);
-#undef CALL
-
- return bad_i;
+ return GatherCpu<T, Index>()(params, indices, out);
}
};
} // namespace functor