aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAntonio Sanchez <cantonios@google.com>2021-06-25 14:22:19 -0700
committerAntonio Sanchez <cantonios@google.com>2021-06-29 10:36:20 -0700
commit3a087ccb99b454dc34484333e608e836e7032213 (patch)
treec14d2a34f8582ead6d7e9b24ceb3ba980bba3d4f
parent2d132d17365ffc84c0cc7a7da9b8f7090e94b476 (diff)
Modify tensor argmin/argmax to always return first occurence.
As written, depending on multithreading/gpu, the returned index from `argmin`/`argmax` is not currently stable. Here we modify the functors to always keep the first occurence (i.e. if the value is equal to the current min/max, then keep the one with the smallest index). This is otherwise causing unpredictable results in some TF tests.
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h16
1 files changed, 12 insertions, 4 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
index fd8fa00fa..3b2100ab0 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
@@ -365,12 +365,16 @@ struct reducer_traits<OrReducer, Device> {
};
};
-
-// Argmin/Argmax reducers
+// Argmin/Argmax reducers. Returns the first occurrence if multiple locations
+// contain the same min/max value.
template <typename T> struct ArgMaxTupleReducer
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
- if (t.second > accum->second) { *accum = t; }
+ if (t.second < accum->second) {
+ return;
+ } else if (t.second > accum->second || t.first < accum->first) {
+ *accum = t;
+ }
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
return T(0, NumTraits<typename T::second_type>::lowest());
@@ -394,7 +398,11 @@ struct reducer_traits<ArgMaxTupleReducer<T>, Device> {
template <typename T> struct ArgMinTupleReducer
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T& t, T* accum) const {
- if (t.second < accum->second) { *accum = t; }
+ if (t.second > accum->second) {
+ return;
+ } else if (t.second < accum->second || t.first < accum->first) {
+ *accum = t;
+ }
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
return T(0, NumTraits<typename T::second_type>::highest());