aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-05-11 14:05:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-11 14:07:50 -0700
commite8dbaff96389ecefd8f84d4c3ce3fce18e876cca (patch)
tree5f2f66314c1060e212eebdffb6ab2c2897af7a0e /tensorflow/compiler/xla/service/elemental_ir_emitter.cc
parent815e02963bbec52626bf86b88773cdbb0aeb25a6 (diff)
Make the elemental ir emitter for dot operations respect contraction dims
PiperOrigin-RevId: 196305803
Diffstat (limited to 'tensorflow/compiler/xla/service/elemental_ir_emitter.cc')
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index f2ad6eaf3a..0a400e982a 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -1863,8 +1863,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
const llvm_ir::IrArray::Index& dot_result_index) const {
auto lhs_generator = operand_to_generator.at(hlo->operand(0));
auto rhs_generator = operand_to_generator.at(hlo->operand(1));
- int64 contracted_dim_size = hlo->operand(0)->shape().dimensions(
- hlo->operand(0)->shape().dimensions_size() - 1);
+
+ const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers();
+ int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0);
+ int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0);
+
+ int64 contracted_dim_size =
+ hlo->operand(0)->shape().dimensions(lhs_contracting_dim);
int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
@@ -1895,13 +1900,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
for (int64 i = 0; i < lhs_dims - 1; i++) {
lhs_index.push_back(dot_result_index[i]);
}
- lhs_index.push_back(inner_loop->GetIndVarValue());
+ lhs_index.InsertAt(lhs_contracting_dim, inner_loop->GetIndVarValue());
- for (int64 i = 0; i < rhs_dims - 2; i++) {
+ for (int64 i = 0; i < rhs_dims - 1; i++) {
rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]);
}
- rhs_index.push_back(inner_loop->GetIndVarValue());
- rhs_index.push_back(dot_result_index.back());
+ rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue());
llvm::Value* current_accumulator =
ir_builder_->CreateLoad(accumulator_alloca);