diff options
Diffstat (limited to 'src/sksl/SkSLJIT.cpp')
-rw-r--r-- | src/sksl/SkSLJIT.cpp | 291 |
1 files changed, 53 insertions, 238 deletions
diff --git a/src/sksl/SkSLJIT.cpp b/src/sksl/SkSLJIT.cpp index 4120c4aa8c..115a6be3ae 100644 --- a/src/sksl/SkSLJIT.cpp +++ b/src/sksl/SkSLJIT.cpp @@ -14,13 +14,11 @@ #include "SkCpu.h" #include "SkRasterPipeline.h" #include "../jumper/SkJumper.h" -#include "ir/SkSLAppendStage.h" #include "ir/SkSLExpressionStatement.h" #include "ir/SkSLFunctionCall.h" #include "ir/SkSLFunctionReference.h" #include "ir/SkSLIndexExpression.h" #include "ir/SkSLProgram.h" -#include "ir/SkSLUnresolvedFunction.h" #include "llvm/ExecutionEngine/RTDyldMemoryManager.h" static constexpr int MAX_VECTOR_COUNT = 16; @@ -39,27 +37,6 @@ extern "C" void sksl_debug_print(float f) { printf("Debug: %f\n", f); } -extern "C" float sksl_clamp1(float f, float min, float max) { - return SkTPin(f, min, max); -} - -using float2 = __attribute__((vector_size(8))) float; -using float3 = __attribute__((vector_size(16))) float; -using float4 = __attribute__((vector_size(16))) float; - -extern "C" float2 sksl_clamp2(float2 f, float min, float max) { - return float2 { SkTPin(f[0], min, max), SkTPin(f[1], min, max) }; -} - -extern "C" float3 sksl_clamp3(float3 f, float min, float max) { - return float3 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max) }; -} - -extern "C" float4 sksl_clamp4(float4 f, float min, float max) { - return float4 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max), - SkTPin(f[3], min, max) }; -} - namespace SkSL { static constexpr int STAGE_PARAM_COUNT = 12; @@ -101,10 +78,6 @@ JIT::JIT(Compiler* compiler) fContext = LLVMContextCreate(); fVoidType = LLVMVoidTypeInContext(fContext); fInt1Type = LLVMInt1TypeInContext(fContext); - fInt1VectorType = LLVMVectorType(fInt1Type, fVectorCount); - fInt1Vector2Type = LLVMVectorType(fInt1Type, 2); - fInt1Vector3Type = LLVMVectorType(fInt1Type, 3); - fInt1Vector4Type = LLVMVectorType(fInt1Type, 4); fInt8Type = LLVMInt8TypeInContext(fContext); fInt8PtrType = LLVMPointerType(fInt8Type, 0); fInt32Type = LLVMInt32TypeInContext(fContext); @@ -128,7 +101,6 @@ JIT::~JIT() { void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType, std::vector<LLVMTypeRef> parameters) { - bool found = false; for (const auto& pair : *fProgram->fSymbols) { if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) { const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second; @@ -145,31 +117,9 @@ void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMType parameters.data(), parameters.size(), false)); - found = true; - } - if (Symbol::kUnresolvedFunction_Kind == pair.second->fKind) { - // FIXME consolidate this with the code above - for (const auto& f : ((const UnresolvedFunction&) *pair.second).fFunctions) { - if (pair.first != ourName || returnType != this->getType(f->fReturnType) || - parameters.size() != f->fParameters.size()) { - continue; - } - for (size_t i = 0; i < parameters.size(); ++i) { - if (parameters[i] != this->getType(f->fParameters[i]->fType)) { - goto next; - } - } - fFunctions[f] = LLVMAddFunction(fModule, realName, LLVMFunctionType( - returnType, - parameters.data(), - parameters.size(), - false)); - found = true; - } } next:; } - SkASSERT(found); } void JIT::loadBuiltinFunctions() { @@ -178,18 +128,6 @@ void JIT::loadBuiltinFunctions() { this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type }); this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type }); this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type }); - this->addBuiltinFunction("clamp", "sksl_clamp1", fFloat32Type, { fFloat32Type, - fFloat32Type, - fFloat32Type }); - this->addBuiltinFunction("clamp", "sksl_clamp2", fFloat32Vector2Type, { fFloat32Vector2Type, - fFloat32Type, - fFloat32Type }); - this->addBuiltinFunction("clamp", "sksl_clamp3", fFloat32Vector3Type, { fFloat32Vector3Type, - fFloat32Type, - fFloat32Type }); - this->addBuiltinFunction("clamp", "sksl_clamp4", fFloat32Vector4Type, { fFloat32Vector4Type, - fFloat32Type, - fFloat32Type }); this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type }); } @@ -200,14 +138,6 @@ uint64_t JIT::resolveSymbol(const char* name, JIT* jit) { result = (uint64_t) &sksl_pipeline_append; } else if (!strcmp(name, "_sksl_pipeline_append_callback")) { result = (uint64_t) &sksl_pipeline_append_callback; - } else if (!strcmp(name, "_sksl_clamp1")) { - result = (uint64_t) &sksl_clamp1; - } else if (!strcmp(name, "_sksl_clamp2")) { - result = (uint64_t) &sksl_clamp2; - } else if (!strcmp(name, "_sksl_clamp3")) { - result = (uint64_t) &sksl_clamp3; - } else if (!strcmp(name, "_sksl_clamp4")) { - result = (uint64_t) &sksl_clamp4; } else if (!strcmp(name, "_sksl_debug_print")) { result = (uint64_t) &sksl_debug_print; } else { @@ -476,7 +406,7 @@ JIT::TypeKind JIT::typeKind(const Type& type) { return JIT::kInt_TypeKind; } else if (type.fName == "uint" || type.fName == "ushort" || type.fName == "ubyte") { return JIT::kUInt_TypeKind; - } else if (type.fName == "float" || type.fName == "double" || type.fName == "half") { + } else if (type.fName == "float" || type.fName == "double") { return JIT::kFloat_TypeKind; } ABORT("unsupported type: %s\n", type.description().c_str()); @@ -511,7 +441,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \ LLVMValueRef right = this->compileExpression(builder, *b.fRight); \ this->vectorize(builder, b, &left, &right); \ - switch (this->typeKind(b.fLeft->fType)) { \ + switch (this->typeKind(b.fLeft->fType)) { \ case kInt_TypeKind: \ return SFunc(builder, left, right, "binary"); \ case kUInt_TypeKind: \ @@ -519,7 +449,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& case kFloat_TypeKind: \ return FFunc(builder, left, right, "binary"); \ default: \ - ABORT("unsupported typeKind"); \ + ABORT("unsupported typeKind"); \ } \ } #define COMPOUND(SFunc, UFunc, FFunc) { \ @@ -528,7 +458,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& LLVMValueRef right = this->compileExpression(builder, *b.fRight); \ this->vectorize(builder, b, &left, &right); \ LLVMValueRef result; \ - switch (this->typeKind(b.fLeft->fType)) { \ + switch (this->typeKind(b.fLeft->fType)) { \ case kInt_TypeKind: \ result = SFunc(builder, left, right, "binary"); \ break; \ @@ -539,7 +469,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& result = FFunc(builder, left, right, "binary"); \ break; \ default: \ - ABORT("unsupported typeKind"); \ + ABORT("unsupported typeKind"); \ } \ lvalue->store(builder, result); \ return result; \ @@ -580,10 +510,6 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd); case Token::BITWISEOR: BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr); - case Token::SHL: - BINARY(LLVMBuildShl, LLVMBuildShl, LLVMBuildShl); - case Token::SHR: - BINARY(LLVMBuildAShr, LLVMBuildLShr, LLVMBuildAShr); case Token::PLUSEQ: COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd); case Token::MINUSEQ: @@ -597,83 +523,13 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& case Token::BITWISEOREQ: COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr); case Token::EQEQ: - switch (b.fLeft->fType.kind()) { - case Type::kScalar_Kind: - COMPARE(LLVMBuildICmp, LLVMIntEQ, - LLVMBuildICmp, LLVMIntEQ, - LLVMBuildFCmp, LLVMRealOEQ); - case Type::kVector_Kind: { - LLVMValueRef left = this->compileExpression(builder, *b.fLeft); - LLVMValueRef right = this->compileExpression(builder, *b.fRight); - this->vectorize(builder, b, &left, &right); - LLVMValueRef value; - switch (this->typeKind(b.fLeft->fType)) { - case kInt_TypeKind: - value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary"); - break; - case kUInt_TypeKind: - value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary"); - break; - case kFloat_TypeKind: - value = LLVMBuildFCmp(builder, LLVMRealOEQ, left, right, "binary"); - break; - default: - ABORT("unsupported typeKind"); - } - LLVMValueRef args[1] = { value }; - LLVMValueRef func; - switch (b.fLeft->fType.columns()) { - case 2: func = fFoldAnd2Func; break; - case 3: func = fFoldAnd3Func; break; - case 4: func = fFoldAnd4Func; break; - default: - SkASSERT(false); - func = fFoldAnd2Func; - } - return LLVMBuildCall(builder, func, args, 1, "all"); - } - default: - SkASSERT(false); - } + COMPARE(LLVMBuildICmp, LLVMIntEQ, + LLVMBuildICmp, LLVMIntEQ, + LLVMBuildFCmp, LLVMRealOEQ); case Token::NEQ: - switch (b.fLeft->fType.kind()) { - case Type::kScalar_Kind: - COMPARE(LLVMBuildICmp, LLVMIntNE, - LLVMBuildICmp, LLVMIntNE, - LLVMBuildFCmp, LLVMRealONE); - case Type::kVector_Kind: { - LLVMValueRef left = this->compileExpression(builder, *b.fLeft); - LLVMValueRef right = this->compileExpression(builder, *b.fRight); - this->vectorize(builder, b, &left, &right); - LLVMValueRef value; - switch (this->typeKind(b.fLeft->fType)) { - case kInt_TypeKind: - value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary"); - break; - case kUInt_TypeKind: - value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary"); - break; - case kFloat_TypeKind: - value = LLVMBuildFCmp(builder, LLVMRealONE, left, right, "binary"); - break; - default: - ABORT("unsupported typeKind"); - } - LLVMValueRef args[1] = { value }; - LLVMValueRef func; - switch (b.fLeft->fType.columns()) { - case 2: func = fFoldOr2Func; break; - case 3: func = fFoldOr3Func; break; - case 4: func = fFoldOr4Func; break; - default: - SkASSERT(false); - func = fFoldOr2Func; - } - return LLVMBuildCall(builder, func, args, 1, "all"); - } - default: - SkASSERT(false); - } + COMPARE(LLVMBuildICmp, LLVMIntNE, + LLVMBuildICmp, LLVMIntNE, + LLVMBuildFCmp, LLVMRealONE); case Token::LT: COMPARE(LLVMBuildICmp, LLVMIntSLT, LLVMBuildICmp, LLVMIntULT, @@ -727,7 +583,6 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& return phi; } default: - printf("%s\n", b.description().c_str()); ABORT("unsupported binary operator"); } } @@ -847,9 +702,9 @@ void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) { const FunctionDeclaration& functionDecl = *((FunctionReference&) *a.fArguments[1]).fFunctions[0]; bool found = false; - for (const auto& pe : *fProgram) { - if (ProgramElement::kFunction_Kind == pe.fKind) { - const FunctionDefinition& def = (const FunctionDefinition&) pe; + for (const auto& pe : fProgram->fElements) { + if (ProgramElement::kFunction_Kind == pe->fKind) { + const FunctionDefinition& def = (const FunctionDefinition&) *pe; if (&def.fDeclaration == &functionDecl) { LLVMValueRef fn = this->compileStageFunction(def); LLVMValueRef args[2] = { @@ -892,74 +747,49 @@ LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor& TypeKind from = this->typeKind(c.fArguments[0]->fType); TypeKind to = this->typeKind(c.fType); LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]); - switch (to) { - case kFloat_TypeKind: - switch (from) { - case kInt_TypeKind: - return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast"); - case kUInt_TypeKind: - return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast"); - case kFloat_TypeKind: - return base; - case kBool_TypeKind: - SkASSERT(false); - } - case kInt_TypeKind: - switch (from) { - case kInt_TypeKind: - return base; - case kUInt_TypeKind: - return base; - case kFloat_TypeKind: - return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast"); - case kBool_TypeKind: - SkASSERT(false); - } - case kUInt_TypeKind: - switch (from) { - case kInt_TypeKind: - return base; - case kUInt_TypeKind: - return base; - case kFloat_TypeKind: - return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast"); - case kBool_TypeKind: - SkASSERT(false); - } - case kBool_TypeKind: - SkASSERT(false); + if (kFloat_TypeKind == to) { + if (kInt_TypeKind == from) { + return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast"); + } + if (kUInt_TypeKind == from) { + return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast"); + } + } + if (kInt_TypeKind == to) { + if (kFloat_TypeKind == from) { + return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast"); + } + if (kUInt_TypeKind == from) { + return base; + } + } + if (kUInt_TypeKind == to) { + if (kFloat_TypeKind == from) { + return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast"); + } + if (kInt_TypeKind == from) { + return base; + } } + ABORT("unsupported constructor"); } case Type::kVector_Kind: { LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType)); - if (c.fArguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) { + if (c.fArguments.size() == 1) { LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]); for (int i = 0; i < c.fType.columns(); ++i) { vec = LLVMBuildInsertElement(builder, vec, value, LLVMConstInt(fInt32Type, i, false), - "vec build 1"); + "vec build"); } } else { - int index = 0; - for (const auto& arg : c.fArguments) { - LLVMValueRef value = this->compileExpression(builder, *arg); - if (arg->fType.kind() == Type::kVector_Kind) { - for (int i = 0; i < arg->fType.columns(); ++i) { - LLVMValueRef column = LLVMBuildExtractElement(builder, - vec, - LLVMConstInt(fInt32Type, - i, - false), - "construct extract"); - vec = LLVMBuildInsertElement(builder, vec, column, - LLVMConstInt(fInt32Type, index++, false), - "vec build 2"); - } - } else { - vec = LLVMBuildInsertElement(builder, vec, value, - LLVMConstInt(fInt32Type, index++, false), - "vec build 3"); - } + SkASSERT(c.fArguments.size() == (size_t) c.fType.columns()); + for (int i = 0; i < c.fType.columns(); ++i) { + vec = LLVMBuildInsertElement(builder, vec, + this->compileExpression(builder, + *c.fArguments[i]), + LLVMConstInt(fInt32Type, i, false), + "vec build"); } } return vec; @@ -1630,6 +1460,7 @@ bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr return this->compileVectorVariableReference(builder, (const VariableReference&) expr, out); default: + printf("failed expression: %s\n", expr.description().c_str()); return false; } } @@ -1649,6 +1480,7 @@ bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt) *((const ExpressionStatement&) stmt).fExpression, &result); default: + printf("failed statement: %s\n", stmt.description().c_str()); return false; } } @@ -1750,7 +1582,7 @@ bool JIT::hasStageSignature(const FunctionDeclaration& f) { f.fParameters[0]->fModifiers.fFlags == 0 && f.fParameters[1]->fType == *fProgram->fContext->fInt_Type && f.fParameters[1]->fModifiers.fFlags == 0 && - f.fParameters[2]->fType == *fProgram->fContext->fHalf4_Type && + f.fParameters[2]->fType == *fProgram->fContext->fFloat4_Type && f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag); } @@ -1807,21 +1639,6 @@ void JIT::createModule() { fPromotedParameters.clear(); fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext); this->loadBuiltinFunctions(); - LLVMTypeRef fold2Params[1] = { fInt1Vector2Type }; - fFoldAnd2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v2i1", - LLVMFunctionType(fInt1Type, fold2Params, 1, false)); - fFoldOr2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v2i1", - LLVMFunctionType(fInt1Type, fold2Params, 1, false)); - LLVMTypeRef fold3Params[1] = { fInt1Vector3Type }; - fFoldAnd3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v3i1", - LLVMFunctionType(fInt1Type, fold3Params, 1, false)); - fFoldOr3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v3i1", - LLVMFunctionType(fInt1Type, fold3Params, 1, false)); - LLVMTypeRef fold4Params[1] = { fInt1Vector4Type }; - fFoldAnd4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v4i1", - LLVMFunctionType(fInt1Type, fold4Params, 1, false)); - fFoldOr4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v4i1", - LLVMFunctionType(fInt1Type, fold4Params, 1, false)); // LLVM doesn't do void*, have to declare it as int8* LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType }; fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType, @@ -1839,15 +1656,13 @@ void JIT::createModule() { 1, false)); - for (const auto& e : *fProgram) { - if (e.fKind == ProgramElement::kFunction_Kind) { - this->compileFunction((FunctionDefinition&) e); - } + for (const auto& e : fProgram->fElements) { + SkASSERT(e->fKind == ProgramElement::kFunction_Kind); + this->compileFunction((FunctionDefinition&) *e); } } std::unique_ptr<JIT::Module> JIT::compile(std::unique_ptr<Program> program) { - fCompiler.optimize(*program); fProgram = std::move(program); this->createModule(); this->optimize(); |