diff options
Diffstat (limited to 'src/sksl/SkSLJIT.cpp')
-rw-r--r-- | src/sksl/SkSLJIT.cpp | 291 |
1 files changed, 238 insertions, 53 deletions
diff --git a/src/sksl/SkSLJIT.cpp b/src/sksl/SkSLJIT.cpp index 115a6be3ae..4120c4aa8c 100644 --- a/src/sksl/SkSLJIT.cpp +++ b/src/sksl/SkSLJIT.cpp @@ -14,11 +14,13 @@ #include "SkCpu.h" #include "SkRasterPipeline.h" #include "../jumper/SkJumper.h" +#include "ir/SkSLAppendStage.h" #include "ir/SkSLExpressionStatement.h" #include "ir/SkSLFunctionCall.h" #include "ir/SkSLFunctionReference.h" #include "ir/SkSLIndexExpression.h" #include "ir/SkSLProgram.h" +#include "ir/SkSLUnresolvedFunction.h" #include "llvm/ExecutionEngine/RTDyldMemoryManager.h" static constexpr int MAX_VECTOR_COUNT = 16; @@ -37,6 +39,27 @@ extern "C" void sksl_debug_print(float f) { printf("Debug: %f\n", f); } +extern "C" float sksl_clamp1(float f, float min, float max) { + return SkTPin(f, min, max); +} + +using float2 = __attribute__((vector_size(8))) float; +using float3 = __attribute__((vector_size(16))) float; +using float4 = __attribute__((vector_size(16))) float; + +extern "C" float2 sksl_clamp2(float2 f, float min, float max) { + return float2 { SkTPin(f[0], min, max), SkTPin(f[1], min, max) }; +} + +extern "C" float3 sksl_clamp3(float3 f, float min, float max) { + return float3 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max) }; +} + +extern "C" float4 sksl_clamp4(float4 f, float min, float max) { + return float4 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max), + SkTPin(f[3], min, max) }; +} + namespace SkSL { static constexpr int STAGE_PARAM_COUNT = 12; @@ -78,6 +101,10 @@ JIT::JIT(Compiler* compiler) fContext = LLVMContextCreate(); fVoidType = LLVMVoidTypeInContext(fContext); fInt1Type = LLVMInt1TypeInContext(fContext); + fInt1VectorType = LLVMVectorType(fInt1Type, fVectorCount); + fInt1Vector2Type = LLVMVectorType(fInt1Type, 2); + fInt1Vector3Type = LLVMVectorType(fInt1Type, 3); + fInt1Vector4Type = LLVMVectorType(fInt1Type, 4); fInt8Type = LLVMInt8TypeInContext(fContext); fInt8PtrType = LLVMPointerType(fInt8Type, 0); fInt32Type = LLVMInt32TypeInContext(fContext); @@ -101,6 +128,7 @@ JIT::~JIT() { void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType, std::vector<LLVMTypeRef> parameters) { + bool found = false; for (const auto& pair : *fProgram->fSymbols) { if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) { const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second; @@ -117,9 +145,31 @@ void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMType parameters.data(), parameters.size(), false)); + found = true; + } + if (Symbol::kUnresolvedFunction_Kind == pair.second->fKind) { + // FIXME consolidate this with the code above + for (const auto& f : ((const UnresolvedFunction&) *pair.second).fFunctions) { + if (pair.first != ourName || returnType != this->getType(f->fReturnType) || + parameters.size() != f->fParameters.size()) { + continue; + } + for (size_t i = 0; i < parameters.size(); ++i) { + if (parameters[i] != this->getType(f->fParameters[i]->fType)) { + goto next; + } + } + fFunctions[f] = LLVMAddFunction(fModule, realName, LLVMFunctionType( + returnType, + parameters.data(), + parameters.size(), + false)); + found = true; + } } next:; } + SkASSERT(found); } void JIT::loadBuiltinFunctions() { @@ -128,6 +178,18 @@ void JIT::loadBuiltinFunctions() { this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type }); this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type }); this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type }); + this->addBuiltinFunction("clamp", "sksl_clamp1", fFloat32Type, { fFloat32Type, + fFloat32Type, + fFloat32Type }); + this->addBuiltinFunction("clamp", "sksl_clamp2", fFloat32Vector2Type, { fFloat32Vector2Type, + fFloat32Type, + fFloat32Type }); + this->addBuiltinFunction("clamp", "sksl_clamp3", fFloat32Vector3Type, { fFloat32Vector3Type, + fFloat32Type, + fFloat32Type }); + this->addBuiltinFunction("clamp", "sksl_clamp4", fFloat32Vector4Type, { fFloat32Vector4Type, + fFloat32Type, + fFloat32Type }); this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type }); } @@ -138,6 +200,14 @@ uint64_t JIT::resolveSymbol(const char* name, JIT* jit) { result = (uint64_t) &sksl_pipeline_append; } else if (!strcmp(name, "_sksl_pipeline_append_callback")) { result = (uint64_t) &sksl_pipeline_append_callback; + } else if (!strcmp(name, "_sksl_clamp1")) { + result = (uint64_t) &sksl_clamp1; + } else if (!strcmp(name, "_sksl_clamp2")) { + result = (uint64_t) &sksl_clamp2; + } else if (!strcmp(name, "_sksl_clamp3")) { + result = (uint64_t) &sksl_clamp3; + } else if (!strcmp(name, "_sksl_clamp4")) { + result = (uint64_t) &sksl_clamp4; } else if (!strcmp(name, "_sksl_debug_print")) { result = (uint64_t) &sksl_debug_print; } else { @@ -406,7 +476,7 @@ JIT::TypeKind JIT::typeKind(const Type& type) { return JIT::kInt_TypeKind; } else if (type.fName == "uint" || type.fName == "ushort" || type.fName == "ubyte") { return JIT::kUInt_TypeKind; - } else if (type.fName == "float" || type.fName == "double") { + } else if (type.fName == "float" || type.fName == "double" || type.fName == "half") { return JIT::kFloat_TypeKind; } ABORT("unsupported type: %s\n", type.description().c_str()); @@ -441,7 +511,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \ LLVMValueRef right = this->compileExpression(builder, *b.fRight); \ this->vectorize(builder, b, &left, &right); \ - switch (this->typeKind(b.fLeft->fType)) { \ + switch (this->typeKind(b.fLeft->fType)) { \ case kInt_TypeKind: \ return SFunc(builder, left, right, "binary"); \ case kUInt_TypeKind: \ @@ -449,7 +519,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& case kFloat_TypeKind: \ return FFunc(builder, left, right, "binary"); \ default: \ - ABORT("unsupported typeKind"); \ + ABORT("unsupported typeKind"); \ } \ } #define COMPOUND(SFunc, UFunc, FFunc) { \ @@ -458,7 +528,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& LLVMValueRef right = this->compileExpression(builder, *b.fRight); \ this->vectorize(builder, b, &left, &right); \ LLVMValueRef result; \ - switch (this->typeKind(b.fLeft->fType)) { \ + switch (this->typeKind(b.fLeft->fType)) { \ case kInt_TypeKind: \ result = SFunc(builder, left, right, "binary"); \ break; \ @@ -469,7 +539,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& result = FFunc(builder, left, right, "binary"); \ break; \ default: \ - ABORT("unsupported typeKind"); \ + ABORT("unsupported typeKind"); \ } \ lvalue->store(builder, result); \ return result; \ @@ -510,6 +580,10 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd); case Token::BITWISEOR: BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr); + case Token::SHL: + BINARY(LLVMBuildShl, LLVMBuildShl, LLVMBuildShl); + case Token::SHR: + BINARY(LLVMBuildAShr, LLVMBuildLShr, LLVMBuildAShr); case Token::PLUSEQ: COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd); case Token::MINUSEQ: @@ -523,13 +597,83 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& case Token::BITWISEOREQ: COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr); case Token::EQEQ: - COMPARE(LLVMBuildICmp, LLVMIntEQ, - LLVMBuildICmp, LLVMIntEQ, - LLVMBuildFCmp, LLVMRealOEQ); + switch (b.fLeft->fType.kind()) { + case Type::kScalar_Kind: + COMPARE(LLVMBuildICmp, LLVMIntEQ, + LLVMBuildICmp, LLVMIntEQ, + LLVMBuildFCmp, LLVMRealOEQ); + case Type::kVector_Kind: { + LLVMValueRef left = this->compileExpression(builder, *b.fLeft); + LLVMValueRef right = this->compileExpression(builder, *b.fRight); + this->vectorize(builder, b, &left, &right); + LLVMValueRef value; + switch (this->typeKind(b.fLeft->fType)) { + case kInt_TypeKind: + value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary"); + break; + case kUInt_TypeKind: + value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary"); + break; + case kFloat_TypeKind: + value = LLVMBuildFCmp(builder, LLVMRealOEQ, left, right, "binary"); + break; + default: + ABORT("unsupported typeKind"); + } + LLVMValueRef args[1] = { value }; + LLVMValueRef func; + switch (b.fLeft->fType.columns()) { + case 2: func = fFoldAnd2Func; break; + case 3: func = fFoldAnd3Func; break; + case 4: func = fFoldAnd4Func; break; + default: + SkASSERT(false); + func = fFoldAnd2Func; + } + return LLVMBuildCall(builder, func, args, 1, "all"); + } + default: + SkASSERT(false); + } case Token::NEQ: - COMPARE(LLVMBuildICmp, LLVMIntNE, - LLVMBuildICmp, LLVMIntNE, - LLVMBuildFCmp, LLVMRealONE); + switch (b.fLeft->fType.kind()) { + case Type::kScalar_Kind: + COMPARE(LLVMBuildICmp, LLVMIntNE, + LLVMBuildICmp, LLVMIntNE, + LLVMBuildFCmp, LLVMRealONE); + case Type::kVector_Kind: { + LLVMValueRef left = this->compileExpression(builder, *b.fLeft); + LLVMValueRef right = this->compileExpression(builder, *b.fRight); + this->vectorize(builder, b, &left, &right); + LLVMValueRef value; + switch (this->typeKind(b.fLeft->fType)) { + case kInt_TypeKind: + value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary"); + break; + case kUInt_TypeKind: + value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary"); + break; + case kFloat_TypeKind: + value = LLVMBuildFCmp(builder, LLVMRealONE, left, right, "binary"); + break; + default: + ABORT("unsupported typeKind"); + } + LLVMValueRef args[1] = { value }; + LLVMValueRef func; + switch (b.fLeft->fType.columns()) { + case 2: func = fFoldOr2Func; break; + case 3: func = fFoldOr3Func; break; + case 4: func = fFoldOr4Func; break; + default: + SkASSERT(false); + func = fFoldOr2Func; + } + return LLVMBuildCall(builder, func, args, 1, "all"); + } + default: + SkASSERT(false); + } case Token::LT: COMPARE(LLVMBuildICmp, LLVMIntSLT, LLVMBuildICmp, LLVMIntULT, @@ -583,6 +727,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& return phi; } default: + printf("%s\n", b.description().c_str()); ABORT("unsupported binary operator"); } } @@ -702,9 +847,9 @@ void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) { const FunctionDeclaration& functionDecl = *((FunctionReference&) *a.fArguments[1]).fFunctions[0]; bool found = false; - for (const auto& pe : fProgram->fElements) { - if (ProgramElement::kFunction_Kind == pe->fKind) { - const FunctionDefinition& def = (const FunctionDefinition&) *pe; + for (const auto& pe : *fProgram) { + if (ProgramElement::kFunction_Kind == pe.fKind) { + const FunctionDefinition& def = (const FunctionDefinition&) pe; if (&def.fDeclaration == &functionDecl) { LLVMValueRef fn = this->compileStageFunction(def); LLVMValueRef args[2] = { @@ -747,49 +892,74 @@ LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor& TypeKind from = this->typeKind(c.fArguments[0]->fType); TypeKind to = this->typeKind(c.fType); LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]); - if (kFloat_TypeKind == to) { - if (kInt_TypeKind == from) { - return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast"); - } - if (kUInt_TypeKind == from) { - return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast"); - } - } - if (kInt_TypeKind == to) { - if (kFloat_TypeKind == from) { - return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast"); - } - if (kUInt_TypeKind == from) { - return base; - } - } - if (kUInt_TypeKind == to) { - if (kFloat_TypeKind == from) { - return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast"); - } - if (kInt_TypeKind == from) { - return base; - } + switch (to) { + case kFloat_TypeKind: + switch (from) { + case kInt_TypeKind: + return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast"); + case kUInt_TypeKind: + return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast"); + case kFloat_TypeKind: + return base; + case kBool_TypeKind: + SkASSERT(false); + } + case kInt_TypeKind: + switch (from) { + case kInt_TypeKind: + return base; + case kUInt_TypeKind: + return base; + case kFloat_TypeKind: + return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast"); + case kBool_TypeKind: + SkASSERT(false); + } + case kUInt_TypeKind: + switch (from) { + case kInt_TypeKind: + return base; + case kUInt_TypeKind: + return base; + case kFloat_TypeKind: + return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast"); + case kBool_TypeKind: + SkASSERT(false); + } + case kBool_TypeKind: + SkASSERT(false); } - ABORT("unsupported constructor"); } case Type::kVector_Kind: { LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType)); - if (c.fArguments.size() == 1) { + if (c.fArguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) { LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]); for (int i = 0; i < c.fType.columns(); ++i) { vec = LLVMBuildInsertElement(builder, vec, value, LLVMConstInt(fInt32Type, i, false), - "vec build"); + "vec build 1"); } } else { - SkASSERT(c.fArguments.size() == (size_t) c.fType.columns()); - for (int i = 0; i < c.fType.columns(); ++i) { - vec = LLVMBuildInsertElement(builder, vec, - this->compileExpression(builder, - *c.fArguments[i]), - LLVMConstInt(fInt32Type, i, false), - "vec build"); + int index = 0; + for (const auto& arg : c.fArguments) { + LLVMValueRef value = this->compileExpression(builder, *arg); + if (arg->fType.kind() == Type::kVector_Kind) { + for (int i = 0; i < arg->fType.columns(); ++i) { + LLVMValueRef column = LLVMBuildExtractElement(builder, + vec, + LLVMConstInt(fInt32Type, + i, + false), + "construct extract"); + vec = LLVMBuildInsertElement(builder, vec, column, + LLVMConstInt(fInt32Type, index++, false), + "vec build 2"); + } + } else { + vec = LLVMBuildInsertElement(builder, vec, value, + LLVMConstInt(fInt32Type, index++, false), + "vec build 3"); + } } } return vec; @@ -1460,7 +1630,6 @@ bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr return this->compileVectorVariableReference(builder, (const VariableReference&) expr, out); default: - printf("failed expression: %s\n", expr.description().c_str()); return false; } } @@ -1480,7 +1649,6 @@ bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt) *((const ExpressionStatement&) stmt).fExpression, &result); default: - printf("failed statement: %s\n", stmt.description().c_str()); return false; } } @@ -1582,7 +1750,7 @@ bool JIT::hasStageSignature(const FunctionDeclaration& f) { f.fParameters[0]->fModifiers.fFlags == 0 && f.fParameters[1]->fType == *fProgram->fContext->fInt_Type && f.fParameters[1]->fModifiers.fFlags == 0 && - f.fParameters[2]->fType == *fProgram->fContext->fFloat4_Type && + f.fParameters[2]->fType == *fProgram->fContext->fHalf4_Type && f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag); } @@ -1639,6 +1807,21 @@ void JIT::createModule() { fPromotedParameters.clear(); fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext); this->loadBuiltinFunctions(); + LLVMTypeRef fold2Params[1] = { fInt1Vector2Type }; + fFoldAnd2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v2i1", + LLVMFunctionType(fInt1Type, fold2Params, 1, false)); + fFoldOr2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v2i1", + LLVMFunctionType(fInt1Type, fold2Params, 1, false)); + LLVMTypeRef fold3Params[1] = { fInt1Vector3Type }; + fFoldAnd3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v3i1", + LLVMFunctionType(fInt1Type, fold3Params, 1, false)); + fFoldOr3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v3i1", + LLVMFunctionType(fInt1Type, fold3Params, 1, false)); + LLVMTypeRef fold4Params[1] = { fInt1Vector4Type }; + fFoldAnd4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v4i1", + LLVMFunctionType(fInt1Type, fold4Params, 1, false)); + fFoldOr4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v4i1", + LLVMFunctionType(fInt1Type, fold4Params, 1, false)); // LLVM doesn't do void*, have to declare it as int8* LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType }; fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType, @@ -1656,13 +1839,15 @@ void JIT::createModule() { 1, false)); - for (const auto& e : fProgram->fElements) { - SkASSERT(e->fKind == ProgramElement::kFunction_Kind); - this->compileFunction((FunctionDefinition&) *e); + for (const auto& e : *fProgram) { + if (e.fKind == ProgramElement::kFunction_Kind) { + this->compileFunction((FunctionDefinition&) e); + } } } std::unique_ptr<JIT::Module> JIT::compile(std::unique_ptr<Program> program) { + fCompiler.optimize(*program); fProgram = std::move(program); this->createModule(); this->optimize(); |