#ifndef TENSORFLOW_KERNELS_CWISE_OPS_H_ #define TENSORFLOW_KERNELS_CWISE_OPS_H_ #include #include #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor_types.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" // The following functors (sign, tanh, sigmoid, etc.) are not defined // by Eigen. When their equivalent are added into the Eigen, we can // replace them using type aliases. namespace Eigen { namespace internal { template struct scalar_sign_op { // TODO(zhifengc): this only works for real types. In theory, // sign(x) = x / |x| works for both real and complex values. EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op); EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { return T(x > T(0)) - T(x < T(0)); } }; // TODO(zhifengc): Eigen::internal::pow_impl does not have proper // EIGEN host/device decoration. We duplicate code here for now. template struct pow { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, const T& y) const { return std::pow(x, y); } }; template struct pow { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(T x, T y) const { T res(1); if (y & 1) res *= x; y >>= 1; while (y) { x *= x; if (y & 1) res *= x; y >>= 1; } return res; } }; template struct scalar_pow2_op : pow::IsInteger> {}; template struct functor_traits > { enum { Cost = 5 * NumTraits::MulCost, PacketAccess = false, }; }; template struct scalar_fmod2_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_fmod2_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a, const T& b) const { return fmod(a, b); } }; template struct scalar_mod2_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_mod2_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a, const T& b) const { return a % b; } }; template struct functor_traits > { enum { Cost = 5, // Roughly the cost of a div PacketAccess = false, }; }; // scalar_left and scalar_right are template helpers to partially // apply a binary function. // // Suppose Binary is a binary functor f(x, y), scalar_left<> is a // unary functor g_x(y) = f(x, y), where x is provided via the // constructor. Similarly, scalar_right<> is a unary functor g_y(x) = // f(x, y). template ::PacketAccess> struct scalar_left { typedef Tout result_type; const Tin* left; EIGEN_DEVICE_FUNC inline scalar_left( const scalar_left& other) // NOLINT(runtime/explicit) : left(other.left) {} EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c) : left(c) {} EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const { return Binary()(*left, right); } }; template struct scalar_left { typedef Tout result_type; const Tin* left; EIGEN_DEVICE_FUNC inline scalar_left( const scalar_left& other) // NOLINT(runtime/explicit) : left(other.left) {} EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c) : left(c) {} EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const { return Binary()(*left, right); } template EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const { const Packet left_packet = Eigen::internal::pset1(*left); return Binary().packetOp(left_packet, right_packet); } }; template struct functor_traits > { enum { Cost = functor_traits::Cost, PacketAccess = functor_traits::PacketAccess, }; }; template ::PacketAccess> struct scalar_right { typedef Tout result_type; const Tin* right; EIGEN_DEVICE_FUNC inline scalar_right( const scalar_right& other) // NOLINT(runtime/explicit) : right(other.right) {} EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c) : right(c) {} EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const { return Binary()(left, *right); } }; template struct scalar_right { typedef Tout result_type; const Tin* right; EIGEN_DEVICE_FUNC inline scalar_right( const scalar_right& other) // NOLINT(runtime/explicit) : right(other.right) {} EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c) : right(c) {} EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const { return Binary()(left, *right); } template EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const { const Packet right_packet = Eigen::internal::pset1(*right); return Binary().packetOp(left_packet, right_packet); } }; template struct functor_traits > { enum { Cost = functor_traits::Cost, PacketAccess = functor_traits::PacketAccess, }; }; // similar to std::equal_to, but with the DEVICE_FUNC qualifier template struct equal_to : std::binary_function { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, const T& y) const { return x == y; } }; // similar to std::not_equal_to, but with the DEVICE_FUNC qualifier template struct not_equal_to : std::binary_function { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, const T& y) const { return x != y; } }; // similar to std::greater, but with the DEVICE_FUNC qualifier template struct greater : std::binary_function { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, const T& y) const { return x > y; } }; // similar to std::less, but with the DEVICE_FUNC qualifier template struct less : std::binary_function { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, const T& y) const { return x < y; } }; // similar to std::greater_equal, but with the DEVICE_FUNC qualifier template struct greater_equal : std::binary_function { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, const T& y) const { return x >= y; } }; // similar to std::less_equal, but with the DEVICE_FUNC qualifier template struct less_equal : std::binary_function { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, const T& y) const { return x <= y; } }; } // end namespace internal } // end namespace Eigen namespace tensorflow { namespace functor { //////////////////////////////////////////////////////////////////////////////// // Helpers //////////////////////////////////////////////////////////////////////////////// // Base template for functors whose input scalar type is T and // output scalar type is R. template struct base { // func defines operator() and its vectorized version packetOp(). typedef F func; // If true, the functor's corresponding binary op will instantiate // specialized kernels to perform an optimized broadcast // operation. Each functor for which this is enabled increases the // code size, so by default this is disabled for binary functors and // is enabled on a per-op basis as needed. static const bool use_bcast_optimization = false; // operator() has the signature: // out_type operator()(in_type in0, in_type in1 ...) typedef R out_type; typedef T in_type; // TensorFlow provides tensor-ized version of "func". Roughly // speaking, the tensorflow operation has the signature: // tout_type op(tin_type in0) // tout_type op(tin_type in0, tin_type in1) // tout_type op(tin_type in0, in_type scalar) typedef typename TTypes::Flat tout_type; typedef typename TTypes::ConstFlat tin_type; typedef typename TTypes::ConstScalar tscalar_type; }; // For now, we only apply certain speed optimization for // float/double's broadcast binary op. template struct use_bcast_optimization { static const bool value = false; }; template <> struct use_bcast_optimization { static const bool value = true; }; template <> struct use_bcast_optimization { static const bool value = true; }; //////////////////////////////////////////////////////////////////////////////// // Unary functors //////////////////////////////////////////////////////////////////////////////// // abs(x) = |x| // neg(x) = - x // inverse(x) = 1 / x // square(x) = x^2 // sqrt(x) = x^(1/2) // rsqrt(x) = x^(-1/2) // exp(x) = e^x // log(x) = natural logrithm of x // tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) // sigmoid = 1 / (1 + exp(-x)) // a.k.a, logistic // // NOTE: We may eventually implement common functions used in NN // here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc. // For reference, see speech/lstm/eigen_functors.h. template struct abs : base, typename Eigen::internal::scalar_abs_op::result_type> {}; template struct neg : base > {}; template struct inverse : base > {}; template struct square : base > {}; template struct sqrt : base > {}; template struct rsqrt : base > {}; template struct exp : base > {}; template struct log : base > {}; template struct sign : base > {}; template struct tanh : base > {}; template struct sigmoid : base > {}; template struct sin : base > {}; template struct cos : base > {}; struct logical_not : base > {}; namespace impl { #ifndef __CUDACC__ // Uses STL std cmath functions. template bool isinf(T v) { return std::isinf(v); } template bool isnan(T v) { return std::isnan(v); } template bool isfinite(T v) { return std::isfinite(v); } template T floor(T v) { return std::floor(v); } template T ceil(T v) { return std::ceil(v); } #else // Uses CUDA's functions for float and double. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isinf(T v) { return ::isinf(v); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isnan(T v) { return ::isnan(v); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isfinite(T v) { return ::isfinite(v); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T floor(T v) { return ::floor(v); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T ceil(T v) { return ::ceil(v); } #endif } // end namespace impl // NOTE: std::isinf, std::isnan, std::isfinite are plain function. // Therefore we need to wrap them in functors to be used with Eigen's // type system. template struct isinf_func { typedef bool result_type; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(T x) const { return impl::isinf(x); } }; template struct isinf : base, bool> {}; template struct isnan_func { typedef bool result_type; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(T x) const { return impl::isnan(x); } }; template struct isnan : base, bool> {}; template struct isfinite_func { typedef bool result_type; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(T x) const { return impl::isfinite(x); } }; template struct isfinite : base, bool> {}; template struct floor_func { typedef T result_type; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(T x) const { return impl::floor(x); } }; template struct floor : base > {}; template struct ceil_func { typedef T result_type; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(T x) const { return impl::ceil(x); } }; template struct ceil : base > {}; //////////////////////////////////////////////////////////////////////////////// // Binary functors //////////////////////////////////////////////////////////////////////////////// // Binary functors: // // add(x, y) = x + y // sub(x, y) = x - y // mul(x, y) = x * y // div(x, y) = x / y // mod(x, y) = x % y (int32 and int64 only) // fmod(x, y) = fmod(x, y) (float and double only) // pow(x, y) = x ^ y // maximum(x, y) = x > y ? x : y // minimum(x, y) = x < y ? x : y template struct add : base > { static const bool use_bcast_optimization = true; }; template struct sub : base > { static const bool use_bcast_optimization = true; }; template struct mul : base > {}; template struct div : base > {}; template struct fmod : base > {}; template struct mod : base > {}; template struct pow : base > {}; template struct maximum : base > {}; template struct minimum : base > {}; template struct less : base, bool> {}; template struct less_equal : base, bool> {}; template struct greater : base, bool> {}; template struct greater_equal : base, bool> {}; template struct equal_to : base, bool> {}; template struct not_equal_to : base, bool> {}; struct logical_and : base {}; struct logical_or : base {}; template struct make_complex_func { typedef std::complex result_type; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(T real, T imag) const { return std::complex(real, imag); } }; template struct make_complex : base, std::complex > {}; template struct get_real : base, typename T::value_type> {}; template struct get_imag : base, typename T::value_type> {}; template struct conj : base > {}; //////////////////////////////////////////////////////////////////////////////// // Functors takes 1 or 2 tensors, computes the base functor on // coefficient of the input tensors and puts the results in the output // tensor. //////////////////////////////////////////////////////////////////////////////// template struct UnaryFunctor { // Computes on device "d": out[i] = Functor(in[i]) void operator()(const Device& d, typename Functor::tout_type out, typename Functor::tin_type in); }; template struct BinaryFunctor { // Computes on device "d": out[i] = Functor(in0[i], in1[i]) void operator()(const Device& d, typename Functor::tout_type out, typename Functor::tin_type in0, typename Functor::tin_type in1); // Computes on device "d": out[i] = Functor(scalar[0], in[i]) void Left(const Device& d, typename Functor::tout_type out, typename Functor::tscalar_type scalar, typename Functor::tin_type in); // Computes on device "d": out[i] = Functor(in[i], scalar[0]) void Right(const Device& d, typename Functor::tout_type out, typename Functor::tin_type in, typename Functor::tscalar_type scalar); // Computes on device "d": // out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast01)) // // TODO(zhifengc): makes BCast a template member function on NDIMS // instead making BinaryFunctor templates on NDIMS. void BCast(const Device& d, typename TTypes::Tensor out, typename TTypes::ConstTensor in0, typename Eigen::array bcast0, typename TTypes::ConstTensor in1, typename Eigen::array bcast1); }; template bool AllOne(const typename Eigen::array& a) { for (int i = 0; i < a.size(); ++i) { if (a[i] != 1) return false; } return true; } template struct SelectFunctor { void operator()(const Device& d, typename TTypes::Flat out, typename TTypes::ConstFlat cond_flat, typename TTypes::ConstFlat then_flat, typename TTypes::ConstFlat else_flat); }; } // end namespace functor } // end namespace tensorflow #endif // TENSORFLOW_KERNELS_CWISE_OPS_H_