diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-01 01:40:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-01 01:43:12 -0700 |
commit | c9fb2a51307ca8597b7d2d436fcdd28a88e78ba5 (patch) | |
tree | 7367d118cfe84dc9c23f41b0e80aeaefd689de8f /tensorflow/compiler/xla/service/llvm_ir | |
parent | 73e5438b725b46e745e6e910c6557b51a321c70f (diff) |
Use ConstantDataArray to lower arrays of constants.
For large constants, creating an llvm::Constant for each element can get prohibitively large compile times.
PiperOrigin-RevId: 198843141
Diffstat (limited to 'tensorflow/compiler/xla/service/llvm_ir')
-rw-r--r-- | tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc | 47 |
2 files changed, 45 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index f172b1d87c..d909845a3a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -80,8 +80,10 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, /*Name=*/""); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); generators_[constant] = [=](const IrArray::Index& index) { - return IrArray(global, constant->shape()) + return IrArray(shape_constant, constant->shape()) .EmitReadArrayElement(index, ir_builder_); }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ec04239b4f..bd45f83fb1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -368,15 +368,52 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, return llvm::ConstantArray::get(aggregate_type, elements); } +template <typename T> +llvm::Constant* GetConstantDataArray(const Literal& literal, + llvm::Module* module) { + const T* data = static_cast<const T*>(literal.untyped_data()); + int64 num_elements = literal.size_bytes() / sizeof(T); + return llvm::ConstantDataArray::get(module->getContext(), + llvm::makeArrayRef(data, num_elements)); +} + } // namespace llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module) { - std::vector<int64> multi_index(ShapeUtil::Rank(literal.shape()), 0); - llvm::Constant* value = LiteralToConstant( - literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1, - &multi_index, module); - return value; + const Shape& shape = literal.shape(); + // TODO(b/29904935): We can get rid of this switch by exposing a + // ConstantDataArray factory method that takes a llvm::Type and a StringRef. + switch (shape.element_type()) { + case U64: + return GetConstantDataArray<uint64>(literal, module); + case U32: + return GetConstantDataArray<uint32>(literal, module); + case U8: + return GetConstantDataArray<uint8>(literal, module); + case S64: + return GetConstantDataArray<int64>(literal, module); + case S32: + return GetConstantDataArray<int32>(literal, module); + case F64: + return GetConstantDataArray<double>(literal, module); + case F32: + return GetConstantDataArray<float>(literal, module); + case BF16: + case F16: + return GetConstantDataArray<uint16>(literal, module); + case PRED: + return GetConstantDataArray<bool>(literal, module); + // TODO(b/29904935): Also use ConstantDataArray for complex numbers. + case C64: { + int64 dimensions = ShapeUtil::Rank(shape); + std::vector<int64> multi_index(dimensions, 0); + return LiteralToConstant(literal, /*dimension_index=*/dimensions - 1, + &multi_index, module); + } + default: + LOG(FATAL) << "unsupported type " << shape.element_type(); + } } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, |