aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/sksl/SkSLJIT.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/sksl/SkSLJIT.cpp')
-rw-r--r--src/sksl/SkSLJIT.cpp291
1 files changed, 238 insertions, 53 deletions
diff --git a/src/sksl/SkSLJIT.cpp b/src/sksl/SkSLJIT.cpp
index 115a6be3ae..4120c4aa8c 100644
--- a/src/sksl/SkSLJIT.cpp
+++ b/src/sksl/SkSLJIT.cpp
@@ -14,11 +14,13 @@
#include "SkCpu.h"
#include "SkRasterPipeline.h"
#include "../jumper/SkJumper.h"
+#include "ir/SkSLAppendStage.h"
#include "ir/SkSLExpressionStatement.h"
#include "ir/SkSLFunctionCall.h"
#include "ir/SkSLFunctionReference.h"
#include "ir/SkSLIndexExpression.h"
#include "ir/SkSLProgram.h"
+#include "ir/SkSLUnresolvedFunction.h"
#include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
static constexpr int MAX_VECTOR_COUNT = 16;
@@ -37,6 +39,27 @@ extern "C" void sksl_debug_print(float f) {
printf("Debug: %f\n", f);
}
+extern "C" float sksl_clamp1(float f, float min, float max) {
+ return SkTPin(f, min, max);
+}
+
+using float2 = __attribute__((vector_size(8))) float;
+using float3 = __attribute__((vector_size(16))) float;
+using float4 = __attribute__((vector_size(16))) float;
+
+extern "C" float2 sksl_clamp2(float2 f, float min, float max) {
+ return float2 { SkTPin(f[0], min, max), SkTPin(f[1], min, max) };
+}
+
+extern "C" float3 sksl_clamp3(float3 f, float min, float max) {
+ return float3 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max) };
+}
+
+extern "C" float4 sksl_clamp4(float4 f, float min, float max) {
+ return float4 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max),
+ SkTPin(f[3], min, max) };
+}
+
namespace SkSL {
static constexpr int STAGE_PARAM_COUNT = 12;
@@ -78,6 +101,10 @@ JIT::JIT(Compiler* compiler)
fContext = LLVMContextCreate();
fVoidType = LLVMVoidTypeInContext(fContext);
fInt1Type = LLVMInt1TypeInContext(fContext);
+ fInt1VectorType = LLVMVectorType(fInt1Type, fVectorCount);
+ fInt1Vector2Type = LLVMVectorType(fInt1Type, 2);
+ fInt1Vector3Type = LLVMVectorType(fInt1Type, 3);
+ fInt1Vector4Type = LLVMVectorType(fInt1Type, 4);
fInt8Type = LLVMInt8TypeInContext(fContext);
fInt8PtrType = LLVMPointerType(fInt8Type, 0);
fInt32Type = LLVMInt32TypeInContext(fContext);
@@ -101,6 +128,7 @@ JIT::~JIT() {
void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType,
std::vector<LLVMTypeRef> parameters) {
+ bool found = false;
for (const auto& pair : *fProgram->fSymbols) {
if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) {
const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second;
@@ -117,9 +145,31 @@ void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMType
parameters.data(),
parameters.size(),
false));
+ found = true;
+ }
+ if (Symbol::kUnresolvedFunction_Kind == pair.second->fKind) {
+ // FIXME consolidate this with the code above
+ for (const auto& f : ((const UnresolvedFunction&) *pair.second).fFunctions) {
+ if (pair.first != ourName || returnType != this->getType(f->fReturnType) ||
+ parameters.size() != f->fParameters.size()) {
+ continue;
+ }
+ for (size_t i = 0; i < parameters.size(); ++i) {
+ if (parameters[i] != this->getType(f->fParameters[i]->fType)) {
+ goto next;
+ }
+ }
+ fFunctions[f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(
+ returnType,
+ parameters.data(),
+ parameters.size(),
+ false));
+ found = true;
+ }
}
next:;
}
+ SkASSERT(found);
}
void JIT::loadBuiltinFunctions() {
@@ -128,6 +178,18 @@ void JIT::loadBuiltinFunctions() {
this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type });
this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type });
this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type });
+ this->addBuiltinFunction("clamp", "sksl_clamp1", fFloat32Type, { fFloat32Type,
+ fFloat32Type,
+ fFloat32Type });
+ this->addBuiltinFunction("clamp", "sksl_clamp2", fFloat32Vector2Type, { fFloat32Vector2Type,
+ fFloat32Type,
+ fFloat32Type });
+ this->addBuiltinFunction("clamp", "sksl_clamp3", fFloat32Vector3Type, { fFloat32Vector3Type,
+ fFloat32Type,
+ fFloat32Type });
+ this->addBuiltinFunction("clamp", "sksl_clamp4", fFloat32Vector4Type, { fFloat32Vector4Type,
+ fFloat32Type,
+ fFloat32Type });
this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type });
}
@@ -138,6 +200,14 @@ uint64_t JIT::resolveSymbol(const char* name, JIT* jit) {
result = (uint64_t) &sksl_pipeline_append;
} else if (!strcmp(name, "_sksl_pipeline_append_callback")) {
result = (uint64_t) &sksl_pipeline_append_callback;
+ } else if (!strcmp(name, "_sksl_clamp1")) {
+ result = (uint64_t) &sksl_clamp1;
+ } else if (!strcmp(name, "_sksl_clamp2")) {
+ result = (uint64_t) &sksl_clamp2;
+ } else if (!strcmp(name, "_sksl_clamp3")) {
+ result = (uint64_t) &sksl_clamp3;
+ } else if (!strcmp(name, "_sksl_clamp4")) {
+ result = (uint64_t) &sksl_clamp4;
} else if (!strcmp(name, "_sksl_debug_print")) {
result = (uint64_t) &sksl_debug_print;
} else {
@@ -406,7 +476,7 @@ JIT::TypeKind JIT::typeKind(const Type& type) {
return JIT::kInt_TypeKind;
} else if (type.fName == "uint" || type.fName == "ushort" || type.fName == "ubyte") {
return JIT::kUInt_TypeKind;
- } else if (type.fName == "float" || type.fName == "double") {
+ } else if (type.fName == "float" || type.fName == "double" || type.fName == "half") {
return JIT::kFloat_TypeKind;
}
ABORT("unsupported type: %s\n", type.description().c_str());
@@ -441,7 +511,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression&
LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
this->vectorize(builder, b, &left, &right); \
- switch (this->typeKind(b.fLeft->fType)) { \
+ switch (this->typeKind(b.fLeft->fType)) { \
case kInt_TypeKind: \
return SFunc(builder, left, right, "binary"); \
case kUInt_TypeKind: \
@@ -449,7 +519,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression&
case kFloat_TypeKind: \
return FFunc(builder, left, right, "binary"); \
default: \
- ABORT("unsupported typeKind"); \
+ ABORT("unsupported typeKind"); \
} \
}
#define COMPOUND(SFunc, UFunc, FFunc) { \
@@ -458,7 +528,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression&
LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
this->vectorize(builder, b, &left, &right); \
LLVMValueRef result; \
- switch (this->typeKind(b.fLeft->fType)) { \
+ switch (this->typeKind(b.fLeft->fType)) { \
case kInt_TypeKind: \
result = SFunc(builder, left, right, "binary"); \
break; \
@@ -469,7 +539,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression&
result = FFunc(builder, left, right, "binary"); \
break; \
default: \
- ABORT("unsupported typeKind"); \
+ ABORT("unsupported typeKind"); \
} \
lvalue->store(builder, result); \
return result; \
@@ -510,6 +580,10 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression&
BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
case Token::BITWISEOR:
BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
+ case Token::SHL:
+ BINARY(LLVMBuildShl, LLVMBuildShl, LLVMBuildShl);
+ case Token::SHR:
+ BINARY(LLVMBuildAShr, LLVMBuildLShr, LLVMBuildAShr);
case Token::PLUSEQ:
COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
case Token::MINUSEQ:
@@ -523,13 +597,83 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression&
case Token::BITWISEOREQ:
COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
case Token::EQEQ:
- COMPARE(LLVMBuildICmp, LLVMIntEQ,
- LLVMBuildICmp, LLVMIntEQ,
- LLVMBuildFCmp, LLVMRealOEQ);
+ switch (b.fLeft->fType.kind()) {
+ case Type::kScalar_Kind:
+ COMPARE(LLVMBuildICmp, LLVMIntEQ,
+ LLVMBuildICmp, LLVMIntEQ,
+ LLVMBuildFCmp, LLVMRealOEQ);
+ case Type::kVector_Kind: {
+ LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
+ LLVMValueRef right = this->compileExpression(builder, *b.fRight);
+ this->vectorize(builder, b, &left, &right);
+ LLVMValueRef value;
+ switch (this->typeKind(b.fLeft->fType)) {
+ case kInt_TypeKind:
+ value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
+ break;
+ case kUInt_TypeKind:
+ value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
+ break;
+ case kFloat_TypeKind:
+ value = LLVMBuildFCmp(builder, LLVMRealOEQ, left, right, "binary");
+ break;
+ default:
+ ABORT("unsupported typeKind");
+ }
+ LLVMValueRef args[1] = { value };
+ LLVMValueRef func;
+ switch (b.fLeft->fType.columns()) {
+ case 2: func = fFoldAnd2Func; break;
+ case 3: func = fFoldAnd3Func; break;
+ case 4: func = fFoldAnd4Func; break;
+ default:
+ SkASSERT(false);
+ func = fFoldAnd2Func;
+ }
+ return LLVMBuildCall(builder, func, args, 1, "all");
+ }
+ default:
+ SkASSERT(false);
+ }
case Token::NEQ:
- COMPARE(LLVMBuildICmp, LLVMIntNE,
- LLVMBuildICmp, LLVMIntNE,
- LLVMBuildFCmp, LLVMRealONE);
+ switch (b.fLeft->fType.kind()) {
+ case Type::kScalar_Kind:
+ COMPARE(LLVMBuildICmp, LLVMIntNE,
+ LLVMBuildICmp, LLVMIntNE,
+ LLVMBuildFCmp, LLVMRealONE);
+ case Type::kVector_Kind: {
+ LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
+ LLVMValueRef right = this->compileExpression(builder, *b.fRight);
+ this->vectorize(builder, b, &left, &right);
+ LLVMValueRef value;
+ switch (this->typeKind(b.fLeft->fType)) {
+ case kInt_TypeKind:
+ value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
+ break;
+ case kUInt_TypeKind:
+ value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
+ break;
+ case kFloat_TypeKind:
+ value = LLVMBuildFCmp(builder, LLVMRealONE, left, right, "binary");
+ break;
+ default:
+ ABORT("unsupported typeKind");
+ }
+ LLVMValueRef args[1] = { value };
+ LLVMValueRef func;
+ switch (b.fLeft->fType.columns()) {
+ case 2: func = fFoldOr2Func; break;
+ case 3: func = fFoldOr3Func; break;
+ case 4: func = fFoldOr4Func; break;
+ default:
+ SkASSERT(false);
+ func = fFoldOr2Func;
+ }
+ return LLVMBuildCall(builder, func, args, 1, "all");
+ }
+ default:
+ SkASSERT(false);
+ }
case Token::LT:
COMPARE(LLVMBuildICmp, LLVMIntSLT,
LLVMBuildICmp, LLVMIntULT,
@@ -583,6 +727,7 @@ LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression&
return phi;
}
default:
+ printf("%s\n", b.description().c_str());
ABORT("unsupported binary operator");
}
}
@@ -702,9 +847,9 @@ void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) {
const FunctionDeclaration& functionDecl =
*((FunctionReference&) *a.fArguments[1]).fFunctions[0];
bool found = false;
- for (const auto& pe : fProgram->fElements) {
- if (ProgramElement::kFunction_Kind == pe->fKind) {
- const FunctionDefinition& def = (const FunctionDefinition&) *pe;
+ for (const auto& pe : *fProgram) {
+ if (ProgramElement::kFunction_Kind == pe.fKind) {
+ const FunctionDefinition& def = (const FunctionDefinition&) pe;
if (&def.fDeclaration == &functionDecl) {
LLVMValueRef fn = this->compileStageFunction(def);
LLVMValueRef args[2] = {
@@ -747,49 +892,74 @@ LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor&
TypeKind from = this->typeKind(c.fArguments[0]->fType);
TypeKind to = this->typeKind(c.fType);
LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]);
- if (kFloat_TypeKind == to) {
- if (kInt_TypeKind == from) {
- return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast");
- }
- if (kUInt_TypeKind == from) {
- return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast");
- }
- }
- if (kInt_TypeKind == to) {
- if (kFloat_TypeKind == from) {
- return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast");
- }
- if (kUInt_TypeKind == from) {
- return base;
- }
- }
- if (kUInt_TypeKind == to) {
- if (kFloat_TypeKind == from) {
- return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast");
- }
- if (kInt_TypeKind == from) {
- return base;
- }
+ switch (to) {
+ case kFloat_TypeKind:
+ switch (from) {
+ case kInt_TypeKind:
+ return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast");
+ case kUInt_TypeKind:
+ return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast");
+ case kFloat_TypeKind:
+ return base;
+ case kBool_TypeKind:
+ SkASSERT(false);
+ }
+ case kInt_TypeKind:
+ switch (from) {
+ case kInt_TypeKind:
+ return base;
+ case kUInt_TypeKind:
+ return base;
+ case kFloat_TypeKind:
+ return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast");
+ case kBool_TypeKind:
+ SkASSERT(false);
+ }
+ case kUInt_TypeKind:
+ switch (from) {
+ case kInt_TypeKind:
+ return base;
+ case kUInt_TypeKind:
+ return base;
+ case kFloat_TypeKind:
+ return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast");
+ case kBool_TypeKind:
+ SkASSERT(false);
+ }
+ case kBool_TypeKind:
+ SkASSERT(false);
}
- ABORT("unsupported constructor");
}
case Type::kVector_Kind: {
LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType));
- if (c.fArguments.size() == 1) {
+ if (c.fArguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]);
for (int i = 0; i < c.fType.columns(); ++i) {
vec = LLVMBuildInsertElement(builder, vec, value,
LLVMConstInt(fInt32Type, i, false),
- "vec build");
+ "vec build 1");
}
} else {
- SkASSERT(c.fArguments.size() == (size_t) c.fType.columns());
- for (int i = 0; i < c.fType.columns(); ++i) {
- vec = LLVMBuildInsertElement(builder, vec,
- this->compileExpression(builder,
- *c.fArguments[i]),
- LLVMConstInt(fInt32Type, i, false),
- "vec build");
+ int index = 0;
+ for (const auto& arg : c.fArguments) {
+ LLVMValueRef value = this->compileExpression(builder, *arg);
+ if (arg->fType.kind() == Type::kVector_Kind) {
+ for (int i = 0; i < arg->fType.columns(); ++i) {
+ LLVMValueRef column = LLVMBuildExtractElement(builder,
+ vec,
+ LLVMConstInt(fInt32Type,
+ i,
+ false),
+ "construct extract");
+ vec = LLVMBuildInsertElement(builder, vec, column,
+ LLVMConstInt(fInt32Type, index++, false),
+ "vec build 2");
+ }
+ } else {
+ vec = LLVMBuildInsertElement(builder, vec, value,
+ LLVMConstInt(fInt32Type, index++, false),
+ "vec build 3");
+ }
}
}
return vec;
@@ -1460,7 +1630,6 @@ bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr
return this->compileVectorVariableReference(builder, (const VariableReference&) expr,
out);
default:
- printf("failed expression: %s\n", expr.description().c_str());
return false;
}
}
@@ -1480,7 +1649,6 @@ bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt)
*((const ExpressionStatement&) stmt).fExpression,
&result);
default:
- printf("failed statement: %s\n", stmt.description().c_str());
return false;
}
}
@@ -1582,7 +1750,7 @@ bool JIT::hasStageSignature(const FunctionDeclaration& f) {
f.fParameters[0]->fModifiers.fFlags == 0 &&
f.fParameters[1]->fType == *fProgram->fContext->fInt_Type &&
f.fParameters[1]->fModifiers.fFlags == 0 &&
- f.fParameters[2]->fType == *fProgram->fContext->fFloat4_Type &&
+ f.fParameters[2]->fType == *fProgram->fContext->fHalf4_Type &&
f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag);
}
@@ -1639,6 +1807,21 @@ void JIT::createModule() {
fPromotedParameters.clear();
fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext);
this->loadBuiltinFunctions();
+ LLVMTypeRef fold2Params[1] = { fInt1Vector2Type };
+ fFoldAnd2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v2i1",
+ LLVMFunctionType(fInt1Type, fold2Params, 1, false));
+ fFoldOr2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v2i1",
+ LLVMFunctionType(fInt1Type, fold2Params, 1, false));
+ LLVMTypeRef fold3Params[1] = { fInt1Vector3Type };
+ fFoldAnd3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v3i1",
+ LLVMFunctionType(fInt1Type, fold3Params, 1, false));
+ fFoldOr3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v3i1",
+ LLVMFunctionType(fInt1Type, fold3Params, 1, false));
+ LLVMTypeRef fold4Params[1] = { fInt1Vector4Type };
+ fFoldAnd4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v4i1",
+ LLVMFunctionType(fInt1Type, fold4Params, 1, false));
+ fFoldOr4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v4i1",
+ LLVMFunctionType(fInt1Type, fold4Params, 1, false));
// LLVM doesn't do void*, have to declare it as int8*
LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType };
fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType,
@@ -1656,13 +1839,15 @@ void JIT::createModule() {
1,
false));
- for (const auto& e : fProgram->fElements) {
- SkASSERT(e->fKind == ProgramElement::kFunction_Kind);
- this->compileFunction((FunctionDefinition&) *e);
+ for (const auto& e : *fProgram) {
+ if (e.fKind == ProgramElement::kFunction_Kind) {
+ this->compileFunction((FunctionDefinition&) e);
+ }
}
}
std::unique_ptr<JIT::Module> JIT::compile(std::unique_ptr<Program> program) {
+ fCompiler.optimize(*program);
fProgram = std::move(program);
this->createModule();
this->optimize();