diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-01 01:40:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-01 01:43:12 -0700 |
commit | c9fb2a51307ca8597b7d2d436fcdd28a88e78ba5 (patch) | |
tree | 7367d118cfe84dc9c23f41b0e80aeaefd689de8f | |
parent | 73e5438b725b46e745e6e910c6557b51a321c70f (diff) |
Use ConstantDataArray to lower arrays of constants.
For large constants, creating an llvm::Constant for each element can get prohibitively large compile times.
PiperOrigin-RevId: 198843141
8 files changed, 78 insertions, 30 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index f6c8593632..a4141dee01 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -160,39 +160,44 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } -llvm::GlobalVariable* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { - llvm::GlobalVariable* result; +llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { + llvm::Constant* result; // We avoid creating large constants in the LLVM IR since LLVM is not // efficient for large constant arrays. We still emit "small enough" constant // arrays into the Ir, in the off chance the LLVM optimizer can do something // interesting with it. + // + // TODO(b/29904935): Remove the large constant pool. const int kMaxInternalConstantSizeInBytes = 128; if (external_constant_pool_ && ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) { string global_name = tensorflow::strings::StrCat( "constant_global_", external_global_constant_counter_++); - result = new llvm::GlobalVariable( + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( /*Module=*/*module_, /*Type=*/IrShapeType(literal.shape()), /*isConstant=*/true, /*Linkage=*/llvm::GlobalValue::ExternalLinkage, /*Initializer=*/nullptr, /*Name=*/AsStringRef(global_name)); - result->setAlignment(MinimumAlignmentForShape(literal.shape())); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); external_constant_pool_->Insert(global_name, literal, MinimumAlignmentForShape(literal.shape())); + result = result_global; } else { llvm::Constant* initializer = llvm_ir::ConvertLiteralToIrConstant(literal, module_); - result = new llvm::GlobalVariable( + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( /*Module=*/*module_, /*Type=*/initializer->getType(), /*isConstant=*/true, /*Linkage=*/llvm::GlobalValue::PrivateLinkage, /*Initializer=*/initializer, /*Name=*/""); - result->setAlignment(MinimumAlignmentForShape(literal.shape())); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); + result = llvm::ConstantExpr::getBitCast( + result_global, IrShapeType(literal.shape())->getPointerTo()); } return result; } @@ -200,7 +205,7 @@ llvm::GlobalVariable* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { Status IrEmitter::HandleConstant(HloInstruction* constant) { VLOG(2) << "HandleConstant: " << constant->ToString(); const Literal& literal = constant->literal(); - llvm::GlobalVariable* global_for_const; + llvm::Constant* global_for_const; auto it = emitted_literals_.find(&literal); if (it != emitted_literals_.end()) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index f49cfc1dc3..32c536e18f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -527,7 +527,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value* program_buffer_address); - llvm::GlobalVariable* EmitGlobalForLiteral(const Literal& literal); + // Returns a ConstExpr bitcast. + llvm::Constant* EmitGlobalForLiteral(const Literal& literal); const HloModuleConfig& hlo_module_config_; @@ -548,7 +549,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { } }; - tensorflow::gtl::FlatMap<const Literal*, llvm::GlobalVariable*, + tensorflow::gtl::FlatMap<const Literal*, llvm::Constant*, LiteralPtrHashFunctor, LiteralPtrEqualityFunctor> emitted_literals_; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index ed8f375bd6..faac927027 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -64,8 +64,8 @@ TEST_F(CpuExternalConstantsTest, BasicNegative) { // The constant array in this test case is small enough that there is no need // to externalize it. TestWithArray(/*rows=*/4, /*cols=*/4, R"( -CHECK-NOT: @constant_global_0 = external constant [4 x [4 x float]], align 8 -CHECK: @0 = private constant [4 x [4 x float]] {{.*}}, align 8 +CHECK-NOT: @constant_global_0 = external constant [16 x float], align 8 +CHECK: @0 = private constant [16 x float] {{.*}}, align 8 )"); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index d6e0425c55..3cb25c5c19 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -55,8 +55,8 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [2 x [3 x [2 x float]]] -CHECK-NOT: private constant [2 x [3 x [2 x float]]] +CHECK: private constant [12 x float] +CHECK-NOT: private constant [12 x float] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, @@ -78,30 +78,30 @@ TEST_F(CpuDuplicateConstantsTest, RepeatedTupleConstants) { HloModule RepeatedConstants while_body { - arg_body = (f32[2,1]{1,0}, f32[2]{0}) parameter(0) - ROOT const = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) + arg_body = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) + ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) } while_cond { - arg_cond = (f32[2,1]{1,0}, f32[2]{0}) parameter(0) + arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) ROOT unknown = pred[] infeed() } ENTRY main { param = f32[2,3,2] parameter(0) - const_a = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) - const_b = (f32[2,1]{1,0}, f32[2]{0}) while((f32[2,1]{1,0}, f32[2]{0}) const_a), condition=while_cond, body=while_body + const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body - out0 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_a) - ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_b) + out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b) } )"; string filecheck_pattern = R"( +CHECK: private constant [1 x float] CHECK: private constant [2 x float] -CHECK: private constant [2 x [1 x float]] +CHECK-NOT: private constant [1 x float] CHECK-NOT: private constant [2 x float] -CHECK-NOT: private constant [2 x [1 x float]] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index 879372eb13..1a948fb4fe 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -37,7 +37,7 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [2 x [3 x [2 x float]]] +CHECK: private constant [12 x float] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 1e0db2821a..547af33e9a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -94,7 +94,10 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) { << std::endl << " its type: " << llvm_ir::DumpToString(*global_for_const->getType()); - bindings_.BindHloToIrValue(*constant, global_for_const); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global_for_const, + llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); + bindings_.BindHloToIrValue(*constant, shape_constant); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index f172b1d87c..d909845a3a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -80,8 +80,10 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, /*Name=*/""); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); generators_[constant] = [=](const IrArray::Index& index) { - return IrArray(global, constant->shape()) + return IrArray(shape_constant, constant->shape()) .EmitReadArrayElement(index, ir_builder_); }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ec04239b4f..bd45f83fb1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -368,15 +368,52 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, return llvm::ConstantArray::get(aggregate_type, elements); } +template <typename T> +llvm::Constant* GetConstantDataArray(const Literal& literal, + llvm::Module* module) { + const T* data = static_cast<const T*>(literal.untyped_data()); + int64 num_elements = literal.size_bytes() / sizeof(T); + return llvm::ConstantDataArray::get(module->getContext(), + llvm::makeArrayRef(data, num_elements)); +} + } // namespace llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module) { - std::vector<int64> multi_index(ShapeUtil::Rank(literal.shape()), 0); - llvm::Constant* value = LiteralToConstant( - literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1, - &multi_index, module); - return value; + const Shape& shape = literal.shape(); + // TODO(b/29904935): We can get rid of this switch by exposing a + // ConstantDataArray factory method that takes a llvm::Type and a StringRef. + switch (shape.element_type()) { + case U64: + return GetConstantDataArray<uint64>(literal, module); + case U32: + return GetConstantDataArray<uint32>(literal, module); + case U8: + return GetConstantDataArray<uint8>(literal, module); + case S64: + return GetConstantDataArray<int64>(literal, module); + case S32: + return GetConstantDataArray<int32>(literal, module); + case F64: + return GetConstantDataArray<double>(literal, module); + case F32: + return GetConstantDataArray<float>(literal, module); + case BF16: + case F16: + return GetConstantDataArray<uint16>(literal, module); + case PRED: + return GetConstantDataArray<bool>(literal, module); + // TODO(b/29904935): Also use ConstantDataArray for complex numbers. + case C64: { + int64 dimensions = ShapeUtil::Rank(shape); + std::vector<int64> multi_index(dimensions, 0); + return LiteralToConstant(literal, /*dimension_index=*/dimensions - 1, + &multi_index, module); + } + default: + LOG(FATAL) << "unsupported type " << shape.element_type(); + } } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, |