diff options
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops.h')
-rw-r--r-- | tensorflow/core/kernels/cwise_ops.h | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index da70b1e314..06918075a4 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -21,6 +21,7 @@ limitations under the License. #include <type_traits> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -115,6 +116,35 @@ struct functor_traits<scalar_binary_pow_op_google<Scalar, Exponent>> { enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false }; }; +template <typename Scalar, typename Exponent> +struct safe_scalar_binary_pow_op { + static_assert(std::is_integral<Scalar>::value, "Integer type expected"); + static_assert(std::is_integral<Exponent>::value && + std::is_signed<Exponent>::value, + "Signed integer type expected"); + + bool* const error; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error) + : error(error) {} + + EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a, + const Exponent& b) const { + const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b); + if (TF_PREDICT_TRUE(safe_b >= 0)) { + return numext::pow(a, safe_b); + } else { + *error = true; + return 0; + } + } +}; + +template <typename Scalar, typename Exponent> +struct functor_traits<safe_scalar_binary_pow_op<Scalar, Exponent>> { + enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false }; +}; + template <typename T, typename DivOrMod> struct safe_div_or_mod_op { static_assert(std::is_integral<T>::value, "Integer type expected"); @@ -742,6 +772,11 @@ template <typename T> struct pow : base<T, Eigen::internal::scalar_binary_pow_op_google<T, T>> {}; template <typename T> +struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> { + static const bool has_errors = true; +}; + +template <typename T> struct maximum : base<T, Eigen::internal::scalar_max_op<T>> {}; template <typename T> |