diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-09-21 09:21:48 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-21 10:32:29 -0700 |
commit | 412ae46894b42c76e8222a3dfa52fe333ce58706 (patch) | |
tree | 90019095b2fcd8a1e346cabe3663119f125d1f56 /tensorflow/core/kernels/gather_op.cc | |
parent | ee73700d148a271dbe407ca092b2ae879febb1a9 (diff) |
Move gather_op's CPU implementation out of the kernel file.
Mkae test_util.py print the actual exception type when there is a mismatch
between actual exception and expected exception.
Change: 133845037
Diffstat (limited to 'tensorflow/core/kernels/gather_op.cc')
-rw-r--r-- | tensorflow/core/kernels/gather_op.cc | 71 |
1 files changed, 2 insertions, 69 deletions
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc index edce9a3197..179931fd4e 100644 --- a/tensorflow/core/kernels/gather_op.cc +++ b/tensorflow/core/kernels/gather_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/gather_op_cpu_impl.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" @@ -92,81 +93,13 @@ class GatherOp : public OpKernel { namespace functor { -// Helper method to copy using memcpy. -template <typename T, typename Index, typename SliceIndex, - SliceIndex static_slice_elems> -SliceIndex HandleCopies(typename TTypes<T>::ConstMatrix params, - typename TTypes<Index>::ConstFlat indices, - SliceIndex slice_elems, - typename TTypes<T>::Matrix out) { - const SliceIndex first_dim_size = - static_cast<SliceIndex>(indices.dimension(0)); - const Index limit = static_cast<Index>(params.dimension(0)); - T* out_base = &out(0, 0); - const T* params_base = ¶ms(0, 0); - if (static_slice_elems >= 0) { - // Give compiler static knowledge of the number of elements/bytes - CHECK_EQ(static_slice_elems, slice_elems); - slice_elems = static_slice_elems; - } - // Compute slice_bytes here so that static knowledge is available - const size_t slice_bytes = slice_elems * sizeof(T); - for (SliceIndex i = 0; i < first_dim_size; i++) { - const SliceIndex j = i + 1; - if (j < first_dim_size) { - port::prefetch<port::PREFETCH_HINT_T0>(¶ms(indices(j), 0)); - port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0)); - } - // Grab the index and check its validity. An earlier version of the - // code checked it and then grabbed it from memory a second time, which - // was a security risk since it could have changed in between. - const Index index = internal::SubtleMustCopy(indices(i)); - if (!FastBoundsCheck(index, limit)) return i; - // Copy using memcpy if possible, otherwise an Eigen loop - if (Allocator::is_simple<T>::value) { - memcpy(out_base + i * slice_elems, params_base + index * slice_elems, - slice_bytes); - } else { - out.template chip<0>(i) = params.template chip<0>(index); - } - } - return -1; -} - // Specialization gather functor for CPU. template <typename T, typename Index> struct Gather<CPUDevice, T, Index> { int64 operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix params, typename TTypes<Index>::ConstFlat indices, typename TTypes<T>::Matrix out) { - const int64 N = indices.size(); - const int64 slice_size = out.size() / N; - int64 bad_i; - - bool use_large = (slice_size > std::numeric_limits<int32>::max() || - params.size() > std::numeric_limits<int32>::max() || - N > std::numeric_limits<int32>::max()); -#define CALL(elems) \ - do { \ - if (use_large) { \ - bad_i = HandleCopies<T, Index, int64, elems>(params, indices, \ - slice_size, out); \ - } else { \ - const int32 small_slice = static_cast<int32>(slice_size); \ - bad_i = HandleCopies<T, Index, int32, elems>(params, indices, \ - small_slice, out); \ - } \ - } while (0) - - if (slice_size == 10) - CALL(10); - else if (slice_size == 20) - CALL(20); - else - CALL(-1); -#undef CALL - - return bad_i; + return GatherCpu<T, Index>()(params, indices, out); } }; } // namespace functor |