diff options
author | 2018-03-12 07:26:13 -0700 | |
---|---|---|
committer | 2018-03-12 07:30:43 -0700 | |
commit | 12496b26049384b78f63940907078f9269c9866f (patch) | |
tree | 423958b64a69a2aa8debbf8b3c4125ca772a613a | |
parent | cd67e8eb088537874b53b4fa52d02ff50c4a66fa (diff) |
Reuse the linear index when broadcasting a contiguous range of dimensions.
This potentially allows us to get rid of additional mod and div operations.
PiperOrigin-RevId: 188719238
-rw-r--r-- | tensorflow/compiler/xla/service/elemental_ir_emitter.cc | 11 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/llvm_ir/ir_array.cc | 63 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/llvm_ir/ir_array.h | 9 |
3 files changed, 75 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 111c29593e..b6a0903b0e 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1522,15 +1522,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kBroadcast: return [this, hlo, &operand_to_generator]( const IrArray::Index& target_index) -> StatusOr<llvm::Value*> { + const HloInstruction* operand = hlo->operand(0); // The `dimensions` member of the broadcast instruction maps from // input dimensions to output dimensions. - const HloInstruction* operand = hlo->operand(0); - int64 rank = ShapeUtil::Rank(operand->shape()); - IrArray::Index source_index(rank); - for (int64 i = 0; i < rank; ++i) { - source_index[i] = target_index[hlo->dimensions(i)]; - } - return operand_to_generator.at(operand)(source_index); + return operand_to_generator.at( + operand)(target_index.SourceIndexOfBroadcast( + hlo->shape(), operand->shape(), hlo->dimensions(), ir_builder_)); }; case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index d444c1d49d..3312a88844 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -241,6 +241,69 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( return Index(multi_index, linear_index, operand_shape); } +IrArray::Index IrArray::Index::SourceIndexOfBroadcast( + const Shape& shape, const Shape& operand_shape, + tensorflow::gtl::ArraySlice<int64> dimension_mapping, + llvm::IRBuilder<>* builder) const { + int64 rank = ShapeUtil::Rank(operand_shape); + std::vector<llvm::Value*> source_index(rank); + for (int64 i = 0; i < rank; ++i) { + source_index[i] = multidim_[dimension_mapping[i]]; + } + if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) || + !LayoutUtil::HasLayout(shape)) { + return Index(source_index); + } + // High-level idea: we can reuse the linear index if the broadcasted + // dimensions are contiguous, and this part of the operation is a bitcast. + // The other dimensions can be masked out with a div and a mod operation. + std::vector<int64> logical_to_physical = + LayoutUtil::MakeLogicalToPhysical(shape.layout()); + int64 output_rank = ShapeUtil::Rank(shape); + // The minimum physical dimension that is broadcasted. + int64 min_broadcasted_dimension = output_rank; + // The maximum physical dimension that is broadcasted. + int64 max_broadcasted_dimension = -1; + for (int64 i = 0; i < rank; ++i) { + int64 physical_dim = logical_to_physical[dimension_mapping[i]]; + min_broadcasted_dimension = + std::min(min_broadcasted_dimension, physical_dim); + max_broadcasted_dimension = + std::max(max_broadcasted_dimension, physical_dim); + } + bool contiguous_broadcast_dimensions = + max_broadcasted_dimension - min_broadcasted_dimension == rank - 1; + if (!contiguous_broadcast_dimensions) { + return Index(source_index); + } + // Check if the mapped dimensions are a bitcast. + std::vector<int64> operand_logical_to_physical = + LayoutUtil::MakeLogicalToPhysical(operand_shape.layout()); + for (int64 i = 0; i < rank; ++i) { + if (operand_logical_to_physical[i] != + logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) { + return Index(source_index); + } + } + llvm::Value* linear = linear_; + int64 divisor = 1; + for (int64 i = max_broadcasted_dimension + 1; i < output_rank; ++i) { + divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); + } + if (divisor > 1) { + linear = builder->CreateUDiv(linear, builder->getInt64(divisor)); + } + if (min_broadcasted_dimension > 0) { + int64 mod = 1; + for (int64 i = min_broadcasted_dimension; i <= max_broadcasted_dimension; + ++i) { + mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); + } + linear = builder->CreateURem(linear, builder->getInt64(mod)); + } + return Index(source_index, linear, operand_shape); +} + llvm::Value* IrArray::Index::Linearize( tensorflow::gtl::ArraySlice<int64> dimensions, llvm::IRBuilder<>* builder) const { diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index faa92d608c..06cfb2a36c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -134,10 +134,17 @@ class IrArray { llvm::IRBuilder<>* builder) const; // Given that "this" is the target index of a bitcast from `operand_shape` - // to `shape` with the given dimension mapping, returns the source index. + // to `shape`, returns the source index. Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape, llvm::IRBuilder<>* builder) const; + // Given that "this" is the target index of a broadcast from `operand_shape` + // to `shape` with the given dimension mapping, returns the source index. + Index SourceIndexOfBroadcast( + const Shape& shape, const Shape& operand_shape, + tensorflow::gtl::ArraySlice<int64> dimension_mapping, + llvm::IRBuilder<>* builder) const; + // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and // returns the index into the sole dimension 0 of the new shape. llvm::Value* Linearize(tensorflow::gtl::ArraySlice<int64> dimensions, |