diff options
author | Adrian Kuegel <akuegel@google.com> | 2018-06-20 05:47:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-20 05:50:01 -0700 |
commit | 352461a3228b13a6b5cc511487580ab4878d07dc (patch) | |
tree | 45b69e75e23ca20ee989d8d86a5b2d36c2737ff0 /tensorflow/compiler/xla/service/llvm_ir | |
parent | 18fd25c19c5c7111d1ba4a1c58718b87a63ad82c (diff) |
Simplify ConvertLiteralToIrConstant()
Also use ConstantDataArray for C64 types.
This allows to delete the old LiteralToDataConstant() method.
PiperOrigin-RevId: 201339634
Diffstat (limited to 'tensorflow/compiler/xla/service/llvm_ir')
-rw-r--r-- | tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc | 165 |
1 files changed, 7 insertions, 158 deletions
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index d18c9dee82..e61a2fd12d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -249,167 +249,16 @@ StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr, return shape; } -namespace { - -// Recursively construct a multidimensional LLVM constant which represents the -// given literal. The minor-to-major dimension ordering in the constant matches -// that of the literal. For example, given a [2 x 3 x 4] Literal (dimension 0 -// has size 4, dimension 1 has size 3, etc) of primitive type F32 with a -// minor_to_major value of [2, 1, 0] (column major), a LLVM constant of type -// [4 x [3 x [2 x float]] will be returned. -// -// multi_index is a multidimensional index into the array. dimension_index is an -// index into the minor_to_major field in the literal shape. This determines -// which dimension is iterated over in this level of the recursion. Dimensions -// are iterated from most major down to most minor (highest dimension_index -// value down to zero). -llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, - std::vector<int64>* multi_index, - llvm::Module* module) { - const Shape& shape = literal.shape(); - llvm::Type* ir_element_type = - llvm_ir::PrimitiveTypeToIrType(shape.element_type(), module); - if (dimension_index == -1) { - // Base case of the recursion. Index into the data field of the protobuf - // with the multi index. - llvm::Constant* value; - switch (shape.element_type()) { - case PRED: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get<bool>(*multi_index)); - break; - case U8: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get<uint8>(*multi_index)); - break; - case S32: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get<int32>(*multi_index)); - break; - case U32: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get<uint32>(*multi_index)); - break; - case S64: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get<int64>(*multi_index)); - break; - case U64: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get<uint64>(*multi_index)); - break; - case F32: - value = llvm::ConstantFP::get(ir_element_type, - literal.Get<float>(*multi_index)); - break; - case BF16: - value = llvm::ConstantInt::get( - ir_element_type, - tensorflow::bit_cast<uint16>(literal.Get<bfloat16>(*multi_index))); - break; - case F16: - value = llvm::ConstantFP::get( - ir_element_type, - static_cast<float>(literal.Get<half>(*multi_index))); - break; - case F64: - value = llvm::ConstantFP::get(ir_element_type, - literal.Get<double>(*multi_index)); - break; - case C64: { - complex64 x = literal.Get<complex64>(*multi_index); - value = llvm::ConstantStruct::get( - static_cast<llvm::StructType*>(ir_element_type), - llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), - x.real()), - llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), - x.imag())); - break; - } - default: - LOG(FATAL) << "unsupported type " << shape.element_type(); - } - return value; - } - - // The dimension index starts at the one less than the rank of the array and - // decrements with each recursive call. We want to iterate through the - // dimensions in major-to-minor order as we recurse so just index into - // minor_to_major to get the dimension number for this level of the recursion. - int64 dimension = LayoutUtil::Minor(shape.layout(), dimension_index); - - // Recursively call LiteralToConstant to construct subarrays for the - // more-minor dimensions. Gather the subarrays into a vector for bundling into - // a new (higher-dimensional) ConstantArray. - std::vector<llvm::Constant*> elements; - for (int64 i = 0; i < shape.dimensions(dimension); ++i) { - (*multi_index)[dimension] = i; - elements.push_back( - LiteralToConstant(literal, dimension_index - 1, multi_index, module)); - } - - llvm::Type* element_type; - if (elements.empty()) { - element_type = ir_element_type; - for (int i = 0; i < dimension_index; ++i) { - int64 index = LayoutUtil::Minor(shape.layout(), i); - element_type = - llvm::ArrayType::get(element_type, shape.dimensions(index)); - } - } else { - element_type = elements[0]->getType(); - } - llvm::ArrayType* aggregate_type = - llvm::ArrayType::get(element_type, shape.dimensions(dimension)); - return llvm::ConstantArray::get(aggregate_type, elements); -} - -template <typename T> -llvm::Constant* GetConstantDataArray(const Literal& literal, - llvm::Module* module) { - const T* data = static_cast<const T*>(literal.untyped_data()); - int64 num_elements = literal.size_bytes() / sizeof(T); - return llvm::ConstantDataArray::get(module->getContext(), - llvm::makeArrayRef(data, num_elements)); -} - -} // namespace - llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module) { const Shape& shape = literal.shape(); - // TODO(b/29904935): We can get rid of this switch by exposing a - // ConstantDataArray factory method that takes a llvm::Type and a StringRef. - switch (shape.element_type()) { - case U64: - return GetConstantDataArray<uint64>(literal, module); - case U32: - return GetConstantDataArray<uint32>(literal, module); - case U8: - return GetConstantDataArray<uint8>(literal, module); - case S64: - return GetConstantDataArray<int64>(literal, module); - case S32: - return GetConstantDataArray<int32>(literal, module); - case F64: - return GetConstantDataArray<double>(literal, module); - case F32: - return GetConstantDataArray<float>(literal, module); - case BF16: - case F16: - return GetConstantDataArray<uint16>(literal, module); - case PRED: - return GetConstantDataArray<bool>(literal, module); - // TODO(b/29904935): Also use ConstantDataArray for complex numbers. - case C64: { - int64 dimensions = ShapeUtil::Rank(shape); - std::vector<int64> multi_index(dimensions, 0); - return LiteralToConstant(literal, /*dimension_index=*/dimensions - 1, - &multi_index, module); - } - default: - LOG(FATAL) << "unsupported type " << shape.element_type(); - } + llvm::Type* type = shape.element_type() == C64 + ? llvm::Type::getFloatTy(module->getContext()) + : PrimitiveTypeToIrType(shape.element_type(), module); + const char* data = static_cast<const char*>(literal.untyped_data()); + uint64 num_elements = literal.size_bytes() * 8 / GetSizeInBits(type); + return llvm::ConstantDataArray::getRaw( + llvm::StringRef(data, literal.size_bytes()), num_elements, type); } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, |