aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/sksl
diff options
context:
space:
mode:
authorGravatar Ethan Nicholas <ethannicholas@google.com>2018-03-27 14:10:52 -0400
committerGravatar Skia Commit-Bot <skia-commit-bot@chromium.org>2018-03-27 18:39:13 +0000
commit26a9aad63b60c9cbbdfa87c212a4e76ce55e7373 (patch)
treed4a42afd75bdff6c7815be7bf78b42cab742df19 /src/sksl
parent3560b58de36988e1fba54d9ac341735ab849e913 (diff)
initial SkSLJIT checkin
Docs-Preview: https://skia.org/?cl=112204 Bug: skia: Change-Id: I10042a0200db00bd8ff8078467c409b1cf191f50 Reviewed-on: https://skia-review.googlesource.com/112204 Commit-Queue: Ethan Nicholas <ethannicholas@google.com> Reviewed-by: Mike Klein <mtklein@chromium.org>
Diffstat (limited to 'src/sksl')
-rw-r--r--src/sksl/SkSLCFGGenerator.cpp1
-rw-r--r--src/sksl/SkSLCompiler.cpp77
-rw-r--r--src/sksl/SkSLCompiler.h6
-rw-r--r--src/sksl/SkSLContext.h2
-rw-r--r--src/sksl/SkSLIRGenerator.cpp90
-rw-r--r--src/sksl/SkSLIRGenerator.h2
-rw-r--r--src/sksl/SkSLInterpreter.cpp473
-rw-r--r--src/sksl/SkSLInterpreter.h89
-rw-r--r--src/sksl/SkSLJIT.cpp1747
-rw-r--r--src/sksl/SkSLJIT.h344
-rw-r--r--src/sksl/ir/SkSLAppendStage.h53
-rw-r--r--src/sksl/ir/SkSLExpression.h1
-rw-r--r--src/sksl/ir/SkSLFunctionReference.h1
-rw-r--r--src/sksl/ir/SkSLProgram.h7
-rw-r--r--src/sksl/sksl_cpu.inc12
15 files changed, 2868 insertions, 37 deletions
diff --git a/src/sksl/SkSLCFGGenerator.cpp b/src/sksl/SkSLCFGGenerator.cpp
index c56d0c1b0f..852f48d71d 100644
--- a/src/sksl/SkSLCFGGenerator.cpp
+++ b/src/sksl/SkSLCFGGenerator.cpp
@@ -387,6 +387,7 @@ void CFGGenerator::addExpression(CFG& cfg, std::unique_ptr<Expression>* e, bool
cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind,
constantPropagate, e, nullptr });
break;
+ case Expression::kAppendStage_Kind: // fall through
case Expression::kBoolLiteral_Kind: // fall through
case Expression::kFloatLiteral_Kind: // fall through
case Expression::kIntLiteral_Kind: // fall through
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index bbaaf407d2..4650e5c643 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -54,17 +54,22 @@ static const char* SKSL_FP_INCLUDE =
#include "sksl_fp.inc"
;
+static const char* SKSL_CPU_INCLUDE =
+#include "sksl_cpu.inc"
+;
+
namespace SkSL {
Compiler::Compiler(Flags flags)
: fFlags(flags)
+, fContext(new Context())
, fErrorCount(0) {
auto types = std::shared_ptr<SymbolTable>(new SymbolTable(this));
auto symbols = std::shared_ptr<SymbolTable>(new SymbolTable(types, this));
- fIRGenerator = new IRGenerator(&fContext, symbols, *this);
+ fIRGenerator = new IRGenerator(fContext.get(), symbols, *this);
fTypes = types;
- #define ADD_TYPE(t) types->addWithoutOwnership(fContext.f ## t ## _Type->fName, \
- fContext.f ## t ## _Type.get())
+ #define ADD_TYPE(t) types->addWithoutOwnership(fContext->f ## t ## _Type->fName, \
+ fContext->f ## t ## _Type.get())
ADD_TYPE(Void);
ADD_TYPE(Float);
ADD_TYPE(Float2);
@@ -188,15 +193,16 @@ Compiler::Compiler(Flags flags)
ADD_TYPE(GSampler2DArrayShadow);
ADD_TYPE(GSamplerCubeArrayShadow);
ADD_TYPE(FragmentProcessor);
+ ADD_TYPE(SkRasterPipeline);
StringFragment skCapsName("sk_Caps");
Variable* skCaps = new Variable(-1, Modifiers(), skCapsName,
- *fContext.fSkCaps_Type, Variable::kGlobal_Storage);
+ *fContext->fSkCaps_Type, Variable::kGlobal_Storage);
fIRGenerator->fSymbolTable->add(skCapsName, std::unique_ptr<Symbol>(skCaps));
StringFragment skArgsName("sk_Args");
Variable* skArgs = new Variable(-1, Modifiers(), skArgsName,
- *fContext.fSkArgs_Type, Variable::kGlobal_Storage);
+ *fContext->fSkArgs_Type, Variable::kGlobal_Storage);
fIRGenerator->fSymbolTable->add(skArgsName, std::unique_ptr<Symbol>(skArgs));
std::vector<std::unique_ptr<ProgramElement>> ignored;
@@ -232,19 +238,19 @@ void Compiler::addDefinition(const Expression* lvalue, std::unique_ptr<Expressio
// but since we pass foo as a whole it is flagged as an error) unless we perform a much
// more complicated whole-program analysis. This is probably good enough.
this->addDefinition(((Swizzle*) lvalue)->fBase.get(),
- (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
+ (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
definitions);
break;
case Expression::kIndex_Kind:
// see comments in Swizzle
this->addDefinition(((IndexExpression*) lvalue)->fBase.get(),
- (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
+ (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
definitions);
break;
case Expression::kFieldAccess_Kind:
// see comments in Swizzle
this->addDefinition(((FieldAccess*) lvalue)->fBase.get(),
- (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
+ (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
definitions);
break;
case Expression::kTernary_Kind:
@@ -252,10 +258,10 @@ void Compiler::addDefinition(const Expression* lvalue, std::unique_ptr<Expressio
// This allows for false positives (meaning we fail to detect that a variable might not
// have been assigned), but is preferable to false negatives.
this->addDefinition(((TernaryExpression*) lvalue)->fIfTrue.get(),
- (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
+ (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
definitions);
this->addDefinition(((TernaryExpression*) lvalue)->fIfFalse.get(),
- (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
+ (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
definitions);
break;
default:
@@ -278,9 +284,9 @@ void Compiler::addDefinitions(const BasicBlock::Node& node,
this->addDefinition(b->fLeft.get(), &b->fRight, definitions);
} else if (Compiler::IsAssignment(b->fOperator)) {
this->addDefinition(
- b->fLeft.get(),
- (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
- definitions);
+ b->fLeft.get(),
+ (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
+ definitions);
}
break;
@@ -289,9 +295,9 @@ void Compiler::addDefinitions(const BasicBlock::Node& node,
const PrefixExpression* p = (PrefixExpression*) expr;
if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) {
this->addDefinition(
- p->fOperand.get(),
- (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
- definitions);
+ p->fOperand.get(),
+ (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
+ definitions);
}
break;
}
@@ -299,9 +305,9 @@ void Compiler::addDefinitions(const BasicBlock::Node& node,
const PostfixExpression* p = (PostfixExpression*) expr;
if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) {
this->addDefinition(
- p->fOperand.get(),
- (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
- definitions);
+ p->fOperand.get(),
+ (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
+ definitions);
}
break;
}
@@ -309,9 +315,9 @@ void Compiler::addDefinitions(const BasicBlock::Node& node,
const VariableReference* v = (VariableReference*) expr;
if (v->fRefKind != VariableReference::kRead_RefKind) {
this->addDefinition(
- v,
- (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
- definitions);
+ v,
+ (std::unique_ptr<Expression>*) &fContext->fDefined_Expression,
+ definitions);
}
}
default:
@@ -343,6 +349,9 @@ void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) {
// propagate definitions to exits
for (BlockId exitId : block.fExits) {
+ if (exitId == blockId) {
+ continue;
+ }
BasicBlock& exit = cfg->fBlocks[exitId];
for (const auto& pair : after) {
std::unique_ptr<Expression>* e1 = pair.second;
@@ -359,7 +368,7 @@ void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) {
workList->insert(exitId);
if (e1 && e2) {
exit.fBefore[pair.first] =
- (std::unique_ptr<Expression>*) &fContext.fDefined_Expression;
+ (std::unique_ptr<Expression>*) &fContext->fDefined_Expression;
} else {
exit.fBefore[pair.first] = nullptr;
}
@@ -990,7 +999,7 @@ void Compiler::simplifyStatement(DefinitionMap& definitions,
continue;
}
ASSERT(c->fValue->fKind == s.fValue->fKind);
- found = c->fValue->compareConstant(fContext, *s.fValue);
+ found = c->fValue->compareConstant(*fContext, *s.fValue);
if (found) {
std::unique_ptr<Statement> newBlock = block_for_case(&s, c.get());
if (newBlock) {
@@ -1153,7 +1162,7 @@ void Compiler::scanCFG(FunctionDefinition& f) {
}
// check for missing return
- if (f.fDeclaration.fReturnType != *fContext.fVoid_Type) {
+ if (f.fDeclaration.fReturnType != *fContext->fVoid_Type) {
if (cfg.fBlocks[cfg.fExit].fEntrances.size()) {
this->error(f.fOffset, String("function can exit without returning a value"));
}
@@ -1183,6 +1192,10 @@ std::unique_ptr<Program> Compiler::convertProgram(Program::Kind kind, String tex
fIRGenerator->convertProgram(kind, SKSL_FP_INCLUDE, strlen(SKSL_FP_INCLUDE), *fTypes,
&elements);
break;
+ case Program::kCPU_Kind:
+ fIRGenerator->convertProgram(kind, SKSL_CPU_INCLUDE, strlen(SKSL_CPU_INCLUDE),
+ *fTypes, &elements);
+ break;
}
fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
for (auto& element : elements) {
@@ -1203,7 +1216,7 @@ std::unique_ptr<Program> Compiler::convertProgram(Program::Kind kind, String tex
auto result = std::unique_ptr<Program>(new Program(kind,
std::move(textPtr),
settings,
- &fContext,
+ fContext,
std::move(elements),
fIRGenerator->fSymbolTable,
fIRGenerator->fInputs));
@@ -1220,7 +1233,7 @@ bool Compiler::toSPIRV(const Program& program, OutputStream& out) {
#ifdef SK_ENABLE_SPIRV_VALIDATION
StringStream buffer;
fSource = program.fSource.get();
- SPIRVCodeGenerator cg(&fContext, &program, this, &buffer);
+ SPIRVCodeGenerator cg(fContext.get(), &program, this, &buffer);
bool result = cg.generateCode();
fSource = nullptr;
if (result) {
@@ -1238,7 +1251,7 @@ bool Compiler::toSPIRV(const Program& program, OutputStream& out) {
}
#else
fSource = program.fSource.get();
- SPIRVCodeGenerator cg(&fContext, &program, this, &out);
+ SPIRVCodeGenerator cg(fContext.get(), &program, this, &out);
bool result = cg.generateCode();
fSource = nullptr;
#endif
@@ -1257,7 +1270,7 @@ bool Compiler::toSPIRV(const Program& program, String* out) {
bool Compiler::toGLSL(const Program& program, OutputStream& out) {
fSource = program.fSource.get();
- GLSLCodeGenerator cg(&fContext, &program, this, &out);
+ GLSLCodeGenerator cg(fContext.get(), &program, this, &out);
bool result = cg.generateCode();
fSource = nullptr;
this->writeErrorCount();
@@ -1274,7 +1287,7 @@ bool Compiler::toGLSL(const Program& program, String* out) {
}
bool Compiler::toMetal(const Program& program, OutputStream& out) {
- MetalCodeGenerator cg(&fContext, &program, this, &out);
+ MetalCodeGenerator cg(fContext.get(), &program, this, &out);
bool result = cg.generateCode();
this->writeErrorCount();
return result;
@@ -1282,7 +1295,7 @@ bool Compiler::toMetal(const Program& program, OutputStream& out) {
bool Compiler::toCPP(const Program& program, String name, OutputStream& out) {
fSource = program.fSource.get();
- CPPCodeGenerator cg(&fContext, &program, this, name, &out);
+ CPPCodeGenerator cg(fContext.get(), &program, this, name, &out);
bool result = cg.generateCode();
fSource = nullptr;
this->writeErrorCount();
@@ -1291,7 +1304,7 @@ bool Compiler::toCPP(const Program& program, String name, OutputStream& out) {
bool Compiler::toH(const Program& program, String name, OutputStream& out) {
fSource = program.fSource.get();
- HCodeGenerator cg(&fContext, &program, this, name, &out);
+ HCodeGenerator cg(fContext.get(), &program, this, name, &out);
bool result = cg.generateCode();
fSource = nullptr;
this->writeErrorCount();
diff --git a/src/sksl/SkSLCompiler.h b/src/sksl/SkSLCompiler.h
index 9f8ef4f18a..0ed6a3bdbb 100644
--- a/src/sksl/SkSLCompiler.h
+++ b/src/sksl/SkSLCompiler.h
@@ -88,6 +88,10 @@ public:
return fErrorCount;
}
+ Context& context() {
+ return *fContext;
+ }
+
static const char* OperatorName(Token::Kind token);
static bool IsAssignment(Token::Kind token);
@@ -134,7 +138,7 @@ private:
int fFlags;
const String* fSource;
- Context fContext;
+ std::shared_ptr<Context> fContext;
int fErrorCount;
String fErrorText;
};
diff --git a/src/sksl/SkSLContext.h b/src/sksl/SkSLContext.h
index 407dbf8e82..61dd8042e4 100644
--- a/src/sksl/SkSLContext.h
+++ b/src/sksl/SkSLContext.h
@@ -185,6 +185,7 @@ public:
, fSkCaps_Type(new Type("$sk_Caps"))
, fSkArgs_Type(new Type("$sk_Args"))
, fFragmentProcessor_Type(new Type("fragmentProcessor"))
+ , fSkRasterPipeline_Type(new Type("SkRasterPipeline"))
, fDefined_Expression(new Defined(*fInvalid_Type)) {}
static std::vector<const Type*> static_type(const Type& t) {
@@ -333,6 +334,7 @@ public:
const std::unique_ptr<Type> fSkCaps_Type;
const std::unique_ptr<Type> fSkArgs_Type;
const std::unique_ptr<Type> fFragmentProcessor_Type;
+ const std::unique_ptr<Type> fSkRasterPipeline_Type;
// dummy expression used to mark that a variable has a value during dataflow analysis (when it
// could have several different values, or the analyzer is otherwise unable to assign it a
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index 68d9f9c0ca..815ec152cd 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -17,6 +17,7 @@
#include "ast/SkSLASTFloatLiteral.h"
#include "ast/SkSLASTIndexSuffix.h"
#include "ast/SkSLASTIntLiteral.h"
+#include "ir/SkSLAppendStage.h"
#include "ir/SkSLBinaryExpression.h"
#include "ir/SkSLBoolLiteral.h"
#include "ir/SkSLBreakStatement.h"
@@ -1971,6 +1972,91 @@ std::unique_ptr<Expression> IRGenerator::convertTypeField(int offset, const Type
return result;
}
+std::unique_ptr<Expression> IRGenerator::convertAppend(int offset,
+ const std::vector<std::unique_ptr<ASTExpression>>& args) {
+#ifndef SKSL_STANDALONE
+ if (args.size() < 2) {
+ fErrors.error(offset, "'append' requires at least two arguments");
+ return nullptr;
+ }
+ std::unique_ptr<Expression> pipeline = this->convertExpression(*args[0]);
+ if (!pipeline) {
+ return nullptr;
+ }
+ if (pipeline->fType != *fContext.fSkRasterPipeline_Type) {
+ fErrors.error(offset, "first argument of 'append' must have type 'SkRasterPipeline'");
+ return nullptr;
+ }
+ if (ASTExpression::kIdentifier_Kind != args[1]->fKind) {
+ fErrors.error(offset, "'" + args[1]->description() + "' is not a valid stage");
+ return nullptr;
+ }
+ StringFragment name = ((const ASTIdentifier&) *args[1]).fText;
+ SkRasterPipeline::StockStage stage = SkRasterPipeline::premul;
+ std::vector<std::unique_ptr<Expression>> stageArgs;
+ stageArgs.push_back(std::move(pipeline));
+ for (size_t i = 2; i < args.size(); ++i) {
+ std::unique_ptr<Expression> arg = this->convertExpression(*args[i]);
+ if (!arg) {
+ return nullptr;
+ }
+ stageArgs.push_back(std::move(arg));
+ }
+ size_t expectedArgs = 0;
+ // FIXME use a map
+ if ("premul" == name) {
+ stage = SkRasterPipeline::premul;
+ }
+ else if ("unpremul" == name) {
+ stage = SkRasterPipeline::unpremul;
+ }
+ else if ("clamp_0" == name) {
+ stage = SkRasterPipeline::clamp_0;
+ }
+ else if ("clamp_1" == name) {
+ stage = SkRasterPipeline::clamp_1;
+ }
+ else if ("matrix_4x5" == name) {
+ expectedArgs = 1;
+ stage = SkRasterPipeline::matrix_4x5;
+ if (1 == stageArgs.size() && stageArgs[0]->fType.fName != "float[20]") {
+ fErrors.error(offset, "pipeline stage '" + name + "' expected a float[20] argument");
+ return nullptr;
+ }
+ }
+ else {
+ bool found = false;
+ for (const auto& e : *fProgramElements) {
+ if (ProgramElement::kFunction_Kind == e->fKind) {
+ const FunctionDefinition& f = (const FunctionDefinition&) *e;
+ if (f.fDeclaration.fName == name) {
+ stage = SkRasterPipeline::callback;
+ std::vector<const FunctionDeclaration*> functions = { &f.fDeclaration };
+ stageArgs.emplace_back(new FunctionReference(fContext, offset, functions));
+ found = true;
+ break;
+ }
+ }
+ }
+ if (!found) {
+ fErrors.error(offset, "'" + name + "' is not a valid pipeline stage");
+ return nullptr;
+ }
+ }
+ if (args.size() != expectedArgs + 2) {
+ fErrors.error(offset, "pipeline stage '" + name + "' expected an additional argument " +
+ "count of " + to_string((int) expectedArgs) + ", but found " +
+ to_string((int) args.size() - 1));
+ return nullptr;
+ }
+ return std::unique_ptr<Expression>(new AppendStage(fContext, offset, stage,
+ std::move(stageArgs)));
+#else
+ ASSERT(false);
+ return nullptr;
+#endif
+}
+
std::unique_ptr<Expression> IRGenerator::convertSuffixExpression(
const ASTSuffixExpression& expression) {
std::unique_ptr<Expression> base = this->convertExpression(*expression.fBase);
@@ -1996,6 +2082,10 @@ std::unique_ptr<Expression> IRGenerator::convertSuffixExpression(
}
case ASTSuffix::kCall_Kind: {
auto rawArguments = &((ASTCallSuffix&) *expression.fSuffix).fArguments;
+ if (Expression::kFunctionReference_Kind == base->fKind &&
+ "append" == ((const FunctionReference&) *base).fFunctions[0]->fName) {
+ return convertAppend(expression.fOffset, *rawArguments);
+ }
std::vector<std::unique_ptr<Expression>> arguments;
for (size_t i = 0; i < rawArguments->size(); i++) {
std::unique_ptr<Expression> converted =
diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h
index e8ff1d21c8..c78c195c0b 100644
--- a/src/sksl/SkSLIRGenerator.h
+++ b/src/sksl/SkSLIRGenerator.h
@@ -114,6 +114,8 @@ private:
std::vector<std::unique_ptr<Expression>> arguments);
int coercionCost(const Expression& expr, const Type& type);
std::unique_ptr<Expression> coerce(std::unique_ptr<Expression> expr, const Type& type);
+ std::unique_ptr<Expression> convertAppend(int offset,
+ const std::vector<std::unique_ptr<ASTExpression>>& args);
std::unique_ptr<Block> convertBlock(const ASTBlock& block);
std::unique_ptr<Statement> convertBreak(const ASTBreakStatement& b);
std::unique_ptr<Expression> convertNumberConstructor(
diff --git a/src/sksl/SkSLInterpreter.cpp b/src/sksl/SkSLInterpreter.cpp
new file mode 100644
index 0000000000..c9b7ceb44c
--- /dev/null
+++ b/src/sksl/SkSLInterpreter.cpp
@@ -0,0 +1,473 @@
+/*
+ * Copyright 2018 Google Inc.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_STANDALONE
+
+#include "SkSLInterpreter.h"
+#include "ir/SkSLBinaryExpression.h"
+#include "ir/SkSLExpressionStatement.h"
+#include "ir/SkSLForStatement.h"
+#include "ir/SkSLFunctionCall.h"
+#include "ir/SkSLFunctionReference.h"
+#include "ir/SkSLIfStatement.h"
+#include "ir/SkSLIndexExpression.h"
+#include "ir/SkSLPostfixExpression.h"
+#include "ir/SkSLPrefixExpression.h"
+#include "ir/SkSLProgram.h"
+#include "ir/SkSLStatement.h"
+#include "ir/SkSLTernaryExpression.h"
+#include "ir/SkSLVarDeclarations.h"
+#include "ir/SkSLVarDeclarationsStatement.h"
+#include "ir/SkSLVariableReference.h"
+#include "SkRasterPipeline.h"
+#include "../jumper/SkJumper.h"
+
+namespace SkSL {
+
+void Interpreter::run() {
+ for (const auto& e : fProgram->fElements) {
+ if (ProgramElement::kFunction_Kind == e->fKind) {
+ const FunctionDefinition& f = (const FunctionDefinition&) *e;
+ if ("appendStages" == f.fDeclaration.fName) {
+ this->run(f);
+ return;
+ }
+ }
+ }
+ ASSERT(false);
+}
+
+static int SizeOf(const Type& type) {
+ return 1;
+}
+
+void Interpreter::run(const FunctionDefinition& f) {
+ fVars.emplace_back();
+ StackIndex current = (StackIndex) fStack.size();
+ for (int i = f.fDeclaration.fParameters.size() - 1; i >= 0; --i) {
+ current -= SizeOf(f.fDeclaration.fParameters[i]->fType);
+ fVars.back()[f.fDeclaration.fParameters[i]] = current;
+ }
+ fCurrentIndex.push_back({ f.fBody.get(), 0 });
+ while (fCurrentIndex.size()) {
+ this->runStatement();
+ }
+}
+
+void Interpreter::push(Value value) {
+ fStack.push_back(value);
+}
+
+Interpreter::Value Interpreter::pop() {
+ auto iter = fStack.end() - 1;
+ Value result = *iter;
+ fStack.erase(iter);
+ return result;
+}
+
+ Interpreter::StackIndex Interpreter::stackAlloc(int count) {
+ int result = fStack.size();
+ for (int i = 0; i < count; ++i) {
+ fStack.push_back(Value((int) 0xDEADBEEF));
+ }
+ return result;
+}
+
+void Interpreter::runStatement() {
+ const Statement& stmt = *fCurrentIndex.back().fStatement;
+ const size_t index = fCurrentIndex.back().fIndex;
+ fCurrentIndex.pop_back();
+ switch (stmt.fKind) {
+ case Statement::kBlock_Kind: {
+ const Block& b = (const Block&) stmt;
+ if (!b.fStatements.size()) {
+ break;
+ }
+ ASSERT(index < b.fStatements.size());
+ if (index < b.fStatements.size() - 1) {
+ fCurrentIndex.push_back({ &b, index + 1 });
+ }
+ fCurrentIndex.push_back({ b.fStatements[index].get(), 0 });
+ break;
+ }
+ case Statement::kBreak_Kind:
+ ASSERT(index == 0);
+ abort();
+ case Statement::kContinue_Kind:
+ ASSERT(index == 0);
+ abort();
+ case Statement::kDiscard_Kind:
+ ASSERT(index == 0);
+ abort();
+ case Statement::kDo_Kind:
+ abort();
+ case Statement::kExpression_Kind:
+ ASSERT(index == 0);
+ this->evaluate(*((const ExpressionStatement&) stmt).fExpression);
+ break;
+ case Statement::kFor_Kind: {
+ ForStatement& f = (ForStatement&) stmt;
+ switch (index) {
+ case 0:
+ // initializer
+ fCurrentIndex.push_back({ &f, 1 });
+ if (f.fInitializer) {
+ fCurrentIndex.push_back({ f.fInitializer.get(), 0 });
+ }
+ break;
+ case 1:
+ // test & body
+ if (f.fTest && !evaluate(*f.fTest).fBool) {
+ break;
+ } else {
+ fCurrentIndex.push_back({ &f, 2 });
+ fCurrentIndex.push_back({ f.fStatement.get(), 0 });
+ }
+ break;
+ case 2:
+ // next
+ if (f.fNext) {
+ this->evaluate(*f.fNext);
+ }
+ fCurrentIndex.push_back({ &f, 1 });
+ break;
+ default:
+ ASSERT(false);
+ }
+ break;
+ }
+ case Statement::kGroup_Kind:
+ abort();
+ case Statement::kIf_Kind: {
+ IfStatement& i = (IfStatement&) stmt;
+ if (evaluate(*i.fTest).fBool) {
+ fCurrentIndex.push_back({ i.fIfTrue.get(), 0 });
+ } else if (i.fIfFalse) {
+ fCurrentIndex.push_back({ i.fIfFalse.get(), 0 });
+ }
+ break;
+ }
+ case Statement::kNop_Kind:
+ ASSERT(index == 0);
+ break;
+ case Statement::kReturn_Kind:
+ ASSERT(index == 0);
+ abort();
+ case Statement::kSwitch_Kind:
+ abort();
+ case Statement::kVarDeclarations_Kind:
+ ASSERT(index == 0);
+ for (const auto& decl :((const VarDeclarationsStatement&) stmt).fDeclaration->fVars) {
+ const Variable* var = ((VarDeclaration&) *decl).fVar;
+ StackIndex pos = this->stackAlloc(SizeOf(var->fType));
+ fVars.back()[var] = pos;
+ if (var->fInitialValue) {
+ fStack[pos] = this->evaluate(*var->fInitialValue);
+ }
+ }
+ break;
+ case Statement::kWhile_Kind:
+ abort();
+ default:
+ abort();
+ }
+}
+
+static Interpreter::TypeKind type_kind(const Type& type) {
+ if (type.fName == "int") {
+ return Interpreter::kInt_TypeKind;
+ } else if (type.fName == "float") {
+ return Interpreter::kFloat_TypeKind;
+ }
+ ABORT("unsupported type: %s\n", type.description().c_str());
+}
+
+Interpreter::StackIndex Interpreter::getLValue(const Expression& expr) {
+ switch (expr.fKind) {
+ case Expression::kFieldAccess_Kind:
+ break;
+ case Expression::kIndex_Kind: {
+ const IndexExpression& idx = (const IndexExpression&) expr;
+ return this->evaluate(*idx.fBase).fInt + this->evaluate(*idx.fIndex).fInt;
+ }
+ case Expression::kSwizzle_Kind:
+ break;
+ case Expression::kVariableReference_Kind:
+ ASSERT(fVars.size());
+ ASSERT(fVars.back().find(&((VariableReference&) expr).fVariable) !=
+ fVars.back().end());
+ return fVars.back()[&((VariableReference&) expr).fVariable];
+ case Expression::kTernary_Kind: {
+ const TernaryExpression& t = (const TernaryExpression&) expr;
+ return this->getLValue(this->evaluate(*t.fTest).fBool ? *t.fIfTrue : *t.fIfFalse);
+ }
+ case Expression::kTypeReference_Kind:
+ break;
+ default:
+ break;
+ }
+ ABORT("unsupported lvalue");
+}
+
+struct CallbackCtx : public SkJumper_CallbackCtx {
+ Interpreter* fInterpreter;
+ const FunctionDefinition* fFunction;
+};
+
+static void do_callback(SkJumper_CallbackCtx* raw, int activePixels) {
+ CallbackCtx& ctx = (CallbackCtx&) *raw;
+ for (int i = 0; i < activePixels; ++i) {
+ ctx.fInterpreter->push(Interpreter::Value(ctx.rgba[i * 4 + 0]));
+ ctx.fInterpreter->push(Interpreter::Value(ctx.rgba[i * 4 + 1]));
+ ctx.fInterpreter->push(Interpreter::Value(ctx.rgba[i * 4 + 2]));
+ ctx.fInterpreter->run(*ctx.fFunction);
+ ctx.read_from[i * 4 + 2] = ctx.fInterpreter->pop().fFloat;
+ ctx.read_from[i * 4 + 1] = ctx.fInterpreter->pop().fFloat;
+ ctx.read_from[i * 4 + 0] = ctx.fInterpreter->pop().fFloat;
+ }
+}
+
+void Interpreter::appendStage(const AppendStage& a) {
+ switch (a.fStage) {
+ case SkRasterPipeline::matrix_4x5: {
+ ASSERT(a.fArguments.size() == 1);
+ StackIndex transpose = evaluate(*a.fArguments[0]).fInt;
+ fPipeline.append(SkRasterPipeline::matrix_4x5, &fStack[transpose]);
+ break;
+ }
+ case SkRasterPipeline::callback: {
+ ASSERT(a.fArguments.size() == 1);
+ CallbackCtx* ctx = new CallbackCtx();
+ ctx->fInterpreter = this;
+ ctx->fn = do_callback;
+ for (const auto& e : fProgram->fElements) {
+ if (ProgramElement::kFunction_Kind == e->fKind) {
+ const FunctionDefinition& f = (const FunctionDefinition&) *e;
+ if (&f.fDeclaration ==
+ ((const FunctionReference&) *a.fArguments[0]).fFunctions[0]) {
+ ctx->fFunction = &f;
+ }
+ }
+ }
+ fPipeline.append(SkRasterPipeline::callback, ctx);
+ break;
+ }
+ default:
+ fPipeline.append(a.fStage);
+ }
+}
+
+Interpreter::Value Interpreter::call(const FunctionCall& c) {
+ abort();
+}
+
+Interpreter::Value Interpreter::evaluate(const Expression& expr) {
+ switch (expr.fKind) {
+ case Expression::kAppendStage_Kind:
+ this->appendStage((const AppendStage&) expr);
+ return Value((int) 0xDEADBEEF);
+ case Expression::kBinary_Kind: {
+ #define ARITHMETIC(op) { \
+ Value left = this->evaluate(*b.fLeft); \
+ Value right = this->evaluate(*b.fRight); \
+ switch (type_kind(b.fLeft->fType)) { \
+ case kFloat_TypeKind: \
+ return Value(left.fFloat op right.fFloat); \
+ case kInt_TypeKind: \
+ return Value(left.fInt op right.fInt); \
+ default: \
+ abort(); \
+ } \
+ }
+ #define BITWISE(op) { \
+ Value left = this->evaluate(*b.fLeft); \
+ Value right = this->evaluate(*b.fRight); \
+ switch (type_kind(b.fLeft->fType)) { \
+ case kInt_TypeKind: \
+ return Value(left.fInt op right.fInt); \
+ default: \
+ abort(); \
+ } \
+ }
+ #define LOGIC(op) { \
+ Value left = this->evaluate(*b.fLeft); \
+ Value right = this->evaluate(*b.fRight); \
+ switch (type_kind(b.fLeft->fType)) { \
+ case kFloat_TypeKind: \
+ return Value(left.fFloat op right.fFloat); \
+ case kInt_TypeKind: \
+ return Value(left.fInt op right.fInt); \
+ default: \
+ abort(); \
+ } \
+ }
+ #define COMPOUND_ARITHMETIC(op) { \
+ StackIndex left = this->getLValue(*b.fLeft); \
+ Value right = this->evaluate(*b.fRight); \
+ Value result = fStack[left]; \
+ switch (type_kind(b.fLeft->fType)) { \
+ case kFloat_TypeKind: \
+ result.fFloat op right.fFloat; \
+ break; \
+ case kInt_TypeKind: \
+ result.fInt op right.fInt; \
+ break; \
+ default: \
+ abort(); \
+ } \
+ fStack[left] = result; \
+ return result; \
+ }
+ #define COMPOUND_BITWISE(op) { \
+ StackIndex left = this->getLValue(*b.fLeft); \
+ Value right = this->evaluate(*b.fRight); \
+ Value result = fStack[left]; \
+ switch (type_kind(b.fLeft->fType)) { \
+ case kInt_TypeKind: \
+ result.fInt op right.fInt; \
+ break; \
+ default: \
+ abort(); \
+ } \
+ fStack[left] = result; \
+ return result; \
+ }
+ const BinaryExpression& b = (const BinaryExpression&) expr;
+ switch (b.fOperator) {
+ case Token::PLUS: ARITHMETIC(+)
+ case Token::MINUS: ARITHMETIC(-)
+ case Token::STAR: ARITHMETIC(*)
+ case Token::SLASH: ARITHMETIC(/)
+ case Token::BITWISEAND: BITWISE(&)
+ case Token::BITWISEOR: BITWISE(|)
+ case Token::BITWISEXOR: BITWISE(^)
+ case Token::LT: LOGIC(<)
+ case Token::GT: LOGIC(>)
+ case Token::LTEQ: LOGIC(<=)
+ case Token::GTEQ: LOGIC(>=)
+ case Token::LOGICALAND: {
+ Value result = this->evaluate(*b.fLeft);
+ if (result.fBool) {
+ result = this->evaluate(*b.fRight);
+ }
+ return result;
+ }
+ case Token::LOGICALOR: {
+ Value result = this->evaluate(*b.fLeft);
+ if (!result.fBool) {
+ result = this->evaluate(*b.fRight);
+ }
+ return result;
+ }
+ case Token::EQ: {
+ StackIndex left = this->getLValue(*b.fLeft);
+ Value right = this->evaluate(*b.fRight);
+ fStack[left] = right;
+ return right;
+ }
+ case Token::PLUSEQ: COMPOUND_ARITHMETIC(+=)
+ case Token::MINUSEQ: COMPOUND_ARITHMETIC(-=)
+ case Token::STAREQ: COMPOUND_ARITHMETIC(*=)
+ case Token::SLASHEQ: COMPOUND_ARITHMETIC(/=)
+ case Token::BITWISEANDEQ: COMPOUND_BITWISE(&=)
+ case Token::BITWISEOREQ: COMPOUND_BITWISE(|=)
+ case Token::BITWISEXOREQ: COMPOUND_BITWISE(^=)
+ default:
+ ABORT("unsupported operator: %s\n", expr.description().c_str());
+ }
+ break;
+ }
+ case Expression::kBoolLiteral_Kind:
+ return Value(((const BoolLiteral&) expr).fValue);
+ case Expression::kConstructor_Kind:
+ break;
+ case Expression::kIntLiteral_Kind:
+ return Value((int) ((const IntLiteral&) expr).fValue);
+ case Expression::kFieldAccess_Kind:
+ break;
+ case Expression::kFloatLiteral_Kind:
+ return Value((float) ((const FloatLiteral&) expr).fValue);
+ case Expression::kFunctionCall_Kind:
+ return this->call((const FunctionCall&) expr);
+ case Expression::kIndex_Kind: {
+ const IndexExpression& idx = (const IndexExpression&) expr;
+ StackIndex pos = this->evaluate(*idx.fBase).fInt +
+ this->evaluate(*idx.fIndex).fInt;
+ return fStack[pos];
+ }
+ case Expression::kPrefix_Kind: {
+ const PrefixExpression& p = (const PrefixExpression&) expr;
+ switch (p.fOperator) {
+ case Token::MINUS: {
+ Value base = this->evaluate(*p.fOperand);
+ switch (type_kind(p.fType)) {
+ case kFloat_TypeKind:
+ return Value(-base.fFloat);
+ case kInt_TypeKind:
+ return Value(-base.fInt);
+ default:
+ abort();
+ }
+ }
+ case Token::LOGICALNOT: {
+ Value base = this->evaluate(*p.fOperand);
+ return Value(!base.fBool);
+ }
+ default:
+ abort();
+ }
+ }
+ case Expression::kPostfix_Kind: {
+ const PostfixExpression& p = (const PostfixExpression&) expr;
+ StackIndex lvalue = this->getLValue(*p.fOperand);
+ Value result = fStack[lvalue];
+ switch (type_kind(p.fType)) {
+ case kFloat_TypeKind:
+ if (Token::PLUSPLUS == p.fOperator) {
+ ++fStack[lvalue].fFloat;
+ } else {
+ ASSERT(Token::MINUSMINUS == p.fOperator);
+ --fStack[lvalue].fFloat;
+ }
+ break;
+ case kInt_TypeKind:
+ if (Token::PLUSPLUS == p.fOperator) {
+ ++fStack[lvalue].fInt;
+ } else {
+ ASSERT(Token::MINUSMINUS == p.fOperator);
+ --fStack[lvalue].fInt;
+ }
+ break;
+ default:
+ abort();
+ }
+ return result;
+ }
+ case Expression::kSetting_Kind:
+ break;
+ case Expression::kSwizzle_Kind:
+ break;
+ case Expression::kVariableReference_Kind:
+ ASSERT(fVars.size());
+ ASSERT(fVars.back().find(&((VariableReference&) expr).fVariable) !=
+ fVars.back().end());
+ return fStack[fVars.back()[&((VariableReference&) expr).fVariable]];
+ case Expression::kTernary_Kind: {
+ const TernaryExpression& t = (const TernaryExpression&) expr;
+ return this->evaluate(this->evaluate(*t.fTest).fBool ? *t.fIfTrue : *t.fIfFalse);
+ }
+ case Expression::kTypeReference_Kind:
+ break;
+ default:
+ break;
+ }
+ ABORT("unsupported expression: %s\n", expr.description().c_str());
+}
+
+} // namespace
+
+#endif
diff --git a/src/sksl/SkSLInterpreter.h b/src/sksl/SkSLInterpreter.h
new file mode 100644
index 0000000000..93f32a1617
--- /dev/null
+++ b/src/sksl/SkSLInterpreter.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2018 Google Inc.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_INTERPRETER
+#define SKSL_INTERPRETER
+
+#include "ir/SkSLAppendStage.h"
+#include "ir/SkSLExpression.h"
+#include "ir/SkSLFunctionCall.h"
+#include "ir/SkSLFunctionDefinition.h"
+#include "ir/SkSLProgram.h"
+#include "ir/SkSLStatement.h"
+
+#include <stack>
+
+class SkRasterPipeline;
+
+namespace SkSL {
+
+class Interpreter {
+ typedef int StackIndex;
+
+ struct StatementIndex {
+ const Statement* fStatement;
+ size_t fIndex;
+ };
+
+public:
+ union Value {
+ Value(float f)
+ : fFloat(f) {}
+
+ Value(int i)
+ : fInt(i) {}
+
+ Value(bool b)
+ : fBool(b) {}
+
+ float fFloat;
+ int fInt;
+ bool fBool;
+ };
+
+ enum TypeKind {
+ kFloat_TypeKind,
+ kInt_TypeKind,
+ kBool_TypeKind
+ };
+
+ Interpreter(std::unique_ptr<Program> program, SkRasterPipeline* pipeline, std::vector<Value>* stack)
+ : fProgram(std::move(program))
+ , fPipeline(*pipeline)
+ , fStack(*stack) {}
+
+ void run();
+
+ void run(const FunctionDefinition& f);
+
+ void push(Value value);
+
+ Value pop();
+
+ StackIndex stackAlloc(int count);
+
+ void runStatement();
+
+ StackIndex getLValue(const Expression& expr);
+
+ Value call(const FunctionCall& c);
+
+ void appendStage(const AppendStage& c);
+
+ Value evaluate(const Expression& expr);
+
+private:
+ std::unique_ptr<Program> fProgram;
+ SkRasterPipeline& fPipeline;
+ std::vector<StatementIndex> fCurrentIndex;
+ std::vector<std::map<const Variable*, StackIndex>> fVars;
+ std::vector<Value> &fStack;
+};
+
+} // namespace
+
+#endif
diff --git a/src/sksl/SkSLJIT.cpp b/src/sksl/SkSLJIT.cpp
new file mode 100644
index 0000000000..06b6e2c94b
--- /dev/null
+++ b/src/sksl/SkSLJIT.cpp
@@ -0,0 +1,1747 @@
+/*
+ * Copyright 2018 Google Inc.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_STANDALONE
+
+#ifdef SK_LLVM_AVAILABLE
+
+#include "SkSLJIT.h"
+
+#include "SkCpu.h"
+#include "SkRasterPipeline.h"
+#include "../jumper/SkJumper.h"
+#include "ir/SkSLExpressionStatement.h"
+#include "ir/SkSLFunctionCall.h"
+#include "ir/SkSLFunctionReference.h"
+#include "ir/SkSLIndexExpression.h"
+#include "ir/SkSLProgram.h"
+#include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
+
+static constexpr int MAX_VECTOR_COUNT = 16;
+
+extern "C" void sksl_pipeline_append(SkRasterPipeline* p, int stage, void* ctx) {
+ p->append((SkRasterPipeline::StockStage) stage, ctx);
+}
+
+#define PTR_SIZE sizeof(void*)
+
+extern "C" void sksl_pipeline_append_callback(SkRasterPipeline* p, void* fn) {
+ p->append(fn, nullptr);
+}
+
+extern "C" void sksl_debug_print(float f) {
+ printf("Debug: %f\n", f);
+}
+
+namespace SkSL {
+
+static constexpr int STAGE_PARAM_COUNT = 12;
+
+static bool ends_with_branch(const Statement& stmt) {
+ switch (stmt.fKind) {
+ case Statement::kBlock_Kind: {
+ const Block& b = (const Block&) stmt;
+ if (b.fStatements.size()) {
+ return ends_with_branch(*b.fStatements.back());
+ }
+ return false;
+ }
+ case Statement::kBreak_Kind: // fall through
+ case Statement::kContinue_Kind: // fall through
+ case Statement::kReturn_Kind: // fall through
+ return true;
+ default:
+ return false;
+ }
+}
+
+JIT::JIT(Compiler* compiler)
+: fCompiler(*compiler) {
+ LLVMInitializeNativeTarget();
+ LLVMInitializeNativeAsmPrinter();
+ LLVMLinkInMCJIT();
+ ASSERT(!SkCpu::Supports(SkCpu::SKX)); // not yet supported
+ if (SkCpu::Supports(SkCpu::HSW)) {
+ fVectorCount = 8;
+ fCPU = "haswell";
+ } else if (SkCpu::Supports(SkCpu::AVX)) {
+ fVectorCount = 8;
+ fCPU = "ivybridge";
+ } else {
+ fVectorCount = 4;
+ fCPU = nullptr;
+ }
+ fContext = LLVMContextCreate();
+ fVoidType = LLVMVoidTypeInContext(fContext);
+ fInt1Type = LLVMInt1TypeInContext(fContext);
+ fInt8Type = LLVMInt8TypeInContext(fContext);
+ fInt8PtrType = LLVMPointerType(fInt8Type, 0);
+ fInt32Type = LLVMInt32TypeInContext(fContext);
+ fInt64Type = LLVMInt64TypeInContext(fContext);
+ fSizeTType = LLVMInt64TypeInContext(fContext);
+ fInt32VectorType = LLVMVectorType(fInt32Type, fVectorCount);
+ fInt32Vector2Type = LLVMVectorType(fInt32Type, 2);
+ fInt32Vector3Type = LLVMVectorType(fInt32Type, 3);
+ fInt32Vector4Type = LLVMVectorType(fInt32Type, 4);
+ fFloat32Type = LLVMFloatTypeInContext(fContext);
+ fFloat32VectorType = LLVMVectorType(fFloat32Type, fVectorCount);
+ fFloat32Vector2Type = LLVMVectorType(fFloat32Type, 2);
+ fFloat32Vector3Type = LLVMVectorType(fFloat32Type, 3);
+ fFloat32Vector4Type = LLVMVectorType(fFloat32Type, 4);
+}
+
+JIT::~JIT() {
+ LLVMOrcDisposeInstance(fJITStack);
+ LLVMContextDispose(fContext);
+}
+
+void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType,
+ std::vector<LLVMTypeRef> parameters) {
+ for (const auto& pair : *fProgram->fSymbols) {
+ if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) {
+ const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second;
+ if (pair.first != ourName || returnType != this->getType(f.fReturnType) ||
+ parameters.size() != f.fParameters.size()) {
+ continue;
+ }
+ for (size_t i = 0; i < parameters.size(); ++i) {
+ if (parameters[i] != this->getType(f.fParameters[i]->fType)) {
+ goto next;
+ }
+ }
+ fFunctions[&f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(returnType,
+ parameters.data(),
+ parameters.size(),
+ false));
+ }
+ next:;
+ }
+}
+
+void JIT::loadBuiltinFunctions() {
+ this->addBuiltinFunction("abs", "fabs", fFloat32Type, { fFloat32Type });
+ this->addBuiltinFunction("sin", "sinf", fFloat32Type, { fFloat32Type });
+ this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type });
+ this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type });
+ this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type });
+ this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type });
+}
+
+uint64_t JIT::resolveSymbol(const char* name, JIT* jit) {
+ LLVMOrcTargetAddress result;
+ if (!LLVMOrcGetSymbolAddress(jit->fJITStack, &result, name)) {
+ if (!strcmp(name, "_sksl_pipeline_append")) {
+ result = (uint64_t) &sksl_pipeline_append;
+ } else if (!strcmp(name, "_sksl_pipeline_append_callback")) {
+ result = (uint64_t) &sksl_pipeline_append_callback;
+ } else if (!strcmp(name, "_sksl_debug_print")) {
+ result = (uint64_t) &sksl_debug_print;
+ } else {
+ result = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name);
+ }
+ }
+ ASSERT(result);
+ return result;
+}
+
+LLVMValueRef JIT::compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc) {
+ LLVMValueRef func = fFunctions[&fc.fFunction];
+ ASSERT(func);
+ std::vector<LLVMValueRef> parameters;
+ for (const auto& a : fc.fArguments) {
+ parameters.push_back(this->compileExpression(builder, *a));
+ }
+ return LLVMBuildCall(builder, func, parameters.data(), parameters.size(), "");
+}
+
+LLVMTypeRef JIT::getType(const Type& type) {
+ switch (type.kind()) {
+ case Type::kOther_Kind:
+ if (type.name() == "void") {
+ return fVoidType;
+ }
+ ASSERT(type.name() == "SkRasterPipeline");
+ return fInt8PtrType;
+ case Type::kScalar_Kind:
+ if (type.isSigned() || type.isUnsigned()) {
+ return fInt32Type;
+ }
+ if (type.isUnsigned()) {
+ return fInt32Type;
+ }
+ if (type.isFloat()) {
+ return fFloat32Type;
+ }
+ ASSERT(type.name() == "bool");
+ return fInt1Type;
+ case Type::kArray_Kind:
+ return LLVMPointerType(this->getType(type.componentType()), 0);
+ case Type::kVector_Kind:
+ if (type.name() == "float2" || type.name() == "half2") {
+ return fFloat32Vector2Type;
+ }
+ if (type.name() == "float3" || type.name() == "half3") {
+ return fFloat32Vector3Type;
+ }
+ if (type.name() == "float4" || type.name() == "half4") {
+ return fFloat32Vector4Type;
+ }
+ if (type.name() == "int2" || type.name() == "short2") {
+ return fInt32Vector2Type;
+ }
+ if (type.name() == "int3" || type.name() == "short3") {
+ return fInt32Vector3Type;
+ }
+ if (type.name() == "int4" || type.name() == "short4") {
+ return fInt32Vector4Type;
+ }
+ // fall through
+ default:
+ ABORT("unsupported type");
+ }
+}
+
+void JIT::setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block) {
+ fCurrentBlock = block;
+ LLVMPositionBuilderAtEnd(builder, block);
+}
+
+std::unique_ptr<JIT::LValue> JIT::getLValue(LLVMBuilderRef builder, const Expression& expr) {
+ switch (expr.fKind) {
+ case Expression::kVariableReference_Kind: {
+ class PointerLValue : public LValue {
+ public:
+ PointerLValue(LLVMValueRef ptr)
+ : fPointer(ptr) {}
+
+ LLVMValueRef load(LLVMBuilderRef builder) override {
+ return LLVMBuildLoad(builder, fPointer, "lvalue load");
+ }
+
+ void store(LLVMBuilderRef builder, LLVMValueRef value) override {
+ LLVMBuildStore(builder, value, fPointer);
+ }
+
+ private:
+ LLVMValueRef fPointer;
+ };
+ const Variable* var = &((VariableReference&) expr).fVariable;
+ if (var->fStorage == Variable::kParameter_Storage &&
+ !(var->fModifiers.fFlags & Modifiers::kOut_Flag) &&
+ fPromotedParameters.find(var) == fPromotedParameters.end()) {
+ // promote parameter to variable
+ fPromotedParameters.insert(var);
+ LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
+ LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(var->fType),
+ String(var->fName).c_str());
+ LLVMBuildStore(builder, fVariables[var], alloca);
+ LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
+ fVariables[var] = alloca;
+ }
+ LLVMValueRef ptr = fVariables[var];
+ return std::unique_ptr<LValue>(new PointerLValue(ptr));
+ }
+ case Expression::kTernary_Kind: {
+ class TernaryLValue : public LValue {
+ public:
+ TernaryLValue(JIT* jit, LLVMValueRef test, std::unique_ptr<LValue> ifTrue,
+ std::unique_ptr<LValue> ifFalse)
+ : fJIT(*jit)
+ , fTest(test)
+ , fIfTrue(std::move(ifTrue))
+ , fIfFalse(std::move(ifFalse)) {}
+
+ LLVMValueRef load(LLVMBuilderRef builder) override {
+ LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
+ fJIT.fContext,
+ fJIT.fCurrentFunction,
+ "true ? ...");
+ LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
+ fJIT.fContext,
+ fJIT.fCurrentFunction,
+ "false ? ...");
+ LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
+ fJIT.fCurrentFunction,
+ "ternary merge");
+ LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
+ fJIT.setBlock(builder, trueBlock);
+ LLVMValueRef ifTrue = fIfTrue->load(builder);
+ LLVMBuildBr(builder, merge);
+ fJIT.setBlock(builder, falseBlock);
+ LLVMValueRef ifFalse = fIfTrue->load(builder);
+ LLVMBuildBr(builder, merge);
+ fJIT.setBlock(builder, merge);
+ LLVMTypeRef type = LLVMPointerType(LLVMTypeOf(ifTrue), 0);
+ LLVMValueRef phi = LLVMBuildPhi(builder, type, "?");
+ LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
+ LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
+ LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
+ return phi;
+ }
+
+ void store(LLVMBuilderRef builder, LLVMValueRef value) override {
+ LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
+ fJIT.fContext,
+ fJIT.fCurrentFunction,
+ "true ? ...");
+ LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
+ fJIT.fContext,
+ fJIT.fCurrentFunction,
+ "false ? ...");
+ LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
+ fJIT.fCurrentFunction,
+ "ternary merge");
+ LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
+ fJIT.setBlock(builder, trueBlock);
+ fIfTrue->store(builder, value);
+ LLVMBuildBr(builder, merge);
+ fJIT.setBlock(builder, falseBlock);
+ fIfTrue->store(builder, value);
+ LLVMBuildBr(builder, merge);
+ fJIT.setBlock(builder, merge);
+ }
+
+ private:
+ JIT& fJIT;
+ LLVMValueRef fTest;
+ std::unique_ptr<LValue> fIfTrue;
+ std::unique_ptr<LValue> fIfFalse;
+ };
+ const TernaryExpression& t = (const TernaryExpression&) expr;
+ LLVMValueRef test = this->compileExpression(builder, *t.fTest);
+ return std::unique_ptr<LValue>(new TernaryLValue(this,
+ test,
+ this->getLValue(builder,
+ *t.fIfTrue),
+ this->getLValue(builder,
+ *t.fIfFalse)));
+ }
+ case Expression::kSwizzle_Kind: {
+ class SwizzleLValue : public LValue {
+ public:
+ SwizzleLValue(JIT* jit, LLVMTypeRef type, std::unique_ptr<LValue> base,
+ std::vector<int> components)
+ : fJIT(*jit)
+ , fType(type)
+ , fBase(std::move(base))
+ , fComponents(components) {}
+
+ LLVMValueRef load(LLVMBuilderRef builder) override {
+ LLVMValueRef base = fBase->load(builder);
+ if (fComponents.size() > 1) {
+ LLVMValueRef result = LLVMGetUndef(fType);
+ for (size_t i = 0; i < fComponents.size(); ++i) {
+ LLVMValueRef element = LLVMBuildExtractElement(
+ builder,
+ base,
+ LLVMConstInt(fJIT.fInt32Type,
+ fComponents[i],
+ false),
+ "swizzle extract");
+ result = LLVMBuildInsertElement(builder, result, element,
+ LLVMConstInt(fJIT.fInt32Type, i, false),
+ "swizzle insert");
+ }
+ return result;
+ }
+ ASSERT(fComponents.size() == 1);
+ return LLVMBuildExtractElement(builder, base,
+ LLVMConstInt(fJIT.fInt32Type,
+ fComponents[0],
+ false),
+ "swizzle extract");
+ }
+
+ void store(LLVMBuilderRef builder, LLVMValueRef value) override {
+ LLVMValueRef result = fBase->load(builder);
+ if (fComponents.size() > 1) {
+ for (size_t i = 0; i < fComponents.size(); ++i) {
+ LLVMValueRef element = LLVMBuildExtractElement(builder, value,
+ LLVMConstInt(
+ fJIT.fInt32Type,
+ i,
+ false),
+ "swizzle extract");
+ result = LLVMBuildInsertElement(builder, result, element,
+ LLVMConstInt(fJIT.fInt32Type,
+ fComponents[i],
+ false),
+ "swizzle insert");
+ }
+ } else {
+ result = LLVMBuildInsertElement(builder, result, value,
+ LLVMConstInt(fJIT.fInt32Type,
+ fComponents[0],
+ false),
+ "swizzle insert");
+ }
+ fBase->store(builder, result);
+ }
+
+ private:
+ JIT& fJIT;
+ LLVMTypeRef fType;
+ std::unique_ptr<LValue> fBase;
+ std::vector<int> fComponents;
+ };
+ const Swizzle& s = (const Swizzle&) expr;
+ return std::unique_ptr<LValue>(new SwizzleLValue(this, this->getType(s.fType),
+ this->getLValue(builder, *s.fBase),
+ s.fComponents));
+ }
+ default:
+ ABORT("unsupported lvalue");
+ }
+}
+
+JIT::TypeKind JIT::typeKind(const Type& type) {
+ if (type.kind() == Type::kVector_Kind) {
+ return this->typeKind(type.componentType());
+ }
+ if (type.fName == "int" || type.fName == "short") {
+ return JIT::kInt_TypeKind;
+ } else if (type.fName == "uint" || type.fName == "ushort") {
+ return JIT::kUInt_TypeKind;
+ } else if (type.fName == "float" || type.fName == "double") {
+ return JIT::kFloat_TypeKind;
+ }
+ ABORT("unsupported type: %s\n", type.description().c_str());
+}
+
+void JIT::vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns) {
+ LLVMValueRef result = LLVMGetUndef(LLVMVectorType(LLVMTypeOf(*value), columns));
+ for (int i = 0; i < columns; ++i) {
+ result = LLVMBuildInsertElement(builder,
+ result,
+ *value,
+ LLVMConstInt(fInt32Type, i, false),
+ "vectorize");
+ }
+ *value = result;
+}
+
+void JIT::vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left,
+ LLVMValueRef* right) {
+ if (b.fLeft->fType.kind() == Type::kScalar_Kind &&
+ b.fRight->fType.kind() == Type::kVector_Kind) {
+ this->vectorize(builder, left, b.fRight->fType.columns());
+ } else if (b.fLeft->fType.kind() == Type::kVector_Kind &&
+ b.fRight->fType.kind() == Type::kScalar_Kind) {
+ this->vectorize(builder, right, b.fLeft->fType.columns());
+ }
+}
+
+
+LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& b) {
+ #define BINARY(SFunc, UFunc, FFunc) { \
+ LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
+ LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
+ this->vectorize(builder, b, &left, &right); \
+ switch (this->typeKind(b.fLeft->fType)) { \
+ case kInt_TypeKind: \
+ return SFunc(builder, left, right, "binary"); \
+ case kUInt_TypeKind: \
+ return UFunc(builder, left, right, "binary"); \
+ case kFloat_TypeKind: \
+ return FFunc(builder, left, right, "binary"); \
+ default: \
+ ABORT("unsupported typeKind"); \
+ } \
+ }
+ #define COMPOUND(SFunc, UFunc, FFunc) { \
+ std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft); \
+ LLVMValueRef left = lvalue->load(builder); \
+ LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
+ this->vectorize(builder, b, &left, &right); \
+ LLVMValueRef result; \
+ switch (this->typeKind(b.fLeft->fType)) { \
+ case kInt_TypeKind: \
+ result = SFunc(builder, left, right, "binary"); \
+ break; \
+ case kUInt_TypeKind: \
+ result = UFunc(builder, left, right, "binary"); \
+ break; \
+ case kFloat_TypeKind: \
+ result = FFunc(builder, left, right, "binary"); \
+ break; \
+ default: \
+ ABORT("unsupported typeKind"); \
+ } \
+ lvalue->store(builder, result); \
+ return result; \
+ }
+ #define COMPARE(SFunc, SOp, UFunc, UOp, FFunc, FOp) { \
+ LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
+ LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
+ this->vectorize(builder, b, &left, &right); \
+ switch (this->typeKind(b.fLeft->fType)) { \
+ case kInt_TypeKind: \
+ return SFunc(builder, SOp, left, right, "binary"); \
+ case kUInt_TypeKind: \
+ return UFunc(builder, UOp, left, right, "binary"); \
+ case kFloat_TypeKind: \
+ return FFunc(builder, FOp, left, right, "binary"); \
+ default: \
+ ABORT("unsupported typeKind"); \
+ } \
+ }
+ switch (b.fOperator) {
+ case Token::EQ: {
+ std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft);
+ LLVMValueRef result = this->compileExpression(builder, *b.fRight);
+ lvalue->store(builder, result);
+ return result;
+ }
+ case Token::PLUS:
+ BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
+ case Token::MINUS:
+ BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
+ case Token::STAR:
+ BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
+ case Token::SLASH:
+ BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
+ case Token::PERCENT:
+ BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
+ case Token::BITWISEAND:
+ BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
+ case Token::BITWISEOR:
+ BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
+ case Token::PLUSEQ:
+ COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
+ case Token::MINUSEQ:
+ COMPOUND(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
+ case Token::STAREQ:
+ COMPOUND(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
+ case Token::SLASHEQ:
+ COMPOUND(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
+ case Token::BITWISEANDEQ:
+ COMPOUND(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
+ case Token::BITWISEOREQ:
+ COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
+ case Token::EQEQ:
+ COMPARE(LLVMBuildICmp, LLVMIntEQ,
+ LLVMBuildICmp, LLVMIntEQ,
+ LLVMBuildFCmp, LLVMRealOEQ);
+ case Token::NEQ:
+ COMPARE(LLVMBuildICmp, LLVMIntNE,
+ LLVMBuildICmp, LLVMIntNE,
+ LLVMBuildFCmp, LLVMRealONE);
+ case Token::LT:
+ COMPARE(LLVMBuildICmp, LLVMIntSLT,
+ LLVMBuildICmp, LLVMIntULT,
+ LLVMBuildFCmp, LLVMRealOLT);
+ case Token::LTEQ:
+ COMPARE(LLVMBuildICmp, LLVMIntSLE,
+ LLVMBuildICmp, LLVMIntULE,
+ LLVMBuildFCmp, LLVMRealOLE);
+ case Token::GT:
+ COMPARE(LLVMBuildICmp, LLVMIntSGT,
+ LLVMBuildICmp, LLVMIntUGT,
+ LLVMBuildFCmp, LLVMRealOGT);
+ case Token::GTEQ:
+ COMPARE(LLVMBuildICmp, LLVMIntSGE,
+ LLVMBuildICmp, LLVMIntUGE,
+ LLVMBuildFCmp, LLVMRealOGE);
+ case Token::LOGICALAND: {
+ LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
+ LLVMBasicBlockRef ifFalse = fCurrentBlock;
+ LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "true && ...");
+ LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "&& merge");
+ LLVMBuildCondBr(builder, left, ifTrue, merge);
+ this->setBlock(builder, ifTrue);
+ LLVMValueRef right = this->compileExpression(builder, *b.fRight);
+ LLVMBuildBr(builder, merge);
+ this->setBlock(builder, merge);
+ LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "&&");
+ LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 0, false) };
+ LLVMBasicBlockRef incomingBlocks[2] = { ifTrue, ifFalse };
+ LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
+ return phi;
+ }
+ case Token::LOGICALOR: {
+ LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
+ LLVMBasicBlockRef ifTrue = fCurrentBlock;
+ LLVMBasicBlockRef ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "false || ...");
+ LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "|| merge");
+ LLVMBuildCondBr(builder, left, merge, ifFalse);
+ this->setBlock(builder, ifFalse);
+ LLVMValueRef right = this->compileExpression(builder, *b.fRight);
+ LLVMBuildBr(builder, merge);
+ this->setBlock(builder, merge);
+ LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "||");
+ LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 1, false) };
+ LLVMBasicBlockRef incomingBlocks[2] = { ifFalse, ifTrue };
+ LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
+ return phi;
+ }
+ default:
+ ABORT("unsupported binary operator");
+ }
+}
+
+LLVMValueRef JIT::compileIndex(LLVMBuilderRef builder, const IndexExpression& idx) {
+ LLVMValueRef base = this->compileExpression(builder, *idx.fBase);
+ LLVMValueRef index = this->compileExpression(builder, *idx.fIndex);
+ LLVMValueRef ptr = LLVMBuildGEP(builder, base, &index, 1, "index ptr");
+ return LLVMBuildLoad(builder, ptr, "index load");
+}
+
+LLVMValueRef JIT::compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p) {
+ std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
+ LLVMValueRef result = lvalue->load(builder);
+ LLVMValueRef mod;
+ LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
+ switch (p.fOperator) {
+ case Token::PLUSPLUS:
+ switch (this->typeKind(p.fType)) {
+ case kInt_TypeKind: // fall through
+ case kUInt_TypeKind:
+ mod = LLVMBuildAdd(builder, result, one, "++");
+ break;
+ case kFloat_TypeKind:
+ mod = LLVMBuildFAdd(builder, result, one, "++");
+ break;
+ default:
+ ABORT("unsupported typeKind");
+ }
+ break;
+ case Token::MINUSMINUS:
+ switch (this->typeKind(p.fType)) {
+ case kInt_TypeKind: // fall through
+ case kUInt_TypeKind:
+ mod = LLVMBuildSub(builder, result, one, "--");
+ break;
+ case kFloat_TypeKind:
+ mod = LLVMBuildFSub(builder, result, one, "--");
+ break;
+ default:
+ ABORT("unsupported typeKind");
+ }
+ break;
+ default:
+ ABORT("unsupported postfix op");
+ }
+ lvalue->store(builder, mod);
+ return result;
+}
+
+LLVMValueRef JIT::compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p) {
+ LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
+ if (Token::LOGICALNOT == p.fOperator) {
+ LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
+ return LLVMBuildXor(builder, base, one, "!");
+ }
+ if (Token::MINUS == p.fOperator) {
+ LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
+ return LLVMBuildSub(builder, LLVMConstInt(this->getType(p.fType), 0, false), base, "-");
+ }
+ std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
+ LLVMValueRef raw = lvalue->load(builder);
+ LLVMValueRef result;
+ switch (p.fOperator) {
+ case Token::PLUSPLUS:
+ switch (this->typeKind(p.fType)) {
+ case kInt_TypeKind: // fall through
+ case kUInt_TypeKind:
+ result = LLVMBuildAdd(builder, raw, one, "++");
+ break;
+ case kFloat_TypeKind:
+ result = LLVMBuildFAdd(builder, raw, one, "++");
+ break;
+ default:
+ ABORT("unsupported typeKind");
+ }
+ break;
+ case Token::MINUSMINUS:
+ switch (this->typeKind(p.fType)) {
+ case kInt_TypeKind: // fall through
+ case kUInt_TypeKind:
+ result = LLVMBuildSub(builder, raw, one, "--");
+ break;
+ case kFloat_TypeKind:
+ result = LLVMBuildFSub(builder, raw, one, "--");
+ break;
+ default:
+ ABORT("unsupported typeKind");
+ }
+ break;
+ default:
+ ABORT("unsupported prefix op");
+ }
+ lvalue->store(builder, result);
+ return result;
+}
+
+LLVMValueRef JIT::compileVariableReference(LLVMBuilderRef builder, const VariableReference& v) {
+ const Variable& var = v.fVariable;
+ if (Variable::kParameter_Storage == var.fStorage &&
+ !(var.fModifiers.fFlags & Modifiers::kOut_Flag) &&
+ fPromotedParameters.find(&var) == fPromotedParameters.end()) {
+ return fVariables[&var];
+ }
+ return LLVMBuildLoad(builder, fVariables[&var], String(var.fName).c_str());
+}
+
+void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) {
+ ASSERT(a.fArguments.size() >= 1);
+ ASSERT(a.fArguments[0]->fType == *fCompiler.context().fSkRasterPipeline_Type);
+ LLVMValueRef pipeline = this->compileExpression(builder, *a.fArguments[0]);
+ LLVMValueRef stage = LLVMConstInt(fInt32Type, a.fStage, 0);
+ switch (a.fStage) {
+ case SkRasterPipeline::callback: {
+ ASSERT(a.fArguments.size() == 2);
+ ASSERT(a.fArguments[1]->fKind == Expression::kFunctionReference_Kind);
+ const FunctionDeclaration& functionDecl =
+ *((FunctionReference&) *a.fArguments[1]).fFunctions[0];
+ bool found = false;
+ for (const auto& pe : fProgram->fElements) {
+ if (ProgramElement::kFunction_Kind == pe->fKind) {
+ const FunctionDefinition& def = (const FunctionDefinition&) *pe;
+ if (&def.fDeclaration == &functionDecl) {
+ LLVMValueRef fn = this->compileStageFunction(def);
+ LLVMValueRef args[2] = {
+ pipeline,
+ LLVMBuildBitCast(builder, fn, fInt8PtrType, "callback cast")
+ };
+ LLVMBuildCall(builder, fAppendCallbackFunc, args, 2, "");
+ found = true;
+ break;
+ }
+ }
+ }
+ ASSERT(found);
+ break;
+ }
+ default: {
+ LLVMValueRef ctx;
+ if (a.fArguments.size() == 2) {
+ ctx = this->compileExpression(builder, *a.fArguments[1]);
+ ctx = LLVMBuildBitCast(builder, ctx, fInt8PtrType, "context cast");
+ } else {
+ ASSERT(a.fArguments.size() == 1);
+ ctx = LLVMConstNull(fInt8PtrType);
+ }
+ LLVMValueRef args[3] = {
+ pipeline,
+ stage,
+ ctx
+ };
+ LLVMBuildCall(builder, fAppendFunc, args, 3, "");
+ break;
+ }
+ }
+}
+
+LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor& c) {
+ switch (c.fType.kind()) {
+ case Type::kScalar_Kind: {
+ ASSERT(c.fArguments.size() == 1);
+ TypeKind from = this->typeKind(c.fArguments[0]->fType);
+ TypeKind to = this->typeKind(c.fType);
+ LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]);
+ if (kFloat_TypeKind == to) {
+ if (kInt_TypeKind == from) {
+ return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast");
+ }
+ if (kUInt_TypeKind == from) {
+ return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast");
+ }
+ }
+ if (kInt_TypeKind == to) {
+ if (kFloat_TypeKind == from) {
+ return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast");
+ }
+ if (kUInt_TypeKind == from) {
+ return base;
+ }
+ }
+ if (kUInt_TypeKind == to) {
+ if (kFloat_TypeKind == from) {
+ return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast");
+ }
+ if (kInt_TypeKind == from) {
+ return base;
+ }
+ }
+ ABORT("unsupported constructor");
+ }
+ case Type::kVector_Kind: {
+ LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType));
+ if (c.fArguments.size() == 1) {
+ LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]);
+ for (int i = 0; i < c.fType.columns(); ++i) {
+ vec = LLVMBuildInsertElement(builder, vec, value,
+ LLVMConstInt(fInt32Type, i, false),
+ "vec build");
+ }
+ } else {
+ ASSERT(c.fArguments.size() == (size_t) c.fType.columns());
+ for (int i = 0; i < c.fType.columns(); ++i) {
+ vec = LLVMBuildInsertElement(builder, vec,
+ this->compileExpression(builder,
+ *c.fArguments[i]),
+ LLVMConstInt(fInt32Type, i, false),
+ "vec build");
+ }
+ }
+ return vec;
+ }
+ default:
+ break;
+ }
+ ABORT("unsupported constructor");
+}
+
+LLVMValueRef JIT::compileSwizzle(LLVMBuilderRef builder, const Swizzle& s) {
+ LLVMValueRef base = this->compileExpression(builder, *s.fBase);
+ if (s.fComponents.size() > 1) {
+ LLVMValueRef result = LLVMGetUndef(this->getType(s.fType));
+ for (size_t i = 0; i < s.fComponents.size(); ++i) {
+ LLVMValueRef element = LLVMBuildExtractElement(
+ builder,
+ base,
+ LLVMConstInt(fInt32Type,
+ s.fComponents[i],
+ false),
+ "swizzle extract");
+ result = LLVMBuildInsertElement(builder, result, element,
+ LLVMConstInt(fInt32Type, i, false),
+ "swizzle insert");
+ }
+ return result;
+ }
+ ASSERT(s.fComponents.size() == 1);
+ return LLVMBuildExtractElement(builder, base,
+ LLVMConstInt(fInt32Type,
+ s.fComponents[0],
+ false),
+ "swizzle extract");
+}
+
+LLVMValueRef JIT::compileTernary(LLVMBuilderRef builder, const TernaryExpression& t) {
+ LLVMValueRef test = this->compileExpression(builder, *t.fTest);
+ LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "if true");
+ LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "if merge");
+ LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "if false");
+ LLVMBuildCondBr(builder, test, trueBlock, falseBlock);
+ this->setBlock(builder, trueBlock);
+ LLVMValueRef ifTrue = this->compileExpression(builder, *t.fIfTrue);
+ trueBlock = fCurrentBlock;
+ LLVMBuildBr(builder, merge);
+ this->setBlock(builder, falseBlock);
+ LLVMValueRef ifFalse = this->compileExpression(builder, *t.fIfFalse);
+ falseBlock = fCurrentBlock;
+ LLVMBuildBr(builder, merge);
+ this->setBlock(builder, merge);
+ LLVMValueRef phi = LLVMBuildPhi(builder, this->getType(t.fType), "?");
+ LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
+ LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
+ LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
+ return phi;
+}
+
+LLVMValueRef JIT::compileExpression(LLVMBuilderRef builder, const Expression& expr) {
+ switch (expr.fKind) {
+ case Expression::kAppendStage_Kind: {
+ this->appendStage(builder, (const AppendStage&) expr);
+ return LLVMValueRef();
+ }
+ case Expression::kBinary_Kind:
+ return this->compileBinary(builder, (BinaryExpression&) expr);
+ case Expression::kBoolLiteral_Kind:
+ return LLVMConstInt(fInt1Type, ((BoolLiteral&) expr).fValue, false);
+ case Expression::kConstructor_Kind:
+ return this->compileConstructor(builder, (Constructor&) expr);
+ case Expression::kIntLiteral_Kind:
+ return LLVMConstInt(this->getType(expr.fType), ((IntLiteral&) expr).fValue, true);
+ case Expression::kFieldAccess_Kind:
+ abort();
+ case Expression::kFloatLiteral_Kind:
+ return LLVMConstReal(this->getType(expr.fType), ((FloatLiteral&) expr).fValue);
+ case Expression::kFunctionCall_Kind:
+ return this->compileFunctionCall(builder, (FunctionCall&) expr);
+ case Expression::kIndex_Kind:
+ return this->compileIndex(builder, (IndexExpression&) expr);
+ case Expression::kPrefix_Kind:
+ return this->compilePrefix(builder, (PrefixExpression&) expr);
+ case Expression::kPostfix_Kind:
+ return this->compilePostfix(builder, (PostfixExpression&) expr);
+ case Expression::kSetting_Kind:
+ abort();
+ case Expression::kSwizzle_Kind:
+ return this->compileSwizzle(builder, (Swizzle&) expr);
+ case Expression::kVariableReference_Kind:
+ return this->compileVariableReference(builder, (VariableReference&) expr);
+ case Expression::kTernary_Kind:
+ return this->compileTernary(builder, (TernaryExpression&) expr);
+ case Expression::kTypeReference_Kind:
+ abort();
+ default:
+ abort();
+ }
+ ABORT("unsupported expression: %s\n", expr.description().c_str());
+}
+
+void JIT::compileBlock(LLVMBuilderRef builder, const Block& block) {
+ for (const auto& stmt : block.fStatements) {
+ this->compileStatement(builder, *stmt);
+ }
+}
+
+void JIT::compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls) {
+ for (const auto& declStatement : decls.fDeclaration->fVars) {
+ const VarDeclaration& decl = (VarDeclaration&) *declStatement;
+ LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
+ LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(decl.fVar->fType),
+ String(decl.fVar->fName).c_str());
+ fVariables[decl.fVar] = alloca;
+ LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
+ if (decl.fValue) {
+ LLVMValueRef result = this->compileExpression(builder, *decl.fValue);
+ LLVMBuildStore(builder, result, alloca);
+ }
+ }
+}
+
+void JIT::compileIf(LLVMBuilderRef builder, const IfStatement& i) {
+ LLVMValueRef test = this->compileExpression(builder, *i.fTest);
+ LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if true");
+ LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "if merge");
+ LLVMBasicBlockRef ifFalse;
+ if (i.fIfFalse) {
+ ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if false");
+ } else {
+ ifFalse = merge;
+ }
+ LLVMBuildCondBr(builder, test, ifTrue, ifFalse);
+ this->setBlock(builder, ifTrue);
+ this->compileStatement(builder, *i.fIfTrue);
+ if (!ends_with_branch(*i.fIfTrue)) {
+ LLVMBuildBr(builder, merge);
+ }
+ if (i.fIfFalse) {
+ this->setBlock(builder, ifFalse);
+ this->compileStatement(builder, *i.fIfFalse);
+ if (!ends_with_branch(*i.fIfFalse)) {
+ LLVMBuildBr(builder, merge);
+ }
+ }
+ this->setBlock(builder, merge);
+}
+
+void JIT::compileFor(LLVMBuilderRef builder, const ForStatement& f) {
+ if (f.fInitializer) {
+ this->compileStatement(builder, *f.fInitializer);
+ }
+ LLVMBasicBlockRef start;
+ LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for body");
+ LLVMBasicBlockRef next = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for next");
+ LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for end");
+ if (f.fTest) {
+ start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for test");
+ LLVMBuildBr(builder, start);
+ this->setBlock(builder, start);
+ LLVMValueRef test = this->compileExpression(builder, *f.fTest);
+ LLVMBuildCondBr(builder, test, body, end);
+ } else {
+ start = body;
+ LLVMBuildBr(builder, body);
+ }
+ this->setBlock(builder, body);
+ fBreakTarget.push_back(end);
+ fContinueTarget.push_back(next);
+ this->compileStatement(builder, *f.fStatement);
+ fBreakTarget.pop_back();
+ fContinueTarget.pop_back();
+ if (!ends_with_branch(*f.fStatement)) {
+ LLVMBuildBr(builder, next);
+ }
+ this->setBlock(builder, next);
+ if (f.fNext) {
+ this->compileExpression(builder, *f.fNext);
+ }
+ LLVMBuildBr(builder, start);
+ this->setBlock(builder, end);
+}
+
+void JIT::compileDo(LLVMBuilderRef builder, const DoStatement& d) {
+ LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "do test");
+ LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "do body");
+ LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "do end");
+ LLVMBuildBr(builder, body);
+ this->setBlock(builder, testBlock);
+ LLVMValueRef test = this->compileExpression(builder, *d.fTest);
+ LLVMBuildCondBr(builder, test, body, end);
+ this->setBlock(builder, body);
+ fBreakTarget.push_back(end);
+ fContinueTarget.push_back(body);
+ this->compileStatement(builder, *d.fStatement);
+ fBreakTarget.pop_back();
+ fContinueTarget.pop_back();
+ if (!ends_with_branch(*d.fStatement)) {
+ LLVMBuildBr(builder, testBlock);
+ }
+ this->setBlock(builder, end);
+}
+
+void JIT::compileWhile(LLVMBuilderRef builder, const WhileStatement& w) {
+ LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "while test");
+ LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "while body");
+ LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "while end");
+ LLVMBuildBr(builder, testBlock);
+ this->setBlock(builder, testBlock);
+ LLVMValueRef test = this->compileExpression(builder, *w.fTest);
+ LLVMBuildCondBr(builder, test, body, end);
+ this->setBlock(builder, body);
+ fBreakTarget.push_back(end);
+ fContinueTarget.push_back(testBlock);
+ this->compileStatement(builder, *w.fStatement);
+ fBreakTarget.pop_back();
+ fContinueTarget.pop_back();
+ if (!ends_with_branch(*w.fStatement)) {
+ LLVMBuildBr(builder, testBlock);
+ }
+ this->setBlock(builder, end);
+}
+
+void JIT::compileBreak(LLVMBuilderRef builder, const BreakStatement& b) {
+ LLVMBuildBr(builder, fBreakTarget.back());
+}
+
+void JIT::compileContinue(LLVMBuilderRef builder, const ContinueStatement& b) {
+ LLVMBuildBr(builder, fContinueTarget.back());
+}
+
+void JIT::compileReturn(LLVMBuilderRef builder, const ReturnStatement& r) {
+ if (r.fExpression) {
+ LLVMBuildRet(builder, this->compileExpression(builder, *r.fExpression));
+ } else {
+ LLVMBuildRetVoid(builder);
+ }
+}
+
+void JIT::compileStatement(LLVMBuilderRef builder, const Statement& stmt) {
+ switch (stmt.fKind) {
+ case Statement::kBlock_Kind:
+ this->compileBlock(builder, (Block&) stmt);
+ break;
+ case Statement::kBreak_Kind:
+ this->compileBreak(builder, (BreakStatement&) stmt);
+ break;
+ case Statement::kContinue_Kind:
+ this->compileContinue(builder, (ContinueStatement&) stmt);
+ break;
+ case Statement::kDiscard_Kind:
+ abort();
+ case Statement::kDo_Kind:
+ this->compileDo(builder, (DoStatement&) stmt);
+ break;
+ case Statement::kExpression_Kind:
+ this->compileExpression(builder, *((ExpressionStatement&) stmt).fExpression);
+ break;
+ case Statement::kFor_Kind:
+ this->compileFor(builder, (ForStatement&) stmt);
+ break;
+ case Statement::kGroup_Kind:
+ abort();
+ case Statement::kIf_Kind:
+ this->compileIf(builder, (IfStatement&) stmt);
+ break;
+ case Statement::kNop_Kind:
+ break;
+ case Statement::kReturn_Kind:
+ this->compileReturn(builder, (ReturnStatement&) stmt);
+ break;
+ case Statement::kSwitch_Kind:
+ abort();
+ case Statement::kVarDeclarations_Kind:
+ this->compileVarDeclarations(builder, (VarDeclarationsStatement&) stmt);
+ break;
+ case Statement::kWhile_Kind:
+ this->compileWhile(builder, (WhileStatement&) stmt);
+ break;
+ default:
+ abort();
+ }
+}
+
+void JIT::compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc) {
+ // loop over fVectorCount pixels, running the body of the stage function for each of them
+ LLVMValueRef oldFunction = fCurrentFunction;
+ fCurrentFunction = newFunc;
+ std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
+ LLVMGetParams(fCurrentFunction, params.get());
+ LLVMValueRef programParam = params.get()[1];
+ LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
+ LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
+ LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
+ fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
+ this->setBlock(builder, fAllocaBlock);
+ // temporaries to store the color channel vectors
+ LLVMValueRef rVec = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
+ LLVMBuildStore(builder, params.get()[4], rVec);
+ LLVMValueRef gVec = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
+ LLVMBuildStore(builder, params.get()[5], gVec);
+ LLVMValueRef bVec = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
+ LLVMBuildStore(builder, params.get()[6], bVec);
+ LLVMValueRef aVec = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
+ LLVMBuildStore(builder, params.get()[7], aVec);
+ LLVMValueRef color = LLVMBuildAlloca(builder, fFloat32Vector4Type, "color");
+ fVariables[f.fDeclaration.fParameters[1]] = LLVMBuildTrunc(builder, params.get()[3], fInt32Type,
+ "y->Int32");
+ fVariables[f.fDeclaration.fParameters[2]] = color;
+ LLVMValueRef ivar = LLVMBuildAlloca(builder, fInt32Type, "i");
+ LLVMBuildStore(builder, LLVMConstInt(fInt32Type, 0, false), ivar);
+ LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
+ this->setBlock(builder, start);
+ LLVMValueRef iload = LLVMBuildLoad(builder, ivar, "load i");
+ fVariables[f.fDeclaration.fParameters[0]] = LLVMBuildAdd(builder,
+ LLVMBuildTrunc(builder,
+ params.get()[2],
+ fInt32Type,
+ "x->Int32"),
+ iload,
+ "x");
+ LLVMValueRef vectorSize = LLVMConstInt(fInt32Type, fVectorCount, false);
+ LLVMValueRef test = LLVMBuildICmp(builder, LLVMIntSLT, iload, vectorSize, "i < vectorSize");
+ LLVMBasicBlockRef loopBody = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "body");
+ LLVMBasicBlockRef loopEnd = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "end");
+ LLVMBuildCondBr(builder, test, loopBody, loopEnd);
+ this->setBlock(builder, loopBody);
+ LLVMValueRef vec = LLVMGetUndef(fFloat32Vector4Type);
+ // extract the r, g, b, and a values from the color channel vectors and store them into "color"
+ for (int i = 0; i < 4; ++i) {
+ vec = LLVMBuildInsertElement(builder, vec,
+ LLVMBuildExtractElement(builder,
+ params.get()[4 + i],
+ iload, "initial"),
+ LLVMConstInt(fInt32Type, i, false),
+ "vec build");
+ }
+ LLVMBuildStore(builder, vec, color);
+ // write actual loop body
+ this->compileStatement(builder, *f.fBody);
+ // extract the r, g, b, and a values from "color" and stick them back into the color channel
+ // vectors
+ LLVMValueRef colorLoad = LLVMBuildLoad(builder, color, "color load");
+ LLVMBuildStore(builder,
+ LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, rVec, "rVec"),
+ LLVMBuildExtractElement(builder, colorLoad,
+ LLVMConstInt(fInt32Type, 0,
+ false),
+ "rExtract"),
+ iload, "rInsert"),
+ rVec);
+ LLVMBuildStore(builder,
+ LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, gVec, "gVec"),
+ LLVMBuildExtractElement(builder, colorLoad,
+ LLVMConstInt(fInt32Type, 1,
+ false),
+ "gExtract"),
+ iload, "gInsert"),
+ gVec);
+ LLVMBuildStore(builder,
+ LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, bVec, "bVec"),
+ LLVMBuildExtractElement(builder, colorLoad,
+ LLVMConstInt(fInt32Type, 2,
+ false),
+ "bExtract"),
+ iload, "bInsert"),
+ bVec);
+ LLVMBuildStore(builder,
+ LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, aVec, "aVec"),
+ LLVMBuildExtractElement(builder, colorLoad,
+ LLVMConstInt(fInt32Type, 3,
+ false),
+ "aExtract"),
+ iload, "aInsert"),
+ aVec);
+ LLVMValueRef inc = LLVMBuildAdd(builder, iload, LLVMConstInt(fInt32Type, 1, false), "inc i");
+ LLVMBuildStore(builder, inc, ivar);
+ LLVMBuildBr(builder, start);
+ this->setBlock(builder, loopEnd);
+ // increment program pointer, call the next stage
+ LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
+ LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
+ LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType, "cast next->func");
+ LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
+ LLVMBuildAdd(builder,
+ LLVMBuildPtrToInt(builder,
+ programParam,
+ fInt64Type,
+ "cast 1"),
+ LLVMConstInt(fInt64Type, PTR_SIZE, false),
+ "add"),
+ LLVMPointerType(fInt8PtrType, 0), "cast 2");
+ LLVMValueRef args[STAGE_PARAM_COUNT] = {
+ params.get()[0],
+ nextInc,
+ params.get()[2],
+ params.get()[3],
+ LLVMBuildLoad(builder, rVec, "rVec"),
+ LLVMBuildLoad(builder, gVec, "gVec"),
+ LLVMBuildLoad(builder, bVec, "bVec"),
+ LLVMBuildLoad(builder, aVec, "aVec"),
+ params.get()[8],
+ params.get()[9],
+ params.get()[10],
+ params.get()[11]
+ };
+ LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
+ LLVMBuildRetVoid(builder);
+ // finish
+ LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
+ LLVMBuildBr(builder, start);
+ LLVMDisposeBuilder(builder);
+ if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
+ ABORT("verify failed\n");
+ }
+ fAllocaBlock = oldAllocaBlock;
+ fCurrentBlock = oldCurrentBlock;
+ fCurrentFunction = oldFunction;
+}
+
+// FIXME maybe pluggable code generators? Need to do something to separate all
+// of the normal codegen from the vector codegen and break this up into multiple
+// classes.
+
+bool JIT::getVectorLValue(LLVMBuilderRef builder, const Expression& e,
+ LLVMValueRef out[CHANNELS]) {
+ switch (e.fKind) {
+ case Expression::kVariableReference_Kind:
+ if (fColorParam == &((VariableReference&) e).fVariable) {
+ memcpy(out, fChannels, sizeof(fChannels));
+ return true;
+ }
+ return false;
+ case Expression::kSwizzle_Kind: {
+ const Swizzle& s = (const Swizzle&) e;
+ LLVMValueRef base[CHANNELS];
+ if (!this->getVectorLValue(builder, *s.fBase, base)) {
+ return false;
+ }
+ for (size_t i = 0; i < s.fComponents.size(); ++i) {
+ out[i] = base[s.fComponents[i]];
+ }
+ return true;
+ }
+ default:
+ return false;
+ }
+}
+
+bool JIT::getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left,
+ LLVMValueRef outLeft[CHANNELS], const Expression& right,
+ LLVMValueRef outRight[CHANNELS]) {
+ if (!this->compileVectorExpression(builder, left, outLeft)) {
+ return false;
+ }
+ int leftColumns = left.fType.columns();
+ int rightColumns = right.fType.columns();
+ if (leftColumns == 1 && rightColumns > 1) {
+ for (int i = 1; i < rightColumns; ++i) {
+ outLeft[i] = outLeft[0];
+ }
+ }
+ if (!this->compileVectorExpression(builder, right, outRight)) {
+ return false;
+ }
+ if (rightColumns == 1 && leftColumns > 1) {
+ for (int i = 1; i < leftColumns; ++i) {
+ outRight[i] = outRight[0];
+ }
+ }
+ return true;
+}
+
+bool JIT::compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b,
+ LLVMValueRef out[CHANNELS]) {
+ LLVMValueRef left[CHANNELS];
+ LLVMValueRef right[CHANNELS];
+ #define VECTOR_BINARY(signedOp, unsignedOp, floatOp) { \
+ if (!this->getVectorBinaryOperands(builder, *b.fLeft, left, *b.fRight, right)) { \
+ return false; \
+ } \
+ for (int i = 0; i < b.fLeft->fType.columns(); ++i) { \
+ switch (this->typeKind(b.fLeft->fType)) { \
+ case kInt_TypeKind: \
+ out[i] = signedOp(builder, left[i], right[i], "binary"); \
+ break; \
+ case kUInt_TypeKind: \
+ out[i] = unsignedOp(builder, left[i], right[i], "binary"); \
+ break; \
+ case kFloat_TypeKind: \
+ out[i] = floatOp(builder, left[i], right[i], "binary"); \
+ break; \
+ case kBool_TypeKind: \
+ ASSERT(false); \
+ break; \
+ } \
+ } \
+ return true; \
+ }
+ switch (b.fOperator) {
+ case Token::EQ: {
+ if (!this->getVectorLValue(builder, *b.fLeft, left)) {
+ return false;
+ }
+ if (!this->compileVectorExpression(builder, *b.fRight, right)) {
+ return false;
+ }
+ int columns = b.fRight->fType.columns();
+ for (int i = 0; i < columns; ++i) {
+ LLVMBuildStore(builder, right[i], left[i]);
+ }
+ return true;
+ }
+ case Token::PLUS:
+ VECTOR_BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
+ case Token::MINUS:
+ VECTOR_BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
+ case Token::STAR:
+ VECTOR_BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
+ case Token::SLASH:
+ VECTOR_BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
+ case Token::PERCENT:
+ VECTOR_BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
+ case Token::BITWISEAND:
+ VECTOR_BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
+ case Token::BITWISEOR:
+ VECTOR_BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
+ default:
+ printf("unsupported operator: %s\n", b.description().c_str());
+ return false;
+ }
+}
+
+bool JIT::compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c,
+ LLVMValueRef out[CHANNELS]) {
+ switch (c.fType.kind()) {
+ case Type::kScalar_Kind: {
+ ASSERT(c.fArguments.size() == 1);
+ TypeKind from = this->typeKind(c.fArguments[0]->fType);
+ TypeKind to = this->typeKind(c.fType);
+ LLVMValueRef base[CHANNELS];
+ if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
+ return false;
+ }
+ #define CONSTRUCT(fn) \
+ out[0] = LLVMGetUndef(LLVMVectorType(this->getType(c.fType), fVectorCount)); \
+ for (int i = 0; i < fVectorCount; ++i) { \
+ LLVMValueRef index = LLVMConstInt(fInt32Type, i, false); \
+ LLVMValueRef baseVal = LLVMBuildExtractElement(builder, base[0], index, \
+ "construct extract"); \
+ out[0] = LLVMBuildInsertElement(builder, out[0], \
+ fn(builder, baseVal, this->getType(c.fType), \
+ "cast"), \
+ index, "construct insert"); \
+ } \
+ return true;
+ if (kFloat_TypeKind == to) {
+ if (kInt_TypeKind == from) {
+ CONSTRUCT(LLVMBuildSIToFP);
+ }
+ if (kUInt_TypeKind == from) {
+ CONSTRUCT(LLVMBuildUIToFP);
+ }
+ }
+ if (kInt_TypeKind == to) {
+ if (kFloat_TypeKind == from) {
+ CONSTRUCT(LLVMBuildFPToSI);
+ }
+ if (kUInt_TypeKind == from) {
+ return true;
+ }
+ }
+ if (kUInt_TypeKind == to) {
+ if (kFloat_TypeKind == from) {
+ CONSTRUCT(LLVMBuildFPToUI);
+ }
+ if (kInt_TypeKind == from) {
+ return base;
+ }
+ }
+ printf("%s\n", c.description().c_str());
+ ABORT("unsupported constructor");
+ }
+ case Type::kVector_Kind: {
+ if (c.fArguments.size() == 1) {
+ LLVMValueRef base[CHANNELS];
+ if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
+ return false;
+ }
+ for (int i = 0; i < c.fType.columns(); ++i) {
+ out[i] = base[0];
+ }
+ } else {
+ ASSERT(c.fArguments.size() == (size_t) c.fType.columns());
+ for (int i = 0; i < c.fType.columns(); ++i) {
+ LLVMValueRef base[CHANNELS];
+ if (!this->compileVectorExpression(builder, *c.fArguments[i], base)) {
+ return false;
+ }
+ out[i] = base[0];
+ }
+ }
+ return true;
+ }
+ default:
+ break;
+ }
+ ABORT("unsupported constructor");
+}
+
+bool JIT::compileVectorFloatLiteral(LLVMBuilderRef builder,
+ const FloatLiteral& f,
+ LLVMValueRef out[CHANNELS]) {
+ LLVMValueRef value = LLVMConstReal(this->getType(f.fType), f.fValue);
+ LLVMValueRef values[MAX_VECTOR_COUNT];
+ for (int i = 0; i < fVectorCount; ++i) {
+ values[i] = value;
+ }
+ out[0] = LLVMConstVector(values, fVectorCount);
+ return true;
+}
+
+
+bool JIT::compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s,
+ LLVMValueRef out[CHANNELS]) {
+ LLVMValueRef all[CHANNELS];
+ if (!this->compileVectorExpression(builder, *s.fBase, all)) {
+ return false;
+ }
+ for (size_t i = 0; i < s.fComponents.size(); ++i) {
+ out[i] = all[s.fComponents[i]];
+ }
+ return true;
+}
+
+bool JIT::compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v,
+ LLVMValueRef out[CHANNELS]) {
+ if (&v.fVariable == fColorParam) {
+ for (int i = 0; i < CHANNELS; ++i) {
+ out[i] = LLVMBuildLoad(builder, fChannels[i], "variable reference");
+ }
+ return true;
+ }
+ return false;
+}
+
+bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr,
+ LLVMValueRef out[CHANNELS]) {
+ switch (expr.fKind) {
+ case Expression::kBinary_Kind:
+ return this->compileVectorBinary(builder, (const BinaryExpression&) expr, out);
+ case Expression::kConstructor_Kind:
+ return this->compileVectorConstructor(builder, (const Constructor&) expr, out);
+ case Expression::kFloatLiteral_Kind:
+ return this->compileVectorFloatLiteral(builder, (const FloatLiteral&) expr, out);
+ case Expression::kSwizzle_Kind:
+ return this->compileVectorSwizzle(builder, (const Swizzle&) expr, out);
+ case Expression::kVariableReference_Kind:
+ return this->compileVectorVariableReference(builder, (const VariableReference&) expr,
+ out);
+ default:
+ printf("failed expression: %s\n", expr.description().c_str());
+ return false;
+ }
+}
+
+bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt) {
+ switch (stmt.fKind) {
+ case Statement::kBlock_Kind:
+ for (const auto& s : ((const Block&) stmt).fStatements) {
+ if (!this->compileVectorStatement(builder, *s)) {
+ return false;
+ }
+ }
+ return true;
+ case Statement::kExpression_Kind:
+ LLVMValueRef result;
+ return this->compileVectorExpression(builder,
+ *((const ExpressionStatement&) stmt).fExpression,
+ &result);
+ default:
+ printf("failed statement: %s\n", stmt.description().c_str());
+ return false;
+ }
+}
+
+bool JIT::compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc) {
+ LLVMValueRef oldFunction = fCurrentFunction;
+ fCurrentFunction = newFunc;
+ std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
+ LLVMGetParams(fCurrentFunction, params.get());
+ LLVMValueRef programParam = params.get()[1];
+ LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
+ LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
+ LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
+ fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
+ this->setBlock(builder, fAllocaBlock);
+ fChannels[0] = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
+ LLVMBuildStore(builder, params.get()[4], fChannels[0]);
+ fChannels[1] = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
+ LLVMBuildStore(builder, params.get()[5], fChannels[1]);
+ fChannels[2] = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
+ LLVMBuildStore(builder, params.get()[6], fChannels[2]);
+ fChannels[3] = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
+ LLVMBuildStore(builder, params.get()[7], fChannels[3]);
+ LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
+ this->setBlock(builder, start);
+ bool success = this->compileVectorStatement(builder, *f.fBody);
+ if (success) {
+ // increment program pointer, call next
+ LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
+ LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
+ LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType,
+ "cast next->func");
+ LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
+ LLVMBuildAdd(builder,
+ LLVMBuildPtrToInt(builder,
+ programParam,
+ fInt64Type,
+ "cast 1"),
+ LLVMConstInt(fInt64Type, PTR_SIZE,
+ false),
+ "add"),
+ LLVMPointerType(fInt8PtrType, 0), "cast 2");
+ LLVMValueRef args[STAGE_PARAM_COUNT] = {
+ params.get()[0],
+ nextInc,
+ params.get()[2],
+ params.get()[3],
+ LLVMBuildLoad(builder, fChannels[0], "rVec"),
+ LLVMBuildLoad(builder, fChannels[1], "gVec"),
+ LLVMBuildLoad(builder, fChannels[2], "bVec"),
+ LLVMBuildLoad(builder, fChannels[3], "aVec"),
+ params.get()[8],
+ params.get()[9],
+ params.get()[10],
+ params.get()[11]
+ };
+ LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
+ LLVMBuildRetVoid(builder);
+ // finish
+ LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
+ LLVMBuildBr(builder, start);
+ LLVMDisposeBuilder(builder);
+ if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
+ ABORT("verify failed\n");
+ }
+ } else {
+ LLVMDeleteBasicBlock(fAllocaBlock);
+ LLVMDeleteBasicBlock(start);
+ }
+
+ fAllocaBlock = oldAllocaBlock;
+ fCurrentBlock = oldCurrentBlock;
+ fCurrentFunction = oldFunction;
+ return success;
+}
+
+LLVMValueRef JIT::compileStageFunction(const FunctionDefinition& f) {
+ LLVMTypeRef returnType = fVoidType;
+ LLVMTypeRef parameterTypes[12] = { fSizeTType, LLVMPointerType(fInt8PtrType, 0), fSizeTType,
+ fSizeTType, fFloat32VectorType, fFloat32VectorType,
+ fFloat32VectorType, fFloat32VectorType, fFloat32VectorType,
+ fFloat32VectorType, fFloat32VectorType, fFloat32VectorType };
+ LLVMTypeRef stageFuncType = LLVMFunctionType(returnType, parameterTypes, 12, false);
+ LLVMValueRef result = LLVMAddFunction(fModule,
+ (String(f.fDeclaration.fName) + "$stage").c_str(),
+ stageFuncType);
+ fColorParam = f.fDeclaration.fParameters[2];
+ if (!this->compileStageFunctionVector(f, result)) {
+ // vectorization failed, fall back to looping over the pixels
+ this->compileStageFunctionLoop(f, result);
+ }
+ return result;
+}
+
+bool JIT::hasStageSignature(const FunctionDeclaration& f) {
+ return f.fReturnType == *fProgram->fContext->fVoid_Type &&
+ f.fParameters.size() == 3 &&
+ f.fParameters[0]->fType == *fProgram->fContext->fInt_Type &&
+ f.fParameters[0]->fModifiers.fFlags == 0 &&
+ f.fParameters[1]->fType == *fProgram->fContext->fInt_Type &&
+ f.fParameters[1]->fModifiers.fFlags == 0 &&
+ f.fParameters[2]->fType == *fProgram->fContext->fFloat4_Type &&
+ f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag);
+}
+
+LLVMValueRef JIT::compileFunction(const FunctionDefinition& f) {
+ if (this->hasStageSignature(f.fDeclaration)) {
+ this->compileStageFunction(f);
+ // we compile foo$stage *in addition* to compiling foo, as we can't be sure that the intent
+ // was to produce an SkJumper stage just because the signature matched or that the function
+ // is not otherwise called. May need a better way to handle this.
+ }
+ LLVMTypeRef returnType = this->getType(f.fDeclaration.fReturnType);
+ std::vector<LLVMTypeRef> parameterTypes;
+ for (const auto& p : f.fDeclaration.fParameters) {
+ LLVMTypeRef type = this->getType(p->fType);
+ if (p->fModifiers.fFlags & Modifiers::kOut_Flag) {
+ type = LLVMPointerType(type, 0);
+ }
+ parameterTypes.push_back(type);
+ }
+ fCurrentFunction = LLVMAddFunction(fModule,
+ String(f.fDeclaration.fName).c_str(),
+ LLVMFunctionType(returnType, parameterTypes.data(),
+ parameterTypes.size(), false));
+ fFunctions[&f.fDeclaration] = fCurrentFunction;
+
+ std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[parameterTypes.size()]);
+ LLVMGetParams(fCurrentFunction, params.get());
+ for (size_t i = 0; i < f.fDeclaration.fParameters.size(); ++i) {
+ fVariables[f.fDeclaration.fParameters[i]] = params.get()[i];
+ }
+ LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
+ fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
+ LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
+ fCurrentBlock = start;
+ LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
+ this->compileStatement(builder, *f.fBody);
+ if (!ends_with_branch(*f.fBody)) {
+ if (f.fDeclaration.fReturnType == *fProgram->fContext->fVoid_Type) {
+ LLVMBuildRetVoid(builder);
+ } else {
+ LLVMBuildUnreachable(builder);
+ }
+ }
+ LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
+ LLVMBuildBr(builder, start);
+ LLVMDisposeBuilder(builder);
+ if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
+ ABORT("verify failed\n");
+ }
+ return fCurrentFunction;
+}
+
+void JIT::createModule() {
+ fPromotedParameters.clear();
+ fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext);
+ this->loadBuiltinFunctions();
+ // LLVM doesn't do void*, have to declare it as int8*
+ LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType };
+ fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType,
+ appendParams,
+ 3,
+ false));
+ LLVMTypeRef appendCallbackParams[2] = { fInt8PtrType, fInt8PtrType };
+ fAppendCallbackFunc = LLVMAddFunction(fModule, "sksl_pipeline_append_callback",
+ LLVMFunctionType(fVoidType, appendCallbackParams, 2,
+ false));
+
+ LLVMTypeRef debugParams[3] = { fFloat32Type };
+ fDebugFunc = LLVMAddFunction(fModule, "sksl_debug_print", LLVMFunctionType(fVoidType,
+ debugParams,
+ 1,
+ false));
+
+ for (const auto& e : fProgram->fElements) {
+ ASSERT(e->fKind == ProgramElement::kFunction_Kind);
+ this->compileFunction((FunctionDefinition&) *e);
+ }
+}
+
+std::unique_ptr<JIT::Module> JIT::compile(std::unique_ptr<Program> program) {
+ fProgram = std::move(program);
+ this->createModule();
+ this->optimize();
+ return std::unique_ptr<Module>(new Module(std::move(fProgram), fSharedModule, fJITStack));
+}
+
+void JIT::optimize() {
+ LLVMPassManagerBuilderRef pmb = LLVMPassManagerBuilderCreate();
+ LLVMPassManagerBuilderSetOptLevel(pmb, 3);
+ LLVMPassManagerRef functionPM = LLVMCreateFunctionPassManagerForModule(fModule);
+ LLVMPassManagerBuilderPopulateFunctionPassManager(pmb, functionPM);
+ LLVMPassManagerRef modulePM = LLVMCreatePassManager();
+ LLVMPassManagerBuilderPopulateModulePassManager(pmb, modulePM);
+ LLVMInitializeFunctionPassManager(functionPM);
+
+ LLVMValueRef func = LLVMGetFirstFunction(fModule);
+ for (;;) {
+ if (!func) {
+ break;
+ }
+ LLVMRunFunctionPassManager(functionPM, func);
+ func = LLVMGetNextFunction(func);
+ }
+ LLVMRunPassManager(modulePM, fModule);
+ LLVMDisposePassManager(functionPM);
+ LLVMDisposePassManager(modulePM);
+ LLVMPassManagerBuilderDispose(pmb);
+
+ std::string error_string;
+ if (LLVMLoadLibraryPermanently(nullptr)) {
+ ABORT("LLVMLoadLibraryPermanently failed");
+ }
+ char* defaultTriple = LLVMGetDefaultTargetTriple();
+ char* error;
+ LLVMTargetRef target;
+ if (LLVMGetTargetFromTriple(defaultTriple, &target, &error)) {
+ ABORT("LLVMGetTargetFromTriple failed");
+ }
+
+ if (!LLVMTargetHasJIT(target)) {
+ ABORT("!LLVMTargetHasJIT");
+ }
+ LLVMTargetMachineRef targetMachine = LLVMCreateTargetMachine(target,
+ defaultTriple,
+ fCPU,
+ nullptr,
+ LLVMCodeGenLevelDefault,
+ LLVMRelocDefault,
+ LLVMCodeModelJITDefault);
+ LLVMDisposeMessage(defaultTriple);
+ LLVMTargetDataRef dataLayout = LLVMCreateTargetDataLayout(targetMachine);
+ LLVMSetModuleDataLayout(fModule, dataLayout);
+ LLVMDisposeTargetData(dataLayout);
+
+ fJITStack = LLVMOrcCreateInstance(targetMachine);
+ fSharedModule = LLVMOrcMakeSharedModule(fModule);
+ LLVMOrcModuleHandle orcModule;
+ LLVMOrcAddEagerlyCompiledIR(fJITStack, &orcModule, fSharedModule,
+ (LLVMOrcSymbolResolverFn) resolveSymbol, this);
+ LLVMDisposeTargetMachine(targetMachine);
+}
+
+void* JIT::Module::getSymbol(const char* name) {
+ LLVMOrcTargetAddress result;
+ if (LLVMOrcGetSymbolAddress(fJITStack, &result, name)) {
+ ABORT("GetSymbolAddress error");
+ }
+ if (!result) {
+ ABORT("symbol not found");
+ }
+ return (void*) result;
+}
+
+void* JIT::Module::getJumperStage(const char* name) {
+ return this->getSymbol((String(name) + "$stage").c_str());
+}
+
+} // namespace
+
+#endif // SK_LLVM_AVAILABLE
+
+#endif // SKSL_STANDALONE
diff --git a/src/sksl/SkSLJIT.h b/src/sksl/SkSLJIT.h
new file mode 100644
index 0000000000..b23e31237f
--- /dev/null
+++ b/src/sksl/SkSLJIT.h
@@ -0,0 +1,344 @@
+/*
+ * Copyright 2018 Google Inc.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_JIT
+#define SKSL_JIT
+
+#ifdef SK_LLVM_AVAILABLE
+
+#include "ir/SkSLAppendStage.h"
+#include "ir/SkSLBinaryExpression.h"
+#include "ir/SkSLBreakStatement.h"
+#include "ir/SkSLContinueStatement.h"
+#include "ir/SkSLExpression.h"
+#include "ir/SkSLDoStatement.h"
+#include "ir/SkSLForStatement.h"
+#include "ir/SkSLFunctionCall.h"
+#include "ir/SkSLFunctionDefinition.h"
+#include "ir/SkSLIfStatement.h"
+#include "ir/SkSLIndexExpression.h"
+#include "ir/SkSLPrefixExpression.h"
+#include "ir/SkSLPostfixExpression.h"
+#include "ir/SkSLProgram.h"
+#include "ir/SkSLReturnStatement.h"
+#include "ir/SkSLStatement.h"
+#include "ir/SkSLSwizzle.h"
+#include "ir/SkSLTernaryExpression.h"
+#include "ir/SkSLVarDeclarationsStatement.h"
+#include "ir/SkSLVariableReference.h"
+#include "ir/SkSLWhileStatement.h"
+
+#include "llvm-c/Analysis.h"
+#include "llvm-c/Core.h"
+#include "llvm-c/OrcBindings.h"
+#include "llvm-c/Support.h"
+#include "llvm-c/Target.h"
+#include "llvm-c/Transforms/PassManagerBuilder.h"
+#include "llvm-c/Types.h"
+#include <stack>
+
+class SkRasterPipeline;
+
+namespace SkSL {
+
+/**
+ * A just-in-time compiler for SkSL code which uses an LLVM backend. Only available when the
+ * skia_llvm_path gn arg is set.
+ *
+ * Example of using SkSLJIT to set up an SkJumper pipeline stage:
+ *
+ * #ifdef SK_LLVM_AVAILABLE
+ * SkSL::Compiler compiler;
+ * SkSL::Program::Settings settings;
+ * std::unique_ptr<SkSL::Program> program = compiler.convertProgram(SkSL::Program::kCPU_Kind,
+ * "void swap(int x, int y, inout float4 color) {"
+ * " color.rb = color.br;"
+ * "}",
+ * settings);
+ * if (!program) {
+ * printf("%s\n", compiler.errorText().c_str());
+ * abort();
+ * }
+ * SkSL::JIT& jit = *scratch->make<SkSL::JIT>(&compiler);
+ * std::unique_ptr<SkSL::JIT::Module> module = jit.compile(std::move(program));
+ * void* func = module->getJumperStage("swap");
+ * p->append(func, nullptr);
+ * #endif
+ */
+class JIT {
+ typedef int StackIndex;
+
+public:
+ class Module {
+ public:
+ /**
+ * Returns the address of a symbol in the module.
+ */
+ void* getSymbol(const char* name);
+
+ /**
+ * Returns the address of a function as an SkJumper pipeline stage. The function must have
+ * the signature void <name>(int x, int y, inout float4 color). The returned function will
+ * have the correct signature to function as an SkJumper stage (meaning it will actually
+ * have a different signature at runtime, accepting vector parameters and operating on
+ * multiple pixels simultaneously as is normal for SkJumper stages).
+ */
+ void* getJumperStage(const char* name);
+
+ ~Module() {
+ LLVMOrcDisposeSharedModuleRef(fSharedModule);
+ }
+
+ private:
+ Module(std::unique_ptr<Program> program,
+ LLVMSharedModuleRef sharedModule,
+ LLVMOrcJITStackRef jitStack)
+ : fProgram(std::move(program))
+ , fSharedModule(sharedModule)
+ , fJITStack(jitStack) {}
+
+ std::unique_ptr<Program> fProgram;
+ LLVMSharedModuleRef fSharedModule;
+ LLVMOrcJITStackRef fJITStack;
+
+ friend class JIT;
+ };
+
+ JIT(Compiler* compiler);
+
+ ~JIT();
+
+ /**
+ * Just-in-time compiles an SkSL program and returns the resulting Module. The JIT must not be
+ * destroyed before all of its Modules are destroyed.
+ */
+ std::unique_ptr<Module> compile(std::unique_ptr<Program> program);
+
+private:
+ static constexpr int CHANNELS = 4;
+
+ enum TypeKind {
+ kFloat_TypeKind,
+ kInt_TypeKind,
+ kUInt_TypeKind,
+ kBool_TypeKind
+ };
+
+ class LValue {
+ public:
+ virtual ~LValue() {}
+
+ virtual LLVMValueRef load(LLVMBuilderRef builder) = 0;
+
+ virtual void store(LLVMBuilderRef builder, LLVMValueRef value) = 0;
+ };
+
+ void addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType,
+ std::vector<LLVMTypeRef> parameters);
+
+ void loadBuiltinFunctions();
+
+ void setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block);
+
+ LLVMTypeRef getType(const Type& type);
+
+ TypeKind typeKind(const Type& type);
+
+ std::unique_ptr<LValue> getLValue(LLVMBuilderRef builder, const Expression& expr);
+
+ void vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns);
+
+ void vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left,
+ LLVMValueRef* right);
+
+ LLVMValueRef compileBinary(LLVMBuilderRef builder, const BinaryExpression& b);
+
+ LLVMValueRef compileConstructor(LLVMBuilderRef builder, const Constructor& c);
+
+ LLVMValueRef compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc);
+
+ LLVMValueRef compileIndex(LLVMBuilderRef builder, const IndexExpression& v);
+
+ LLVMValueRef compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p);
+
+ LLVMValueRef compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p);
+
+ LLVMValueRef compileSwizzle(LLVMBuilderRef builder, const Swizzle& s);
+
+ LLVMValueRef compileVariableReference(LLVMBuilderRef builder, const VariableReference& v);
+
+ LLVMValueRef compileTernary(LLVMBuilderRef builder, const TernaryExpression& t);
+
+ LLVMValueRef compileExpression(LLVMBuilderRef builder, const Expression& expr);
+
+ void appendStage(LLVMBuilderRef builder, const AppendStage& a);
+
+ void compileBlock(LLVMBuilderRef builder, const Block& block);
+
+ void compileBreak(LLVMBuilderRef builder, const BreakStatement& b);
+
+ void compileContinue(LLVMBuilderRef builder, const ContinueStatement& c);
+
+ void compileDo(LLVMBuilderRef builder, const DoStatement& d);
+
+ void compileFor(LLVMBuilderRef builder, const ForStatement& f);
+
+ void compileIf(LLVMBuilderRef builder, const IfStatement& i);
+
+ void compileReturn(LLVMBuilderRef builder, const ReturnStatement& r);
+
+ void compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls);
+
+ void compileWhile(LLVMBuilderRef builder, const WhileStatement& w);
+
+ void compileStatement(LLVMBuilderRef builder, const Statement& stmt);
+
+ // The "Vector" variants of functions attempt to compile a given expression or statement as part
+ // of a vectorized SkJumper stage function - that is, with r, g, b, and a each being vectors of
+ // fVectorCount floats. So a statement like "color.r = 0;" looks like it modifies a single
+ // channel of a single pixel, but the compiled code will actually modify the red channel of
+ // fVectorCount pixels at once.
+ //
+ // As not everything can be vectorized, these calls return a bool to indicate whether they were
+ // successful. If anything anywhere in the function cannot be vectorized, the JIT will fall back
+ // to looping over the pixels instead.
+ //
+ // Since we process multiple pixels at once, and each pixel consists of multiple color channels,
+ // expressions may effectively result in a vector-of-vectors. We produce zero to four outputs
+ // when compiling expression, each of which is a vector, so that e.g. float2(1, 0) actually
+ // produces two vectors, one containing all 1s, the other all 0s. The out parameter always
+ // allows for 4 channels, but the functions produce 0 to 4 channels depending on the type they
+ // are operating on. Thus evaluating "color.rgb" actually fills in out[0] through out[2],
+ // leaving out[3] uninitialized.
+ // As the number of outputs can be inferred from the type of the expression, it is not
+ // explicitly signalled anywhere.
+ bool compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b,
+ LLVMValueRef out[CHANNELS]);
+
+ bool compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c,
+ LLVMValueRef out[CHANNELS]);
+
+ bool compileVectorFloatLiteral(LLVMBuilderRef builder, const FloatLiteral& f,
+ LLVMValueRef out[CHANNELS]);
+
+ bool compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s,
+ LLVMValueRef out[CHANNELS]);
+
+ bool compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v,
+ LLVMValueRef out[CHANNELS]);
+
+ bool compileVectorExpression(LLVMBuilderRef builder, const Expression& expr,
+ LLVMValueRef out[CHANNELS]);
+
+ bool getVectorLValue(LLVMBuilderRef builder, const Expression& e, LLVMValueRef out[CHANNELS]);
+
+ /**
+ * Evaluates the left and right operands of a binary operation, promoting one of them to a
+ * vector if necessary to make the types match.
+ */
+ bool getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left,
+ LLVMValueRef outLeft[CHANNELS], const Expression& right,
+ LLVMValueRef outRight[CHANNELS]);
+
+ bool compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt);
+
+ /**
+ * Returns true if this function has the signature void(int, int, inout float4) and thus can be
+ * used as an SkJumper stage.
+ */
+ bool hasStageSignature(const FunctionDeclaration& f);
+
+ /**
+ * Attempts to compile a vectorized stage function, returning true on success. A stage function
+ * of e.g. "color.r = 0;" will produce code which sets the entire red vector to zeros in a
+ * single instruction, thus calculating several pixels at once.
+ */
+ bool compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc);
+
+ /**
+ * Fallback function which loops over the pixels, for when vectorization fails. A stage function
+ * of e.g. "color.r = 0;" will produce a loop which iterates over the entries in the red vector,
+ * setting each one to zero individually.
+ */
+ void compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc);
+
+ /**
+ * Called when compiling a function which has the signature of an SkJumper stage. Produces a
+ * version of the function which can be plugged into SkJumper (thus having a signature which
+ * accepts four vectors, one for each color channel, containing the color data of multiple
+ * pixels at once). To go from SkSL code which operates on a single pixel at a time to CPU code
+ * which operates on multiple pixels at once, the code is either vectorized using
+ * compileStageFunctionVector or wrapped in a loop using compileStageFunctionLoop.
+ */
+ LLVMValueRef compileStageFunction(const FunctionDefinition& f);
+
+ /**
+ * Compiles an SkSL function to an LLVM function. If the function has the signature of an
+ * SkJumper stage, it will *also* be compiled by compileStageFunction, resulting in both a stage
+ * and non-stage version of the function.
+ */
+ LLVMValueRef compileFunction(const FunctionDefinition& f);
+
+ void createModule();
+
+ void optimize();
+
+ bool isColorRef(const Expression& expr);
+
+ static uint64_t resolveSymbol(const char* name, JIT* jit);
+
+ const char* fCPU;
+ int fVectorCount;
+ Compiler& fCompiler;
+ std::unique_ptr<Program> fProgram;
+ LLVMContextRef fContext;
+ LLVMModuleRef fModule;
+ LLVMSharedModuleRef fSharedModule;
+ LLVMOrcJITStackRef fJITStack;
+ LLVMValueRef fCurrentFunction;
+ LLVMBasicBlockRef fAllocaBlock;
+ LLVMBasicBlockRef fCurrentBlock;
+ LLVMTypeRef fVoidType;
+ LLVMTypeRef fInt1Type;
+ LLVMTypeRef fInt8Type;
+ LLVMTypeRef fInt8PtrType;
+ LLVMTypeRef fInt32Type;
+ LLVMTypeRef fInt32VectorType;
+ LLVMTypeRef fInt32Vector2Type;
+ LLVMTypeRef fInt32Vector3Type;
+ LLVMTypeRef fInt32Vector4Type;
+ LLVMTypeRef fInt64Type;
+ LLVMTypeRef fSizeTType;
+ LLVMTypeRef fFloat32Type;
+ LLVMTypeRef fFloat32VectorType;
+ LLVMTypeRef fFloat32Vector2Type;
+ LLVMTypeRef fFloat32Vector3Type;
+ LLVMTypeRef fFloat32Vector4Type;
+ // Our SkSL stage functions have a single float4 for color, but the actual SkJumper stage
+ // function has four separate vectors, one for each channel. These four values are references to
+ // the red, green, blue, and alpha vectors respectively.
+ LLVMValueRef fChannels[CHANNELS];
+ // when processing a stage function, this points to the SkSL color parameter (an inout float4)
+ const Variable* fColorParam;
+ std::map<const FunctionDeclaration*, LLVMValueRef> fFunctions;
+ std::map<const Variable*, LLVMValueRef> fVariables;
+ // LLVM function parameters are read-only, so when modifying function parameters we need to
+ // first promote them to variables. This keeps track of which parameters have been promoted.
+ std::set<const Variable*> fPromotedParameters;
+ std::vector<LLVMBasicBlockRef> fBreakTarget;
+ std::vector<LLVMBasicBlockRef> fContinueTarget;
+
+ LLVMValueRef fAppendFunc;
+ LLVMValueRef fAppendCallbackFunc;
+ LLVMValueRef fDebugFunc;
+};
+
+} // namespace
+
+#endif // SK_LLVM_AVAILABLE
+
+#endif // SKSL_JIT
diff --git a/src/sksl/ir/SkSLAppendStage.h b/src/sksl/ir/SkSLAppendStage.h
new file mode 100644
index 0000000000..87a8210a83
--- /dev/null
+++ b/src/sksl/ir/SkSLAppendStage.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright 2018 Google Inc.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_APPENDSTAGE
+#define SKSL_APPENDSTAGE
+
+#ifndef SKSL_STANDALONE
+
+#include "SkRasterPipeline.h"
+#include "SkSLContext.h"
+#include "SkSLExpression.h"
+
+namespace SkSL {
+
+struct AppendStage : public Expression {
+ AppendStage(const Context& context, int offset, SkRasterPipeline::StockStage stage,
+ std::vector<std::unique_ptr<Expression>> arguments)
+ : INHERITED(offset, kAppendStage_Kind, *context.fVoid_Type)
+ , fStage(stage)
+ , fArguments(std::move(arguments)) {}
+
+ String description() const {
+ String result = "append(";
+ const char* separator = "";
+ for (const auto& a : fArguments) {
+ result += separator;
+ result += a->description();
+ separator = ", ";
+ }
+ result += ")";
+ return result;
+ }
+
+ bool hasSideEffects() const {
+ return true;
+ }
+
+ SkRasterPipeline::StockStage fStage;
+
+ std::vector<std::unique_ptr<Expression>> fArguments;
+
+ typedef Expression INHERITED;
+};
+
+} // namespace
+
+#endif // SKSL_STANDALONE
+
+#endif // SKSL_APPENDSTAGE
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index 52fa5d297f..c8ad1380e7 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -25,6 +25,7 @@ typedef std::unordered_map<const Variable*, std::unique_ptr<Expression>*> Defini
*/
struct Expression : public IRNode {
enum Kind {
+ kAppendStage_Kind,
kBinary_Kind,
kBoolLiteral_Kind,
kConstructor_Kind,
diff --git a/src/sksl/ir/SkSLFunctionReference.h b/src/sksl/ir/SkSLFunctionReference.h
index 58831c5e99..58fefce801 100644
--- a/src/sksl/ir/SkSLFunctionReference.h
+++ b/src/sksl/ir/SkSLFunctionReference.h
@@ -29,7 +29,6 @@ struct FunctionReference : public Expression {
}
String description() const override {
- ASSERT(false);
return String("<function>");
}
diff --git a/src/sksl/ir/SkSLProgram.h b/src/sksl/ir/SkSLProgram.h
index a63cd237ac..cbb9dfe1a7 100644
--- a/src/sksl/ir/SkSLProgram.h
+++ b/src/sksl/ir/SkSLProgram.h
@@ -103,13 +103,14 @@ struct Program {
kFragment_Kind,
kVertex_Kind,
kGeometry_Kind,
- kFragmentProcessor_Kind
+ kFragmentProcessor_Kind,
+ kCPU_Kind
};
Program(Kind kind,
std::unique_ptr<String> source,
Settings settings,
- Context* context,
+ std::shared_ptr<Context> context,
std::vector<std::unique_ptr<ProgramElement>> elements,
std::shared_ptr<SymbolTable> symbols,
Inputs inputs)
@@ -124,7 +125,7 @@ struct Program {
Kind fKind;
std::unique_ptr<String> fSource;
Settings fSettings;
- Context* fContext;
+ std::shared_ptr<Context> fContext;
// it's important to keep fElements defined after (and thus destroyed before) fSymbols,
// because destroying elements can modify reference counts in symbols
std::shared_ptr<SymbolTable> fSymbols;
diff --git a/src/sksl/sksl_cpu.inc b/src/sksl/sksl_cpu.inc
new file mode 100644
index 0000000000..479450bd33
--- /dev/null
+++ b/src/sksl/sksl_cpu.inc
@@ -0,0 +1,12 @@
+STRINGIFY(
+ // special-cased within the compiler - append takes various arguments depending on what kind of
+ // stage is being appended
+ sk_has_side_effects void append();
+
+ float abs(float x);
+ float sin(float x);
+ float cos(float x);
+ float tan(float x);
+ float sqrt(float x);
+ sk_has_side_effects void print(float x);
+)