#ifndef TENSORFLOW_KERNELS_DENSE_UPDATE_OPS_H_ #define TENSORFLOW_KERNELS_DENSE_UPDATE_OPS_H_ #include "tensorflow/core/framework/tensor_types.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { enum DenseUpdateType { ADD, SUB, ASSIGN }; namespace functor { template struct DenseUpdate; template struct DenseUpdate { void operator()(const Device& d, typename TTypes::Flat params, typename TTypes::ConstFlat update) { params.device(d) += update; } }; template struct DenseUpdate { void operator()(const Device& d, typename TTypes::Flat params, typename TTypes::ConstFlat update) { params.device(d) -= update; } }; template struct DenseUpdate { void operator()(const Device& d, typename TTypes::Flat params, typename TTypes::ConstFlat update) { params.device(d) = update; } }; } // end namespace functor } // end namespace tensorflow #endif // TENSORFLOW_KERNELS_DENSE_UPDATE_OPS_H_