/* * Copyright 2018 Google Inc. * * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ #ifndef SKSL_STANDALONE #ifdef SK_LLVM_AVAILABLE #include "SkSLJIT.h" #include "SkCpu.h" #include "SkRasterPipeline.h" #include "../jumper/SkJumper.h" #include "ir/SkSLExpressionStatement.h" #include "ir/SkSLFunctionCall.h" #include "ir/SkSLFunctionReference.h" #include "ir/SkSLIndexExpression.h" #include "ir/SkSLProgram.h" #include "llvm/ExecutionEngine/RTDyldMemoryManager.h" static constexpr int MAX_VECTOR_COUNT = 16; extern "C" void sksl_pipeline_append(SkRasterPipeline* p, int stage, void* ctx) { p->append((SkRasterPipeline::StockStage) stage, ctx); } #define PTR_SIZE sizeof(void*) extern "C" void sksl_pipeline_append_callback(SkRasterPipeline* p, void* fn) { p->append(fn, nullptr); } extern "C" void sksl_debug_print(float f) { printf("Debug: %f\n", f); } namespace SkSL { static constexpr int STAGE_PARAM_COUNT = 12; static bool ends_with_branch(const Statement& stmt) { switch (stmt.fKind) { case Statement::kBlock_Kind: { const Block& b = (const Block&) stmt; if (b.fStatements.size()) { return ends_with_branch(*b.fStatements.back()); } return false; } case Statement::kBreak_Kind: // fall through case Statement::kContinue_Kind: // fall through case Statement::kReturn_Kind: // fall through return true; default: return false; } } JIT::JIT(Compiler* compiler) : fCompiler(*compiler) { LLVMInitializeNativeTarget(); LLVMInitializeNativeAsmPrinter(); LLVMLinkInMCJIT(); SkASSERT(!SkCpu::Supports(SkCpu::SKX)); // not yet supported if (SkCpu::Supports(SkCpu::HSW)) { fVectorCount = 8; fCPU = "haswell"; } else if (SkCpu::Supports(SkCpu::AVX)) { fVectorCount = 8; fCPU = "ivybridge"; } else { fVectorCount = 4; fCPU = nullptr; } fContext = LLVMContextCreate(); fVoidType = LLVMVoidTypeInContext(fContext); fInt1Type = LLVMInt1TypeInContext(fContext); fInt8Type = LLVMInt8TypeInContext(fContext); fInt8PtrType = LLVMPointerType(fInt8Type, 0); fInt32Type = LLVMInt32TypeInContext(fContext); fInt64Type = LLVMInt64TypeInContext(fContext); fSizeTType = LLVMInt64TypeInContext(fContext); fInt32VectorType = LLVMVectorType(fInt32Type, fVectorCount); fInt32Vector2Type = LLVMVectorType(fInt32Type, 2); fInt32Vector3Type = LLVMVectorType(fInt32Type, 3); fInt32Vector4Type = LLVMVectorType(fInt32Type, 4); fFloat32Type = LLVMFloatTypeInContext(fContext); fFloat32VectorType = LLVMVectorType(fFloat32Type, fVectorCount); fFloat32Vector2Type = LLVMVectorType(fFloat32Type, 2); fFloat32Vector3Type = LLVMVectorType(fFloat32Type, 3); fFloat32Vector4Type = LLVMVectorType(fFloat32Type, 4); } JIT::~JIT() { LLVMOrcDisposeInstance(fJITStack); LLVMContextDispose(fContext); } void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType, std::vector parameters) { for (const auto& pair : *fProgram->fSymbols) { if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) { const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second; if (pair.first != ourName || returnType != this->getType(f.fReturnType) || parameters.size() != f.fParameters.size()) { continue; } for (size_t i = 0; i < parameters.size(); ++i) { if (parameters[i] != this->getType(f.fParameters[i]->fType)) { goto next; } } fFunctions[&f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(returnType, parameters.data(), parameters.size(), false)); } next:; } } void JIT::loadBuiltinFunctions() { this->addBuiltinFunction("abs", "fabs", fFloat32Type, { fFloat32Type }); this->addBuiltinFunction("sin", "sinf", fFloat32Type, { fFloat32Type }); this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type }); this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type }); this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type }); this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type }); } uint64_t JIT::resolveSymbol(const char* name, JIT* jit) { LLVMOrcTargetAddress result; if (!LLVMOrcGetSymbolAddress(jit->fJITStack, &result, name)) { if (!strcmp(name, "_sksl_pipeline_append")) { result = (uint64_t) &sksl_pipeline_append; } else if (!strcmp(name, "_sksl_pipeline_append_callback")) { result = (uint64_t) &sksl_pipeline_append_callback; } else if (!strcmp(name, "_sksl_debug_print")) { result = (uint64_t) &sksl_debug_print; } else { result = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name); } } SkASSERT(result); return result; } LLVMValueRef JIT::compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc) { LLVMValueRef func = fFunctions[&fc.fFunction]; SkASSERT(func); std::vector parameters; for (const auto& a : fc.fArguments) { parameters.push_back(this->compileExpression(builder, *a)); } return LLVMBuildCall(builder, func, parameters.data(), parameters.size(), ""); } LLVMTypeRef JIT::getType(const Type& type) { switch (type.kind()) { case Type::kOther_Kind: if (type.name() == "void") { return fVoidType; } SkASSERT(type.name() == "SkRasterPipeline"); return fInt8PtrType; case Type::kScalar_Kind: if (type.isSigned() || type.isUnsigned()) { return fInt32Type; } if (type.isUnsigned()) { return fInt32Type; } if (type.isFloat()) { return fFloat32Type; } SkASSERT(type.name() == "bool"); return fInt1Type; case Type::kArray_Kind: return LLVMPointerType(this->getType(type.componentType()), 0); case Type::kVector_Kind: if (type.name() == "float2" || type.name() == "half2") { return fFloat32Vector2Type; } if (type.name() == "float3" || type.name() == "half3") { return fFloat32Vector3Type; } if (type.name() == "float4" || type.name() == "half4") { return fFloat32Vector4Type; } if (type.name() == "int2" || type.name() == "short2" || type.name == "byte2") { return fInt32Vector2Type; } if (type.name() == "int3" || type.name() == "short3" || type.name == "byte3") { return fInt32Vector3Type; } if (type.name() == "int4" || type.name() == "short4" || type.name == "byte3") { return fInt32Vector4Type; } // fall through default: ABORT("unsupported type"); } } void JIT::setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block) { fCurrentBlock = block; LLVMPositionBuilderAtEnd(builder, block); } std::unique_ptr JIT::getLValue(LLVMBuilderRef builder, const Expression& expr) { switch (expr.fKind) { case Expression::kVariableReference_Kind: { class PointerLValue : public LValue { public: PointerLValue(LLVMValueRef ptr) : fPointer(ptr) {} LLVMValueRef load(LLVMBuilderRef builder) override { return LLVMBuildLoad(builder, fPointer, "lvalue load"); } void store(LLVMBuilderRef builder, LLVMValueRef value) override { LLVMBuildStore(builder, value, fPointer); } private: LLVMValueRef fPointer; }; const Variable* var = &((VariableReference&) expr).fVariable; if (var->fStorage == Variable::kParameter_Storage && !(var->fModifiers.fFlags & Modifiers::kOut_Flag) && fPromotedParameters.find(var) == fPromotedParameters.end()) { // promote parameter to variable fPromotedParameters.insert(var); LLVMPositionBuilderAtEnd(builder, fAllocaBlock); LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(var->fType), String(var->fName).c_str()); LLVMBuildStore(builder, fVariables[var], alloca); LLVMPositionBuilderAtEnd(builder, fCurrentBlock); fVariables[var] = alloca; } LLVMValueRef ptr = fVariables[var]; return std::unique_ptr(new PointerLValue(ptr)); } case Expression::kTernary_Kind: { class TernaryLValue : public LValue { public: TernaryLValue(JIT* jit, LLVMValueRef test, std::unique_ptr ifTrue, std::unique_ptr ifFalse) : fJIT(*jit) , fTest(test) , fIfTrue(std::move(ifTrue)) , fIfFalse(std::move(ifFalse)) {} LLVMValueRef load(LLVMBuilderRef builder) override { LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext( fJIT.fContext, fJIT.fCurrentFunction, "true ? ..."); LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext( fJIT.fContext, fJIT.fCurrentFunction, "false ? ..."); LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext, fJIT.fCurrentFunction, "ternary merge"); LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock); fJIT.setBlock(builder, trueBlock); LLVMValueRef ifTrue = fIfTrue->load(builder); LLVMBuildBr(builder, merge); fJIT.setBlock(builder, falseBlock); LLVMValueRef ifFalse = fIfTrue->load(builder); LLVMBuildBr(builder, merge); fJIT.setBlock(builder, merge); LLVMTypeRef type = LLVMPointerType(LLVMTypeOf(ifTrue), 0); LLVMValueRef phi = LLVMBuildPhi(builder, type, "?"); LLVMValueRef incomingValues[2] = { ifTrue, ifFalse }; LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock }; LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2); return phi; } void store(LLVMBuilderRef builder, LLVMValueRef value) override { LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext( fJIT.fContext, fJIT.fCurrentFunction, "true ? ..."); LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext( fJIT.fContext, fJIT.fCurrentFunction, "false ? ..."); LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext, fJIT.fCurrentFunction, "ternary merge"); LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock); fJIT.setBlock(builder, trueBlock); fIfTrue->store(builder, value); LLVMBuildBr(builder, merge); fJIT.setBlock(builder, falseBlock); fIfTrue->store(builder, value); LLVMBuildBr(builder, merge); fJIT.setBlock(builder, merge); } private: JIT& fJIT; LLVMValueRef fTest; std::unique_ptr fIfTrue; std::unique_ptr fIfFalse; }; const TernaryExpression& t = (const TernaryExpression&) expr; LLVMValueRef test = this->compileExpression(builder, *t.fTest); return std::unique_ptr(new TernaryLValue(this, test, this->getLValue(builder, *t.fIfTrue), this->getLValue(builder, *t.fIfFalse))); } case Expression::kSwizzle_Kind: { class SwizzleLValue : public LValue { public: SwizzleLValue(JIT* jit, LLVMTypeRef type, std::unique_ptr base, std::vector components) : fJIT(*jit) , fType(type) , fBase(std::move(base)) , fComponents(components) {} LLVMValueRef load(LLVMBuilderRef builder) override { LLVMValueRef base = fBase->load(builder); if (fComponents.size() > 1) { LLVMValueRef result = LLVMGetUndef(fType); for (size_t i = 0; i < fComponents.size(); ++i) { LLVMValueRef element = LLVMBuildExtractElement( builder, base, LLVMConstInt(fJIT.fInt32Type, fComponents[i], false), "swizzle extract"); result = LLVMBuildInsertElement(builder, result, element, LLVMConstInt(fJIT.fInt32Type, i, false), "swizzle insert"); } return result; } SkASSERT(fComponents.size() == 1); return LLVMBuildExtractElement(builder, base, LLVMConstInt(fJIT.fInt32Type, fComponents[0], false), "swizzle extract"); } void store(LLVMBuilderRef builder, LLVMValueRef value) override { LLVMValueRef result = fBase->load(builder); if (fComponents.size() > 1) { for (size_t i = 0; i < fComponents.size(); ++i) { LLVMValueRef element = LLVMBuildExtractElement(builder, value, LLVMConstInt( fJIT.fInt32Type, i, false), "swizzle extract"); result = LLVMBuildInsertElement(builder, result, element, LLVMConstInt(fJIT.fInt32Type, fComponents[i], false), "swizzle insert"); } } else { result = LLVMBuildInsertElement(builder, result, value, LLVMConstInt(fJIT.fInt32Type, fComponents[0], false), "swizzle insert"); } fBase->store(builder, result); } private: JIT& fJIT; LLVMTypeRef fType; std::unique_ptr fBase; std::vector fComponents; }; const Swizzle& s = (const Swizzle&) expr; return std::unique_ptr(new SwizzleLValue(this, this->getType(s.fType), this->getLValue(builder, *s.fBase), s.fComponents)); } default: ABORT("unsupported lvalue"); } } JIT::TypeKind JIT::typeKind(const Type& type) { if (type.kind() == Type::kVector_Kind) { return this->typeKind(type.componentType()); } if (type.fName == "int" || type.fName == "short" || type.fName == "byte") { return JIT::kInt_TypeKind; } else if (type.fName == "uint" || type.fName == "ushort" || type.fName == "ubyte") { return JIT::kUInt_TypeKind; } else if (type.fName == "float" || type.fName == "double") { return JIT::kFloat_TypeKind; } ABORT("unsupported type: %s\n", type.description().c_str()); } void JIT::vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns) { LLVMValueRef result = LLVMGetUndef(LLVMVectorType(LLVMTypeOf(*value), columns)); for (int i = 0; i < columns; ++i) { result = LLVMBuildInsertElement(builder, result, *value, LLVMConstInt(fInt32Type, i, false), "vectorize"); } *value = result; } void JIT::vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left, LLVMValueRef* right) { if (b.fLeft->fType.kind() == Type::kScalar_Kind && b.fRight->fType.kind() == Type::kVector_Kind) { this->vectorize(builder, left, b.fRight->fType.columns()); } else if (b.fLeft->fType.kind() == Type::kVector_Kind && b.fRight->fType.kind() == Type::kScalar_Kind) { this->vectorize(builder, right, b.fLeft->fType.columns()); } } LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& b) { #define BINARY(SFunc, UFunc, FFunc) { \ LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \ LLVMValueRef right = this->compileExpression(builder, *b.fRight); \ this->vectorize(builder, b, &left, &right); \ switch (this->typeKind(b.fLeft->fType)) { \ case kInt_TypeKind: \ return SFunc(builder, left, right, "binary"); \ case kUInt_TypeKind: \ return UFunc(builder, left, right, "binary"); \ case kFloat_TypeKind: \ return FFunc(builder, left, right, "binary"); \ default: \ ABORT("unsupported typeKind"); \ } \ } #define COMPOUND(SFunc, UFunc, FFunc) { \ std::unique_ptr lvalue = this->getLValue(builder, *b.fLeft); \ LLVMValueRef left = lvalue->load(builder); \ LLVMValueRef right = this->compileExpression(builder, *b.fRight); \ this->vectorize(builder, b, &left, &right); \ LLVMValueRef result; \ switch (this->typeKind(b.fLeft->fType)) { \ case kInt_TypeKind: \ result = SFunc(builder, left, right, "binary"); \ break; \ case kUInt_TypeKind: \ result = UFunc(builder, left, right, "binary"); \ break; \ case kFloat_TypeKind: \ result = FFunc(builder, left, right, "binary"); \ break; \ default: \ ABORT("unsupported typeKind"); \ } \ lvalue->store(builder, result); \ return result; \ } #define COMPARE(SFunc, SOp, UFunc, UOp, FFunc, FOp) { \ LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \ LLVMValueRef right = this->compileExpression(builder, *b.fRight); \ this->vectorize(builder, b, &left, &right); \ switch (this->typeKind(b.fLeft->fType)) { \ case kInt_TypeKind: \ return SFunc(builder, SOp, left, right, "binary"); \ case kUInt_TypeKind: \ return UFunc(builder, UOp, left, right, "binary"); \ case kFloat_TypeKind: \ return FFunc(builder, FOp, left, right, "binary"); \ default: \ ABORT("unsupported typeKind"); \ } \ } switch (b.fOperator) { case Token::EQ: { std::unique_ptr lvalue = this->getLValue(builder, *b.fLeft); LLVMValueRef result = this->compileExpression(builder, *b.fRight); lvalue->store(builder, result); return result; } case Token::PLUS: BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd); case Token::MINUS: BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub); case Token::STAR: BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul); case Token::SLASH: BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv); case Token::PERCENT: BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem); case Token::BITWISEAND: BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd); case Token::BITWISEOR: BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr); case Token::PLUSEQ: COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd); case Token::MINUSEQ: COMPOUND(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub); case Token::STAREQ: COMPOUND(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul); case Token::SLASHEQ: COMPOUND(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv); case Token::BITWISEANDEQ: COMPOUND(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd); case Token::BITWISEOREQ: COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr); case Token::EQEQ: COMPARE(LLVMBuildICmp, LLVMIntEQ, LLVMBuildICmp, LLVMIntEQ, LLVMBuildFCmp, LLVMRealOEQ); case Token::NEQ: COMPARE(LLVMBuildICmp, LLVMIntNE, LLVMBuildICmp, LLVMIntNE, LLVMBuildFCmp, LLVMRealONE); case Token::LT: COMPARE(LLVMBuildICmp, LLVMIntSLT, LLVMBuildICmp, LLVMIntULT, LLVMBuildFCmp, LLVMRealOLT); case Token::LTEQ: COMPARE(LLVMBuildICmp, LLVMIntSLE, LLVMBuildICmp, LLVMIntULE, LLVMBuildFCmp, LLVMRealOLE); case Token::GT: COMPARE(LLVMBuildICmp, LLVMIntSGT, LLVMBuildICmp, LLVMIntUGT, LLVMBuildFCmp, LLVMRealOGT); case Token::GTEQ: COMPARE(LLVMBuildICmp, LLVMIntSGE, LLVMBuildICmp, LLVMIntUGE, LLVMBuildFCmp, LLVMRealOGE); case Token::LOGICALAND: { LLVMValueRef left = this->compileExpression(builder, *b.fLeft); LLVMBasicBlockRef ifFalse = fCurrentBlock; LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "true && ..."); LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "&& merge"); LLVMBuildCondBr(builder, left, ifTrue, merge); this->setBlock(builder, ifTrue); LLVMValueRef right = this->compileExpression(builder, *b.fRight); LLVMBuildBr(builder, merge); this->setBlock(builder, merge); LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "&&"); LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 0, false) }; LLVMBasicBlockRef incomingBlocks[2] = { ifTrue, ifFalse }; LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2); return phi; } case Token::LOGICALOR: { LLVMValueRef left = this->compileExpression(builder, *b.fLeft); LLVMBasicBlockRef ifTrue = fCurrentBlock; LLVMBasicBlockRef ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "false || ..."); LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "|| merge"); LLVMBuildCondBr(builder, left, merge, ifFalse); this->setBlock(builder, ifFalse); LLVMValueRef right = this->compileExpression(builder, *b.fRight); LLVMBuildBr(builder, merge); this->setBlock(builder, merge); LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "||"); LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 1, false) }; LLVMBasicBlockRef incomingBlocks[2] = { ifFalse, ifTrue }; LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2); return phi; } default: ABORT("unsupported binary operator"); } } LLVMValueRef JIT::compileIndex(LLVMBuilderRef builder, const IndexExpression& idx) { LLVMValueRef base = this->compileExpression(builder, *idx.fBase); LLVMValueRef index = this->compileExpression(builder, *idx.fIndex); LLVMValueRef ptr = LLVMBuildGEP(builder, base, &index, 1, "index ptr"); return LLVMBuildLoad(builder, ptr, "index load"); } LLVMValueRef JIT::compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p) { std::unique_ptr lvalue = this->getLValue(builder, *p.fOperand); LLVMValueRef result = lvalue->load(builder); LLVMValueRef mod; LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false); switch (p.fOperator) { case Token::PLUSPLUS: switch (this->typeKind(p.fType)) { case kInt_TypeKind: // fall through case kUInt_TypeKind: mod = LLVMBuildAdd(builder, result, one, "++"); break; case kFloat_TypeKind: mod = LLVMBuildFAdd(builder, result, one, "++"); break; default: ABORT("unsupported typeKind"); } break; case Token::MINUSMINUS: switch (this->typeKind(p.fType)) { case kInt_TypeKind: // fall through case kUInt_TypeKind: mod = LLVMBuildSub(builder, result, one, "--"); break; case kFloat_TypeKind: mod = LLVMBuildFSub(builder, result, one, "--"); break; default: ABORT("unsupported typeKind"); } break; default: ABORT("unsupported postfix op"); } lvalue->store(builder, mod); return result; } LLVMValueRef JIT::compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p) { LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false); if (Token::LOGICALNOT == p.fOperator) { LLVMValueRef base = this->compileExpression(builder, *p.fOperand); return LLVMBuildXor(builder, base, one, "!"); } if (Token::MINUS == p.fOperator) { LLVMValueRef base = this->compileExpression(builder, *p.fOperand); return LLVMBuildSub(builder, LLVMConstInt(this->getType(p.fType), 0, false), base, "-"); } std::unique_ptr lvalue = this->getLValue(builder, *p.fOperand); LLVMValueRef raw = lvalue->load(builder); LLVMValueRef result; switch (p.fOperator) { case Token::PLUSPLUS: switch (this->typeKind(p.fType)) { case kInt_TypeKind: // fall through case kUInt_TypeKind: result = LLVMBuildAdd(builder, raw, one, "++"); break; case kFloat_TypeKind: result = LLVMBuildFAdd(builder, raw, one, "++"); break; default: ABORT("unsupported typeKind"); } break; case Token::MINUSMINUS: switch (this->typeKind(p.fType)) { case kInt_TypeKind: // fall through case kUInt_TypeKind: result = LLVMBuildSub(builder, raw, one, "--"); break; case kFloat_TypeKind: result = LLVMBuildFSub(builder, raw, one, "--"); break; default: ABORT("unsupported typeKind"); } break; default: ABORT("unsupported prefix op"); } lvalue->store(builder, result); return result; } LLVMValueRef JIT::compileVariableReference(LLVMBuilderRef builder, const VariableReference& v) { const Variable& var = v.fVariable; if (Variable::kParameter_Storage == var.fStorage && !(var.fModifiers.fFlags & Modifiers::kOut_Flag) && fPromotedParameters.find(&var) == fPromotedParameters.end()) { return fVariables[&var]; } return LLVMBuildLoad(builder, fVariables[&var], String(var.fName).c_str()); } void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) { SkASSERT(a.fArguments.size() >= 1); SkASSERT(a.fArguments[0]->fType == *fCompiler.context().fSkRasterPipeline_Type); LLVMValueRef pipeline = this->compileExpression(builder, *a.fArguments[0]); LLVMValueRef stage = LLVMConstInt(fInt32Type, a.fStage, 0); switch (a.fStage) { case SkRasterPipeline::callback: { SkASSERT(a.fArguments.size() == 2); SkASSERT(a.fArguments[1]->fKind == Expression::kFunctionReference_Kind); const FunctionDeclaration& functionDecl = *((FunctionReference&) *a.fArguments[1]).fFunctions[0]; bool found = false; for (const auto& pe : fProgram->fElements) { if (ProgramElement::kFunction_Kind == pe->fKind) { const FunctionDefinition& def = (const FunctionDefinition&) *pe; if (&def.fDeclaration == &functionDecl) { LLVMValueRef fn = this->compileStageFunction(def); LLVMValueRef args[2] = { pipeline, LLVMBuildBitCast(builder, fn, fInt8PtrType, "callback cast") }; LLVMBuildCall(builder, fAppendCallbackFunc, args, 2, ""); found = true; break; } } } SkASSERT(found); break; } default: { LLVMValueRef ctx; if (a.fArguments.size() == 2) { ctx = this->compileExpression(builder, *a.fArguments[1]); ctx = LLVMBuildBitCast(builder, ctx, fInt8PtrType, "context cast"); } else { SkASSERT(a.fArguments.size() == 1); ctx = LLVMConstNull(fInt8PtrType); } LLVMValueRef args[3] = { pipeline, stage, ctx }; LLVMBuildCall(builder, fAppendFunc, args, 3, ""); break; } } } LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor& c) { switch (c.fType.kind()) { case Type::kScalar_Kind: { SkASSERT(c.fArguments.size() == 1); TypeKind from = this->typeKind(c.fArguments[0]->fType); TypeKind to = this->typeKind(c.fType); LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]); if (kFloat_TypeKind == to) { if (kInt_TypeKind == from) { return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast"); } if (kUInt_TypeKind == from) { return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast"); } } if (kInt_TypeKind == to) { if (kFloat_TypeKind == from) { return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast"); } if (kUInt_TypeKind == from) { return base; } } if (kUInt_TypeKind == to) { if (kFloat_TypeKind == from) { return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast"); } if (kInt_TypeKind == from) { return base; } } ABORT("unsupported constructor"); } case Type::kVector_Kind: { LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType)); if (c.fArguments.size() == 1) { LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]); for (int i = 0; i < c.fType.columns(); ++i) { vec = LLVMBuildInsertElement(builder, vec, value, LLVMConstInt(fInt32Type, i, false), "vec build"); } } else { SkASSERT(c.fArguments.size() == (size_t) c.fType.columns()); for (int i = 0; i < c.fType.columns(); ++i) { vec = LLVMBuildInsertElement(builder, vec, this->compileExpression(builder, *c.fArguments[i]), LLVMConstInt(fInt32Type, i, false), "vec build"); } } return vec; } default: break; } ABORT("unsupported constructor"); } LLVMValueRef JIT::compileSwizzle(LLVMBuilderRef builder, const Swizzle& s) { LLVMValueRef base = this->compileExpression(builder, *s.fBase); if (s.fComponents.size() > 1) { LLVMValueRef result = LLVMGetUndef(this->getType(s.fType)); for (size_t i = 0; i < s.fComponents.size(); ++i) { LLVMValueRef element = LLVMBuildExtractElement( builder, base, LLVMConstInt(fInt32Type, s.fComponents[i], false), "swizzle extract"); result = LLVMBuildInsertElement(builder, result, element, LLVMConstInt(fInt32Type, i, false), "swizzle insert"); } return result; } SkASSERT(s.fComponents.size() == 1); return LLVMBuildExtractElement(builder, base, LLVMConstInt(fInt32Type, s.fComponents[0], false), "swizzle extract"); } LLVMValueRef JIT::compileTernary(LLVMBuilderRef builder, const TernaryExpression& t) { LLVMValueRef test = this->compileExpression(builder, *t.fTest); LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if true"); LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if merge"); LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if false"); LLVMBuildCondBr(builder, test, trueBlock, falseBlock); this->setBlock(builder, trueBlock); LLVMValueRef ifTrue = this->compileExpression(builder, *t.fIfTrue); trueBlock = fCurrentBlock; LLVMBuildBr(builder, merge); this->setBlock(builder, falseBlock); LLVMValueRef ifFalse = this->compileExpression(builder, *t.fIfFalse); falseBlock = fCurrentBlock; LLVMBuildBr(builder, merge); this->setBlock(builder, merge); LLVMValueRef phi = LLVMBuildPhi(builder, this->getType(t.fType), "?"); LLVMValueRef incomingValues[2] = { ifTrue, ifFalse }; LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock }; LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2); return phi; } LLVMValueRef JIT::compileExpression(LLVMBuilderRef builder, const Expression& expr) { switch (expr.fKind) { case Expression::kAppendStage_Kind: { this->appendStage(builder, (const AppendStage&) expr); return LLVMValueRef(); } case Expression::kBinary_Kind: return this->compileBinary(builder, (BinaryExpression&) expr); case Expression::kBoolLiteral_Kind: return LLVMConstInt(fInt1Type, ((BoolLiteral&) expr).fValue, false); case Expression::kConstructor_Kind: return this->compileConstructor(builder, (Constructor&) expr); case Expression::kIntLiteral_Kind: return LLVMConstInt(this->getType(expr.fType), ((IntLiteral&) expr).fValue, true); case Expression::kFieldAccess_Kind: abort(); case Expression::kFloatLiteral_Kind: return LLVMConstReal(this->getType(expr.fType), ((FloatLiteral&) expr).fValue); case Expression::kFunctionCall_Kind: return this->compileFunctionCall(builder, (FunctionCall&) expr); case Expression::kIndex_Kind: return this->compileIndex(builder, (IndexExpression&) expr); case Expression::kPrefix_Kind: return this->compilePrefix(builder, (PrefixExpression&) expr); case Expression::kPostfix_Kind: return this->compilePostfix(builder, (PostfixExpression&) expr); case Expression::kSetting_Kind: abort(); case Expression::kSwizzle_Kind: return this->compileSwizzle(builder, (Swizzle&) expr); case Expression::kVariableReference_Kind: return this->compileVariableReference(builder, (VariableReference&) expr); case Expression::kTernary_Kind: return this->compileTernary(builder, (TernaryExpression&) expr); case Expression::kTypeReference_Kind: abort(); default: abort(); } ABORT("unsupported expression: %s\n", expr.description().c_str()); } void JIT::compileBlock(LLVMBuilderRef builder, const Block& block) { for (const auto& stmt : block.fStatements) { this->compileStatement(builder, *stmt); } } void JIT::compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls) { for (const auto& declStatement : decls.fDeclaration->fVars) { const VarDeclaration& decl = (VarDeclaration&) *declStatement; LLVMPositionBuilderAtEnd(builder, fAllocaBlock); LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(decl.fVar->fType), String(decl.fVar->fName).c_str()); fVariables[decl.fVar] = alloca; LLVMPositionBuilderAtEnd(builder, fCurrentBlock); if (decl.fValue) { LLVMValueRef result = this->compileExpression(builder, *decl.fValue); LLVMBuildStore(builder, result, alloca); } } } void JIT::compileIf(LLVMBuilderRef builder, const IfStatement& i) { LLVMValueRef test = this->compileExpression(builder, *i.fTest); LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if true"); LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if merge"); LLVMBasicBlockRef ifFalse; if (i.fIfFalse) { ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if false"); } else { ifFalse = merge; } LLVMBuildCondBr(builder, test, ifTrue, ifFalse); this->setBlock(builder, ifTrue); this->compileStatement(builder, *i.fIfTrue); if (!ends_with_branch(*i.fIfTrue)) { LLVMBuildBr(builder, merge); } if (i.fIfFalse) { this->setBlock(builder, ifFalse); this->compileStatement(builder, *i.fIfFalse); if (!ends_with_branch(*i.fIfFalse)) { LLVMBuildBr(builder, merge); } } this->setBlock(builder, merge); } void JIT::compileFor(LLVMBuilderRef builder, const ForStatement& f) { if (f.fInitializer) { this->compileStatement(builder, *f.fInitializer); } LLVMBasicBlockRef start; LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for body"); LLVMBasicBlockRef next = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for next"); LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for end"); if (f.fTest) { start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for test"); LLVMBuildBr(builder, start); this->setBlock(builder, start); LLVMValueRef test = this->compileExpression(builder, *f.fTest); LLVMBuildCondBr(builder, test, body, end); } else { start = body; LLVMBuildBr(builder, body); } this->setBlock(builder, body); fBreakTarget.push_back(end); fContinueTarget.push_back(next); this->compileStatement(builder, *f.fStatement); fBreakTarget.pop_back(); fContinueTarget.pop_back(); if (!ends_with_branch(*f.fStatement)) { LLVMBuildBr(builder, next); } this->setBlock(builder, next); if (f.fNext) { this->compileExpression(builder, *f.fNext); } LLVMBuildBr(builder, start); this->setBlock(builder, end); } void JIT::compileDo(LLVMBuilderRef builder, const DoStatement& d) { LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "do test"); LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "do body"); LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "do end"); LLVMBuildBr(builder, body); this->setBlock(builder, testBlock); LLVMValueRef test = this->compileExpression(builder, *d.fTest); LLVMBuildCondBr(builder, test, body, end); this->setBlock(builder, body); fBreakTarget.push_back(end); fContinueTarget.push_back(body); this->compileStatement(builder, *d.fStatement); fBreakTarget.pop_back(); fContinueTarget.pop_back(); if (!ends_with_branch(*d.fStatement)) { LLVMBuildBr(builder, testBlock); } this->setBlock(builder, end); } void JIT::compileWhile(LLVMBuilderRef builder, const WhileStatement& w) { LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "while test"); LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "while body"); LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "while end"); LLVMBuildBr(builder, testBlock); this->setBlock(builder, testBlock); LLVMValueRef test = this->compileExpression(builder, *w.fTest); LLVMBuildCondBr(builder, test, body, end); this->setBlock(builder, body); fBreakTarget.push_back(end); fContinueTarget.push_back(testBlock); this->compileStatement(builder, *w.fStatement); fBreakTarget.pop_back(); fContinueTarget.pop_back(); if (!ends_with_branch(*w.fStatement)) { LLVMBuildBr(builder, testBlock); } this->setBlock(builder, end); } void JIT::compileBreak(LLVMBuilderRef builder, const BreakStatement& b) { LLVMBuildBr(builder, fBreakTarget.back()); } void JIT::compileContinue(LLVMBuilderRef builder, const ContinueStatement& b) { LLVMBuildBr(builder, fContinueTarget.back()); } void JIT::compileReturn(LLVMBuilderRef builder, const ReturnStatement& r) { if (r.fExpression) { LLVMBuildRet(builder, this->compileExpression(builder, *r.fExpression)); } else { LLVMBuildRetVoid(builder); } } void JIT::compileStatement(LLVMBuilderRef builder, const Statement& stmt) { switch (stmt.fKind) { case Statement::kBlock_Kind: this->compileBlock(builder, (Block&) stmt); break; case Statement::kBreak_Kind: this->compileBreak(builder, (BreakStatement&) stmt); break; case Statement::kContinue_Kind: this->compileContinue(builder, (ContinueStatement&) stmt); break; case Statement::kDiscard_Kind: abort(); case Statement::kDo_Kind: this->compileDo(builder, (DoStatement&) stmt); break; case Statement::kExpression_Kind: this->compileExpression(builder, *((ExpressionStatement&) stmt).fExpression); break; case Statement::kFor_Kind: this->compileFor(builder, (ForStatement&) stmt); break; case Statement::kGroup_Kind: abort(); case Statement::kIf_Kind: this->compileIf(builder, (IfStatement&) stmt); break; case Statement::kNop_Kind: break; case Statement::kReturn_Kind: this->compileReturn(builder, (ReturnStatement&) stmt); break; case Statement::kSwitch_Kind: abort(); case Statement::kVarDeclarations_Kind: this->compileVarDeclarations(builder, (VarDeclarationsStatement&) stmt); break; case Statement::kWhile_Kind: this->compileWhile(builder, (WhileStatement&) stmt); break; default: abort(); } } void JIT::compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc) { // loop over fVectorCount pixels, running the body of the stage function for each of them LLVMValueRef oldFunction = fCurrentFunction; fCurrentFunction = newFunc; std::unique_ptr params(new LLVMValueRef[STAGE_PARAM_COUNT]); LLVMGetParams(fCurrentFunction, params.get()); LLVMValueRef programParam = params.get()[1]; LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext); LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock; LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock; fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca"); this->setBlock(builder, fAllocaBlock); // temporaries to store the color channel vectors LLVMValueRef rVec = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec"); LLVMBuildStore(builder, params.get()[4], rVec); LLVMValueRef gVec = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec"); LLVMBuildStore(builder, params.get()[5], gVec); LLVMValueRef bVec = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec"); LLVMBuildStore(builder, params.get()[6], bVec); LLVMValueRef aVec = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec"); LLVMBuildStore(builder, params.get()[7], aVec); LLVMValueRef color = LLVMBuildAlloca(builder, fFloat32Vector4Type, "color"); fVariables[f.fDeclaration.fParameters[1]] = LLVMBuildTrunc(builder, params.get()[3], fInt32Type, "y->Int32"); fVariables[f.fDeclaration.fParameters[2]] = color; LLVMValueRef ivar = LLVMBuildAlloca(builder, fInt32Type, "i"); LLVMBuildStore(builder, LLVMConstInt(fInt32Type, 0, false), ivar); LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start"); this->setBlock(builder, start); LLVMValueRef iload = LLVMBuildLoad(builder, ivar, "load i"); fVariables[f.fDeclaration.fParameters[0]] = LLVMBuildAdd(builder, LLVMBuildTrunc(builder, params.get()[2], fInt32Type, "x->Int32"), iload, "x"); LLVMValueRef vectorSize = LLVMConstInt(fInt32Type, fVectorCount, false); LLVMValueRef test = LLVMBuildICmp(builder, LLVMIntSLT, iload, vectorSize, "i < vectorSize"); LLVMBasicBlockRef loopBody = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "body"); LLVMBasicBlockRef loopEnd = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "end"); LLVMBuildCondBr(builder, test, loopBody, loopEnd); this->setBlock(builder, loopBody); LLVMValueRef vec = LLVMGetUndef(fFloat32Vector4Type); // extract the r, g, b, and a values from the color channel vectors and store them into "color" for (int i = 0; i < 4; ++i) { vec = LLVMBuildInsertElement(builder, vec, LLVMBuildExtractElement(builder, params.get()[4 + i], iload, "initial"), LLVMConstInt(fInt32Type, i, false), "vec build"); } LLVMBuildStore(builder, vec, color); // write actual loop body this->compileStatement(builder, *f.fBody); // extract the r, g, b, and a values from "color" and stick them back into the color channel // vectors LLVMValueRef colorLoad = LLVMBuildLoad(builder, color, "color load"); LLVMBuildStore(builder, LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, rVec, "rVec"), LLVMBuildExtractElement(builder, colorLoad, LLVMConstInt(fInt32Type, 0, false), "rExtract"), iload, "rInsert"), rVec); LLVMBuildStore(builder, LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, gVec, "gVec"), LLVMBuildExtractElement(builder, colorLoad, LLVMConstInt(fInt32Type, 1, false), "gExtract"), iload, "gInsert"), gVec); LLVMBuildStore(builder, LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, bVec, "bVec"), LLVMBuildExtractElement(builder, colorLoad, LLVMConstInt(fInt32Type, 2, false), "bExtract"), iload, "bInsert"), bVec); LLVMBuildStore(builder, LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, aVec, "aVec"), LLVMBuildExtractElement(builder, colorLoad, LLVMConstInt(fInt32Type, 3, false), "aExtract"), iload, "aInsert"), aVec); LLVMValueRef inc = LLVMBuildAdd(builder, iload, LLVMConstInt(fInt32Type, 1, false), "inc i"); LLVMBuildStore(builder, inc, ivar); LLVMBuildBr(builder, start); this->setBlock(builder, loopEnd); // increment program pointer, call the next stage LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load"); LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc); LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType, "cast next->func"); LLVMValueRef nextInc = LLVMBuildIntToPtr(builder, LLVMBuildAdd(builder, LLVMBuildPtrToInt(builder, programParam, fInt64Type, "cast 1"), LLVMConstInt(fInt64Type, PTR_SIZE, false), "add"), LLVMPointerType(fInt8PtrType, 0), "cast 2"); LLVMValueRef args[STAGE_PARAM_COUNT] = { params.get()[0], nextInc, params.get()[2], params.get()[3], LLVMBuildLoad(builder, rVec, "rVec"), LLVMBuildLoad(builder, gVec, "gVec"), LLVMBuildLoad(builder, bVec, "bVec"), LLVMBuildLoad(builder, aVec, "aVec"), params.get()[8], params.get()[9], params.get()[10], params.get()[11] }; LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, ""); LLVMBuildRetVoid(builder); // finish LLVMPositionBuilderAtEnd(builder, fAllocaBlock); LLVMBuildBr(builder, start); LLVMDisposeBuilder(builder); if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) { ABORT("verify failed\n"); } fAllocaBlock = oldAllocaBlock; fCurrentBlock = oldCurrentBlock; fCurrentFunction = oldFunction; } // FIXME maybe pluggable code generators? Need to do something to separate all // of the normal codegen from the vector codegen and break this up into multiple // classes. bool JIT::getVectorLValue(LLVMBuilderRef builder, const Expression& e, LLVMValueRef out[CHANNELS]) { switch (e.fKind) { case Expression::kVariableReference_Kind: if (fColorParam == &((VariableReference&) e).fVariable) { memcpy(out, fChannels, sizeof(fChannels)); return true; } return false; case Expression::kSwizzle_Kind: { const Swizzle& s = (const Swizzle&) e; LLVMValueRef base[CHANNELS]; if (!this->getVectorLValue(builder, *s.fBase, base)) { return false; } for (size_t i = 0; i < s.fComponents.size(); ++i) { out[i] = base[s.fComponents[i]]; } return true; } default: return false; } } bool JIT::getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left, LLVMValueRef outLeft[CHANNELS], const Expression& right, LLVMValueRef outRight[CHANNELS]) { if (!this->compileVectorExpression(builder, left, outLeft)) { return false; } int leftColumns = left.fType.columns(); int rightColumns = right.fType.columns(); if (leftColumns == 1 && rightColumns > 1) { for (int i = 1; i < rightColumns; ++i) { outLeft[i] = outLeft[0]; } } if (!this->compileVectorExpression(builder, right, outRight)) { return false; } if (rightColumns == 1 && leftColumns > 1) { for (int i = 1; i < leftColumns; ++i) { outRight[i] = outRight[0]; } } return true; } bool JIT::compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef out[CHANNELS]) { LLVMValueRef left[CHANNELS]; LLVMValueRef right[CHANNELS]; #define VECTOR_BINARY(signedOp, unsignedOp, floatOp) { \ if (!this->getVectorBinaryOperands(builder, *b.fLeft, left, *b.fRight, right)) { \ return false; \ } \ for (int i = 0; i < b.fLeft->fType.columns(); ++i) { \ switch (this->typeKind(b.fLeft->fType)) { \ case kInt_TypeKind: \ out[i] = signedOp(builder, left[i], right[i], "binary"); \ break; \ case kUInt_TypeKind: \ out[i] = unsignedOp(builder, left[i], right[i], "binary"); \ break; \ case kFloat_TypeKind: \ out[i] = floatOp(builder, left[i], right[i], "binary"); \ break; \ case kBool_TypeKind: \ SkASSERT(false); \ break; \ } \ } \ return true; \ } switch (b.fOperator) { case Token::EQ: { if (!this->getVectorLValue(builder, *b.fLeft, left)) { return false; } if (!this->compileVectorExpression(builder, *b.fRight, right)) { return false; } int columns = b.fRight->fType.columns(); for (int i = 0; i < columns; ++i) { LLVMBuildStore(builder, right[i], left[i]); } return true; } case Token::PLUS: VECTOR_BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd); case Token::MINUS: VECTOR_BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub); case Token::STAR: VECTOR_BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul); case Token::SLASH: VECTOR_BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv); case Token::PERCENT: VECTOR_BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem); case Token::BITWISEAND: VECTOR_BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd); case Token::BITWISEOR: VECTOR_BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr); default: printf("unsupported operator: %s\n", b.description().c_str()); return false; } } bool JIT::compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c, LLVMValueRef out[CHANNELS]) { switch (c.fType.kind()) { case Type::kScalar_Kind: { SkASSERT(c.fArguments.size() == 1); TypeKind from = this->typeKind(c.fArguments[0]->fType); TypeKind to = this->typeKind(c.fType); LLVMValueRef base[CHANNELS]; if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) { return false; } #define CONSTRUCT(fn) \ out[0] = LLVMGetUndef(LLVMVectorType(this->getType(c.fType), fVectorCount)); \ for (int i = 0; i < fVectorCount; ++i) { \ LLVMValueRef index = LLVMConstInt(fInt32Type, i, false); \ LLVMValueRef baseVal = LLVMBuildExtractElement(builder, base[0], index, \ "construct extract"); \ out[0] = LLVMBuildInsertElement(builder, out[0], \ fn(builder, baseVal, this->getType(c.fType), \ "cast"), \ index, "construct insert"); \ } \ return true; if (kFloat_TypeKind == to) { if (kInt_TypeKind == from) { CONSTRUCT(LLVMBuildSIToFP); } if (kUInt_TypeKind == from) { CONSTRUCT(LLVMBuildUIToFP); } } if (kInt_TypeKind == to) { if (kFloat_TypeKind == from) { CONSTRUCT(LLVMBuildFPToSI); } if (kUInt_TypeKind == from) { return true; } } if (kUInt_TypeKind == to) { if (kFloat_TypeKind == from) { CONSTRUCT(LLVMBuildFPToUI); } if (kInt_TypeKind == from) { return base; } } printf("%s\n", c.description().c_str()); ABORT("unsupported constructor"); } case Type::kVector_Kind: { if (c.fArguments.size() == 1) { LLVMValueRef base[CHANNELS]; if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) { return false; } for (int i = 0; i < c.fType.columns(); ++i) { out[i] = base[0]; } } else { SkASSERT(c.fArguments.size() == (size_t) c.fType.columns()); for (int i = 0; i < c.fType.columns(); ++i) { LLVMValueRef base[CHANNELS]; if (!this->compileVectorExpression(builder, *c.fArguments[i], base)) { return false; } out[i] = base[0]; } } return true; } default: break; } ABORT("unsupported constructor"); } bool JIT::compileVectorFloatLiteral(LLVMBuilderRef builder, const FloatLiteral& f, LLVMValueRef out[CHANNELS]) { LLVMValueRef value = LLVMConstReal(this->getType(f.fType), f.fValue); LLVMValueRef values[MAX_VECTOR_COUNT]; for (int i = 0; i < fVectorCount; ++i) { values[i] = value; } out[0] = LLVMConstVector(values, fVectorCount); return true; } bool JIT::compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s, LLVMValueRef out[CHANNELS]) { LLVMValueRef all[CHANNELS]; if (!this->compileVectorExpression(builder, *s.fBase, all)) { return false; } for (size_t i = 0; i < s.fComponents.size(); ++i) { out[i] = all[s.fComponents[i]]; } return true; } bool JIT::compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v, LLVMValueRef out[CHANNELS]) { if (&v.fVariable == fColorParam) { for (int i = 0; i < CHANNELS; ++i) { out[i] = LLVMBuildLoad(builder, fChannels[i], "variable reference"); } return true; } return false; } bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr, LLVMValueRef out[CHANNELS]) { switch (expr.fKind) { case Expression::kBinary_Kind: return this->compileVectorBinary(builder, (const BinaryExpression&) expr, out); case Expression::kConstructor_Kind: return this->compileVectorConstructor(builder, (const Constructor&) expr, out); case Expression::kFloatLiteral_Kind: return this->compileVectorFloatLiteral(builder, (const FloatLiteral&) expr, out); case Expression::kSwizzle_Kind: return this->compileVectorSwizzle(builder, (const Swizzle&) expr, out); case Expression::kVariableReference_Kind: return this->compileVectorVariableReference(builder, (const VariableReference&) expr, out); default: printf("failed expression: %s\n", expr.description().c_str()); return false; } } bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt) { switch (stmt.fKind) { case Statement::kBlock_Kind: for (const auto& s : ((const Block&) stmt).fStatements) { if (!this->compileVectorStatement(builder, *s)) { return false; } } return true; case Statement::kExpression_Kind: LLVMValueRef result; return this->compileVectorExpression(builder, *((const ExpressionStatement&) stmt).fExpression, &result); default: printf("failed statement: %s\n", stmt.description().c_str()); return false; } } bool JIT::compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc) { LLVMValueRef oldFunction = fCurrentFunction; fCurrentFunction = newFunc; std::unique_ptr params(new LLVMValueRef[STAGE_PARAM_COUNT]); LLVMGetParams(fCurrentFunction, params.get()); LLVMValueRef programParam = params.get()[1]; LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext); LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock; LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock; fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca"); this->setBlock(builder, fAllocaBlock); fChannels[0] = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec"); LLVMBuildStore(builder, params.get()[4], fChannels[0]); fChannels[1] = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec"); LLVMBuildStore(builder, params.get()[5], fChannels[1]); fChannels[2] = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec"); LLVMBuildStore(builder, params.get()[6], fChannels[2]); fChannels[3] = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec"); LLVMBuildStore(builder, params.get()[7], fChannels[3]); LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start"); this->setBlock(builder, start); bool success = this->compileVectorStatement(builder, *f.fBody); if (success) { // increment program pointer, call next LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load"); LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc); LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType, "cast next->func"); LLVMValueRef nextInc = LLVMBuildIntToPtr(builder, LLVMBuildAdd(builder, LLVMBuildPtrToInt(builder, programParam, fInt64Type, "cast 1"), LLVMConstInt(fInt64Type, PTR_SIZE, false), "add"), LLVMPointerType(fInt8PtrType, 0), "cast 2"); LLVMValueRef args[STAGE_PARAM_COUNT] = { params.get()[0], nextInc, params.get()[2], params.get()[3], LLVMBuildLoad(builder, fChannels[0], "rVec"), LLVMBuildLoad(builder, fChannels[1], "gVec"), LLVMBuildLoad(builder, fChannels[2], "bVec"), LLVMBuildLoad(builder, fChannels[3], "aVec"), params.get()[8], params.get()[9], params.get()[10], params.get()[11] }; LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, ""); LLVMBuildRetVoid(builder); // finish LLVMPositionBuilderAtEnd(builder, fAllocaBlock); LLVMBuildBr(builder, start); LLVMDisposeBuilder(builder); if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) { ABORT("verify failed\n"); } } else { LLVMDeleteBasicBlock(fAllocaBlock); LLVMDeleteBasicBlock(start); } fAllocaBlock = oldAllocaBlock; fCurrentBlock = oldCurrentBlock; fCurrentFunction = oldFunction; return success; } LLVMValueRef JIT::compileStageFunction(const FunctionDefinition& f) { LLVMTypeRef returnType = fVoidType; LLVMTypeRef parameterTypes[12] = { fSizeTType, LLVMPointerType(fInt8PtrType, 0), fSizeTType, fSizeTType, fFloat32VectorType, fFloat32VectorType, fFloat32VectorType, fFloat32VectorType, fFloat32VectorType, fFloat32VectorType, fFloat32VectorType, fFloat32VectorType }; LLVMTypeRef stageFuncType = LLVMFunctionType(returnType, parameterTypes, 12, false); LLVMValueRef result = LLVMAddFunction(fModule, (String(f.fDeclaration.fName) + "$stage").c_str(), stageFuncType); fColorParam = f.fDeclaration.fParameters[2]; if (!this->compileStageFunctionVector(f, result)) { // vectorization failed, fall back to looping over the pixels this->compileStageFunctionLoop(f, result); } return result; } bool JIT::hasStageSignature(const FunctionDeclaration& f) { return f.fReturnType == *fProgram->fContext->fVoid_Type && f.fParameters.size() == 3 && f.fParameters[0]->fType == *fProgram->fContext->fInt_Type && f.fParameters[0]->fModifiers.fFlags == 0 && f.fParameters[1]->fType == *fProgram->fContext->fInt_Type && f.fParameters[1]->fModifiers.fFlags == 0 && f.fParameters[2]->fType == *fProgram->fContext->fFloat4_Type && f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag); } LLVMValueRef JIT::compileFunction(const FunctionDefinition& f) { if (this->hasStageSignature(f.fDeclaration)) { this->compileStageFunction(f); // we compile foo$stage *in addition* to compiling foo, as we can't be sure that the intent // was to produce an SkJumper stage just because the signature matched or that the function // is not otherwise called. May need a better way to handle this. } LLVMTypeRef returnType = this->getType(f.fDeclaration.fReturnType); std::vector parameterTypes; for (const auto& p : f.fDeclaration.fParameters) { LLVMTypeRef type = this->getType(p->fType); if (p->fModifiers.fFlags & Modifiers::kOut_Flag) { type = LLVMPointerType(type, 0); } parameterTypes.push_back(type); } fCurrentFunction = LLVMAddFunction(fModule, String(f.fDeclaration.fName).c_str(), LLVMFunctionType(returnType, parameterTypes.data(), parameterTypes.size(), false)); fFunctions[&f.fDeclaration] = fCurrentFunction; std::unique_ptr params(new LLVMValueRef[parameterTypes.size()]); LLVMGetParams(fCurrentFunction, params.get()); for (size_t i = 0; i < f.fDeclaration.fParameters.size(); ++i) { fVariables[f.fDeclaration.fParameters[i]] = params.get()[i]; } LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext); fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca"); LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start"); fCurrentBlock = start; LLVMPositionBuilderAtEnd(builder, fCurrentBlock); this->compileStatement(builder, *f.fBody); if (!ends_with_branch(*f.fBody)) { if (f.fDeclaration.fReturnType == *fProgram->fContext->fVoid_Type) { LLVMBuildRetVoid(builder); } else { LLVMBuildUnreachable(builder); } } LLVMPositionBuilderAtEnd(builder, fAllocaBlock); LLVMBuildBr(builder, start); LLVMDisposeBuilder(builder); if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) { ABORT("verify failed\n"); } return fCurrentFunction; } void JIT::createModule() { fPromotedParameters.clear(); fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext); this->loadBuiltinFunctions(); // LLVM doesn't do void*, have to declare it as int8* LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType }; fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType, appendParams, 3, false)); LLVMTypeRef appendCallbackParams[2] = { fInt8PtrType, fInt8PtrType }; fAppendCallbackFunc = LLVMAddFunction(fModule, "sksl_pipeline_append_callback", LLVMFunctionType(fVoidType, appendCallbackParams, 2, false)); LLVMTypeRef debugParams[3] = { fFloat32Type }; fDebugFunc = LLVMAddFunction(fModule, "sksl_debug_print", LLVMFunctionType(fVoidType, debugParams, 1, false)); for (const auto& e : fProgram->fElements) { SkASSERT(e->fKind == ProgramElement::kFunction_Kind); this->compileFunction((FunctionDefinition&) *e); } } std::unique_ptr JIT::compile(std::unique_ptr program) { fProgram = std::move(program); this->createModule(); this->optimize(); return std::unique_ptr(new Module(std::move(fProgram), fSharedModule, fJITStack)); } void JIT::optimize() { LLVMPassManagerBuilderRef pmb = LLVMPassManagerBuilderCreate(); LLVMPassManagerBuilderSetOptLevel(pmb, 3); LLVMPassManagerRef functionPM = LLVMCreateFunctionPassManagerForModule(fModule); LLVMPassManagerBuilderPopulateFunctionPassManager(pmb, functionPM); LLVMPassManagerRef modulePM = LLVMCreatePassManager(); LLVMPassManagerBuilderPopulateModulePassManager(pmb, modulePM); LLVMInitializeFunctionPassManager(functionPM); LLVMValueRef func = LLVMGetFirstFunction(fModule); for (;;) { if (!func) { break; } LLVMRunFunctionPassManager(functionPM, func); func = LLVMGetNextFunction(func); } LLVMRunPassManager(modulePM, fModule); LLVMDisposePassManager(functionPM); LLVMDisposePassManager(modulePM); LLVMPassManagerBuilderDispose(pmb); std::string error_string; if (LLVMLoadLibraryPermanently(nullptr)) { ABORT("LLVMLoadLibraryPermanently failed"); } char* defaultTriple = LLVMGetDefaultTargetTriple(); char* error; LLVMTargetRef target; if (LLVMGetTargetFromTriple(defaultTriple, &target, &error)) { ABORT("LLVMGetTargetFromTriple failed"); } if (!LLVMTargetHasJIT(target)) { ABORT("!LLVMTargetHasJIT"); } LLVMTargetMachineRef targetMachine = LLVMCreateTargetMachine(target, defaultTriple, fCPU, nullptr, LLVMCodeGenLevelDefault, LLVMRelocDefault, LLVMCodeModelJITDefault); LLVMDisposeMessage(defaultTriple); LLVMTargetDataRef dataLayout = LLVMCreateTargetDataLayout(targetMachine); LLVMSetModuleDataLayout(fModule, dataLayout); LLVMDisposeTargetData(dataLayout); fJITStack = LLVMOrcCreateInstance(targetMachine); fSharedModule = LLVMOrcMakeSharedModule(fModule); LLVMOrcModuleHandle orcModule; LLVMOrcAddEagerlyCompiledIR(fJITStack, &orcModule, fSharedModule, (LLVMOrcSymbolResolverFn) resolveSymbol, this); LLVMDisposeTargetMachine(targetMachine); } void* JIT::Module::getSymbol(const char* name) { LLVMOrcTargetAddress result; if (LLVMOrcGetSymbolAddress(fJITStack, &result, name)) { ABORT("GetSymbolAddress error"); } if (!result) { ABORT("symbol not found"); } return (void*) result; } void* JIT::Module::getJumperStage(const char* name) { return this->getSymbol((String(name) + "$stage").c_str()); } } // namespace #endif // SK_LLVM_AVAILABLE #endif // SKSL_STANDALONE