diff options
author | Ethan Nicholas <ethannicholas@google.com> | 2018-03-27 14:10:52 -0400 |
---|---|---|
committer | Skia Commit-Bot <skia-commit-bot@chromium.org> | 2018-03-27 18:39:13 +0000 |
commit | 26a9aad63b60c9cbbdfa87c212a4e76ce55e7373 (patch) | |
tree | d4a42afd75bdff6c7815be7bf78b42cab742df19 /src/sksl/SkSLJIT.cpp | |
parent | 3560b58de36988e1fba54d9ac341735ab849e913 (diff) |
initial SkSLJIT checkin
Docs-Preview: https://skia.org/?cl=112204
Bug: skia:
Change-Id: I10042a0200db00bd8ff8078467c409b1cf191f50
Reviewed-on: https://skia-review.googlesource.com/112204
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: Mike Klein <mtklein@chromium.org>
Diffstat (limited to 'src/sksl/SkSLJIT.cpp')
-rw-r--r-- | src/sksl/SkSLJIT.cpp | 1747 |
1 files changed, 1747 insertions, 0 deletions
diff --git a/src/sksl/SkSLJIT.cpp b/src/sksl/SkSLJIT.cpp new file mode 100644 index 0000000000..06b6e2c94b --- /dev/null +++ b/src/sksl/SkSLJIT.cpp @@ -0,0 +1,1747 @@ +/* + * Copyright 2018 Google Inc. + * + * Use of this source code is governed by a BSD-style license that can be + * found in the LICENSE file. + */ + +#ifndef SKSL_STANDALONE + +#ifdef SK_LLVM_AVAILABLE + +#include "SkSLJIT.h" + +#include "SkCpu.h" +#include "SkRasterPipeline.h" +#include "../jumper/SkJumper.h" +#include "ir/SkSLExpressionStatement.h" +#include "ir/SkSLFunctionCall.h" +#include "ir/SkSLFunctionReference.h" +#include "ir/SkSLIndexExpression.h" +#include "ir/SkSLProgram.h" +#include "llvm/ExecutionEngine/RTDyldMemoryManager.h" + +static constexpr int MAX_VECTOR_COUNT = 16; + +extern "C" void sksl_pipeline_append(SkRasterPipeline* p, int stage, void* ctx) { + p->append((SkRasterPipeline::StockStage) stage, ctx); +} + +#define PTR_SIZE sizeof(void*) + +extern "C" void sksl_pipeline_append_callback(SkRasterPipeline* p, void* fn) { + p->append(fn, nullptr); +} + +extern "C" void sksl_debug_print(float f) { + printf("Debug: %f\n", f); +} + +namespace SkSL { + +static constexpr int STAGE_PARAM_COUNT = 12; + +static bool ends_with_branch(const Statement& stmt) { + switch (stmt.fKind) { + case Statement::kBlock_Kind: { + const Block& b = (const Block&) stmt; + if (b.fStatements.size()) { + return ends_with_branch(*b.fStatements.back()); + } + return false; + } + case Statement::kBreak_Kind: // fall through + case Statement::kContinue_Kind: // fall through + case Statement::kReturn_Kind: // fall through + return true; + default: + return false; + } +} + +JIT::JIT(Compiler* compiler) +: fCompiler(*compiler) { + LLVMInitializeNativeTarget(); + LLVMInitializeNativeAsmPrinter(); + LLVMLinkInMCJIT(); + ASSERT(!SkCpu::Supports(SkCpu::SKX)); // not yet supported + if (SkCpu::Supports(SkCpu::HSW)) { + fVectorCount = 8; + fCPU = "haswell"; + } else if (SkCpu::Supports(SkCpu::AVX)) { + fVectorCount = 8; + fCPU = "ivybridge"; + } else { + fVectorCount = 4; + fCPU = nullptr; + } + fContext = LLVMContextCreate(); + fVoidType = LLVMVoidTypeInContext(fContext); + fInt1Type = LLVMInt1TypeInContext(fContext); + fInt8Type = LLVMInt8TypeInContext(fContext); + fInt8PtrType = LLVMPointerType(fInt8Type, 0); + fInt32Type = LLVMInt32TypeInContext(fContext); + fInt64Type = LLVMInt64TypeInContext(fContext); + fSizeTType = LLVMInt64TypeInContext(fContext); + fInt32VectorType = LLVMVectorType(fInt32Type, fVectorCount); + fInt32Vector2Type = LLVMVectorType(fInt32Type, 2); + fInt32Vector3Type = LLVMVectorType(fInt32Type, 3); + fInt32Vector4Type = LLVMVectorType(fInt32Type, 4); + fFloat32Type = LLVMFloatTypeInContext(fContext); + fFloat32VectorType = LLVMVectorType(fFloat32Type, fVectorCount); + fFloat32Vector2Type = LLVMVectorType(fFloat32Type, 2); + fFloat32Vector3Type = LLVMVectorType(fFloat32Type, 3); + fFloat32Vector4Type = LLVMVectorType(fFloat32Type, 4); +} + +JIT::~JIT() { + LLVMOrcDisposeInstance(fJITStack); + LLVMContextDispose(fContext); +} + +void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType, + std::vector<LLVMTypeRef> parameters) { + for (const auto& pair : *fProgram->fSymbols) { + if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) { + const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second; + if (pair.first != ourName || returnType != this->getType(f.fReturnType) || + parameters.size() != f.fParameters.size()) { + continue; + } + for (size_t i = 0; i < parameters.size(); ++i) { + if (parameters[i] != this->getType(f.fParameters[i]->fType)) { + goto next; + } + } + fFunctions[&f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(returnType, + parameters.data(), + parameters.size(), + false)); + } + next:; + } +} + +void JIT::loadBuiltinFunctions() { + this->addBuiltinFunction("abs", "fabs", fFloat32Type, { fFloat32Type }); + this->addBuiltinFunction("sin", "sinf", fFloat32Type, { fFloat32Type }); + this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type }); + this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type }); + this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type }); + this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type }); +} + +uint64_t JIT::resolveSymbol(const char* name, JIT* jit) { + LLVMOrcTargetAddress result; + if (!LLVMOrcGetSymbolAddress(jit->fJITStack, &result, name)) { + if (!strcmp(name, "_sksl_pipeline_append")) { + result = (uint64_t) &sksl_pipeline_append; + } else if (!strcmp(name, "_sksl_pipeline_append_callback")) { + result = (uint64_t) &sksl_pipeline_append_callback; + } else if (!strcmp(name, "_sksl_debug_print")) { + result = (uint64_t) &sksl_debug_print; + } else { + result = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name); + } + } + ASSERT(result); + return result; +} + +LLVMValueRef JIT::compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc) { + LLVMValueRef func = fFunctions[&fc.fFunction]; + ASSERT(func); + std::vector<LLVMValueRef> parameters; + for (const auto& a : fc.fArguments) { + parameters.push_back(this->compileExpression(builder, *a)); + } + return LLVMBuildCall(builder, func, parameters.data(), parameters.size(), ""); +} + +LLVMTypeRef JIT::getType(const Type& type) { + switch (type.kind()) { + case Type::kOther_Kind: + if (type.name() == "void") { + return fVoidType; + } + ASSERT(type.name() == "SkRasterPipeline"); + return fInt8PtrType; + case Type::kScalar_Kind: + if (type.isSigned() || type.isUnsigned()) { + return fInt32Type; + } + if (type.isUnsigned()) { + return fInt32Type; + } + if (type.isFloat()) { + return fFloat32Type; + } + ASSERT(type.name() == "bool"); + return fInt1Type; + case Type::kArray_Kind: + return LLVMPointerType(this->getType(type.componentType()), 0); + case Type::kVector_Kind: + if (type.name() == "float2" || type.name() == "half2") { + return fFloat32Vector2Type; + } + if (type.name() == "float3" || type.name() == "half3") { + return fFloat32Vector3Type; + } + if (type.name() == "float4" || type.name() == "half4") { + return fFloat32Vector4Type; + } + if (type.name() == "int2" || type.name() == "short2") { + return fInt32Vector2Type; + } + if (type.name() == "int3" || type.name() == "short3") { + return fInt32Vector3Type; + } + if (type.name() == "int4" || type.name() == "short4") { + return fInt32Vector4Type; + } + // fall through + default: + ABORT("unsupported type"); + } +} + +void JIT::setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block) { + fCurrentBlock = block; + LLVMPositionBuilderAtEnd(builder, block); +} + +std::unique_ptr<JIT::LValue> JIT::getLValue(LLVMBuilderRef builder, const Expression& expr) { + switch (expr.fKind) { + case Expression::kVariableReference_Kind: { + class PointerLValue : public LValue { + public: + PointerLValue(LLVMValueRef ptr) + : fPointer(ptr) {} + + LLVMValueRef load(LLVMBuilderRef builder) override { + return LLVMBuildLoad(builder, fPointer, "lvalue load"); + } + + void store(LLVMBuilderRef builder, LLVMValueRef value) override { + LLVMBuildStore(builder, value, fPointer); + } + + private: + LLVMValueRef fPointer; + }; + const Variable* var = &((VariableReference&) expr).fVariable; + if (var->fStorage == Variable::kParameter_Storage && + !(var->fModifiers.fFlags & Modifiers::kOut_Flag) && + fPromotedParameters.find(var) == fPromotedParameters.end()) { + // promote parameter to variable + fPromotedParameters.insert(var); + LLVMPositionBuilderAtEnd(builder, fAllocaBlock); + LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(var->fType), + String(var->fName).c_str()); + LLVMBuildStore(builder, fVariables[var], alloca); + LLVMPositionBuilderAtEnd(builder, fCurrentBlock); + fVariables[var] = alloca; + } + LLVMValueRef ptr = fVariables[var]; + return std::unique_ptr<LValue>(new PointerLValue(ptr)); + } + case Expression::kTernary_Kind: { + class TernaryLValue : public LValue { + public: + TernaryLValue(JIT* jit, LLVMValueRef test, std::unique_ptr<LValue> ifTrue, + std::unique_ptr<LValue> ifFalse) + : fJIT(*jit) + , fTest(test) + , fIfTrue(std::move(ifTrue)) + , fIfFalse(std::move(ifFalse)) {} + + LLVMValueRef load(LLVMBuilderRef builder) override { + LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext( + fJIT.fContext, + fJIT.fCurrentFunction, + "true ? ..."); + LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext( + fJIT.fContext, + fJIT.fCurrentFunction, + "false ? ..."); + LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext, + fJIT.fCurrentFunction, + "ternary merge"); + LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock); + fJIT.setBlock(builder, trueBlock); + LLVMValueRef ifTrue = fIfTrue->load(builder); + LLVMBuildBr(builder, merge); + fJIT.setBlock(builder, falseBlock); + LLVMValueRef ifFalse = fIfTrue->load(builder); + LLVMBuildBr(builder, merge); + fJIT.setBlock(builder, merge); + LLVMTypeRef type = LLVMPointerType(LLVMTypeOf(ifTrue), 0); + LLVMValueRef phi = LLVMBuildPhi(builder, type, "?"); + LLVMValueRef incomingValues[2] = { ifTrue, ifFalse }; + LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock }; + LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2); + return phi; + } + + void store(LLVMBuilderRef builder, LLVMValueRef value) override { + LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext( + fJIT.fContext, + fJIT.fCurrentFunction, + "true ? ..."); + LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext( + fJIT.fContext, + fJIT.fCurrentFunction, + "false ? ..."); + LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext, + fJIT.fCurrentFunction, + "ternary merge"); + LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock); + fJIT.setBlock(builder, trueBlock); + fIfTrue->store(builder, value); + LLVMBuildBr(builder, merge); + fJIT.setBlock(builder, falseBlock); + fIfTrue->store(builder, value); + LLVMBuildBr(builder, merge); + fJIT.setBlock(builder, merge); + } + + private: + JIT& fJIT; + LLVMValueRef fTest; + std::unique_ptr<LValue> fIfTrue; + std::unique_ptr<LValue> fIfFalse; + }; + const TernaryExpression& t = (const TernaryExpression&) expr; + LLVMValueRef test = this->compileExpression(builder, *t.fTest); + return std::unique_ptr<LValue>(new TernaryLValue(this, + test, + this->getLValue(builder, + *t.fIfTrue), + this->getLValue(builder, + *t.fIfFalse))); + } + case Expression::kSwizzle_Kind: { + class SwizzleLValue : public LValue { + public: + SwizzleLValue(JIT* jit, LLVMTypeRef type, std::unique_ptr<LValue> base, + std::vector<int> components) + : fJIT(*jit) + , fType(type) + , fBase(std::move(base)) + , fComponents(components) {} + + LLVMValueRef load(LLVMBuilderRef builder) override { + LLVMValueRef base = fBase->load(builder); + if (fComponents.size() > 1) { + LLVMValueRef result = LLVMGetUndef(fType); + for (size_t i = 0; i < fComponents.size(); ++i) { + LLVMValueRef element = LLVMBuildExtractElement( + builder, + base, + LLVMConstInt(fJIT.fInt32Type, + fComponents[i], + false), + "swizzle extract"); + result = LLVMBuildInsertElement(builder, result, element, + LLVMConstInt(fJIT.fInt32Type, i, false), + "swizzle insert"); + } + return result; + } + ASSERT(fComponents.size() == 1); + return LLVMBuildExtractElement(builder, base, + LLVMConstInt(fJIT.fInt32Type, + fComponents[0], + false), + "swizzle extract"); + } + + void store(LLVMBuilderRef builder, LLVMValueRef value) override { + LLVMValueRef result = fBase->load(builder); + if (fComponents.size() > 1) { + for (size_t i = 0; i < fComponents.size(); ++i) { + LLVMValueRef element = LLVMBuildExtractElement(builder, value, + LLVMConstInt( + fJIT.fInt32Type, + i, + false), + "swizzle extract"); + result = LLVMBuildInsertElement(builder, result, element, + LLVMConstInt(fJIT.fInt32Type, + fComponents[i], + false), + "swizzle insert"); + } + } else { + result = LLVMBuildInsertElement(builder, result, value, + LLVMConstInt(fJIT.fInt32Type, + fComponents[0], + false), + "swizzle insert"); + } + fBase->store(builder, result); + } + + private: + JIT& fJIT; + LLVMTypeRef fType; + std::unique_ptr<LValue> fBase; + std::vector<int> fComponents; + }; + const Swizzle& s = (const Swizzle&) expr; + return std::unique_ptr<LValue>(new SwizzleLValue(this, this->getType(s.fType), + this->getLValue(builder, *s.fBase), + s.fComponents)); + } + default: + ABORT("unsupported lvalue"); + } +} + +JIT::TypeKind JIT::typeKind(const Type& type) { + if (type.kind() == Type::kVector_Kind) { + return this->typeKind(type.componentType()); + } + if (type.fName == "int" || type.fName == "short") { + return JIT::kInt_TypeKind; + } else if (type.fName == "uint" || type.fName == "ushort") { + return JIT::kUInt_TypeKind; + } else if (type.fName == "float" || type.fName == "double") { + return JIT::kFloat_TypeKind; + } + ABORT("unsupported type: %s\n", type.description().c_str()); +} + +void JIT::vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns) { + LLVMValueRef result = LLVMGetUndef(LLVMVectorType(LLVMTypeOf(*value), columns)); + for (int i = 0; i < columns; ++i) { + result = LLVMBuildInsertElement(builder, + result, + *value, + LLVMConstInt(fInt32Type, i, false), + "vectorize"); + } + *value = result; +} + +void JIT::vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left, + LLVMValueRef* right) { + if (b.fLeft->fType.kind() == Type::kScalar_Kind && + b.fRight->fType.kind() == Type::kVector_Kind) { + this->vectorize(builder, left, b.fRight->fType.columns()); + } else if (b.fLeft->fType.kind() == Type::kVector_Kind && + b.fRight->fType.kind() == Type::kScalar_Kind) { + this->vectorize(builder, right, b.fLeft->fType.columns()); + } +} + + +LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& b) { + #define BINARY(SFunc, UFunc, FFunc) { \ + LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \ + LLVMValueRef right = this->compileExpression(builder, *b.fRight); \ + this->vectorize(builder, b, &left, &right); \ + switch (this->typeKind(b.fLeft->fType)) { \ + case kInt_TypeKind: \ + return SFunc(builder, left, right, "binary"); \ + case kUInt_TypeKind: \ + return UFunc(builder, left, right, "binary"); \ + case kFloat_TypeKind: \ + return FFunc(builder, left, right, "binary"); \ + default: \ + ABORT("unsupported typeKind"); \ + } \ + } + #define COMPOUND(SFunc, UFunc, FFunc) { \ + std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft); \ + LLVMValueRef left = lvalue->load(builder); \ + LLVMValueRef right = this->compileExpression(builder, *b.fRight); \ + this->vectorize(builder, b, &left, &right); \ + LLVMValueRef result; \ + switch (this->typeKind(b.fLeft->fType)) { \ + case kInt_TypeKind: \ + result = SFunc(builder, left, right, "binary"); \ + break; \ + case kUInt_TypeKind: \ + result = UFunc(builder, left, right, "binary"); \ + break; \ + case kFloat_TypeKind: \ + result = FFunc(builder, left, right, "binary"); \ + break; \ + default: \ + ABORT("unsupported typeKind"); \ + } \ + lvalue->store(builder, result); \ + return result; \ + } + #define COMPARE(SFunc, SOp, UFunc, UOp, FFunc, FOp) { \ + LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \ + LLVMValueRef right = this->compileExpression(builder, *b.fRight); \ + this->vectorize(builder, b, &left, &right); \ + switch (this->typeKind(b.fLeft->fType)) { \ + case kInt_TypeKind: \ + return SFunc(builder, SOp, left, right, "binary"); \ + case kUInt_TypeKind: \ + return UFunc(builder, UOp, left, right, "binary"); \ + case kFloat_TypeKind: \ + return FFunc(builder, FOp, left, right, "binary"); \ + default: \ + ABORT("unsupported typeKind"); \ + } \ + } + switch (b.fOperator) { + case Token::EQ: { + std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft); + LLVMValueRef result = this->compileExpression(builder, *b.fRight); + lvalue->store(builder, result); + return result; + } + case Token::PLUS: + BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd); + case Token::MINUS: + BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub); + case Token::STAR: + BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul); + case Token::SLASH: + BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv); + case Token::PERCENT: + BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem); + case Token::BITWISEAND: + BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd); + case Token::BITWISEOR: + BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr); + case Token::PLUSEQ: + COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd); + case Token::MINUSEQ: + COMPOUND(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub); + case Token::STAREQ: + COMPOUND(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul); + case Token::SLASHEQ: + COMPOUND(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv); + case Token::BITWISEANDEQ: + COMPOUND(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd); + case Token::BITWISEOREQ: + COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr); + case Token::EQEQ: + COMPARE(LLVMBuildICmp, LLVMIntEQ, + LLVMBuildICmp, LLVMIntEQ, + LLVMBuildFCmp, LLVMRealOEQ); + case Token::NEQ: + COMPARE(LLVMBuildICmp, LLVMIntNE, + LLVMBuildICmp, LLVMIntNE, + LLVMBuildFCmp, LLVMRealONE); + case Token::LT: + COMPARE(LLVMBuildICmp, LLVMIntSLT, + LLVMBuildICmp, LLVMIntULT, + LLVMBuildFCmp, LLVMRealOLT); + case Token::LTEQ: + COMPARE(LLVMBuildICmp, LLVMIntSLE, + LLVMBuildICmp, LLVMIntULE, + LLVMBuildFCmp, LLVMRealOLE); + case Token::GT: + COMPARE(LLVMBuildICmp, LLVMIntSGT, + LLVMBuildICmp, LLVMIntUGT, + LLVMBuildFCmp, LLVMRealOGT); + case Token::GTEQ: + COMPARE(LLVMBuildICmp, LLVMIntSGE, + LLVMBuildICmp, LLVMIntUGE, + LLVMBuildFCmp, LLVMRealOGE); + case Token::LOGICALAND: { + LLVMValueRef left = this->compileExpression(builder, *b.fLeft); + LLVMBasicBlockRef ifFalse = fCurrentBlock; + LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "true && ..."); + LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "&& merge"); + LLVMBuildCondBr(builder, left, ifTrue, merge); + this->setBlock(builder, ifTrue); + LLVMValueRef right = this->compileExpression(builder, *b.fRight); + LLVMBuildBr(builder, merge); + this->setBlock(builder, merge); + LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "&&"); + LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 0, false) }; + LLVMBasicBlockRef incomingBlocks[2] = { ifTrue, ifFalse }; + LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2); + return phi; + } + case Token::LOGICALOR: { + LLVMValueRef left = this->compileExpression(builder, *b.fLeft); + LLVMBasicBlockRef ifTrue = fCurrentBlock; + LLVMBasicBlockRef ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "false || ..."); + LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "|| merge"); + LLVMBuildCondBr(builder, left, merge, ifFalse); + this->setBlock(builder, ifFalse); + LLVMValueRef right = this->compileExpression(builder, *b.fRight); + LLVMBuildBr(builder, merge); + this->setBlock(builder, merge); + LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "||"); + LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 1, false) }; + LLVMBasicBlockRef incomingBlocks[2] = { ifFalse, ifTrue }; + LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2); + return phi; + } + default: + ABORT("unsupported binary operator"); + } +} + +LLVMValueRef JIT::compileIndex(LLVMBuilderRef builder, const IndexExpression& idx) { + LLVMValueRef base = this->compileExpression(builder, *idx.fBase); + LLVMValueRef index = this->compileExpression(builder, *idx.fIndex); + LLVMValueRef ptr = LLVMBuildGEP(builder, base, &index, 1, "index ptr"); + return LLVMBuildLoad(builder, ptr, "index load"); +} + +LLVMValueRef JIT::compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p) { + std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand); + LLVMValueRef result = lvalue->load(builder); + LLVMValueRef mod; + LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false); + switch (p.fOperator) { + case Token::PLUSPLUS: + switch (this->typeKind(p.fType)) { + case kInt_TypeKind: // fall through + case kUInt_TypeKind: + mod = LLVMBuildAdd(builder, result, one, "++"); + break; + case kFloat_TypeKind: + mod = LLVMBuildFAdd(builder, result, one, "++"); + break; + default: + ABORT("unsupported typeKind"); + } + break; + case Token::MINUSMINUS: + switch (this->typeKind(p.fType)) { + case kInt_TypeKind: // fall through + case kUInt_TypeKind: + mod = LLVMBuildSub(builder, result, one, "--"); + break; + case kFloat_TypeKind: + mod = LLVMBuildFSub(builder, result, one, "--"); + break; + default: + ABORT("unsupported typeKind"); + } + break; + default: + ABORT("unsupported postfix op"); + } + lvalue->store(builder, mod); + return result; +} + +LLVMValueRef JIT::compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p) { + LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false); + if (Token::LOGICALNOT == p.fOperator) { + LLVMValueRef base = this->compileExpression(builder, *p.fOperand); + return LLVMBuildXor(builder, base, one, "!"); + } + if (Token::MINUS == p.fOperator) { + LLVMValueRef base = this->compileExpression(builder, *p.fOperand); + return LLVMBuildSub(builder, LLVMConstInt(this->getType(p.fType), 0, false), base, "-"); + } + std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand); + LLVMValueRef raw = lvalue->load(builder); + LLVMValueRef result; + switch (p.fOperator) { + case Token::PLUSPLUS: + switch (this->typeKind(p.fType)) { + case kInt_TypeKind: // fall through + case kUInt_TypeKind: + result = LLVMBuildAdd(builder, raw, one, "++"); + break; + case kFloat_TypeKind: + result = LLVMBuildFAdd(builder, raw, one, "++"); + break; + default: + ABORT("unsupported typeKind"); + } + break; + case Token::MINUSMINUS: + switch (this->typeKind(p.fType)) { + case kInt_TypeKind: // fall through + case kUInt_TypeKind: + result = LLVMBuildSub(builder, raw, one, "--"); + break; + case kFloat_TypeKind: + result = LLVMBuildFSub(builder, raw, one, "--"); + break; + default: + ABORT("unsupported typeKind"); + } + break; + default: + ABORT("unsupported prefix op"); + } + lvalue->store(builder, result); + return result; +} + +LLVMValueRef JIT::compileVariableReference(LLVMBuilderRef builder, const VariableReference& v) { + const Variable& var = v.fVariable; + if (Variable::kParameter_Storage == var.fStorage && + !(var.fModifiers.fFlags & Modifiers::kOut_Flag) && + fPromotedParameters.find(&var) == fPromotedParameters.end()) { + return fVariables[&var]; + } + return LLVMBuildLoad(builder, fVariables[&var], String(var.fName).c_str()); +} + +void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) { + ASSERT(a.fArguments.size() >= 1); + ASSERT(a.fArguments[0]->fType == *fCompiler.context().fSkRasterPipeline_Type); + LLVMValueRef pipeline = this->compileExpression(builder, *a.fArguments[0]); + LLVMValueRef stage = LLVMConstInt(fInt32Type, a.fStage, 0); + switch (a.fStage) { + case SkRasterPipeline::callback: { + ASSERT(a.fArguments.size() == 2); + ASSERT(a.fArguments[1]->fKind == Expression::kFunctionReference_Kind); + const FunctionDeclaration& functionDecl = + *((FunctionReference&) *a.fArguments[1]).fFunctions[0]; + bool found = false; + for (const auto& pe : fProgram->fElements) { + if (ProgramElement::kFunction_Kind == pe->fKind) { + const FunctionDefinition& def = (const FunctionDefinition&) *pe; + if (&def.fDeclaration == &functionDecl) { + LLVMValueRef fn = this->compileStageFunction(def); + LLVMValueRef args[2] = { + pipeline, + LLVMBuildBitCast(builder, fn, fInt8PtrType, "callback cast") + }; + LLVMBuildCall(builder, fAppendCallbackFunc, args, 2, ""); + found = true; + break; + } + } + } + ASSERT(found); + break; + } + default: { + LLVMValueRef ctx; + if (a.fArguments.size() == 2) { + ctx = this->compileExpression(builder, *a.fArguments[1]); + ctx = LLVMBuildBitCast(builder, ctx, fInt8PtrType, "context cast"); + } else { + ASSERT(a.fArguments.size() == 1); + ctx = LLVMConstNull(fInt8PtrType); + } + LLVMValueRef args[3] = { + pipeline, + stage, + ctx + }; + LLVMBuildCall(builder, fAppendFunc, args, 3, ""); + break; + } + } +} + +LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor& c) { + switch (c.fType.kind()) { + case Type::kScalar_Kind: { + ASSERT(c.fArguments.size() == 1); + TypeKind from = this->typeKind(c.fArguments[0]->fType); + TypeKind to = this->typeKind(c.fType); + LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]); + if (kFloat_TypeKind == to) { + if (kInt_TypeKind == from) { + return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast"); + } + if (kUInt_TypeKind == from) { + return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast"); + } + } + if (kInt_TypeKind == to) { + if (kFloat_TypeKind == from) { + return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast"); + } + if (kUInt_TypeKind == from) { + return base; + } + } + if (kUInt_TypeKind == to) { + if (kFloat_TypeKind == from) { + return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast"); + } + if (kInt_TypeKind == from) { + return base; + } + } + ABORT("unsupported constructor"); + } + case Type::kVector_Kind: { + LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType)); + if (c.fArguments.size() == 1) { + LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]); + for (int i = 0; i < c.fType.columns(); ++i) { + vec = LLVMBuildInsertElement(builder, vec, value, + LLVMConstInt(fInt32Type, i, false), + "vec build"); + } + } else { + ASSERT(c.fArguments.size() == (size_t) c.fType.columns()); + for (int i = 0; i < c.fType.columns(); ++i) { + vec = LLVMBuildInsertElement(builder, vec, + this->compileExpression(builder, + *c.fArguments[i]), + LLVMConstInt(fInt32Type, i, false), + "vec build"); + } + } + return vec; + } + default: + break; + } + ABORT("unsupported constructor"); +} + +LLVMValueRef JIT::compileSwizzle(LLVMBuilderRef builder, const Swizzle& s) { + LLVMValueRef base = this->compileExpression(builder, *s.fBase); + if (s.fComponents.size() > 1) { + LLVMValueRef result = LLVMGetUndef(this->getType(s.fType)); + for (size_t i = 0; i < s.fComponents.size(); ++i) { + LLVMValueRef element = LLVMBuildExtractElement( + builder, + base, + LLVMConstInt(fInt32Type, + s.fComponents[i], + false), + "swizzle extract"); + result = LLVMBuildInsertElement(builder, result, element, + LLVMConstInt(fInt32Type, i, false), + "swizzle insert"); + } + return result; + } + ASSERT(s.fComponents.size() == 1); + return LLVMBuildExtractElement(builder, base, + LLVMConstInt(fInt32Type, + s.fComponents[0], + false), + "swizzle extract"); +} + +LLVMValueRef JIT::compileTernary(LLVMBuilderRef builder, const TernaryExpression& t) { + LLVMValueRef test = this->compileExpression(builder, *t.fTest); + LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "if true"); + LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "if merge"); + LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "if false"); + LLVMBuildCondBr(builder, test, trueBlock, falseBlock); + this->setBlock(builder, trueBlock); + LLVMValueRef ifTrue = this->compileExpression(builder, *t.fIfTrue); + trueBlock = fCurrentBlock; + LLVMBuildBr(builder, merge); + this->setBlock(builder, falseBlock); + LLVMValueRef ifFalse = this->compileExpression(builder, *t.fIfFalse); + falseBlock = fCurrentBlock; + LLVMBuildBr(builder, merge); + this->setBlock(builder, merge); + LLVMValueRef phi = LLVMBuildPhi(builder, this->getType(t.fType), "?"); + LLVMValueRef incomingValues[2] = { ifTrue, ifFalse }; + LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock }; + LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2); + return phi; +} + +LLVMValueRef JIT::compileExpression(LLVMBuilderRef builder, const Expression& expr) { + switch (expr.fKind) { + case Expression::kAppendStage_Kind: { + this->appendStage(builder, (const AppendStage&) expr); + return LLVMValueRef(); + } + case Expression::kBinary_Kind: + return this->compileBinary(builder, (BinaryExpression&) expr); + case Expression::kBoolLiteral_Kind: + return LLVMConstInt(fInt1Type, ((BoolLiteral&) expr).fValue, false); + case Expression::kConstructor_Kind: + return this->compileConstructor(builder, (Constructor&) expr); + case Expression::kIntLiteral_Kind: + return LLVMConstInt(this->getType(expr.fType), ((IntLiteral&) expr).fValue, true); + case Expression::kFieldAccess_Kind: + abort(); + case Expression::kFloatLiteral_Kind: + return LLVMConstReal(this->getType(expr.fType), ((FloatLiteral&) expr).fValue); + case Expression::kFunctionCall_Kind: + return this->compileFunctionCall(builder, (FunctionCall&) expr); + case Expression::kIndex_Kind: + return this->compileIndex(builder, (IndexExpression&) expr); + case Expression::kPrefix_Kind: + return this->compilePrefix(builder, (PrefixExpression&) expr); + case Expression::kPostfix_Kind: + return this->compilePostfix(builder, (PostfixExpression&) expr); + case Expression::kSetting_Kind: + abort(); + case Expression::kSwizzle_Kind: + return this->compileSwizzle(builder, (Swizzle&) expr); + case Expression::kVariableReference_Kind: + return this->compileVariableReference(builder, (VariableReference&) expr); + case Expression::kTernary_Kind: + return this->compileTernary(builder, (TernaryExpression&) expr); + case Expression::kTypeReference_Kind: + abort(); + default: + abort(); + } + ABORT("unsupported expression: %s\n", expr.description().c_str()); +} + +void JIT::compileBlock(LLVMBuilderRef builder, const Block& block) { + for (const auto& stmt : block.fStatements) { + this->compileStatement(builder, *stmt); + } +} + +void JIT::compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls) { + for (const auto& declStatement : decls.fDeclaration->fVars) { + const VarDeclaration& decl = (VarDeclaration&) *declStatement; + LLVMPositionBuilderAtEnd(builder, fAllocaBlock); + LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(decl.fVar->fType), + String(decl.fVar->fName).c_str()); + fVariables[decl.fVar] = alloca; + LLVMPositionBuilderAtEnd(builder, fCurrentBlock); + if (decl.fValue) { + LLVMValueRef result = this->compileExpression(builder, *decl.fValue); + LLVMBuildStore(builder, result, alloca); + } + } +} + +void JIT::compileIf(LLVMBuilderRef builder, const IfStatement& i) { + LLVMValueRef test = this->compileExpression(builder, *i.fTest); + LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if true"); + LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "if merge"); + LLVMBasicBlockRef ifFalse; + if (i.fIfFalse) { + ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if false"); + } else { + ifFalse = merge; + } + LLVMBuildCondBr(builder, test, ifTrue, ifFalse); + this->setBlock(builder, ifTrue); + this->compileStatement(builder, *i.fIfTrue); + if (!ends_with_branch(*i.fIfTrue)) { + LLVMBuildBr(builder, merge); + } + if (i.fIfFalse) { + this->setBlock(builder, ifFalse); + this->compileStatement(builder, *i.fIfFalse); + if (!ends_with_branch(*i.fIfFalse)) { + LLVMBuildBr(builder, merge); + } + } + this->setBlock(builder, merge); +} + +void JIT::compileFor(LLVMBuilderRef builder, const ForStatement& f) { + if (f.fInitializer) { + this->compileStatement(builder, *f.fInitializer); + } + LLVMBasicBlockRef start; + LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for body"); + LLVMBasicBlockRef next = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for next"); + LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for end"); + if (f.fTest) { + start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for test"); + LLVMBuildBr(builder, start); + this->setBlock(builder, start); + LLVMValueRef test = this->compileExpression(builder, *f.fTest); + LLVMBuildCondBr(builder, test, body, end); + } else { + start = body; + LLVMBuildBr(builder, body); + } + this->setBlock(builder, body); + fBreakTarget.push_back(end); + fContinueTarget.push_back(next); + this->compileStatement(builder, *f.fStatement); + fBreakTarget.pop_back(); + fContinueTarget.pop_back(); + if (!ends_with_branch(*f.fStatement)) { + LLVMBuildBr(builder, next); + } + this->setBlock(builder, next); + if (f.fNext) { + this->compileExpression(builder, *f.fNext); + } + LLVMBuildBr(builder, start); + this->setBlock(builder, end); +} + +void JIT::compileDo(LLVMBuilderRef builder, const DoStatement& d) { + LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "do test"); + LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "do body"); + LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "do end"); + LLVMBuildBr(builder, body); + this->setBlock(builder, testBlock); + LLVMValueRef test = this->compileExpression(builder, *d.fTest); + LLVMBuildCondBr(builder, test, body, end); + this->setBlock(builder, body); + fBreakTarget.push_back(end); + fContinueTarget.push_back(body); + this->compileStatement(builder, *d.fStatement); + fBreakTarget.pop_back(); + fContinueTarget.pop_back(); + if (!ends_with_branch(*d.fStatement)) { + LLVMBuildBr(builder, testBlock); + } + this->setBlock(builder, end); +} + +void JIT::compileWhile(LLVMBuilderRef builder, const WhileStatement& w) { + LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "while test"); + LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "while body"); + LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "while end"); + LLVMBuildBr(builder, testBlock); + this->setBlock(builder, testBlock); + LLVMValueRef test = this->compileExpression(builder, *w.fTest); + LLVMBuildCondBr(builder, test, body, end); + this->setBlock(builder, body); + fBreakTarget.push_back(end); + fContinueTarget.push_back(testBlock); + this->compileStatement(builder, *w.fStatement); + fBreakTarget.pop_back(); + fContinueTarget.pop_back(); + if (!ends_with_branch(*w.fStatement)) { + LLVMBuildBr(builder, testBlock); + } + this->setBlock(builder, end); +} + +void JIT::compileBreak(LLVMBuilderRef builder, const BreakStatement& b) { + LLVMBuildBr(builder, fBreakTarget.back()); +} + +void JIT::compileContinue(LLVMBuilderRef builder, const ContinueStatement& b) { + LLVMBuildBr(builder, fContinueTarget.back()); +} + +void JIT::compileReturn(LLVMBuilderRef builder, const ReturnStatement& r) { + if (r.fExpression) { + LLVMBuildRet(builder, this->compileExpression(builder, *r.fExpression)); + } else { + LLVMBuildRetVoid(builder); + } +} + +void JIT::compileStatement(LLVMBuilderRef builder, const Statement& stmt) { + switch (stmt.fKind) { + case Statement::kBlock_Kind: + this->compileBlock(builder, (Block&) stmt); + break; + case Statement::kBreak_Kind: + this->compileBreak(builder, (BreakStatement&) stmt); + break; + case Statement::kContinue_Kind: + this->compileContinue(builder, (ContinueStatement&) stmt); + break; + case Statement::kDiscard_Kind: + abort(); + case Statement::kDo_Kind: + this->compileDo(builder, (DoStatement&) stmt); + break; + case Statement::kExpression_Kind: + this->compileExpression(builder, *((ExpressionStatement&) stmt).fExpression); + break; + case Statement::kFor_Kind: + this->compileFor(builder, (ForStatement&) stmt); + break; + case Statement::kGroup_Kind: + abort(); + case Statement::kIf_Kind: + this->compileIf(builder, (IfStatement&) stmt); + break; + case Statement::kNop_Kind: + break; + case Statement::kReturn_Kind: + this->compileReturn(builder, (ReturnStatement&) stmt); + break; + case Statement::kSwitch_Kind: + abort(); + case Statement::kVarDeclarations_Kind: + this->compileVarDeclarations(builder, (VarDeclarationsStatement&) stmt); + break; + case Statement::kWhile_Kind: + this->compileWhile(builder, (WhileStatement&) stmt); + break; + default: + abort(); + } +} + +void JIT::compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc) { + // loop over fVectorCount pixels, running the body of the stage function for each of them + LLVMValueRef oldFunction = fCurrentFunction; + fCurrentFunction = newFunc; + std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]); + LLVMGetParams(fCurrentFunction, params.get()); + LLVMValueRef programParam = params.get()[1]; + LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext); + LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock; + LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock; + fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca"); + this->setBlock(builder, fAllocaBlock); + // temporaries to store the color channel vectors + LLVMValueRef rVec = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec"); + LLVMBuildStore(builder, params.get()[4], rVec); + LLVMValueRef gVec = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec"); + LLVMBuildStore(builder, params.get()[5], gVec); + LLVMValueRef bVec = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec"); + LLVMBuildStore(builder, params.get()[6], bVec); + LLVMValueRef aVec = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec"); + LLVMBuildStore(builder, params.get()[7], aVec); + LLVMValueRef color = LLVMBuildAlloca(builder, fFloat32Vector4Type, "color"); + fVariables[f.fDeclaration.fParameters[1]] = LLVMBuildTrunc(builder, params.get()[3], fInt32Type, + "y->Int32"); + fVariables[f.fDeclaration.fParameters[2]] = color; + LLVMValueRef ivar = LLVMBuildAlloca(builder, fInt32Type, "i"); + LLVMBuildStore(builder, LLVMConstInt(fInt32Type, 0, false), ivar); + LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start"); + this->setBlock(builder, start); + LLVMValueRef iload = LLVMBuildLoad(builder, ivar, "load i"); + fVariables[f.fDeclaration.fParameters[0]] = LLVMBuildAdd(builder, + LLVMBuildTrunc(builder, + params.get()[2], + fInt32Type, + "x->Int32"), + iload, + "x"); + LLVMValueRef vectorSize = LLVMConstInt(fInt32Type, fVectorCount, false); + LLVMValueRef test = LLVMBuildICmp(builder, LLVMIntSLT, iload, vectorSize, "i < vectorSize"); + LLVMBasicBlockRef loopBody = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "body"); + LLVMBasicBlockRef loopEnd = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "end"); + LLVMBuildCondBr(builder, test, loopBody, loopEnd); + this->setBlock(builder, loopBody); + LLVMValueRef vec = LLVMGetUndef(fFloat32Vector4Type); + // extract the r, g, b, and a values from the color channel vectors and store them into "color" + for (int i = 0; i < 4; ++i) { + vec = LLVMBuildInsertElement(builder, vec, + LLVMBuildExtractElement(builder, + params.get()[4 + i], + iload, "initial"), + LLVMConstInt(fInt32Type, i, false), + "vec build"); + } + LLVMBuildStore(builder, vec, color); + // write actual loop body + this->compileStatement(builder, *f.fBody); + // extract the r, g, b, and a values from "color" and stick them back into the color channel + // vectors + LLVMValueRef colorLoad = LLVMBuildLoad(builder, color, "color load"); + LLVMBuildStore(builder, + LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, rVec, "rVec"), + LLVMBuildExtractElement(builder, colorLoad, + LLVMConstInt(fInt32Type, 0, + false), + "rExtract"), + iload, "rInsert"), + rVec); + LLVMBuildStore(builder, + LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, gVec, "gVec"), + LLVMBuildExtractElement(builder, colorLoad, + LLVMConstInt(fInt32Type, 1, + false), + "gExtract"), + iload, "gInsert"), + gVec); + LLVMBuildStore(builder, + LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, bVec, "bVec"), + LLVMBuildExtractElement(builder, colorLoad, + LLVMConstInt(fInt32Type, 2, + false), + "bExtract"), + iload, "bInsert"), + bVec); + LLVMBuildStore(builder, + LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, aVec, "aVec"), + LLVMBuildExtractElement(builder, colorLoad, + LLVMConstInt(fInt32Type, 3, + false), + "aExtract"), + iload, "aInsert"), + aVec); + LLVMValueRef inc = LLVMBuildAdd(builder, iload, LLVMConstInt(fInt32Type, 1, false), "inc i"); + LLVMBuildStore(builder, inc, ivar); + LLVMBuildBr(builder, start); + this->setBlock(builder, loopEnd); + // increment program pointer, call the next stage + LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load"); + LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc); + LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType, "cast next->func"); + LLVMValueRef nextInc = LLVMBuildIntToPtr(builder, + LLVMBuildAdd(builder, + LLVMBuildPtrToInt(builder, + programParam, + fInt64Type, + "cast 1"), + LLVMConstInt(fInt64Type, PTR_SIZE, false), + "add"), + LLVMPointerType(fInt8PtrType, 0), "cast 2"); + LLVMValueRef args[STAGE_PARAM_COUNT] = { + params.get()[0], + nextInc, + params.get()[2], + params.get()[3], + LLVMBuildLoad(builder, rVec, "rVec"), + LLVMBuildLoad(builder, gVec, "gVec"), + LLVMBuildLoad(builder, bVec, "bVec"), + LLVMBuildLoad(builder, aVec, "aVec"), + params.get()[8], + params.get()[9], + params.get()[10], + params.get()[11] + }; + LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, ""); + LLVMBuildRetVoid(builder); + // finish + LLVMPositionBuilderAtEnd(builder, fAllocaBlock); + LLVMBuildBr(builder, start); + LLVMDisposeBuilder(builder); + if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) { + ABORT("verify failed\n"); + } + fAllocaBlock = oldAllocaBlock; + fCurrentBlock = oldCurrentBlock; + fCurrentFunction = oldFunction; +} + +// FIXME maybe pluggable code generators? Need to do something to separate all +// of the normal codegen from the vector codegen and break this up into multiple +// classes. + +bool JIT::getVectorLValue(LLVMBuilderRef builder, const Expression& e, + LLVMValueRef out[CHANNELS]) { + switch (e.fKind) { + case Expression::kVariableReference_Kind: + if (fColorParam == &((VariableReference&) e).fVariable) { + memcpy(out, fChannels, sizeof(fChannels)); + return true; + } + return false; + case Expression::kSwizzle_Kind: { + const Swizzle& s = (const Swizzle&) e; + LLVMValueRef base[CHANNELS]; + if (!this->getVectorLValue(builder, *s.fBase, base)) { + return false; + } + for (size_t i = 0; i < s.fComponents.size(); ++i) { + out[i] = base[s.fComponents[i]]; + } + return true; + } + default: + return false; + } +} + +bool JIT::getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left, + LLVMValueRef outLeft[CHANNELS], const Expression& right, + LLVMValueRef outRight[CHANNELS]) { + if (!this->compileVectorExpression(builder, left, outLeft)) { + return false; + } + int leftColumns = left.fType.columns(); + int rightColumns = right.fType.columns(); + if (leftColumns == 1 && rightColumns > 1) { + for (int i = 1; i < rightColumns; ++i) { + outLeft[i] = outLeft[0]; + } + } + if (!this->compileVectorExpression(builder, right, outRight)) { + return false; + } + if (rightColumns == 1 && leftColumns > 1) { + for (int i = 1; i < leftColumns; ++i) { + outRight[i] = outRight[0]; + } + } + return true; +} + +bool JIT::compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b, + LLVMValueRef out[CHANNELS]) { + LLVMValueRef left[CHANNELS]; + LLVMValueRef right[CHANNELS]; + #define VECTOR_BINARY(signedOp, unsignedOp, floatOp) { \ + if (!this->getVectorBinaryOperands(builder, *b.fLeft, left, *b.fRight, right)) { \ + return false; \ + } \ + for (int i = 0; i < b.fLeft->fType.columns(); ++i) { \ + switch (this->typeKind(b.fLeft->fType)) { \ + case kInt_TypeKind: \ + out[i] = signedOp(builder, left[i], right[i], "binary"); \ + break; \ + case kUInt_TypeKind: \ + out[i] = unsignedOp(builder, left[i], right[i], "binary"); \ + break; \ + case kFloat_TypeKind: \ + out[i] = floatOp(builder, left[i], right[i], "binary"); \ + break; \ + case kBool_TypeKind: \ + ASSERT(false); \ + break; \ + } \ + } \ + return true; \ + } + switch (b.fOperator) { + case Token::EQ: { + if (!this->getVectorLValue(builder, *b.fLeft, left)) { + return false; + } + if (!this->compileVectorExpression(builder, *b.fRight, right)) { + return false; + } + int columns = b.fRight->fType.columns(); + for (int i = 0; i < columns; ++i) { + LLVMBuildStore(builder, right[i], left[i]); + } + return true; + } + case Token::PLUS: + VECTOR_BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd); + case Token::MINUS: + VECTOR_BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub); + case Token::STAR: + VECTOR_BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul); + case Token::SLASH: + VECTOR_BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv); + case Token::PERCENT: + VECTOR_BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem); + case Token::BITWISEAND: + VECTOR_BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd); + case Token::BITWISEOR: + VECTOR_BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr); + default: + printf("unsupported operator: %s\n", b.description().c_str()); + return false; + } +} + +bool JIT::compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c, + LLVMValueRef out[CHANNELS]) { + switch (c.fType.kind()) { + case Type::kScalar_Kind: { + ASSERT(c.fArguments.size() == 1); + TypeKind from = this->typeKind(c.fArguments[0]->fType); + TypeKind to = this->typeKind(c.fType); + LLVMValueRef base[CHANNELS]; + if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) { + return false; + } + #define CONSTRUCT(fn) \ + out[0] = LLVMGetUndef(LLVMVectorType(this->getType(c.fType), fVectorCount)); \ + for (int i = 0; i < fVectorCount; ++i) { \ + LLVMValueRef index = LLVMConstInt(fInt32Type, i, false); \ + LLVMValueRef baseVal = LLVMBuildExtractElement(builder, base[0], index, \ + "construct extract"); \ + out[0] = LLVMBuildInsertElement(builder, out[0], \ + fn(builder, baseVal, this->getType(c.fType), \ + "cast"), \ + index, "construct insert"); \ + } \ + return true; + if (kFloat_TypeKind == to) { + if (kInt_TypeKind == from) { + CONSTRUCT(LLVMBuildSIToFP); + } + if (kUInt_TypeKind == from) { + CONSTRUCT(LLVMBuildUIToFP); + } + } + if (kInt_TypeKind == to) { + if (kFloat_TypeKind == from) { + CONSTRUCT(LLVMBuildFPToSI); + } + if (kUInt_TypeKind == from) { + return true; + } + } + if (kUInt_TypeKind == to) { + if (kFloat_TypeKind == from) { + CONSTRUCT(LLVMBuildFPToUI); + } + if (kInt_TypeKind == from) { + return base; + } + } + printf("%s\n", c.description().c_str()); + ABORT("unsupported constructor"); + } + case Type::kVector_Kind: { + if (c.fArguments.size() == 1) { + LLVMValueRef base[CHANNELS]; + if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) { + return false; + } + for (int i = 0; i < c.fType.columns(); ++i) { + out[i] = base[0]; + } + } else { + ASSERT(c.fArguments.size() == (size_t) c.fType.columns()); + for (int i = 0; i < c.fType.columns(); ++i) { + LLVMValueRef base[CHANNELS]; + if (!this->compileVectorExpression(builder, *c.fArguments[i], base)) { + return false; + } + out[i] = base[0]; + } + } + return true; + } + default: + break; + } + ABORT("unsupported constructor"); +} + +bool JIT::compileVectorFloatLiteral(LLVMBuilderRef builder, + const FloatLiteral& f, + LLVMValueRef out[CHANNELS]) { + LLVMValueRef value = LLVMConstReal(this->getType(f.fType), f.fValue); + LLVMValueRef values[MAX_VECTOR_COUNT]; + for (int i = 0; i < fVectorCount; ++i) { + values[i] = value; + } + out[0] = LLVMConstVector(values, fVectorCount); + return true; +} + + +bool JIT::compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s, + LLVMValueRef out[CHANNELS]) { + LLVMValueRef all[CHANNELS]; + if (!this->compileVectorExpression(builder, *s.fBase, all)) { + return false; + } + for (size_t i = 0; i < s.fComponents.size(); ++i) { + out[i] = all[s.fComponents[i]]; + } + return true; +} + +bool JIT::compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v, + LLVMValueRef out[CHANNELS]) { + if (&v.fVariable == fColorParam) { + for (int i = 0; i < CHANNELS; ++i) { + out[i] = LLVMBuildLoad(builder, fChannels[i], "variable reference"); + } + return true; + } + return false; +} + +bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr, + LLVMValueRef out[CHANNELS]) { + switch (expr.fKind) { + case Expression::kBinary_Kind: + return this->compileVectorBinary(builder, (const BinaryExpression&) expr, out); + case Expression::kConstructor_Kind: + return this->compileVectorConstructor(builder, (const Constructor&) expr, out); + case Expression::kFloatLiteral_Kind: + return this->compileVectorFloatLiteral(builder, (const FloatLiteral&) expr, out); + case Expression::kSwizzle_Kind: + return this->compileVectorSwizzle(builder, (const Swizzle&) expr, out); + case Expression::kVariableReference_Kind: + return this->compileVectorVariableReference(builder, (const VariableReference&) expr, + out); + default: + printf("failed expression: %s\n", expr.description().c_str()); + return false; + } +} + +bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt) { + switch (stmt.fKind) { + case Statement::kBlock_Kind: + for (const auto& s : ((const Block&) stmt).fStatements) { + if (!this->compileVectorStatement(builder, *s)) { + return false; + } + } + return true; + case Statement::kExpression_Kind: + LLVMValueRef result; + return this->compileVectorExpression(builder, + *((const ExpressionStatement&) stmt).fExpression, + &result); + default: + printf("failed statement: %s\n", stmt.description().c_str()); + return false; + } +} + +bool JIT::compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc) { + LLVMValueRef oldFunction = fCurrentFunction; + fCurrentFunction = newFunc; + std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]); + LLVMGetParams(fCurrentFunction, params.get()); + LLVMValueRef programParam = params.get()[1]; + LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext); + LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock; + LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock; + fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca"); + this->setBlock(builder, fAllocaBlock); + fChannels[0] = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec"); + LLVMBuildStore(builder, params.get()[4], fChannels[0]); + fChannels[1] = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec"); + LLVMBuildStore(builder, params.get()[5], fChannels[1]); + fChannels[2] = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec"); + LLVMBuildStore(builder, params.get()[6], fChannels[2]); + fChannels[3] = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec"); + LLVMBuildStore(builder, params.get()[7], fChannels[3]); + LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start"); + this->setBlock(builder, start); + bool success = this->compileVectorStatement(builder, *f.fBody); + if (success) { + // increment program pointer, call next + LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load"); + LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc); + LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType, + "cast next->func"); + LLVMValueRef nextInc = LLVMBuildIntToPtr(builder, + LLVMBuildAdd(builder, + LLVMBuildPtrToInt(builder, + programParam, + fInt64Type, + "cast 1"), + LLVMConstInt(fInt64Type, PTR_SIZE, + false), + "add"), + LLVMPointerType(fInt8PtrType, 0), "cast 2"); + LLVMValueRef args[STAGE_PARAM_COUNT] = { + params.get()[0], + nextInc, + params.get()[2], + params.get()[3], + LLVMBuildLoad(builder, fChannels[0], "rVec"), + LLVMBuildLoad(builder, fChannels[1], "gVec"), + LLVMBuildLoad(builder, fChannels[2], "bVec"), + LLVMBuildLoad(builder, fChannels[3], "aVec"), + params.get()[8], + params.get()[9], + params.get()[10], + params.get()[11] + }; + LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, ""); + LLVMBuildRetVoid(builder); + // finish + LLVMPositionBuilderAtEnd(builder, fAllocaBlock); + LLVMBuildBr(builder, start); + LLVMDisposeBuilder(builder); + if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) { + ABORT("verify failed\n"); + } + } else { + LLVMDeleteBasicBlock(fAllocaBlock); + LLVMDeleteBasicBlock(start); + } + + fAllocaBlock = oldAllocaBlock; + fCurrentBlock = oldCurrentBlock; + fCurrentFunction = oldFunction; + return success; +} + +LLVMValueRef JIT::compileStageFunction(const FunctionDefinition& f) { + LLVMTypeRef returnType = fVoidType; + LLVMTypeRef parameterTypes[12] = { fSizeTType, LLVMPointerType(fInt8PtrType, 0), fSizeTType, + fSizeTType, fFloat32VectorType, fFloat32VectorType, + fFloat32VectorType, fFloat32VectorType, fFloat32VectorType, + fFloat32VectorType, fFloat32VectorType, fFloat32VectorType }; + LLVMTypeRef stageFuncType = LLVMFunctionType(returnType, parameterTypes, 12, false); + LLVMValueRef result = LLVMAddFunction(fModule, + (String(f.fDeclaration.fName) + "$stage").c_str(), + stageFuncType); + fColorParam = f.fDeclaration.fParameters[2]; + if (!this->compileStageFunctionVector(f, result)) { + // vectorization failed, fall back to looping over the pixels + this->compileStageFunctionLoop(f, result); + } + return result; +} + +bool JIT::hasStageSignature(const FunctionDeclaration& f) { + return f.fReturnType == *fProgram->fContext->fVoid_Type && + f.fParameters.size() == 3 && + f.fParameters[0]->fType == *fProgram->fContext->fInt_Type && + f.fParameters[0]->fModifiers.fFlags == 0 && + f.fParameters[1]->fType == *fProgram->fContext->fInt_Type && + f.fParameters[1]->fModifiers.fFlags == 0 && + f.fParameters[2]->fType == *fProgram->fContext->fFloat4_Type && + f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag); +} + +LLVMValueRef JIT::compileFunction(const FunctionDefinition& f) { + if (this->hasStageSignature(f.fDeclaration)) { + this->compileStageFunction(f); + // we compile foo$stage *in addition* to compiling foo, as we can't be sure that the intent + // was to produce an SkJumper stage just because the signature matched or that the function + // is not otherwise called. May need a better way to handle this. + } + LLVMTypeRef returnType = this->getType(f.fDeclaration.fReturnType); + std::vector<LLVMTypeRef> parameterTypes; + for (const auto& p : f.fDeclaration.fParameters) { + LLVMTypeRef type = this->getType(p->fType); + if (p->fModifiers.fFlags & Modifiers::kOut_Flag) { + type = LLVMPointerType(type, 0); + } + parameterTypes.push_back(type); + } + fCurrentFunction = LLVMAddFunction(fModule, + String(f.fDeclaration.fName).c_str(), + LLVMFunctionType(returnType, parameterTypes.data(), + parameterTypes.size(), false)); + fFunctions[&f.fDeclaration] = fCurrentFunction; + + std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[parameterTypes.size()]); + LLVMGetParams(fCurrentFunction, params.get()); + for (size_t i = 0; i < f.fDeclaration.fParameters.size(); ++i) { + fVariables[f.fDeclaration.fParameters[i]] = params.get()[i]; + } + LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext); + fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca"); + LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start"); + fCurrentBlock = start; + LLVMPositionBuilderAtEnd(builder, fCurrentBlock); + this->compileStatement(builder, *f.fBody); + if (!ends_with_branch(*f.fBody)) { + if (f.fDeclaration.fReturnType == *fProgram->fContext->fVoid_Type) { + LLVMBuildRetVoid(builder); + } else { + LLVMBuildUnreachable(builder); + } + } + LLVMPositionBuilderAtEnd(builder, fAllocaBlock); + LLVMBuildBr(builder, start); + LLVMDisposeBuilder(builder); + if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) { + ABORT("verify failed\n"); + } + return fCurrentFunction; +} + +void JIT::createModule() { + fPromotedParameters.clear(); + fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext); + this->loadBuiltinFunctions(); + // LLVM doesn't do void*, have to declare it as int8* + LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType }; + fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType, + appendParams, + 3, + false)); + LLVMTypeRef appendCallbackParams[2] = { fInt8PtrType, fInt8PtrType }; + fAppendCallbackFunc = LLVMAddFunction(fModule, "sksl_pipeline_append_callback", + LLVMFunctionType(fVoidType, appendCallbackParams, 2, + false)); + + LLVMTypeRef debugParams[3] = { fFloat32Type }; + fDebugFunc = LLVMAddFunction(fModule, "sksl_debug_print", LLVMFunctionType(fVoidType, + debugParams, + 1, + false)); + + for (const auto& e : fProgram->fElements) { + ASSERT(e->fKind == ProgramElement::kFunction_Kind); + this->compileFunction((FunctionDefinition&) *e); + } +} + +std::unique_ptr<JIT::Module> JIT::compile(std::unique_ptr<Program> program) { + fProgram = std::move(program); + this->createModule(); + this->optimize(); + return std::unique_ptr<Module>(new Module(std::move(fProgram), fSharedModule, fJITStack)); +} + +void JIT::optimize() { + LLVMPassManagerBuilderRef pmb = LLVMPassManagerBuilderCreate(); + LLVMPassManagerBuilderSetOptLevel(pmb, 3); + LLVMPassManagerRef functionPM = LLVMCreateFunctionPassManagerForModule(fModule); + LLVMPassManagerBuilderPopulateFunctionPassManager(pmb, functionPM); + LLVMPassManagerRef modulePM = LLVMCreatePassManager(); + LLVMPassManagerBuilderPopulateModulePassManager(pmb, modulePM); + LLVMInitializeFunctionPassManager(functionPM); + + LLVMValueRef func = LLVMGetFirstFunction(fModule); + for (;;) { + if (!func) { + break; + } + LLVMRunFunctionPassManager(functionPM, func); + func = LLVMGetNextFunction(func); + } + LLVMRunPassManager(modulePM, fModule); + LLVMDisposePassManager(functionPM); + LLVMDisposePassManager(modulePM); + LLVMPassManagerBuilderDispose(pmb); + + std::string error_string; + if (LLVMLoadLibraryPermanently(nullptr)) { + ABORT("LLVMLoadLibraryPermanently failed"); + } + char* defaultTriple = LLVMGetDefaultTargetTriple(); + char* error; + LLVMTargetRef target; + if (LLVMGetTargetFromTriple(defaultTriple, &target, &error)) { + ABORT("LLVMGetTargetFromTriple failed"); + } + + if (!LLVMTargetHasJIT(target)) { + ABORT("!LLVMTargetHasJIT"); + } + LLVMTargetMachineRef targetMachine = LLVMCreateTargetMachine(target, + defaultTriple, + fCPU, + nullptr, + LLVMCodeGenLevelDefault, + LLVMRelocDefault, + LLVMCodeModelJITDefault); + LLVMDisposeMessage(defaultTriple); + LLVMTargetDataRef dataLayout = LLVMCreateTargetDataLayout(targetMachine); + LLVMSetModuleDataLayout(fModule, dataLayout); + LLVMDisposeTargetData(dataLayout); + + fJITStack = LLVMOrcCreateInstance(targetMachine); + fSharedModule = LLVMOrcMakeSharedModule(fModule); + LLVMOrcModuleHandle orcModule; + LLVMOrcAddEagerlyCompiledIR(fJITStack, &orcModule, fSharedModule, + (LLVMOrcSymbolResolverFn) resolveSymbol, this); + LLVMDisposeTargetMachine(targetMachine); +} + +void* JIT::Module::getSymbol(const char* name) { + LLVMOrcTargetAddress result; + if (LLVMOrcGetSymbolAddress(fJITStack, &result, name)) { + ABORT("GetSymbolAddress error"); + } + if (!result) { + ABORT("symbol not found"); + } + return (void*) result; +} + +void* JIT::Module::getJumperStage(const char* name) { + return this->getSymbol((String(name) + "$stage").c_str()); +} + +} // namespace + +#endif // SK_LLVM_AVAILABLE + +#endif // SKSL_STANDALONE |