#ifndef TENSORFLOW_KERNELS_TILE_OPS_H_ #define TENSORFLOW_KERNELS_TILE_OPS_H_ #include "tensorflow/core/platform/port.h" #include "tensorflow/core/framework/tensor_types.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace functor { template struct Tile { void operator()(const Device& d, typename TTypes::Tensor out, typename TTypes::ConstTensor in, const Eigen::array& broadcast_array) const { out.device(d) = in.broadcast(broadcast_array); } }; template struct TileGrad { void operator()(const Device& d, typename TTypes::Tensor out, typename TTypes::ConstTensor in, const Eigen::DSizes& indices, const Eigen::DSizes& sizes, bool first) const { if (first) { out.device(d) = in.slice(indices, sizes); } else { out.device(d) += in.slice(indices, sizes); } } }; template struct ReduceAndReshape { void operator()(const Device& d, typename TTypes::Tensor out, typename TTypes::ConstTensor in, const Eigen::DSizes& reduce_dim, const Eigen::DSizes& reshape_dim) const { out.device(d) = in.sum(reduce_dim).reshape(reshape_dim); } }; } // end namespace functor } // end namespace tensorflow #endif // TENSORFLOW_KERNELS_TILE_OPS_H_