aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/llvm_ir
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-01 01:40:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-01 01:43:12 -0700
commitc9fb2a51307ca8597b7d2d436fcdd28a88e78ba5 (patch)
tree7367d118cfe84dc9c23f41b0e80aeaefd689de8f /tensorflow/compiler/xla/service/llvm_ir
parent73e5438b725b46e745e6e910c6557b51a321c70f (diff)
Use ConstantDataArray to lower arrays of constants.
For large constants, creating an llvm::Constant for each element can get prohibitively large compile times. PiperOrigin-RevId: 198843141
Diffstat (limited to 'tensorflow/compiler/xla/service/llvm_ir')
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc47
2 files changed, 45 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
index f172b1d87c..d909845a3a 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
@@ -80,8 +80,10 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) {
*ir_builder_->GetInsertBlock()->getModule(), initializer->getType(),
/*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer,
/*Name=*/"");
+ llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast(
+ global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo());
generators_[constant] = [=](const IrArray::Index& index) {
- return IrArray(global, constant->shape())
+ return IrArray(shape_constant, constant->shape())
.EmitReadArrayElement(index, ir_builder_);
};
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index ec04239b4f..bd45f83fb1 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -368,15 +368,52 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
return llvm::ConstantArray::get(aggregate_type, elements);
}
+template <typename T>
+llvm::Constant* GetConstantDataArray(const Literal& literal,
+ llvm::Module* module) {
+ const T* data = static_cast<const T*>(literal.untyped_data());
+ int64 num_elements = literal.size_bytes() / sizeof(T);
+ return llvm::ConstantDataArray::get(module->getContext(),
+ llvm::makeArrayRef(data, num_elements));
+}
+
} // namespace
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
llvm::Module* module) {
- std::vector<int64> multi_index(ShapeUtil::Rank(literal.shape()), 0);
- llvm::Constant* value = LiteralToConstant(
- literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1,
- &multi_index, module);
- return value;
+ const Shape& shape = literal.shape();
+ // TODO(b/29904935): We can get rid of this switch by exposing a
+ // ConstantDataArray factory method that takes a llvm::Type and a StringRef.
+ switch (shape.element_type()) {
+ case U64:
+ return GetConstantDataArray<uint64>(literal, module);
+ case U32:
+ return GetConstantDataArray<uint32>(literal, module);
+ case U8:
+ return GetConstantDataArray<uint8>(literal, module);
+ case S64:
+ return GetConstantDataArray<int64>(literal, module);
+ case S32:
+ return GetConstantDataArray<int32>(literal, module);
+ case F64:
+ return GetConstantDataArray<double>(literal, module);
+ case F32:
+ return GetConstantDataArray<float>(literal, module);
+ case BF16:
+ case F16:
+ return GetConstantDataArray<uint16>(literal, module);
+ case PRED:
+ return GetConstantDataArray<bool>(literal, module);
+ // TODO(b/29904935): Also use ConstantDataArray for complex numbers.
+ case C64: {
+ int64 dimensions = ShapeUtil::Rank(shape);
+ std::vector<int64> multi_index(dimensions, 0);
+ return LiteralToConstant(literal, /*dimension_index=*/dimensions - 1,
+ &multi_index, module);
+ }
+ default:
+ LOG(FATAL) << "unsupported type " << shape.element_type();
+ }
}
llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,