/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ #include #include #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/bounds_check.h" namespace Eigen { namespace internal { template struct scalar_asinh_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_asinh_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { #if EIGEN_HAS_CXX11_MATH return numext::asinh(a); #else return std::asinh(a); #endif // EIGEN_HAS_CXX11_MATH } }; template struct functor_traits> { enum { Cost = 5 * NumTraits::MulCost, PacketAccess = false }; }; template struct scalar_acosh_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_acosh_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { #if EIGEN_HAS_CXX11_MATH return numext::acosh(a); #else return std::acosh(a); #endif // EIGEN_HAS_CXX11_MATH } }; template struct functor_traits> { enum { Cost = 5 * NumTraits::MulCost, PacketAccess = false }; }; template struct scalar_atanh_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_atanh_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { #if EIGEN_HAS_CXX11_MATH return numext::atanh(a); #else return std::atanh(a); #endif // EIGEN_HAS_CXX11_MATH } }; template struct functor_traits> { enum { Cost = 5 * NumTraits::MulCost, PacketAccess = false }; }; // TODO(rmlarsen): This is a workaround for upstream change // https://bitbucket.org/eigen/eigen/commits/f339468d04d0f87caeb6cab9aef568627e9f6ea9 // that renamed scalar_binary_pow_op to scalar_pow_op and deleted the unary // version of the latter. Remove once we upgrade to Eigen 3.3. template struct scalar_binary_pow_op_google { EIGEN_EMPTY_STRUCT_CTOR(scalar_binary_pow_op_google) EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a, const Exponent& b) const { return numext::pow(a, b); } }; template struct functor_traits> { enum { Cost = 5 * NumTraits::MulCost, PacketAccess = false }; }; template struct safe_scalar_binary_pow_op { static_assert(std::is_integral::value, "Integer type expected"); static_assert(std::is_integral::value && std::is_signed::value, "Signed integer type expected"); bool* const error; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error) : error(error) {} EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a, const Exponent& b) const { const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b); if (TF_PREDICT_TRUE(safe_b >= 0)) { return numext::pow(a, safe_b); } else { *error = true; return 0; } } }; template struct functor_traits> { enum { Cost = 5 * NumTraits::MulCost, PacketAccess = false }; }; template struct safe_div_or_mod_op { static_assert(std::is_integral::value, "Integer type expected"); bool* const error; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_div_or_mod_op(bool* error) : error(error) {} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a, const T& b) const { const T safe_b = tensorflow::internal::SubtleMustCopy(b); if (TF_PREDICT_TRUE(safe_b != 0)) { return DivOrMod()(a, safe_b); } else { *error = true; return 0; } } }; template struct functor_traits> { enum { Cost = functor_traits::Cost + NumTraits::AddCost, PacketAccess = false, }; }; template struct div_no_nan_op { EIGEN_EMPTY_STRUCT_CTOR(div_no_nan_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a, const T& b) const { if (b != 0) { return scalar_quotient_op()(a, b); } else { return 0; } } }; template struct functor_traits> { enum { Cost = functor_traits>::Cost + NumTraits::AddCost, PacketAccess = false, }; }; // scalar_left and scalar_right are template helpers to partially // apply a binary function. // // Suppose Binary is a binary functor f(x, y), scalar_left<> is a // unary functor g_x(y) = f(x, y), where x is provided via the // constructor. Similarly, scalar_right<> is a unary functor g_y(x) = // f(x, y). template struct scalar_left : private Binary { typedef Tout result_type; const Tin* left; EIGEN_DEVICE_FUNC inline scalar_left(const scalar_left& other) = default; template EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args) : Binary(args...), left(c) {} EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const { return Binary::operator()(*left, right); } template EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const { const Packet left_packet = Eigen::internal::pset1(*left); return Binary::packetOp(left_packet, right_packet); } }; template struct functor_traits> { enum { Cost = functor_traits::Cost, PacketAccess = functor_traits::PacketAccess, }; }; template struct scalar_right : private Binary { typedef Tout result_type; const Tin* right; EIGEN_DEVICE_FUNC inline scalar_right(const scalar_right& other) = default; template EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args) : Binary(args...), right(c) {} EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const { return Binary::operator()(left, *right); } template EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const { const Packet right_packet = Eigen::internal::pset1(*right); return Binary::packetOp(left_packet, right_packet); } }; template struct functor_traits> { enum { Cost = functor_traits::Cost, PacketAccess = functor_traits::PacketAccess, }; }; // similar to std::equal_to, but with the DEVICE_FUNC qualifier template struct equal_to : std::binary_function { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, const T& y) const { return x == y; } }; // similar to std::not_equal_to, but with the DEVICE_FUNC qualifier template struct not_equal_to : std::binary_function { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, const T& y) const { return x != y; } }; // similar to std::greater, but with the DEVICE_FUNC qualifier template struct greater : std::binary_function { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, const T& y) const { return x > y; } }; // similar to std::less, but with the DEVICE_FUNC qualifier template struct less : std::binary_function { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, const T& y) const { return x < y; } }; // similar to std::greater_equal, but with the DEVICE_FUNC qualifier template struct greater_equal : std::binary_function { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, const T& y) const { return x >= y; } }; // similar to std::less_equal, but with the DEVICE_FUNC qualifier template struct less_equal : std::binary_function { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, const T& y) const { return x <= y; } }; // Functor that enables composition of multiple Eigen functors. template struct scalar_compose_op { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a, const Scalar& b) const { return UnaryFunctor()(BinaryFunctor()(a, b)); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const { return UnaryFunctor().packetOp(BinaryFunctor().packetOp(a, b)); } }; template struct functor_traits> { enum { Cost = functor_traits::Cost + functor_traits::Cost, PacketAccess = functor_traits::PacketAccess && functor_traits::PacketAccess }; }; // TODO(b/32239616): This kernel should be moved into Eigen and vectorized. template struct google_floor_div { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, const T& y) const { if ((x < T(0)) != (y < T(0))) { T abs_x = std::abs(x); T abs_y = std::abs(y); return -(abs_x + abs_y - 1) / abs_y; } else { return x / y; } } }; template struct google_floor_div< T, typename std::enable_if::value>::type> { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, const T& y) const { return x / y; } }; template struct functor_traits> { enum { Cost = 2 * Eigen::internal::scalar_div_cost::value + 2 * NumTraits::AddCost, PacketAccess = false }; }; // TODO(b/32239616): This kernel should be moved into Eigen and vectorized. template struct google_floor_div_real { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, const T& y) const { return Eigen::numext::floor(x / y); } }; template struct functor_traits> { enum { Cost = 2 * Eigen::internal::scalar_div_cost::value + 2 * NumTraits::AddCost, PacketAccess = false }; }; // TODO(b//32239616): This kernel should be moved into Eigen and vectorized. template struct google_floor_fmod { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, const T& y) const { // EIGEN_STATIC_ASSERT(NUMERIC_TYPE_MUST_BE_REAL); T trunc_mod = std::fmod(x, y); return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); } }; template struct functor_traits> { enum { Cost = 2 * Eigen::internal::scalar_div_cost::value + 2 * NumTraits::AddCost, PacketAccess = false }; }; // TODO(b/32239616): This kernel should be moved into Eigen and vectorized. template struct google_floor_mod { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, const T& y) const { // EIGEN_STATIC_ASSERT(!NUMERIC_TYPE_MUST_BE_REAL); T trunc_mod = x % y; return (x < T(0)) == (y < T(0)) ? trunc_mod : (trunc_mod + y) % y; } }; template struct functor_traits> { enum { Cost = 2 * Eigen::internal::scalar_div_cost::value + 2 * NumTraits::AddCost, PacketAccess = false }; }; #if EIGEN_COMP_GNUC && __cplusplus > 199711L #define DISABLE_FLOAT_EQUALITY_WARNING \ _Pragma("GCC diagnostic push") \ _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") #else #define DISABLE_FLOAT_EQUALITY_WARNING #define ENABLE_FLOAT_EQUALITY_WARNING #endif template struct scalar_round_op_google { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& x) const { EIGEN_STATIC_ASSERT((!NumTraits::IsComplex), NUMERIC_TYPE_MUST_BE_REAL) Scalar round_val = Eigen::numext::floor(x); const Scalar fraction = x - round_val; if (fraction > Scalar(.5)) { round_val += Scalar(1.0); } else if (fraction == Scalar(.5)) { const Scalar nearest_even_int = round_val - Scalar(2) * Eigen::numext::floor(Scalar(.5) * x); bool is_odd = (nearest_even_int == Scalar(1)); if (is_odd) { round_val += Scalar(1); } } return round_val; } }; template struct functor_traits> { enum { Cost = 4 * NumTraits::AddCost, PacketAccess = false }; }; #undef ENABLE_FLOAT_EQUALITY_WARNING #undef DISABLE_FLOAT_EQUALITY_WARNING template struct bitwise_xor_op { EIGEN_EMPTY_STRUCT_CTOR(bitwise_xor_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& x, const Scalar& y) const { return x ^ y; } typedef typename Eigen::internal::packet_traits::type Packet; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const { return Eigen::internal::pxor(a, b); } }; template struct functor_traits> { enum { Cost = Eigen::NumTraits::AddCost, PacketAccess = true }; }; // TODO(srvasude): Add packet versions of this operation. template struct xlogy_op { EIGEN_EMPTY_STRUCT_CTOR(xlogy_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& x, const Scalar& y) const { if (x == Scalar(0.)) { return Scalar(0.); } return x * numext::log(y); } }; template struct functor_traits> { enum { Cost = (sizeof(Scalar) == 4 ? 40 : 85) + Eigen::NumTraits::MulCost, PacketAccess = false }; }; template // TODO(srvasude): Add packet versions of this operation. struct xdivy_op { EIGEN_EMPTY_STRUCT_CTOR(xdivy_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& x, const Scalar& y) const { if (x == Scalar(0.)) { return Scalar(0.); } return x / y; } }; template struct functor_traits> { enum { Cost = Eigen::NumTraits::MulCost, PacketAccess = false }; }; } // end namespace internal } // end namespace Eigen namespace tensorflow { namespace functor { //////////////////////////////////////////////////////////////////////////////// // Helpers //////////////////////////////////////////////////////////////////////////////// // Base template for functors whose input scalar type is T and // output scalar type is R. template struct base { // func defines operator() and its vectorized version packetOp(). typedef F func; // If true, the functor's corresponding binary op will instantiate // specialized kernels to perform an optimized broadcast // operation. Each functor for which this is enabled increases the // code size, so by default this is disabled for binary functors and // is enabled on a per-op basis as needed. static const bool use_bcast_optimization = false; // operator() has the signature: // out_type operator()(in_type in0, in_type in1 ...) typedef R out_type; typedef T in_type; // TensorFlow provides tensor-ized version of "func". Roughly // speaking, the tensorflow operation has the signature: // tout_type op(tin_type in0) // tout_type op(tin_type in0, tin_type in1) // tout_type op(tin_type in0, in_type scalar) typedef typename TTypes::Flat tout_type; typedef typename TTypes::ConstFlat tin_type; typedef typename TTypes::ConstScalar tscalar_type; // Whether the functor can error out. Currently applies only to integer // div and mod. static const bool has_errors = false; }; // For now, we only apply certain speed optimization for // float/double's broadcast binary op. template struct use_bcast_optimization { static const bool value = false; }; template <> struct use_bcast_optimization { static const bool value = true; }; template <> struct use_bcast_optimization { static const bool value = true; }; //////////////////////////////////////////////////////////////////////////////// // Unary functors //////////////////////////////////////////////////////////////////////////////// // abs(x) = |x| // neg(x) = - x // inverse(x) = 1 / x // square(x) = x^2 // sqrt(x) = x^(1/2) // rsqrt(x) = x^(-1/2) // exp(x) = e^x // expm1(x) = e^x - 1 // log(x) = natural logarithm of x // log1p(x) = natural logarithm of 1 + x // tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) // sigmoid = 1 / (1 + exp(-x)) // a.k.a, logistic // // NOTE: We may eventually implement common functions used in NN // here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc. // For reference, see speech/lstm/eigen_functors.h. template struct abs : base, typename Eigen::internal::scalar_abs_op::result_type> {}; template struct neg : base> {}; template struct inverse : base> {}; template struct square : base> {}; template struct sqrt : base> {}; template struct rsqrt : base> {}; template struct exp : base> {}; template struct expm1 : base> {}; template struct log : base> {}; template struct log1p : base> {}; template struct sign : base> {}; template struct sinh : base> {}; template struct cosh : base> {}; template struct tanh : base> {}; template struct asinh : base> {}; template struct acosh : base> {}; template struct atanh : base> {}; template struct lgamma : base> {}; template struct digamma : base> {}; template struct erf : base> {}; template struct erfc : base> {}; template struct sigmoid : base> {}; template struct sin : base> {}; template struct cos : base> {}; template struct tan : base> {}; template struct asin : base> {}; template struct acos : base> {}; template struct atan : base> {}; template struct bessel_i0e : base> {}; template struct bessel_i1e : base> {}; struct logical_not : base> { }; // Flip all bits. Named invert to be consistent with numpy. template struct invert_op { EIGEN_EMPTY_STRUCT_CTOR(invert_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { return ~a; } }; template struct invert : base> {}; // NOTE: std::isinf, std::isnan, std::isfinite are plain function. // Therefore we need to wrap them in functors to be used with Eigen's // type system. template struct isinf : base, bool> {}; template struct isnan : base, bool> {}; template struct isfinite : base, bool> {}; template struct floor : base> {}; template struct round : base> {}; template struct ceil : base> {}; /** this should go in Eigen * \brief Template functor to compute the round to int value of a scalar */ template struct scalar_rint_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_rint_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { #if defined(__CUDACC__) return ::rint(a); #elif defined(__ANDROID__) return rint(a); #else return std::rint(a); #endif } }; template struct rint : base> {}; //////////////////////////////////////////////////////////////////////////////// // Binary functors //////////////////////////////////////////////////////////////////////////////// // Binary functors: // // add(x, y) = x + y // sub(x, y) = x - y // mul(x, y) = x * y // div(x, y) = x / y // mod(x, y) = x % y (int32 and int64 only) // fmod(x, y) = fmod(x, y) (float and double only) // pow(x, y) = x ^ y // maximum(x, y) = x > y ? x : y // minimum(x, y) = x < y ? x : y // squared_difference(x, y) = (x - y) * (x - y) template struct add : base> { static const bool use_bcast_optimization = true; }; template struct sub : base> { static const bool use_bcast_optimization = true; }; template struct mul : base> { static const bool use_bcast_optimization = true; }; template struct div : base> {}; template struct safe_div : base>> { static const bool has_errors = true; }; template struct div_no_nan : base> {}; template struct fmod : base> {}; template struct mod : base> {}; template struct safe_mod : base>> { static const bool has_errors = true; }; template struct floor_fmod : base> {}; template struct safe_floor_mod : base>> { static const bool has_errors = true; }; template struct floor_div : base> {}; template struct safe_floor_div : base>> { static const bool has_errors = true; }; template struct floor_div_real : base> {}; template struct pow : base> {}; template struct safe_pow : base> { static const bool has_errors = true; }; template struct maximum : base> {}; template struct minimum : base> {}; template struct igamma : base> {}; template struct random_gamma_grad : base> {}; template struct igammac : base> {}; template struct zeta : base> {}; template struct polygamma : base> {}; template struct scalar_atan2_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_atan2_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& y, const Scalar& x) const { #if GOOGLE_CUDA return ::atan2(y, x); #else return std::atan2(y, x); #endif } }; template struct atan2 : base> {}; template struct squared_difference : base, Eigen::internal::scalar_difference_op>> {}; template struct xdivy : base> {}; template struct xlogy : base> {}; template struct less : base, bool> {}; template struct less_equal : base, bool> {}; template struct greater : base, bool> {}; template struct greater_equal : base, bool> {}; template struct equal_to : base, bool> {}; template struct not_equal_to : base, bool> {}; struct logical_and : base {}; struct logical_or : base {}; template struct bitwise_and_op { EIGEN_EMPTY_STRUCT_CTOR(bitwise_and_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, const T& y) const { return x & y; } }; template struct bitwise_or_op { EIGEN_EMPTY_STRUCT_CTOR(bitwise_or_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, const T& y) const { return x | y; } }; template struct bitwise_and : base> {}; template struct bitwise_or : base> {}; template struct bitwise_xor : base> {}; template struct left_shift_op { EIGEN_EMPTY_STRUCT_CTOR(left_shift_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, const T& y) const { // Avoids UB: don't shift by larger than the bitwidth of T, and // performs left shifts as unsigned shifts. T y_clamped = y; if (y_clamped < 0) { y_clamped = 0; } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { y_clamped = sizeof(T) * CHAR_BIT - 1; } using U = typename std::make_unsigned::type; return static_cast(static_cast(x) << static_cast(y_clamped)); } }; template struct right_shift_op { EIGEN_EMPTY_STRUCT_CTOR(right_shift_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, const T& y) const { // Avoids UB: don't shift by larger than the bitwidth of T. T y_clamped = y; if (y_clamped < 0) { y_clamped = 0; } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { y_clamped = sizeof(T) * CHAR_BIT - 1; } // Technically right shifts of signed integers are not necessarily // arithmetic shifts according to the C++ standard. However in practice most // implementations are arithmetic shifts. If this proves to be a problem in // practice, we may need to use an alternative implementation. return x >> y_clamped; } }; template struct left_shift : base> {}; template struct right_shift : base> {}; template struct make_complex_func { typedef std::complex result_type; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(T real, T imag) const { return std::complex(real, imag); } }; template struct make_complex : base, std::complex> {}; template struct get_real : base, typename T::value_type> {}; template struct get_imag : base, typename T::value_type> {}; template struct get_angle : base, typename T::value_type> {}; template struct conj : base> {}; //////////////////////////////////////////////////////////////////////////////// // Functors takes 1 or 2 tensors, computes the base functor on // coefficient of the input tensors and puts the results in the output // tensor. //////////////////////////////////////////////////////////////////////////////// template struct UnaryFunctor { // Computes on device "d": out[i] = Functor(in[i]) void operator()(const Device& d, typename Functor::tout_type out, typename Functor::tin_type in); }; template struct BinaryFunctor { // Computes on device "d": out[i] = Functor(in0[i], in1[i]) void operator()(const Device& d, typename Functor::tout_type out, typename Functor::tin_type in0, typename Functor::tin_type in1, bool* error); // Computes on device "d": out[i] = Functor(scalar[0], in[i]) void Left(const Device& d, typename Functor::tout_type out, typename Functor::tscalar_type scalar, typename Functor::tin_type in, bool* error); // Computes on device "d": out[i] = Functor(in[i], scalar[0]) void Right(const Device& d, typename Functor::tout_type out, typename Functor::tin_type in, typename Functor::tscalar_type scalar, bool* error); // Computes on device "d": // out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast1)) // // TODO(zhifengc): makes BCast a template member function on NDIMS // instead making BinaryFunctor templates on NDIMS. void BCast(const Device& d, typename TTypes::Tensor out, typename TTypes::ConstTensor in0, typename Eigen::array bcast0, typename TTypes::ConstTensor in1, typename Eigen::array bcast1, bool* error); }; template struct ApproximateEqual { void operator()(const Device& d, typename TTypes::ConstFlat x, typename TTypes::ConstFlat y, T tolerance, typename TTypes::Flat z); }; template bool AllOne(const typename Eigen::array& a) { for (size_t i = 0; i < a.size(); ++i) { if (a[i] != 1) return false; } return true; } template struct SelectFunctor { void operator()(const Device& d, typename TTypes::Flat out, typename TTypes::ConstFlat cond_flat, typename TTypes::ConstFlat then_flat, typename TTypes::ConstFlat else_flat); }; template struct SelectScalarFunctor { void operator()(const Device& d, typename TTypes::Flat out, typename TTypes::ConstScalar cond, typename TTypes::ConstFlat then_flat, typename TTypes::ConstFlat else_flat); }; template struct BatchSelectFunctor { void operator()(const Device& d, typename TTypes::Matrix output_flat_outer_dims, TTypes::ConstVec cond_vec, typename TTypes::ConstMatrix then_flat_outer_dims, typename TTypes::ConstMatrix else_flat_outer_dims); }; } // end namespace functor } // end namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_