aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_ops.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops.h')
-rw-r--r--tensorflow/core/kernels/cwise_ops.h35
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index da70b1e314..06918075a4 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <type_traits>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
@@ -115,6 +116,35 @@ struct functor_traits<scalar_binary_pow_op_google<Scalar, Exponent>> {
enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
};
+template <typename Scalar, typename Exponent>
+struct safe_scalar_binary_pow_op {
+ static_assert(std::is_integral<Scalar>::value, "Integer type expected");
+ static_assert(std::is_integral<Exponent>::value &&
+ std::is_signed<Exponent>::value,
+ "Signed integer type expected");
+
+ bool* const error;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error)
+ : error(error) {}
+
+ EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a,
+ const Exponent& b) const {
+ const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b);
+ if (TF_PREDICT_TRUE(safe_b >= 0)) {
+ return numext::pow(a, safe_b);
+ } else {
+ *error = true;
+ return 0;
+ }
+ }
+};
+
+template <typename Scalar, typename Exponent>
+struct functor_traits<safe_scalar_binary_pow_op<Scalar, Exponent>> {
+ enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
+};
+
template <typename T, typename DivOrMod>
struct safe_div_or_mod_op {
static_assert(std::is_integral<T>::value, "Integer type expected");
@@ -742,6 +772,11 @@ template <typename T>
struct pow : base<T, Eigen::internal::scalar_binary_pow_op_google<T, T>> {};
template <typename T>
+struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> {
+ static const bool has_errors = true;
+};
+
+template <typename T>
struct maximum : base<T, Eigen::internal::scalar_max_op<T>> {};
template <typename T>