diff options
author | Ethan Nicholas <ethannicholas@google.com> | 2018-03-27 14:10:52 -0400 |
---|---|---|
committer | Skia Commit-Bot <skia-commit-bot@chromium.org> | 2018-03-27 18:39:13 +0000 |
commit | 26a9aad63b60c9cbbdfa87c212a4e76ce55e7373 (patch) | |
tree | d4a42afd75bdff6c7815be7bf78b42cab742df19 /src/sksl | |
parent | 3560b58de36988e1fba54d9ac341735ab849e913 (diff) |
initial SkSLJIT checkin
Docs-Preview: https://skia.org/?cl=112204
Bug: skia:
Change-Id: I10042a0200db00bd8ff8078467c409b1cf191f50
Reviewed-on: https://skia-review.googlesource.com/112204
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: Mike Klein <mtklein@chromium.org>
Diffstat (limited to 'src/sksl')
-rw-r--r-- | src/sksl/SkSLCFGGenerator.cpp | 1 | ||||
-rw-r--r-- | src/sksl/SkSLCompiler.cpp | 77 | ||||
-rw-r--r-- | src/sksl/SkSLCompiler.h | 6 | ||||
-rw-r--r-- | src/sksl/SkSLContext.h | 2 | ||||
-rw-r--r-- | src/sksl/SkSLIRGenerator.cpp | 90 | ||||
-rw-r--r-- | src/sksl/SkSLIRGenerator.h | 2 | ||||
-rw-r--r-- | src/sksl/SkSLInterpreter.cpp | 473 | ||||
-rw-r--r-- | src/sksl/SkSLInterpreter.h | 89 | ||||
-rw-r--r-- | src/sksl/SkSLJIT.cpp | 1747 | ||||
-rw-r--r-- | src/sksl/SkSLJIT.h | 344 | ||||
-rw-r--r-- | src/sksl/ir/SkSLAppendStage.h | 53 | ||||
-rw-r--r-- | src/sksl/ir/SkSLExpression.h | 1 | ||||
-rw-r--r-- | src/sksl/ir/SkSLFunctionReference.h | 1 | ||||
-rw-r--r-- | src/sksl/ir/SkSLProgram.h | 7 | ||||
-rw-r--r-- | src/sksl/sksl_cpu.inc | 12 |
15 files changed, 2868 insertions, 37 deletions
diff --git a/src/sksl/SkSLCFGGenerator.cpp b/src/sksl/SkSLCFGGenerator.cpp index c56d0c1b0f..852f48d71d 100644 --- a/src/sksl/SkSLCFGGenerator.cpp +++ b/src/sksl/SkSLCFGGenerator.cpp @@ -387,6 +387,7 @@ void CFGGenerator::addExpression(CFG& cfg, std::unique_ptr<Expression>* e, bool cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, constantPropagate, e, nullptr }); break; + case Expression::kAppendStage_Kind: // fall through case Expression::kBoolLiteral_Kind: // fall through case Expression::kFloatLiteral_Kind: // fall through case Expression::kIntLiteral_Kind: // fall through diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp index bbaaf407d2..4650e5c643 100644 --- a/src/sksl/SkSLCompiler.cpp +++ b/src/sksl/SkSLCompiler.cpp @@ -54,17 +54,22 @@ static const char* SKSL_FP_INCLUDE = #include "sksl_fp.inc" ; +static const char* SKSL_CPU_INCLUDE = +#include "sksl_cpu.inc" +; + namespace SkSL { Compiler::Compiler(Flags flags) : fFlags(flags) +, fContext(new Context()) , fErrorCount(0) { auto types = std::shared_ptr<SymbolTable>(new SymbolTable(this)); auto symbols = std::shared_ptr<SymbolTable>(new SymbolTable(types, this)); - fIRGenerator = new IRGenerator(&fContext, symbols, *this); + fIRGenerator = new IRGenerator(fContext.get(), symbols, *this); fTypes = types; - #define ADD_TYPE(t) types->addWithoutOwnership(fContext.f ## t ## _Type->fName, \ - fContext.f ## t ## _Type.get()) + #define ADD_TYPE(t) types->addWithoutOwnership(fContext->f ## t ## _Type->fName, \ + fContext->f ## t ## _Type.get()) ADD_TYPE(Void); ADD_TYPE(Float); ADD_TYPE(Float2); @@ -188,15 +193,16 @@ Compiler::Compiler(Flags flags) ADD_TYPE(GSampler2DArrayShadow); ADD_TYPE(GSamplerCubeArrayShadow); ADD_TYPE(FragmentProcessor); + ADD_TYPE(SkRasterPipeline); StringFragment skCapsName("sk_Caps"); Variable* skCaps = new Variable(-1, Modifiers(), skCapsName, - *fContext.fSkCaps_Type, Variable::kGlobal_Storage); + *fContext->fSkCaps_Type, Variable::kGlobal_Storage); fIRGenerator->fSymbolTable->add(skCapsName, std::unique_ptr<Symbol>(skCaps)); StringFragment skArgsName("sk_Args"); Variable* skArgs = new Variable(-1, Modifiers(), skArgsName, - *fContext.fSkArgs_Type, Variable::kGlobal_Storage); + *fContext->fSkArgs_Type, Variable::kGlobal_Storage); fIRGenerator->fSymbolTable->add(skArgsName, std::unique_ptr<Symbol>(skArgs)); std::vector<std::unique_ptr<ProgramElement>> ignored; @@ -232,19 +238,19 @@ void Compiler::addDefinition(const Expression* lvalue, std::unique_ptr<Expressio // but since we pass foo as a whole it is flagged as an error) unless we perform a much // more complicated whole-program analysis. This is probably good enough. this->addDefinition(((Swizzle*) lvalue)->fBase.get(), - (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, + (std::unique_ptr<Expression>*) &fContext->fDefined_Expression, definitions); break; case Expression::kIndex_Kind: // see comments in Swizzle this->addDefinition(((IndexExpression*) lvalue)->fBase.get(), - (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, + (std::unique_ptr<Expression>*) &fContext->fDefined_Expression, definitions); break; case Expression::kFieldAccess_Kind: // see comments in Swizzle this->addDefinition(((FieldAccess*) lvalue)->fBase.get(), - (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, + (std::unique_ptr<Expression>*) &fContext->fDefined_Expression, definitions); break; case Expression::kTernary_Kind: @@ -252,10 +258,10 @@ void Compiler::addDefinition(const Expression* lvalue, std::unique_ptr<Expressio // This allows for false positives (meaning we fail to detect that a variable might not // have been assigned), but is preferable to false negatives. this->addDefinition(((TernaryExpression*) lvalue)->fIfTrue.get(), - (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, + (std::unique_ptr<Expression>*) &fContext->fDefined_Expression, definitions); this->addDefinition(((TernaryExpression*) lvalue)->fIfFalse.get(), - (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, + (std::unique_ptr<Expression>*) &fContext->fDefined_Expression, definitions); break; default: @@ -278,9 +284,9 @@ void Compiler::addDefinitions(const BasicBlock::Node& node, this->addDefinition(b->fLeft.get(), &b->fRight, definitions); } else if (Compiler::IsAssignment(b->fOperator)) { this->addDefinition( - b->fLeft.get(), - (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, - definitions); + b->fLeft.get(), + (std::unique_ptr<Expression>*) &fContext->fDefined_Expression, + definitions); } break; @@ -289,9 +295,9 @@ void Compiler::addDefinitions(const BasicBlock::Node& node, const PrefixExpression* p = (PrefixExpression*) expr; if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) { this->addDefinition( - p->fOperand.get(), - (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, - definitions); + p->fOperand.get(), + (std::unique_ptr<Expression>*) &fContext->fDefined_Expression, + definitions); } break; } @@ -299,9 +305,9 @@ void Compiler::addDefinitions(const BasicBlock::Node& node, const PostfixExpression* p = (PostfixExpression*) expr; if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) { this->addDefinition( - p->fOperand.get(), - (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, - definitions); + p->fOperand.get(), + (std::unique_ptr<Expression>*) &fContext->fDefined_Expression, + definitions); } break; } @@ -309,9 +315,9 @@ void Compiler::addDefinitions(const BasicBlock::Node& node, const VariableReference* v = (VariableReference*) expr; if (v->fRefKind != VariableReference::kRead_RefKind) { this->addDefinition( - v, - (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, - definitions); + v, + (std::unique_ptr<Expression>*) &fContext->fDefined_Expression, + definitions); } } default: @@ -343,6 +349,9 @@ void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) { // propagate definitions to exits for (BlockId exitId : block.fExits) { + if (exitId == blockId) { + continue; + } BasicBlock& exit = cfg->fBlocks[exitId]; for (const auto& pair : after) { std::unique_ptr<Expression>* e1 = pair.second; @@ -359,7 +368,7 @@ void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) { workList->insert(exitId); if (e1 && e2) { exit.fBefore[pair.first] = - (std::unique_ptr<Expression>*) &fContext.fDefined_Expression; + (std::unique_ptr<Expression>*) &fContext->fDefined_Expression; } else { exit.fBefore[pair.first] = nullptr; } @@ -990,7 +999,7 @@ void Compiler::simplifyStatement(DefinitionMap& definitions, continue; } ASSERT(c->fValue->fKind == s.fValue->fKind); - found = c->fValue->compareConstant(fContext, *s.fValue); + found = c->fValue->compareConstant(*fContext, *s.fValue); if (found) { std::unique_ptr<Statement> newBlock = block_for_case(&s, c.get()); if (newBlock) { @@ -1153,7 +1162,7 @@ void Compiler::scanCFG(FunctionDefinition& f) { } // check for missing return - if (f.fDeclaration.fReturnType != *fContext.fVoid_Type) { + if (f.fDeclaration.fReturnType != *fContext->fVoid_Type) { if (cfg.fBlocks[cfg.fExit].fEntrances.size()) { this->error(f.fOffset, String("function can exit without returning a value")); } @@ -1183,6 +1192,10 @@ std::unique_ptr<Program> Compiler::convertProgram(Program::Kind kind, String tex fIRGenerator->convertProgram(kind, SKSL_FP_INCLUDE, strlen(SKSL_FP_INCLUDE), *fTypes, &elements); break; + case Program::kCPU_Kind: + fIRGenerator->convertProgram(kind, SKSL_CPU_INCLUDE, strlen(SKSL_CPU_INCLUDE), + *fTypes, &elements); + break; } fIRGenerator->fSymbolTable->markAllFunctionsBuiltin(); for (auto& element : elements) { @@ -1203,7 +1216,7 @@ std::unique_ptr<Program> Compiler::convertProgram(Program::Kind kind, String tex auto result = std::unique_ptr<Program>(new Program(kind, std::move(textPtr), settings, - &fContext, + fContext, std::move(elements), fIRGenerator->fSymbolTable, fIRGenerator->fInputs)); @@ -1220,7 +1233,7 @@ bool Compiler::toSPIRV(const Program& program, OutputStream& out) { #ifdef SK_ENABLE_SPIRV_VALIDATION StringStream buffer; fSource = program.fSource.get(); - SPIRVCodeGenerator cg(&fContext, &program, this, &buffer); + SPIRVCodeGenerator cg(fContext.get(), &program, this, &buffer); bool result = cg.generateCode(); fSource = nullptr; if (result) { @@ -1238,7 +1251,7 @@ bool Compiler::toSPIRV(const Program& program, OutputStream& out) { } #else fSource = program.fSource.get(); - SPIRVCodeGenerator cg(&fContext, &program, this, &out); + SPIRVCodeGenerator cg(fContext.get(), &program, this, &out); bool result = cg.generateCode(); fSource = nullptr; #endif @@ -1257,7 +1270,7 @@ bool Compiler::toSPIRV(const Program& program, String* out) { bool Compiler::toGLSL(const Program& program, OutputStream& out) { fSource = program.fSource.get(); - GLSLCodeGenerator cg(&fContext, &program, this, &out); + GLSLCodeGenerator cg(fContext.get(), &program, this, &out); bool result = cg.generateCode(); fSource = nullptr; this->writeErrorCount(); @@ -1274,7 +1287,7 @@ bool Compiler::toGLSL(const Program& program, String* out) { } bool Compiler::toMetal(const Program& program, OutputStream& out) { - MetalCodeGenerator cg(&fContext, &program, this, &out); + MetalCodeGenerator cg(fContext.get(), &program, this, &out); bool result = cg.generateCode(); this->writeErrorCount(); return result; @@ -1282,7 +1295,7 @@ bool Compiler::toMetal(const Program& program, OutputStream& out) { bool Compiler::toCPP(const Program& program, String name, OutputStream& out) { fSource = program.fSource.get(); - CPPCodeGenerator cg(&fContext, &program, this, name, &out); + CPPCodeGenerator cg(fContext.get(), &program, this, name, &out); bool result = cg.generateCode(); fSource = nullptr; this->writeErrorCount(); @@ -1291,7 +1304,7 @@ bool Compiler::toCPP(const Program& program, String name, OutputStream& out) { bool Compiler::toH(const Program& program, String name, OutputStream& out) { fSource = program.fSource.get(); - HCodeGenerator cg(&fContext, &program, this, name, &out); + HCodeGenerator cg(fContext.get(), &program, this, name, &out); bool result = cg.generateCode(); fSource = nullptr; this->writeErrorCount(); diff --git a/src/sksl/SkSLCompiler.h b/src/sksl/SkSLCompiler.h index 9f8ef4f18a..0ed6a3bdbb 100644 --- a/src/sksl/SkSLCompiler.h +++ b/src/sksl/SkSLCompiler.h @@ -88,6 +88,10 @@ public: return fErrorCount; } + Context& context() { + return *fContext; + } + static const char* OperatorName(Token::Kind token); static bool IsAssignment(Token::Kind token); @@ -134,7 +138,7 @@ private: int fFlags; const String* fSource; - Context fContext; + std::shared_ptr<Context> fContext; int fErrorCount; String fErrorText; }; diff --git a/src/sksl/SkSLContext.h b/src/sksl/SkSLContext.h index 407dbf8e82..61dd8042e4 100644 --- a/src/sksl/SkSLContext.h +++ b/src/sksl/SkSLContext.h @@ -185,6 +185,7 @@ public: , fSkCaps_Type(new Type("$sk_Caps")) , fSkArgs_Type(new Type("$sk_Args")) , fFragmentProcessor_Type(new Type("fragmentProcessor")) + , fSkRasterPipeline_Type(new Type("SkRasterPipeline")) , fDefined_Expression(new Defined(*fInvalid_Type)) {} static std::vector<const Type*> static_type(const Type& t) { @@ -333,6 +334,7 @@ public: const std::unique_ptr<Type> fSkCaps_Type; const std::unique_ptr<Type> fSkArgs_Type; const std::unique_ptr<Type> fFragmentProcessor_Type; + const std::unique_ptr<Type> fSkRasterPipeline_Type; // dummy expression used to mark that a variable has a value during dataflow analysis (when it // could have several different values, or the analyzer is otherwise unable to assign it a diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp index 68d9f9c0ca..815ec152cd 100644 --- a/src/sksl/SkSLIRGenerator.cpp +++ b/src/sksl/SkSLIRGenerator.cpp @@ -17,6 +17,7 @@ #include "ast/SkSLASTFloatLiteral.h" #include "ast/SkSLASTIndexSuffix.h" #include "ast/SkSLASTIntLiteral.h" +#include "ir/SkSLAppendStage.h" #include "ir/SkSLBinaryExpression.h" #include "ir/SkSLBoolLiteral.h" #include "ir/SkSLBreakStatement.h" @@ -1971,6 +1972,91 @@ std::unique_ptr<Expression> IRGenerator::convertTypeField(int offset, const Type return result; } +std::unique_ptr<Expression> IRGenerator::convertAppend(int offset, + const std::vector<std::unique_ptr<ASTExpression>>& args) { +#ifndef SKSL_STANDALONE + if (args.size() < 2) { + fErrors.error(offset, "'append' requires at least two arguments"); + return nullptr; + } + std::unique_ptr<Expression> pipeline = this->convertExpression(*args[0]); + if (!pipeline) { + return nullptr; + } + if (pipeline->fType != *fContext.fSkRasterPipeline_Type) { + fErrors.error(offset, "first argument of 'append' must have type 'SkRasterPipeline'"); + return nullptr; + } + if (ASTExpression::kIdentifier_Kind != args[1]->fKind) { + fErrors.error(offset, "'" + args[1]->description() + "' is not a valid stage"); + return nullptr; + } + StringFragment name = ((const ASTIdentifier&) *args[1]).fText; + SkRasterPipeline::StockStage stage = SkRasterPipeline::premul; + std::vector<std::unique_ptr<Expression>> stageArgs; + stageArgs.push_back(std::move(pipeline)); + for (size_t i = 2; i < args.size(); ++i) { + std::unique_ptr<Expression> arg = this->convertExpression(*args[i]); + if (!arg) { + return nullptr; + } + stageArgs.push_back(std::move(arg)); + } + size_t expectedArgs = 0; + // FIXME use a map + if ("premul" == name) { + stage = SkRasterPipeline::premul; + } + else if ("unpremul" == name) { + stage = SkRasterPipeline::unpremul; + } + else if ("clamp_0" == name) { + stage = SkRasterPipeline::clamp_0; + } + else if ("clamp_1" == name) { + stage = SkRasterPipeline::clamp_1; + } + else if ("matrix_4x5" == name) { + expectedArgs = 1; + stage = SkRasterPipeline::matrix_4x5; + if (1 == stageArgs.size() && stageArgs[0]->fType.fName != "float[20]") { + fErrors.error(offset, "pipeline stage '" + name + "' expected a float[20] argument"); + return nullptr; + } + } + else { + bool found = false; + for (const auto& e : *fProgramElements) { + if (ProgramElement::kFunction_Kind == e->fKind) { + const FunctionDefinition& f = (const FunctionDefinition&) *e; + if (f.fDeclaration.fName == name) { + stage = SkRasterPipeline::callback; + std::vector<const FunctionDeclaration*> functions = { &f.fDeclaration }; + stageArgs.emplace_back(new FunctionReference(fContext, offset, functions)); + found = true; + break; + } + } + } + if (!found) { + fErrors.error(offset, "'" + name + "' is not a valid pipeline stage"); + return nullptr; + } + } + if (args.size() != expectedArgs + 2) { + fErrors.error(offset, "pipeline stage '" + name + "' expected an additional argument " + + "count of " + to_string((int) expectedArgs) + ", but found " + + to_string((int) args.size() - 1)); + return nullptr; + } + return std::unique_ptr<Expression>(new AppendStage(fContext, offset, stage, + std::move(stageArgs))); +#else + ASSERT(false); + return nullptr; +#endif +} + std::unique_ptr<Expression> IRGenerator::convertSuffixExpression( const ASTSuffixExpression& expression) { std::unique_ptr<Expression> base = this->convertExpression(*expression.fBase); @@ -1996,6 +2082,10 @@ std::unique_ptr<Expression> IRGenerator::convertSuffixExpression( } case ASTSuffix::kCall_Kind: { auto rawArguments = &((ASTCallSuffix&) *expression.fSuffix).fArguments; + if (Expression::kFunctionReference_Kind == base->fKind && + "append" == ((const FunctionReference&) *base).fFunctions[0]->fName) { + return convertAppend(expression.fOffset, *rawArguments); + } std::vector<std::unique_ptr<Expression>> arguments; for (size_t i = 0; i < rawArguments->size(); i++) { std::unique_ptr<Expression> converted = diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h index e8ff1d21c8..c78c195c0b 100644 --- a/src/sksl/SkSLIRGenerator.h +++ b/src/sksl/SkSLIRGenerator.h @@ -114,6 +114,8 @@ private: std::vector<std::unique_ptr<Expression>> arguments); int coercionCost(const Expression& expr, const Type& type); std::unique_ptr<Expression> coerce(std::unique_ptr<Expression> expr, const Type& type); + std::unique_ptr<Expression> convertAppend(int offset, + const std::vector<std::unique_ptr<ASTExpression>>& args); std::unique_ptr<Block> convertBlock(const ASTBlock& block); std::unique_ptr<Statement> convertBreak(const ASTBreakStatement& b); std::unique_ptr<Expression> convertNumberConstructor( diff --git a/src/sksl/SkSLInterpreter.cpp b/src/sksl/SkSLInterpreter.cpp new file mode 100644 index 0000000000..c9b7ceb44c --- /dev/null +++ b/src/sksl/SkSLInterpreter.cpp @@ -0,0 +1,473 @@ +/* + * Copyright 2018 Google Inc. + * + * Use of this source code is governed by a BSD-style license that can be + * found in the LICENSE file. + */ + +#ifndef SKSL_STANDALONE + +#include "SkSLInterpreter.h" +#include "ir/SkSLBinaryExpression.h" +#include "ir/SkSLExpressionStatement.h" +#include "ir/SkSLForStatement.h" +#include "ir/SkSLFunctionCall.h" +#include "ir/SkSLFunctionReference.h" +#include "ir/SkSLIfStatement.h" +#include "ir/SkSLIndexExpression.h" +#include "ir/SkSLPostfixExpression.h" +#include "ir/SkSLPrefixExpression.h" +#include "ir/SkSLProgram.h" +#include "ir/SkSLStatement.h" +#include "ir/SkSLTernaryExpression.h" +#include "ir/SkSLVarDeclarations.h" +#include "ir/SkSLVarDeclarationsStatement.h" +#include "ir/SkSLVariableReference.h" +#include "SkRasterPipeline.h" +#include "../jumper/SkJumper.h" + +namespace SkSL { + +void Interpreter::run() { + for (const auto& e : fProgram->fElements) { + if (ProgramElement::kFunction_Kind == e->fKind) { + const FunctionDefinition& f = (const FunctionDefinition&) *e; + if ("appendStages" == f.fDeclaration.fName) { + this->run(f); + return; + } + } + } + ASSERT(false); +} + +static int SizeOf(const Type& type) { + return 1; +} + +void Interpreter::run(const FunctionDefinition& f) { + fVars.emplace_back(); + StackIndex current = (StackIndex) fStack.size(); + for (int i = f.fDeclaration.fParameters.size() - 1; i >= 0; --i) { + current -= SizeOf(f.fDeclaration.fParameters[i]->fType); + fVars.back()[f.fDeclaration.fParameters[i]] = current; + } + fCurrentIndex.push_back({ f.fBody.get(), 0 }); + while (fCurrentIndex.size()) { + this->runStatement(); + } +} + +void Interpreter::push(Value value) { + fStack.push_back(value); +} + +Interpreter::Value Interpreter::pop() { + auto iter = fStack.end() - 1; + Value result = *iter; + fStack.erase(iter); + return result; +} + + Interpreter::StackIndex Interpreter::stackAlloc(int count) { + int result = fStack.size(); + for (int i = 0; i < count; ++i) { + fStack.push_back(Value((int) 0xDEADBEEF)); + } + return result; +} + +void Interpreter::runStatement() { + const Statement& stmt = *fCurrentIndex.back().fStatement; + const size_t index = fCurrentIndex.back().fIndex; + fCurrentIndex.pop_back(); + switch (stmt.fKind) { + case Statement::kBlock_Kind: { + const Block& b = (const Block&) stmt; + if (!b.fStatements.size()) { + break; + } + ASSERT(index < b.fStatements.size()); + if (index < b.fStatements.size() - 1) { + fCurrentIndex.push_back({ &b, index + 1 }); + } + fCurrentIndex.push_back({ b.fStatements[index].get(), 0 }); + break; + } + case Statement::kBreak_Kind: + ASSERT(index == 0); + abort(); + case Statement::kContinue_Kind: + ASSERT(index == 0); + abort(); + case Statement::kDiscard_Kind: + ASSERT(index == 0); + abort(); + case Statement::kDo_Kind: + abort(); + case Statement::kExpression_Kind: + ASSERT(index == 0); + this->evaluate(*((const ExpressionStatement&) stmt).fExpression); + break; + case Statement::kFor_Kind: { + ForStatement& f = (ForStatement&) stmt; + switch (index) { + case 0: + // initializer + fCurrentIndex.push_back({ &f, 1 }); + if (f.fInitializer) { + fCurrentIndex.push_back({ f.fInitializer.get(), 0 }); + } + break; + case 1: + // test & body + if (f.fTest && !evaluate(*f.fTest).fBool) { + break; + } else { + fCurrentIndex.push_back({ &f, 2 }); + fCurrentIndex.push_back({ f.fStatement.get(), 0 }); + } + break; + case 2: + // next + if (f.fNext) { + this->evaluate(*f.fNext); + } + fCurrentIndex.push_back({ &f, 1 }); + break; + default: + ASSERT(false); + } + break; + } + case Statement::kGroup_Kind: + abort(); + case Statement::kIf_Kind: { + IfStatement& i = (IfStatement&) stmt; + if (evaluate(*i.fTest).fBool) { + fCurrentIndex.push_back({ i.fIfTrue.get(), 0 }); + } else if (i.fIfFalse) { + fCurrentIndex.push_back({ i.fIfFalse.get(), 0 }); + } + break; + } + case Statement::kNop_Kind: + ASSERT(index == 0); + break; + case Statement::kReturn_Kind: + ASSERT(index == 0); + abort(); + case Statement::kSwitch_Kind: + abort(); + case Statement::kVarDeclarations_Kind: + ASSERT(index == 0); + for (const auto& decl :((const VarDeclarationsStatement&) stmt).fDeclaration->fVars) { + const Variable* var = ((VarDeclaration&) *decl).fVar; + StackIndex pos = this->stackAlloc(SizeOf(var->fType)); + fVars.back()[var] = pos; + if (var->fInitialValue) { + fStack[pos] = this->evaluate(*var->fInitialValue); + } + } + break; + case Statement::kWhile_Kind: + abort(); + default: + abort(); + } +} + +static Interpreter::TypeKind type_kind(const Type& type) { + if (type.fName == "int") { + return Interpreter::kInt_TypeKind; + } else if (type.fName == "float") { + return Interpreter::kFloat_TypeKind; + } + ABORT("unsupported type: %s\n", type.description().c_str()); +} + +Interpreter::StackIndex Interpreter::getLValue(const Expression& expr) { + switch (expr.fKind) { + case Expression::kFieldAccess_Kind: + break; + case Expression::kIndex_Kind: { + const IndexExpression& idx = (const IndexExpression&) expr; + return this->evaluate(*idx.fBase).fInt + this->evaluate(*idx.fIndex).fInt; + } + case Expression::kSwizzle_Kind: + break; + case Expression::kVariableReference_Kind: + ASSERT(fVars.size()); + ASSERT(fVars.back().find(&((VariableReference&) expr).fVariable) != + fVars.back().end()); + return fVars.back()[&((VariableReference&) expr).fVariable]; + case Expression::kTernary_Kind: { + const TernaryExpression& t = (const TernaryExpression&) expr; + return this->getLValue(this->evaluate(*t.fTest).fBool ? *t.fIfTrue : *t.fIfFalse); + } + case Expression::kTypeReference_Kind: + break; + default: + break; + } + ABORT("unsupported lvalue"); +} + +struct CallbackCtx : public SkJumper_CallbackCtx { + Interpreter* fInterpreter; + const FunctionDefinition* fFunction; +}; + +static void do_callback(SkJumper_CallbackCtx* raw, int activePixels) { + CallbackCtx& ctx = (CallbackCtx&) *raw; + for (int i = 0; i < activePixels; ++i) { + ctx.fInterpreter->push(Interpreter::Value(ctx.rgba[i * 4 + 0])); + ctx.fInterpreter->push(Interpreter::Value(ctx.rgba[i * 4 + 1])); + ctx.fInterpreter->push(Interpreter::Value(ctx.rgba[i * 4 + 2])); + ctx.fInterpreter->run(*ctx.fFunction); + ctx.read_from[i * 4 + 2] = ctx.fInterpreter->pop().fFloat; + ctx.read_from[i * 4 + 1] = ctx.fInterpreter->pop().fFloat; + ctx.read_from[i * 4 + 0] = ctx.fInterpreter->pop().fFloat; + } +} + +void Interpreter::appendStage(const AppendStage& a) { + switch (a.fStage) { + case SkRasterPipeline::matrix_4x5: { + ASSERT(a.fArguments.size() == 1); + StackIndex transpose = evaluate(*a.fArguments[0]).fInt; + fPipeline.append(SkRasterPipeline::matrix_4x5, &fStack[transpose]); + break; + } + case SkRasterPipeline::callback: { + ASSERT(a.fArguments.size() == 1); + CallbackCtx* ctx = new CallbackCtx(); + ctx->fInterpreter = this; + ctx->fn = do_callback; + for (const auto& e : fProgram->fElements) { + if (ProgramElement::kFunction_Kind == e->fKind) { + const FunctionDefinition& f = (const FunctionDefinition&) *e; + if (&f.fDeclaration == + ((const FunctionReference&) *a.fArguments[0]).fFunctions[0]) { + ctx->fFunction = &f; + } + } + } + fPipeline.append(SkRasterPipeline::callback, ctx); + break; + } + default: + fPipeline.append(a.fStage); + } +} + +Interpreter::Value Interpreter::call(const FunctionCall& c) { + abort(); +} + +Interpreter::Value Interpreter::evaluate(const Expression& expr) { + switch (expr.fKind) { + case Expression::kAppendStage_Kind: + this->appendStage((const AppendStage&) expr); + return Value((int) 0xDEADBEEF); + case Expression::kBinary_Kind: { + #define ARITHMETIC(op) { \ + Value left = this->evaluate(*b.fLeft); \ + Value right = this->evaluate(*b.fRight); \ + switch (type_kind(b.fLeft->fType)) { \ + case kFloat_TypeKind: \ + return Value(left.fFloat op right.fFloat); \ + case kInt_TypeKind: \ + return Value(left.fInt op right.fInt); \ + default: \ + abort(); \ + } \ + } + #define BITWISE(op) { \ + Value left = this->evaluate(*b.fLeft); \ + Value right = this->evaluate(*b.fRight); \ + switch (type_kind(b.fLeft->fType)) { \ + case kInt_TypeKind: \ + return Value(left.fInt op right.fInt); \ + default: \ + abort(); \ + } \ + } + #define LOGIC(op) { \ + Value left = this->evaluate(*b.fLeft); \ + Value right = this->evaluate(*b.fRight); \ + switch (type_kind(b.fLeft->fType)) { \ + case kFloat_TypeKind: \ + return Value(left.fFloat op right.fFloat); \ + case kInt_TypeKind: \ + return Value(left.fInt op right.fInt); \ + default: \ + abort(); \ + } \ + } + #define COMPOUND_ARITHMETIC(op) { \ + StackIndex left = this->getLValue(*b.fLeft); \ + Value right = this->evaluate(*b.fRight); \ + Value result = fStack[left]; \ + switch (type_kind(b.fLeft->fType)) { \ + case kFloat_TypeKind: \ + result.fFloat op right.fFloat; \ + break; \ + case kInt_TypeKind: \ + result.fInt op right.fInt; \ + break; \ + default: \ + abort(); \ + } \ + fStack[left] = result; \ + return result; \ + } + #define COMPOUND_BITWISE(op) { \ + StackIndex left = this->getLValue(*b.fLeft); \ + Value right = this->evaluate(*b.fRight); \ + Value result = fStack[left]; \ + switch (type_kind(b.fLeft->fType)) { \ + case kInt_TypeKind: \ + result.fInt op right.fInt; \ + break; \ + default: \ + abort(); \ + } \ + fStack[left] = result; \ + return result; \ + } + const BinaryExpression& b = (const BinaryExpression&) expr; + switch (b.fOperator) { + case Token::PLUS: ARITHMETIC(+) + case Token::MINUS: ARITHMETIC(-) + case Token::STAR: ARITHMETIC(*) + case Token::SLASH: ARITHMETIC(/) + case Token::BITWISEAND: BITWISE(&) + case Token::BITWISEOR: BITWISE(|) + case Token::BITWISEXOR: BITWISE(^) + case Token::LT: LOGIC(<) + case Token::GT: LOGIC(>) + case Token::LTEQ: LOGIC(<=) + case Token::GTEQ: LOGIC(>=) + case Token::LOGICALAND: { + Value result = this->evaluate(*b.fLeft); + if (result.fBool) { + result = this->evaluate(*b.fRight); + } + return result; + } + case Token::LOGICALOR: { + Value result = this->evaluate(*b.fLeft); + if (!result.fBool) { + result = this->evaluate(*b.fRight); + } + return result; + } + case Token::EQ: { + StackIndex left = this->getLValue(*b.fLeft); + Value right = this->evaluate(*b.fRight); + fStack[left] = right; + return right; + } + case Token::PLUSEQ: COMPOUND_ARITHMETIC(+=) + case Token::MINUSEQ: COMPOUND_ARITHMETIC(-=) + case Token::STAREQ: COMPOUND_ARITHMETIC(*=) + case Token::SLASHEQ: COMPOUND_ARITHMETIC(/=) + case Token::BITWISEANDEQ: COMPOUND_BITWISE(&=) + case Token::BITWISEOREQ: COMPOUND_BITWISE(|=) + case Token::BITWISEXOREQ: COMPOUND_BITWISE(^=) + default: + ABORT("unsupported operator: %s\n", expr.description().c_str()); + } + break; + } + case Expression::kBoolLiteral_Kind: + return Value(((const BoolLiteral&) expr).fValue); + case Expression::kConstructor_Kind: + break; + case Expression::kIntLiteral_Kind: + return Value((int) ((const IntLiteral&) expr).fValue); + case Expression::kFieldAccess_Kind: + break; + case Expression::kFloatLiteral_Kind: + return Value((float) ((const FloatLiteral&) expr).fValue); + case Expression::kFunctionCall_Kind: + return this->call((const FunctionCall&) expr); + case Expression::kIndex_Kind: { + const IndexExpression& idx = (const IndexExpression&) expr; + StackIndex pos = this->evaluate(*idx.fBase).fInt + + this->evaluate(*idx.fIndex).fInt; + return fStack[pos]; + } + case Expression::kPrefix_Kind: { + const PrefixExpression& p = (const PrefixExpression&) expr; + switch (p.fOperator) { + case Token::MINUS: { + Value base = this->evaluate(*p.fOperand); + switch (type_kind(p.fType)) { + case kFloat_TypeKind: + return Value(-base.fFloat); + case kInt_TypeKind: + return Value(-base.fInt); + default: + abort(); + } + } + case Token::LOGICALNOT: { + Value base = this->evaluate(*p.fOperand); + return Value(!base.fBool); + } + default: + abort(); + } + } + case Expression::kPostfix_Kind: { + const PostfixExpression& p = (const PostfixExpression&) expr; + StackIndex lvalue = this->getLValue(*p.fOperand); + Value result = fStack[lvalue]; + switch (type_kind(p.fType)) { + case kFloat_TypeKind: + if (Token::PLUSPLUS == p.fOperator) { + ++fStack[lvalue].fFloat; + } else { + ASSERT(Token::MINUSMINUS == p.fOperator); + --fStack[lvalue].fFloat; + } + break; + case kInt_TypeKind: + if (Token::PLUSPLUS == p.fOperator) { + ++fStack[lvalue].fInt; + } else { + ASSERT(Token::MINUSMINUS == p.fOperator); + --fStack[lvalue].fInt; + } + break; + default: + abort(); + } + return result; + } + case Expression::kSetting_Kind: + break; + case Expression::kSwizzle_Kind: + break; + case Expression::kVariableReference_Kind: + ASSERT(fVars.size()); + ASSERT(fVars.back().find(&((VariableReference&) expr).fVariable) != + fVars.back().end()); + return fStack[fVars.back()[&((VariableReference&) expr).fVariable]]; + case Expression::kTernary_Kind: { + const TernaryExpression& t = (const TernaryExpression&) expr; + return this->evaluate(this->evaluate(*t.fTest).fBool ? *t.fIfTrue : *t.fIfFalse); + } + case Expression::kTypeReference_Kind: + break; + default: + break; + } + ABORT("unsupported expression: %s\n", expr.description().c_str()); +} + +} // namespace + +#endif diff --git a/src/sksl/SkSLInterpreter.h b/src/sksl/SkSLInterpreter.h new file mode 100644 index 0000000000..93f32a1617 --- /dev/null +++ b/src/sksl/SkSLInterpreter.h @@ -0,0 +1,89 @@ +/* + * Copyright 2018 Google Inc. + * + * Use of this source code is governed by a BSD-style license that can be + * found in the LICENSE file. + */ + +#ifndef SKSL_INTERPRETER +#define SKSL_INTERPRETER + +#include "ir/SkSLAppendStage.h" +#include "ir/SkSLExpression.h" +#include "ir/SkSLFunctionCall.h" +#include "ir/SkSLFunctionDefinition.h" +#include "ir/SkSLProgram.h" +#include "ir/SkSLStatement.h" + +#include <stack> + +class SkRasterPipeline; + +namespace SkSL { + +class Interpreter { + typedef int StackIndex; + + struct StatementIndex { + const Statement* fStatement; + size_t fIndex; + }; + +public: + union Value { + Value(float f) + : fFloat(f) {} + + Value(int i) + : fInt(i) {} + + Value(bool b) + : fBool(b) {} + + float fFloat; + int fInt; + bool fBool; + }; + + enum TypeKind { + kFloat_TypeKind, + kInt_TypeKind, + kBool_TypeKind + }; + + Interpreter(std::unique_ptr<Program> program, SkRasterPipeline* pipeline, std::vector<Value>* stack) + : fProgram(std::move(program)) + , fPipeline(*pipeline) + , fStack(*stack) {} + + void run(); + + void run(const FunctionDefinition& f); + + void push(Value value); + + Value pop(); + + StackIndex stackAlloc(int count); + + void runStatement(); + + StackIndex getLValue(const Expression& expr); + + Value call(const FunctionCall& c); + + void appendStage(const AppendStage& c); + + Value evaluate(const Expression& expr); + +private: + std::unique_ptr<Program> fProgram; + SkRasterPipeline& fPipeline; + std::vector<StatementIndex> fCurrentIndex; + std::vector<std::map<const Variable*, StackIndex>> fVars; + std::vector<Value> &fStack; +}; + +} // namespace + +#endif diff --git a/src/sksl/SkSLJIT.cpp b/src/sksl/SkSLJIT.cpp new file mode 100644 index 0000000000..06b6e2c94b --- /dev/null +++ b/src/sksl/SkSLJIT.cpp @@ -0,0 +1,1747 @@ +/* + * Copyright 2018 Google Inc. + * + * Use of this source code is governed by a BSD-style license that can be + * found in the LICENSE file. + */ + +#ifndef SKSL_STANDALONE + +#ifdef SK_LLVM_AVAILABLE + +#include "SkSLJIT.h" + +#include "SkCpu.h" +#include "SkRasterPipeline.h" +#include "../jumper/SkJumper.h" +#include "ir/SkSLExpressionStatement.h" +#include "ir/SkSLFunctionCall.h" +#include "ir/SkSLFunctionReference.h" +#include "ir/SkSLIndexExpression.h" +#include "ir/SkSLProgram.h" +#include "llvm/ExecutionEngine/RTDyldMemoryManager.h" + +static constexpr int MAX_VECTOR_COUNT = 16; + +extern "C" void sksl_pipeline_append(SkRasterPipeline* p, int stage, void* ctx) { + p->append((SkRasterPipeline::StockStage) stage, ctx); +} + +#define PTR_SIZE sizeof(void*) + +extern "C" void sksl_pipeline_append_callback(SkRasterPipeline* p, void* fn) { + p->append(fn, nullptr); +} + +extern "C" void sksl_debug_print(float f) { + printf("Debug: %f\n", f); +} + +namespace SkSL { + +static constexpr int STAGE_PARAM_COUNT = 12; + +static bool ends_with_branch(const Statement& stmt) { + switch (stmt.fKind) { + case Statement::kBlock_Kind: { + const Block& b = (const Block&) stmt; + if (b.fStatements.size()) { + return ends_with_branch(*b.fStatements.back()); + } + return false; + } + case Statement::kBreak_Kind: // fall through + case Statement::kContinue_Kind: // fall through + case Statement::kReturn_Kind: // fall through + return true; + default: + return false; + } +} + +JIT::JIT(Compiler* compiler) +: fCompiler(*compiler) { + LLVMInitializeNativeTarget(); + LLVMInitializeNativeAsmPrinter(); + LLVMLinkInMCJIT(); + ASSERT(!SkCpu::Supports(SkCpu::SKX)); // not yet supported + if (SkCpu::Supports(SkCpu::HSW)) { + fVectorCount = 8; + fCPU = "haswell"; + } else if (SkCpu::Supports(SkCpu::AVX)) { + fVectorCount = 8; + fCPU = "ivybridge"; + } else { + fVectorCount = 4; + fCPU = nullptr; + } + fContext = LLVMContextCreate(); + fVoidType = LLVMVoidTypeInContext(fContext); + fInt1Type = LLVMInt1TypeInContext(fContext); + fInt8Type = LLVMInt8TypeInContext(fContext); + fInt8PtrType = LLVMPointerType(fInt8Type, 0); + fInt32Type = LLVMInt32TypeInContext(fContext); + fInt64Type = LLVMInt64TypeInContext(fContext); + fSizeTType = LLVMInt64TypeInContext(fContext); + fInt32VectorType = LLVMVectorType(fInt32Type, fVectorCount); + fInt32Vector2Type = LLVMVectorType(fInt32Type, 2); + fInt32Vector3Type = LLVMVectorType(fInt32Type, 3); + fInt32Vector4Type = LLVMVectorType(fInt32Type, 4); + fFloat32Type = LLVMFloatTypeInContext(fContext); + fFloat32VectorType = LLVMVectorType(fFloat32Type, fVectorCount); + fFloat32Vector2Type = LLVMVectorType(fFloat32Type, 2); + fFloat32Vector3Type = LLVMVectorType(fFloat32Type, 3); + fFloat32Vector4Type = LLVMVectorType(fFloat32Type, 4); +} + +JIT::~JIT() { + LLVMOrcDisposeInstance(fJITStack); + LLVMContextDispose(fContext); +} + +void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType, + std::vector<LLVMTypeRef> parameters) { + for (const auto& pair : *fProgram->fSymbols) { + if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) { + const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second; + if (pair.first != ourName || returnType != this->getType(f.fReturnType) || + parameters.size() != f.fParameters.size()) { + continue; + } + for (size_t i = 0; i < parameters.size(); ++i) { + if (parameters[i] != this->getType(f.fParameters[i]->fType)) { + goto next; + } + } + fFunctions[&f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(returnType, + parameters.data(), + parameters.size(), + false)); + } + next:; + } +} + +void JIT::loadBuiltinFunctions() { + this->addBuiltinFunction("abs", "fabs", fFloat32Type, { fFloat32Type }); + this->addBuiltinFunction("sin", "sinf", fFloat32Type, { fFloat32Type }); + this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type }); + this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type }); + this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type }); + this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type }); +} + +uint64_t JIT::resolveSymbol(const char* name, JIT* jit) { + LLVMOrcTargetAddress result; + if (!LLVMOrcGetSymbolAddress(jit->fJITStack, &result, name)) { + if (!strcmp(name, "_sksl_pipeline_append")) { + result = (uint64_t) &sksl_pipeline_append; + } else if (!strcmp(name, "_sksl_pipeline_append_callback")) { + result = (uint64_t) &sksl_pipeline_append_callback; + } else if (!strcmp(name, "_sksl_debug_print")) { + result = (uint64_t) &sksl_debug_print; + } else { + result = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name); + } + } + ASSERT(result); + return result; +} + +LLVMValueRef JIT::compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc) { + LLVMValueRef func = fFunctions[&fc.fFunction]; + ASSERT(func); + std::vector<LLVMValueRef> parameters; + for (const auto& a : fc.fArguments) { + parameters.push_back(this->compileExpression(builder, *a)); + } + return LLVMBuildCall(builder, func, parameters.data(), parameters.size(), ""); +} + +LLVMTypeRef JIT::getType(const Type& type) { + switch (type.kind()) { + case Type::kOther_Kind: + if (type.name() == "void") { + return fVoidType; + } + ASSERT(type.name() == "SkRasterPipeline"); + return fInt8PtrType; + case Type::kScalar_Kind: + if (type.isSigned() || type.isUnsigned()) { + return fInt32Type; + } + if (type.isUnsigned()) { + return fInt32Type; + } + if (type.isFloat()) { + return fFloat32Type; + } + ASSERT(type.name() == "bool"); + return fInt1Type; + case Type::kArray_Kind: + return LLVMPointerType(this->getType(type.componentType()), 0); + case Type::kVector_Kind: + if (type.name() == "float2" || type.name() == "half2") { + return fFloat32Vector2Type; + } + if (type.name() == "float3" || type.name() == "half3") { + return fFloat32Vector3Type; + } + if (type.name() == "float4" || type.name() == "half4") { + return fFloat32Vector4Type; + } + if (type.name() == "int2" || type.name() == "short2") { + return fInt32Vector2Type; + } + if (type.name() == "int3" || type.name() == "short3") { + return fInt32Vector3Type; + } + if (type.name() == "int4" || type.name() == "short4") { + return fInt32Vector4Type; + } + // fall through + default: + ABORT("unsupported type"); + } +} + +void JIT::setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block) { + fCurrentBlock = block; + LLVMPositionBuilderAtEnd(builder, block); +} + +std::unique_ptr<JIT::LValue> JIT::getLValue(LLVMBuilderRef builder, const Expression& expr) { + switch (expr.fKind) { + case Expression::kVariableReference_Kind: { + class PointerLValue : public LValue { + public: + PointerLValue(LLVMValueRef ptr) + : fPointer(ptr) {} + + LLVMValueRef load(LLVMBuilderRef builder) override { + return LLVMBuildLoad(builder, fPointer, "lvalue load"); + } + + void store(LLVMBuilderRef builder, LLVMValueRef value) override { + LLVMBuildStore(builder, value, fPointer); + } + + private: + LLVMValueRef fPointer; + }; + const Variable* var = &((VariableReference&) expr).fVariable; + if (var->fStorage == Variable::kParameter_Storage && + !(var->fModifiers.fFlags & Modifiers::kOut_Flag) && + fPromotedParameters.find(var) == fPromotedParameters.end()) { + // promote parameter to variable + fPromotedParameters.insert(var); + LLVMPositionBuilderAtEnd(builder, fAllocaBlock); + LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(var->fType), + String(var->fName).c_str()); + LLVMBuildStore(builder, fVariables[var], alloca); + LLVMPositionBuilderAtEnd(builder, fCurrentBlock); + fVariables[var] = alloca; + } + LLVMValueRef ptr = fVariables[var]; + return std::unique_ptr<LValue>(new PointerLValue(ptr)); + } + case Expression::kTernary_Kind: { + class TernaryLValue : public LValue { + public: + TernaryLValue(JIT* jit, LLVMValueRef test, std::unique_ptr<LValue> ifTrue, + std::unique_ptr<LValue> ifFalse) + : fJIT(*jit) + , fTest(test) + , fIfTrue(std::move(ifTrue)) + , fIfFalse(std::move(ifFalse)) {} + + LLVMValueRef load(LLVMBuilderRef builder) override { + LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext( + fJIT.fContext, + fJIT.fCurrentFunction, + "true ? ..."); + LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext( + fJIT.fContext, + fJIT.fCurrentFunction, + "false ? ..."); + LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext, + fJIT.fCurrentFunction, + "ternary merge"); + LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock); + fJIT.setBlock(builder, trueBlock); + LLVMValueRef ifTrue = fIfTrue->load(builder); + LLVMBuildBr(builder, merge); + fJIT.setBlock(builder, falseBlock); + LLVMValueRef ifFalse = fIfTrue->load(builder); + LLVMBuildBr(builder, merge); + fJIT.setBlock(builder, merge); + LLVMTypeRef type = LLVMPointerType(LLVMTypeOf(ifTrue), 0); + LLVMValueRef phi = LLVMBuildPhi(builder, type, "?"); + LLVMValueRef incomingValues[2] = { ifTrue, ifFalse }; + LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock }; + LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2); + return phi; + } + + void store(LLVMBuilderRef builder, LLVMValueRef value) override { + LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext( + fJIT.fContext, + fJIT.fCurrentFunction, + "true ? ..."); + LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext( + fJIT.fContext, + fJIT.fCurrentFunction, + "false ? ..."); + LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext, + fJIT.fCurrentFunction, + "ternary merge"); + LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock); + fJIT.setBlock(builder, trueBlock); + fIfTrue->store(builder, value); + LLVMBuildBr(builder, merge); + fJIT.setBlock(builder, falseBlock); + fIfTrue->store(builder, value); + LLVMBuildBr(builder, merge); + fJIT.setBlock(builder, merge); + } + + private: + JIT& fJIT; + LLVMValueRef fTest; + std::unique_ptr<LValue> fIfTrue; + std::unique_ptr<LValue> fIfFalse; + }; + const TernaryExpression& t = (const TernaryExpression&) expr; + LLVMValueRef test = this->compileExpression(builder, *t.fTest); + return std::unique_ptr<LValue>(new TernaryLValue(this, + test, + this->getLValue(builder, + *t.fIfTrue), + this->getLValue(builder, + *t.fIfFalse))); + } + case Expression::kSwizzle_Kind: { + class SwizzleLValue : public LValue { + public: + SwizzleLValue(JIT* jit, LLVMTypeRef type, std::unique_ptr<LValue> base, + std::vector<int> components) + : fJIT(*jit) + , fType(type) + , fBase(std::move(base)) + , fComponents(components) {} + + LLVMValueRef load(LLVMBuilderRef builder) override { + LLVMValueRef base = fBase->load(builder); + if (fComponents.size() > 1) { + LLVMValueRef result = LLVMGetUndef(fType); + for (size_t i = 0; i < fComponents.size(); ++i) { + LLVMValueRef element = LLVMBuildExtractElement( + builder, + base, + LLVMConstInt(fJIT.fInt32Type, + fComponents[i], + false), + "swizzle extract"); + result = LLVMBuildInsertElement(builder, result, element, + LLVMConstInt(fJIT.fInt32Type, i, false), + "swizzle insert"); + } + return result; + } + ASSERT(fComponents.size() == 1); + return LLVMBuildExtractElement(builder, base, + LLVMConstInt(fJIT.fInt32Type, + fComponents[0], + false), + "swizzle extract"); + } + + void store(LLVMBuilderRef builder, LLVMValueRef value) override { + LLVMValueRef result = fBase->load(builder); + if (fComponents.size() > 1) { + for (size_t i = 0; i < fComponents.size(); ++i) { + LLVMValueRef element = LLVMBuildExtractElement(builder, value, + LLVMConstInt( + fJIT.fInt32Type, + i, + false), + "swizzle extract"); + result = LLVMBuildInsertElement(builder, result, element, + LLVMConstInt(fJIT.fInt32Type, + fComponents[i], + false), + "swizzle insert"); + } + } else { + result = LLVMBuildInsertElement(builder, result, value, + LLVMConstInt(fJIT.fInt32Type, + fComponents[0], + false), + "swizzle insert"); + } + fBase->store(builder, result); + } + + private: + JIT& fJIT; + LLVMTypeRef fType; + std::unique_ptr<LValue> fBase; + std::vector<int> fComponents; + }; + const Swizzle& s = (const Swizzle&) expr; + return std::unique_ptr<LValue>(new SwizzleLValue(this, this->getType(s.fType), + this->getLValue(builder, *s.fBase), + s.fComponents)); + } + default: + ABORT("unsupported lvalue"); + } +} + +JIT::TypeKind JIT::typeKind(const Type& type) { + if (type.kind() == Type::kVector_Kind) { + return this->typeKind(type.componentType()); + } + if (type.fName == "int" || type.fName == "short") { + return JIT::kInt_TypeKind; + } else if (type.fName == "uint" || type.fName == "ushort") { + return JIT::kUInt_TypeKind; + } else if (type.fName == "float" || type.fName == "double") { + return JIT::kFloat_TypeKind; + } + ABORT("unsupported type: %s\n", type.description().c_str()); +} + +void JIT::vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns) { + LLVMValueRef result = LLVMGetUndef(LLVMVectorType(LLVMTypeOf(*value), columns)); + for (int i = 0; i < columns; ++i) { + result = LLVMBuildInsertElement(builder, + result, + *value, + LLVMConstInt(fInt32Type, i, false), + "vectorize"); + } + *value = result; +} + +void JIT::vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left, + LLVMValueRef* right) { + if (b.fLeft->fType.kind() == Type::kScalar_Kind && + b.fRight->fType.kind() == Type::kVector_Kind) { + this->vectorize(builder, left, b.fRight->fType.columns()); + } else if (b.fLeft->fType.kind() == Type::kVector_Kind && + b.fRight->fType.kind() == Type::kScalar_Kind) { + this->vectorize(builder, right, b.fLeft->fType.columns()); + } +} + + +LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& b) { + #define BINARY(SFunc, UFunc, FFunc) { \ + LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \ + LLVMValueRef right = this->compileExpression(builder, *b.fRight); \ + this->vectorize(builder, b, &left, &right); \ + switch (this->typeKind(b.fLeft->fType)) { \ + case kInt_TypeKind: \ + return SFunc(builder, left, right, "binary"); \ + case kUInt_TypeKind: \ + return UFunc(builder, left, right, "binary"); \ + case kFloat_TypeKind: \ + return FFunc(builder, left, right, "binary"); \ + default: \ + ABORT("unsupported typeKind"); \ + } \ + } + #define COMPOUND(SFunc, UFunc, FFunc) { \ + std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft); \ + LLVMValueRef left = lvalue->load(builder); \ + LLVMValueRef right = this->compileExpression(builder, *b.fRight); \ + this->vectorize(builder, b, &left, &right); \ + LLVMValueRef result; \ + switch (this->typeKind(b.fLeft->fType)) { \ + case kInt_TypeKind: \ + result = SFunc(builder, left, right, "binary"); \ + break; \ + case kUInt_TypeKind: \ + result = UFunc(builder, left, right, "binary"); \ + break; \ + case kFloat_TypeKind: \ + result = FFunc(builder, left, right, "binary"); \ + break; \ + default: \ + ABORT("unsupported typeKind"); \ + } \ + lvalue->store(builder, result); \ + return result; \ + } + #define COMPARE(SFunc, SOp, UFunc, UOp, FFunc, FOp) { \ + LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \ + LLVMValueRef right = this->compileExpression(builder, *b.fRight); \ + this->vectorize(builder, b, &left, &right); \ + switch (this->typeKind(b.fLeft->fType)) { \ + case kInt_TypeKind: \ + return SFunc(builder, SOp, left, right, "binary"); \ + case kUInt_TypeKind: \ + return UFunc(builder, UOp, left, right, "binary"); \ + case kFloat_TypeKind: \ + return FFunc(builder, FOp, left, right, "binary"); \ + default: \ + ABORT("unsupported typeKind"); \ + } \ + } + switch (b.fOperator) { + case Token::EQ: { + std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft); + LLVMValueRef result = this->compileExpression(builder, *b.fRight); + lvalue->store(builder, result); + return result; + } + case Token::PLUS: + BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd); + case Token::MINUS: + BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub); + case Token::STAR: + BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul); + case Token::SLASH: + BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv); + case Token::PERCENT: + BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem); + case Token::BITWISEAND: + BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd); + case Token::BITWISEOR: + BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr); + case Token::PLUSEQ: + COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd); + case Token::MINUSEQ: + COMPOUND(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub); + case Token::STAREQ: + COMPOUND(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul); + case Token::SLASHEQ: + COMPOUND(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv); + case Token::BITWISEANDEQ: + COMPOUND(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd); + case Token::BITWISEOREQ: + COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr); + case Token::EQEQ: + COMPARE(LLVMBuildICmp, LLVMIntEQ, + LLVMBuildICmp, LLVMIntEQ, + LLVMBuildFCmp, LLVMRealOEQ); + case Token::NEQ: + COMPARE(LLVMBuildICmp, LLVMIntNE, + LLVMBuildICmp, LLVMIntNE, + LLVMBuildFCmp, LLVMRealONE); + case Token::LT: + COMPARE(LLVMBuildICmp, LLVMIntSLT, + LLVMBuildICmp, LLVMIntULT, + LLVMBuildFCmp, LLVMRealOLT); + case Token::LTEQ: + COMPARE(LLVMBuildICmp, LLVMIntSLE, + LLVMBuildICmp, LLVMIntULE, + LLVMBuildFCmp, LLVMRealOLE); + case Token::GT: + COMPARE(LLVMBuildICmp, LLVMIntSGT, + LLVMBuildICmp, LLVMIntUGT, + LLVMBuildFCmp, LLVMRealOGT); + case Token::GTEQ: + COMPARE(LLVMBuildICmp, LLVMIntSGE, + LLVMBuildICmp, LLVMIntUGE, + LLVMBuildFCmp, LLVMRealOGE); + case Token::LOGICALAND: { + LLVMValueRef left = this->compileExpression(builder, *b.fLeft); + LLVMBasicBlockRef ifFalse = fCurrentBlock; + LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "true && ..."); + LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "&& merge"); + LLVMBuildCondBr(builder, left, ifTrue, merge); + this->setBlock(builder, ifTrue); + LLVMValueRef right = this->compileExpression(builder, *b.fRight); + LLVMBuildBr(builder, merge); + this->setBlock(builder, merge); + LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "&&"); + LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 0, false) }; + LLVMBasicBlockRef incomingBlocks[2] = { ifTrue, ifFalse }; + LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2); + return phi; + } + case Token::LOGICALOR: { + LLVMValueRef left = this->compileExpression(builder, *b.fLeft); + LLVMBasicBlockRef ifTrue = fCurrentBlock; + LLVMBasicBlockRef ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "false || ..."); + LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "|| merge"); + LLVMBuildCondBr(builder, left, merge, ifFalse); + this->setBlock(builder, ifFalse); + LLVMValueRef right = this->compileExpression(builder, *b.fRight); + LLVMBuildBr(builder, merge); + this->setBlock(builder, merge); + LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "||"); + LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 1, false) }; + LLVMBasicBlockRef incomingBlocks[2] = { ifFalse, ifTrue }; + LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2); + return phi; + } + default: + ABORT("unsupported binary operator"); + } +} + +LLVMValueRef JIT::compileIndex(LLVMBuilderRef builder, const IndexExpression& idx) { + LLVMValueRef base = this->compileExpression(builder, *idx.fBase); + LLVMValueRef index = this->compileExpression(builder, *idx.fIndex); + LLVMValueRef ptr = LLVMBuildGEP(builder, base, &index, 1, "index ptr"); + return LLVMBuildLoad(builder, ptr, "index load"); +} + +LLVMValueRef JIT::compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p) { + std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand); + LLVMValueRef result = lvalue->load(builder); + LLVMValueRef mod; + LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false); + switch (p.fOperator) { + case Token::PLUSPLUS: + switch (this->typeKind(p.fType)) { + case kInt_TypeKind: // fall through + case kUInt_TypeKind: + mod = LLVMBuildAdd(builder, result, one, "++"); + break; + case kFloat_TypeKind: + mod = LLVMBuildFAdd(builder, result, one, "++"); + break; + default: + ABORT("unsupported typeKind"); + } + break; + case Token::MINUSMINUS: + switch (this->typeKind(p.fType)) { + case kInt_TypeKind: // fall through + case kUInt_TypeKind: + mod = LLVMBuildSub(builder, result, one, "--"); + break; + case kFloat_TypeKind: + mod = LLVMBuildFSub(builder, result, one, "--"); + break; + default: + ABORT("unsupported typeKind"); + } + break; + default: + ABORT("unsupported postfix op"); + } + lvalue->store(builder, mod); + return result; +} + +LLVMValueRef JIT::compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p) { + LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false); + if (Token::LOGICALNOT == p.fOperator) { + LLVMValueRef base = this->compileExpression(builder, *p.fOperand); + return LLVMBuildXor(builder, base, one, "!"); + } + if (Token::MINUS == p.fOperator) { + LLVMValueRef base = this->compileExpression(builder, *p.fOperand); + return LLVMBuildSub(builder, LLVMConstInt(this->getType(p.fType), 0, false), base, "-"); + } + std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand); + LLVMValueRef raw = lvalue->load(builder); + LLVMValueRef result; + switch (p.fOperator) { + case Token::PLUSPLUS: + switch (this->typeKind(p.fType)) { + case kInt_TypeKind: // fall through + case kUInt_TypeKind: + result = LLVMBuildAdd(builder, raw, one, "++"); + break; + case kFloat_TypeKind: + result = LLVMBuildFAdd(builder, raw, one, "++"); + break; + default: + ABORT("unsupported typeKind"); + } + break; + case Token::MINUSMINUS: + switch (this->typeKind(p.fType)) { + case kInt_TypeKind: // fall through + case kUInt_TypeKind: + result = LLVMBuildSub(builder, raw, one, "--"); + break; + case kFloat_TypeKind: + result = LLVMBuildFSub(builder, raw, one, "--"); + break; + default: + ABORT("unsupported typeKind"); + } + break; + default: + ABORT("unsupported prefix op"); + } + lvalue->store(builder, result); + return result; +} + +LLVMValueRef JIT::compileVariableReference(LLVMBuilderRef builder, const VariableReference& v) { + const Variable& var = v.fVariable; + if (Variable::kParameter_Storage == var.fStorage && + !(var.fModifiers.fFlags & Modifiers::kOut_Flag) && + fPromotedParameters.find(&var) == fPromotedParameters.end()) { + return fVariables[&var]; + } + return LLVMBuildLoad(builder, fVariables[&var], String(var.fName).c_str()); +} + +void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) { + ASSERT(a.fArguments.size() >= 1); + ASSERT(a.fArguments[0]->fType == *fCompiler.context().fSkRasterPipeline_Type); + LLVMValueRef pipeline = this->compileExpression(builder, *a.fArguments[0]); + LLVMValueRef stage = LLVMConstInt(fInt32Type, a.fStage, 0); + switch (a.fStage) { + case SkRasterPipeline::callback: { + ASSERT(a.fArguments.size() == 2); + ASSERT(a.fArguments[1]->fKind == Expression::kFunctionReference_Kind); + const FunctionDeclaration& functionDecl = + *((FunctionReference&) *a.fArguments[1]).fFunctions[0]; + bool found = false; + for (const auto& pe : fProgram->fElements) { + if (ProgramElement::kFunction_Kind == pe->fKind) { + const FunctionDefinition& def = (const FunctionDefinition&) *pe; + if (&def.fDeclaration == &functionDecl) { + LLVMValueRef fn = this->compileStageFunction(def); + LLVMValueRef args[2] = { + pipeline, + LLVMBuildBitCast(builder, fn, fInt8PtrType, "callback cast") + }; + LLVMBuildCall(builder, fAppendCallbackFunc, args, 2, ""); + found = true; + break; + } + } + } + ASSERT(found); + break; + } + default: { + LLVMValueRef ctx; + if (a.fArguments.size() == 2) { + ctx = this->compileExpression(builder, *a.fArguments[1]); + ctx = LLVMBuildBitCast(builder, ctx, fInt8PtrType, "context cast"); + } else { + ASSERT(a.fArguments.size() == 1); + ctx = LLVMConstNull(fInt8PtrType); + } + LLVMValueRef args[3] = { + pipeline, + stage, + ctx + }; + LLVMBuildCall(builder, fAppendFunc, args, 3, ""); + break; + } + } +} + +LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor& c) { + switch (c.fType.kind()) { + case Type::kScalar_Kind: { + ASSERT(c.fArguments.size() == 1); + TypeKind from = this->typeKind(c.fArguments[0]->fType); + TypeKind to = this->typeKind(c.fType); + LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]); + if (kFloat_TypeKind == to) { + if (kInt_TypeKind == from) { + return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast"); + } + if (kUInt_TypeKind == from) { + return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast"); + } + } + if (kInt_TypeKind == to) { + if (kFloat_TypeKind == from) { + return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast"); + } + if (kUInt_TypeKind == from) { + return base; + } + } + if (kUInt_TypeKind == to) { + if (kFloat_TypeKind == from) { + return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast"); + } + if (kInt_TypeKind == from) { + return base; + } + } + ABORT("unsupported constructor"); + } + case Type::kVector_Kind: { + LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType)); + if (c.fArguments.size() == 1) { + LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]); + for (int i = 0; i < c.fType.columns(); ++i) { + vec = LLVMBuildInsertElement(builder, vec, value, + LLVMConstInt(fInt32Type, i, false), + "vec build"); + } + } else { + ASSERT(c.fArguments.size() == (size_t) c.fType.columns()); + for (int i = 0; i < c.fType.columns(); ++i) { + vec = LLVMBuildInsertElement(builder, vec, + this->compileExpression(builder, + *c.fArguments[i]), + LLVMConstInt(fInt32Type, i, false), + "vec build"); + } + } + return vec; + } + default: + break; + } + ABORT("unsupported constructor"); +} + +LLVMValueRef JIT::compileSwizzle(LLVMBuilderRef builder, const Swizzle& s) { + LLVMValueRef base = this->compileExpression(builder, *s.fBase); + if (s.fComponents.size() > 1) { + LLVMValueRef result = LLVMGetUndef(this->getType(s.fType)); + for (size_t i = 0; i < s.fComponents.size(); ++i) { + LLVMValueRef element = LLVMBuildExtractElement( + builder, + base, + LLVMConstInt(fInt32Type, + s.fComponents[i], + false), + "swizzle extract"); + result = LLVMBuildInsertElement(builder, result, element, + LLVMConstInt(fInt32Type, i, false), + "swizzle insert"); + } + return result; + } + ASSERT(s.fComponents.size() == 1); + return LLVMBuildExtractElement(builder, base, + LLVMConstInt(fInt32Type, + s.fComponents[0], + false), + "swizzle extract"); +} + +LLVMValueRef JIT::compileTernary(LLVMBuilderRef builder, const TernaryExpression& t) { + LLVMValueRef test = this->compileExpression(builder, *t.fTest); + LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "if true"); + LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "if merge"); + LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "if false"); + LLVMBuildCondBr(builder, test, trueBlock, falseBlock); + this->setBlock(builder, trueBlock); + LLVMValueRef ifTrue = this->compileExpression(builder, *t.fIfTrue); + trueBlock = fCurrentBlock; + LLVMBuildBr(builder, merge); + this->setBlock(builder, falseBlock); + LLVMValueRef ifFalse = this->compileExpression(builder, *t.fIfFalse); + falseBlock = fCurrentBlock; + LLVMBuildBr(builder, merge); + this->setBlock(builder, merge); + LLVMValueRef phi = LLVMBuildPhi(builder, this->getType(t.fType), "?"); + LLVMValueRef incomingValues[2] = { ifTrue, ifFalse }; + LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock }; + LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2); + return phi; +} + +LLVMValueRef JIT::compileExpression(LLVMBuilderRef builder, const Expression& expr) { + switch (expr.fKind) { + case Expression::kAppendStage_Kind: { + this->appendStage(builder, (const AppendStage&) expr); + return LLVMValueRef(); + } + case Expression::kBinary_Kind: + return this->compileBinary(builder, (BinaryExpression&) expr); + case Expression::kBoolLiteral_Kind: + return LLVMConstInt(fInt1Type, ((BoolLiteral&) expr).fValue, false); + case Expression::kConstructor_Kind: + return this->compileConstructor(builder, (Constructor&) expr); + case Expression::kIntLiteral_Kind: + return LLVMConstInt(this->getType(expr.fType), ((IntLiteral&) expr).fValue, true); + case Expression::kFieldAccess_Kind: + abort(); + case Expression::kFloatLiteral_Kind: + return LLVMConstReal(this->getType(expr.fType), ((FloatLiteral&) expr).fValue); + case Expression::kFunctionCall_Kind: + return this->compileFunctionCall(builder, (FunctionCall&) expr); + case Expression::kIndex_Kind: + return this->compileIndex(builder, (IndexExpression&) expr); + case Expression::kPrefix_Kind: + return this->compilePrefix(builder, (PrefixExpression&) expr); + case Expression::kPostfix_Kind: + return this->compilePostfix(builder, (PostfixExpression&) expr); + case Expression::kSetting_Kind: + abort(); + case Expression::kSwizzle_Kind: + return this->compileSwizzle(builder, (Swizzle&) expr); + case Expression::kVariableReference_Kind: + return this->compileVariableReference(builder, (VariableReference&) expr); + case Expression::kTernary_Kind: + return this->compileTernary(builder, (TernaryExpression&) expr); + case Expression::kTypeReference_Kind: + abort(); + default: + abort(); + } + ABORT("unsupported expression: %s\n", expr.description().c_str()); +} + +void JIT::compileBlock(LLVMBuilderRef builder, const Block& block) { + for (const auto& stmt : block.fStatements) { + this->compileStatement(builder, *stmt); + } +} + +void JIT::compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls) { + for (const auto& declStatement : decls.fDeclaration->fVars) { + const VarDeclaration& decl = (VarDeclaration&) *declStatement; + LLVMPositionBuilderAtEnd(builder, fAllocaBlock); + LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(decl.fVar->fType), + String(decl.fVar->fName).c_str()); + fVariables[decl.fVar] = alloca; + LLVMPositionBuilderAtEnd(builder, fCurrentBlock); + if (decl.fValue) { + LLVMValueRef result = this->compileExpression(builder, *decl.fValue); + LLVMBuildStore(builder, result, alloca); + } + } +} + +void JIT::compileIf(LLVMBuilderRef builder, const IfStatement& i) { + LLVMValueRef test = this->compileExpression(builder, *i.fTest); + LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if true"); + LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "if merge"); + LLVMBasicBlockRef ifFalse; + if (i.fIfFalse) { + ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if false"); + } else { + ifFalse = merge; + } + LLVMBuildCondBr(builder, test, ifTrue, ifFalse); + this->setBlock(builder, ifTrue); + this->compileStatement(builder, *i.fIfTrue); + if (!ends_with_branch(*i.fIfTrue)) { + LLVMBuildBr(builder, merge); + } + if (i.fIfFalse) { + this->setBlock(builder, ifFalse); + this->compileStatement(builder, *i.fIfFalse); + if (!ends_with_branch(*i.fIfFalse)) { + LLVMBuildBr(builder, merge); + } + } + this->setBlock(builder, merge); +} + +void JIT::compileFor(LLVMBuilderRef builder, const ForStatement& f) { + if (f.fInitializer) { + this->compileStatement(builder, *f.fInitializer); + } + LLVMBasicBlockRef start; + LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for body"); + LLVMBasicBlockRef next = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for next"); + LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for end"); + if (f.fTest) { + start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for test"); + LLVMBuildBr(builder, start); + this->setBlock(builder, start); + LLVMValueRef test = this->compileExpression(builder, *f.fTest); + LLVMBuildCondBr(builder, test, body, end); + } else { + start = body; + LLVMBuildBr(builder, body); + } + this->setBlock(builder, body); + fBreakTarget.push_back(end); + fContinueTarget.push_back(next); + this->compileStatement(builder, *f.fStatement); + fBreakTarget.pop_back(); + fContinueTarget.pop_back(); + if (!ends_with_branch(*f.fStatement)) { + LLVMBuildBr(builder, next); + } + this->setBlock(builder, next); + if (f.fNext) { + this->compileExpression(builder, *f.fNext); + } + LLVMBuildBr(builder, start); + this->setBlock(builder, end); +} + +void JIT::compileDo(LLVMBuilderRef builder, const DoStatement& d) { + LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "do test"); + LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "do body"); + LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "do end"); + LLVMBuildBr(builder, body); + this->setBlock(builder, testBlock); + LLVMValueRef test = this->compileExpression(builder, *d.fTest); + LLVMBuildCondBr(builder, test, body, end); + this->setBlock(builder, body); + fBreakTarget.push_back(end); + fContinueTarget.push_back(body); + this->compileStatement(builder, *d.fStatement); + fBreakTarget.pop_back(); + fContinueTarget.pop_back(); + if (!ends_with_branch(*d.fStatement)) { + LLVMBuildBr(builder, testBlock); + } + this->setBlock(builder, end); +} + +void JIT::compileWhile(LLVMBuilderRef builder, const WhileStatement& w) { + LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "while test"); + LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "while body"); + LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, + "while end"); + LLVMBuildBr(builder, testBlock); + this->setBlock(builder, testBlock); + LLVMValueRef test = this->compileExpression(builder, *w.fTest); + LLVMBuildCondBr(builder, test, body, end); + this->setBlock(builder, body); + fBreakTarget.push_back(end); + fContinueTarget.push_back(testBlock); + this->compileStatement(builder, *w.fStatement); + fBreakTarget.pop_back(); + fContinueTarget.pop_back(); + if (!ends_with_branch(*w.fStatement)) { + LLVMBuildBr(builder, testBlock); + } + this->setBlock(builder, end); +} + +void JIT::compileBreak(LLVMBuilderRef builder, const BreakStatement& b) { + LLVMBuildBr(builder, fBreakTarget.back()); +} + +void JIT::compileContinue(LLVMBuilderRef builder, const ContinueStatement& b) { + LLVMBuildBr(builder, fContinueTarget.back()); +} + +void JIT::compileReturn(LLVMBuilderRef builder, const ReturnStatement& r) { + if (r.fExpression) { + LLVMBuildRet(builder, this->compileExpression(builder, *r.fExpression)); + } else { + LLVMBuildRetVoid(builder); + } +} + +void JIT::compileStatement(LLVMBuilderRef builder, const Statement& stmt) { + switch (stmt.fKind) { + case Statement::kBlock_Kind: + this->compileBlock(builder, (Block&) stmt); + break; + case Statement::kBreak_Kind: + this->compileBreak(builder, (BreakStatement&) stmt); + break; + case Statement::kContinue_Kind: + this->compileContinue(builder, (ContinueStatement&) stmt); + break; + case Statement::kDiscard_Kind: + abort(); + case Statement::kDo_Kind: + this->compileDo(builder, (DoStatement&) stmt); + break; + case Statement::kExpression_Kind: + this->compileExpression(builder, *((ExpressionStatement&) stmt).fExpression); + break; + case Statement::kFor_Kind: + this->compileFor(builder, (ForStatement&) stmt); + break; + case Statement::kGroup_Kind: + abort(); + case Statement::kIf_Kind: + this->compileIf(builder, (IfStatement&) stmt); + break; + case Statement::kNop_Kind: + break; + case Statement::kReturn_Kind: + this->compileReturn(builder, (ReturnStatement&) stmt); + break; + case Statement::kSwitch_Kind: + abort(); + case Statement::kVarDeclarations_Kind: + this->compileVarDeclarations(builder, (VarDeclarationsStatement&) stmt); + break; + case Statement::kWhile_Kind: + this->compileWhile(builder, (WhileStatement&) stmt); + break; + default: + abort(); + } +} + +void JIT::compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc) { + // loop over fVectorCount pixels, running the body of the stage function for each of them + LLVMValueRef oldFunction = fCurrentFunction; + fCurrentFunction = newFunc; + std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]); + LLVMGetParams(fCurrentFunction, params.get()); + LLVMValueRef programParam = params.get()[1]; + LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext); + LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock; + LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock; + fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca"); + this->setBlock(builder, fAllocaBlock); + // temporaries to store the color channel vectors + LLVMValueRef rVec = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec"); + LLVMBuildStore(builder, params.get()[4], rVec); + LLVMValueRef gVec = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec"); + LLVMBuildStore(builder, params.get()[5], gVec); + LLVMValueRef bVec = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec"); + LLVMBuildStore(builder, params.get()[6], bVec); + LLVMValueRef aVec = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec"); + LLVMBuildStore(builder, params.get()[7], aVec); + LLVMValueRef color = LLVMBuildAlloca(builder, fFloat32Vector4Type, "color"); + fVariables[f.fDeclaration.fParameters[1]] = LLVMBuildTrunc(builder, params.get()[3], fInt32Type, + "y->Int32"); + fVariables[f.fDeclaration.fParameters[2]] = color; + LLVMValueRef ivar = LLVMBuildAlloca(builder, fInt32Type, "i"); + LLVMBuildStore(builder, LLVMConstInt(fInt32Type, 0, false), ivar); + LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start"); + this->setBlock(builder, start); + LLVMValueRef iload = LLVMBuildLoad(builder, ivar, "load i"); + fVariables[f.fDeclaration.fParameters[0]] = LLVMBuildAdd(builder, + LLVMBuildTrunc(builder, + params.get()[2], + fInt32Type, + "x->Int32"), + iload, + "x"); + LLVMValueRef vectorSize = LLVMConstInt(fInt32Type, fVectorCount, false); + LLVMValueRef test = LLVMBuildICmp(builder, LLVMIntSLT, iload, vectorSize, "i < vectorSize"); + LLVMBasicBlockRef loopBody = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "body"); + LLVMBasicBlockRef loopEnd = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "end"); + LLVMBuildCondBr(builder, test, loopBody, loopEnd); + this->setBlock(builder, loopBody); + LLVMValueRef vec = LLVMGetUndef(fFloat32Vector4Type); + // extract the r, g, b, and a values from the color channel vectors and store them into "color" + for (int i = 0; i < 4; ++i) { + vec = LLVMBuildInsertElement(builder, vec, + LLVMBuildExtractElement(builder, + params.get()[4 + i], + iload, "initial"), + LLVMConstInt(fInt32Type, i, false), + "vec build"); + } + LLVMBuildStore(builder, vec, color); + // write actual loop body + this->compileStatement(builder, *f.fBody); + // extract the r, g, b, and a values from "color" and stick them back into the color channel + // vectors + LLVMValueRef colorLoad = LLVMBuildLoad(builder, color, "color load"); + LLVMBuildStore(builder, + LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, rVec, "rVec"), + LLVMBuildExtractElement(builder, colorLoad, + LLVMConstInt(fInt32Type, 0, + false), + "rExtract"), + iload, "rInsert"), + rVec); + LLVMBuildStore(builder, + LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, gVec, "gVec"), + LLVMBuildExtractElement(builder, colorLoad, + LLVMConstInt(fInt32Type, 1, + false), + "gExtract"), + iload, "gInsert"), + gVec); + LLVMBuildStore(builder, + LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, bVec, "bVec"), + LLVMBuildExtractElement(builder, colorLoad, + LLVMConstInt(fInt32Type, 2, + false), + "bExtract"), + iload, "bInsert"), + bVec); + LLVMBuildStore(builder, + LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, aVec, "aVec"), + LLVMBuildExtractElement(builder, colorLoad, + LLVMConstInt(fInt32Type, 3, + false), + "aExtract"), + iload, "aInsert"), + aVec); + LLVMValueRef inc = LLVMBuildAdd(builder, iload, LLVMConstInt(fInt32Type, 1, false), "inc i"); + LLVMBuildStore(builder, inc, ivar); + LLVMBuildBr(builder, start); + this->setBlock(builder, loopEnd); + // increment program pointer, call the next stage + LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load"); + LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc); + LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType, "cast next->func"); + LLVMValueRef nextInc = LLVMBuildIntToPtr(builder, + LLVMBuildAdd(builder, + LLVMBuildPtrToInt(builder, + programParam, + fInt64Type, + "cast 1"), + LLVMConstInt(fInt64Type, PTR_SIZE, false), + "add"), + LLVMPointerType(fInt8PtrType, 0), "cast 2"); + LLVMValueRef args[STAGE_PARAM_COUNT] = { + params.get()[0], + nextInc, + params.get()[2], + params.get()[3], + LLVMBuildLoad(builder, rVec, "rVec"), + LLVMBuildLoad(builder, gVec, "gVec"), + LLVMBuildLoad(builder, bVec, "bVec"), + LLVMBuildLoad(builder, aVec, "aVec"), + params.get()[8], + params.get()[9], + params.get()[10], + params.get()[11] + }; + LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, ""); + LLVMBuildRetVoid(builder); + // finish + LLVMPositionBuilderAtEnd(builder, fAllocaBlock); + LLVMBuildBr(builder, start); + LLVMDisposeBuilder(builder); + if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) { + ABORT("verify failed\n"); + } + fAllocaBlock = oldAllocaBlock; + fCurrentBlock = oldCurrentBlock; + fCurrentFunction = oldFunction; +} + +// FIXME maybe pluggable code generators? Need to do something to separate all +// of the normal codegen from the vector codegen and break this up into multiple +// classes. + +bool JIT::getVectorLValue(LLVMBuilderRef builder, const Expression& e, + LLVMValueRef out[CHANNELS]) { + switch (e.fKind) { + case Expression::kVariableReference_Kind: + if (fColorParam == &((VariableReference&) e).fVariable) { + memcpy(out, fChannels, sizeof(fChannels)); + return true; + } + return false; + case Expression::kSwizzle_Kind: { + const Swizzle& s = (const Swizzle&) e; + LLVMValueRef base[CHANNELS]; + if (!this->getVectorLValue(builder, *s.fBase, base)) { + return false; + } + for (size_t i = 0; i < s.fComponents.size(); ++i) { + out[i] = base[s.fComponents[i]]; + } + return true; + } + default: + return false; + } +} + +bool JIT::getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left, + LLVMValueRef outLeft[CHANNELS], const Expression& right, + LLVMValueRef outRight[CHANNELS]) { + if (!this->compileVectorExpression(builder, left, outLeft)) { + return false; + } + int leftColumns = left.fType.columns(); + int rightColumns = right.fType.columns(); + if (leftColumns == 1 && rightColumns > 1) { + for (int i = 1; i < rightColumns; ++i) { + outLeft[i] = outLeft[0]; + } + } + if (!this->compileVectorExpression(builder, right, outRight)) { + return false; + } + if (rightColumns == 1 && leftColumns > 1) { + for (int i = 1; i < leftColumns; ++i) { + outRight[i] = outRight[0]; + } + } + return true; +} + +bool JIT::compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b, + LLVMValueRef out[CHANNELS]) { + LLVMValueRef left[CHANNELS]; + LLVMValueRef right[CHANNELS]; + #define VECTOR_BINARY(signedOp, unsignedOp, floatOp) { \ + if (!this->getVectorBinaryOperands(builder, *b.fLeft, left, *b.fRight, right)) { \ + return false; \ + } \ + for (int i = 0; i < b.fLeft->fType.columns(); ++i) { \ + switch (this->typeKind(b.fLeft->fType)) { \ + case kInt_TypeKind: \ + out[i] = signedOp(builder, left[i], right[i], "binary"); \ + break; \ + case kUInt_TypeKind: \ + out[i] = unsignedOp(builder, left[i], right[i], "binary"); \ + break; \ + case kFloat_TypeKind: \ + out[i] = floatOp(builder, left[i], right[i], "binary"); \ + break; \ + case kBool_TypeKind: \ + ASSERT(false); \ + break; \ + } \ + } \ + return true; \ + } + switch (b.fOperator) { + case Token::EQ: { + if (!this->getVectorLValue(builder, *b.fLeft, left)) { + return false; + } + if (!this->compileVectorExpression(builder, *b.fRight, right)) { + return false; + } + int columns = b.fRight->fType.columns(); + for (int i = 0; i < columns; ++i) { + LLVMBuildStore(builder, right[i], left[i]); + } + return true; + } + case Token::PLUS: + VECTOR_BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd); + case Token::MINUS: + VECTOR_BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub); + case Token::STAR: + VECTOR_BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul); + case Token::SLASH: + VECTOR_BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv); + case Token::PERCENT: + VECTOR_BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem); + case Token::BITWISEAND: + VECTOR_BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd); + case Token::BITWISEOR: + VECTOR_BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr); + default: + printf("unsupported operator: %s\n", b.description().c_str()); + return false; + } +} + +bool JIT::compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c, + LLVMValueRef out[CHANNELS]) { + switch (c.fType.kind()) { + case Type::kScalar_Kind: { + ASSERT(c.fArguments.size() == 1); + TypeKind from = this->typeKind(c.fArguments[0]->fType); + TypeKind to = this->typeKind(c.fType); + LLVMValueRef base[CHANNELS]; + if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) { + return false; + } + #define CONSTRUCT(fn) \ + out[0] = LLVMGetUndef(LLVMVectorType(this->getType(c.fType), fVectorCount)); \ + for (int i = 0; i < fVectorCount; ++i) { \ + LLVMValueRef index = LLVMConstInt(fInt32Type, i, false); \ + LLVMValueRef baseVal = LLVMBuildExtractElement(builder, base[0], index, \ + "construct extract"); \ + out[0] = LLVMBuildInsertElement(builder, out[0], \ + fn(builder, baseVal, this->getType(c.fType), \ + "cast"), \ + index, "construct insert"); \ + } \ + return true; + if (kFloat_TypeKind == to) { + if (kInt_TypeKind == from) { + CONSTRUCT(LLVMBuildSIToFP); + } + if (kUInt_TypeKind == from) { + CONSTRUCT(LLVMBuildUIToFP); + } + } + if (kInt_TypeKind == to) { + if (kFloat_TypeKind == from) { + CONSTRUCT(LLVMBuildFPToSI); + } + if (kUInt_TypeKind == from) { + return true; + } + } + if (kUInt_TypeKind == to) { + if (kFloat_TypeKind == from) { + CONSTRUCT(LLVMBuildFPToUI); + } + if (kInt_TypeKind == from) { + return base; + } + } + printf("%s\n", c.description().c_str()); + ABORT("unsupported constructor"); + } + case Type::kVector_Kind: { + if (c.fArguments.size() == 1) { + LLVMValueRef base[CHANNELS]; + if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) { + return false; + } + for (int i = 0; i < c.fType.columns(); ++i) { + out[i] = base[0]; + } + } else { + ASSERT(c.fArguments.size() == (size_t) c.fType.columns()); + for (int i = 0; i < c.fType.columns(); ++i) { + LLVMValueRef base[CHANNELS]; + if (!this->compileVectorExpression(builder, *c.fArguments[i], base)) { + return false; + } + out[i] = base[0]; + } + } + return true; + } + default: + break; + } + ABORT("unsupported constructor"); +} + +bool JIT::compileVectorFloatLiteral(LLVMBuilderRef builder, + const FloatLiteral& f, + LLVMValueRef out[CHANNELS]) { + LLVMValueRef value = LLVMConstReal(this->getType(f.fType), f.fValue); + LLVMValueRef values[MAX_VECTOR_COUNT]; + for (int i = 0; i < fVectorCount; ++i) { + values[i] = value; + } + out[0] = LLVMConstVector(values, fVectorCount); + return true; +} + + +bool JIT::compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s, + LLVMValueRef out[CHANNELS]) { + LLVMValueRef all[CHANNELS]; + if (!this->compileVectorExpression(builder, *s.fBase, all)) { + return false; + } + for (size_t i = 0; i < s.fComponents.size(); ++i) { + out[i] = all[s.fComponents[i]]; + } + return true; +} + +bool JIT::compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v, + LLVMValueRef out[CHANNELS]) { + if (&v.fVariable == fColorParam) { + for (int i = 0; i < CHANNELS; ++i) { + out[i] = LLVMBuildLoad(builder, fChannels[i], "variable reference"); + } + return true; + } + return false; +} + +bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr, + LLVMValueRef out[CHANNELS]) { + switch (expr.fKind) { + case Expression::kBinary_Kind: + return this->compileVectorBinary(builder, (const BinaryExpression&) expr, out); + case Expression::kConstructor_Kind: + return this->compileVectorConstructor(builder, (const Constructor&) expr, out); + case Expression::kFloatLiteral_Kind: + return this->compileVectorFloatLiteral(builder, (const FloatLiteral&) expr, out); + case Expression::kSwizzle_Kind: + return this->compileVectorSwizzle(builder, (const Swizzle&) expr, out); + case Expression::kVariableReference_Kind: + return this->compileVectorVariableReference(builder, (const VariableReference&) expr, + out); + default: + printf("failed expression: %s\n", expr.description().c_str()); + return false; + } +} + +bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt) { + switch (stmt.fKind) { + case Statement::kBlock_Kind: + for (const auto& s : ((const Block&) stmt).fStatements) { + if (!this->compileVectorStatement(builder, *s)) { + return false; + } + } + return true; + case Statement::kExpression_Kind: + LLVMValueRef result; + return this->compileVectorExpression(builder, + *((const ExpressionStatement&) stmt).fExpression, + &result); + default: + printf("failed statement: %s\n", stmt.description().c_str()); + return false; + } +} + +bool JIT::compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc) { + LLVMValueRef oldFunction = fCurrentFunction; + fCurrentFunction = newFunc; + std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]); + LLVMGetParams(fCurrentFunction, params.get()); + LLVMValueRef programParam = params.get()[1]; + LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext); + LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock; + LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock; + fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca"); + this->setBlock(builder, fAllocaBlock); + fChannels[0] = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec"); + LLVMBuildStore(builder, params.get()[4], fChannels[0]); + fChannels[1] = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec"); + LLVMBuildStore(builder, params.get()[5], fChannels[1]); + fChannels[2] = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec"); + LLVMBuildStore(builder, params.get()[6], fChannels[2]); + fChannels[3] = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec"); + LLVMBuildStore(builder, params.get()[7], fChannels[3]); + LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start"); + this->setBlock(builder, start); + bool success = this->compileVectorStatement(builder, *f.fBody); + if (success) { + // increment program pointer, call next + LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load"); + LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc); + LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType, + "cast next->func"); + LLVMValueRef nextInc = LLVMBuildIntToPtr(builder, + LLVMBuildAdd(builder, + LLVMBuildPtrToInt(builder, + programParam, + fInt64Type, + "cast 1"), + LLVMConstInt(fInt64Type, PTR_SIZE, + false), + "add"), + LLVMPointerType(fInt8PtrType, 0), "cast 2"); + LLVMValueRef args[STAGE_PARAM_COUNT] = { + params.get()[0], + nextInc, + params.get()[2], + params.get()[3], + LLVMBuildLoad(builder, fChannels[0], "rVec"), + LLVMBuildLoad(builder, fChannels[1], "gVec"), + LLVMBuildLoad(builder, fChannels[2], "bVec"), + LLVMBuildLoad(builder, fChannels[3], "aVec"), + params.get()[8], + params.get()[9], + params.get()[10], + params.get()[11] + }; + LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, ""); + LLVMBuildRetVoid(builder); + // finish + LLVMPositionBuilderAtEnd(builder, fAllocaBlock); + LLVMBuildBr(builder, start); + LLVMDisposeBuilder(builder); + if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) { + ABORT("verify failed\n"); + } + } else { + LLVMDeleteBasicBlock(fAllocaBlock); + LLVMDeleteBasicBlock(start); + } + + fAllocaBlock = oldAllocaBlock; + fCurrentBlock = oldCurrentBlock; + fCurrentFunction = oldFunction; + return success; +} + +LLVMValueRef JIT::compileStageFunction(const FunctionDefinition& f) { + LLVMTypeRef returnType = fVoidType; + LLVMTypeRef parameterTypes[12] = { fSizeTType, LLVMPointerType(fInt8PtrType, 0), fSizeTType, + fSizeTType, fFloat32VectorType, fFloat32VectorType, + fFloat32VectorType, fFloat32VectorType, fFloat32VectorType, + fFloat32VectorType, fFloat32VectorType, fFloat32VectorType }; + LLVMTypeRef stageFuncType = LLVMFunctionType(returnType, parameterTypes, 12, false); + LLVMValueRef result = LLVMAddFunction(fModule, + (String(f.fDeclaration.fName) + "$stage").c_str(), + stageFuncType); + fColorParam = f.fDeclaration.fParameters[2]; + if (!this->compileStageFunctionVector(f, result)) { + // vectorization failed, fall back to looping over the pixels + this->compileStageFunctionLoop(f, result); + } + return result; +} + +bool JIT::hasStageSignature(const FunctionDeclaration& f) { + return f.fReturnType == *fProgram->fContext->fVoid_Type && + f.fParameters.size() == 3 && + f.fParameters[0]->fType == *fProgram->fContext->fInt_Type && + f.fParameters[0]->fModifiers.fFlags == 0 && + f.fParameters[1]->fType == *fProgram->fContext->fInt_Type && + f.fParameters[1]->fModifiers.fFlags == 0 && + f.fParameters[2]->fType == *fProgram->fContext->fFloat4_Type && + f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag); +} + +LLVMValueRef JIT::compileFunction(const FunctionDefinition& f) { + if (this->hasStageSignature(f.fDeclaration)) { + this->compileStageFunction(f); + // we compile foo$stage *in addition* to compiling foo, as we can't be sure that the intent + // was to produce an SkJumper stage just because the signature matched or that the function + // is not otherwise called. May need a better way to handle this. + } + LLVMTypeRef returnType = this->getType(f.fDeclaration.fReturnType); + std::vector<LLVMTypeRef> parameterTypes; + for (const auto& p : f.fDeclaration.fParameters) { + LLVMTypeRef type = this->getType(p->fType); + if (p->fModifiers.fFlags & Modifiers::kOut_Flag) { + type = LLVMPointerType(type, 0); + } + parameterTypes.push_back(type); + } + fCurrentFunction = LLVMAddFunction(fModule, + String(f.fDeclaration.fName).c_str(), + LLVMFunctionType(returnType, parameterTypes.data(), + parameterTypes.size(), false)); + fFunctions[&f.fDeclaration] = fCurrentFunction; + + std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[parameterTypes.size()]); + LLVMGetParams(fCurrentFunction, params.get()); + for (size_t i = 0; i < f.fDeclaration.fParameters.size(); ++i) { + fVariables[f.fDeclaration.fParameters[i]] = params.get()[i]; + } + LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext); + fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca"); + LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start"); + fCurrentBlock = start; + LLVMPositionBuilderAtEnd(builder, fCurrentBlock); + this->compileStatement(builder, *f.fBody); + if (!ends_with_branch(*f.fBody)) { + if (f.fDeclaration.fReturnType == *fProgram->fContext->fVoid_Type) { + LLVMBuildRetVoid(builder); + } else { + LLVMBuildUnreachable(builder); + } + } + LLVMPositionBuilderAtEnd(builder, fAllocaBlock); + LLVMBuildBr(builder, start); + LLVMDisposeBuilder(builder); + if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) { + ABORT("verify failed\n"); + } + return fCurrentFunction; +} + +void JIT::createModule() { + fPromotedParameters.clear(); + fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext); + this->loadBuiltinFunctions(); + // LLVM doesn't do void*, have to declare it as int8* + LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType }; + fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType, + appendParams, + 3, + false)); + LLVMTypeRef appendCallbackParams[2] = { fInt8PtrType, fInt8PtrType }; + fAppendCallbackFunc = LLVMAddFunction(fModule, "sksl_pipeline_append_callback", + LLVMFunctionType(fVoidType, appendCallbackParams, 2, + false)); + + LLVMTypeRef debugParams[3] = { fFloat32Type }; + fDebugFunc = LLVMAddFunction(fModule, "sksl_debug_print", LLVMFunctionType(fVoidType, + debugParams, + 1, + false)); + + for (const auto& e : fProgram->fElements) { + ASSERT(e->fKind == ProgramElement::kFunction_Kind); + this->compileFunction((FunctionDefinition&) *e); + } +} + +std::unique_ptr<JIT::Module> JIT::compile(std::unique_ptr<Program> program) { + fProgram = std::move(program); + this->createModule(); + this->optimize(); + return std::unique_ptr<Module>(new Module(std::move(fProgram), fSharedModule, fJITStack)); +} + +void JIT::optimize() { + LLVMPassManagerBuilderRef pmb = LLVMPassManagerBuilderCreate(); + LLVMPassManagerBuilderSetOptLevel(pmb, 3); + LLVMPassManagerRef functionPM = LLVMCreateFunctionPassManagerForModule(fModule); + LLVMPassManagerBuilderPopulateFunctionPassManager(pmb, functionPM); + LLVMPassManagerRef modulePM = LLVMCreatePassManager(); + LLVMPassManagerBuilderPopulateModulePassManager(pmb, modulePM); + LLVMInitializeFunctionPassManager(functionPM); + + LLVMValueRef func = LLVMGetFirstFunction(fModule); + for (;;) { + if (!func) { + break; + } + LLVMRunFunctionPassManager(functionPM, func); + func = LLVMGetNextFunction(func); + } + LLVMRunPassManager(modulePM, fModule); + LLVMDisposePassManager(functionPM); + LLVMDisposePassManager(modulePM); + LLVMPassManagerBuilderDispose(pmb); + + std::string error_string; + if (LLVMLoadLibraryPermanently(nullptr)) { + ABORT("LLVMLoadLibraryPermanently failed"); + } + char* defaultTriple = LLVMGetDefaultTargetTriple(); + char* error; + LLVMTargetRef target; + if (LLVMGetTargetFromTriple(defaultTriple, &target, &error)) { + ABORT("LLVMGetTargetFromTriple failed"); + } + + if (!LLVMTargetHasJIT(target)) { + ABORT("!LLVMTargetHasJIT"); + } + LLVMTargetMachineRef targetMachine = LLVMCreateTargetMachine(target, + defaultTriple, + fCPU, + nullptr, + LLVMCodeGenLevelDefault, + LLVMRelocDefault, + LLVMCodeModelJITDefault); + LLVMDisposeMessage(defaultTriple); + LLVMTargetDataRef dataLayout = LLVMCreateTargetDataLayout(targetMachine); + LLVMSetModuleDataLayout(fModule, dataLayout); + LLVMDisposeTargetData(dataLayout); + + fJITStack = LLVMOrcCreateInstance(targetMachine); + fSharedModule = LLVMOrcMakeSharedModule(fModule); + LLVMOrcModuleHandle orcModule; + LLVMOrcAddEagerlyCompiledIR(fJITStack, &orcModule, fSharedModule, + (LLVMOrcSymbolResolverFn) resolveSymbol, this); + LLVMDisposeTargetMachine(targetMachine); +} + +void* JIT::Module::getSymbol(const char* name) { + LLVMOrcTargetAddress result; + if (LLVMOrcGetSymbolAddress(fJITStack, &result, name)) { + ABORT("GetSymbolAddress error"); + } + if (!result) { + ABORT("symbol not found"); + } + return (void*) result; +} + +void* JIT::Module::getJumperStage(const char* name) { + return this->getSymbol((String(name) + "$stage").c_str()); +} + +} // namespace + +#endif // SK_LLVM_AVAILABLE + +#endif // SKSL_STANDALONE diff --git a/src/sksl/SkSLJIT.h b/src/sksl/SkSLJIT.h new file mode 100644 index 0000000000..b23e31237f --- /dev/null +++ b/src/sksl/SkSLJIT.h @@ -0,0 +1,344 @@ +/* + * Copyright 2018 Google Inc. + * + * Use of this source code is governed by a BSD-style license that can be + * found in the LICENSE file. + */ + +#ifndef SKSL_JIT +#define SKSL_JIT + +#ifdef SK_LLVM_AVAILABLE + +#include "ir/SkSLAppendStage.h" +#include "ir/SkSLBinaryExpression.h" +#include "ir/SkSLBreakStatement.h" +#include "ir/SkSLContinueStatement.h" +#include "ir/SkSLExpression.h" +#include "ir/SkSLDoStatement.h" +#include "ir/SkSLForStatement.h" +#include "ir/SkSLFunctionCall.h" +#include "ir/SkSLFunctionDefinition.h" +#include "ir/SkSLIfStatement.h" +#include "ir/SkSLIndexExpression.h" +#include "ir/SkSLPrefixExpression.h" +#include "ir/SkSLPostfixExpression.h" +#include "ir/SkSLProgram.h" +#include "ir/SkSLReturnStatement.h" +#include "ir/SkSLStatement.h" +#include "ir/SkSLSwizzle.h" +#include "ir/SkSLTernaryExpression.h" +#include "ir/SkSLVarDeclarationsStatement.h" +#include "ir/SkSLVariableReference.h" +#include "ir/SkSLWhileStatement.h" + +#include "llvm-c/Analysis.h" +#include "llvm-c/Core.h" +#include "llvm-c/OrcBindings.h" +#include "llvm-c/Support.h" +#include "llvm-c/Target.h" +#include "llvm-c/Transforms/PassManagerBuilder.h" +#include "llvm-c/Types.h" +#include <stack> + +class SkRasterPipeline; + +namespace SkSL { + +/** + * A just-in-time compiler for SkSL code which uses an LLVM backend. Only available when the + * skia_llvm_path gn arg is set. + * + * Example of using SkSLJIT to set up an SkJumper pipeline stage: + * + * #ifdef SK_LLVM_AVAILABLE + * SkSL::Compiler compiler; + * SkSL::Program::Settings settings; + * std::unique_ptr<SkSL::Program> program = compiler.convertProgram(SkSL::Program::kCPU_Kind, + * "void swap(int x, int y, inout float4 color) {" + * " color.rb = color.br;" + * "}", + * settings); + * if (!program) { + * printf("%s\n", compiler.errorText().c_str()); + * abort(); + * } + * SkSL::JIT& jit = *scratch->make<SkSL::JIT>(&compiler); + * std::unique_ptr<SkSL::JIT::Module> module = jit.compile(std::move(program)); + * void* func = module->getJumperStage("swap"); + * p->append(func, nullptr); + * #endif + */ +class JIT { + typedef int StackIndex; + +public: + class Module { + public: + /** + * Returns the address of a symbol in the module. + */ + void* getSymbol(const char* name); + + /** + * Returns the address of a function as an SkJumper pipeline stage. The function must have + * the signature void <name>(int x, int y, inout float4 color). The returned function will + * have the correct signature to function as an SkJumper stage (meaning it will actually + * have a different signature at runtime, accepting vector parameters and operating on + * multiple pixels simultaneously as is normal for SkJumper stages). + */ + void* getJumperStage(const char* name); + + ~Module() { + LLVMOrcDisposeSharedModuleRef(fSharedModule); + } + + private: + Module(std::unique_ptr<Program> program, + LLVMSharedModuleRef sharedModule, + LLVMOrcJITStackRef jitStack) + : fProgram(std::move(program)) + , fSharedModule(sharedModule) + , fJITStack(jitStack) {} + + std::unique_ptr<Program> fProgram; + LLVMSharedModuleRef fSharedModule; + LLVMOrcJITStackRef fJITStack; + + friend class JIT; + }; + + JIT(Compiler* compiler); + + ~JIT(); + + /** + * Just-in-time compiles an SkSL program and returns the resulting Module. The JIT must not be + * destroyed before all of its Modules are destroyed. + */ + std::unique_ptr<Module> compile(std::unique_ptr<Program> program); + +private: + static constexpr int CHANNELS = 4; + + enum TypeKind { + kFloat_TypeKind, + kInt_TypeKind, + kUInt_TypeKind, + kBool_TypeKind + }; + + class LValue { + public: + virtual ~LValue() {} + + virtual LLVMValueRef load(LLVMBuilderRef builder) = 0; + + virtual void store(LLVMBuilderRef builder, LLVMValueRef value) = 0; + }; + + void addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType, + std::vector<LLVMTypeRef> parameters); + + void loadBuiltinFunctions(); + + void setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block); + + LLVMTypeRef getType(const Type& type); + + TypeKind typeKind(const Type& type); + + std::unique_ptr<LValue> getLValue(LLVMBuilderRef builder, const Expression& expr); + + void vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns); + + void vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left, + LLVMValueRef* right); + + LLVMValueRef compileBinary(LLVMBuilderRef builder, const BinaryExpression& b); + + LLVMValueRef compileConstructor(LLVMBuilderRef builder, const Constructor& c); + + LLVMValueRef compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc); + + LLVMValueRef compileIndex(LLVMBuilderRef builder, const IndexExpression& v); + + LLVMValueRef compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p); + + LLVMValueRef compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p); + + LLVMValueRef compileSwizzle(LLVMBuilderRef builder, const Swizzle& s); + + LLVMValueRef compileVariableReference(LLVMBuilderRef builder, const VariableReference& v); + + LLVMValueRef compileTernary(LLVMBuilderRef builder, const TernaryExpression& t); + + LLVMValueRef compileExpression(LLVMBuilderRef builder, const Expression& expr); + + void appendStage(LLVMBuilderRef builder, const AppendStage& a); + + void compileBlock(LLVMBuilderRef builder, const Block& block); + + void compileBreak(LLVMBuilderRef builder, const BreakStatement& b); + + void compileContinue(LLVMBuilderRef builder, const ContinueStatement& c); + + void compileDo(LLVMBuilderRef builder, const DoStatement& d); + + void compileFor(LLVMBuilderRef builder, const ForStatement& f); + + void compileIf(LLVMBuilderRef builder, const IfStatement& i); + + void compileReturn(LLVMBuilderRef builder, const ReturnStatement& r); + + void compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls); + + void compileWhile(LLVMBuilderRef builder, const WhileStatement& w); + + void compileStatement(LLVMBuilderRef builder, const Statement& stmt); + + // The "Vector" variants of functions attempt to compile a given expression or statement as part + // of a vectorized SkJumper stage function - that is, with r, g, b, and a each being vectors of + // fVectorCount floats. So a statement like "color.r = 0;" looks like it modifies a single + // channel of a single pixel, but the compiled code will actually modify the red channel of + // fVectorCount pixels at once. + // + // As not everything can be vectorized, these calls return a bool to indicate whether they were + // successful. If anything anywhere in the function cannot be vectorized, the JIT will fall back + // to looping over the pixels instead. + // + // Since we process multiple pixels at once, and each pixel consists of multiple color channels, + // expressions may effectively result in a vector-of-vectors. We produce zero to four outputs + // when compiling expression, each of which is a vector, so that e.g. float2(1, 0) actually + // produces two vectors, one containing all 1s, the other all 0s. The out parameter always + // allows for 4 channels, but the functions produce 0 to 4 channels depending on the type they + // are operating on. Thus evaluating "color.rgb" actually fills in out[0] through out[2], + // leaving out[3] uninitialized. + // As the number of outputs can be inferred from the type of the expression, it is not + // explicitly signalled anywhere. + bool compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b, + LLVMValueRef out[CHANNELS]); + + bool compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c, + LLVMValueRef out[CHANNELS]); + + bool compileVectorFloatLiteral(LLVMBuilderRef builder, const FloatLiteral& f, + LLVMValueRef out[CHANNELS]); + + bool compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s, + LLVMValueRef out[CHANNELS]); + + bool compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v, + LLVMValueRef out[CHANNELS]); + + bool compileVectorExpression(LLVMBuilderRef builder, const Expression& expr, + LLVMValueRef out[CHANNELS]); + + bool getVectorLValue(LLVMBuilderRef builder, const Expression& e, LLVMValueRef out[CHANNELS]); + + /** + * Evaluates the left and right operands of a binary operation, promoting one of them to a + * vector if necessary to make the types match. + */ + bool getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left, + LLVMValueRef outLeft[CHANNELS], const Expression& right, + LLVMValueRef outRight[CHANNELS]); + + bool compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt); + + /** + * Returns true if this function has the signature void(int, int, inout float4) and thus can be + * used as an SkJumper stage. + */ + bool hasStageSignature(const FunctionDeclaration& f); + + /** + * Attempts to compile a vectorized stage function, returning true on success. A stage function + * of e.g. "color.r = 0;" will produce code which sets the entire red vector to zeros in a + * single instruction, thus calculating several pixels at once. + */ + bool compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc); + + /** + * Fallback function which loops over the pixels, for when vectorization fails. A stage function + * of e.g. "color.r = 0;" will produce a loop which iterates over the entries in the red vector, + * setting each one to zero individually. + */ + void compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc); + + /** + * Called when compiling a function which has the signature of an SkJumper stage. Produces a + * version of the function which can be plugged into SkJumper (thus having a signature which + * accepts four vectors, one for each color channel, containing the color data of multiple + * pixels at once). To go from SkSL code which operates on a single pixel at a time to CPU code + * which operates on multiple pixels at once, the code is either vectorized using + * compileStageFunctionVector or wrapped in a loop using compileStageFunctionLoop. + */ + LLVMValueRef compileStageFunction(const FunctionDefinition& f); + + /** + * Compiles an SkSL function to an LLVM function. If the function has the signature of an + * SkJumper stage, it will *also* be compiled by compileStageFunction, resulting in both a stage + * and non-stage version of the function. + */ + LLVMValueRef compileFunction(const FunctionDefinition& f); + + void createModule(); + + void optimize(); + + bool isColorRef(const Expression& expr); + + static uint64_t resolveSymbol(const char* name, JIT* jit); + + const char* fCPU; + int fVectorCount; + Compiler& fCompiler; + std::unique_ptr<Program> fProgram; + LLVMContextRef fContext; + LLVMModuleRef fModule; + LLVMSharedModuleRef fSharedModule; + LLVMOrcJITStackRef fJITStack; + LLVMValueRef fCurrentFunction; + LLVMBasicBlockRef fAllocaBlock; + LLVMBasicBlockRef fCurrentBlock; + LLVMTypeRef fVoidType; + LLVMTypeRef fInt1Type; + LLVMTypeRef fInt8Type; + LLVMTypeRef fInt8PtrType; + LLVMTypeRef fInt32Type; + LLVMTypeRef fInt32VectorType; + LLVMTypeRef fInt32Vector2Type; + LLVMTypeRef fInt32Vector3Type; + LLVMTypeRef fInt32Vector4Type; + LLVMTypeRef fInt64Type; + LLVMTypeRef fSizeTType; + LLVMTypeRef fFloat32Type; + LLVMTypeRef fFloat32VectorType; + LLVMTypeRef fFloat32Vector2Type; + LLVMTypeRef fFloat32Vector3Type; + LLVMTypeRef fFloat32Vector4Type; + // Our SkSL stage functions have a single float4 for color, but the actual SkJumper stage + // function has four separate vectors, one for each channel. These four values are references to + // the red, green, blue, and alpha vectors respectively. + LLVMValueRef fChannels[CHANNELS]; + // when processing a stage function, this points to the SkSL color parameter (an inout float4) + const Variable* fColorParam; + std::map<const FunctionDeclaration*, LLVMValueRef> fFunctions; + std::map<const Variable*, LLVMValueRef> fVariables; + // LLVM function parameters are read-only, so when modifying function parameters we need to + // first promote them to variables. This keeps track of which parameters have been promoted. + std::set<const Variable*> fPromotedParameters; + std::vector<LLVMBasicBlockRef> fBreakTarget; + std::vector<LLVMBasicBlockRef> fContinueTarget; + + LLVMValueRef fAppendFunc; + LLVMValueRef fAppendCallbackFunc; + LLVMValueRef fDebugFunc; +}; + +} // namespace + +#endif // SK_LLVM_AVAILABLE + +#endif // SKSL_JIT diff --git a/src/sksl/ir/SkSLAppendStage.h b/src/sksl/ir/SkSLAppendStage.h new file mode 100644 index 0000000000..87a8210a83 --- /dev/null +++ b/src/sksl/ir/SkSLAppendStage.h @@ -0,0 +1,53 @@ +/* + * Copyright 2018 Google Inc. + * + * Use of this source code is governed by a BSD-style license that can be + * found in the LICENSE file. + */ + +#ifndef SKSL_APPENDSTAGE +#define SKSL_APPENDSTAGE + +#ifndef SKSL_STANDALONE + +#include "SkRasterPipeline.h" +#include "SkSLContext.h" +#include "SkSLExpression.h" + +namespace SkSL { + +struct AppendStage : public Expression { + AppendStage(const Context& context, int offset, SkRasterPipeline::StockStage stage, + std::vector<std::unique_ptr<Expression>> arguments) + : INHERITED(offset, kAppendStage_Kind, *context.fVoid_Type) + , fStage(stage) + , fArguments(std::move(arguments)) {} + + String description() const { + String result = "append("; + const char* separator = ""; + for (const auto& a : fArguments) { + result += separator; + result += a->description(); + separator = ", "; + } + result += ")"; + return result; + } + + bool hasSideEffects() const { + return true; + } + + SkRasterPipeline::StockStage fStage; + + std::vector<std::unique_ptr<Expression>> fArguments; + + typedef Expression INHERITED; +}; + +} // namespace + +#endif // SKSL_STANDALONE + +#endif // SKSL_APPENDSTAGE diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h index 52fa5d297f..c8ad1380e7 100644 --- a/src/sksl/ir/SkSLExpression.h +++ b/src/sksl/ir/SkSLExpression.h @@ -25,6 +25,7 @@ typedef std::unordered_map<const Variable*, std::unique_ptr<Expression>*> Defini */ struct Expression : public IRNode { enum Kind { + kAppendStage_Kind, kBinary_Kind, kBoolLiteral_Kind, kConstructor_Kind, diff --git a/src/sksl/ir/SkSLFunctionReference.h b/src/sksl/ir/SkSLFunctionReference.h index 58831c5e99..58fefce801 100644 --- a/src/sksl/ir/SkSLFunctionReference.h +++ b/src/sksl/ir/SkSLFunctionReference.h @@ -29,7 +29,6 @@ struct FunctionReference : public Expression { } String description() const override { - ASSERT(false); return String("<function>"); } diff --git a/src/sksl/ir/SkSLProgram.h b/src/sksl/ir/SkSLProgram.h index a63cd237ac..cbb9dfe1a7 100644 --- a/src/sksl/ir/SkSLProgram.h +++ b/src/sksl/ir/SkSLProgram.h @@ -103,13 +103,14 @@ struct Program { kFragment_Kind, kVertex_Kind, kGeometry_Kind, - kFragmentProcessor_Kind + kFragmentProcessor_Kind, + kCPU_Kind }; Program(Kind kind, std::unique_ptr<String> source, Settings settings, - Context* context, + std::shared_ptr<Context> context, std::vector<std::unique_ptr<ProgramElement>> elements, std::shared_ptr<SymbolTable> symbols, Inputs inputs) @@ -124,7 +125,7 @@ struct Program { Kind fKind; std::unique_ptr<String> fSource; Settings fSettings; - Context* fContext; + std::shared_ptr<Context> fContext; // it's important to keep fElements defined after (and thus destroyed before) fSymbols, // because destroying elements can modify reference counts in symbols std::shared_ptr<SymbolTable> fSymbols; diff --git a/src/sksl/sksl_cpu.inc b/src/sksl/sksl_cpu.inc new file mode 100644 index 0000000000..479450bd33 --- /dev/null +++ b/src/sksl/sksl_cpu.inc @@ -0,0 +1,12 @@ +STRINGIFY( + // special-cased within the compiler - append takes various arguments depending on what kind of + // stage is being appended + sk_has_side_effects void append(); + + float abs(float x); + float sin(float x); + float cos(float x); + float tan(float x); + float sqrt(float x); + sk_has_side_effects void print(float x); +) |