aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/llvm_ir
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-06-20 05:47:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-20 05:50:01 -0700
commit352461a3228b13a6b5cc511487580ab4878d07dc (patch)
tree45b69e75e23ca20ee989d8d86a5b2d36c2737ff0 /tensorflow/compiler/xla/service/llvm_ir
parent18fd25c19c5c7111d1ba4a1c58718b87a63ad82c (diff)
Simplify ConvertLiteralToIrConstant()
Also use ConstantDataArray for C64 types. This allows to delete the old LiteralToDataConstant() method. PiperOrigin-RevId: 201339634
Diffstat (limited to 'tensorflow/compiler/xla/service/llvm_ir')
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc165
1 files changed, 7 insertions, 158 deletions
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index d18c9dee82..e61a2fd12d 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -249,167 +249,16 @@ StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr,
return shape;
}
-namespace {
-
-// Recursively construct a multidimensional LLVM constant which represents the
-// given literal. The minor-to-major dimension ordering in the constant matches
-// that of the literal. For example, given a [2 x 3 x 4] Literal (dimension 0
-// has size 4, dimension 1 has size 3, etc) of primitive type F32 with a
-// minor_to_major value of [2, 1, 0] (column major), a LLVM constant of type
-// [4 x [3 x [2 x float]] will be returned.
-//
-// multi_index is a multidimensional index into the array. dimension_index is an
-// index into the minor_to_major field in the literal shape. This determines
-// which dimension is iterated over in this level of the recursion. Dimensions
-// are iterated from most major down to most minor (highest dimension_index
-// value down to zero).
-llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
- std::vector<int64>* multi_index,
- llvm::Module* module) {
- const Shape& shape = literal.shape();
- llvm::Type* ir_element_type =
- llvm_ir::PrimitiveTypeToIrType(shape.element_type(), module);
- if (dimension_index == -1) {
- // Base case of the recursion. Index into the data field of the protobuf
- // with the multi index.
- llvm::Constant* value;
- switch (shape.element_type()) {
- case PRED:
- value = llvm::ConstantInt::get(ir_element_type,
- literal.Get<bool>(*multi_index));
- break;
- case U8:
- value = llvm::ConstantInt::get(ir_element_type,
- literal.Get<uint8>(*multi_index));
- break;
- case S32:
- value = llvm::ConstantInt::get(ir_element_type,
- literal.Get<int32>(*multi_index));
- break;
- case U32:
- value = llvm::ConstantInt::get(ir_element_type,
- literal.Get<uint32>(*multi_index));
- break;
- case S64:
- value = llvm::ConstantInt::get(ir_element_type,
- literal.Get<int64>(*multi_index));
- break;
- case U64:
- value = llvm::ConstantInt::get(ir_element_type,
- literal.Get<uint64>(*multi_index));
- break;
- case F32:
- value = llvm::ConstantFP::get(ir_element_type,
- literal.Get<float>(*multi_index));
- break;
- case BF16:
- value = llvm::ConstantInt::get(
- ir_element_type,
- tensorflow::bit_cast<uint16>(literal.Get<bfloat16>(*multi_index)));
- break;
- case F16:
- value = llvm::ConstantFP::get(
- ir_element_type,
- static_cast<float>(literal.Get<half>(*multi_index)));
- break;
- case F64:
- value = llvm::ConstantFP::get(ir_element_type,
- literal.Get<double>(*multi_index));
- break;
- case C64: {
- complex64 x = literal.Get<complex64>(*multi_index);
- value = llvm::ConstantStruct::get(
- static_cast<llvm::StructType*>(ir_element_type),
- llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module),
- x.real()),
- llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module),
- x.imag()));
- break;
- }
- default:
- LOG(FATAL) << "unsupported type " << shape.element_type();
- }
- return value;
- }
-
- // The dimension index starts at the one less than the rank of the array and
- // decrements with each recursive call. We want to iterate through the
- // dimensions in major-to-minor order as we recurse so just index into
- // minor_to_major to get the dimension number for this level of the recursion.
- int64 dimension = LayoutUtil::Minor(shape.layout(), dimension_index);
-
- // Recursively call LiteralToConstant to construct subarrays for the
- // more-minor dimensions. Gather the subarrays into a vector for bundling into
- // a new (higher-dimensional) ConstantArray.
- std::vector<llvm::Constant*> elements;
- for (int64 i = 0; i < shape.dimensions(dimension); ++i) {
- (*multi_index)[dimension] = i;
- elements.push_back(
- LiteralToConstant(literal, dimension_index - 1, multi_index, module));
- }
-
- llvm::Type* element_type;
- if (elements.empty()) {
- element_type = ir_element_type;
- for (int i = 0; i < dimension_index; ++i) {
- int64 index = LayoutUtil::Minor(shape.layout(), i);
- element_type =
- llvm::ArrayType::get(element_type, shape.dimensions(index));
- }
- } else {
- element_type = elements[0]->getType();
- }
- llvm::ArrayType* aggregate_type =
- llvm::ArrayType::get(element_type, shape.dimensions(dimension));
- return llvm::ConstantArray::get(aggregate_type, elements);
-}
-
-template <typename T>
-llvm::Constant* GetConstantDataArray(const Literal& literal,
- llvm::Module* module) {
- const T* data = static_cast<const T*>(literal.untyped_data());
- int64 num_elements = literal.size_bytes() / sizeof(T);
- return llvm::ConstantDataArray::get(module->getContext(),
- llvm::makeArrayRef(data, num_elements));
-}
-
-} // namespace
-
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
llvm::Module* module) {
const Shape& shape = literal.shape();
- // TODO(b/29904935): We can get rid of this switch by exposing a
- // ConstantDataArray factory method that takes a llvm::Type and a StringRef.
- switch (shape.element_type()) {
- case U64:
- return GetConstantDataArray<uint64>(literal, module);
- case U32:
- return GetConstantDataArray<uint32>(literal, module);
- case U8:
- return GetConstantDataArray<uint8>(literal, module);
- case S64:
- return GetConstantDataArray<int64>(literal, module);
- case S32:
- return GetConstantDataArray<int32>(literal, module);
- case F64:
- return GetConstantDataArray<double>(literal, module);
- case F32:
- return GetConstantDataArray<float>(literal, module);
- case BF16:
- case F16:
- return GetConstantDataArray<uint16>(literal, module);
- case PRED:
- return GetConstantDataArray<bool>(literal, module);
- // TODO(b/29904935): Also use ConstantDataArray for complex numbers.
- case C64: {
- int64 dimensions = ShapeUtil::Rank(shape);
- std::vector<int64> multi_index(dimensions, 0);
- return LiteralToConstant(literal, /*dimension_index=*/dimensions - 1,
- &multi_index, module);
- }
- default:
- LOG(FATAL) << "unsupported type " << shape.element_type();
- }
+ llvm::Type* type = shape.element_type() == C64
+ ? llvm::Type::getFloatTy(module->getContext())
+ : PrimitiveTypeToIrType(shape.element_type(), module);
+ const char* data = static_cast<const char*>(literal.untyped_data());
+ uint64 num_elements = literal.size_bytes() * 8 / GetSizeInBits(type);
+ return llvm::ConstantDataArray::getRaw(
+ llvm::StringRef(data, literal.size_bytes()), num_elements, type);
}
llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,