#ifndef TENSORFLOW_KERNELS_BIAS_OP_H_ #define TENSORFLOW_KERNELS_BIAS_OP_H_ // Functor definition for BiasOp, must be compilable by nvcc. #include "tensorflow/core/framework/tensor_types.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace functor { // Functor used by BiasOp to do the computations. template struct Bias { // Add "bias" to "input", broadcasting it on all dimensions but the last one. void operator()(const Device& d, typename TTypes::ConstTensor input, typename TTypes::ConstVec bias, typename TTypes::Tensor output) { const int bias_size = bias.dimension(0); const int rest_size = input.size() / bias_size; Eigen::DSizes rest_by_bias(rest_size, bias_size); #if !defined(EIGEN_HAS_INDEX_LIST) Eigen::DSizes rest_by_one(rest_size, 1); Eigen::DSizes one_by_bias(1, bias_size); #else Eigen::IndexList > rest_by_one; rest_by_one.set(0, rest_size); Eigen::IndexList, int> one_by_bias; one_by_bias.set(1, bias_size); #endif output.reshape(rest_by_bias).device(d) = input.reshape(rest_by_bias) + bias.reshape(one_by_bias).broadcast(rest_by_one); } }; } // namespace functor } // namespace tensorflow #endif // TENSORFLOW_KERNELS_BIAS_OP_H_