aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-01 01:40:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-01 01:43:12 -0700
commitc9fb2a51307ca8597b7d2d436fcdd28a88e78ba5 (patch)
tree7367d118cfe84dc9c23f41b0e80aeaefd689de8f
parent73e5438b725b46e745e6e910c6557b51a321c70f (diff)
Use ConstantDataArray to lower arrays of constants.
For large constants, creating an llvm::Constant for each element can get prohibitively large compile times. PiperOrigin-RevId: 198843141
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc19
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h5
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc5
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc47
8 files changed, 78 insertions, 30 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index f6c8593632..a4141dee01 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -160,39 +160,44 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
return Status::OK();
}
-llvm::GlobalVariable* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
- llvm::GlobalVariable* result;
+llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
+ llvm::Constant* result;
// We avoid creating large constants in the LLVM IR since LLVM is not
// efficient for large constant arrays. We still emit "small enough" constant
// arrays into the Ir, in the off chance the LLVM optimizer can do something
// interesting with it.
+ //
+ // TODO(b/29904935): Remove the large constant pool.
const int kMaxInternalConstantSizeInBytes = 128;
if (external_constant_pool_ &&
ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) {
string global_name = tensorflow::strings::StrCat(
"constant_global_", external_global_constant_counter_++);
- result = new llvm::GlobalVariable(
+ llvm::GlobalVariable* result_global = new llvm::GlobalVariable(
/*Module=*/*module_,
/*Type=*/IrShapeType(literal.shape()),
/*isConstant=*/true,
/*Linkage=*/llvm::GlobalValue::ExternalLinkage,
/*Initializer=*/nullptr,
/*Name=*/AsStringRef(global_name));
- result->setAlignment(MinimumAlignmentForShape(literal.shape()));
+ result_global->setAlignment(MinimumAlignmentForShape(literal.shape()));
external_constant_pool_->Insert(global_name, literal,
MinimumAlignmentForShape(literal.shape()));
+ result = result_global;
} else {
llvm::Constant* initializer =
llvm_ir::ConvertLiteralToIrConstant(literal, module_);
- result = new llvm::GlobalVariable(
+ llvm::GlobalVariable* result_global = new llvm::GlobalVariable(
/*Module=*/*module_,
/*Type=*/initializer->getType(),
/*isConstant=*/true,
/*Linkage=*/llvm::GlobalValue::PrivateLinkage,
/*Initializer=*/initializer,
/*Name=*/"");
- result->setAlignment(MinimumAlignmentForShape(literal.shape()));
+ result_global->setAlignment(MinimumAlignmentForShape(literal.shape()));
+ result = llvm::ConstantExpr::getBitCast(
+ result_global, IrShapeType(literal.shape())->getPointerTo());
}
return result;
}
@@ -200,7 +205,7 @@ llvm::GlobalVariable* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
Status IrEmitter::HandleConstant(HloInstruction* constant) {
VLOG(2) << "HandleConstant: " << constant->ToString();
const Literal& literal = constant->literal();
- llvm::GlobalVariable* global_for_const;
+ llvm::Constant* global_for_const;
auto it = emitted_literals_.find(&literal);
if (it != emitted_literals_.end()) {
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index f49cfc1dc3..32c536e18f 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -527,7 +527,8 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
llvm::Value* program_buffer_address);
- llvm::GlobalVariable* EmitGlobalForLiteral(const Literal& literal);
+ // Returns a ConstExpr bitcast.
+ llvm::Constant* EmitGlobalForLiteral(const Literal& literal);
const HloModuleConfig& hlo_module_config_;
@@ -548,7 +549,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
}
};
- tensorflow::gtl::FlatMap<const Literal*, llvm::GlobalVariable*,
+ tensorflow::gtl::FlatMap<const Literal*, llvm::Constant*,
LiteralPtrHashFunctor, LiteralPtrEqualityFunctor>
emitted_literals_;
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc
index ed8f375bd6..faac927027 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc
@@ -64,8 +64,8 @@ TEST_F(CpuExternalConstantsTest, BasicNegative) {
// The constant array in this test case is small enough that there is no need
// to externalize it.
TestWithArray(/*rows=*/4, /*cols=*/4, R"(
-CHECK-NOT: @constant_global_0 = external constant [4 x [4 x float]], align 8
-CHECK: @0 = private constant [4 x [4 x float]] {{.*}}, align 8
+CHECK-NOT: @constant_global_0 = external constant [16 x float], align 8
+CHECK: @0 = private constant [16 x float] {{.*}}, align 8
)");
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
index d6e0425c55..3cb25c5c19 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
@@ -55,8 +55,8 @@ ENTRY main {
)";
string filecheck_pattern = R"(
-CHECK: private constant [2 x [3 x [2 x float]]]
-CHECK-NOT: private constant [2 x [3 x [2 x float]]]
+CHECK: private constant [12 x float]
+CHECK-NOT: private constant [12 x float]
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
@@ -78,30 +78,30 @@ TEST_F(CpuDuplicateConstantsTest, RepeatedTupleConstants) {
HloModule RepeatedConstants
while_body {
- arg_body = (f32[2,1]{1,0}, f32[2]{0}) parameter(0)
- ROOT const = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} ))
+ arg_body = (f32[2,1]{1,0}, f32[1]{0}) parameter(0)
+ ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} ))
}
while_cond {
- arg_cond = (f32[2,1]{1,0}, f32[2]{0}) parameter(0)
+ arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0)
ROOT unknown = pred[] infeed()
}
ENTRY main {
param = f32[2,3,2] parameter(0)
- const_a = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} ))
- const_b = (f32[2,1]{1,0}, f32[2]{0}) while((f32[2,1]{1,0}, f32[2]{0}) const_a), condition=while_cond, body=while_body
+ const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} ))
+ const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body
- out0 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_a)
- ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_b)
+ out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a)
+ ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b)
}
)";
string filecheck_pattern = R"(
+CHECK: private constant [1 x float]
CHECK: private constant [2 x float]
-CHECK: private constant [2 x [1 x float]]
+CHECK-NOT: private constant [1 x float]
CHECK-NOT: private constant [2 x float]
-CHECK-NOT: private constant [2 x [1 x float]]
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
index 879372eb13..1a948fb4fe 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
@@ -37,7 +37,7 @@ ENTRY main {
)";
string filecheck_pattern = R"(
-CHECK: private constant [2 x [3 x [2 x float]]]
+CHECK: private constant [12 x float]
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 1e0db2821a..547af33e9a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -94,7 +94,10 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) {
<< std::endl
<< " its type: "
<< llvm_ir::DumpToString(*global_for_const->getType());
- bindings_.BindHloToIrValue(*constant, global_for_const);
+ llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast(
+ global_for_const,
+ llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo());
+ bindings_.BindHloToIrValue(*constant, shape_constant);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
index f172b1d87c..d909845a3a 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
@@ -80,8 +80,10 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) {
*ir_builder_->GetInsertBlock()->getModule(), initializer->getType(),
/*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer,
/*Name=*/"");
+ llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast(
+ global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo());
generators_[constant] = [=](const IrArray::Index& index) {
- return IrArray(global, constant->shape())
+ return IrArray(shape_constant, constant->shape())
.EmitReadArrayElement(index, ir_builder_);
};
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index ec04239b4f..bd45f83fb1 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -368,15 +368,52 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
return llvm::ConstantArray::get(aggregate_type, elements);
}
+template <typename T>
+llvm::Constant* GetConstantDataArray(const Literal& literal,
+ llvm::Module* module) {
+ const T* data = static_cast<const T*>(literal.untyped_data());
+ int64 num_elements = literal.size_bytes() / sizeof(T);
+ return llvm::ConstantDataArray::get(module->getContext(),
+ llvm::makeArrayRef(data, num_elements));
+}
+
} // namespace
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
llvm::Module* module) {
- std::vector<int64> multi_index(ShapeUtil::Rank(literal.shape()), 0);
- llvm::Constant* value = LiteralToConstant(
- literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1,
- &multi_index, module);
- return value;
+ const Shape& shape = literal.shape();
+ // TODO(b/29904935): We can get rid of this switch by exposing a
+ // ConstantDataArray factory method that takes a llvm::Type and a StringRef.
+ switch (shape.element_type()) {
+ case U64:
+ return GetConstantDataArray<uint64>(literal, module);
+ case U32:
+ return GetConstantDataArray<uint32>(literal, module);
+ case U8:
+ return GetConstantDataArray<uint8>(literal, module);
+ case S64:
+ return GetConstantDataArray<int64>(literal, module);
+ case S32:
+ return GetConstantDataArray<int32>(literal, module);
+ case F64:
+ return GetConstantDataArray<double>(literal, module);
+ case F32:
+ return GetConstantDataArray<float>(literal, module);
+ case BF16:
+ case F16:
+ return GetConstantDataArray<uint16>(literal, module);
+ case PRED:
+ return GetConstantDataArray<bool>(literal, module);
+ // TODO(b/29904935): Also use ConstantDataArray for complex numbers.
+ case C64: {
+ int64 dimensions = ShapeUtil::Rank(shape);
+ std::vector<int64> multi_index(dimensions, 0);
+ return LiteralToConstant(literal, /*dimension_index=*/dimensions - 1,
+ &multi_index, module);
+ }
+ default:
+ LOG(FATAL) << "unsupported type " << shape.element_type();
+ }
}
llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,