diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-05-11 14:05:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-11 14:07:50 -0700 |
commit | e8dbaff96389ecefd8f84d4c3ce3fce18e876cca (patch) | |
tree | 5f2f66314c1060e212eebdffb6ab2c2897af7a0e /tensorflow/compiler/xla/service/elemental_ir_emitter.cc | |
parent | 815e02963bbec52626bf86b88773cdbb0aeb25a6 (diff) |
Make the elemental ir emitter for dot operations respect contraction dims
PiperOrigin-RevId: 196305803
Diffstat (limited to 'tensorflow/compiler/xla/service/elemental_ir_emitter.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/elemental_ir_emitter.cc | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index f2ad6eaf3a..0a400e982a 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1863,8 +1863,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot( const llvm_ir::IrArray::Index& dot_result_index) const { auto lhs_generator = operand_to_generator.at(hlo->operand(0)); auto rhs_generator = operand_to_generator.at(hlo->operand(1)); - int64 contracted_dim_size = hlo->operand(0)->shape().dimensions( - hlo->operand(0)->shape().dimensions_size() - 1); + + const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers(); + int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0); + int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0); + + int64 contracted_dim_size = + hlo->operand(0)->shape().dimensions(lhs_contracting_dim); int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); @@ -1895,13 +1900,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot( for (int64 i = 0; i < lhs_dims - 1; i++) { lhs_index.push_back(dot_result_index[i]); } - lhs_index.push_back(inner_loop->GetIndVarValue()); + lhs_index.InsertAt(lhs_contracting_dim, inner_loop->GetIndVarValue()); - for (int64 i = 0; i < rhs_dims - 2; i++) { + for (int64 i = 0; i < rhs_dims - 1; i++) { rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]); } - rhs_index.push_back(inner_loop->GetIndVarValue()); - rhs_index.push_back(dot_result_index.back()); + rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue()); llvm::Value* current_accumulator = ir_builder_->CreateLoad(accumulator_alloca); |