diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc | 569 |
1 files changed, 11 insertions, 558 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 2a447a54b0..e57d49172b 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -25,9 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -40,450 +38,6 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { -namespace { -// Loads a tile of values from a 2D tensor. -class TileLoader { - public: - // Constructs a TileLoader that will load a tile consisting of - // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at - // `major_dim_offset` in the major dimension. The tile size along the minor - // dimension is the vector size, and that is implicitly determined by `vsl`. - TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, - llvm::Value* matrix, int64 matrix_size_along_minor_dim, - llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) - : vsl_(vsl) { - pointers_.reserve(tile_size_along_major_dim); - for (int64 i = 0; i < tile_size_along_major_dim; i++) { - llvm::Value* total_offset = ir_builder->CreateMul( - ir_builder->getInt64(matrix_size_along_minor_dim), - ir_builder->CreateAdd(ir_builder->getInt64(i), major_dim_offset)); - pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); - } - } - - // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at - // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the - // minor dimension. - std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const { - std::vector<llvm::Value*> result; - result.reserve(pointers_.size()); - for (const auto& pointer : pointers_) { - result.push_back(vsl_->LoadVector(pointer, minor_dim_offset)); - } - return result; - } - - private: - VectorSupportLibrary* vsl_; - std::vector<llvm::Value*> pointers_; -}; - -// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the -// layout of the vector does not matter). This implementation uses a tiling -// scheme to improve performance. -// -// We logically separate the LHS matrix into four segments: -// -// +----------------------+---+ -// | | | -// | | | -// | A | B | -// | | | -// | | | -// | | | -// +----------------------+---+ -// | C | D | -// +----------------------+---+ -// -// where A is the largest submatrix of the LHS that can be evenly dividied into -// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: -// -// +---+---+---+---+ +--+--+--+--+ -// |M00|M10|M20|M30| |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M01|M11|M21|M31| and |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M02|M12|M22|M32| |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M03|M13|M23|M33| |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// -// (Legend: rows are horizontal and columns are vertical; and each column is one -// llvm::Value of a vector type) -// -// where: -// -// a. The left tile is from the column major left matrix. -// b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3] -// vector loaded from the RHS vector. -// -// As we iterate through the column dimension, we compute the change to the -// result vector by an elementwise multiplication between the two tiles above -// followed by a reduction along the major dimension: -// -// +-----------------------------------+ -// | M00*V0 + M10*V1 + M20*V2 + M30*V3 | -// +-----------------------------------+ -// | M01*V0 + M11*V1 + M21*V2 + M31*V3 | -// Result[R:R+4] += +-----------------------------------+ -// | M02*V0 + M12*V1 + M22*V2 + M32*V3 | -// +-----------------------------------+ -// | M03*V0 + M13*V1 + M23*V2 + M33*V3 | -// +-----------------------------------+ -// -// Where R is the starting row for the tile. -// -// We have an inner epilogue loop to deal with the "C" submatrix and an outer -// epilogue loop to deal with the B,D submarix. -// -// TODO(sanjoy): We should investigate if using gather loads and scatter stores -// can be used here have the same inner loop for both column-major and row-major -// matrix-vector products. -class ColumnMajorMatrixVectorProductEmitter { - public: - ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, - int64 tile_rows, int64 tile_cols, - int64 m, int64 k, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* result, - llvm::IRBuilder<>* ir_builder) - : scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), - lhs_(lhs), - rhs_(rhs), - result_(result), - ir_builder_(ir_builder), - ksl_(ir_builder_), - vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") { - CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows_))); - } - - void Emit(); - - private: - void EmitOuterLoopBody(llvm::Value* column, int64 column_count, - bool is_first_column); - - TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) { - return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/m_, - /*major_dim_offset=*/column_start, - /*tile_size_along_major_dim=*/column_count); - } - - // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous - // sequnce of `count` values, each one broadcasted to the vector width. - std::vector<llvm::Value*> LoadRhsTile(llvm::Value* offset, int64 count) { - llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); - std::vector<llvm::Value*> result; - result.reserve(count); - for (int64 i = 0; i < count; i++) { - result.push_back(vsl_.LoadBroadcast(base_pointer, i)); - } - return result; - } - - void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, - const std::vector<llvm::Value*>& rhs_tile, - int64 columns, bool is_first_column); - - void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, - bool is_first_tiled_column); - - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; - llvm::Value* lhs_; - llvm::Value* rhs_; - llvm::Value* result_; - llvm::IRBuilder<>* ir_builder_; - KernelSupportLibrary ksl_; - VectorSupportLibrary vsl_; -}; - -void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( - llvm::Value* column, int64 column_count, bool is_first_column) { - TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column, - /*column_count=*/column_count); - - std::vector<llvm::Value*> rhs_tile = - LoadRhsTile(column, /*count=*/column_count); - EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile, - /*columns=*/column_count, is_first_column); - EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); -} - -void ColumnMajorMatrixVectorProductEmitter::Emit() { - // See the comment on the class declaration for the algorithm used here. - int64 column_remainder = k_ % tile_cols_; - int64 column_limit = k_ - column_remainder; - - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_, - [&](llvm::Value* column, bool is_first_column) { - EmitOuterLoopBody(column, tile_cols_, is_first_column); - }); - - if (column_remainder != 0) { - EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder, - column_limit == 0); - } -} - -void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - TileLoader* lhs_tile_loader, const std::vector<llvm::Value*>& rhs_tile, - int64 columns, bool is_first_column) { - int64 row_limit = m_ - (m_ % tile_rows_); - - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, - /*step=*/tile_rows_, [&](llvm::Value* row) { - std::vector<llvm::Value*> lhs_tile = - lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row); - llvm::Value* accumulator = is_first_column - ? vsl_.GetZeroVector() - : vsl_.LoadVector(result_, row); - for (int i = 0; i < columns; i++) { - accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); - } - vsl_.StoreVector(accumulator, result_, row); - }); -} - -void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( - llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { - int64 row_start = m_ - (m_ % tile_rows_); - if (row_start == m_) { - return; - } - - llvm::Value* columns_llvm = ir_builder_->getInt64(columns); - - // for (col = current_tile_col; col < (columns + current_tile_col); col++) - // for (row = row_start, row < m_; row++) { - // result[row] += lhs[row, col] * rhs[col] - // // Also take into account that if col is 0 then result[row] is not - // // initialized. - // } - - ksl_.For( - "dot.inner.epilg.outer", /*start=*/current_tile_col, - /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col), - /*step=*/1, /*peel_first_iteration=*/false, - [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { - llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); - llvm::Value* total_offset = - ir_builder_->CreateMul(col, ir_builder_->getInt64(m_)); - llvm::Value* lhs_base_pointer = - vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For( - "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_, - /*step=*/1, [&](llvm::Value* scalar_row) { - llvm::Value* product = vsl_.Mul( - vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); - llvm::Value* setting_result_first_time = ir_builder_->CreateAnd( - is_first_scalar_col, - ir_builder_->getInt1(is_first_tiled_column)); - ksl_.If( - setting_result_first_time, - [&]() { vsl_.StoreScalar(product, result_, scalar_row); }, - [&]() { - vsl_.StoreScalar( - vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), - result_, scalar_row); - }); - }); - }); -} - -// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the -// layout of the vector does not matter). This implementation uses a tiling -// scheme to improve performance. -// -// We logically separate the LHS matrix into four segments: -// -// +----------------------+---+ -// | | | -// | | | -// | A | B | -// | | | -// | | | -// | | | -// +----------------------+---+ -// | C | D | -// +----------------------+---+ -// -// where A is the largest submatrix of the LHS that can be evenly dividied into -// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: -// -// +---+---+---+---+ -// |M00|M10|M20|M30| -// +---+---+---+---+ +--+--+--+--+ -// |M01|M11|M21|M31| and |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M02|M12|M22|M32| -// +---+---+---+---+ -// |M03|M13|M23|M33| -// +---+---+---+---+ -// -// (Legend: rows are horizontal and columns are vertical; and each row is one -// llvm::Value of a vector type) -// -// where: -// -// a. The left tile is loaded from the row major left matrix. -// b. The right vector is loaded from the RHS vector. -// -// We keep 4 vector accumulators accumulating the following four vector -// expressions as we iterate over the row dimension: -// -// +------+------+------+------+ -// |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4) -// +------+------+------+------+ -// -// In the end we do a horizontal reduction over these 4 vector accumulators to -// get 4 values in the result vector. -// -// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer -// epilogue loop to deal with the C,D submatrix. -class RowMajorMatrixVectorProductEmitter { - public: - RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, - int64 tile_cols, int64 m, int64 k, - llvm::Value* lhs, llvm::Value* rhs, - llvm::Value* result, - llvm::IRBuilder<>* ir_builder) - : scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), - lhs_(lhs), - rhs_(rhs), - result_(result), - ir_builder_(ir_builder), - ksl_(ir_builder_), - vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") { - CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols_))); - } - - void Emit(); - - private: - TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) { - return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/k_, - /*major_dim_offset=*/row_start, - /*tile_size_along_major_dim=*/row_count); - } - - void EmitOuterLoopBody(llvm::Value* row, int64 row_count); - - void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows, - std::vector<VectorVariable>* vector_accumulators); - - void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, - std::vector<ScalarVariable>* scalar_accumulators); - - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; - llvm::Value* lhs_; - llvm::Value* rhs_; - llvm::Value* result_; - llvm::IRBuilder<>* ir_builder_; - KernelSupportLibrary ksl_; - VectorSupportLibrary vsl_; -}; - -void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, - int64 row_count) { - TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row, - /*row_count=*/row_count); - std::vector<VectorVariable> vector_accumulators; - std::vector<ScalarVariable> scalar_accumulators; - for (int i = 0; i < row_count; i++) { - vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); - scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); - } - EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count, - &vector_accumulators); - EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, - &scalar_accumulators); - - for (int i = 0; i < row_count; i++) { - llvm::Value* result_value = - vsl_.Add(vsl_.AddReduce(vector_accumulators[i].Get()), - scalar_accumulators[i].Get()); - llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row); - vsl_.StoreScalar(result_value, result_, offset); - } -} - -void RowMajorMatrixVectorProductEmitter::Emit() { - // See the comment on the class declaration for the algorithm used here. - int64 row_remainder = m_ % tile_rows_; - int64 row_limit = m_ - row_remainder; - - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_, - [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); }); - - if (row_remainder != 0) { - EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder); - } -} - -void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - TileLoader* lhs_tile_loader, int64 rows, - std::vector<VectorVariable>* vector_accumulators) { - int64 column_limit = k_ - (k_ % tile_cols_); - - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, - /*step=*/tile_cols_, [&](llvm::Value* col) { - std::vector<llvm::Value*> lhs_tile = - lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col); - llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); - for (int i = 0; i < rows; i++) { - llvm::Value* old_sum = (*vector_accumulators)[i].Get(); - (*vector_accumulators)[i].Set( - vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); - } - }); -} - -void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( - llvm::Value* current_tile_row, int64 rows, - std::vector<ScalarVariable>* scalar_accumulators) { - int64 column_start = k_ - (k_ % tile_cols_); - if (column_start == k_) { - return; - } - - for (int r = 0; r < rows; r++) { - llvm::Value* total_offset = ir_builder_->CreateMul( - ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row), - ir_builder_->getInt64(k_)); - llvm::Value* lhs_base_pointer = - vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_, - /*step=*/1, [&](llvm::Value* scalar_col) { - llvm::Value* product = - vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), - vsl_.LoadScalar(rhs_, scalar_col)); - llvm::Value* old_value = (*scalar_accumulators)[r].Get(); - (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); - }); - } -} - -} // namespace - DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, @@ -518,93 +72,6 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; } -bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { - if (dot_.shape().dimensions_size() != 2 || - ProfitableToImplementDotInUntiledLlvmIr(dot_) == - DotInLlvmIrProfitable::kYes) { - return false; - } - - if (!primitive_util::IsFloatingPointType(dot_.shape().element_type()) && - !primitive_util::IsIntegralType(dot_.shape().element_type())) { - return false; - } - - MatMultDims mat_mult_dims = GetMatMultDims(); - bool is_column_major_matrix_vector = false; - bool is_row_major_matrix_vector = false; - - int64 m, k; - bool swap_operands; - - if (mat_mult_dims.m == 1) { - bool rhs_effectively_row_major = - transpose_rhs_ ^ !mat_mult_dims.rhs_column_major; - if (rhs_effectively_row_major) { - k = mat_mult_dims.k; - m = mat_mult_dims.n; - is_column_major_matrix_vector = true; - swap_operands = true; - } else { - k = mat_mult_dims.k; - m = mat_mult_dims.n; - is_row_major_matrix_vector = true; - swap_operands = true; - } - } - - if (mat_mult_dims.n == 1) { - bool lhs_effectively_column_major = - transpose_lhs_ ^ mat_mult_dims.lhs_column_major; - if (lhs_effectively_column_major) { - m = mat_mult_dims.m; - k = mat_mult_dims.k; - is_column_major_matrix_vector = true; - swap_operands = false; - } else { - m = mat_mult_dims.m; - k = mat_mult_dims.k; - is_row_major_matrix_vector = true; - swap_operands = false; - } - } - - if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { - return false; - } - - int64 tiling_factor = GetGemvTilingFactor(); - CHECK_GT(tiling_factor, 0); - - if (is_column_major_matrix_vector) { - VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m - << " and k = " << k; - ColumnMajorMatrixVectorProductEmitter emitter( - dot_.shape().element_type(), /*tile_rows=*/8, - /*tile_cols=*/tiling_factor, m, k, - swap_operands ? rhs_array_.GetBasePointer() - : lhs_array_.GetBasePointer(), - swap_operands ? lhs_array_.GetBasePointer() - : rhs_array_.GetBasePointer(), - target_array_.GetBasePointer(), ir_builder_); - emitter.Emit(); - } else { - VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m - << " and k = " << k; - RowMajorMatrixVectorProductEmitter emitter( - dot_.shape().element_type(), /*tile_rows=*/tiling_factor, - /*tile_cols=*/8, m, k, - swap_operands ? rhs_array_.GetBasePointer() - : lhs_array_.GetBasePointer(), - swap_operands ? lhs_array_.GetBasePointer() - : rhs_array_.GetBasePointer(), - target_array_.GetBasePointer(), ir_builder_); - emitter.Emit(); - } - - return true; -} - tensorflow::Status DotOpEmitter::Emit() { // The dot operation performs a sum of products over dimension 0 of the left // hand side operand and dimension 1 of the right hand side operand. @@ -638,10 +105,6 @@ tensorflow::Status DotOpEmitter::Emit() { return EmitScalarDot(); } - if (EmitLlvmIrDotIfProfitable()) { - return Status::OK(); - } - if (PotentiallyImplementedAsEigenDot(dot_)) { return EmitCallToRuntime(); } @@ -877,17 +340,22 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { // // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'. - MatMultDims mat_mult_dims = GetMatMultDims(); + const Shape& lhs_shape = lhs_array_.GetShape(); + const Shape& rhs_shape = rhs_array_.GetShape(); - CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major); + CHECK(LayoutUtil::Equal(lhs_shape.layout(), rhs_shape.layout())); + int64 m = lhs_shape.dimensions(transpose_lhs_ ? 1 : 0); + int64 k = lhs_shape.dimensions(transpose_lhs_ ? 0 : 1); + int64 n = rhs_shape.dimensions(transpose_rhs_ ? 0 : 1); const llvm_ir::IrArray* lhs = &lhs_array_; const llvm_ir::IrArray* rhs = &rhs_array_; bool transpose_lhs = transpose_lhs_; bool transpose_rhs = transpose_rhs_; - if (!mat_mult_dims.lhs_column_major) { - std::swap(mat_mult_dims.m, mat_mult_dims.n); + bool is_column_major = lhs_shape.layout().minor_to_major(0) == 0; + if (!is_column_major) { + std::swap(m, n); std::swap(lhs, rhs); std::swap(transpose_lhs, transpose_rhs); } @@ -899,27 +367,12 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { float_ptr_type), ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type), ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type), - ir_builder_->getInt64(mat_mult_dims.m), - ir_builder_->getInt64(mat_mult_dims.n), - ir_builder_->getInt64(mat_mult_dims.k), - ir_builder_->getInt32(transpose_lhs), + ir_builder_->getInt64(m), ir_builder_->getInt64(n), + ir_builder_->getInt64(k), ir_builder_->getInt32(transpose_lhs), ir_builder_->getInt32(transpose_rhs)}); return tensorflow::Status::OK(); } -DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { - CHECK_EQ(dot_.shape().dimensions_size(), 2); - - const Shape& lhs_shape = lhs_array_.GetShape(); - const Shape& rhs_shape = rhs_array_.GetShape(); - - return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0), - lhs_shape.dimensions(transpose_lhs_ ? 0 : 1), - rhs_shape.dimensions(transpose_rhs_ ? 0 : 1), - lhs_shape.layout().minor_to_major(0) == 0, - rhs_shape.layout().minor_to_major(0) == 0}; -} - llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, int64 reduction_dimension, tensorflow::StringPiece name_suffix) { |