diff options
author | 2017-05-05 18:10:11 -0800 | |
---|---|---|
committer | 2017-05-05 19:32:06 -0700 | |
commit | b04d0985f34b15657cb179731871aee02f138962 (patch) | |
tree | 35a2a973b5a6312c41aa9a7cb8f4f2739cdfab9a /tensorflow/compiler/xla/util.h | |
parent | 87ba9f5370c0f7068760f9536979d9183f6dfe9c (diff) |
[TF:XLA] Optimize the literal transpose operation
Optimize the literal transpose operation by avoiding item by item copies.
Transposing a F32{128, 64, 64, 32} with a {0, 3, 2, 1} permutation, on a Xeon E5-1650 v3, took ~40s before, and ~130ms after.
Made literal Reshape support not MonotonicDim0Major layouts.
Optimized the literal Relayout operation to use the new Copy() operation, and to hence cover all the primitive types.
Added unittest for the LiteralUtil::Populate() API.
Change: 155265178
Diffstat (limited to 'tensorflow/compiler/xla/util.h')
-rw-r--r-- | tensorflow/compiler/xla/util.h | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 236728f417..15a6ef404e 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -177,6 +177,9 @@ Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); string Reindent(tensorflow::StringPiece original, tensorflow::StringPiece indentation); +// Checks whether permutation is a permutation of the [0, rank) integer range. +bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank); + // Applies `permutation` on `input` and returns the permuted array. // For each i, output[permutation[i]] = input[i]. // @@ -187,12 +190,11 @@ template <template <typename...> class C, typename T> std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation, C<T> input_) { tensorflow::gtl::ArraySlice<T> input(input_); - CHECK_EQ(permutation.size(), input.size()); + CHECK(IsPermutation(permutation, input.size())); std::vector<T> output(input.size()); for (size_t i = 0; i < permutation.size(); ++i) { output[permutation[i]] = input[i]; } - DCHECK(std::is_permutation(input.begin(), input.end(), output.begin())); return output; } |