diff options
Diffstat (limited to 'tensorflow/compiler/xla/util.h')
-rw-r--r-- | tensorflow/compiler/xla/util.h | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 42d5c1d155..31f0c3147e 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -195,16 +195,24 @@ bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank); // 2. permutation.size() == input.size(). template <template <typename...> class C, typename T> std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation, - C<T> input_) { - tensorflow::gtl::ArraySlice<T> input(input_); - CHECK(IsPermutation(permutation, input.size())); - std::vector<T> output(input.size()); + C<T> input) { + tensorflow::gtl::ArraySlice<T> data(input); + CHECK(IsPermutation(permutation, data.size())); + std::vector<T> output(data.size()); for (size_t i = 0; i < permutation.size(); ++i) { - output[permutation[i]] = input[i]; + output[permutation[i]] = data[i]; } return output; } +// Override of the above that works around compile failures with gcc 7.1.1. +// For details see https://github.com/tensorflow/tensorflow/issues/10843 +template <typename T> +std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation, + const std::vector<T>& input) { + return Permute<std::vector, T>(permutation, input); +} + // Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i. std::vector<int64> InversePermutation( tensorflow::gtl::ArraySlice<int64> input_permutation); |