aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc')
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc471
1 files changed, 471 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
new file mode 100644
index 0000000000..d7a231db61
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -0,0 +1,471 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+
+#include <algorithm>
+#include <vector>
+
+#include "external/llvm/include/llvm/IR/MDBuilder.h"
+#include "external/llvm/include/llvm/IR/Operator.h"
+#include "external/llvm/include/llvm/Target/TargetOptions.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace llvm_ir {
+
+string AsString(const std::string& str) {
+ return string(str.data(), str.length());
+}
+
+llvm::StringRef AsStringRef(tensorflow::StringPiece str) {
+ return llvm::StringRef(str.data(), str.size());
+}
+
+string DumpModuleToString(const llvm::Module& module) {
+ std::string buffer_string;
+ llvm::raw_string_ostream ostream(buffer_string);
+ module.print(ostream, nullptr);
+ ostream.flush();
+ return AsString(buffer_string);
+}
+
+llvm::Value* EmitCallToIntrinsic(
+ llvm::Intrinsic::ID intrinsic_id,
+ tensorflow::gtl::ArraySlice<llvm::Value*> operands,
+ tensorflow::gtl::ArraySlice<llvm::Type*> overloaded_types,
+ llvm::IRBuilder<>* ir_builder) {
+ std::vector<llvm::Type*> types;
+ for (auto type : overloaded_types) {
+ types.push_back(type);
+ }
+ llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent();
+ llvm::Function* intrinsic =
+ llvm::Intrinsic::getDeclaration(module, intrinsic_id, types);
+ std::vector<llvm::Value*> operands_vec;
+ for (auto operand : operands) {
+ operands_vec.push_back(operand);
+ }
+ return ir_builder->CreateCall(intrinsic, operands_vec);
+}
+
+llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index,
+ llvm::IRBuilder<>* ir_builder) {
+ llvm::Type* array_type = array->getType();
+ CHECK(array_type->isPointerTy());
+ llvm::PointerType* array_type_as_pointer =
+ llvm::cast<llvm::PointerType>(array_type);
+ VLOG(2) << "EmitBufferIndexingGEP with type="
+ << llvm_ir::DumpToString(*array_type)
+ << " array=" << llvm_ir::DumpToString(*array)
+ << " index=" << llvm_ir::DumpToString(*index);
+
+ return ir_builder->CreateInBoundsGEP(
+ array_type_as_pointer->getElementType(), array,
+ llvm::isa<llvm::GlobalVariable>(array)
+ ? llvm::ArrayRef<llvm::Value*>({ir_builder->getInt64(0), index})
+ : index);
+}
+
+llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
+ llvm::IRBuilder<>* ir_builder) {
+ return EmitBufferIndexingGEP(array, ir_builder->getInt64(index), ir_builder);
+}
+
+llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
+ llvm::IRBuilder<>* ir_builder) {
+ switch (element_type) {
+ case PRED:
+ case S8:
+ case U8:
+ return ir_builder->getInt8Ty();
+ case S16:
+ case U16:
+ return ir_builder->getInt16Ty();
+ case S32:
+ case U32:
+ return ir_builder->getInt32Ty();
+ case S64:
+ case U64:
+ return ir_builder->getInt64Ty();
+ case F32:
+ return ir_builder->getFloatTy();
+ case F64:
+ return ir_builder->getDoubleTy();
+ // A Tuple contains an array of pointers. Use i8*.
+ case TUPLE:
+ // An Opaque is like a void*, use i8*.
+ case OPAQUE:
+ return ir_builder->getInt8PtrTy();
+ default:
+ LOG(FATAL) << "unsupported type " << element_type;
+ }
+}
+
+llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder) {
+ llvm::Type* result_type =
+ PrimitiveTypeToIrType(shape.element_type(), ir_builder);
+ if (ShapeUtil::IsTuple(shape)) {
+ // A tuple buffer is an array of pointers.
+ result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
+ } else {
+ for (int64 dimension : shape.layout().minor_to_major()) {
+ result_type =
+ llvm::ArrayType::get(result_type, shape.dimensions(dimension));
+ }
+ }
+ return result_type;
+}
+
+namespace {
+
+// Recursively construct a multidimensional LLVM constant which represents the
+// given literal. The minor-to-major dimension ordering in the constant matches
+// that of the literal. For example, given a [2 x 3 x 4] Literal (dimension 0
+// has size 4, dimension 1 has size 3, etc) of primitive type F32 with a
+// minor_to_major value of [2, 1, 0] (column major), a LLVM constant of type
+// [4 x [3 x [2 x float]] will be returned.
+//
+// multi_index is a multidimensional index into the array. dimension_index is an
+// index into the minor_to_major field in the literal shape. This determines
+// which dimension is iterated over in this level of the recursion. Dimensions
+// are iterated from most major down to most minor (highest dimension_index
+// value down to zero).
+llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
+ std::vector<int64>* multi_index,
+ llvm::IRBuilder<>* ir_builder) {
+ const Shape& shape = literal.shape();
+ llvm::Type* ir_element_type =
+ llvm_ir::PrimitiveTypeToIrType(shape.element_type(), ir_builder);
+ if (dimension_index == -1) {
+ // Base case of the recursion. Index into the data field of the protobuf
+ // with the multi index.
+ llvm::Constant* value;
+ switch (shape.element_type()) {
+ case PRED:
+ value = llvm::ConstantInt::get(
+ ir_element_type, LiteralUtil::Get<bool>(literal, *multi_index));
+ break;
+ case U8:
+ value = llvm::ConstantInt::get(
+ ir_element_type, LiteralUtil::Get<uint8>(literal, *multi_index));
+ break;
+ case S32:
+ value = llvm::ConstantInt::get(
+ ir_element_type, LiteralUtil::Get<int32>(literal, *multi_index));
+ break;
+ case U32:
+ value = llvm::ConstantInt::get(
+ ir_element_type, LiteralUtil::Get<uint32>(literal, *multi_index));
+ break;
+ case S64:
+ value = llvm::ConstantInt::get(
+ ir_element_type, LiteralUtil::Get<int64>(literal, *multi_index));
+ break;
+ case U64:
+ value = llvm::ConstantInt::get(
+ ir_element_type, LiteralUtil::Get<uint64>(literal, *multi_index));
+ break;
+ case F32:
+ value = llvm::ConstantFP::get(
+ ir_element_type, LiteralUtil::Get<float>(literal, *multi_index));
+ break;
+ case F64:
+ value = llvm::ConstantFP::get(
+ ir_element_type, LiteralUtil::Get<double>(literal, *multi_index));
+ break;
+ default:
+ LOG(FATAL) << "unsupported type " << shape.element_type();
+ }
+ return value;
+ }
+
+ // The dimension index starts at the one less than the rank of the array and
+ // decrements with each recursive call. We want to iterate through the
+ // dimensions in major-to-minor order as we recurse so just index into
+ // minor_to_major to get the dimension number for this level of the recursion.
+ int64 dimension = shape.layout().minor_to_major(dimension_index);
+
+ // Recursively call LiteralToConstant to construct subarrays for the
+ // more-minor dimensions. Gather the subarrays into a vector for bundling into
+ // a new (higher-dimensional) ConstantArray.
+ std::vector<llvm::Constant*> elements;
+ for (int64 i = 0; i < shape.dimensions(dimension); ++i) {
+ (*multi_index)[dimension] = i;
+ elements.push_back(LiteralToConstant(literal, dimension_index - 1,
+ multi_index, ir_builder));
+ }
+
+ llvm::Type* element_type;
+ if (elements.empty()) {
+ element_type = ir_element_type;
+ for (int i = 0; i < dimension_index; ++i) {
+ int64 index = shape.layout().minor_to_major(i);
+ element_type =
+ llvm::ArrayType::get(element_type, shape.dimensions(index));
+ }
+ } else {
+ element_type = elements[0]->getType();
+ }
+ llvm::ArrayType* aggregate_type =
+ llvm::ArrayType::get(element_type, shape.dimensions(dimension));
+ return llvm::ConstantArray::get(aggregate_type, elements);
+}
+
+} // namespace
+
+llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
+ llvm::IRBuilder<>* ir_builder) {
+ std::vector<int64> multi_index(ShapeUtil::Rank(literal.shape()), 0);
+ llvm::Constant* value = LiteralToConstant(
+ literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1,
+ &multi_index, ir_builder);
+ return value;
+}
+
+llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
+ tensorflow::StringPiece name,
+ llvm::IRBuilder<>* ir_builder,
+ int alignment) {
+ return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, ir_builder,
+ alignment);
+}
+
+llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(
+ llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name,
+ llvm::IRBuilder<>* ir_builder, int alignment) {
+ llvm::IRBuilder<>::InsertPoint insert_point = ir_builder->saveIP();
+ llvm::Function* function = ir_builder->GetInsertBlock()->getParent();
+ ir_builder->SetInsertPoint(&function->getEntryBlock(),
+ function->getEntryBlock().getFirstInsertionPt());
+ llvm::AllocaInst* alloca =
+ ir_builder->CreateAlloca(type, element_count, AsStringRef(name));
+ if (alignment != 0) {
+ alloca->setAlignment(alignment);
+ }
+ ir_builder->restoreIP(insert_point);
+ return alloca;
+}
+
+llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
+ tensorflow::StringPiece name,
+ llvm::IRBuilder<>* ir_builder) {
+ return llvm::BasicBlock::Create(
+ /*Context=*/ir_builder->getContext(),
+ /*Name=*/AsStringRef(name),
+ /*Parent=*/ir_builder->GetInsertBlock()->getParent(),
+ /*InsertBefore*/ insert_before);
+}
+
+LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name,
+ llvm::IRBuilder<>* ir_builder, bool emit_else) {
+ llvm_ir::LlvmIfData if_data;
+ if_data.if_block = ir_builder->GetInsertBlock();
+ if_data.true_block = CreateBasicBlock(
+ nullptr, tensorflow::strings::StrCat(name, "-true"), ir_builder);
+ if_data.false_block =
+ emit_else ? CreateBasicBlock(nullptr,
+ tensorflow::strings::StrCat(name, "-false"),
+ ir_builder)
+ : nullptr;
+
+ // There is no reason this function cannot work without a
+ // terminator, that is just a different case that has not been
+ // implemented yet. It is a different case because splitBasicBlock
+ // requires a terminator.
+ CHECK_NE(nullptr, if_data.if_block->getTerminator());
+ if_data.after_block = if_data.if_block->splitBasicBlock(
+ ir_builder->GetInsertPoint(),
+ AsStringRef(tensorflow::strings::StrCat(name, "-after")));
+
+ // splitBasicBlock inserts an unconditional terminator that we have
+ // to remove as we want a conditional branch there.
+ if_data.if_block->getTerminator()->eraseFromParent();
+
+ ir_builder->SetInsertPoint(if_data.if_block);
+ ir_builder->CreateCondBr(
+ condition, if_data.true_block,
+ emit_else ? if_data.false_block : if_data.after_block);
+
+ ir_builder->SetInsertPoint(if_data.true_block);
+ ir_builder->CreateBr(if_data.after_block);
+
+ if (emit_else) {
+ ir_builder->SetInsertPoint(if_data.false_block);
+ ir_builder->CreateBr(if_data.after_block);
+ }
+
+ ir_builder->SetInsertPoint(if_data.after_block,
+ if_data.after_block->getFirstInsertionPt());
+
+ return if_data;
+}
+
+llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
+ llvm::Value* lhs_value, llvm::Value* rhs_value,
+ llvm::IRBuilder<>* ir_builder) {
+ llvm::Value* comparison_result;
+ if (lhs_value->getType()->isIntegerTy()) {
+ comparison_result = ir_builder->CreateICmp(predicate, lhs_value, rhs_value);
+ } else {
+ comparison_result = ir_builder->CreateFCmp(predicate, lhs_value, rhs_value);
+ }
+ // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
+ // arrays. So we extend it to i8 so that it's addressable.
+ return ir_builder->CreateZExt(
+ comparison_result, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder));
+}
+
+// Internal helper that is called from emitted code to log an int64 value with a
+// tag.
+static void LogS64(const char* tag, int64 value) {
+ LOG(INFO) << tag << " (int64): " << value;
+}
+
+void EmitLogging(const char* tag, llvm::Value* value,
+ llvm::IRBuilder<>* ir_builder) {
+ llvm::FunctionType* log_function_type = llvm::FunctionType::get(
+ ir_builder->getVoidTy(),
+ {ir_builder->getInt64Ty(), ir_builder->getInt64Ty()}, /*isVarArg=*/false);
+ ir_builder->CreateCall(
+ log_function_type,
+ ir_builder->CreateIntToPtr(
+ ir_builder->getInt64(tensorflow::bit_cast<int64>(&LogS64)),
+ log_function_type->getPointerTo()),
+ {ir_builder->getInt64(tensorflow::bit_cast<int64>(tag)), value});
+}
+
+void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape,
+ bool is_pointer_to) {
+ legacy_flags::LlvmUtilFlags* flags = legacy_flags::GetLlvmUtilFlags();
+ if (!flags->xla_emit_tbaa) {
+ return;
+ }
+
+ llvm::MDBuilder metadata_builder(instruction->getContext());
+ llvm::MDNode* root = metadata_builder.createTBAARoot("XLA TBAA");
+ string type_name;
+ if (is_pointer_to) {
+ type_name += "pointer-to ";
+ }
+ // Scalars do not have layout which makes it permissible to omit an explicit
+ // layout. To make sure that equivalent scalar shapes have the same TBAA,
+ // remove the (meaningless) explicit layout if one is present.
+ if (ShapeUtil::Rank(shape) == 0) {
+ LayoutUtil::ClearLayout(&shape);
+ } else {
+ CHECK(shape.has_layout());
+ }
+ type_name += shape.ShortDebugString();
+ llvm::MDNode* tbaa_node =
+ metadata_builder.createTBAANode(llvm_ir::AsStringRef(type_name), root);
+ instruction->setMetadata(llvm::LLVMContext::MD_tbaa,
+ metadata_builder.createTBAAStructTagNode(
+ tbaa_node, tbaa_node, /*Offset=*/0));
+}
+
+void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
+ llvm::LLVMContext& context = load->getContext();
+ llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
+ llvm::Constant* alignment_constant =
+ llvm::ConstantInt::get(int64_ty, alignment);
+ llvm::MDBuilder metadata_builder(context);
+ auto* alignment_metadata =
+ metadata_builder.createConstant(alignment_constant);
+ load->setMetadata(llvm::LLVMContext::MD_align,
+ llvm::MDNode::get(context, alignment_metadata));
+}
+
+void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
+ uint64_t dereferenceable_bytes) {
+ llvm::LLVMContext& context = load->getContext();
+ llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
+ llvm::Constant* dereferenceable_bytes_constant =
+ llvm::ConstantInt::get(int64_ty, dereferenceable_bytes);
+ llvm::MDBuilder metadata_builder(context);
+ auto* dereferenceable_bytes_metadata =
+ metadata_builder.createConstant(dereferenceable_bytes_constant);
+ load->setMetadata(llvm::LLVMContext::MD_dereferenceable,
+ llvm::MDNode::get(context, dereferenceable_bytes_metadata));
+}
+
+llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
+ llvm::Instruction* inst) {
+ llvm::LLVMContext& context = inst->getParent()->getContext();
+ llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context);
+ inst->setMetadata(
+ llvm::LLVMContext::MD_range,
+ llvm::MDNode::get(
+ context,
+ {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)),
+ llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))}));
+ return inst;
+}
+
+string SanitizeIrName(string function_name) {
+ // Replace some characters that cannot occur in LLVM names with '_'
+ std::replace(function_name.begin(), function_name.end(), '.', '_');
+ std::replace(function_name.begin(), function_name.end(), '%', '_');
+ std::replace(function_name.begin(), function_name.end(), '-', '_');
+ return function_name;
+}
+
+void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
+ builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
+}
+
+llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
+ llvm::IRBuilder<>* builder) {
+ auto size = rotand->getType()->getPrimitiveSizeInBits();
+ auto size_value = builder->getIntN(size, size);
+ auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); };
+ return builder->CreateOr(
+ builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))),
+ builder->CreateLShr(rotand, mod(rotor)));
+}
+
+int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
+ unsigned pointer_size = data_layout.getPointerSize();
+ return ShapeUtil::ByteSizeOf(shape, pointer_size);
+}
+
+void SetFastMathFlags(llvm::FastMathFlags* fast_math_flags) {
+ auto* flags = legacy_flags::GetLlvmBackendFlags();
+ if (flags->xla_precision_losing_optimizations) {
+ fast_math_flags->setAllowReciprocal();
+ }
+ if (flags->xla_fast_math) {
+ fast_math_flags->setUnsafeAlgebra();
+ }
+}
+
+void SetTargetOptions(llvm::TargetOptions* options) {
+ auto* flags = legacy_flags::GetLlvmBackendFlags();
+ options->LessPreciseFPMADOption = options->UnsafeFPMath =
+ flags->xla_fast_math || flags->xla_precision_losing_optimizations;
+ options->NoInfsFPMath = options->NoNaNsFPMath = flags->xla_fast_math;
+}
+
+} // namespace llvm_ir
+} // namespace xla