aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/sksl/SkSLJIT.cpp
diff options
context:
space:
mode:
authorGravatar Ethan Nicholas <ethannicholas@google.com>2018-03-27 14:10:52 -0400
committerGravatar Skia Commit-Bot <skia-commit-bot@chromium.org>2018-03-27 18:39:13 +0000
commit26a9aad63b60c9cbbdfa87c212a4e76ce55e7373 (patch)
treed4a42afd75bdff6c7815be7bf78b42cab742df19 /src/sksl/SkSLJIT.cpp
parent3560b58de36988e1fba54d9ac341735ab849e913 (diff)
initial SkSLJIT checkin
Docs-Preview: https://skia.org/?cl=112204 Bug: skia: Change-Id: I10042a0200db00bd8ff8078467c409b1cf191f50 Reviewed-on: https://skia-review.googlesource.com/112204 Commit-Queue: Ethan Nicholas <ethannicholas@google.com> Reviewed-by: Mike Klein <mtklein@chromium.org>
Diffstat (limited to 'src/sksl/SkSLJIT.cpp')
-rw-r--r--src/sksl/SkSLJIT.cpp1747
1 files changed, 1747 insertions, 0 deletions
diff --git a/src/sksl/SkSLJIT.cpp b/src/sksl/SkSLJIT.cpp
new file mode 100644
index 0000000000..06b6e2c94b
--- /dev/null
+++ b/src/sksl/SkSLJIT.cpp
@@ -0,0 +1,1747 @@
+/*
+ * Copyright 2018 Google Inc.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_STANDALONE
+
+#ifdef SK_LLVM_AVAILABLE
+
+#include "SkSLJIT.h"
+
+#include "SkCpu.h"
+#include "SkRasterPipeline.h"
+#include "../jumper/SkJumper.h"
+#include "ir/SkSLExpressionStatement.h"
+#include "ir/SkSLFunctionCall.h"
+#include "ir/SkSLFunctionReference.h"
+#include "ir/SkSLIndexExpression.h"
+#include "ir/SkSLProgram.h"
+#include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
+
+static constexpr int MAX_VECTOR_COUNT = 16;
+
+extern "C" void sksl_pipeline_append(SkRasterPipeline* p, int stage, void* ctx) {
+ p->append((SkRasterPipeline::StockStage) stage, ctx);
+}
+
+#define PTR_SIZE sizeof(void*)
+
+extern "C" void sksl_pipeline_append_callback(SkRasterPipeline* p, void* fn) {
+ p->append(fn, nullptr);
+}
+
+extern "C" void sksl_debug_print(float f) {
+ printf("Debug: %f\n", f);
+}
+
+namespace SkSL {
+
+static constexpr int STAGE_PARAM_COUNT = 12;
+
+static bool ends_with_branch(const Statement& stmt) {
+ switch (stmt.fKind) {
+ case Statement::kBlock_Kind: {
+ const Block& b = (const Block&) stmt;
+ if (b.fStatements.size()) {
+ return ends_with_branch(*b.fStatements.back());
+ }
+ return false;
+ }
+ case Statement::kBreak_Kind: // fall through
+ case Statement::kContinue_Kind: // fall through
+ case Statement::kReturn_Kind: // fall through
+ return true;
+ default:
+ return false;
+ }
+}
+
+JIT::JIT(Compiler* compiler)
+: fCompiler(*compiler) {
+ LLVMInitializeNativeTarget();
+ LLVMInitializeNativeAsmPrinter();
+ LLVMLinkInMCJIT();
+ ASSERT(!SkCpu::Supports(SkCpu::SKX)); // not yet supported
+ if (SkCpu::Supports(SkCpu::HSW)) {
+ fVectorCount = 8;
+ fCPU = "haswell";
+ } else if (SkCpu::Supports(SkCpu::AVX)) {
+ fVectorCount = 8;
+ fCPU = "ivybridge";
+ } else {
+ fVectorCount = 4;
+ fCPU = nullptr;
+ }
+ fContext = LLVMContextCreate();
+ fVoidType = LLVMVoidTypeInContext(fContext);
+ fInt1Type = LLVMInt1TypeInContext(fContext);
+ fInt8Type = LLVMInt8TypeInContext(fContext);
+ fInt8PtrType = LLVMPointerType(fInt8Type, 0);
+ fInt32Type = LLVMInt32TypeInContext(fContext);
+ fInt64Type = LLVMInt64TypeInContext(fContext);
+ fSizeTType = LLVMInt64TypeInContext(fContext);
+ fInt32VectorType = LLVMVectorType(fInt32Type, fVectorCount);
+ fInt32Vector2Type = LLVMVectorType(fInt32Type, 2);
+ fInt32Vector3Type = LLVMVectorType(fInt32Type, 3);
+ fInt32Vector4Type = LLVMVectorType(fInt32Type, 4);
+ fFloat32Type = LLVMFloatTypeInContext(fContext);
+ fFloat32VectorType = LLVMVectorType(fFloat32Type, fVectorCount);
+ fFloat32Vector2Type = LLVMVectorType(fFloat32Type, 2);
+ fFloat32Vector3Type = LLVMVectorType(fFloat32Type, 3);
+ fFloat32Vector4Type = LLVMVectorType(fFloat32Type, 4);
+}
+
+JIT::~JIT() {
+ LLVMOrcDisposeInstance(fJITStack);
+ LLVMContextDispose(fContext);
+}
+
+void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType,
+ std::vector<LLVMTypeRef> parameters) {
+ for (const auto& pair : *fProgram->fSymbols) {
+ if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) {
+ const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second;
+ if (pair.first != ourName || returnType != this->getType(f.fReturnType) ||
+ parameters.size() != f.fParameters.size()) {
+ continue;
+ }
+ for (size_t i = 0; i < parameters.size(); ++i) {
+ if (parameters[i] != this->getType(f.fParameters[i]->fType)) {
+ goto next;
+ }
+ }
+ fFunctions[&f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(returnType,
+ parameters.data(),
+ parameters.size(),
+ false));
+ }
+ next:;
+ }
+}
+
+void JIT::loadBuiltinFunctions() {
+ this->addBuiltinFunction("abs", "fabs", fFloat32Type, { fFloat32Type });
+ this->addBuiltinFunction("sin", "sinf", fFloat32Type, { fFloat32Type });
+ this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type });
+ this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type });
+ this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type });
+ this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type });
+}
+
+uint64_t JIT::resolveSymbol(const char* name, JIT* jit) {
+ LLVMOrcTargetAddress result;
+ if (!LLVMOrcGetSymbolAddress(jit->fJITStack, &result, name)) {
+ if (!strcmp(name, "_sksl_pipeline_append")) {
+ result = (uint64_t) &sksl_pipeline_append;
+ } else if (!strcmp(name, "_sksl_pipeline_append_callback")) {
+ result = (uint64_t) &sksl_pipeline_append_callback;
+ } else if (!strcmp(name, "_sksl_debug_print")) {
+ result = (uint64_t) &sksl_debug_print;
+ } else {
+ result = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name);
+ }
+ }
+ ASSERT(result);
+ return result;
+}
+
+LLVMValueRef JIT::compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc) {
+ LLVMValueRef func = fFunctions[&fc.fFunction];
+ ASSERT(func);
+ std::vector<LLVMValueRef> parameters;
+ for (const auto& a : fc.fArguments) {
+ parameters.push_back(this->compileExpression(builder, *a));
+ }
+ return LLVMBuildCall(builder, func, parameters.data(), parameters.size(), "");
+}
+
+LLVMTypeRef JIT::getType(const Type& type) {
+ switch (type.kind()) {
+ case Type::kOther_Kind:
+ if (type.name() == "void") {
+ return fVoidType;
+ }
+ ASSERT(type.name() == "SkRasterPipeline");
+ return fInt8PtrType;
+ case Type::kScalar_Kind:
+ if (type.isSigned() || type.isUnsigned()) {
+ return fInt32Type;
+ }
+ if (type.isUnsigned()) {
+ return fInt32Type;
+ }
+ if (type.isFloat()) {
+ return fFloat32Type;
+ }
+ ASSERT(type.name() == "bool");
+ return fInt1Type;
+ case Type::kArray_Kind:
+ return LLVMPointerType(this->getType(type.componentType()), 0);
+ case Type::kVector_Kind:
+ if (type.name() == "float2" || type.name() == "half2") {
+ return fFloat32Vector2Type;
+ }
+ if (type.name() == "float3" || type.name() == "half3") {
+ return fFloat32Vector3Type;
+ }
+ if (type.name() == "float4" || type.name() == "half4") {
+ return fFloat32Vector4Type;
+ }
+ if (type.name() == "int2" || type.name() == "short2") {
+ return fInt32Vector2Type;
+ }
+ if (type.name() == "int3" || type.name() == "short3") {
+ return fInt32Vector3Type;
+ }
+ if (type.name() == "int4" || type.name() == "short4") {
+ return fInt32Vector4Type;
+ }
+ // fall through
+ default:
+ ABORT("unsupported type");
+ }
+}
+
+void JIT::setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block) {
+ fCurrentBlock = block;
+ LLVMPositionBuilderAtEnd(builder, block);
+}
+
+std::unique_ptr<JIT::LValue> JIT::getLValue(LLVMBuilderRef builder, const Expression& expr) {
+ switch (expr.fKind) {
+ case Expression::kVariableReference_Kind: {
+ class PointerLValue : public LValue {
+ public:
+ PointerLValue(LLVMValueRef ptr)
+ : fPointer(ptr) {}
+
+ LLVMValueRef load(LLVMBuilderRef builder) override {
+ return LLVMBuildLoad(builder, fPointer, "lvalue load");
+ }
+
+ void store(LLVMBuilderRef builder, LLVMValueRef value) override {
+ LLVMBuildStore(builder, value, fPointer);
+ }
+
+ private:
+ LLVMValueRef fPointer;
+ };
+ const Variable* var = &((VariableReference&) expr).fVariable;
+ if (var->fStorage == Variable::kParameter_Storage &&
+ !(var->fModifiers.fFlags & Modifiers::kOut_Flag) &&
+ fPromotedParameters.find(var) == fPromotedParameters.end()) {
+ // promote parameter to variable
+ fPromotedParameters.insert(var);
+ LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
+ LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(var->fType),
+ String(var->fName).c_str());
+ LLVMBuildStore(builder, fVariables[var], alloca);
+ LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
+ fVariables[var] = alloca;
+ }
+ LLVMValueRef ptr = fVariables[var];
+ return std::unique_ptr<LValue>(new PointerLValue(ptr));
+ }
+ case Expression::kTernary_Kind: {
+ class TernaryLValue : public LValue {
+ public:
+ TernaryLValue(JIT* jit, LLVMValueRef test, std::unique_ptr<LValue> ifTrue,
+ std::unique_ptr<LValue> ifFalse)
+ : fJIT(*jit)
+ , fTest(test)
+ , fIfTrue(std::move(ifTrue))
+ , fIfFalse(std::move(ifFalse)) {}
+
+ LLVMValueRef load(LLVMBuilderRef builder) override {
+ LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
+ fJIT.fContext,
+ fJIT.fCurrentFunction,
+ "true ? ...");
+ LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
+ fJIT.fContext,
+ fJIT.fCurrentFunction,
+ "false ? ...");
+ LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
+ fJIT.fCurrentFunction,
+ "ternary merge");
+ LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
+ fJIT.setBlock(builder, trueBlock);
+ LLVMValueRef ifTrue = fIfTrue->load(builder);
+ LLVMBuildBr(builder, merge);
+ fJIT.setBlock(builder, falseBlock);
+ LLVMValueRef ifFalse = fIfTrue->load(builder);
+ LLVMBuildBr(builder, merge);
+ fJIT.setBlock(builder, merge);
+ LLVMTypeRef type = LLVMPointerType(LLVMTypeOf(ifTrue), 0);
+ LLVMValueRef phi = LLVMBuildPhi(builder, type, "?");
+ LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
+ LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
+ LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
+ return phi;
+ }
+
+ void store(LLVMBuilderRef builder, LLVMValueRef value) override {
+ LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
+ fJIT.fContext,
+ fJIT.fCurrentFunction,
+ "true ? ...");
+ LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
+ fJIT.fContext,
+ fJIT.fCurrentFunction,
+ "false ? ...");
+ LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
+ fJIT.fCurrentFunction,
+ "ternary merge");
+ LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
+ fJIT.setBlock(builder, trueBlock);
+ fIfTrue->store(builder, value);
+ LLVMBuildBr(builder, merge);
+ fJIT.setBlock(builder, falseBlock);
+ fIfTrue->store(builder, value);
+ LLVMBuildBr(builder, merge);
+ fJIT.setBlock(builder, merge);
+ }
+
+ private:
+ JIT& fJIT;
+ LLVMValueRef fTest;
+ std::unique_ptr<LValue> fIfTrue;
+ std::unique_ptr<LValue> fIfFalse;
+ };
+ const TernaryExpression& t = (const TernaryExpression&) expr;
+ LLVMValueRef test = this->compileExpression(builder, *t.fTest);
+ return std::unique_ptr<LValue>(new TernaryLValue(this,
+ test,
+ this->getLValue(builder,
+ *t.fIfTrue),
+ this->getLValue(builder,
+ *t.fIfFalse)));
+ }
+ case Expression::kSwizzle_Kind: {
+ class SwizzleLValue : public LValue {
+ public:
+ SwizzleLValue(JIT* jit, LLVMTypeRef type, std::unique_ptr<LValue> base,
+ std::vector<int> components)
+ : fJIT(*jit)
+ , fType(type)
+ , fBase(std::move(base))
+ , fComponents(components) {}
+
+ LLVMValueRef load(LLVMBuilderRef builder) override {
+ LLVMValueRef base = fBase->load(builder);
+ if (fComponents.size() > 1) {
+ LLVMValueRef result = LLVMGetUndef(fType);
+ for (size_t i = 0; i < fComponents.size(); ++i) {
+ LLVMValueRef element = LLVMBuildExtractElement(
+ builder,
+ base,
+ LLVMConstInt(fJIT.fInt32Type,
+ fComponents[i],
+ false),
+ "swizzle extract");
+ result = LLVMBuildInsertElement(builder, result, element,
+ LLVMConstInt(fJIT.fInt32Type, i, false),
+ "swizzle insert");
+ }
+ return result;
+ }
+ ASSERT(fComponents.size() == 1);
+ return LLVMBuildExtractElement(builder, base,
+ LLVMConstInt(fJIT.fInt32Type,
+ fComponents[0],
+ false),
+ "swizzle extract");
+ }
+
+ void store(LLVMBuilderRef builder, LLVMValueRef value) override {
+ LLVMValueRef result = fBase->load(builder);
+ if (fComponents.size() > 1) {
+ for (size_t i = 0; i < fComponents.size(); ++i) {
+ LLVMValueRef element = LLVMBuildExtractElement(builder, value,
+ LLVMConstInt(
+ fJIT.fInt32Type,
+ i,
+ false),
+ "swizzle extract");
+ result = LLVMBuildInsertElement(builder, result, element,
+ LLVMConstInt(fJIT.fInt32Type,
+ fComponents[i],
+ false),
+ "swizzle insert");
+ }
+ } else {
+ result = LLVMBuildInsertElement(builder, result, value,
+ LLVMConstInt(fJIT.fInt32Type,
+ fComponents[0],
+ false),
+ "swizzle insert");
+ }
+ fBase->store(builder, result);
+ }
+
+ private:
+ JIT& fJIT;
+ LLVMTypeRef fType;
+ std::unique_ptr<LValue> fBase;
+ std::vector<int> fComponents;
+ };
+ const Swizzle& s = (const Swizzle&) expr;
+ return std::unique_ptr<LValue>(new SwizzleLValue(this, this->getType(s.fType),
+ this->getLValue(builder, *s.fBase),
+ s.fComponents));
+ }
+ default:
+ ABORT("unsupported lvalue");
+ }
+}
+
+JIT::TypeKind JIT::typeKind(const Type& type) {
+ if (type.kind() == Type::kVector_Kind) {
+ return this->typeKind(type.componentType());
+ }
+ if (type.fName == "int" || type.fName == "short") {
+ return JIT::kInt_TypeKind;
+ } else if (type.fName == "uint" || type.fName == "ushort") {
+ return JIT::kUInt_TypeKind;
+ } else if (type.fName == "float" || type.fName == "double") {
+ return JIT::kFloat_TypeKind;
+ }
+ ABORT("unsupported type: %s\n", type.description().c_str());
+}
+
+void JIT::vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns) {
+ LLVMValueRef result = LLVMGetUndef(LLVMVectorType(LLVMTypeOf(*value), columns));
+ for (int i = 0; i < columns; ++i) {
+ result = LLVMBuildInsertElement(builder,
+ result,
+ *value,
+ LLVMConstInt(fInt32Type, i, false),
+ "vectorize");
+ }
+ *value = result;
+}
+
+void JIT::vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left,
+ LLVMValueRef* right) {
+ if (b.fLeft->fType.kind() == Type::kScalar_Kind &&
+ b.fRight->fType.kind() == Type::kVector_Kind) {
+ this->vectorize(builder, left, b.fRight->fType.columns());
+ } else if (b.fLeft->fType.kind() == Type::kVector_Kind &&
+ b.fRight->fType.kind() == Type::kScalar_Kind) {
+ this->vectorize(builder, right, b.fLeft->fType.columns());
+ }
+}
+
+
+LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& b) {
+ #define BINARY(SFunc, UFunc, FFunc) { \
+ LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
+ LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
+ this->vectorize(builder, b, &left, &right); \
+ switch (this->typeKind(b.fLeft->fType)) { \
+ case kInt_TypeKind: \
+ return SFunc(builder, left, right, "binary"); \
+ case kUInt_TypeKind: \
+ return UFunc(builder, left, right, "binary"); \
+ case kFloat_TypeKind: \
+ return FFunc(builder, left, right, "binary"); \
+ default: \
+ ABORT("unsupported typeKind"); \
+ } \
+ }
+ #define COMPOUND(SFunc, UFunc, FFunc) { \
+ std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft); \
+ LLVMValueRef left = lvalue->load(builder); \
+ LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
+ this->vectorize(builder, b, &left, &right); \
+ LLVMValueRef result; \
+ switch (this->typeKind(b.fLeft->fType)) { \
+ case kInt_TypeKind: \
+ result = SFunc(builder, left, right, "binary"); \
+ break; \
+ case kUInt_TypeKind: \
+ result = UFunc(builder, left, right, "binary"); \
+ break; \
+ case kFloat_TypeKind: \
+ result = FFunc(builder, left, right, "binary"); \
+ break; \
+ default: \
+ ABORT("unsupported typeKind"); \
+ } \
+ lvalue->store(builder, result); \
+ return result; \
+ }
+ #define COMPARE(SFunc, SOp, UFunc, UOp, FFunc, FOp) { \
+ LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
+ LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
+ this->vectorize(builder, b, &left, &right); \
+ switch (this->typeKind(b.fLeft->fType)) { \
+ case kInt_TypeKind: \
+ return SFunc(builder, SOp, left, right, "binary"); \
+ case kUInt_TypeKind: \
+ return UFunc(builder, UOp, left, right, "binary"); \
+ case kFloat_TypeKind: \
+ return FFunc(builder, FOp, left, right, "binary"); \
+ default: \
+ ABORT("unsupported typeKind"); \
+ } \
+ }
+ switch (b.fOperator) {
+ case Token::EQ: {
+ std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft);
+ LLVMValueRef result = this->compileExpression(builder, *b.fRight);
+ lvalue->store(builder, result);
+ return result;
+ }
+ case Token::PLUS:
+ BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
+ case Token::MINUS:
+ BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
+ case Token::STAR:
+ BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
+ case Token::SLASH:
+ BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
+ case Token::PERCENT:
+ BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
+ case Token::BITWISEAND:
+ BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
+ case Token::BITWISEOR:
+ BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
+ case Token::PLUSEQ:
+ COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
+ case Token::MINUSEQ:
+ COMPOUND(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
+ case Token::STAREQ:
+ COMPOUND(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
+ case Token::SLASHEQ:
+ COMPOUND(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
+ case Token::BITWISEANDEQ:
+ COMPOUND(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
+ case Token::BITWISEOREQ:
+ COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
+ case Token::EQEQ:
+ COMPARE(LLVMBuildICmp, LLVMIntEQ,
+ LLVMBuildICmp, LLVMIntEQ,
+ LLVMBuildFCmp, LLVMRealOEQ);
+ case Token::NEQ:
+ COMPARE(LLVMBuildICmp, LLVMIntNE,
+ LLVMBuildICmp, LLVMIntNE,
+ LLVMBuildFCmp, LLVMRealONE);
+ case Token::LT:
+ COMPARE(LLVMBuildICmp, LLVMIntSLT,
+ LLVMBuildICmp, LLVMIntULT,
+ LLVMBuildFCmp, LLVMRealOLT);
+ case Token::LTEQ:
+ COMPARE(LLVMBuildICmp, LLVMIntSLE,
+ LLVMBuildICmp, LLVMIntULE,
+ LLVMBuildFCmp, LLVMRealOLE);
+ case Token::GT:
+ COMPARE(LLVMBuildICmp, LLVMIntSGT,
+ LLVMBuildICmp, LLVMIntUGT,
+ LLVMBuildFCmp, LLVMRealOGT);
+ case Token::GTEQ:
+ COMPARE(LLVMBuildICmp, LLVMIntSGE,
+ LLVMBuildICmp, LLVMIntUGE,
+ LLVMBuildFCmp, LLVMRealOGE);
+ case Token::LOGICALAND: {
+ LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
+ LLVMBasicBlockRef ifFalse = fCurrentBlock;
+ LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "true && ...");
+ LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "&& merge");
+ LLVMBuildCondBr(builder, left, ifTrue, merge);
+ this->setBlock(builder, ifTrue);
+ LLVMValueRef right = this->compileExpression(builder, *b.fRight);
+ LLVMBuildBr(builder, merge);
+ this->setBlock(builder, merge);
+ LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "&&");
+ LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 0, false) };
+ LLVMBasicBlockRef incomingBlocks[2] = { ifTrue, ifFalse };
+ LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
+ return phi;
+ }
+ case Token::LOGICALOR: {
+ LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
+ LLVMBasicBlockRef ifTrue = fCurrentBlock;
+ LLVMBasicBlockRef ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "false || ...");
+ LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "|| merge");
+ LLVMBuildCondBr(builder, left, merge, ifFalse);
+ this->setBlock(builder, ifFalse);
+ LLVMValueRef right = this->compileExpression(builder, *b.fRight);
+ LLVMBuildBr(builder, merge);
+ this->setBlock(builder, merge);
+ LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "||");
+ LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 1, false) };
+ LLVMBasicBlockRef incomingBlocks[2] = { ifFalse, ifTrue };
+ LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
+ return phi;
+ }
+ default:
+ ABORT("unsupported binary operator");
+ }
+}
+
+LLVMValueRef JIT::compileIndex(LLVMBuilderRef builder, const IndexExpression& idx) {
+ LLVMValueRef base = this->compileExpression(builder, *idx.fBase);
+ LLVMValueRef index = this->compileExpression(builder, *idx.fIndex);
+ LLVMValueRef ptr = LLVMBuildGEP(builder, base, &index, 1, "index ptr");
+ return LLVMBuildLoad(builder, ptr, "index load");
+}
+
+LLVMValueRef JIT::compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p) {
+ std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
+ LLVMValueRef result = lvalue->load(builder);
+ LLVMValueRef mod;
+ LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
+ switch (p.fOperator) {
+ case Token::PLUSPLUS:
+ switch (this->typeKind(p.fType)) {
+ case kInt_TypeKind: // fall through
+ case kUInt_TypeKind:
+ mod = LLVMBuildAdd(builder, result, one, "++");
+ break;
+ case kFloat_TypeKind:
+ mod = LLVMBuildFAdd(builder, result, one, "++");
+ break;
+ default:
+ ABORT("unsupported typeKind");
+ }
+ break;
+ case Token::MINUSMINUS:
+ switch (this->typeKind(p.fType)) {
+ case kInt_TypeKind: // fall through
+ case kUInt_TypeKind:
+ mod = LLVMBuildSub(builder, result, one, "--");
+ break;
+ case kFloat_TypeKind:
+ mod = LLVMBuildFSub(builder, result, one, "--");
+ break;
+ default:
+ ABORT("unsupported typeKind");
+ }
+ break;
+ default:
+ ABORT("unsupported postfix op");
+ }
+ lvalue->store(builder, mod);
+ return result;
+}
+
+LLVMValueRef JIT::compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p) {
+ LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
+ if (Token::LOGICALNOT == p.fOperator) {
+ LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
+ return LLVMBuildXor(builder, base, one, "!");
+ }
+ if (Token::MINUS == p.fOperator) {
+ LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
+ return LLVMBuildSub(builder, LLVMConstInt(this->getType(p.fType), 0, false), base, "-");
+ }
+ std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
+ LLVMValueRef raw = lvalue->load(builder);
+ LLVMValueRef result;
+ switch (p.fOperator) {
+ case Token::PLUSPLUS:
+ switch (this->typeKind(p.fType)) {
+ case kInt_TypeKind: // fall through
+ case kUInt_TypeKind:
+ result = LLVMBuildAdd(builder, raw, one, "++");
+ break;
+ case kFloat_TypeKind:
+ result = LLVMBuildFAdd(builder, raw, one, "++");
+ break;
+ default:
+ ABORT("unsupported typeKind");
+ }
+ break;
+ case Token::MINUSMINUS:
+ switch (this->typeKind(p.fType)) {
+ case kInt_TypeKind: // fall through
+ case kUInt_TypeKind:
+ result = LLVMBuildSub(builder, raw, one, "--");
+ break;
+ case kFloat_TypeKind:
+ result = LLVMBuildFSub(builder, raw, one, "--");
+ break;
+ default:
+ ABORT("unsupported typeKind");
+ }
+ break;
+ default:
+ ABORT("unsupported prefix op");
+ }
+ lvalue->store(builder, result);
+ return result;
+}
+
+LLVMValueRef JIT::compileVariableReference(LLVMBuilderRef builder, const VariableReference& v) {
+ const Variable& var = v.fVariable;
+ if (Variable::kParameter_Storage == var.fStorage &&
+ !(var.fModifiers.fFlags & Modifiers::kOut_Flag) &&
+ fPromotedParameters.find(&var) == fPromotedParameters.end()) {
+ return fVariables[&var];
+ }
+ return LLVMBuildLoad(builder, fVariables[&var], String(var.fName).c_str());
+}
+
+void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) {
+ ASSERT(a.fArguments.size() >= 1);
+ ASSERT(a.fArguments[0]->fType == *fCompiler.context().fSkRasterPipeline_Type);
+ LLVMValueRef pipeline = this->compileExpression(builder, *a.fArguments[0]);
+ LLVMValueRef stage = LLVMConstInt(fInt32Type, a.fStage, 0);
+ switch (a.fStage) {
+ case SkRasterPipeline::callback: {
+ ASSERT(a.fArguments.size() == 2);
+ ASSERT(a.fArguments[1]->fKind == Expression::kFunctionReference_Kind);
+ const FunctionDeclaration& functionDecl =
+ *((FunctionReference&) *a.fArguments[1]).fFunctions[0];
+ bool found = false;
+ for (const auto& pe : fProgram->fElements) {
+ if (ProgramElement::kFunction_Kind == pe->fKind) {
+ const FunctionDefinition& def = (const FunctionDefinition&) *pe;
+ if (&def.fDeclaration == &functionDecl) {
+ LLVMValueRef fn = this->compileStageFunction(def);
+ LLVMValueRef args[2] = {
+ pipeline,
+ LLVMBuildBitCast(builder, fn, fInt8PtrType, "callback cast")
+ };
+ LLVMBuildCall(builder, fAppendCallbackFunc, args, 2, "");
+ found = true;
+ break;
+ }
+ }
+ }
+ ASSERT(found);
+ break;
+ }
+ default: {
+ LLVMValueRef ctx;
+ if (a.fArguments.size() == 2) {
+ ctx = this->compileExpression(builder, *a.fArguments[1]);
+ ctx = LLVMBuildBitCast(builder, ctx, fInt8PtrType, "context cast");
+ } else {
+ ASSERT(a.fArguments.size() == 1);
+ ctx = LLVMConstNull(fInt8PtrType);
+ }
+ LLVMValueRef args[3] = {
+ pipeline,
+ stage,
+ ctx
+ };
+ LLVMBuildCall(builder, fAppendFunc, args, 3, "");
+ break;
+ }
+ }
+}
+
+LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor& c) {
+ switch (c.fType.kind()) {
+ case Type::kScalar_Kind: {
+ ASSERT(c.fArguments.size() == 1);
+ TypeKind from = this->typeKind(c.fArguments[0]->fType);
+ TypeKind to = this->typeKind(c.fType);
+ LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]);
+ if (kFloat_TypeKind == to) {
+ if (kInt_TypeKind == from) {
+ return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast");
+ }
+ if (kUInt_TypeKind == from) {
+ return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast");
+ }
+ }
+ if (kInt_TypeKind == to) {
+ if (kFloat_TypeKind == from) {
+ return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast");
+ }
+ if (kUInt_TypeKind == from) {
+ return base;
+ }
+ }
+ if (kUInt_TypeKind == to) {
+ if (kFloat_TypeKind == from) {
+ return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast");
+ }
+ if (kInt_TypeKind == from) {
+ return base;
+ }
+ }
+ ABORT("unsupported constructor");
+ }
+ case Type::kVector_Kind: {
+ LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType));
+ if (c.fArguments.size() == 1) {
+ LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]);
+ for (int i = 0; i < c.fType.columns(); ++i) {
+ vec = LLVMBuildInsertElement(builder, vec, value,
+ LLVMConstInt(fInt32Type, i, false),
+ "vec build");
+ }
+ } else {
+ ASSERT(c.fArguments.size() == (size_t) c.fType.columns());
+ for (int i = 0; i < c.fType.columns(); ++i) {
+ vec = LLVMBuildInsertElement(builder, vec,
+ this->compileExpression(builder,
+ *c.fArguments[i]),
+ LLVMConstInt(fInt32Type, i, false),
+ "vec build");
+ }
+ }
+ return vec;
+ }
+ default:
+ break;
+ }
+ ABORT("unsupported constructor");
+}
+
+LLVMValueRef JIT::compileSwizzle(LLVMBuilderRef builder, const Swizzle& s) {
+ LLVMValueRef base = this->compileExpression(builder, *s.fBase);
+ if (s.fComponents.size() > 1) {
+ LLVMValueRef result = LLVMGetUndef(this->getType(s.fType));
+ for (size_t i = 0; i < s.fComponents.size(); ++i) {
+ LLVMValueRef element = LLVMBuildExtractElement(
+ builder,
+ base,
+ LLVMConstInt(fInt32Type,
+ s.fComponents[i],
+ false),
+ "swizzle extract");
+ result = LLVMBuildInsertElement(builder, result, element,
+ LLVMConstInt(fInt32Type, i, false),
+ "swizzle insert");
+ }
+ return result;
+ }
+ ASSERT(s.fComponents.size() == 1);
+ return LLVMBuildExtractElement(builder, base,
+ LLVMConstInt(fInt32Type,
+ s.fComponents[0],
+ false),
+ "swizzle extract");
+}
+
+LLVMValueRef JIT::compileTernary(LLVMBuilderRef builder, const TernaryExpression& t) {
+ LLVMValueRef test = this->compileExpression(builder, *t.fTest);
+ LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "if true");
+ LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "if merge");
+ LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "if false");
+ LLVMBuildCondBr(builder, test, trueBlock, falseBlock);
+ this->setBlock(builder, trueBlock);
+ LLVMValueRef ifTrue = this->compileExpression(builder, *t.fIfTrue);
+ trueBlock = fCurrentBlock;
+ LLVMBuildBr(builder, merge);
+ this->setBlock(builder, falseBlock);
+ LLVMValueRef ifFalse = this->compileExpression(builder, *t.fIfFalse);
+ falseBlock = fCurrentBlock;
+ LLVMBuildBr(builder, merge);
+ this->setBlock(builder, merge);
+ LLVMValueRef phi = LLVMBuildPhi(builder, this->getType(t.fType), "?");
+ LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
+ LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
+ LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
+ return phi;
+}
+
+LLVMValueRef JIT::compileExpression(LLVMBuilderRef builder, const Expression& expr) {
+ switch (expr.fKind) {
+ case Expression::kAppendStage_Kind: {
+ this->appendStage(builder, (const AppendStage&) expr);
+ return LLVMValueRef();
+ }
+ case Expression::kBinary_Kind:
+ return this->compileBinary(builder, (BinaryExpression&) expr);
+ case Expression::kBoolLiteral_Kind:
+ return LLVMConstInt(fInt1Type, ((BoolLiteral&) expr).fValue, false);
+ case Expression::kConstructor_Kind:
+ return this->compileConstructor(builder, (Constructor&) expr);
+ case Expression::kIntLiteral_Kind:
+ return LLVMConstInt(this->getType(expr.fType), ((IntLiteral&) expr).fValue, true);
+ case Expression::kFieldAccess_Kind:
+ abort();
+ case Expression::kFloatLiteral_Kind:
+ return LLVMConstReal(this->getType(expr.fType), ((FloatLiteral&) expr).fValue);
+ case Expression::kFunctionCall_Kind:
+ return this->compileFunctionCall(builder, (FunctionCall&) expr);
+ case Expression::kIndex_Kind:
+ return this->compileIndex(builder, (IndexExpression&) expr);
+ case Expression::kPrefix_Kind:
+ return this->compilePrefix(builder, (PrefixExpression&) expr);
+ case Expression::kPostfix_Kind:
+ return this->compilePostfix(builder, (PostfixExpression&) expr);
+ case Expression::kSetting_Kind:
+ abort();
+ case Expression::kSwizzle_Kind:
+ return this->compileSwizzle(builder, (Swizzle&) expr);
+ case Expression::kVariableReference_Kind:
+ return this->compileVariableReference(builder, (VariableReference&) expr);
+ case Expression::kTernary_Kind:
+ return this->compileTernary(builder, (TernaryExpression&) expr);
+ case Expression::kTypeReference_Kind:
+ abort();
+ default:
+ abort();
+ }
+ ABORT("unsupported expression: %s\n", expr.description().c_str());
+}
+
+void JIT::compileBlock(LLVMBuilderRef builder, const Block& block) {
+ for (const auto& stmt : block.fStatements) {
+ this->compileStatement(builder, *stmt);
+ }
+}
+
+void JIT::compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls) {
+ for (const auto& declStatement : decls.fDeclaration->fVars) {
+ const VarDeclaration& decl = (VarDeclaration&) *declStatement;
+ LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
+ LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(decl.fVar->fType),
+ String(decl.fVar->fName).c_str());
+ fVariables[decl.fVar] = alloca;
+ LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
+ if (decl.fValue) {
+ LLVMValueRef result = this->compileExpression(builder, *decl.fValue);
+ LLVMBuildStore(builder, result, alloca);
+ }
+ }
+}
+
+void JIT::compileIf(LLVMBuilderRef builder, const IfStatement& i) {
+ LLVMValueRef test = this->compileExpression(builder, *i.fTest);
+ LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if true");
+ LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "if merge");
+ LLVMBasicBlockRef ifFalse;
+ if (i.fIfFalse) {
+ ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if false");
+ } else {
+ ifFalse = merge;
+ }
+ LLVMBuildCondBr(builder, test, ifTrue, ifFalse);
+ this->setBlock(builder, ifTrue);
+ this->compileStatement(builder, *i.fIfTrue);
+ if (!ends_with_branch(*i.fIfTrue)) {
+ LLVMBuildBr(builder, merge);
+ }
+ if (i.fIfFalse) {
+ this->setBlock(builder, ifFalse);
+ this->compileStatement(builder, *i.fIfFalse);
+ if (!ends_with_branch(*i.fIfFalse)) {
+ LLVMBuildBr(builder, merge);
+ }
+ }
+ this->setBlock(builder, merge);
+}
+
+void JIT::compileFor(LLVMBuilderRef builder, const ForStatement& f) {
+ if (f.fInitializer) {
+ this->compileStatement(builder, *f.fInitializer);
+ }
+ LLVMBasicBlockRef start;
+ LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for body");
+ LLVMBasicBlockRef next = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for next");
+ LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for end");
+ if (f.fTest) {
+ start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for test");
+ LLVMBuildBr(builder, start);
+ this->setBlock(builder, start);
+ LLVMValueRef test = this->compileExpression(builder, *f.fTest);
+ LLVMBuildCondBr(builder, test, body, end);
+ } else {
+ start = body;
+ LLVMBuildBr(builder, body);
+ }
+ this->setBlock(builder, body);
+ fBreakTarget.push_back(end);
+ fContinueTarget.push_back(next);
+ this->compileStatement(builder, *f.fStatement);
+ fBreakTarget.pop_back();
+ fContinueTarget.pop_back();
+ if (!ends_with_branch(*f.fStatement)) {
+ LLVMBuildBr(builder, next);
+ }
+ this->setBlock(builder, next);
+ if (f.fNext) {
+ this->compileExpression(builder, *f.fNext);
+ }
+ LLVMBuildBr(builder, start);
+ this->setBlock(builder, end);
+}
+
+void JIT::compileDo(LLVMBuilderRef builder, const DoStatement& d) {
+ LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "do test");
+ LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "do body");
+ LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "do end");
+ LLVMBuildBr(builder, body);
+ this->setBlock(builder, testBlock);
+ LLVMValueRef test = this->compileExpression(builder, *d.fTest);
+ LLVMBuildCondBr(builder, test, body, end);
+ this->setBlock(builder, body);
+ fBreakTarget.push_back(end);
+ fContinueTarget.push_back(body);
+ this->compileStatement(builder, *d.fStatement);
+ fBreakTarget.pop_back();
+ fContinueTarget.pop_back();
+ if (!ends_with_branch(*d.fStatement)) {
+ LLVMBuildBr(builder, testBlock);
+ }
+ this->setBlock(builder, end);
+}
+
+void JIT::compileWhile(LLVMBuilderRef builder, const WhileStatement& w) {
+ LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "while test");
+ LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "while body");
+ LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
+ "while end");
+ LLVMBuildBr(builder, testBlock);
+ this->setBlock(builder, testBlock);
+ LLVMValueRef test = this->compileExpression(builder, *w.fTest);
+ LLVMBuildCondBr(builder, test, body, end);
+ this->setBlock(builder, body);
+ fBreakTarget.push_back(end);
+ fContinueTarget.push_back(testBlock);
+ this->compileStatement(builder, *w.fStatement);
+ fBreakTarget.pop_back();
+ fContinueTarget.pop_back();
+ if (!ends_with_branch(*w.fStatement)) {
+ LLVMBuildBr(builder, testBlock);
+ }
+ this->setBlock(builder, end);
+}
+
+void JIT::compileBreak(LLVMBuilderRef builder, const BreakStatement& b) {
+ LLVMBuildBr(builder, fBreakTarget.back());
+}
+
+void JIT::compileContinue(LLVMBuilderRef builder, const ContinueStatement& b) {
+ LLVMBuildBr(builder, fContinueTarget.back());
+}
+
+void JIT::compileReturn(LLVMBuilderRef builder, const ReturnStatement& r) {
+ if (r.fExpression) {
+ LLVMBuildRet(builder, this->compileExpression(builder, *r.fExpression));
+ } else {
+ LLVMBuildRetVoid(builder);
+ }
+}
+
+void JIT::compileStatement(LLVMBuilderRef builder, const Statement& stmt) {
+ switch (stmt.fKind) {
+ case Statement::kBlock_Kind:
+ this->compileBlock(builder, (Block&) stmt);
+ break;
+ case Statement::kBreak_Kind:
+ this->compileBreak(builder, (BreakStatement&) stmt);
+ break;
+ case Statement::kContinue_Kind:
+ this->compileContinue(builder, (ContinueStatement&) stmt);
+ break;
+ case Statement::kDiscard_Kind:
+ abort();
+ case Statement::kDo_Kind:
+ this->compileDo(builder, (DoStatement&) stmt);
+ break;
+ case Statement::kExpression_Kind:
+ this->compileExpression(builder, *((ExpressionStatement&) stmt).fExpression);
+ break;
+ case Statement::kFor_Kind:
+ this->compileFor(builder, (ForStatement&) stmt);
+ break;
+ case Statement::kGroup_Kind:
+ abort();
+ case Statement::kIf_Kind:
+ this->compileIf(builder, (IfStatement&) stmt);
+ break;
+ case Statement::kNop_Kind:
+ break;
+ case Statement::kReturn_Kind:
+ this->compileReturn(builder, (ReturnStatement&) stmt);
+ break;
+ case Statement::kSwitch_Kind:
+ abort();
+ case Statement::kVarDeclarations_Kind:
+ this->compileVarDeclarations(builder, (VarDeclarationsStatement&) stmt);
+ break;
+ case Statement::kWhile_Kind:
+ this->compileWhile(builder, (WhileStatement&) stmt);
+ break;
+ default:
+ abort();
+ }
+}
+
+void JIT::compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc) {
+ // loop over fVectorCount pixels, running the body of the stage function for each of them
+ LLVMValueRef oldFunction = fCurrentFunction;
+ fCurrentFunction = newFunc;
+ std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
+ LLVMGetParams(fCurrentFunction, params.get());
+ LLVMValueRef programParam = params.get()[1];
+ LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
+ LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
+ LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
+ fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
+ this->setBlock(builder, fAllocaBlock);
+ // temporaries to store the color channel vectors
+ LLVMValueRef rVec = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
+ LLVMBuildStore(builder, params.get()[4], rVec);
+ LLVMValueRef gVec = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
+ LLVMBuildStore(builder, params.get()[5], gVec);
+ LLVMValueRef bVec = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
+ LLVMBuildStore(builder, params.get()[6], bVec);
+ LLVMValueRef aVec = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
+ LLVMBuildStore(builder, params.get()[7], aVec);
+ LLVMValueRef color = LLVMBuildAlloca(builder, fFloat32Vector4Type, "color");
+ fVariables[f.fDeclaration.fParameters[1]] = LLVMBuildTrunc(builder, params.get()[3], fInt32Type,
+ "y->Int32");
+ fVariables[f.fDeclaration.fParameters[2]] = color;
+ LLVMValueRef ivar = LLVMBuildAlloca(builder, fInt32Type, "i");
+ LLVMBuildStore(builder, LLVMConstInt(fInt32Type, 0, false), ivar);
+ LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
+ this->setBlock(builder, start);
+ LLVMValueRef iload = LLVMBuildLoad(builder, ivar, "load i");
+ fVariables[f.fDeclaration.fParameters[0]] = LLVMBuildAdd(builder,
+ LLVMBuildTrunc(builder,
+ params.get()[2],
+ fInt32Type,
+ "x->Int32"),
+ iload,
+ "x");
+ LLVMValueRef vectorSize = LLVMConstInt(fInt32Type, fVectorCount, false);
+ LLVMValueRef test = LLVMBuildICmp(builder, LLVMIntSLT, iload, vectorSize, "i < vectorSize");
+ LLVMBasicBlockRef loopBody = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "body");
+ LLVMBasicBlockRef loopEnd = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "end");
+ LLVMBuildCondBr(builder, test, loopBody, loopEnd);
+ this->setBlock(builder, loopBody);
+ LLVMValueRef vec = LLVMGetUndef(fFloat32Vector4Type);
+ // extract the r, g, b, and a values from the color channel vectors and store them into "color"
+ for (int i = 0; i < 4; ++i) {
+ vec = LLVMBuildInsertElement(builder, vec,
+ LLVMBuildExtractElement(builder,
+ params.get()[4 + i],
+ iload, "initial"),
+ LLVMConstInt(fInt32Type, i, false),
+ "vec build");
+ }
+ LLVMBuildStore(builder, vec, color);
+ // write actual loop body
+ this->compileStatement(builder, *f.fBody);
+ // extract the r, g, b, and a values from "color" and stick them back into the color channel
+ // vectors
+ LLVMValueRef colorLoad = LLVMBuildLoad(builder, color, "color load");
+ LLVMBuildStore(builder,
+ LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, rVec, "rVec"),
+ LLVMBuildExtractElement(builder, colorLoad,
+ LLVMConstInt(fInt32Type, 0,
+ false),
+ "rExtract"),
+ iload, "rInsert"),
+ rVec);
+ LLVMBuildStore(builder,
+ LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, gVec, "gVec"),
+ LLVMBuildExtractElement(builder, colorLoad,
+ LLVMConstInt(fInt32Type, 1,
+ false),
+ "gExtract"),
+ iload, "gInsert"),
+ gVec);
+ LLVMBuildStore(builder,
+ LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, bVec, "bVec"),
+ LLVMBuildExtractElement(builder, colorLoad,
+ LLVMConstInt(fInt32Type, 2,
+ false),
+ "bExtract"),
+ iload, "bInsert"),
+ bVec);
+ LLVMBuildStore(builder,
+ LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, aVec, "aVec"),
+ LLVMBuildExtractElement(builder, colorLoad,
+ LLVMConstInt(fInt32Type, 3,
+ false),
+ "aExtract"),
+ iload, "aInsert"),
+ aVec);
+ LLVMValueRef inc = LLVMBuildAdd(builder, iload, LLVMConstInt(fInt32Type, 1, false), "inc i");
+ LLVMBuildStore(builder, inc, ivar);
+ LLVMBuildBr(builder, start);
+ this->setBlock(builder, loopEnd);
+ // increment program pointer, call the next stage
+ LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
+ LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
+ LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType, "cast next->func");
+ LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
+ LLVMBuildAdd(builder,
+ LLVMBuildPtrToInt(builder,
+ programParam,
+ fInt64Type,
+ "cast 1"),
+ LLVMConstInt(fInt64Type, PTR_SIZE, false),
+ "add"),
+ LLVMPointerType(fInt8PtrType, 0), "cast 2");
+ LLVMValueRef args[STAGE_PARAM_COUNT] = {
+ params.get()[0],
+ nextInc,
+ params.get()[2],
+ params.get()[3],
+ LLVMBuildLoad(builder, rVec, "rVec"),
+ LLVMBuildLoad(builder, gVec, "gVec"),
+ LLVMBuildLoad(builder, bVec, "bVec"),
+ LLVMBuildLoad(builder, aVec, "aVec"),
+ params.get()[8],
+ params.get()[9],
+ params.get()[10],
+ params.get()[11]
+ };
+ LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
+ LLVMBuildRetVoid(builder);
+ // finish
+ LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
+ LLVMBuildBr(builder, start);
+ LLVMDisposeBuilder(builder);
+ if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
+ ABORT("verify failed\n");
+ }
+ fAllocaBlock = oldAllocaBlock;
+ fCurrentBlock = oldCurrentBlock;
+ fCurrentFunction = oldFunction;
+}
+
+// FIXME maybe pluggable code generators? Need to do something to separate all
+// of the normal codegen from the vector codegen and break this up into multiple
+// classes.
+
+bool JIT::getVectorLValue(LLVMBuilderRef builder, const Expression& e,
+ LLVMValueRef out[CHANNELS]) {
+ switch (e.fKind) {
+ case Expression::kVariableReference_Kind:
+ if (fColorParam == &((VariableReference&) e).fVariable) {
+ memcpy(out, fChannels, sizeof(fChannels));
+ return true;
+ }
+ return false;
+ case Expression::kSwizzle_Kind: {
+ const Swizzle& s = (const Swizzle&) e;
+ LLVMValueRef base[CHANNELS];
+ if (!this->getVectorLValue(builder, *s.fBase, base)) {
+ return false;
+ }
+ for (size_t i = 0; i < s.fComponents.size(); ++i) {
+ out[i] = base[s.fComponents[i]];
+ }
+ return true;
+ }
+ default:
+ return false;
+ }
+}
+
+bool JIT::getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left,
+ LLVMValueRef outLeft[CHANNELS], const Expression& right,
+ LLVMValueRef outRight[CHANNELS]) {
+ if (!this->compileVectorExpression(builder, left, outLeft)) {
+ return false;
+ }
+ int leftColumns = left.fType.columns();
+ int rightColumns = right.fType.columns();
+ if (leftColumns == 1 && rightColumns > 1) {
+ for (int i = 1; i < rightColumns; ++i) {
+ outLeft[i] = outLeft[0];
+ }
+ }
+ if (!this->compileVectorExpression(builder, right, outRight)) {
+ return false;
+ }
+ if (rightColumns == 1 && leftColumns > 1) {
+ for (int i = 1; i < leftColumns; ++i) {
+ outRight[i] = outRight[0];
+ }
+ }
+ return true;
+}
+
+bool JIT::compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b,
+ LLVMValueRef out[CHANNELS]) {
+ LLVMValueRef left[CHANNELS];
+ LLVMValueRef right[CHANNELS];
+ #define VECTOR_BINARY(signedOp, unsignedOp, floatOp) { \
+ if (!this->getVectorBinaryOperands(builder, *b.fLeft, left, *b.fRight, right)) { \
+ return false; \
+ } \
+ for (int i = 0; i < b.fLeft->fType.columns(); ++i) { \
+ switch (this->typeKind(b.fLeft->fType)) { \
+ case kInt_TypeKind: \
+ out[i] = signedOp(builder, left[i], right[i], "binary"); \
+ break; \
+ case kUInt_TypeKind: \
+ out[i] = unsignedOp(builder, left[i], right[i], "binary"); \
+ break; \
+ case kFloat_TypeKind: \
+ out[i] = floatOp(builder, left[i], right[i], "binary"); \
+ break; \
+ case kBool_TypeKind: \
+ ASSERT(false); \
+ break; \
+ } \
+ } \
+ return true; \
+ }
+ switch (b.fOperator) {
+ case Token::EQ: {
+ if (!this->getVectorLValue(builder, *b.fLeft, left)) {
+ return false;
+ }
+ if (!this->compileVectorExpression(builder, *b.fRight, right)) {
+ return false;
+ }
+ int columns = b.fRight->fType.columns();
+ for (int i = 0; i < columns; ++i) {
+ LLVMBuildStore(builder, right[i], left[i]);
+ }
+ return true;
+ }
+ case Token::PLUS:
+ VECTOR_BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
+ case Token::MINUS:
+ VECTOR_BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
+ case Token::STAR:
+ VECTOR_BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
+ case Token::SLASH:
+ VECTOR_BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
+ case Token::PERCENT:
+ VECTOR_BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
+ case Token::BITWISEAND:
+ VECTOR_BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
+ case Token::BITWISEOR:
+ VECTOR_BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
+ default:
+ printf("unsupported operator: %s\n", b.description().c_str());
+ return false;
+ }
+}
+
+bool JIT::compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c,
+ LLVMValueRef out[CHANNELS]) {
+ switch (c.fType.kind()) {
+ case Type::kScalar_Kind: {
+ ASSERT(c.fArguments.size() == 1);
+ TypeKind from = this->typeKind(c.fArguments[0]->fType);
+ TypeKind to = this->typeKind(c.fType);
+ LLVMValueRef base[CHANNELS];
+ if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
+ return false;
+ }
+ #define CONSTRUCT(fn) \
+ out[0] = LLVMGetUndef(LLVMVectorType(this->getType(c.fType), fVectorCount)); \
+ for (int i = 0; i < fVectorCount; ++i) { \
+ LLVMValueRef index = LLVMConstInt(fInt32Type, i, false); \
+ LLVMValueRef baseVal = LLVMBuildExtractElement(builder, base[0], index, \
+ "construct extract"); \
+ out[0] = LLVMBuildInsertElement(builder, out[0], \
+ fn(builder, baseVal, this->getType(c.fType), \
+ "cast"), \
+ index, "construct insert"); \
+ } \
+ return true;
+ if (kFloat_TypeKind == to) {
+ if (kInt_TypeKind == from) {
+ CONSTRUCT(LLVMBuildSIToFP);
+ }
+ if (kUInt_TypeKind == from) {
+ CONSTRUCT(LLVMBuildUIToFP);
+ }
+ }
+ if (kInt_TypeKind == to) {
+ if (kFloat_TypeKind == from) {
+ CONSTRUCT(LLVMBuildFPToSI);
+ }
+ if (kUInt_TypeKind == from) {
+ return true;
+ }
+ }
+ if (kUInt_TypeKind == to) {
+ if (kFloat_TypeKind == from) {
+ CONSTRUCT(LLVMBuildFPToUI);
+ }
+ if (kInt_TypeKind == from) {
+ return base;
+ }
+ }
+ printf("%s\n", c.description().c_str());
+ ABORT("unsupported constructor");
+ }
+ case Type::kVector_Kind: {
+ if (c.fArguments.size() == 1) {
+ LLVMValueRef base[CHANNELS];
+ if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
+ return false;
+ }
+ for (int i = 0; i < c.fType.columns(); ++i) {
+ out[i] = base[0];
+ }
+ } else {
+ ASSERT(c.fArguments.size() == (size_t) c.fType.columns());
+ for (int i = 0; i < c.fType.columns(); ++i) {
+ LLVMValueRef base[CHANNELS];
+ if (!this->compileVectorExpression(builder, *c.fArguments[i], base)) {
+ return false;
+ }
+ out[i] = base[0];
+ }
+ }
+ return true;
+ }
+ default:
+ break;
+ }
+ ABORT("unsupported constructor");
+}
+
+bool JIT::compileVectorFloatLiteral(LLVMBuilderRef builder,
+ const FloatLiteral& f,
+ LLVMValueRef out[CHANNELS]) {
+ LLVMValueRef value = LLVMConstReal(this->getType(f.fType), f.fValue);
+ LLVMValueRef values[MAX_VECTOR_COUNT];
+ for (int i = 0; i < fVectorCount; ++i) {
+ values[i] = value;
+ }
+ out[0] = LLVMConstVector(values, fVectorCount);
+ return true;
+}
+
+
+bool JIT::compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s,
+ LLVMValueRef out[CHANNELS]) {
+ LLVMValueRef all[CHANNELS];
+ if (!this->compileVectorExpression(builder, *s.fBase, all)) {
+ return false;
+ }
+ for (size_t i = 0; i < s.fComponents.size(); ++i) {
+ out[i] = all[s.fComponents[i]];
+ }
+ return true;
+}
+
+bool JIT::compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v,
+ LLVMValueRef out[CHANNELS]) {
+ if (&v.fVariable == fColorParam) {
+ for (int i = 0; i < CHANNELS; ++i) {
+ out[i] = LLVMBuildLoad(builder, fChannels[i], "variable reference");
+ }
+ return true;
+ }
+ return false;
+}
+
+bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr,
+ LLVMValueRef out[CHANNELS]) {
+ switch (expr.fKind) {
+ case Expression::kBinary_Kind:
+ return this->compileVectorBinary(builder, (const BinaryExpression&) expr, out);
+ case Expression::kConstructor_Kind:
+ return this->compileVectorConstructor(builder, (const Constructor&) expr, out);
+ case Expression::kFloatLiteral_Kind:
+ return this->compileVectorFloatLiteral(builder, (const FloatLiteral&) expr, out);
+ case Expression::kSwizzle_Kind:
+ return this->compileVectorSwizzle(builder, (const Swizzle&) expr, out);
+ case Expression::kVariableReference_Kind:
+ return this->compileVectorVariableReference(builder, (const VariableReference&) expr,
+ out);
+ default:
+ printf("failed expression: %s\n", expr.description().c_str());
+ return false;
+ }
+}
+
+bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt) {
+ switch (stmt.fKind) {
+ case Statement::kBlock_Kind:
+ for (const auto& s : ((const Block&) stmt).fStatements) {
+ if (!this->compileVectorStatement(builder, *s)) {
+ return false;
+ }
+ }
+ return true;
+ case Statement::kExpression_Kind:
+ LLVMValueRef result;
+ return this->compileVectorExpression(builder,
+ *((const ExpressionStatement&) stmt).fExpression,
+ &result);
+ default:
+ printf("failed statement: %s\n", stmt.description().c_str());
+ return false;
+ }
+}
+
+bool JIT::compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc) {
+ LLVMValueRef oldFunction = fCurrentFunction;
+ fCurrentFunction = newFunc;
+ std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
+ LLVMGetParams(fCurrentFunction, params.get());
+ LLVMValueRef programParam = params.get()[1];
+ LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
+ LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
+ LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
+ fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
+ this->setBlock(builder, fAllocaBlock);
+ fChannels[0] = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
+ LLVMBuildStore(builder, params.get()[4], fChannels[0]);
+ fChannels[1] = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
+ LLVMBuildStore(builder, params.get()[5], fChannels[1]);
+ fChannels[2] = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
+ LLVMBuildStore(builder, params.get()[6], fChannels[2]);
+ fChannels[3] = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
+ LLVMBuildStore(builder, params.get()[7], fChannels[3]);
+ LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
+ this->setBlock(builder, start);
+ bool success = this->compileVectorStatement(builder, *f.fBody);
+ if (success) {
+ // increment program pointer, call next
+ LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
+ LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
+ LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType,
+ "cast next->func");
+ LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
+ LLVMBuildAdd(builder,
+ LLVMBuildPtrToInt(builder,
+ programParam,
+ fInt64Type,
+ "cast 1"),
+ LLVMConstInt(fInt64Type, PTR_SIZE,
+ false),
+ "add"),
+ LLVMPointerType(fInt8PtrType, 0), "cast 2");
+ LLVMValueRef args[STAGE_PARAM_COUNT] = {
+ params.get()[0],
+ nextInc,
+ params.get()[2],
+ params.get()[3],
+ LLVMBuildLoad(builder, fChannels[0], "rVec"),
+ LLVMBuildLoad(builder, fChannels[1], "gVec"),
+ LLVMBuildLoad(builder, fChannels[2], "bVec"),
+ LLVMBuildLoad(builder, fChannels[3], "aVec"),
+ params.get()[8],
+ params.get()[9],
+ params.get()[10],
+ params.get()[11]
+ };
+ LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
+ LLVMBuildRetVoid(builder);
+ // finish
+ LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
+ LLVMBuildBr(builder, start);
+ LLVMDisposeBuilder(builder);
+ if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
+ ABORT("verify failed\n");
+ }
+ } else {
+ LLVMDeleteBasicBlock(fAllocaBlock);
+ LLVMDeleteBasicBlock(start);
+ }
+
+ fAllocaBlock = oldAllocaBlock;
+ fCurrentBlock = oldCurrentBlock;
+ fCurrentFunction = oldFunction;
+ return success;
+}
+
+LLVMValueRef JIT::compileStageFunction(const FunctionDefinition& f) {
+ LLVMTypeRef returnType = fVoidType;
+ LLVMTypeRef parameterTypes[12] = { fSizeTType, LLVMPointerType(fInt8PtrType, 0), fSizeTType,
+ fSizeTType, fFloat32VectorType, fFloat32VectorType,
+ fFloat32VectorType, fFloat32VectorType, fFloat32VectorType,
+ fFloat32VectorType, fFloat32VectorType, fFloat32VectorType };
+ LLVMTypeRef stageFuncType = LLVMFunctionType(returnType, parameterTypes, 12, false);
+ LLVMValueRef result = LLVMAddFunction(fModule,
+ (String(f.fDeclaration.fName) + "$stage").c_str(),
+ stageFuncType);
+ fColorParam = f.fDeclaration.fParameters[2];
+ if (!this->compileStageFunctionVector(f, result)) {
+ // vectorization failed, fall back to looping over the pixels
+ this->compileStageFunctionLoop(f, result);
+ }
+ return result;
+}
+
+bool JIT::hasStageSignature(const FunctionDeclaration& f) {
+ return f.fReturnType == *fProgram->fContext->fVoid_Type &&
+ f.fParameters.size() == 3 &&
+ f.fParameters[0]->fType == *fProgram->fContext->fInt_Type &&
+ f.fParameters[0]->fModifiers.fFlags == 0 &&
+ f.fParameters[1]->fType == *fProgram->fContext->fInt_Type &&
+ f.fParameters[1]->fModifiers.fFlags == 0 &&
+ f.fParameters[2]->fType == *fProgram->fContext->fFloat4_Type &&
+ f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag);
+}
+
+LLVMValueRef JIT::compileFunction(const FunctionDefinition& f) {
+ if (this->hasStageSignature(f.fDeclaration)) {
+ this->compileStageFunction(f);
+ // we compile foo$stage *in addition* to compiling foo, as we can't be sure that the intent
+ // was to produce an SkJumper stage just because the signature matched or that the function
+ // is not otherwise called. May need a better way to handle this.
+ }
+ LLVMTypeRef returnType = this->getType(f.fDeclaration.fReturnType);
+ std::vector<LLVMTypeRef> parameterTypes;
+ for (const auto& p : f.fDeclaration.fParameters) {
+ LLVMTypeRef type = this->getType(p->fType);
+ if (p->fModifiers.fFlags & Modifiers::kOut_Flag) {
+ type = LLVMPointerType(type, 0);
+ }
+ parameterTypes.push_back(type);
+ }
+ fCurrentFunction = LLVMAddFunction(fModule,
+ String(f.fDeclaration.fName).c_str(),
+ LLVMFunctionType(returnType, parameterTypes.data(),
+ parameterTypes.size(), false));
+ fFunctions[&f.fDeclaration] = fCurrentFunction;
+
+ std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[parameterTypes.size()]);
+ LLVMGetParams(fCurrentFunction, params.get());
+ for (size_t i = 0; i < f.fDeclaration.fParameters.size(); ++i) {
+ fVariables[f.fDeclaration.fParameters[i]] = params.get()[i];
+ }
+ LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
+ fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
+ LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
+ fCurrentBlock = start;
+ LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
+ this->compileStatement(builder, *f.fBody);
+ if (!ends_with_branch(*f.fBody)) {
+ if (f.fDeclaration.fReturnType == *fProgram->fContext->fVoid_Type) {
+ LLVMBuildRetVoid(builder);
+ } else {
+ LLVMBuildUnreachable(builder);
+ }
+ }
+ LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
+ LLVMBuildBr(builder, start);
+ LLVMDisposeBuilder(builder);
+ if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
+ ABORT("verify failed\n");
+ }
+ return fCurrentFunction;
+}
+
+void JIT::createModule() {
+ fPromotedParameters.clear();
+ fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext);
+ this->loadBuiltinFunctions();
+ // LLVM doesn't do void*, have to declare it as int8*
+ LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType };
+ fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType,
+ appendParams,
+ 3,
+ false));
+ LLVMTypeRef appendCallbackParams[2] = { fInt8PtrType, fInt8PtrType };
+ fAppendCallbackFunc = LLVMAddFunction(fModule, "sksl_pipeline_append_callback",
+ LLVMFunctionType(fVoidType, appendCallbackParams, 2,
+ false));
+
+ LLVMTypeRef debugParams[3] = { fFloat32Type };
+ fDebugFunc = LLVMAddFunction(fModule, "sksl_debug_print", LLVMFunctionType(fVoidType,
+ debugParams,
+ 1,
+ false));
+
+ for (const auto& e : fProgram->fElements) {
+ ASSERT(e->fKind == ProgramElement::kFunction_Kind);
+ this->compileFunction((FunctionDefinition&) *e);
+ }
+}
+
+std::unique_ptr<JIT::Module> JIT::compile(std::unique_ptr<Program> program) {
+ fProgram = std::move(program);
+ this->createModule();
+ this->optimize();
+ return std::unique_ptr<Module>(new Module(std::move(fProgram), fSharedModule, fJITStack));
+}
+
+void JIT::optimize() {
+ LLVMPassManagerBuilderRef pmb = LLVMPassManagerBuilderCreate();
+ LLVMPassManagerBuilderSetOptLevel(pmb, 3);
+ LLVMPassManagerRef functionPM = LLVMCreateFunctionPassManagerForModule(fModule);
+ LLVMPassManagerBuilderPopulateFunctionPassManager(pmb, functionPM);
+ LLVMPassManagerRef modulePM = LLVMCreatePassManager();
+ LLVMPassManagerBuilderPopulateModulePassManager(pmb, modulePM);
+ LLVMInitializeFunctionPassManager(functionPM);
+
+ LLVMValueRef func = LLVMGetFirstFunction(fModule);
+ for (;;) {
+ if (!func) {
+ break;
+ }
+ LLVMRunFunctionPassManager(functionPM, func);
+ func = LLVMGetNextFunction(func);
+ }
+ LLVMRunPassManager(modulePM, fModule);
+ LLVMDisposePassManager(functionPM);
+ LLVMDisposePassManager(modulePM);
+ LLVMPassManagerBuilderDispose(pmb);
+
+ std::string error_string;
+ if (LLVMLoadLibraryPermanently(nullptr)) {
+ ABORT("LLVMLoadLibraryPermanently failed");
+ }
+ char* defaultTriple = LLVMGetDefaultTargetTriple();
+ char* error;
+ LLVMTargetRef target;
+ if (LLVMGetTargetFromTriple(defaultTriple, &target, &error)) {
+ ABORT("LLVMGetTargetFromTriple failed");
+ }
+
+ if (!LLVMTargetHasJIT(target)) {
+ ABORT("!LLVMTargetHasJIT");
+ }
+ LLVMTargetMachineRef targetMachine = LLVMCreateTargetMachine(target,
+ defaultTriple,
+ fCPU,
+ nullptr,
+ LLVMCodeGenLevelDefault,
+ LLVMRelocDefault,
+ LLVMCodeModelJITDefault);
+ LLVMDisposeMessage(defaultTriple);
+ LLVMTargetDataRef dataLayout = LLVMCreateTargetDataLayout(targetMachine);
+ LLVMSetModuleDataLayout(fModule, dataLayout);
+ LLVMDisposeTargetData(dataLayout);
+
+ fJITStack = LLVMOrcCreateInstance(targetMachine);
+ fSharedModule = LLVMOrcMakeSharedModule(fModule);
+ LLVMOrcModuleHandle orcModule;
+ LLVMOrcAddEagerlyCompiledIR(fJITStack, &orcModule, fSharedModule,
+ (LLVMOrcSymbolResolverFn) resolveSymbol, this);
+ LLVMDisposeTargetMachine(targetMachine);
+}
+
+void* JIT::Module::getSymbol(const char* name) {
+ LLVMOrcTargetAddress result;
+ if (LLVMOrcGetSymbolAddress(fJITStack, &result, name)) {
+ ABORT("GetSymbolAddress error");
+ }
+ if (!result) {
+ ABORT("symbol not found");
+ }
+ return (void*) result;
+}
+
+void* JIT::Module::getJumperStage(const char* name) {
+ return this->getSymbol((String(name) + "$stage").c_str());
+}
+
+} // namespace
+
+#endif // SK_LLVM_AVAILABLE
+
+#endif // SKSL_STANDALONE