aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/sksl/SkSLJIT.h
blob: b23e31237f20168d454e418c703c85546601bbd9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
/*
 * Copyright 2018 Google Inc.
 *
 * Use of this source code is governed by a BSD-style license that can be
 * found in the LICENSE file.
 */

#ifndef SKSL_JIT
#define SKSL_JIT

#ifdef SK_LLVM_AVAILABLE

#include "ir/SkSLAppendStage.h"
#include "ir/SkSLBinaryExpression.h"
#include "ir/SkSLBreakStatement.h"
#include "ir/SkSLContinueStatement.h"
#include "ir/SkSLExpression.h"
#include "ir/SkSLDoStatement.h"
#include "ir/SkSLForStatement.h"
#include "ir/SkSLFunctionCall.h"
#include "ir/SkSLFunctionDefinition.h"
#include "ir/SkSLIfStatement.h"
#include "ir/SkSLIndexExpression.h"
#include "ir/SkSLPrefixExpression.h"
#include "ir/SkSLPostfixExpression.h"
#include "ir/SkSLProgram.h"
#include "ir/SkSLReturnStatement.h"
#include "ir/SkSLStatement.h"
#include "ir/SkSLSwizzle.h"
#include "ir/SkSLTernaryExpression.h"
#include "ir/SkSLVarDeclarationsStatement.h"
#include "ir/SkSLVariableReference.h"
#include "ir/SkSLWhileStatement.h"

#include "llvm-c/Analysis.h"
#include "llvm-c/Core.h"
#include "llvm-c/OrcBindings.h"
#include "llvm-c/Support.h"
#include "llvm-c/Target.h"
#include "llvm-c/Transforms/PassManagerBuilder.h"
#include "llvm-c/Types.h"
#include <stack>

class SkRasterPipeline;

namespace SkSL {

/**
 * A just-in-time compiler for SkSL code which uses an LLVM backend. Only available when the
 * skia_llvm_path gn arg is set.
 *
 * Example of using SkSLJIT to set up an SkJumper pipeline stage:
 *
 * #ifdef SK_LLVM_AVAILABLE
 *   SkSL::Compiler compiler;
 *   SkSL::Program::Settings settings;
 *   std::unique_ptr<SkSL::Program> program = compiler.convertProgram(SkSL::Program::kCPU_Kind,
 *       "void swap(int x, int y, inout float4 color) {"
 *       "    color.rb = color.br;"
 *       "}",
 *       settings);
 *   if (!program) {
 *       printf("%s\n", compiler.errorText().c_str());
 *       abort();
 *   }
 *   SkSL::JIT& jit = *scratch->make<SkSL::JIT>(&compiler);
 *   std::unique_ptr<SkSL::JIT::Module> module = jit.compile(std::move(program));
 *   void* func = module->getJumperStage("swap");
 *   p->append(func, nullptr);
 * #endif
 */
class JIT {
    typedef int StackIndex;

public:
    class Module {
    public:
        /**
         * Returns the address of a symbol in the module.
         */
        void* getSymbol(const char* name);

        /**
         * Returns the address of a function as an SkJumper pipeline stage. The function must have
         * the signature void <name>(int x, int y, inout float4 color). The returned function will
         * have the correct signature to function as an SkJumper stage (meaning it will actually
         * have a different signature at runtime, accepting vector parameters and operating on
         * multiple pixels simultaneously as is normal for SkJumper stages).
         */
        void* getJumperStage(const char* name);

        ~Module() {
            LLVMOrcDisposeSharedModuleRef(fSharedModule);
        }

    private:
        Module(std::unique_ptr<Program> program,
               LLVMSharedModuleRef sharedModule,
               LLVMOrcJITStackRef jitStack)
        : fProgram(std::move(program))
        , fSharedModule(sharedModule)
        , fJITStack(jitStack) {}

        std::unique_ptr<Program> fProgram;
        LLVMSharedModuleRef fSharedModule;
        LLVMOrcJITStackRef fJITStack;

        friend class JIT;
    };

    JIT(Compiler* compiler);

    ~JIT();

    /**
     * Just-in-time compiles an SkSL program and returns the resulting Module. The JIT must not be
     * destroyed before all of its Modules are destroyed.
     */
    std::unique_ptr<Module> compile(std::unique_ptr<Program> program);

private:
    static constexpr int CHANNELS = 4;

