aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/layout_util.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2017-09-14 16:06:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-14 16:10:49 -0700
commitc82a933f449e637ee83244d2c40162e24cdde0e1 (patch)
tree74e0cb414c7c6bbdd67c548c35efbcab20c3e11a /tensorflow/compiler/xla/layout_util.cc
parentdd22dbc7b9be5a47db91fa28557b911dabf720ec (diff)
Lower vector-matrix dot to LLVM IR if the RHS of the dot can be made
column major. The naive dot lowering to LLVM IR (already present in XLA today) is cache efficient if the dot has LHS of shape [1,K]{1,0} and RHS of shape [K x N]{0,1}. This change teaches the layout assignment pass to exploit this property by converting a constant RHS matrix to a column major layout when possible. Couple of related things I had to touch in this change: - In LayoutAssignmentTest.TupleLayout we used to generate a kCopy to satisfy the conflicting constraints between the result and the constant shapes, but with this change we change the layout of the constants themselves. So the EXPECT_FALSE is now an EXPECT_TRUE. - The extra instruction layout constraints added at the end of CpuLayoutAssignment::AddBackendConstraints seemed redundant. The layout assignment pass already tries to make all unconstrained buffers have the default row-major layout. Moreover, they were blocking this optimization in some cases by introducing conflicting constraints. - The changes to literal_util.h have to be made to deal with the Literal::Relayout calls we now get on literals of various types. PiperOrigin-RevId: 168761204
Diffstat (limited to 'tensorflow/compiler/xla/layout_util.cc')
-rw-r--r--tensorflow/compiler/xla/layout_util.cc6
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index 6271b59a5b..011fc3c194 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -109,6 +109,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
}
+/* static */ Shape LayoutUtil::GetWithDefaultLayout(const Shape& shape) {
+ Shape copy(shape);
+ LayoutUtil::SetToDefaultLayout(&copy);
+ return copy;
+}
+
/* static */ void LayoutUtil::SetToDefaultLayout(ProgramShape* program_shape) {
for (auto& parameter_shape : *program_shape->mutable_parameters()) {
LayoutUtil::SetToDefaultLayout(&parameter_shape);