/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ #define TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ // Specialization of GatherNdSlice to CPU #define EIGEN_USE_THREADS #include #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/gather_nd_op.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/util.h" namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; namespace generator { template class GatherNdSliceGenerator { public: EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE GatherNdSliceGenerator( const Index slice_size, typename TTypes::ConstMatrix Tindices, typename TTypes::ConstTensor Tparams, typename TTypes::Matrix Tout, std::atomic* error_loc) : slice_size_(slice_size), Tindices_(Tindices), Tparams_(Tparams), Tout_(Tout), error_loc_(error_loc) {} EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool GenerateIndices( const Index loc, Eigen::array* ix) const { (*ix)[IXDIM] = 0; bool out_of_bounds = false; for (int i = 0; i < IXDIM; ++i) { const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i)); (*ix)[i] = ix_i; out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i)); } return out_of_bounds; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32 operator()(const Eigen::array& loc_array) const { const Index loc = loc_array[0]; Eigen::array ix; Eigen::array ix_out; ix_out[0] = loc; ix_out[1] = 0; const bool out_of_bounds = GenerateIndices(loc, &ix); if (TF_PREDICT_FALSE(out_of_bounds)) { error_loc_->store(loc); std::fill_n(&Tout_(ix_out), slice_size_, T()); } else { std::copy_n(&Tparams_(ix), slice_size_, &Tout_(ix_out)); } return static_cast(0); // Return something... } private: const Index slice_size_; const typename TTypes::ConstMatrix Tindices_; const typename TTypes::ConstTensor Tparams_; mutable typename TTypes::Matrix Tout_; std::atomic* error_loc_; }; } // namespace generator namespace functor { template struct GatherNdSlice { Index operator()(const CPUDevice& d, const Index slice_size, typename TTypes::Scalar Tscratch, typename TTypes::ConstTensor Tparams, typename TTypes::ConstMatrix Tindices, typename TTypes::Matrix Tout) { std::atomic error_loc(-1); const Eigen::DenseIndex batch_size = Tindices.dimension(0); #if !defined(EIGEN_HAS_INDEX_LIST) Eigen::Tensor::Dimensions reshape_dims{{ 1 }}; Eigen::array broadcast_dims{{ batch_size }}; #else Eigen::IndexList > reshape_dims; Eigen::IndexList broadcast_dims; broadcast_dims.set(0, batch_size); #endif generator::GatherNdSliceGenerator gather_nd_generator( slice_size, Tindices, Tparams, Tout, &error_loc); #if defined(INTEL_MKL) && defined(ENABLE_MKL) // Eigen implementation below is not highly performant. gather_nd_generator // does not seem to be called in parallel, leading to very poor performance. // Additionally, since it uses scalar (Tscratch) to invoke 'generate', it // needs to go through redundant operations like 'reshape', 'broadcast' and // 'sum'. OpenMP loop below essentially does same thing as Eigen code, but // is considerably more efficient. #pragma omp parallel for for (Eigen::DenseIndex i = 0; i < batch_size; i++) { const Eigen::array loc{i}; gather_nd_generator(loc); } #else // INTEL_MKL && ENABLE_MKL Tscratch.device(d) = Tscratch.reshape(reshape_dims) .broadcast(broadcast_dims) .generate(gather_nd_generator) .sum(); #endif // INTEL_MKL && ENABLE_MKL // error_loc() returns -1 if there's no out-of-bounds index, // otherwise it returns the location of an OOB index in Tindices. return error_loc.load(); } }; #define REGISTER_GATHER_ND_FULL(T, Index) \ template Index GatherNdSlice:: \ operator()(const CPUDevice& d, const Index slice_size, \ typename TTypes::Scalar Tscratch, \ typename TTypes::ConstTensor Tparams, \ typename TTypes::ConstMatrix Tindices, \ typename TTypes::Matrix Tout); #define REGISTER_GATHER_ND_CPU(type) \ REGISTER_GATHER_ND_FULL(type, int32); \ REGISTER_GATHER_ND_FULL(type, int64) TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU); } // namespace functor } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_