aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/sksl/SkSLJIT.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/sksl/SkSLJIT.cpp')
-rw-r--r--src/sksl/SkSLJIT.cpp291
1 files changed, 53 insertions, 238 deletions
diff --git a/src/sksl/SkSLJIT.cpp b/src/sksl/SkSLJIT.cpp
index 4120c4aa8c..115a6be3ae 100644
--- a/src/sksl/SkSLJIT.cpp
+++ b/src/sksl/SkSLJIT.cpp
@@ -14,13 +14,11 @@
#include "SkCpu.h"
#include "SkRasterPipeline.h"
#include "../jumper/SkJumper.h"
-#include "ir/SkSLAppendStage.h"
#include "ir/SkSLExpressionStatement.h"
#include "ir/SkSLFunctionCall.h"
#include "ir/SkSLFunctionReference.h"
#include "ir/SkSLIndexExpression.h"
#include "ir/SkSLProgram.h"
-#include "ir/SkSLUnresolvedFunction.h"
#include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
static constexpr int MAX_VECTOR_COUNT = 16;
@@ -39,27 +37,6 @@ extern "C" void sksl_debug_print(float f) {
printf("Debug: %f\n", f);
}
-extern "C" float sksl_clamp1(float f, float min, float max) {
- return SkTPin(f, min, max);
-}
-
-using float2 = __attribute__((vector_size(8))) float;
-using float3 = __attribute__((vector_size(16))) float;
-using float4 = __attribute__((vector_size(16))) float;
-
-extern "C" float2 sksl_clamp2(float2 f, float min, float max) {
- return float2 { SkTPin(f[0], min, max), SkTPin(f[1], min, max) };
-}
-
-extern "C" float3 sksl_clamp3(float3 f, float min, float max) {
- return float3 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max) };
-}
-
-extern "C" float4 sksl_clamp4(float4 f, float min, float max) {
- return float4 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max),
- SkTPin(f[3], min, max) };
-}
-
namespace SkSL {
static constexpr int STAGE_PARAM_COUNT = 12;
@@ -101,10 +78,6 @@ JIT::JIT(Compiler* compiler)
fContext = LLVMContextCreate();
fVoidType = LLVMVoidTypeInContext(fContext);
fInt1Type = LLVMInt1TypeInContext(fContext);
- fInt1VectorType = LLVMVectorType(fInt1Type, fVectorCount);
- fInt1Vector2Type = LLVMVectorType(fInt1Type, 2);
- fInt1Vector3Type = LLVMVectorType(fInt1Type, 3);
- fInt1Vector4Type = LLVMVectorType(fInt1Type, 4);
fInt8Type = LLVMInt8TypeInContext(fContext);
fInt8PtrType = LLVMPointerType(fInt8Type, 0);
fInt32Type = LLVMInt32TypeInContext(fContext);
@@ -128,7 +101,6 @@ JIT::~JIT() {
void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType,
std::vector<LLVMTypeRef> parameters) {
- bool found = false;
for (const auto& pair : *fProgram->fSymbols) {
if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) {
const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second;
@@ -145,31 +117,9 @@ void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMType
parameters.data(),
parameters.size(),
false));
- found = true;
- }
- if (Symbol::kUnresolvedFunction_Kind == pair.second->fKind) {
- // FIXME consolidate this with the code above
- for (const auto& f : ((const UnresolvedFunction&) *pair.second).fFunctions) {
- if (pair.first != ourName || returnType != this->getType(f->fReturnType) ||
- parameters.size() != f->fParameters.size()) {
- continue;
- }
- for (size_t i = 0; i < parameters.size(); ++i) {
- if (parameters[i] != this->getType(f->fParameters[i]->fType)) {
- goto next;
- }
- }
- fFunctions[f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(
- returnType,
- parameters.data(),
- parameters.size(),
- false));
- found = true;
- }
}
next:;
}
- SkASSERT(found);
}
void JIT::loadBuiltinFunctions() {
@@ -178,18 +128,6 @@ void JIT::loadBuiltinFunctions() {
this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type });
this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type });
this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type });
- this->addBuiltinFunction("clamp", "sksl_clamp1", fFloat32Type, { fFloat32Type,
- fFloat32Type,
- fFloat32Type });
- this->addBuiltinFunction("clamp", "sksl_clamp2", fFloat32Vector2Type, { fFloat32Vector2Type,
- fFloat32Type,
- fFloat32Type });
- this->addBuiltinFunction("clamp", "sksl_clamp3", fFloat32Vector3Type, { fFloat32Vector3Type,
- fFloat32Type,
- fFloat32Type });
- this->addBuiltinFunction("clamp", "sksl_clamp4", fFloat32Vector4Type, { fFloat32Vector4Type,
- fFloat32Type,
- fFloat32Type });
this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type });
}
@@ -200,14 +138,6 @@ uint64_t JIT::resolveSymbol(const char* name, JIT* jit) {
result = (uint64_t) &sksl_pipeline_append;
} else if (!strcmp(name, "_sksl_pipeline_append_callback")) {
result = (uint64_t) &sksl_pipeline_append_callback;
- } else if (!strcmp(name, "_sksl_clamp1")) {
- result = (uint64_t) &sksl_clamp1;
- } else if (!strcmp(name, "_sksl_clamp2")) {
- result = (uint64_t) &sksl_clamp2;
- } else if (!strcmp(name, "_sksl_clamp3")) {
- result = (uint64_t) &sksl_clamp3;
- } else if (!strcmp(name, "_sksl_clamp4")) {
- result = (uint64_t) &sksl_clamp4;
} else if (!strcmp(name, "_sksl_debug_print")) {
result = (uint64_t) &sksl_debug_print;
} else {
@@ -476,7 +406,7 @@ JIT::TypeKind JIT::typeKind(const Type& type) {
return JIT::kInt_TypeKind;
} else if (type.fName == "uint" || type.fName == "ushort" || type.fName == "ubyte") {
return JIT::kUInt_TypeKind;
- } else if (type.fName == "float" || type.fName == "double" || type.fName == "half") {
+ } else if (type.fName == "float" || type.fName == "double") {
return JIT::kFloat_TypeKind;
}
ABORT("unsupported type: %s\n", type.description().c_str());
@@ -511,7 +441,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression&
LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
this->vectorize(builder, b, &left, &right); \
- switch (this->typeKind(b.fLeft->fType)) { \
+ switch (this->typeKind(b.fLeft->fType)) { \
case kInt_TypeKind: \
return SFunc(builder, left, right, "binary"); \
case kUInt_TypeKind: \
@@ -519,7 +449,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression&
case kFloat_TypeKind: \
return FFunc(builder, left, right, "binary"); \
default: \
- ABORT("unsupported typeKind"); \
+ ABORT("unsupported typeKind"); \
} \
}
#define COMPOUND(SFunc, UFunc, FFunc) { \
@@ -528,7 +458,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression&
LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
this->vectorize(builder, b, &left, &right); \
LLVMValueRef result; \
- switch (this->typeKind(b.fLeft->fType)) { \
+ switch (this->typeKind(b.fLeft->fType)) { \
case kInt_TypeKind: \
result = SFunc(builder, left, right, "binary"); \
break; \
@@ -539,7 +469,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression&
result = FFunc(builder, left, right, "binary"); \
break; \
default: \
- ABORT("unsupported typeKind"); \
+ ABORT("unsupported typeKind"); \
} \
lvalue->store(builder, result); \
return result; \
@@ -580,10 +510,6 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression&
BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
case Token::BITWISEOR:
BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
- case Token::SHL:
- BINARY(LLVMBuildShl, LLVMBuildShl, LLVMBuildShl);
- case Token::SHR:
- BINARY(LLVMBuildAShr, LLVMBuildLShr, LLVMBuildAShr);
case Token::PLUSEQ:
COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
case Token::MINUSEQ:
@@ -597,83 +523,13 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression&
case Token::BITWISEOREQ:
COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
case Token::EQEQ:
- switch (b.fLeft->fType.kind()) {
- case Type::kScalar_Kind:
- COMPARE(LLVMBuildICmp, LLVMIntEQ,
- LLVMBuildICmp, LLVMIntEQ,
- LLVMBuildFCmp, LLVMRealOEQ);
- case Type::kVector_Kind: {
- LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
- LLVMValueRef right = this->compileExpression(builder, *b.fRight);
- this->vectorize(builder, b, &left, &right);
- LLVMValueRef value;
- switch (this->typeKind(b.fLeft->fType)) {
- case kInt_TypeKind:
- value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
- break;
- case kUInt_TypeKind:
- value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
- break;
- case kFloat_TypeKind:
- value = LLVMBuildFCmp(builder, LLVMRealOEQ, left, right, "binary");
- break;
- default:
- ABORT("unsupported typeKind");
- }
- LLVMValueRef args[1] = { value };
- LLVMValueRef func;
- switch (b.fLeft->fType.columns()) {
- case 2: func = fFoldAnd2Func; break;
- case 3: func = fFoldAnd3Func; break;
- case 4: func = fFoldAnd4Func; break;
- default:
- SkASSERT(false);
- func = fFoldAnd2Func;
- }
- return LLVMBuildCall(builder, func, args, 1, "all");
- }
- default:
- SkASSERT(false);
- }
+ COMPARE(LLVMBuildICmp, LLVMIntEQ,
+ LLVMBuildICmp, LLVMIntEQ,
+ LLVMBuildFCmp, LLVMRealOEQ);
case Token::NEQ:
- switch (b.fLeft->fType.kind()) {
- case Type::kScalar_Kind:
- COMPARE(LLVMBuildICmp, LLVMIntNE,
- LLVMBuildICmp, LLVMIntNE,
- LLVMBuildFCmp, LLVMRealONE);
- case Type::kVector_Kind: {
- LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
- LLVMValueRef right = this->compileExpression(builder, *b.fRight);
- this->vectorize(builder, b, &left, &right);
- LLVMValueRef value;
- switch (this->typeKind(b.fLeft->fType)) {
- case kInt_TypeKind:
- value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
- break;
- case kUInt_TypeKind:
- value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
- break;
- case kFloat_TypeKind:
- value = LLVMBuildFCmp(builder, LLVMRealONE, left, right, "binary");
- break;
- default:
- ABORT("unsupported typeKind");
- }
- LLVMValueRef args[1] = { value };
- LLVMValueRef func;
- switch (b.fLeft->fType.columns()) {
- case 2: func = fFoldOr2Func; break;
- case 3: func = fFoldOr3Func; break;
- case 4: func = fFoldOr4Func; break;
- default:
- SkASSERT(false);
- func = fFoldOr2Func;
- }
- return LLVMBuildCall(builder, func, args, 1, "all");
- }
- default:
- SkASSERT(false);
- }
+ COMPARE(LLVMBuildICmp, LLVMIntNE,
+ LLVMBuildICmp, LLVMIntNE,
+ LLVMBuildFCmp, LLVMRealONE);
case Token::LT:
COMPARE(LLVMBuildICmp, LLVMIntSLT,
LLVMBuildICmp, LLVMIntULT,
@@ -727,7 +583,6 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression&
return phi;
}
default:
- printf("%s\n", b.description().c_str());
ABORT("unsupported binary operator");
}
}
@@ -847,9 +702,9 @@ void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) {
const FunctionDeclaration& functionDecl =
*((FunctionReference&) *a.fArguments[1]).fFunctions[0];
bool found = false;
- for (const auto& pe : *fProgram) {
- if (ProgramElement::kFunction_Kind == pe.fKind) {
- const FunctionDefinition& def = (const FunctionDefinition&) pe;
+ for (const auto& pe : fProgram->fElements) {
+ if (ProgramElement::kFunction_Kind == pe->fKind) {
+ const FunctionDefinition& def = (const FunctionDefinition&) *pe;
if (&def.fDeclaration == &functionDecl) {
LLVMValueRef fn = this->compileStageFunction(def);
LLVMValueRef args[2] = {
@@ -892,74 +747,49 @@ LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor&
TypeKind from = this->typeKind(c.fArguments[0]->fType);
TypeKind to = this->typeKind(c.fType);
LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]);
- switch (to) {
- case kFloat_TypeKind:
- switch (from) {
- case kInt_TypeKind:
- return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast");
- case kUInt_TypeKind:
- return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast");
- case kFloat_TypeKind:
- return base;
- case kBool_TypeKind:
- SkASSERT(false);
- }
- case kInt_TypeKind:
- switch (from) {
- case kInt_TypeKind:
- return base;
- case kUInt_TypeKind:
- return base;
- case kFloat_TypeKind:
- return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast");
- case kBool_TypeKind:
- SkASSERT(false);
- }
- case kUInt_TypeKind:
- switch (from) {
- case kInt_TypeKind:
- return base;
- case kUInt_TypeKind:
- return base;
- case kFloat_TypeKind:
- return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast");
- case kBool_TypeKind:
- SkASSERT(false);
- }
- case kBool_TypeKind:
- SkASSERT(false);
+ if (kFloat_TypeKind == to) {
+ if (kInt_TypeKind == from) {
+ return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast");
+ }
+ if (kUInt_TypeKind == from) {
+ return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast");
+ }
+ }
+ if (kInt_TypeKind == to) {
+ if (kFloat_TypeKind == from) {
+ return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast");
+ }
+ if (kUInt_TypeKind == from) {
+ return base;
+ }
+ }
+ if (kUInt_TypeKind == to) {
+ if (kFloat_TypeKind == from) {
+ return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast");
+ }
+ if (kInt_TypeKind == from) {
+ return base;
+ }
}
+ ABORT("unsupported constructor");
}
case Type::kVector_Kind: {
LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType));
- if (c.fArguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
+ if (c.fArguments.size() == 1) {
LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]);
for (int i = 0; i < c.fType.columns(); ++i) {
vec = LLVMBuildInsertElement(builder, vec, value,
LLVMConstInt(fInt32Type, i, false),
- "vec build 1");
+ "vec build");
}
} else {
- int index = 0;
- for (const auto& arg : c.fArguments) {
- LLVMValueRef value = this->compileExpression(builder, *arg);
- if (arg->fType.kind() == Type::kVector_Kind) {
- for (int i = 0; i < arg->fType.columns(); ++i) {
- LLVMValueRef column = LLVMBuildExtractElement(builder,
- vec,
- LLVMConstInt(fInt32Type,
- i,
- false),
- "construct extract");
- vec = LLVMBuildInsertElement(builder, vec, column,
- LLVMConstInt(fInt32Type, index++, false),
- "vec build 2");
- }
- } else {
- vec = LLVMBuildInsertElement(builder, vec, value,
- LLVMConstInt(fInt32Type, index++, false),
- "vec build 3");
- }
+ SkASSERT(c.fArguments.size() == (size_t) c.fType.columns());
+ for (int i = 0; i < c.fType.columns(); ++i) {
+ vec = LLVMBuildInsertElement(builder, vec,
+ this->compileExpression(builder,
+ *c.fArguments[i]),
+ LLVMConstInt(fInt32Type, i, false),
+ "vec build");
}
}
return vec;
@@ -1630,6 +1460,7 @@ bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr
return this->compileVectorVariableReference(builder, (const VariableReference&) expr,
out);
default:
+ printf("failed expression: %s\n", expr.description().c_str());
return false;
}
}
@@ -1649,6 +1480,7 @@ bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt)
*((const ExpressionStatement&) stmt).fExpression,
&result);
default:
+ printf("failed statement: %s\n", stmt.description().c_str());
return false;
}
}
@@ -1750,7 +1582,7 @@ bool JIT::hasStageSignature(const FunctionDeclaration& f) {
f.fParameters[0]->fModifiers.fFlags == 0 &&
f.fParameters[1]->fType == *fProgram->fContext->fInt_Type &&
f.fParameters[1]->fModifiers.fFlags == 0 &&
- f.fParameters[2]->fType == *fProgram->fContext->fHalf4_Type &&
+ f.fParameters[2]->fType == *fProgram->fContext->fFloat4_Type &&
f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag);
}
@@ -1807,21 +1639,6 @@ void JIT::createModule() {
fPromotedParameters.clear();
fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext);
this->loadBuiltinFunctions();
- LLVMTypeRef fold2Params[1] = { fInt1Vector2Type };
- fFoldAnd2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v2i1",
- LLVMFunctionType(fInt1Type, fold2Params, 1, false));
- fFoldOr2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v2i1",
- LLVMFunctionType(fInt1Type, fold2Params, 1, false));
- LLVMTypeRef fold3Params[1] = { fInt1Vector3Type };
- fFoldAnd3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v3i1",
- LLVMFunctionType(fInt1Type, fold3Params, 1, false));
- fFoldOr3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v3i1",
- LLVMFunctionType(fInt1Type, fold3Params, 1, false));
- LLVMTypeRef fold4Params[1] = { fInt1Vector4Type };
- fFoldAnd4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v4i1",
- LLVMFunctionType(fInt1Type, fold4Params, 1, false));
- fFoldOr4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v4i1",
- LLVMFunctionType(fInt1Type, fold4Params, 1, false));
// LLVM doesn't do void*, have to declare it as int8*
LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType };
fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType,
@@ -1839,15 +1656,13 @@ void JIT::createModule() {
1,
false));
- for (const auto& e : *fProgram) {
- if (e.fKind == ProgramElement::kFunction_Kind) {
- this->compileFunction((FunctionDefinition&) e);
- }
+ for (const auto& e : fProgram->fElements) {
+ SkASSERT(e->fKind == ProgramElement::kFunction_Kind);
+ this->compileFunction((FunctionDefinition&) *e);
}
}
std::unique_ptr<JIT::Module> JIT::compile(std::unique_ptr<Program> program) {
- fCompiler.optimize(*program);
fProgram = std::move(program);
this->createModule();
this->optimize();