aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-28 10:52:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-28 10:56:18 -0700
commitefc63f6248a4a85c885e4a4facabd7242ee3a94c (patch)
treed93be7518aef1a3c9df0e1bd6ff32b9cfe6724d5
parenta553aff1319a68b1ee8e8084c3915ea1fde8b888 (diff)
Lower concatenate operations to memcpy.
This usually ends up being faster than elemental IR implementation. PiperOrigin-RevId: 163489782
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc167
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h18
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc8
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.h5
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc48
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.h8
6 files changed, 252 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 644bf5fd74..8671452e7f 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -2382,6 +2382,173 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
return Status::OK();
}
+StatusOr<bool> IrEmitter::EmitFastConcatenate(
+ HloInstruction* concatenate,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ string* failure_reason) {
+ if (ShouldEmitParallelLoopFor(*concatenate)) {
+ *failure_reason =
+ "cannot generate memcpy-based concat for the parallel CPU backend";
+ return false;
+ }
+
+ const Shape& output_shape = concatenate->shape();
+ for (auto* op : operands) {
+ if (!LayoutUtil::Equal(op->shape().layout(), output_shape.layout())) {
+ *failure_reason = "operand has mismatching layouts";
+ return false;
+ }
+ if (LayoutUtil::IsPadded(op->shape())) {
+ *failure_reason = "operand has padded layout";
+ return false;
+ }
+ }
+
+ CHECK(!LayoutUtil::IsPadded(concatenate->shape()));
+
+ // We split the dimensions into three categories: the dimension over which we
+ // are concatenating (concat_dim), the dimensions that are minor to it
+ // (inner_dims) and the dimensions that are major to it (outer_dims).
+
+ int64 concat_dim = concatenate->dimensions(0);
+ const Layout& output_layout = output_shape.layout();
+ auto concat_dim_layout_itr =
+ std::find(output_layout.minor_to_major().begin(),
+ output_layout.minor_to_major().end(), concat_dim);
+
+ std::vector<int64> inner_dims(output_layout.minor_to_major().begin(),
+ concat_dim_layout_itr);
+ std::vector<int64> outer_dims(std::next(concat_dim_layout_itr),
+ output_layout.minor_to_major().end());
+
+ llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy();
+ llvm::Type* i8_type = ir_builder_.getInt8Ty();
+
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(concatenate));
+
+ llvm_ir::IrArray target_array(target_address, output_shape);
+
+ llvm_ir::ForLoopNest loops(&ir_builder_);
+ llvm_ir::IrArray::Index outer_dims_index =
+ loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat");
+ std::replace(outer_dims_index.begin(), outer_dims_index.end(),
+ static_cast<llvm::Value*>(nullptr),
+ static_cast<llvm::Value*>(ir_builder_.getInt64(0)));
+
+ if (!outer_dims.empty()) {
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+ }
+
+ PrimitiveType primitive_type = output_shape.element_type();
+ unsigned primitive_type_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
+
+ AddAliasingInformationToIrArray(*concatenate, &target_array);
+
+ // Contiguous subregions from each operand to the concatenate contribute to a
+ // contiguous subregion in the target buffer starting at target_region_begin.
+ llvm::Value* target_region_begin = ir_builder_.CreateBitCast(
+ target_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_,
+ "target_region"),
+ i8_ptr_type);
+ int64 byte_offset_into_target_region = 0;
+
+ int64 inner_dims_product =
+ std::accumulate(inner_dims.begin(), inner_dims.end(), 1l,
+ [&](int64 product, int64 inner_dim) {
+ return product * output_shape.dimensions(inner_dim);
+ });
+
+ // For each operand, emit a memcpy from the operand to the target of size
+ // equal to the product of inner dimensions.
+ for (HloInstruction* operand : operands) {
+ const Shape& input_shape = operand->shape();
+ llvm_ir::IrArray source_array(GetEmittedValueFor(operand), input_shape);
+ AddAliasingInformationToIrArray(*operand, &source_array);
+
+ llvm::Value* copy_source_address = ir_builder_.CreateBitCast(
+ source_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_,
+ "src_addr"),
+ i8_ptr_type);
+
+ llvm::Value* copy_target_address = ir_builder_.CreateGEP(
+ i8_type, target_region_begin,
+ ir_builder_.getInt64(byte_offset_into_target_region));
+
+ EmitTransferElements(
+ copy_target_address, copy_source_address,
+ inner_dims_product * input_shape.dimensions(concat_dim), primitive_type,
+ target_array, source_array);
+
+ byte_offset_into_target_region += inner_dims_product *
+ input_shape.dimensions(concat_dim) *
+ primitive_type_size;
+ }
+
+ if (!outer_dims.empty()) {
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
+ }
+
+ emitted_value_[concatenate] = target_address;
+
+ return true;
+}
+
+void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
+ int64 element_count,
+ PrimitiveType primitive_type,
+ const llvm_ir::IrArray& target_array,
+ const llvm_ir::IrArray& source_array) {
+ unsigned primitive_type_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
+ unsigned element_alignment = GCD(
+ primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type));
+ llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual(
+ llvm_ir::PrimitiveTypeToIrType(primitive_type, &ir_builder_));
+
+ if (element_count == 1) {
+ auto* load_instruction = ir_builder_.CreateAlignedLoad(
+ ir_builder_.CreateBitCast(source, primitive_ptr_type),
+ element_alignment);
+ source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
+ auto* store_instruction = ir_builder_.CreateAlignedStore(
+ load_instruction, ir_builder_.CreateBitCast(target, primitive_ptr_type),
+ element_alignment);
+ target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
+ } else {
+ auto* memcpy_instruction = ir_builder_.CreateMemCpy(
+ target, source, element_count * primitive_type_size, element_alignment);
+
+ // The memcpy does the load and the store internally. The aliasing related
+ // metadata has to reflect that.
+ std::map<int, llvm::MDNode*> merged_metadata =
+ llvm_ir::MergeMetadata(&module_->getContext(), source_array.metadata(),
+ target_array.metadata());
+ for (const auto& kind_md_pair : merged_metadata) {
+ memcpy_instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
+ }
+ }
+}
+
+Status IrEmitter::HandleConcatenate(
+ HloInstruction* concatenate,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ string failure_reason;
+ TF_ASSIGN_OR_RETURN(
+ bool successful,
+ EmitFastConcatenate(concatenate, operands, &failure_reason));
+ if (successful) {
+ VLOG(1) << "Emitted fast concatenate for " << concatenate->ToString();
+ return Status::OK();
+ }
+
+ VLOG(1) << "Could not emit fast concatenate for " << concatenate->ToString()
+ << ": " << failure_reason;
+
+ return DefaultAction(concatenate);
+}
+
Status IrEmitter::FinishVisit(HloInstruction* root) {
// When this method is called, we should have already emitted an IR value for
// the root (return) op. The IR value holds the address of the buffer holding
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 4533253680..2fea6846d8 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -191,6 +191,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
tensorflow::StringPiece custom_call_target) override;
Status HandleWhile(HloInstruction* xla_while) override;
+ Status HandleConcatenate(
+ HloInstruction* concatenate,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
Status FinishVisit(HloInstruction* root) override;
Status Preprocess(HloInstruction* hlo) override;
@@ -407,6 +410,21 @@ class IrEmitter : public DfsHloVisitorWithDefault {
HloInstruction* arg, tensorflow::gtl::ArraySlice<int64> dimensions,
unsigned element_alignment);
+ // Tries to emit a fast concatenate operation using memcpy. Returns true if
+ // successful, and false on failure. On failure, sets "failure_reason" to a
+ // string describing why it could not emit a fast concatenate.
+ StatusOr<bool> EmitFastConcatenate(
+ HloInstruction* concatenate,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ string* failure_reason);
+
+ // Emits LLVM IR to transfer "element_count" elements of type "primitive_type"
+ // from the address "source" to the address "target".
+ void EmitTransferElements(llvm::Value* target, llvm::Value* source,
+ int64 element_count, PrimitiveType primitive_type,
+ const llvm_ir::IrArray& target_array,
+ const llvm_ir::IrArray& source_array);
+
// Name of the computation entry function. This function serves as the
// top-level "main" of the computation and will be invoked by the JIT.
string entry_function_name_;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
index 6bfe8bfc75..5e28e37600 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
@@ -56,7 +56,9 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
alias_scope_md =
GetAliasScopeMetadataForBuffer(buffer_slice, GetAliasDomain());
}
- array->AddAliasScopeMetadata(alias_scope_md);
+ if (alias_scope_md != nullptr) {
+ array->AddAliasScopeMetadata(alias_scope_md);
+ }
}
if (module_.config().debug_options().xla_llvm_enable_noalias_metadata()) {
@@ -65,7 +67,9 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
noalias_md = GetNoaliasMetadataForBuffer(buffer_slice, GetAliasDomain(),
assignment_, hlo);
}
- array->AddNoaliasMetadata(noalias_md);
+ if (noalias_md != nullptr) {
+ array->AddNoaliasMetadata(noalias_md);
+ }
}
if (module_.config()
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
index a38cf0e5d9..a6a3ea1adc 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
@@ -218,17 +218,22 @@ class IrArray {
llvm::IRBuilder<>* ir_builder) const;
void AddAliasScopeMetadata(llvm::MDNode* alias_scope) {
+ CHECK_NE(alias_scope, nullptr);
AddMetadata(llvm::LLVMContext::MD_alias_scope, alias_scope);
}
void AddNoaliasMetadata(llvm::MDNode* noalias) {
+ CHECK_NE(noalias, nullptr);
AddMetadata(llvm::LLVMContext::MD_noalias, noalias);
}
void AddInvariantLoad(llvm::MDNode* invariant_load) {
+ CHECK_NE(invariant_load, nullptr);
AddMetadata(llvm::LLVMContext::MD_invariant_load, invariant_load);
}
+ const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; }
+
// Bumps the "which_dimension" value within the provided index by the provided
// addend.
static Index BumpIndex(const Index& index, int64 which_dimension,
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index 6d985fba0c..0ae75c5b3c 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -464,5 +464,53 @@ void SetTargetOptions(bool fast_math_enabled,
target_options->NoSignedZerosFPMath = fast_math_enabled;
}
+std::map<int, llvm::MDNode*> MergeMetadata(
+ llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
+ const std::map<int, llvm::MDNode*>& b) {
+ // We should extend this as needed to deal with other kinds of metadata like
+ // !dereferenceable and !range.
+
+ std::map<int, llvm::MDNode*> result;
+ for (auto kind_md_pair : a) {
+ if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) {
+ llvm::SmallVector<llvm::Metadata*, 8> union_of_scopes;
+ llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
+ for (const auto& scope_a : kind_md_pair.second->operands()) {
+ scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
+ union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_a.get()));
+ }
+ auto it = b.find(kind_md_pair.first);
+ if (it != b.end()) {
+ for (const auto& scope_b : it->second->operands()) {
+ if (!scope_set.count(llvm::cast<llvm::MDNode>(scope_b.get()))) {
+ union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b.get()));
+ }
+ }
+ }
+ result[llvm::LLVMContext::MD_alias_scope] =
+ llvm::MDNode::get(*context, union_of_scopes);
+ } else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) {
+ llvm::SmallVector<llvm::Metadata*, 8> intersection_of_scopes;
+ llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
+ for (const auto& scope_a : kind_md_pair.second->operands()) {
+ scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
+ }
+ auto it = b.find(kind_md_pair.first);
+ if (it != b.end()) {
+ for (const auto& scope_b : it->second->operands()) {
+ if (scope_set.count(llvm::cast<llvm::MDNode>(scope_b))) {
+ intersection_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b));
+ }
+ }
+ }
+ if (!intersection_of_scopes.empty()) {
+ result[llvm::LLVMContext::MD_noalias] =
+ llvm::MDNode::get(*context, intersection_of_scopes);
+ }
+ }
+ }
+ return result;
+}
+
} // namespace llvm_ir
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
index 96d2c2dba8..6d94603338 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
@@ -238,6 +238,14 @@ llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled);
void SetTargetOptions(bool fast_math_enabled,
llvm::TargetOptions* target_options);
+// Computes a conservative union of the metadata in "a" and "b". For
+// aliasing-related metadata, this means the result can be applied to
+// instructions whose aliasing relationship can be described either by "a" *or*
+// by "b".
+std::map<int, llvm::MDNode*> MergeMetadata(
+ llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
+ const std::map<int, llvm::MDNode*>& b);
+
} // namespace llvm_ir
} // namespace xla