aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/util.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-05 18:10:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-05 19:32:06 -0700
commitb04d0985f34b15657cb179731871aee02f138962 (patch)
tree35a2a973b5a6312c41aa9a7cb8f4f2739cdfab9a /tensorflow/compiler/xla/util.h
parent87ba9f5370c0f7068760f9536979d9183f6dfe9c (diff)
[TF:XLA] Optimize the literal transpose operation
Optimize the literal transpose operation by avoiding item by item copies. Transposing a F32{128, 64, 64, 32} with a {0, 3, 2, 1} permutation, on a Xeon E5-1650 v3, took ~40s before, and ~130ms after. Made literal Reshape support not MonotonicDim0Major layouts. Optimized the literal Relayout operation to use the new Copy() operation, and to hence cover all the primitive types. Added unittest for the LiteralUtil::Populate() API. Change: 155265178
Diffstat (limited to 'tensorflow/compiler/xla/util.h')
-rw-r--r--tensorflow/compiler/xla/util.h6
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index 236728f417..15a6ef404e 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -177,6 +177,9 @@ Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
string Reindent(tensorflow::StringPiece original,
tensorflow::StringPiece indentation);
+// Checks whether permutation is a permutation of the [0, rank) integer range.
+bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank);
+
// Applies `permutation` on `input` and returns the permuted array.
// For each i, output[permutation[i]] = input[i].
//
@@ -187,12 +190,11 @@ template <template <typename...> class C, typename T>
std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
C<T> input_) {
tensorflow::gtl::ArraySlice<T> input(input_);
- CHECK_EQ(permutation.size(), input.size());
+ CHECK(IsPermutation(permutation, input.size()));
std::vector<T> output(input.size());
for (size_t i = 0; i < permutation.size(); ++i) {
output[permutation[i]] = input[i];
}
- DCHECK(std::is_permutation(input.begin(), input.end(), output.begin()));
return output;
}