    enum TypeKind {
        kFloat_TypeKind,
        kInt_TypeKind,
        kUInt_TypeKind,
        kBool_TypeKind
    };

    class LValue {
    public:
        virtual ~LValue() {}

        virtual LLVMValueRef load(LLVMBuilderRef builder) = 0;

        virtual void store(LLVMBuilderRef builder, LLVMValueRef value) = 0;
    };

    void addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType,
                            std::vector<LLVMTypeRef> parameters);

    void loadBuiltinFunctions();

    void setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block);

    LLVMTypeRef getType(const Type& type);

    TypeKind typeKind(const Type& type);

    std::unique_ptr<LValue> getLValue(LLVMBuilderRef builder, const Expression& expr);

    void vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns);

    void vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left,
                   LLVMValueRef* right);

    LLVMValueRef compileBinary(LLVMBuilderRef builder, const BinaryExpression& b);

    LLVMValueRef compileConstructor(LLVMBuilderRef builder, const Constructor& c);

    LLVMValueRef compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc);

    LLVMValueRef compileIndex(LLVMBuilderRef builder, const IndexExpression& v);

    LLVMValueRef compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p);

    LLVMValueRef compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p);

    LLVMValueRef compileSwizzle(LLVMBuilderRef builder, const Swizzle& s);

    LLVMValueRef compileVariableReference(LLVMBuilderRef builder, const VariableReference& v);

    LLVMValueRef compileTernary(LLVMBuilderRef builder, const TernaryExpression& t);

    LLVMValueRef compileExpression(LLVMBuilderRef builder, const Expression& expr);

    void appendStage(LLVMBuilderRef builder, const AppendStage& a);

    void compileBlock(LLVMBuilderRef builder, const Block& block);

    void compileBreak(LLVMBuilderRef builder, const BreakStatement& b);

    void compileContinue(LLVMBuilderRef builder, const ContinueStatement& c);

    void compileDo(LLVMBuilderRef builder, const DoStatement& d);

    void compileFor(LLVMBuilderRef builder, const ForStatement& f);

    void compileIf(LLVMBuilderRef builder, const IfStatement& i);

    void compileReturn(LLVMBuilderRef builder, const ReturnStatement& r);

    void compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls);

    void compileWhile(LLVMBuilderRef builder, const WhileStatement& w);

    void compileStatement(LLVMBuilderRef builder, const Statement& stmt);

    // The "Vector" variants of functions attempt to compile a given expression or statement as part
    // of a vectorized SkJumper stage function - that is, with r, g, b, and a each being vectors of
    // fVectorCount floats. So a statement like "color.r = 0;" looks like it modifies a single
    // channel of a single pixel, but the compiled code will actually modify the red channel of
    // fVectorCount pixels at once.
    //
    // As not everything can be vectorized, these calls return a bool to indicate whether they were
    // successful. If anything anywhere in the function cannot be vectorized, the JIT will fall back
    // to looping over the pixels instead.
    //
    // Since we process multiple pixels at once, and each pixel consists of multiple color channels,
    // expressions may effectively result in a vector-of-vectors. We produce zero to four outputs
    // when compiling expression, each of which is a vector, so that e.g. float2(1, 0) actually
    // produces two vectors, one containing all 1s, the other all 0s. The out parameter always
    // allows for 4 channels, but the functions produce 0 to 4 channels depending on the type they
    // are operating on. Thus evaluating "color.rgb" actually fills in out[0] through out[2],
    // leaving out[3] uninitialized.
    // As the number of outputs can be inferred from the type of the expression, it is not
    // explicitly signalled anywhere.
    bool compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b,
                             LLVMValueRef out[CHANNELS]);

    bool compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c,
                                  LLVMValueRef out[CHANNELS]);

    bool compileVectorFloatLiteral(LLVMBuilderRef builder, const FloatLiteral& f,
                                   LLVMValueRef out[CHANNELS]);

    bool compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s,
                              LLVMValueRef out[CHANNELS]);

    bool compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v,
                                        LLVMValueRef out[CHANNELS]);

    bool compileVectorExpression(LLVMBuilderRef builder, const Expression& expr,
                                 LLVMValueRef out[CHANNELS]);

    bool getVectorLValue(LLVMBuilderRef builder, const Expression& e, LLVMValueRef out[CHANNELS]);

    /**
     * Evaluates the left and right operands of a binary operation, promoting one of them to a
     * vector if necessary to make the types match.
     */
    bool getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left,
                                 LLVMValueRef outLeft[CHANNELS], const Expression& right,
                                 LLVMValueRef outRight[CHANNELS]);

    bool compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt);

    /**
     * Returns true if this function has the signature void(int, int, inout float4) and thus can be
     * used as an SkJumper stage.
     */
    bool hasStageSignature(const FunctionDeclaration& f);

    /**
     * Attempts to compile a vectorized stage function, returning true on success. A stage function
     * of e.g. "color.r = 0;" will produce code which sets the entire red vector to zeros in a
     * single instruction, thus calculating several pixels at once.
     */
    bool compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc);

    /**
     * Fallback function which loops over the pixels, for when vectorization fails. A stage function
     * of e.g. "color.r = 0;" will produce a loop which iterates over the entries in the red vector,
     * setting each one to zero individually.
     */
    void compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc);

    /**
     * Called when compiling a function which has the signature of an SkJumper stage. Produces a
     * version of the function which can be plugged into SkJumper (thus having a signature which
     * accepts four vectors, one for each color channel, containing the color data of multiple
     * pixels at once). To go from SkSL code which operates on a single pixel at a time to CPU code
     * which operates on multiple pixels at once, the code is either vectorized using
     * compileStageFunctionVector or wrapped in a loop using compileStageFunctionLoop.
     */
    LLVMValueRef compileStageFunction(const FunctionDefinition& f);

    /**
     * Compiles an SkSL function to an LLVM function. If the function has the signature of an
     * SkJumper stage, it will *also* be compiled by compileStageFunction, resulting in both a stage
     * and non-stage version of the function.
     */
    LLVMValueRef compileFunction(const FunctionDefinition& f);

    void createModule();

    void optimize();

    bool isColorRef(const Expression& expr);

    static uint64_t resolveSymbol(const char* name, JIT* jit);

    const char* fCPU;
    int fVectorCount;
    Compiler& fCompiler;
    std::unique_ptr<Program> fProgram;
    LLVMContextRef fContext;
    LLVMModuleRef fModule;
    LLVMSharedModuleRef fSharedModule;
    LLVMOrcJITStackRef fJITStack;
    LLVMValueRef fCurrentFunction;
    LLVMBasicBlockRef fAllocaBlock;
    LLVMBasicBlockRef fCurrentBlock;
    LLVMTypeRef fVoidType;
    LLVMTypeRef fInt1Type;
    LLVMTypeRef fInt8Type;
    LLVMTypeRef fInt8PtrType;
    LLVMTypeRef fInt32Type;
    LLVMTypeRef fInt32VectorType;
    LLVMTypeRef fInt32Vector2Type;
    LLVMTypeRef fInt32Vector3Type;
    LLVMTypeRef fInt32Vector4Type;
    LLVMTypeRef fInt64Type;
    LLVMTypeRef fSizeTType;
    LLVMTypeRef fFloat32Type;
    LLVMTypeRef fFloat32VectorType;
    LLVMTypeRef fFloat32Vector2Type;
    LLVMTypeRef fFloat32Vector3Type;
    LLVMTypeRef fFloat32Vector4Type;
    // Our SkSL stage functions have a single float4 for color, but the actual SkJumper stage
    // function has four separate vectors, one for each channel. These four values are references to
    // the red, green, blue, and alpha vectors respectively.
    LLVMValueRef fChannels[CHANNELS];
    // when processing a stage function, this points to the SkSL color parameter (an inout float4)
    const Variable* fColorParam;
    std::map<const FunctionDeclaration*, LLVMValueRef> fFunctions;
    std::map<const Variable*, LLVMValueRef> fVariables;
    // LLVM function parameters are read-only, so when modifying function parameters we need to
    // first promote them to variables. This keeps track of which parameters have been promoted.
    std::set<const Variable*> fPromotedParameters;
    std::vector<LLVMBasicBlockRef> fBreakTarget;
    std::vector<LLVMBasicBlockRef> fContinueTarget;

    LLVMValueRef fAppendFunc;
    LLVMValueRef fAppendCallbackFunc;
    LLVMValueRef fDebugFunc;
};

} // namespace

#endif // SK_LLVM_AVAILABLE

#endif // SKSL_JIT