diff options
author | ethannicholas <ethannicholas@google.com> | 2016-07-25 10:08:54 -0700 |
---|---|---|
committer | Commit bot <commit-bot@chromium.org> | 2016-07-25 10:08:54 -0700 |
commit | d598f7981f34811e6f2a949207dc13638852f3f7 (patch) | |
tree | 83dd4cf4983f90125651a0ab380f4f71cb3e27f2 /src | |
parent | d9ddad2952cdfd0809249abbd94a285abdb6b1d0 (diff) |
SkSL performance improvements (plus a couple of minor warning fixes)
GOLD_TRYBOT_URL= https://gold.skia.org/search?issue=2131223002
Committed: https://skia.googlesource.com/skia/+/9fd67a1f53809f5eff1210dd107241b450c48acc
Review-Url: https://codereview.chromium.org/2131223002
Diffstat (limited to 'src')
36 files changed, 1032 insertions, 956 deletions
diff --git a/src/gpu/vk/GrVkPipelineStateBuilder.cpp b/src/gpu/vk/GrVkPipelineStateBuilder.cpp index 323ea66946..d9d1b6cfb8 100644 --- a/src/gpu/vk/GrVkPipelineStateBuilder.cpp +++ b/src/gpu/vk/GrVkPipelineStateBuilder.cpp @@ -93,8 +93,6 @@ shaderc_shader_kind vk_shader_stage_to_shaderc_kind(VkShaderStageFlagBits stage) } #endif -#include <fstream> -#include <sstream> bool GrVkPipelineStateBuilder::CreateVkShaderModule(const GrVkGpu* gpu, VkShaderStageFlagBits stage, const GrGLSLShaderBuilder& builder, diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp index 2b4adc1026..0d65b107ec 100644 --- a/src/sksl/SkSLCompiler.cpp +++ b/src/sksl/SkSLCompiler.cpp @@ -41,9 +41,10 @@ Compiler::Compiler() : fErrorCount(0) { auto types = std::shared_ptr<SymbolTable>(new SymbolTable(*this)); auto symbols = std::shared_ptr<SymbolTable>(new SymbolTable(types, *this)); - fIRGenerator = new IRGenerator(symbols, *this); + fIRGenerator = new IRGenerator(&fContext, symbols, *this); fTypes = types; - #define ADD_TYPE(t) types->add(k ## t ## _Type->fName, k ## t ## _Type) + #define ADD_TYPE(t) types->addWithoutOwnership(fContext.f ## t ## _Type->fName, \ + fContext.f ## t ## _Type.get()) ADD_TYPE(Void); ADD_TYPE(Float); ADD_TYPE(Vec2); @@ -185,19 +186,21 @@ std::unique_ptr<Program> Compiler::convertProgram(Program::Kind kind, std::strin fErrorText = ""; fErrorCount = 0; fIRGenerator->pushSymbolTable(); - std::vector<std::unique_ptr<ProgramElement>> result; + std::vector<std::unique_ptr<ProgramElement>> elements; switch (kind) { case Program::kVertex_Kind: - this->internalConvertProgram(SKSL_VERT_INCLUDE, &result); + this->internalConvertProgram(SKSL_VERT_INCLUDE, &elements); break; case Program::kFragment_Kind: - this->internalConvertProgram(SKSL_FRAG_INCLUDE, &result); + this->internalConvertProgram(SKSL_FRAG_INCLUDE, &elements); break; } - this->internalConvertProgram(text, &result); + this->internalConvertProgram(text, &elements); + auto result = std::unique_ptr<Program>(new Program(kind, std::move(elements), + fIRGenerator->fSymbolTable));; fIRGenerator->popSymbolTable(); this->writeErrorCount(); - return std::unique_ptr<Program>(new Program(kind, std::move(result)));; + return result; } void Compiler::error(Position position, std::string msg) { @@ -224,7 +227,7 @@ void Compiler::writeErrorCount() { bool Compiler::toSPIRV(Program::Kind kind, std::string text, std::ostream& out) { auto program = this->convertProgram(kind, text); if (fErrorCount == 0) { - SkSL::SPIRVCodeGenerator cg; + SkSL::SPIRVCodeGenerator cg(&fContext); cg.generateCode(*program.get(), out); ASSERT(!out.rdstate()); } diff --git a/src/sksl/SkSLCompiler.h b/src/sksl/SkSLCompiler.h index 2209427eef..e63d5f4ed8 100644 --- a/src/sksl/SkSLCompiler.h +++ b/src/sksl/SkSLCompiler.h @@ -11,6 +11,7 @@ #include <vector> #include "ir/SkSLProgram.h" #include "ir/SkSLSymbolTable.h" +#include "SkSLContext.h" #include "SkSLErrorReporter.h" namespace SkSL { @@ -50,6 +51,7 @@ private: IRGenerator* fIRGenerator; std::string fSkiaVertText; // FIXME store parsed version instead + Context fContext; int fErrorCount; std::string fErrorText; }; diff --git a/src/sksl/SkSLContext.h b/src/sksl/SkSLContext.h new file mode 100644 index 0000000000..1f124d05eb --- /dev/null +++ b/src/sksl/SkSLContext.h @@ -0,0 +1,227 @@ +/* + * Copyright 2016 Google Inc. + * + * Use of this source code is governed by a BSD-style license that can be + * found in the LICENSE file. + */ + +#ifndef SKSL_CONTEXT +#define SKSL_CONTEXT + +#include "ir/SkSLType.h" + +namespace SkSL { + +/** + * Contains compiler-wide objects, which currently means the core types. + */ +class Context { +public: + Context() + : fVoid_Type(new Type("void")) + , fDouble_Type(new Type("double", true)) + , fDVec2_Type(new Type("dvec2", *fDouble_Type, 2)) + , fDVec3_Type(new Type("dvec3", *fDouble_Type, 3)) + , fDVec4_Type(new Type("dvec4", *fDouble_Type, 4)) + , fFloat_Type(new Type("float", true, { fDouble_Type.get() })) + , fVec2_Type(new Type("vec2", *fFloat_Type, 2)) + , fVec3_Type(new Type("vec3", *fFloat_Type, 3)) + , fVec4_Type(new Type("vec4", *fFloat_Type, 4)) + , fUInt_Type(new Type("uint", true, { fFloat_Type.get(), fDouble_Type.get() })) + , fUVec2_Type(new Type("uvec2", *fUInt_Type, 2)) + , fUVec3_Type(new Type("uvec3", *fUInt_Type, 3)) + , fUVec4_Type(new Type("uvec4", *fUInt_Type, 4)) + , fInt_Type(new Type("int", true, { fUInt_Type.get(), fFloat_Type.get(), fDouble_Type.get() })) + , fIVec2_Type(new Type("ivec2", *fInt_Type, 2)) + , fIVec3_Type(new Type("ivec3", *fInt_Type, 3)) + , fIVec4_Type(new Type("ivec4", *fInt_Type, 4)) + , fBool_Type(new Type("bool", false)) + , fBVec2_Type(new Type("bvec2", *fBool_Type, 2)) + , fBVec3_Type(new Type("bvec3", *fBool_Type, 3)) + , fBVec4_Type(new Type("bvec4", *fBool_Type, 4)) + , fMat2x2_Type(new Type("mat2", *fFloat_Type, 2, 2)) + , fMat2x3_Type(new Type("mat2x3", *fFloat_Type, 2, 3)) + , fMat2x4_Type(new Type("mat2x4", *fFloat_Type, 2, 4)) + , fMat3x2_Type(new Type("mat3x2", *fFloat_Type, 3, 2)) + , fMat3x3_Type(new Type("mat3", *fFloat_Type, 3, 3)) + , fMat3x4_Type(new Type("mat3x4", *fFloat_Type, 3, 4)) + , fMat4x2_Type(new Type("mat4x2", *fFloat_Type, 4, 2)) + , fMat4x3_Type(new Type("mat4x3", *fFloat_Type, 4, 3)) + , fMat4x4_Type(new Type("mat4", *fFloat_Type, 4, 4)) + , fDMat2x2_Type(new Type("dmat2", *fFloat_Type, 2, 2)) + , fDMat2x3_Type(new Type("dmat2x3", *fFloat_Type, 2, 3)) + , fDMat2x4_Type(new Type("dmat2x4", *fFloat_Type, 2, 4)) + , fDMat3x2_Type(new Type("dmat3x2", *fFloat_Type, 3, 2)) + , fDMat3x3_Type(new Type("dmat3", *fFloat_Type, 3, 3)) + , fDMat3x4_Type(new Type("dmat3x4", *fFloat_Type, 3, 4)) + , fDMat4x2_Type(new Type("dmat4x2", *fFloat_Type, 4, 2)) + , fDMat4x3_Type(new Type("dmat4x3", *fFloat_Type, 4, 3)) + , fDMat4x4_Type(new Type("dmat4", *fFloat_Type, 4, 4)) + , fSampler1D_Type(new Type("sampler1D", SpvDim1D, false, false, false, true)) + , fSampler2D_Type(new Type("sampler2D", SpvDim2D, false, false, false, true)) + , fSampler3D_Type(new Type("sampler3D", SpvDim3D, false, false, false, true)) + , fSamplerCube_Type(new Type("samplerCube")) + , fSampler2DRect_Type(new Type("sampler2DRect")) + , fSampler1DArray_Type(new Type("sampler1DArray")) + , fSampler2DArray_Type(new Type("sampler2DArray")) + , fSamplerCubeArray_Type(new Type("samplerCubeArray")) + , fSamplerBuffer_Type(new Type("samplerBuffer")) + , fSampler2DMS_Type(new Type("sampler2DMS")) + , fSampler2DMSArray_Type(new Type("sampler2DMSArray")) + , fSampler1DShadow_Type(new Type("sampler1DShadow")) + , fSampler2DShadow_Type(new Type("sampler2DShadow")) + , fSamplerCubeShadow_Type(new Type("samplerCubeShadow")) + , fSampler2DRectShadow_Type(new Type("sampler2DRectShadow")) + , fSampler1DArrayShadow_Type(new Type("sampler1DArrayShadow")) + , fSampler2DArrayShadow_Type(new Type("sampler2DArrayShadow")) + , fSamplerCubeArrayShadow_Type(new Type("samplerCubeArrayShadow")) + // FIXME figure out what we're supposed to do with the gsampler et al. types) + , fGSampler1D_Type(new Type("$gsampler1D", static_type(*fSampler1D_Type))) + , fGSampler2D_Type(new Type("$gsampler2D", static_type(*fSampler2D_Type))) + , fGSampler3D_Type(new Type("$gsampler3D", static_type(*fSampler3D_Type))) + , fGSamplerCube_Type(new Type("$gsamplerCube", static_type(*fSamplerCube_Type))) + , fGSampler2DRect_Type(new Type("$gsampler2DRect", static_type(*fSampler2DRect_Type))) + , fGSampler1DArray_Type(new Type("$gsampler1DArray", static_type(*fSampler1DArray_Type))) + , fGSampler2DArray_Type(new Type("$gsampler2DArray", static_type(*fSampler2DArray_Type))) + , fGSamplerCubeArray_Type(new Type("$gsamplerCubeArray", static_type(*fSamplerCubeArray_Type))) + , fGSamplerBuffer_Type(new Type("$gsamplerBuffer", static_type(*fSamplerBuffer_Type))) + , fGSampler2DMS_Type(new Type("$gsampler2DMS", static_type(*fSampler2DMS_Type))) + , fGSampler2DMSArray_Type(new Type("$gsampler2DMSArray", static_type(*fSampler2DMSArray_Type))) + , fGSampler2DArrayShadow_Type(new Type("$gsampler2DArrayShadow", + static_type(*fSampler2DArrayShadow_Type))) + , fGSamplerCubeArrayShadow_Type(new Type("$gsamplerCubeArrayShadow", + static_type(*fSamplerCubeArrayShadow_Type))) + , fGenType_Type(new Type("$genType", { fFloat_Type.get(), fVec2_Type.get(), fVec3_Type.get(), + fVec4_Type.get() })) + , fGenDType_Type(new Type("$genDType", { fDouble_Type.get(), fDVec2_Type.get(), + fDVec3_Type.get(), fDVec4_Type.get() })) + , fGenIType_Type(new Type("$genIType", { fInt_Type.get(), fIVec2_Type.get(), fIVec3_Type.get(), + fIVec4_Type.get() })) + , fGenUType_Type(new Type("$genUType", { fUInt_Type.get(), fUVec2_Type.get(), fUVec3_Type.get(), + fUVec4_Type.get() })) + , fGenBType_Type(new Type("$genBType", { fBool_Type.get(), fBVec2_Type.get(), fBVec3_Type.get(), + fBVec4_Type.get() })) + , fMat_Type(new Type("$mat")) + , fVec_Type(new Type("$vec", { fVec2_Type.get(), fVec2_Type.get(), fVec3_Type.get(), + fVec4_Type.get() })) + , fGVec_Type(new Type("$gvec")) + , fGVec2_Type(new Type("$gvec2")) + , fGVec3_Type(new Type("$gvec3")) + , fGVec4_Type(new Type("$gvec4", static_type(*fVec4_Type))) + , fDVec_Type(new Type("$dvec")) + , fIVec_Type(new Type("$ivec")) + , fUVec_Type(new Type("$uvec")) + , fBVec_Type(new Type("$bvec", { fBVec2_Type.get(), fBVec2_Type.get(), fBVec3_Type.get(), + fBVec4_Type.get() })) + , fInvalid_Type(new Type("<INVALID>")) {} + + static std::vector<const Type*> static_type(const Type& t) { + return { &t, &t, &t, &t }; + } + + const std::unique_ptr<Type> fVoid_Type; + + const std::unique_ptr<Type> fDouble_Type; + const std::unique_ptr<Type> fDVec2_Type; + const std::unique_ptr<Type> fDVec3_Type; + const std::unique_ptr<Type> fDVec4_Type; + + const std::unique_ptr<Type> fFloat_Type; + const std::unique_ptr<Type> fVec2_Type; + const std::unique_ptr<Type> fVec3_Type; + const std::unique_ptr<Type> fVec4_Type; + + const std::unique_ptr<Type> fUInt_Type; + const std::unique_ptr<Type> fUVec2_Type; + const std::unique_ptr<Type> fUVec3_Type; + const std::unique_ptr<Type> fUVec4_Type; + + const std::unique_ptr<Type> fInt_Type; + const std::unique_ptr<Type> fIVec2_Type; + const std::unique_ptr<Type> fIVec3_Type; + const std::unique_ptr<Type> fIVec4_Type; + + const std::unique_ptr<Type> fBool_Type; + const std::unique_ptr<Type> fBVec2_Type; + const std::unique_ptr<Type> fBVec3_Type; + const std::unique_ptr<Type> fBVec4_Type; + + const std::unique_ptr<Type> fMat2x2_Type; + const std::unique_ptr<Type> fMat2x3_Type; + const std::unique_ptr<Type> fMat2x4_Type; + const std::unique_ptr<Type> fMat3x2_Type; + const std::unique_ptr<Type> fMat3x3_Type; + const std::unique_ptr<Type> fMat3x4_Type; + const std::unique_ptr<Type> fMat4x2_Type; + const std::unique_ptr<Type> fMat4x3_Type; + const std::unique_ptr<Type> fMat4x4_Type; + + const std::unique_ptr<Type> fDMat2x2_Type; + const std::unique_ptr<Type> fDMat2x3_Type; + const std::unique_ptr<Type> fDMat2x4_Type; + const std::unique_ptr<Type> fDMat3x2_Type; + const std::unique_ptr<Type> fDMat3x3_Type; + const std::unique_ptr<Type> fDMat3x4_Type; + const std::unique_ptr<Type> fDMat4x2_Type; + const std::unique_ptr<Type> fDMat4x3_Type; + const std::unique_ptr<Type> fDMat4x4_Type; + + const std::unique_ptr<Type> fSampler1D_Type; + const std::unique_ptr<Type> fSampler2D_Type; + const std::unique_ptr<Type> fSampler3D_Type; + const std::unique_ptr<Type> fSamplerCube_Type; + const std::unique_ptr<Type> fSampler2DRect_Type; + const std::unique_ptr<Type> fSampler1DArray_Type; + const std::unique_ptr<Type> fSampler2DArray_Type; + const std::unique_ptr<Type> fSamplerCubeArray_Type; + const std::unique_ptr<Type> fSamplerBuffer_Type; + const std::unique_ptr<Type> fSampler2DMS_Type; + const std::unique_ptr<Type> fSampler2DMSArray_Type; + const std::unique_ptr<Type> fSampler1DShadow_Type; + const std::unique_ptr<Type> fSampler2DShadow_Type; + const std::unique_ptr<Type> fSamplerCubeShadow_Type; + const std::unique_ptr<Type> fSampler2DRectShadow_Type; + const std::unique_ptr<Type> fSampler1DArrayShadow_Type; + const std::unique_ptr<Type> fSampler2DArrayShadow_Type; + const std::unique_ptr<Type> fSamplerCubeArrayShadow_Type; + + const std::unique_ptr<Type> fGSampler1D_Type; + const std::unique_ptr<Type> fGSampler2D_Type; + const std::unique_ptr<Type> fGSampler3D_Type; + const std::unique_ptr<Type> fGSamplerCube_Type; + const std::unique_ptr<Type> fGSampler2DRect_Type; + const std::unique_ptr<Type> fGSampler1DArray_Type; + const std::unique_ptr<Type> fGSampler2DArray_Type; + const std::unique_ptr<Type> fGSamplerCubeArray_Type; + const std::unique_ptr<Type> fGSamplerBuffer_Type; + const std::unique_ptr<Type> fGSampler2DMS_Type; + const std::unique_ptr<Type> fGSampler2DMSArray_Type; + const std::unique_ptr<Type> fGSampler2DArrayShadow_Type; + const std::unique_ptr<Type> fGSamplerCubeArrayShadow_Type; + + const std::unique_ptr<Type> fGenType_Type; + const std::unique_ptr<Type> fGenDType_Type; + const std::unique_ptr<Type> fGenIType_Type; + const std::unique_ptr<Type> fGenUType_Type; + const std::unique_ptr<Type> fGenBType_Type; + + const std::unique_ptr<Type> fMat_Type; + + const std::unique_ptr<Type> fVec_Type; + + const std::unique_ptr<Type> fGVec_Type; + const std::unique_ptr<Type> fGVec2_Type; + const std::unique_ptr<Type> fGVec3_Type; + const std::unique_ptr<Type> fGVec4_Type; + const std::unique_ptr<Type> fDVec_Type; + const std::unique_ptr<Type> fIVec_Type; + const std::unique_ptr<Type> fUVec_Type; + + const std::unique_ptr<Type> fBVec_Type; + + const std::unique_ptr<Type> fInvalid_Type; +}; + +} // namespace + +#endif diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp index 2cc7eacb4d..f250c4bb0c 100644 --- a/src/sksl/SkSLIRGenerator.cpp +++ b/src/sksl/SkSLIRGenerator.cpp @@ -66,11 +66,12 @@ public: std::shared_ptr<SymbolTable> fPrevious; }; -IRGenerator::IRGenerator(std::shared_ptr<SymbolTable> symbolTable, +IRGenerator::IRGenerator(const Context* context, std::shared_ptr<SymbolTable> symbolTable, ErrorReporter& errorReporter) -: fSymbolTable(std::move(symbolTable)) -, fErrors(errorReporter) { -} +: fContext(*context) +, fCurrentFunction(nullptr) +, fSymbolTable(std::move(symbolTable)) +, fErrors(errorReporter) {} void IRGenerator::pushSymbolTable() { fSymbolTable.reset(new SymbolTable(std::move(fSymbolTable), fErrors)); @@ -123,7 +124,7 @@ std::unique_ptr<Block> IRGenerator::convertBlock(const ASTBlock& block) { } statements.push_back(std::move(statement)); } - return std::unique_ptr<Block>(new Block(block.fPosition, std::move(statements))); + return std::unique_ptr<Block>(new Block(block.fPosition, std::move(statements), fSymbolTable)); } std::unique_ptr<Statement> IRGenerator::convertVarDeclarationStatement( @@ -141,22 +142,22 @@ Modifiers IRGenerator::convertModifiers(const ASTModifiers& modifiers) { std::unique_ptr<VarDeclaration> IRGenerator::convertVarDeclaration(const ASTVarDeclaration& decl, Variable::Storage storage) { - std::vector<std::shared_ptr<Variable>> variables; + std::vector<const Variable*> variables; std::vector<std::vector<std::unique_ptr<Expression>>> sizes; std::vector<std::unique_ptr<Expression>> values; - std::shared_ptr<Type> baseType = this->convertType(*decl.fType); + const Type* baseType = this->convertType(*decl.fType); if (!baseType) { return nullptr; } for (size_t i = 0; i < decl.fNames.size(); i++) { Modifiers modifiers = this->convertModifiers(decl.fModifiers); - std::shared_ptr<Type> type = baseType; + const Type* type = baseType; ASSERT(type->kind() != Type::kArray_Kind); std::vector<std::unique_ptr<Expression>> currentVarSizes; for (size_t j = 0; j < decl.fSizes[i].size(); j++) { if (decl.fSizes[i][j]) { ASTExpression& rawSize = *decl.fSizes[i][j]; - auto size = this->coerce(this->convertExpression(rawSize), kInt_Type); + auto size = this->coerce(this->convertExpression(rawSize), *fContext.fInt_Type); if (!size) { return nullptr; } @@ -172,27 +173,28 @@ std::unique_ptr<VarDeclaration> IRGenerator::convertVarDeclaration(const ASTVarD count = -1; name += "[]"; } - type = std::shared_ptr<Type>(new Type(name, Type::kArray_Kind, type, (int) count)); + type = new Type(name, Type::kArray_Kind, *type, (int) count); + fSymbolTable->takeOwnership((Type*) type); currentVarSizes.push_back(std::move(size)); } else { - type = std::shared_ptr<Type>(new Type(type->fName + "[]", Type::kArray_Kind, type, - -1)); + type = new Type(type->fName + "[]", Type::kArray_Kind, *type, -1); + fSymbolTable->takeOwnership((Type*) type); currentVarSizes.push_back(nullptr); } } sizes.push_back(std::move(currentVarSizes)); - auto var = std::make_shared<Variable>(decl.fPosition, modifiers, decl.fNames[i], type, - storage); - variables.push_back(var); + auto var = std::unique_ptr<Variable>(new Variable(decl.fPosition, modifiers, decl.fNames[i], + *type, storage)); std::unique_ptr<Expression> value; if (decl.fValues[i]) { value = this->convertExpression(*decl.fValues[i]); if (!value) { return nullptr; } - value = this->coerce(std::move(value), type); + value = this->coerce(std::move(value), *type); } - fSymbolTable->add(var->fName, var); + variables.push_back(var.get()); + fSymbolTable->add(decl.fNames[i], std::move(var)); values.push_back(std::move(value)); } return std::unique_ptr<VarDeclaration>(new VarDeclaration(decl.fPosition, std::move(variables), @@ -200,7 +202,8 @@ std::unique_ptr<VarDeclaration> IRGenerator::convertVarDeclaration(const ASTVarD } std::unique_ptr<Statement> IRGenerator::convertIf(const ASTIfStatement& s) { - std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*s.fTest), kBool_Type); + std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*s.fTest), + *fContext.fBool_Type); if (!test) { return nullptr; } @@ -225,7 +228,8 @@ std::unique_ptr<Statement> IRGenerator::convertFor(const ASTForStatement& f) { if (!initializer) { return nullptr; } - std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*f.fTest), kBool_Type); + std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*f.fTest), + *fContext.fBool_Type); if (!test) { return nullptr; } @@ -240,11 +244,12 @@ std::unique_ptr<Statement> IRGenerator::convertFor(const ASTForStatement& f) { } return std::unique_ptr<Statement>(new ForStatement(f.fPosition, std::move(initializer), std::move(test), std::move(next), - std::move(statement))); + std::move(statement), fSymbolTable)); } std::unique_ptr<Statement> IRGenerator::convertWhile(const ASTWhileStatement& w) { - std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*w.fTest), kBool_Type); + std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*w.fTest), + *fContext.fBool_Type); if (!test) { return nullptr; } @@ -257,7 +262,8 @@ std::unique_ptr<Statement> IRGenerator::convertWhile(const ASTWhileStatement& w) } std::unique_ptr<Statement> IRGenerator::convertDo(const ASTDoStatement& d) { - std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*d.fTest), kBool_Type); + std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*d.fTest), + *fContext.fBool_Type); if (!test) { return nullptr; } @@ -286,7 +292,7 @@ std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTReturnStatement& if (!result) { return nullptr; } - if (fCurrentFunction->fReturnType == kVoid_Type) { + if (fCurrentFunction->fReturnType == *fContext.fVoid_Type) { fErrors.error(result->fPosition, "may not return a value from a void function"); } else { result = this->coerce(std::move(result), fCurrentFunction->fReturnType); @@ -296,9 +302,9 @@ std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTReturnStatement& } return std::unique_ptr<Statement>(new ReturnStatement(std::move(result))); } else { - if (fCurrentFunction->fReturnType != kVoid_Type) { + if (fCurrentFunction->fReturnType != *fContext.fVoid_Type) { fErrors.error(r.fPosition, "expected function to return '" + - fCurrentFunction->fReturnType->description() + "'"); + fCurrentFunction->fReturnType.description() + "'"); } return std::unique_ptr<Statement>(new ReturnStatement(r.fPosition)); } @@ -316,80 +322,74 @@ std::unique_ptr<Statement> IRGenerator::convertDiscard(const ASTDiscardStatement return std::unique_ptr<Statement>(new DiscardStatement(d.fPosition)); } -static std::shared_ptr<Type> expand_generics(std::shared_ptr<Type> type, int i) { - if (type->kind() == Type::kGeneric_Kind) { - return type->coercibleTypes()[i]; +static const Type& expand_generics(const Type& type, int i) { + if (type.kind() == Type::kGeneric_Kind) { + return *type.coercibleTypes()[i]; } return type; } -static void expand_generics(FunctionDeclaration& decl, - SymbolTable& symbolTable) { +static void expand_generics(const FunctionDeclaration& decl, + std::shared_ptr<SymbolTable> symbolTable) { for (int i = 0; i < 4; i++) { - std::shared_ptr<Type> returnType = expand_generics(decl.fReturnType, i); - std::vector<std::shared_ptr<Variable>> arguments; + const Type& returnType = expand_generics(decl.fReturnType, i); + std::vector<const Variable*> parameters; for (const auto& p : decl.fParameters) { - arguments.push_back(std::shared_ptr<Variable>(new Variable( - p->fPosition, - Modifiers(p->fModifiers), - p->fName, - expand_generics(p->fType, i), - Variable::kParameter_Storage))); + Variable* var = new Variable(p->fPosition, Modifiers(p->fModifiers), p->fName, + expand_generics(p->fType, i), + Variable::kParameter_Storage); + symbolTable->takeOwnership(var); + parameters.push_back(var); } - std::shared_ptr<FunctionDeclaration> expanded(new FunctionDeclaration( - decl.fPosition, - decl.fName, - std::move(arguments), - std::move(returnType))); - symbolTable.add(expanded->fName, expanded); + symbolTable->add(decl.fName, std::unique_ptr<FunctionDeclaration>(new FunctionDeclaration( + decl.fPosition, + decl.fName, + std::move(parameters), + std::move(returnType)))); } } std::unique_ptr<FunctionDefinition> IRGenerator::convertFunction(const ASTFunction& f) { - std::shared_ptr<SymbolTable> old = fSymbolTable; - AutoSymbolTable table(this); bool isGeneric; - std::shared_ptr<Type> returnType = this->convertType(*f.fReturnType); + const Type* returnType = this->convertType(*f.fReturnType); if (!returnType) { return nullptr; } isGeneric = returnType->kind() == Type::kGeneric_Kind; - std::vector<std::shared_ptr<Variable>> parameters; + std::vector<const Variable*> parameters; for (const auto& param : f.fParameters) { - std::shared_ptr<Type> type = this->convertType(*param->fType); + const Type* type = this->convertType(*param->fType); if (!type) { return nullptr; } for (int j = (int) param->fSizes.size() - 1; j >= 0; j--) { int size = param->fSizes[j]; std::string name = type->name() + "[" + to_string(size) + "]"; - type = std::shared_ptr<Type>(new Type(std::move(name), Type::kArray_Kind, - std::move(type), size)); + Type* newType = new Type(std::move(name), Type::kArray_Kind, *type, size); + fSymbolTable->takeOwnership(newType); + type = newType; } std::string name = param->fName; Modifiers modifiers = this->convertModifiers(param->fModifiers); Position pos = param->fPosition; - std::shared_ptr<Variable> var = std::shared_ptr<Variable>(new Variable( - pos, - modifiers, - std::move(name), - type, - Variable::kParameter_Storage)); - parameters.push_back(std::move(var)); + Variable* var = new Variable(pos, modifiers, std::move(name), *type, + Variable::kParameter_Storage); + fSymbolTable->takeOwnership(var); + parameters.push_back(var); isGeneric |= type->kind() == Type::kGeneric_Kind; } // find existing declaration - std::shared_ptr<FunctionDeclaration> decl; - auto entry = (*old)[f.fName]; + const FunctionDeclaration* decl = nullptr; + auto entry = (*fSymbolTable)[f.fName]; if (entry) { - std::vector<std::shared_ptr<FunctionDeclaration>> functions; + std::vector<const FunctionDeclaration*> functions; switch (entry->fKind) { case Symbol::kUnresolvedFunction_Kind: - functions = std::static_pointer_cast<UnresolvedFunction>(entry)->fFunctions; + functions = ((UnresolvedFunction*) entry)->fFunctions; break; case Symbol::kFunctionDeclaration_Kind: - functions.push_back(std::static_pointer_cast<FunctionDeclaration>(entry)); + functions.push_back((FunctionDeclaration*) entry); break; default: fErrors.error(f.fPosition, "symbol '" + f.fName + "' was already defined"); @@ -406,11 +406,8 @@ std::unique_ptr<FunctionDefinition> IRGenerator::convertFunction(const ASTFuncti } } if (match) { - if (returnType != other->fReturnType) { - FunctionDeclaration newDecl = FunctionDeclaration(f.fPosition, - f.fName, - parameters, - returnType); + if (*returnType != other->fReturnType) { + FunctionDeclaration newDecl(f.fPosition, f.fName, parameters, *returnType); fErrors.error(f.fPosition, "functions '" + newDecl.description() + "' and '" + other->description() + "' differ only in return type"); @@ -424,7 +421,6 @@ std::unique_ptr<FunctionDefinition> IRGenerator::convertFunction(const ASTFuncti "declaration and definition"); return nullptr; } - fSymbolTable->add(parameters[i]->fName, decl->fParameters[i]); } if (other->fDefined) { fErrors.error(f.fPosition, "duplicate definition of " + @@ -437,28 +433,36 @@ std::unique_ptr<FunctionDefinition> IRGenerator::convertFunction(const ASTFuncti } if (!decl) { // couldn't find an existing declaration - decl.reset(new FunctionDeclaration(f.fPosition, f.fName, parameters, returnType)); - for (auto var : parameters) { - fSymbolTable->add(var->fName, var); + if (isGeneric) { + ASSERT(!f.fBody); + expand_generics(FunctionDeclaration(f.fPosition, f.fName, parameters, *returnType), + fSymbolTable); + } else { + auto newDecl = std::unique_ptr<FunctionDeclaration>(new FunctionDeclaration( + f.fPosition, + f.fName, + parameters, + *returnType)); + decl = newDecl.get(); + fSymbolTable->add(decl->fName, std::move(newDecl)); } } - if (isGeneric) { - ASSERT(!f.fBody); - expand_generics(*decl, *old); - } else { - old->add(decl->fName, decl); - if (f.fBody) { - ASSERT(!fCurrentFunction); - fCurrentFunction = decl; - decl->fDefined = true; - std::unique_ptr<Block> body = this->convertBlock(*f.fBody); - fCurrentFunction = nullptr; - if (!body) { - return nullptr; - } - return std::unique_ptr<FunctionDefinition>(new FunctionDefinition(f.fPosition, decl, - std::move(body))); + if (f.fBody) { + ASSERT(!fCurrentFunction); + fCurrentFunction = decl; + decl->fDefined = true; + std::shared_ptr<SymbolTable> old = fSymbolTable; + AutoSymbolTable table(this); + for (size_t i = 0; i < parameters.size(); i++) { + fSymbolTable->addWithoutOwnership(parameters[i]->fName, decl->fParameters[i]); + } + std::unique_ptr<Block> body = this->convertBlock(*f.fBody); + fCurrentFunction = nullptr; + if (!body) { + return nullptr; } + return std::unique_ptr<FunctionDefinition>(new FunctionDefinition(f.fPosition, *decl, + std::move(body))); } return nullptr; } @@ -488,28 +492,26 @@ std::unique_ptr<InterfaceBlock> IRGenerator::convertInterfaceBlock(const ASTInte } } } - std::shared_ptr<Type> type = std::shared_ptr<Type>(new Type(intf.fInterfaceName, fields)); + Type* type = new Type(intf.fInterfaceName, fields); + fSymbolTable->takeOwnership(type); std::string name = intf.fValueName.length() > 0 ? intf.fValueName : intf.fInterfaceName; - std::shared_ptr<Variable> var = std::shared_ptr<Variable>(new Variable(intf.fPosition, mods, - name, type, - Variable::kGlobal_Storage)); + Variable* var = new Variable(intf.fPosition, mods, name, *type, Variable::kGlobal_Storage); + fSymbolTable->takeOwnership(var); if (intf.fValueName.length()) { - old->add(intf.fValueName, var); - + old->addWithoutOwnership(intf.fValueName, var); } else { for (size_t i = 0; i < fields.size(); i++) { - std::shared_ptr<Field> field = std::shared_ptr<Field>(new Field(intf.fPosition, var, - (int) i)); - old->add(fields[i].fName, field); + old->add(fields[i].fName, std::unique_ptr<Field>(new Field(intf.fPosition, *var, + (int) i))); } } - return std::unique_ptr<InterfaceBlock>(new InterfaceBlock(intf.fPosition, var)); + return std::unique_ptr<InterfaceBlock>(new InterfaceBlock(intf.fPosition, *var, fSymbolTable)); } -std::shared_ptr<Type> IRGenerator::convertType(const ASTType& type) { - std::shared_ptr<Symbol> result = (*fSymbolTable)[type.fName]; +const Type* IRGenerator::convertType(const ASTType& type) { + const Symbol* result = (*fSymbolTable)[type.fName]; if (result && result->fKind == Symbol::kType_Kind) { - return std::static_pointer_cast<Type>(result); + return (const Type*) result; } fErrors.error(type.fPosition, "unknown type '" + type.fName + "'"); return nullptr; @@ -520,13 +522,13 @@ std::unique_ptr<Expression> IRGenerator::convertExpression(const ASTExpression& case ASTExpression::kIdentifier_Kind: return this->convertIdentifier((ASTIdentifier&) expr); case ASTExpression::kBool_Kind: - return std::unique_ptr<Expression>(new BoolLiteral(expr.fPosition, + return std::unique_ptr<Expression>(new BoolLiteral(fContext, expr.fPosition, ((ASTBoolLiteral&) expr).fValue)); case ASTExpression::kInt_Kind: - return std::unique_ptr<Expression>(new IntLiteral(expr.fPosition, + return std::unique_ptr<Expression>(new IntLiteral(fContext, expr.fPosition, ((ASTIntLiteral&) expr).fValue)); case ASTExpression::kFloat_Kind: - return std::unique_ptr<Expression>(new FloatLiteral(expr.fPosition, + return std::unique_ptr<Expression>(new FloatLiteral(fContext, expr.fPosition, ((ASTFloatLiteral&) expr).fValue)); case ASTExpression::kBinary_Kind: return this->convertBinaryExpression((ASTBinaryExpression&) expr); @@ -542,40 +544,42 @@ std::unique_ptr<Expression> IRGenerator::convertExpression(const ASTExpression& } std::unique_ptr<Expression> IRGenerator::convertIdentifier(const ASTIdentifier& identifier) { - std::shared_ptr<Symbol> result = (*fSymbolTable)[identifier.fText]; + const Symbol* result = (*fSymbolTable)[identifier.fText]; if (!result) { fErrors.error(identifier.fPosition, "unknown identifier '" + identifier.fText + "'"); return nullptr; } switch (result->fKind) { case Symbol::kFunctionDeclaration_Kind: { - std::vector<std::shared_ptr<FunctionDeclaration>> f = { - std::static_pointer_cast<FunctionDeclaration>(result) + std::vector<const FunctionDeclaration*> f = { + (const FunctionDeclaration*) result }; - return std::unique_ptr<FunctionReference>(new FunctionReference(identifier.fPosition, - std::move(f))); + return std::unique_ptr<FunctionReference>(new FunctionReference(fContext, + identifier.fPosition, + f)); } case Symbol::kUnresolvedFunction_Kind: { - auto f = std::static_pointer_cast<UnresolvedFunction>(result); - return std::unique_ptr<FunctionReference>(new FunctionReference(identifier.fPosition, + const UnresolvedFunction* f = (const UnresolvedFunction*) result; + return std::unique_ptr<FunctionReference>(new FunctionReference(fContext, + identifier.fPosition, f->fFunctions)); } case Symbol::kVariable_Kind: { - std::shared_ptr<Variable> var = std::static_pointer_cast<Variable>(result); - this->markReadFrom(var); + const Variable* var = (const Variable*) result; + this->markReadFrom(*var); return std::unique_ptr<VariableReference>(new VariableReference(identifier.fPosition, - std::move(var))); + *var)); } case Symbol::kField_Kind: { - std::shared_ptr<Field> field = std::static_pointer_cast<Field>(result); + const Field* field = (const Field*) result; VariableReference* base = new VariableReference(identifier.fPosition, field->fOwner); return std::unique_ptr<Expression>(new FieldAccess(std::unique_ptr<Expression>(base), field->fFieldIndex)); } case Symbol::kType_Kind: { - auto t = std::static_pointer_cast<Type>(result); - return std::unique_ptr<TypeReference>(new TypeReference(identifier.fPosition, - std::move(t))); + const Type* t = (const Type*) result; + return std::unique_ptr<TypeReference>(new TypeReference(fContext, identifier.fPosition, + *t)); } default: ABORT("unsupported symbol type %d\n", result->fKind); @@ -584,43 +588,45 @@ std::unique_ptr<Expression> IRGenerator::convertIdentifier(const ASTIdentifier& } std::unique_ptr<Expression> IRGenerator::coerce(std::unique_ptr<Expression> expr, - std::shared_ptr<Type> type) { + const Type& type) { if (!expr) { return nullptr; } - if (*expr->fType == *type) { + if (expr->fType == type) { return expr; } this->checkValid(*expr); - if (*expr->fType == *kInvalid_Type) { + if (expr->fType == *fContext.fInvalid_Type) { return nullptr; } - if (!expr->fType->canCoerceTo(type)) { - fErrors.error(expr->fPosition, "expected '" + type->description() + "', but found '" + - expr->fType->description() + "'"); + if (!expr->fType.canCoerceTo(type)) { + fErrors.error(expr->fPosition, "expected '" + type.description() + "', but found '" + + expr->fType.description() + "'"); return nullptr; } - if (type->kind() == Type::kScalar_Kind) { + if (type.kind() == Type::kScalar_Kind) { std::vector<std::unique_ptr<Expression>> args; args.push_back(std::move(expr)); - ASTIdentifier id(Position(), type->description()); + ASTIdentifier id(Position(), type.description()); std::unique_ptr<Expression> ctor = this->convertIdentifier(id); ASSERT(ctor); return this->call(Position(), std::move(ctor), std::move(args)); } - ABORT("cannot coerce %s to %s", expr->fType->description().c_str(), - type->description().c_str()); + ABORT("cannot coerce %s to %s", expr->fType.description().c_str(), + type.description().c_str()); } /** * Determines the operand and result types of a binary expression. Returns true if the expression is * legal, false otherwise. If false, the values of the out parameters are undefined. */ -static bool determine_binary_type(Token::Kind op, std::shared_ptr<Type> left, - std::shared_ptr<Type> right, - std::shared_ptr<Type>* outLeftType, - std::shared_ptr<Type>* outRightType, - std::shared_ptr<Type>* outResultType, +static bool determine_binary_type(const Context& context, + Token::Kind op, + const Type& left, + const Type& right, + const Type** outLeftType, + const Type** outRightType, + const Type** outResultType, bool tryFlipped) { bool isLogical; switch (op) { @@ -638,24 +644,25 @@ static bool determine_binary_type(Token::Kind op, std::shared_ptr<Type> left, case Token::LOGICALOREQ: // fall through case Token::LOGICALANDEQ: // fall through case Token::LOGICALXOREQ: - *outLeftType = kBool_Type; - *outRightType = kBool_Type; - *outResultType = kBool_Type; - return left->canCoerceTo(kBool_Type) && right->canCoerceTo(kBool_Type); + *outLeftType = context.fBool_Type.get(); + *outRightType = context.fBool_Type.get(); + *outResultType = context.fBool_Type.get(); + return left.canCoerceTo(*context.fBool_Type) && + right.canCoerceTo(*context.fBool_Type); case Token::STAR: // fall through case Token::STAREQ: // FIXME need to handle non-square matrices - if (left->kind() == Type::kMatrix_Kind && right->kind() == Type::kVector_Kind) { - *outLeftType = left; - *outRightType = right; - *outResultType = right; - return left->rows() == right->columns(); + if (left.kind() == Type::kMatrix_Kind && right.kind() == Type::kVector_Kind) { + *outLeftType = &left; + *outRightType = &right; + *outResultType = &right; + return left.rows() == right.columns(); } - if (left->kind() == Type::kVector_Kind && right->kind() == Type::kMatrix_Kind) { - *outLeftType = left; - *outRightType = right; - *outResultType = left; - return left->columns() == right->columns(); + if (left.kind() == Type::kVector_Kind && right.kind() == Type::kMatrix_Kind) { + *outLeftType = &left; + *outRightType = &right; + *outResultType = &left; + return left.columns() == right.columns(); } // fall through default: @@ -664,41 +671,42 @@ static bool determine_binary_type(Token::Kind op, std::shared_ptr<Type> left, // FIXME: need to disallow illegal operations like vec3 > vec3. Also do not currently have // full support for numbers other than float. if (left == right) { - *outLeftType = left; - *outRightType = left; + *outLeftType = &left; + *outRightType = &left; if (isLogical) { - *outResultType = kBool_Type; + *outResultType = context.fBool_Type.get(); } else { - *outResultType = left; + *outResultType = &left; } return true; } // FIXME: incorrect for shift operations - if (left->canCoerceTo(right)) { - *outLeftType = right; - *outRightType = right; + if (left.canCoerceTo(right)) { + *outLeftType = &right; + *outRightType = &right; if (isLogical) { - *outResultType = kBool_Type; + *outResultType = context.fBool_Type.get(); } else { - *outResultType = right; + *outResultType = &right; } return true; } - if ((left->kind() == Type::kVector_Kind || left->kind() == Type::kMatrix_Kind) && - (right->kind() == Type::kScalar_Kind)) { - if (determine_binary_type(op, left->componentType(), right, outLeftType, outRightType, - outResultType, false)) { - *outLeftType = (*outLeftType)->toCompound(left->columns(), left->rows()); + if ((left.kind() == Type::kVector_Kind || left.kind() == Type::kMatrix_Kind) && + (right.kind() == Type::kScalar_Kind)) { + if (determine_binary_type(context, op, left.componentType(), right, outLeftType, + outRightType, outResultType, false)) { + *outLeftType = &(*outLeftType)->toCompound(context, left.columns(), left.rows()); if (!isLogical) { - *outResultType = (*outResultType)->toCompound(left->columns(), left->rows()); + *outResultType = &(*outResultType)->toCompound(context, left.columns(), + left.rows()); } return true; } return false; } if (tryFlipped) { - return determine_binary_type(op, right, left, outRightType, outLeftType, outResultType, - false); + return determine_binary_type(context, op, right, left, outRightType, outLeftType, + outResultType, false); } return false; } @@ -713,15 +721,15 @@ std::unique_ptr<Expression> IRGenerator::convertBinaryExpression( if (!right) { return nullptr; } - std::shared_ptr<Type> leftType; - std::shared_ptr<Type> rightType; - std::shared_ptr<Type> resultType; - if (!determine_binary_type(expression.fOperator, left->fType, right->fType, &leftType, + const Type* leftType; + const Type* rightType; + const Type* resultType; + if (!determine_binary_type(fContext, expression.fOperator, left->fType, right->fType, &leftType, &rightType, &resultType, true)) { fErrors.error(expression.fPosition, "type mismatch: '" + Token::OperatorName(expression.fOperator) + - "' cannot operate on '" + left->fType->fName + - "', '" + right->fType->fName + "'"); + "' cannot operate on '" + left->fType.fName + + "', '" + right->fType.fName + "'"); return nullptr; } switch (expression.fOperator) { @@ -744,17 +752,18 @@ std::unique_ptr<Expression> IRGenerator::convertBinaryExpression( break; } return std::unique_ptr<Expression>(new BinaryExpression(expression.fPosition, - this->coerce(std::move(left), leftType), + this->coerce(std::move(left), + *leftType), expression.fOperator, this->coerce(std::move(right), - rightType), - resultType)); + *rightType), + *resultType)); } std::unique_ptr<Expression> IRGenerator::convertTernaryExpression( const ASTTernaryExpression& expression) { std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*expression.fTest), - kBool_Type); + *fContext.fBool_Type); if (!test) { return nullptr; } @@ -766,34 +775,33 @@ std::unique_ptr<Expression> IRGenerator::convertTernaryExpression( if (!ifFalse) { return nullptr; } - std::shared_ptr<Type> trueType; - std::shared_ptr<Type> falseType; - std::shared_ptr<Type> resultType; - if (!determine_binary_type(Token::EQEQ, ifTrue->fType, ifFalse->fType, &trueType, + const Type* trueType; + const Type* falseType; + const Type* resultType; + if (!determine_binary_type(fContext, Token::EQEQ, ifTrue->fType, ifFalse->fType, &trueType, &falseType, &resultType, true)) { fErrors.error(expression.fPosition, "ternary operator result mismatch: '" + - ifTrue->fType->fName + "', '" + - ifFalse->fType->fName + "'"); + ifTrue->fType.fName + "', '" + + ifFalse->fType.fName + "'"); return nullptr; } ASSERT(trueType == falseType); - ifTrue = this->coerce(std::move(ifTrue), trueType); - ifFalse = this->coerce(std::move(ifFalse), falseType); + ifTrue = this->coerce(std::move(ifTrue), *trueType); + ifFalse = this->coerce(std::move(ifFalse), *falseType); return std::unique_ptr<Expression>(new TernaryExpression(expression.fPosition, std::move(test), std::move(ifTrue), std::move(ifFalse))); } -std::unique_ptr<Expression> IRGenerator::call( - Position position, - std::shared_ptr<FunctionDeclaration> function, - std::vector<std::unique_ptr<Expression>> arguments) { - if (function->fParameters.size() != arguments.size()) { - std::string msg = "call to '" + function->fName + "' expected " + - to_string(function->fParameters.size()) + +std::unique_ptr<Expression> IRGenerator::call(Position position, + const FunctionDeclaration& function, + std::vector<std::unique_ptr<Expression>> arguments) { + if (function.fParameters.size() != arguments.size()) { + std::string msg = "call to '" + function.fName + "' expected " + + to_string(function.fParameters.size()) + " argument"; - if (function->fParameters.size() != 1) { + if (function.fParameters.size() != 1) { msg += "s"; } msg += ", but found " + to_string(arguments.size()); @@ -801,12 +809,12 @@ std::unique_ptr<Expression> IRGenerator::call( return nullptr; } for (size_t i = 0; i < arguments.size(); i++) { - arguments[i] = this->coerce(std::move(arguments[i]), function->fParameters[i]->fType); - if (arguments[i] && (function->fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag)) { + arguments[i] = this->coerce(std::move(arguments[i]), function.fParameters[i]->fType); + if (arguments[i] && (function.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag)) { this->markWrittenTo(*arguments[i]); } } - return std::unique_ptr<FunctionCall>(new FunctionCall(position, std::move(function), + return std::unique_ptr<FunctionCall>(new FunctionCall(position, function, std::move(arguments))); } @@ -815,16 +823,16 @@ std::unique_ptr<Expression> IRGenerator::call( * if the cost could be computed, false if the call is not valid. Cost has no particular meaning * other than "lower costs are preferred". */ -bool IRGenerator::determineCallCost(std::shared_ptr<FunctionDeclaration> function, +bool IRGenerator::determineCallCost(const FunctionDeclaration& function, const std::vector<std::unique_ptr<Expression>>& arguments, int* outCost) { - if (function->fParameters.size() != arguments.size()) { + if (function.fParameters.size() != arguments.size()) { return false; } int total = 0; for (size_t i = 0; i < arguments.size(); i++) { int cost; - if (arguments[i]->fType->determineCoercionCost(function->fParameters[i]->fType, &cost)) { + if (arguments[i]->fType.determineCoercionCost(function.fParameters[i]->fType, &cost)) { total += cost; } else { return false; @@ -848,97 +856,97 @@ std::unique_ptr<Expression> IRGenerator::call(Position position, } FunctionReference* ref = (FunctionReference*) functionValue.get(); int bestCost = INT_MAX; - std::shared_ptr<FunctionDeclaration> best; + const FunctionDeclaration* best = nullptr; if (ref->fFunctions.size() > 1) { for (const auto& f : ref->fFunctions) { int cost; - if (this->determineCallCost(f, arguments, &cost) && cost < bestCost) { + if (this->determineCallCost(*f, arguments, &cost) && cost < bestCost) { bestCost = cost; best = f; } } if (best) { - return this->call(position, std::move(best), std::move(arguments)); + return this->call(position, *best, std::move(arguments)); } std::string msg = "no match for " + ref->fFunctions[0]->fName + "("; std::string separator = ""; for (size_t i = 0; i < arguments.size(); i++) { msg += separator; separator = ", "; - msg += arguments[i]->fType->description(); + msg += arguments[i]->fType.description(); } msg += ")"; fErrors.error(position, msg); return nullptr; } - return this->call(position, ref->fFunctions[0], std::move(arguments)); + return this->call(position, *ref->fFunctions[0], std::move(arguments)); } std::unique_ptr<Expression> IRGenerator::convertConstructor( Position position, - std::shared_ptr<Type> type, + const Type& type, std::vector<std::unique_ptr<Expression>> args) { // FIXME: add support for structs and arrays - Type::Kind kind = type->kind(); - if (!type->isNumber() && kind != Type::kVector_Kind && kind != Type::kMatrix_Kind) { - fErrors.error(position, "cannot construct '" + type->description() + "'"); + Type::Kind kind = type.kind(); + if (!type.isNumber() && kind != Type::kVector_Kind && kind != Type::kMatrix_Kind) { + fErrors.error(position, "cannot construct '" + type.description() + "'"); return nullptr; } - if (type == kFloat_Type && args.size() == 1 && + if (type == *fContext.fFloat_Type && args.size() == 1 && args[0]->fKind == Expression::kIntLiteral_Kind) { int64_t value = ((IntLiteral&) *args[0]).fValue; - return std::unique_ptr<Expression>(new FloatLiteral(position, (double) value)); + return std::unique_ptr<Expression>(new FloatLiteral(fContext, position, (double) value)); } if (args.size() == 1 && args[0]->fType == type) { // argument is already the right type, just return it return std::move(args[0]); } - if (type->isNumber()) { + if (type.isNumber()) { if (args.size() != 1) { - fErrors.error(position, "invalid arguments to '" + type->description() + + fErrors.error(position, "invalid arguments to '" + type.description() + "' constructor, (expected exactly 1 argument, but found " + to_string(args.size()) + ")"); } - if (args[0]->fType == kBool_Type) { - std::unique_ptr<IntLiteral> zero(new IntLiteral(position, 0)); - std::unique_ptr<IntLiteral> one(new IntLiteral(position, 1)); + if (args[0]->fType == *fContext.fBool_Type) { + std::unique_ptr<IntLiteral> zero(new IntLiteral(fContext, position, 0)); + std::unique_ptr<IntLiteral> one(new IntLiteral(fContext, position, 1)); return std::unique_ptr<Expression>( new TernaryExpression(position, std::move(args[0]), this->coerce(std::move(one), type), this->coerce(std::move(zero), type))); - } else if (!args[0]->fType->isNumber()) { - fErrors.error(position, "invalid argument to '" + type->description() + + } else if (!args[0]->fType.isNumber()) { + fErrors.error(position, "invalid argument to '" + type.description() + "' constructor (expected a number or bool, but found '" + - args[0]->fType->description() + "')"); + args[0]->fType.description() + "')"); } } else { ASSERT(kind == Type::kVector_Kind || kind == Type::kMatrix_Kind); int actual = 0; for (size_t i = 0; i < args.size(); i++) { - if (args[i]->fType->kind() == Type::kVector_Kind || - args[i]->fType->kind() == Type::kMatrix_Kind) { - int columns = args[i]->fType->columns(); - int rows = args[i]->fType->rows(); + if (args[i]->fType.kind() == Type::kVector_Kind || + args[i]->fType.kind() == Type::kMatrix_Kind) { + int columns = args[i]->fType.columns(); + int rows = args[i]->fType.rows(); args[i] = this->coerce(std::move(args[i]), - type->componentType()->toCompound(columns, rows)); - actual += args[i]->fType->rows() * args[i]->fType->columns(); - } else if (args[i]->fType->kind() == Type::kScalar_Kind) { + type.componentType().toCompound(fContext, columns, rows)); + actual += args[i]->fType.rows() * args[i]->fType.columns(); + } else if (args[i]->fType.kind() == Type::kScalar_Kind) { actual += 1; - if (type->kind() != Type::kScalar_Kind) { - args[i] = this->coerce(std::move(args[i]), type->componentType()); + if (type.kind() != Type::kScalar_Kind) { + args[i] = this->coerce(std::move(args[i]), type.componentType()); } } else { - fErrors.error(position, "'" + args[i]->fType->description() + "' is not a valid " - "parameter to '" + type->description() + "' constructor"); + fErrors.error(position, "'" + args[i]->fType.description() + "' is not a valid " + "parameter to '" + type.description() + "' constructor"); return nullptr; } } - int min = type->rows() * type->columns(); - int max = type->columns() > 1 ? INT_MAX : min; + int min = type.rows() * type.columns(); + int max = type.columns() > 1 ? INT_MAX : min; if ((actual < min || actual > max) && !((kind == Type::kVector_Kind || kind == Type::kMatrix_Kind) && (actual == 1))) { - fErrors.error(position, "invalid arguments to '" + type->description() + + fErrors.error(position, "invalid arguments to '" + type.description() + "' constructor (expected " + to_string(min) + " scalar" + (min == 1 ? "" : "s") + ", but found " + to_string(actual) + ")"); @@ -956,50 +964,51 @@ std::unique_ptr<Expression> IRGenerator::convertPrefixExpression( } switch (expression.fOperator) { case Token::PLUS: - if (!base->fType->isNumber() && base->fType->kind() != Type::kVector_Kind) { + if (!base->fType.isNumber() && base->fType.kind() != Type::kVector_Kind) { fErrors.error(expression.fPosition, - "'+' cannot operate on '" + base->fType->description() + "'"); + "'+' cannot operate on '" + base->fType.description() + "'"); return nullptr; } return base; case Token::MINUS: - if (!base->fType->isNumber() && base->fType->kind() != Type::kVector_Kind) { + if (!base->fType.isNumber() && base->fType.kind() != Type::kVector_Kind) { fErrors.error(expression.fPosition, - "'-' cannot operate on '" + base->fType->description() + "'"); + "'-' cannot operate on '" + base->fType.description() + "'"); return nullptr; } if (base->fKind == Expression::kIntLiteral_Kind) { - return std::unique_ptr<Expression>(new IntLiteral(base->fPosition, + return std::unique_ptr<Expression>(new IntLiteral(fContext, base->fPosition, -((IntLiteral&) *base).fValue)); } if (base->fKind == Expression::kFloatLiteral_Kind) { double value = -((FloatLiteral&) *base).fValue; - return std::unique_ptr<Expression>(new FloatLiteral(base->fPosition, value)); + return std::unique_ptr<Expression>(new FloatLiteral(fContext, base->fPosition, + value)); } return std::unique_ptr<Expression>(new PrefixExpression(Token::MINUS, std::move(base))); case Token::PLUSPLUS: - if (!base->fType->isNumber()) { + if (!base->fType.isNumber()) { fErrors.error(expression.fPosition, "'" + Token::OperatorName(expression.fOperator) + - "' cannot operate on '" + base->fType->description() + "'"); + "' cannot operate on '" + base->fType.description() + "'"); return nullptr; } this->markWrittenTo(*base); break; case Token::MINUSMINUS: - if (!base->fType->isNumber()) { + if (!base->fType.isNumber()) { fErrors.error(expression.fPosition, "'" + Token::OperatorName(expression.fOperator) + - "' cannot operate on '" + base->fType->description() + "'"); + "' cannot operate on '" + base->fType.description() + "'"); return nullptr; } this->markWrittenTo(*base); break; case Token::NOT: - if (base->fType != kBool_Type) { + if (base->fType != *fContext.fBool_Type) { fErrors.error(expression.fPosition, "'" + Token::OperatorName(expression.fOperator) + - "' cannot operate on '" + base->fType->description() + "'"); + "' cannot operate on '" + base->fType.description() + "'"); return nullptr; } break; @@ -1012,8 +1021,8 @@ std::unique_ptr<Expression> IRGenerator::convertPrefixExpression( std::unique_ptr<Expression> IRGenerator::convertIndex(std::unique_ptr<Expression> base, const ASTExpression& index) { - if (base->fType->kind() != Type::kArray_Kind && base->fType->kind() != Type::kMatrix_Kind) { - fErrors.error(base->fPosition, "expected array, but found '" + base->fType->description() + + if (base->fType.kind() != Type::kArray_Kind && base->fType.kind() != Type::kMatrix_Kind) { + fErrors.error(base->fPosition, "expected array, but found '" + base->fType.description() + "'"); return nullptr; } @@ -1021,30 +1030,31 @@ std::unique_ptr<Expression> IRGenerator::convertIndex(std::unique_ptr<Expression if (!converted) { return nullptr; } - converted = this->coerce(std::move(converted), kInt_Type); + converted = this->coerce(std::move(converted), *fContext.fInt_Type); if (!converted) { return nullptr; } - return std::unique_ptr<Expression>(new IndexExpression(std::move(base), std::move(converted))); + return std::unique_ptr<Expression>(new IndexExpression(fContext, std::move(base), + std::move(converted))); } std::unique_ptr<Expression> IRGenerator::convertField(std::unique_ptr<Expression> base, const std::string& field) { - auto fields = base->fType->fields(); + auto fields = base->fType.fields(); for (size_t i = 0; i < fields.size(); i++) { if (fields[i].fName == field) { return std::unique_ptr<Expression>(new FieldAccess(std::move(base), (int) i)); } } - fErrors.error(base->fPosition, "type '" + base->fType->description() + "' does not have a " + fErrors.error(base->fPosition, "type '" + base->fType.description() + "' does not have a " "field named '" + field + ""); return nullptr; } std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expression> base, const std::string& fields) { - if (base->fType->kind() != Type::kVector_Kind) { - fErrors.error(base->fPosition, "cannot swizzle type '" + base->fType->description() + "'"); + if (base->fType.kind() != Type::kVector_Kind) { + fErrors.error(base->fPosition, "cannot swizzle type '" + base->fType.description() + "'"); return nullptr; } std::vector<int> swizzleComponents; @@ -1058,7 +1068,7 @@ std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expressi case 'y': // fall through case 'g': // fall through case 't': - if (base->fType->columns() >= 2) { + if (base->fType.columns() >= 2) { swizzleComponents.push_back(1); break; } @@ -1066,7 +1076,7 @@ std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expressi case 'z': // fall through case 'b': // fall through case 'p': - if (base->fType->columns() >= 3) { + if (base->fType.columns() >= 3) { swizzleComponents.push_back(2); break; } @@ -1074,7 +1084,7 @@ std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expressi case 'w': // fall through case 'a': // fall through case 'q': - if (base->fType->columns() >= 4) { + if (base->fType.columns() >= 4) { swizzleComponents.push_back(3); break; } @@ -1090,7 +1100,7 @@ std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expressi fErrors.error(base->fPosition, "too many components in swizzle mask '" + fields + "'"); return nullptr; } - return std::unique_ptr<Expression>(new Swizzle(std::move(base), swizzleComponents)); + return std::unique_ptr<Expression>(new Swizzle(fContext, std::move(base), swizzleComponents)); } std::unique_ptr<Expression> IRGenerator::convertSuffixExpression( @@ -1117,7 +1127,7 @@ std::unique_ptr<Expression> IRGenerator::convertSuffixExpression( return this->call(expression.fPosition, std::move(base), std::move(arguments)); } case ASTSuffix::kField_Kind: { - switch (base->fType->kind()) { + switch (base->fType.kind()) { case Type::kVector_Kind: return this->convertSwizzle(std::move(base), ((ASTFieldSuffix&) *expression.fSuffix).fField); @@ -1126,23 +1136,23 @@ std::unique_ptr<Expression> IRGenerator::convertSuffixExpression( ((ASTFieldSuffix&) *expression.fSuffix).fField); default: fErrors.error(base->fPosition, "cannot swizzle value of type '" + - base->fType->description() + "'"); + base->fType.description() + "'"); return nullptr; } } case ASTSuffix::kPostIncrement_Kind: - if (!base->fType->isNumber()) { + if (!base->fType.isNumber()) { fErrors.error(expression.fPosition, - "'++' cannot operate on '" + base->fType->description() + "'"); + "'++' cannot operate on '" + base->fType.description() + "'"); return nullptr; } this->markWrittenTo(*base); return std::unique_ptr<Expression>(new PostfixExpression(std::move(base), Token::PLUSPLUS)); case ASTSuffix::kPostDecrement_Kind: - if (!base->fType->isNumber()) { + if (!base->fType.isNumber()) { fErrors.error(expression.fPosition, - "'--' cannot operate on '" + base->fType->description() + "'"); + "'--' cannot operate on '" + base->fType.description() + "'"); return nullptr; } this->markWrittenTo(*base); @@ -1162,13 +1172,13 @@ void IRGenerator::checkValid(const Expression& expr) { fErrors.error(expr.fPosition, "expected '(' to begin constructor invocation"); break; default: - ASSERT(expr.fType != kInvalid_Type); + ASSERT(expr.fType != *fContext.fInvalid_Type); break; } } -void IRGenerator::markReadFrom(std::shared_ptr<Variable> var) { - var->fIsReadFrom = true; +void IRGenerator::markReadFrom(const Variable& var) { + var.fIsReadFrom = true; } static bool has_duplicates(const Swizzle& swizzle) { @@ -1187,7 +1197,7 @@ static bool has_duplicates(const Swizzle& swizzle) { void IRGenerator::markWrittenTo(const Expression& expr) { switch (expr.fKind) { case Expression::kVariableReference_Kind: { - const Variable& var = *((VariableReference&) expr).fVariable; + const Variable& var = ((VariableReference&) expr).fVariable; if (var.fModifiers.fFlags & (Modifiers::kConst_Flag | Modifiers::kUniform_Flag)) { fErrors.error(expr.fPosition, "cannot modify immutable variable '" + var.fName + "'"); diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h index d23e5a1bdb..2384b2dabc 100644 --- a/src/sksl/SkSLIRGenerator.h +++ b/src/sksl/SkSLIRGenerator.h @@ -53,7 +53,8 @@ namespace SkSL { */ class IRGenerator { public: - IRGenerator(std::shared_ptr<SymbolTable> root, ErrorReporter& errorReporter); + IRGenerator(const Context* context, std::shared_ptr<SymbolTable> root, + ErrorReporter& errorReporter); std::unique_ptr<VarDeclaration> convertVarDeclaration(const ASTVarDeclaration& decl, Variable::Storage storage); @@ -65,21 +66,20 @@ private: void pushSymbolTable(); void popSymbolTable(); - std::shared_ptr<Type> convertType(const ASTType& type); + const Type* convertType(const ASTType& type); std::unique_ptr<Expression> call(Position position, - std::shared_ptr<FunctionDeclaration> function, + const FunctionDeclaration& function, std::vector<std::unique_ptr<Expression>> arguments); - bool determineCallCost(std::shared_ptr<FunctionDeclaration> function, + bool determineCallCost(const FunctionDeclaration& function, const std::vector<std::unique_ptr<Expression>>& arguments, int* outCost); std::unique_ptr<Expression> call(Position position, std::unique_ptr<Expression> function, std::vector<std::unique_ptr<Expression>> arguments); - std::unique_ptr<Expression> coerce(std::unique_ptr<Expression> expr, - std::shared_ptr<Type> type); + std::unique_ptr<Expression> coerce(std::unique_ptr<Expression> expr, const Type& type); std::unique_ptr<Block> convertBlock(const ASTBlock& block); std::unique_ptr<Statement> convertBreak(const ASTBreakStatement& b); std::unique_ptr<Expression> convertConstructor(Position position, - std::shared_ptr<Type> type, + const Type& type, std::vector<std::unique_ptr<Expression>> params); std::unique_ptr<Statement> convertContinue(const ASTContinueStatement& c); std::unique_ptr<Statement> convertDiscard(const ASTDiscardStatement& d); @@ -106,10 +106,11 @@ private: std::unique_ptr<Statement> convertWhile(const ASTWhileStatement& w); void checkValid(const Expression& expr); - void markReadFrom(std::shared_ptr<Variable> var); + void markReadFrom(const Variable& var); void markWrittenTo(const Expression& expr); - std::shared_ptr<FunctionDeclaration> fCurrentFunction; + const Context& fContext; + const FunctionDeclaration* fCurrentFunction; std::shared_ptr<SymbolTable> fSymbolTable; ErrorReporter& fErrors; diff --git a/src/sksl/SkSLParser.cpp b/src/sksl/SkSLParser.cpp index fa302af0d3..edff0c67d1 100644 --- a/src/sksl/SkSLParser.cpp +++ b/src/sksl/SkSLParser.cpp @@ -52,6 +52,7 @@ #include "ast/SkSLASTVarDeclarationStatement.h" #include "ast/SkSLASTWhileStatement.h" #include "ir/SkSLSymbolTable.h" +#include "ir/SkSLType.h" namespace SkSL { @@ -290,17 +291,17 @@ std::unique_ptr<ASTType> Parser::structDeclaration() { return nullptr; } for (size_t i = 0; i < decl->fNames.size(); i++) { - auto type = std::static_pointer_cast<Type>(fTypes[decl->fType->fName]); + auto type = (const Type*) fTypes[decl->fType->fName]; for (int j = (int) decl->fSizes[i].size() - 1; j >= 0; j--) { - if (decl->fSizes[i][j]->fKind == ASTExpression::kInt_Kind) { + if (decl->fSizes[i][j]->fKind != ASTExpression::kInt_Kind) { this->error(decl->fPosition, "array size in struct field must be a constant"); } uint64_t columns = ((ASTIntLiteral&) *decl->fSizes[i][j]).fValue; std::string name = type->name() + "[" + to_string(columns) + "]"; - type = std::shared_ptr<Type>(new Type(name, Type::kArray_Kind, std::move(type), - (int) columns)); + type = new Type(name, Type::kArray_Kind, *type, (int) columns); + fTypes.takeOwnership((Type*) type); } - fields.push_back(Type::Field(decl->fModifiers, decl->fNames[i], std::move(type))); + fields.push_back(Type::Field(decl->fModifiers, decl->fNames[i], *type)); if (decl->fValues[i]) { this->error(decl->fPosition, "initializers are not permitted on struct fields"); } @@ -309,9 +310,8 @@ std::unique_ptr<ASTType> Parser::structDeclaration() { if (!this->expect(Token::RBRACE, "'}'")) { return nullptr; } - std::shared_ptr<Type> type(new Type(name.fText, fields)); - fTypes.add(type->fName, type); - return std::unique_ptr<ASTType>(new ASTType(name.fPosition, type->fName, + fTypes.add(name.fText, std::unique_ptr<Type>(new Type(name.fText, fields))); + return std::unique_ptr<ASTType>(new ASTType(name.fPosition, name.fText, ASTType::kStruct_Kind)); } diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp index 0a2dab3adf..2771e0291b 100644 --- a/src/sksl/SkSLSPIRVCodeGenerator.cpp +++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp @@ -141,36 +141,36 @@ void SPIRVCodeGenerator::writeWord(int32_t word, std::ostream& out) { #endif } -static bool is_float(const Type& type) { +static bool is_float(const Context& context, const Type& type) { if (type.kind() == Type::kVector_Kind) { - return is_float(*type.componentType()); + return is_float(context, type.componentType()); } - return type == *kFloat_Type || type == *kDouble_Type; + return type == *context.fFloat_Type || type == *context.fDouble_Type; } -static bool is_signed(const Type& type) { +static bool is_signed(const Context& context, const Type& type) { if (type.kind() == Type::kVector_Kind) { - return is_signed(*type.componentType()); + return is_signed(context, type.componentType()); } - return type == *kInt_Type; + return type == *context.fInt_Type; } -static bool is_unsigned(const Type& type) { +static bool is_unsigned(const Context& context, const Type& type) { if (type.kind() == Type::kVector_Kind) { - return is_unsigned(*type.componentType()); + return is_unsigned(context, type.componentType()); } - return type == *kUInt_Type; + return type == *context.fUInt_Type; } -static bool is_bool(const Type& type) { +static bool is_bool(const Context& context, const Type& type) { if (type.kind() == Type::kVector_Kind) { - return is_bool(*type.componentType()); + return is_bool(context, type.componentType()); } - return type == *kBool_Type; + return type == *context.fBool_Type; } -static bool is_out(std::shared_ptr<Variable> var) { - return (var->fModifiers.fFlags & Modifiers::kOut_Flag) != 0; +static bool is_out(const Variable& var) { + return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0; } #if SPIRV_DEBUG @@ -973,7 +973,7 @@ void SPIRVCodeGenerator::writeStruct(const Type& type, SpvId resultId) { // in the middle of writing the struct instruction std::vector<SpvId> types; for (const auto& f : type.fields()) { - types.push_back(this->getType(*f.fType)); + types.push_back(this->getType(f.fType)); } this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer); this->writeWord(resultId, fConstantBuffer); @@ -982,8 +982,8 @@ void SPIRVCodeGenerator::writeStruct(const Type& type, SpvId resultId) { } size_t offset = 0; for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) { - size_t size = type.fields()[i].fType->size(); - size_t alignment = type.fields()[i].fType->alignment(); + size_t size = type.fields()[i].fType.size(); + size_t alignment = type.fields()[i].fType.alignment(); size_t mod = offset % alignment; if (mod != 0) { offset += alignment - mod; @@ -995,14 +995,14 @@ void SPIRVCodeGenerator::writeStruct(const Type& type, SpvId resultId) { this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset, (SpvId) offset, fDecorationBuffer); } - if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) { + if (type.fields()[i].fType.kind() == Type::kMatrix_Kind) { this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor, fDecorationBuffer); this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride, - (SpvId) type.fields()[i].fType->stride(), fDecorationBuffer); + (SpvId) type.fields()[i].fType.stride(), fDecorationBuffer); } offset += size; - Type::Kind kind = type.fields()[i].fType->kind(); + Type::Kind kind = type.fields()[i].fType.kind(); if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) { offset += alignment - offset % alignment; } @@ -1016,15 +1016,15 @@ SpvId SPIRVCodeGenerator::getType(const Type& type) { SpvId result = this->nextId(); switch (type.kind()) { case Type::kScalar_Kind: - if (type == *kBool_Type) { + if (type == *fContext.fBool_Type) { this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer); - } else if (type == *kInt_Type) { + } else if (type == *fContext.fInt_Type) { this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer); - } else if (type == *kUInt_Type) { + } else if (type == *fContext.fUInt_Type) { this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer); - } else if (type == *kFloat_Type) { + } else if (type == *fContext.fFloat_Type) { this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer); - } else if (type == *kDouble_Type) { + } else if (type == *fContext.fDouble_Type) { this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer); } else { ASSERT(false); @@ -1032,11 +1032,12 @@ SpvId SPIRVCodeGenerator::getType(const Type& type) { break; case Type::kVector_Kind: this->writeInstruction(SpvOpTypeVector, result, - this->getType(*type.componentType()), + this->getType(type.componentType()), type.columns(), fConstantBuffer); break; case Type::kMatrix_Kind: - this->writeInstruction(SpvOpTypeMatrix, result, this->getType(*index_type(type)), + this->writeInstruction(SpvOpTypeMatrix, result, + this->getType(index_type(fContext, type)), type.columns(), fConstantBuffer); break; case Type::kStruct_Kind: @@ -1044,22 +1045,22 @@ SpvId SPIRVCodeGenerator::getType(const Type& type) { break; case Type::kArray_Kind: { if (type.columns() > 0) { - IntLiteral count(Position(), type.columns()); + IntLiteral count(fContext, Position(), type.columns()); this->writeInstruction(SpvOpTypeArray, result, - this->getType(*type.componentType()), + this->getType(type.componentType()), this->writeIntLiteral(count), fConstantBuffer); this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride, (int32_t) type.stride(), fDecorationBuffer); } else { ABORT("runtime-sized arrays are not yet supported"); this->writeInstruction(SpvOpTypeRuntimeArray, result, - this->getType(*type.componentType()), fConstantBuffer); + this->getType(type.componentType()), fConstantBuffer); } break; } case Type::kSampler_Kind: { SpvId image = this->nextId(); - this->writeInstruction(SpvOpTypeImage, image, this->getType(*kFloat_Type), + this->writeInstruction(SpvOpTypeImage, image, this->getType(*fContext.fFloat_Type), type.dimensions(), type.isDepth(), type.isArrayed(), type.isMultisampled(), type.isSampled(), SpvImageFormatUnknown, fConstantBuffer); @@ -1067,7 +1068,7 @@ SpvId SPIRVCodeGenerator::getType(const Type& type) { break; } default: - if (type == *kVoid_Type) { + if (type == *fContext.fVoid_Type) { this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer); } else { ABORT("invalid type: %s", type.description().c_str()); @@ -1079,22 +1080,22 @@ SpvId SPIRVCodeGenerator::getType(const Type& type) { return entry->second; } -SpvId SPIRVCodeGenerator::getFunctionType(std::shared_ptr<FunctionDeclaration> function) { - std::string key = function->fReturnType->description() + "("; +SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) { + std::string key = function.fReturnType.description() + "("; std::string separator = ""; - for (size_t i = 0; i < function->fParameters.size(); i++) { + for (size_t i = 0; i < function.fParameters.size(); i++) { key += separator; separator = ", "; - key += function->fParameters[i]->fType->description(); + key += function.fParameters[i]->fType.description(); } key += ")"; auto entry = fTypeMap.find(key); if (entry == fTypeMap.end()) { SpvId result = this->nextId(); - int32_t length = 3 + (int32_t) function->fParameters.size(); - SpvId returnType = this->getType(*function->fReturnType); + int32_t length = 3 + (int32_t) function.fParameters.size(); + SpvId returnType = this->getType(function.fReturnType); std::vector<SpvId> parameterTypes; - for (size_t i = 0; i < function->fParameters.size(); i++) { + for (size_t i = 0; i < function.fParameters.size(); i++) { // glslang seems to treat all function arguments as pointers whether they need to be or // not. I was initially puzzled by this until I ran bizarre failures with certain // patterns of function calls and control constructs, as exemplified by this minimal @@ -1118,10 +1119,10 @@ SpvId SPIRVCodeGenerator::getFunctionType(std::shared_ptr<FunctionDeclaration> f // as glslang does, fixes it. It's entirely possible I simply missed whichever part of // the spec makes this make sense. // if (is_out(function->fParameters[i])) { - parameterTypes.push_back(this->getPointerType(function->fParameters[i]->fType, + parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType, SpvStorageClassFunction)); // } else { -// parameterTypes.push_back(this->getType(*function->fParameters[i]->fType)); +// parameterTypes.push_back(this->getType(function.fParameters[i]->fType)); // } } this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer); @@ -1136,14 +1137,14 @@ SpvId SPIRVCodeGenerator::getFunctionType(std::shared_ptr<FunctionDeclaration> f return entry->second; } -SpvId SPIRVCodeGenerator::getPointerType(std::shared_ptr<Type> type, +SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) { - std::string key = type->description() + "*" + to_string(storageClass); + std::string key = type.description() + "*" + to_string(storageClass); auto entry = fTypeMap.find(key); if (entry == fTypeMap.end()) { SpvId result = this->nextId(); this->writeInstruction(SpvOpTypePointer, result, storageClass, - this->getType(*type), fConstantBuffer); + this->getType(type), fConstantBuffer); fTypeMap[key] = result; return result; } @@ -1185,21 +1186,21 @@ SpvId SPIRVCodeGenerator::writeExpression(Expression& expr, std::ostream& out) { } SpvId SPIRVCodeGenerator::writeIntrinsicCall(FunctionCall& c, std::ostream& out) { - auto intrinsic = fIntrinsicMap.find(c.fFunction->fName); + auto intrinsic = fIntrinsicMap.find(c.fFunction.fName); ASSERT(intrinsic != fIntrinsicMap.end()); - std::shared_ptr<Type> type = c.fArguments[0]->fType; + const Type& type = c.fArguments[0]->fType; int32_t intrinsicId; - if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(*type)) { + if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) { intrinsicId = std::get<1>(intrinsic->second); - } else if (is_signed(*type)) { + } else if (is_signed(fContext, type)) { intrinsicId = std::get<2>(intrinsic->second); - } else if (is_unsigned(*type)) { + } else if (is_unsigned(fContext, type)) { intrinsicId = std::get<3>(intrinsic->second); - } else if (is_bool(*type)) { + } else if (is_bool(fContext, type)) { intrinsicId = std::get<4>(intrinsic->second); } else { ABORT("invalid call %s, cannot operate on '%s'", c.description().c_str(), - type->description().c_str()); + type.description().c_str()); } switch (std::get<0>(intrinsic->second)) { case kGLSL_STD_450_IntrinsicKind: { @@ -1209,7 +1210,7 @@ SpvId SPIRVCodeGenerator::writeIntrinsicCall(FunctionCall& c, std::ostream& out) arguments.push_back(this->writeExpression(*c.fArguments[i], out)); } this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out); - this->writeWord(this->getType(*c.fType), out); + this->writeWord(this->getType(c.fType), out); this->writeWord(result, out); this->writeWord(fGLSLExtendedInstructions, out); this->writeWord(intrinsicId, out); @@ -1225,7 +1226,7 @@ SpvId SPIRVCodeGenerator::writeIntrinsicCall(FunctionCall& c, std::ostream& out) arguments.push_back(this->writeExpression(*c.fArguments[i], out)); } this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out); - this->writeWord(this->getType(*c.fType), out); + this->writeWord(this->getType(c.fType), out); this->writeWord(result, out); for (SpvId id : arguments) { this->writeWord(id, out); @@ -1249,7 +1250,7 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsi arguments.push_back(this->writeExpression(*c.fArguments[i], out)); } this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out); - this->writeWord(this->getType(*c.fType), out); + this->writeWord(this->getType(c.fType), out); this->writeWord(result, out); this->writeWord(fGLSLExtendedInstructions, out); this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out); @@ -1259,7 +1260,7 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsi return result; } case kTexture_SpecialIntrinsic: { - SpvId type = this->getType(*c.fType); + SpvId type = this->getType(c.fType); SpvId sampler = this->writeExpression(*c.fArguments[0], out); SpvId uv = this->writeExpression(*c.fArguments[1], out); if (c.fArguments.size() == 3) { @@ -1274,7 +1275,7 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsi break; } case kTextureProj_SpecialIntrinsic: { - SpvId type = this->getType(*c.fType); + SpvId type = this->getType(c.fType); SpvId sampler = this->writeExpression(*c.fArguments[0], out); SpvId uv = this->writeExpression(*c.fArguments[1], out); if (c.fArguments.size() == 3) { @@ -1293,7 +1294,7 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsi SpvId img = this->writeExpression(*c.fArguments[0], out); SpvId coords = this->writeExpression(*c.fArguments[1], out); this->writeInstruction(SpvOpImageSampleImplicitLod, - this->getType(*c.fType), + this->getType(c.fType), result, img, coords, @@ -1305,7 +1306,7 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsi } SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out) { - const auto& entry = fFunctionMap.find(c.fFunction); + const auto& entry = fFunctionMap.find(&c.fFunction); if (entry == fFunctionMap.end()) { return this->writeIntrinsicCall(c, out); } @@ -1318,7 +1319,7 @@ SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out) SpvId tmpVar; // if we need a temporary var to store this argument, this is the value to store in the var SpvId tmpValueId; - if (is_out(c.fFunction->fParameters[i])) { + if (is_out(*c.fFunction.fParameters[i])) { std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out); SpvId ptr = lv->getPointer(); if (ptr) { @@ -1330,7 +1331,7 @@ SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out) // update the lvalue. tmpValueId = lv->load(out); tmpVar = this->nextId(); - lvalues.push_back(std::make_tuple(tmpVar, this->getType(*c.fArguments[i]->fType), + lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType), std::move(lv))); } } else { @@ -1343,13 +1344,13 @@ SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out) SpvStorageClassFunction), tmpVar, SpvStorageClassFunction, - out); + fVariableBuffer); this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out); arguments.push_back(tmpVar); } SpvId result = this->nextId(); this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out); - this->writeWord(this->getType(*c.fType), out); + this->writeWord(this->getType(c.fType), out); this->writeWord(result, out); this->writeWord(entry->second, out); for (SpvId id : arguments) { @@ -1366,19 +1367,19 @@ SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out) } SpvId SPIRVCodeGenerator::writeConstantVector(Constructor& c) { - ASSERT(c.fType->kind() == Type::kVector_Kind && c.isConstant()); + ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant()); SpvId result = this->nextId(); std::vector<SpvId> arguments; for (size_t i = 0; i < c.fArguments.size(); i++) { arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer)); } - SpvId type = this->getType(*c.fType); + SpvId type = this->getType(c.fType); if (c.fArguments.size() == 1) { // with a single argument, a vector will have all of its entries equal to the argument - this->writeOpCode(SpvOpConstantComposite, 3 + c.fType->columns(), fConstantBuffer); + this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer); this->writeWord(type, fConstantBuffer); this->writeWord(result, fConstantBuffer); - for (int i = 0; i < c.fType->columns(); i++) { + for (int i = 0; i < c.fType.columns(); i++) { this->writeWord(arguments[0], fConstantBuffer); } } else { @@ -1394,43 +1395,43 @@ SpvId SPIRVCodeGenerator::writeConstantVector(Constructor& c) { } SpvId SPIRVCodeGenerator::writeFloatConstructor(Constructor& c, std::ostream& out) { - ASSERT(c.fType == kFloat_Type); + ASSERT(c.fType == *fContext.fFloat_Type); ASSERT(c.fArguments.size() == 1); - ASSERT(c.fArguments[0]->fType->isNumber()); + ASSERT(c.fArguments[0]->fType.isNumber()); SpvId result = this->nextId(); SpvId parameter = this->writeExpression(*c.fArguments[0], out); - if (c.fArguments[0]->fType == kInt_Type) { - this->writeInstruction(SpvOpConvertSToF, this->getType(*c.fType), result, parameter, + if (c.fArguments[0]->fType == *fContext.fInt_Type) { + this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter, out); - } else if (c.fArguments[0]->fType == kUInt_Type) { - this->writeInstruction(SpvOpConvertUToF, this->getType(*c.fType), result, parameter, + } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) { + this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter, out); - } else if (c.fArguments[0]->fType == kFloat_Type) { + } else if (c.fArguments[0]->fType == *fContext.fFloat_Type) { return parameter; } return result; } SpvId SPIRVCodeGenerator::writeIntConstructor(Constructor& c, std::ostream& out) { - ASSERT(c.fType == kInt_Type); + ASSERT(c.fType == *fContext.fInt_Type); ASSERT(c.fArguments.size() == 1); - ASSERT(c.fArguments[0]->fType->isNumber()); + ASSERT(c.fArguments[0]->fType.isNumber()); SpvId result = this->nextId(); SpvId parameter = this->writeExpression(*c.fArguments[0], out); - if (c.fArguments[0]->fType == kFloat_Type) { - this->writeInstruction(SpvOpConvertFToS, this->getType(*c.fType), result, parameter, + if (c.fArguments[0]->fType == *fContext.fFloat_Type) { + this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter, out); - } else if (c.fArguments[0]->fType == kUInt_Type) { - this->writeInstruction(SpvOpSatConvertUToS, this->getType(*c.fType), result, parameter, + } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) { + this->writeInstruction(SpvOpSatConvertUToS, this->getType(c.fType), result, parameter, out); - } else if (c.fArguments[0]->fType == kInt_Type) { + } else if (c.fArguments[0]->fType == *fContext.fInt_Type) { return parameter; } return result; } SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& out) { - ASSERT(c.fType->kind() == Type::kMatrix_Kind); + ASSERT(c.fType.kind() == Type::kMatrix_Kind); // go ahead and write the arguments so we don't try to write new instructions in the middle of // an instruction std::vector<SpvId> arguments; @@ -1438,30 +1439,31 @@ SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& o arguments.push_back(this->writeExpression(*c.fArguments[i], out)); } SpvId result = this->nextId(); - int rows = c.fType->rows(); - int columns = c.fType->columns(); + int rows = c.fType.rows(); + int columns = c.fType.columns(); // FIXME this won't work to create a matrix from another matrix if (arguments.size() == 1) { // with a single argument, a matrix will have all of its diagonal entries equal to the // argument and its other values equal to zero // FIXME this won't work for int matrices - FloatLiteral zero(Position(), 0); + FloatLiteral zero(fContext, Position(), 0); SpvId zeroId = this->writeFloatLiteral(zero); std::vector<SpvId> columnIds; for (int column = 0; column < columns; column++) { - this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType->rows(), + this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.rows(), out); - this->writeWord(this->getType(*c.fType->componentType()->toCompound(rows, 1)), out); + this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows, 1)), + out); SpvId columnId = this->nextId(); this->writeWord(columnId, out); columnIds.push_back(columnId); - for (int row = 0; row < c.fType->columns(); row++) { + for (int row = 0; row < c.fType.columns(); row++) { this->writeWord(row == column ? arguments[0] : zeroId, out); } } this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out); - this->writeWord(this->getType(*c.fType), out); + this->writeWord(this->getType(c.fType), out); this->writeWord(result, out); for (SpvId id : columnIds) { this->writeWord(id, out); @@ -1470,15 +1472,16 @@ SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& o std::vector<SpvId> columnIds; int currentCount = 0; for (size_t i = 0; i < arguments.size(); i++) { - if (c.fArguments[i]->fType->kind() == Type::kVector_Kind) { + if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) { ASSERT(currentCount == 0); columnIds.push_back(arguments[i]); currentCount = 0; } else { - ASSERT(c.fArguments[i]->fType->kind() == Type::kScalar_Kind); + ASSERT(c.fArguments[i]->fType.kind() == Type::kScalar_Kind); if (currentCount == 0) { - this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType->rows(), out); - this->writeWord(this->getType(*c.fType->componentType()->toCompound(rows, 1)), + this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.rows(), out); + this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows, + 1)), out); SpvId id = this->nextId(); this->writeWord(id, out); @@ -1490,7 +1493,7 @@ SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& o } ASSERT(columnIds.size() == (size_t) columns); this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out); - this->writeWord(this->getType(*c.fType), out); + this->writeWord(this->getType(c.fType), out); this->writeWord(result, out); for (SpvId id : columnIds) { this->writeWord(id, out); @@ -1500,7 +1503,7 @@ SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& o } SpvId SPIRVCodeGenerator::writeVectorConstructor(Constructor& c, std::ostream& out) { - ASSERT(c.fType->kind() == Type::kVector_Kind); + ASSERT(c.fType.kind() == Type::kVector_Kind); if (c.isConstant()) { return this->writeConstantVector(c); } @@ -1511,16 +1514,16 @@ SpvId SPIRVCodeGenerator::writeVectorConstructor(Constructor& c, std::ostream& o arguments.push_back(this->writeExpression(*c.fArguments[i], out)); } SpvId result = this->nextId(); - if (arguments.size() == 1 && c.fArguments[0]->fType->kind() == Type::kScalar_Kind) { - this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType->columns(), out); - this->writeWord(this->getType(*c.fType), out); + if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) { + this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out); + this->writeWord(this->getType(c.fType), out); this->writeWord(result, out); - for (int i = 0; i < c.fType->columns(); i++) { + for (int i = 0; i < c.fType.columns(); i++) { this->writeWord(arguments[0], out); } } else { this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out); - this->writeWord(this->getType(*c.fType), out); + this->writeWord(this->getType(c.fType), out); this->writeWord(result, out); for (SpvId id : arguments) { this->writeWord(id, out); @@ -1530,12 +1533,12 @@ SpvId SPIRVCodeGenerator::writeVectorConstructor(Constructor& c, std::ostream& o } SpvId SPIRVCodeGenerator::writeConstructor(Constructor& c, std::ostream& out) { - if (c.fType == kFloat_Type) { + if (c.fType == *fContext.fFloat_Type) { return this->writeFloatConstructor(c, out); - } else if (c.fType == kInt_Type) { + } else if (c.fType == *fContext.fInt_Type) { return this->writeIntConstructor(c, out); } - switch (c.fType->kind()) { + switch (c.fType.kind()) { case Type::kVector_Kind: return this->writeVectorConstructor(c, out); case Type::kMatrix_Kind: @@ -1560,7 +1563,7 @@ SpvStorageClass_ get_storage_class(const Modifiers& modifiers) { SpvStorageClass_ get_storage_class(Expression& expr) { switch (expr.fKind) { case Expression::kVariableReference_Kind: - return get_storage_class(((VariableReference&) expr).fVariable->fModifiers); + return get_storage_class(((VariableReference&) expr).fVariable.fModifiers); case Expression::kFieldAccess_Kind: return get_storage_class(*((FieldAccess&) expr).fBase); case Expression::kIndex_Kind: @@ -1582,7 +1585,7 @@ std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(Expression& expr, std::ost case Expression::kFieldAccess_Kind: { FieldAccess& fieldExpr = (FieldAccess&) expr; chain = this->getAccessChain(*fieldExpr.fBase, out); - IntLiteral index(Position(), fieldExpr.fFieldIndex); + IntLiteral index(fContext, Position(), fieldExpr.fFieldIndex); chain.push_back(this->writeIntLiteral(index)); break; } @@ -1698,13 +1701,13 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(Expres std::ostream& out) { switch (expr.fKind) { case Expression::kVariableReference_Kind: { - std::shared_ptr<Variable> var = ((VariableReference&) expr).fVariable; - auto entry = fVariableMap.find(var); + const Variable& var = ((VariableReference&) expr).fVariable; + auto entry = fVariableMap.find(&var); ASSERT(entry != fVariableMap.end()); return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( *this, entry->second, - this->getType(*expr.fType))); + this->getType(expr.fType))); } case Expression::kIndex_Kind: // fall through case Expression::kFieldAccess_Kind: { @@ -1719,7 +1722,7 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(Expres return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( *this, member, - this->getType(*expr.fType))); + this->getType(expr.fType))); } case Expression::kSwizzle_Kind: { @@ -1728,7 +1731,7 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(Expres SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer(); ASSERT(base); if (count == 1) { - IntLiteral index(Position(), swizzle.fComponents[0]); + IntLiteral index(fContext, Position(), swizzle.fComponents[0]); SpvId member = this->nextId(); this->writeInstruction(SpvOpAccessChain, this->getPointerType(swizzle.fType, @@ -1740,14 +1743,14 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(Expres return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( *this, member, - this->getType(*expr.fType))); + this->getType(expr.fType))); } else { return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue( *this, base, swizzle.fComponents, - *swizzle.fBase->fType, - *expr.fType)); + swizzle.fBase->fType, + expr.fType)); } } @@ -1758,21 +1761,22 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(Expres // caught by IRGenerator SpvId result = this->nextId(); SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction); - this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction, out); + this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction, + fVariableBuffer); this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out); return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( *this, result, - this->getType(*expr.fType))); + this->getType(expr.fType))); } } SpvId SPIRVCodeGenerator::writeVariableReference(VariableReference& ref, std::ostream& out) { - auto entry = fVariableMap.find(ref.fVariable); + auto entry = fVariableMap.find(&ref.fVariable); ASSERT(entry != fVariableMap.end()); SpvId var = entry->second; SpvId result = this->nextId(); - this->writeInstruction(SpvOpLoad, this->getType(*ref.fVariable->fType), result, var, out); + this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out); return result; } @@ -1789,11 +1793,11 @@ SpvId SPIRVCodeGenerator::writeSwizzle(Swizzle& swizzle, std::ostream& out) { SpvId result = this->nextId(); size_t count = swizzle.fComponents.size(); if (count == 1) { - this->writeInstruction(SpvOpCompositeExtract, this->getType(*swizzle.fType), result, base, + this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base, swizzle.fComponents[0], out); } else { this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out); - this->writeWord(this->getType(*swizzle.fType), out); + this->writeWord(this->getType(swizzle.fType), out); this->writeWord(result, out); this->writeWord(base, out); this->writeWord(base, out); @@ -1809,13 +1813,13 @@ SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt, SpvOp_ ifBool, std::ostream& out) { SpvId result = this->nextId(); - if (is_float(operandType)) { + if (is_float(fContext, operandType)) { this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out); - } else if (is_signed(operandType)) { + } else if (is_signed(fContext, operandType)) { this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out); - } else if (is_unsigned(operandType)) { + } else if (is_unsigned(fContext, operandType)) { this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out); - } else if (operandType == *kBool_Type) { + } else if (operandType == *fContext.fBool_Type) { this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out); } else { ABORT("invalid operandType: %s", operandType.description().c_str()); @@ -1862,7 +1866,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea } // "normal" operators - const Type& resultType = *b.fType; + const Type& resultType = b.fType; std::unique_ptr<LValue> lvalue; SpvId lhs; if (is_assignment(b.fOperator)) { @@ -1878,23 +1882,23 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea // IR allows mismatched types in expressions (e.g. vec2 * float), but they need special handling // in SPIR-V if (b.fLeft->fType != b.fRight->fType) { - if (b.fLeft->fType->kind() == Type::kVector_Kind && - b.fRight->fType->isNumber()) { + if (b.fLeft->fType.kind() == Type::kVector_Kind && + b.fRight->fType.isNumber()) { // promote number to vector SpvId vec = this->nextId(); - this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType->columns(), out); + this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out); this->writeWord(this->getType(resultType), out); this->writeWord(vec, out); for (int i = 0; i < resultType.columns(); i++) { this->writeWord(rhs, out); } rhs = vec; - operandType = b.fRight->fType.get(); - } else if (b.fRight->fType->kind() == Type::kVector_Kind && - b.fLeft->fType->isNumber()) { + operandType = &b.fRight->fType; + } else if (b.fRight->fType.kind() == Type::kVector_Kind && + b.fLeft->fType.isNumber()) { // promote number to vector SpvId vec = this->nextId(); - this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType->columns(), out); + this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out); this->writeWord(this->getType(resultType), out); this->writeWord(vec, out); for (int i = 0; i < resultType.columns(); i++) { @@ -1902,33 +1906,33 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea } lhs = vec; ASSERT(!lvalue); - operandType = b.fLeft->fType.get(); - } else if (b.fLeft->fType->kind() == Type::kMatrix_Kind) { + operandType = &b.fLeft->fType; + } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) { SpvOp_ op; - if (b.fRight->fType->kind() == Type::kMatrix_Kind) { + if (b.fRight->fType.kind() == Type::kMatrix_Kind) { op = SpvOpMatrixTimesMatrix; - } else if (b.fRight->fType->kind() == Type::kVector_Kind) { + } else if (b.fRight->fType.kind() == Type::kVector_Kind) { op = SpvOpMatrixTimesVector; } else { - ASSERT(b.fRight->fType->kind() == Type::kScalar_Kind); + ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind); op = SpvOpMatrixTimesScalar; } SpvId result = this->nextId(); - this->writeInstruction(op, this->getType(*b.fType), result, lhs, rhs, out); + this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out); if (b.fOperator == Token::STAREQ) { lvalue->store(result, out); } else { ASSERT(b.fOperator == Token::STAR); } return result; - } else if (b.fRight->fType->kind() == Type::kMatrix_Kind) { + } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) { SpvId result = this->nextId(); - if (b.fLeft->fType->kind() == Type::kVector_Kind) { - this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(*b.fType), result, + if (b.fLeft->fType.kind() == Type::kVector_Kind) { + this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result, lhs, rhs, out); } else { - ASSERT(b.fLeft->fType->kind() == Type::kScalar_Kind); - this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(*b.fType), result, rhs, + ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind); + this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs, lhs, out); } if (b.fOperator == Token::STAREQ) { @@ -1941,35 +1945,35 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea ABORT("unsupported binary expression: %s", b.description().c_str()); } } else { - operandType = b.fLeft->fType.get(); - ASSERT(*operandType == *b.fRight->fType); + operandType = &b.fLeft->fType; + ASSERT(*operandType == b.fRight->fType); } switch (b.fOperator) { case Token::EQEQ: - ASSERT(resultType == *kBool_Type); + ASSERT(resultType == *fContext.fBool_Type); return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdEqual, SpvOpIEqual, SpvOpIEqual, SpvOpLogicalEqual, out); case Token::NEQ: - ASSERT(resultType == *kBool_Type); + ASSERT(resultType == *fContext.fBool_Type); return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdNotEqual, SpvOpINotEqual, SpvOpINotEqual, SpvOpLogicalNotEqual, out); case Token::GT: - ASSERT(resultType == *kBool_Type); + ASSERT(resultType == *fContext.fBool_Type); return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdGreaterThan, SpvOpSGreaterThan, SpvOpUGreaterThan, SpvOpUndef, out); case Token::LT: - ASSERT(resultType == *kBool_Type); + ASSERT(resultType == *fContext.fBool_Type); return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan, SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out); case Token::GTEQ: - ASSERT(resultType == *kBool_Type); + ASSERT(resultType == *fContext.fBool_Type); return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual, SpvOpUGreaterThanEqual, SpvOpUndef, out); case Token::LTEQ: - ASSERT(resultType == *kBool_Type); + ASSERT(resultType == *fContext.fBool_Type); return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual, SpvOpULessThanEqual, SpvOpUndef, out); @@ -1980,8 +1984,8 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef, out); case Token::STAR: - if (b.fLeft->fType->kind() == Type::kMatrix_Kind && - b.fRight->fType->kind() == Type::kMatrix_Kind) { + if (b.fLeft->fType.kind() == Type::kMatrix_Kind && + b.fRight->fType.kind() == Type::kMatrix_Kind) { // matrix multiply SpvId result = this->nextId(); this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result, @@ -2008,8 +2012,8 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea return result; } case Token::STAREQ: { - if (b.fLeft->fType->kind() == Type::kMatrix_Kind && - b.fRight->fType->kind() == Type::kMatrix_Kind) { + if (b.fLeft->fType.kind() == Type::kMatrix_Kind && + b.fRight->fType.kind() == Type::kMatrix_Kind) { // matrix multiply SpvId result = this->nextId(); this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result, @@ -2039,7 +2043,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea SpvId SPIRVCodeGenerator::writeLogicalAnd(BinaryExpression& a, std::ostream& out) { ASSERT(a.fOperator == Token::LOGICALAND); - BoolLiteral falseLiteral(Position(), false); + BoolLiteral falseLiteral(fContext, Position(), false); SpvId falseConstant = this->writeBoolLiteral(falseLiteral); SpvId lhs = this->writeExpression(*a.fLeft, out); SpvId rhsLabel = this->nextId(); @@ -2053,14 +2057,14 @@ SpvId SPIRVCodeGenerator::writeLogicalAnd(BinaryExpression& a, std::ostream& out this->writeInstruction(SpvOpBranch, end, out); this->writeLabel(end, out); SpvId result = this->nextId(); - this->writeInstruction(SpvOpPhi, this->getType(*kBool_Type), result, falseConstant, lhsBlock, - rhs, rhsBlock, out); + this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant, + lhsBlock, rhs, rhsBlock, out); return result; } SpvId SPIRVCodeGenerator::writeLogicalOr(BinaryExpression& o, std::ostream& out) { ASSERT(o.fOperator == Token::LOGICALOR); - BoolLiteral trueLiteral(Position(), true); + BoolLiteral trueLiteral(fContext, Position(), true); SpvId trueConstant = this->writeBoolLiteral(trueLiteral); SpvId lhs = this->writeExpression(*o.fLeft, out); SpvId rhsLabel = this->nextId(); @@ -2074,8 +2078,8 @@ SpvId SPIRVCodeGenerator::writeLogicalOr(BinaryExpression& o, std::ostream& out) this->writeInstruction(SpvOpBranch, end, out); this->writeLabel(end, out); SpvId result = this->nextId(); - this->writeInstruction(SpvOpPhi, this->getType(*kBool_Type), result, trueConstant, lhsBlock, - rhs, rhsBlock, out); + this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant, + lhsBlock, rhs, rhsBlock, out); return result; } @@ -2086,7 +2090,7 @@ SpvId SPIRVCodeGenerator::writeTernaryExpression(TernaryExpression& t, std::ostr SpvId result = this->nextId(); SpvId trueId = this->writeExpression(*t.fIfTrue, out); SpvId falseId = this->writeExpression(*t.fIfFalse, out); - this->writeInstruction(SpvOpSelect, this->getType(*t.fType), result, test, trueId, falseId, + this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId, out); return result; } @@ -2094,7 +2098,7 @@ SpvId SPIRVCodeGenerator::writeTernaryExpression(TernaryExpression& t, std::ostr // Adreno. Switched to storing the result in a temp variable as glslang does. SpvId var = this->nextId(); this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction), - var, SpvStorageClassFunction, out); + var, SpvStorageClassFunction, fVariableBuffer); SpvId trueLabel = this->nextId(); SpvId falseLabel = this->nextId(); SpvId end = this->nextId(); @@ -2108,18 +2112,16 @@ SpvId SPIRVCodeGenerator::writeTernaryExpression(TernaryExpression& t, std::ostr this->writeInstruction(SpvOpBranch, end, out); this->writeLabel(end, out); SpvId result = this->nextId(); - this->writeInstruction(SpvOpLoad, this->getType(*t.fType), result, var, out); + this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out); return result; } -Expression* literal_1(const Type& type) { - static IntLiteral int1(Position(), 1); - static FloatLiteral float1(Position(), 1.0); - if (type == *kInt_Type) { - return &int1; +std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) { + if (type == *context.fInt_Type) { + return std::unique_ptr<Expression>(new IntLiteral(context, Position(), 1)); } - else if (type == *kFloat_Type) { - return &float1; + else if (type == *context.fFloat_Type) { + return std::unique_ptr<Expression>(new FloatLiteral(context, Position(), 1.0)); } else { ABORT("math is unsupported on type '%s'") } @@ -2128,11 +2130,11 @@ Expression* literal_1(const Type& type) { SpvId SPIRVCodeGenerator::writePrefixExpression(PrefixExpression& p, std::ostream& out) { if (p.fOperator == Token::MINUS) { SpvId result = this->nextId(); - SpvId typeId = this->getType(*p.fType); + SpvId typeId = this->getType(p.fType); SpvId expr = this->writeExpression(*p.fOperand, out); - if (is_float(*p.fType)) { + if (is_float(fContext, p.fType)) { this->writeInstruction(SpvOpFNegate, typeId, result, expr, out); - } else if (is_signed(*p.fType)) { + } else if (is_signed(fContext, p.fType)) { this->writeInstruction(SpvOpSNegate, typeId, result, expr, out); } else { ABORT("unsupported prefix expression %s", p.description().c_str()); @@ -2144,8 +2146,8 @@ SpvId SPIRVCodeGenerator::writePrefixExpression(PrefixExpression& p, std::ostrea return this->writeExpression(*p.fOperand, out); case Token::PLUSPLUS: { std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out); - SpvId one = this->writeExpression(*literal_1(*p.fType), out); - SpvId result = this->writeBinaryOperation(*p.fType, *p.fType, lv->load(out), one, + SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out); + SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one, SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); lv->store(result, out); @@ -2153,17 +2155,17 @@ SpvId SPIRVCodeGenerator::writePrefixExpression(PrefixExpression& p, std::ostrea } case Token::MINUSMINUS: { std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out); - SpvId one = this->writeExpression(*literal_1(*p.fType), out); - SpvId result = this->writeBinaryOperation(*p.fType, *p.fType, lv->load(out), one, + SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out); + SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one, SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef, out); lv->store(result, out); return result; } case Token::NOT: { - ASSERT(p.fOperand->fType == kBool_Type); + ASSERT(p.fOperand->fType == *fContext.fBool_Type); SpvId result = this->nextId(); - this->writeInstruction(SpvOpLogicalNot, this->getType(*p.fOperand->fType), result, + this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result, this->writeExpression(*p.fOperand, out), out); return result; } @@ -2175,16 +2177,16 @@ SpvId SPIRVCodeGenerator::writePrefixExpression(PrefixExpression& p, std::ostrea SpvId SPIRVCodeGenerator::writePostfixExpression(PostfixExpression& p, std::ostream& out) { std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out); SpvId result = lv->load(out); - SpvId one = this->writeExpression(*literal_1(*p.fType), out); + SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out); switch (p.fOperator) { case Token::PLUSPLUS: { - SpvId temp = this->writeBinaryOperation(*p.fType, *p.fType, result, one, SpvOpFAdd, + SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); lv->store(temp, out); return result; } case Token::MINUSMINUS: { - SpvId temp = this->writeBinaryOperation(*p.fType, *p.fType, result, one, SpvOpFSub, + SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef, out); lv->store(temp, out); return result; @@ -2198,14 +2200,14 @@ SpvId SPIRVCodeGenerator::writeBoolLiteral(BoolLiteral& b) { if (b.fValue) { if (fBoolTrue == 0) { fBoolTrue = this->nextId(); - this->writeInstruction(SpvOpConstantTrue, this->getType(*b.fType), fBoolTrue, + this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue, fConstantBuffer); } return fBoolTrue; } else { if (fBoolFalse == 0) { fBoolFalse = this->nextId(); - this->writeInstruction(SpvOpConstantFalse, this->getType(*b.fType), fBoolFalse, + this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse, fConstantBuffer); } return fBoolFalse; @@ -2213,22 +2215,22 @@ SpvId SPIRVCodeGenerator::writeBoolLiteral(BoolLiteral& b) { } SpvId SPIRVCodeGenerator::writeIntLiteral(IntLiteral& i) { - if (i.fType == kInt_Type) { + if (i.fType == *fContext.fInt_Type) { auto entry = fIntConstants.find(i.fValue); if (entry == fIntConstants.end()) { SpvId result = this->nextId(); - this->writeInstruction(SpvOpConstant, this->getType(*i.fType), result, (SpvId) i.fValue, + this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue, fConstantBuffer); fIntConstants[i.fValue] = result; return result; } return entry->second; } else { - ASSERT(i.fType == kUInt_Type); + ASSERT(i.fType == *fContext.fUInt_Type); auto entry = fUIntConstants.find(i.fValue); if (entry == fUIntConstants.end()) { SpvId result = this->nextId(); - this->writeInstruction(SpvOpConstant, this->getType(*i.fType), result, (SpvId) i.fValue, + this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue, fConstantBuffer); fUIntConstants[i.fValue] = result; return result; @@ -2238,7 +2240,7 @@ SpvId SPIRVCodeGenerator::writeIntLiteral(IntLiteral& i) { } SpvId SPIRVCodeGenerator::writeFloatLiteral(FloatLiteral& f) { - if (f.fType == kFloat_Type) { + if (f.fType == *fContext.fFloat_Type) { float value = (float) f.fValue; auto entry = fFloatConstants.find(value); if (entry == fFloatConstants.end()) { @@ -2246,21 +2248,21 @@ SpvId SPIRVCodeGenerator::writeFloatLiteral(FloatLiteral& f) { uint32_t bits; ASSERT(sizeof(bits) == sizeof(value)); memcpy(&bits, &value, sizeof(bits)); - this->writeInstruction(SpvOpConstant, this->getType(*f.fType), result, bits, + this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits, fConstantBuffer); fFloatConstants[value] = result; return result; } return entry->second; } else { - ASSERT(f.fType == kDouble_Type); + ASSERT(f.fType == *fContext.fDouble_Type); auto entry = fDoubleConstants.find(f.fValue); if (entry == fDoubleConstants.end()) { SpvId result = this->nextId(); uint64_t bits; ASSERT(sizeof(bits) == sizeof(f.fValue)); memcpy(&bits, &f.fValue, sizeof(bits)); - this->writeInstruction(SpvOpConstant, this->getType(*f.fType), result, + this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits & 0xffffffff, bits >> 32, fConstantBuffer); fDoubleConstants[f.fValue] = result; return result; @@ -2269,26 +2271,25 @@ SpvId SPIRVCodeGenerator::writeFloatLiteral(FloatLiteral& f) { } } -SpvId SPIRVCodeGenerator::writeFunctionStart(std::shared_ptr<FunctionDeclaration> f, - std::ostream& out) { - SpvId result = fFunctionMap[f]; - this->writeInstruction(SpvOpFunction, this->getType(*f->fReturnType), result, +SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, std::ostream& out) { + SpvId result = fFunctionMap[&f]; + this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result, SpvFunctionControlMaskNone, this->getFunctionType(f), out); - this->writeInstruction(SpvOpName, result, f->fName.c_str(), fNameBuffer); - for (size_t i = 0; i < f->fParameters.size(); i++) { + this->writeInstruction(SpvOpName, result, f.fName.c_str(), fNameBuffer); + for (size_t i = 0; i < f.fParameters.size(); i++) { SpvId id = this->nextId(); - fVariableMap[f->fParameters[i]] = id; + fVariableMap[f.fParameters[i]] = id; SpvId type; - type = this->getPointerType(f->fParameters[i]->fType, SpvStorageClassFunction); + type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction); this->writeInstruction(SpvOpFunctionParameter, type, id, out); } return result; } -SpvId SPIRVCodeGenerator::writeFunction(FunctionDefinition& f, std::ostream& out) { +SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, std::ostream& out) { SpvId result = this->writeFunctionStart(f.fDeclaration, out); this->writeLabel(this->nextId(), out); - if (f.fDeclaration->fName == "main") { + if (f.fDeclaration.fName == "main") { out << fGlobalInitializersBuffer.str(); } std::stringstream bodyBuffer; @@ -2350,21 +2351,26 @@ void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int mem } SpvId SPIRVCodeGenerator::writeInterfaceBlock(InterfaceBlock& intf) { - SpvId type = this->getType(*intf.fVariable->fType); + SpvId type = this->getType(intf.fVariable.fType); SpvId result = this->nextId(); this->writeInstruction(SpvOpDecorate, type, SpvDecorationBlock, fDecorationBuffer); - SpvStorageClass_ storageClass = get_storage_class(intf.fVariable->fModifiers); + SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers); SpvId ptrType = this->nextId(); this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, type, fConstantBuffer); this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer); - this->writeLayout(intf.fVariable->fModifiers.fLayout, result); - fVariableMap[intf.fVariable] = result; + this->writeLayout(intf.fVariable.fModifiers.fLayout, result); + fVariableMap[&intf.fVariable] = result; return result; } void SPIRVCodeGenerator::writeGlobalVars(VarDeclaration& decl, std::ostream& out) { for (size_t i = 0; i < decl.fVars.size(); i++) { - if (!decl.fVars[i]->fIsReadFrom && !decl.fVars[i]->fIsWrittenTo) { + if (!decl.fVars[i]->fIsReadFrom && !decl.fVars[i]->fIsWrittenTo && + !(decl.fVars[i]->fModifiers.fFlags & (Modifiers::kIn_Flag | + Modifiers::kOut_Flag | + Modifiers::kUniform_Flag))) { + // variable is dead and not an input / output var (the Vulkan debug layers complain if + // we elide an interface var, even if it's dead) continue; } SpvStorageClass_ storageClass; @@ -2373,7 +2379,7 @@ void SPIRVCodeGenerator::writeGlobalVars(VarDeclaration& decl, std::ostream& out } else if (decl.fVars[i]->fModifiers.fFlags & Modifiers::kOut_Flag) { storageClass = SpvStorageClassOutput; } else if (decl.fVars[i]->fModifiers.fFlags & Modifiers::kUniform_Flag) { - if (decl.fVars[i]->fType->kind() == Type::kSampler_Kind) { + if (decl.fVars[i]->fType.kind() == Type::kSampler_Kind) { storageClass = SpvStorageClassUniformConstant; } else { storageClass = SpvStorageClassUniform; @@ -2386,11 +2392,11 @@ void SPIRVCodeGenerator::writeGlobalVars(VarDeclaration& decl, std::ostream& out SpvId type = this->getPointerType(decl.fVars[i]->fType, storageClass); this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer); this->writeInstruction(SpvOpName, id, decl.fVars[i]->fName.c_str(), fNameBuffer); - if (decl.fVars[i]->fType->kind() == Type::kMatrix_Kind) { + if (decl.fVars[i]->fType.kind() == Type::kMatrix_Kind) { this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationColMajor, fDecorationBuffer); this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationMatrixStride, - (SpvId) decl.fVars[i]->fType->stride(), fDecorationBuffer); + (SpvId) decl.fVars[i]->fType.stride(), fDecorationBuffer); } if (decl.fValues[i]) { ASSERT(!fCurrentBlock); @@ -2538,15 +2544,15 @@ void SPIRVCodeGenerator::writeInstructions(Program& program, std::ostream& out) for (size_t i = 0; i < program.fElements.size(); i++) { if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) { FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i]; - fFunctionMap[f.fDeclaration] = this->nextId(); + fFunctionMap[&f.fDeclaration] = this->nextId(); } } for (size_t i = 0; i < program.fElements.size(); i++) { if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) { InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i]; SpvId id = this->writeInterfaceBlock(intf); - if ((intf.fVariable->fModifiers.fFlags & Modifiers::kIn_Flag) || - (intf.fVariable->fModifiers.fFlags & Modifiers::kOut_Flag)) { + if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) || + (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) { interfaceVars.push_back(id); } } @@ -2561,7 +2567,7 @@ void SPIRVCodeGenerator::writeInstructions(Program& program, std::ostream& out) this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body); } } - std::shared_ptr<FunctionDeclaration> main = nullptr; + const FunctionDeclaration* main = nullptr; for (auto entry : fFunctionMap) { if (entry.first->fName == "main") { main = entry.first; @@ -2569,7 +2575,7 @@ void SPIRVCodeGenerator::writeInstructions(Program& program, std::ostream& out) } ASSERT(main); for (auto entry : fVariableMap) { - std::shared_ptr<Variable> var = entry.first; + const Variable* var = entry.first; if (var->fStorage == Variable::kGlobal_Storage && ((var->fModifiers.fFlags & Modifiers::kIn_Flag) || (var->fModifiers.fFlags & Modifiers::kOut_Flag))) { diff --git a/src/sksl/SkSLSPIRVCodeGenerator.h b/src/sksl/SkSLSPIRVCodeGenerator.h index 885c6b8b70..a20ad9f40b 100644 --- a/src/sksl/SkSLSPIRVCodeGenerator.h +++ b/src/sksl/SkSLSPIRVCodeGenerator.h @@ -61,8 +61,9 @@ public: virtual void store(SpvId value, std::ostream& out) = 0; }; - SPIRVCodeGenerator() - : fCapabilities(1 << SpvCapabilityShader) + SPIRVCodeGenerator(const Context* context) + : fContext(*context) + , fCapabilities(1 << SpvCapabilityShader) , fIdCount(1) , fBoolTrue(0) , fBoolFalse(0) @@ -92,9 +93,9 @@ private: SpvId getType(const Type& type); - SpvId getFunctionType(std::shared_ptr<FunctionDeclaration> function); + SpvId getFunctionType(const FunctionDeclaration& function); - SpvId getPointerType(std::shared_ptr<Type> type, SpvStorageClass_ storageClass); + SpvId getPointerType(const Type& type, SpvStorageClass_ storageClass); std::vector<SpvId> getAccessChain(Expression& expr, std::ostream& out); @@ -108,11 +109,11 @@ private: SpvId writeInterfaceBlock(InterfaceBlock& intf); - SpvId writeFunctionStart(std::shared_ptr<FunctionDeclaration> f, std::ostream& out); + SpvId writeFunctionStart(const FunctionDeclaration& f, std::ostream& out); - SpvId writeFunctionDeclaration(std::shared_ptr<FunctionDeclaration> f, std::ostream& out); + SpvId writeFunctionDeclaration(const FunctionDeclaration& f, std::ostream& out); - SpvId writeFunction(FunctionDefinition& f, std::ostream& out); + SpvId writeFunction(const FunctionDefinition& f, std::ostream& out); void writeGlobalVars(VarDeclaration& v, std::ostream& out); @@ -227,14 +228,16 @@ private: int32_t word5, int32_t word6, int32_t word7, int32_t word8, std::ostream& out); + const Context& fContext; + uint64_t fCapabilities; SpvId fIdCount; SpvId fGLSLExtendedInstructions; typedef std::tuple<IntrinsicKind, int32_t, int32_t, int32_t, int32_t> Intrinsic; std::unordered_map<std::string, Intrinsic> fIntrinsicMap; - std::unordered_map<std::shared_ptr<FunctionDeclaration>, SpvId> fFunctionMap; - std::unordered_map<std::shared_ptr<Variable>, SpvId> fVariableMap; - std::unordered_map<std::shared_ptr<Variable>, int32_t> fInterfaceBlockMap; + std::unordered_map<const FunctionDeclaration*, SpvId> fFunctionMap; + std::unordered_map<const Variable*, SpvId> fVariableMap; + std::unordered_map<const Variable*, int32_t> fInterfaceBlockMap; std::unordered_map<std::string, SpvId> fTypeMap; std::stringstream fCapabilitiesBuffer; std::stringstream fGlobalInitializersBuffer; diff --git a/src/sksl/ir/SkSLBinaryExpression.h b/src/sksl/ir/SkSLBinaryExpression.h index bd89d6c602..9ecdbc717c 100644 --- a/src/sksl/ir/SkSLBinaryExpression.h +++ b/src/sksl/ir/SkSLBinaryExpression.h @@ -18,7 +18,7 @@ namespace SkSL { */ struct BinaryExpression : public Expression { BinaryExpression(Position position, std::unique_ptr<Expression> left, Token::Kind op, - std::unique_ptr<Expression> right, std::shared_ptr<Type> type) + std::unique_ptr<Expression> right, const Type& type) : INHERITED(position, kBinary_Kind, type) , fLeft(std::move(left)) , fOperator(op) diff --git a/src/sksl/ir/SkSLBlock.h b/src/sksl/ir/SkSLBlock.h index 56ed77a0ba..a53d13d169 100644 --- a/src/sksl/ir/SkSLBlock.h +++ b/src/sksl/ir/SkSLBlock.h @@ -9,6 +9,7 @@ #define SKSL_BLOCK #include "SkSLStatement.h" +#include "SkSLSymbolTable.h" namespace SkSL { @@ -16,9 +17,11 @@ namespace SkSL { * A block of multiple statements functioning as a single statement. */ struct Block : public Statement { - Block(Position position, std::vector<std::unique_ptr<Statement>> statements) + Block(Position position, std::vector<std::unique_ptr<Statement>> statements, + const std::shared_ptr<SymbolTable> symbols) : INHERITED(position, kBlock_Kind) - , fStatements(std::move(statements)) {} + , fStatements(std::move(statements)) + , fSymbols(std::move(symbols)) {} std::string description() const override { std::string result = "{"; @@ -31,6 +34,7 @@ struct Block : public Statement { } const std::vector<std::unique_ptr<Statement>> fStatements; + const std::shared_ptr<SymbolTable> fSymbols; typedef Statement INHERITED; }; diff --git a/src/sksl/ir/SkSLBoolLiteral.h b/src/sksl/ir/SkSLBoolLiteral.h index 3c40e59514..8f55a69311 100644 --- a/src/sksl/ir/SkSLBoolLiteral.h +++ b/src/sksl/ir/SkSLBoolLiteral.h @@ -8,6 +8,7 @@ #ifndef SKSL_BOOLLITERAL #define SKSL_BOOLLITERAL +#include "SkSLContext.h" #include "SkSLExpression.h" namespace SkSL { @@ -16,8 +17,8 @@ namespace SkSL { * Represents 'true' or 'false'. */ struct BoolLiteral : public Expression { - BoolLiteral(Position position, bool value) - : INHERITED(position, kBoolLiteral_Kind, kBool_Type) + BoolLiteral(const Context& context, Position position, bool value) + : INHERITED(position, kBoolLiteral_Kind, *context.fBool_Type) , fValue(value) {} std::string description() const override { diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h index c58da7e5b8..0501b651ea 100644 --- a/src/sksl/ir/SkSLConstructor.h +++ b/src/sksl/ir/SkSLConstructor.h @@ -16,13 +16,13 @@ namespace SkSL { * Represents the construction of a compound type, such as "vec2(x, y)". */ struct Constructor : public Expression { - Constructor(Position position, std::shared_ptr<Type> type, + Constructor(Position position, const Type& type, std::vector<std::unique_ptr<Expression>> arguments) - : INHERITED(position, kConstructor_Kind, std::move(type)) + : INHERITED(position, kConstructor_Kind, type) , fArguments(std::move(arguments)) {} std::string description() const override { - std::string result = fType->description() + "("; + std::string result = fType.description() + "("; std::string separator = ""; for (size_t i = 0; i < fArguments.size(); i++) { result += separator; diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h index 1e42c7a475..92cb37de77 100644 --- a/src/sksl/ir/SkSLExpression.h +++ b/src/sksl/ir/SkSLExpression.h @@ -35,7 +35,7 @@ struct Expression : public IRNode { kTypeReference_Kind, }; - Expression(Position position, Kind kind, std::shared_ptr<Type> type) + Expression(Position position, Kind kind, const Type& type) : INHERITED(position) , fKind(kind) , fType(std::move(type)) {} @@ -45,7 +45,7 @@ struct Expression : public IRNode { } const Kind fKind; - const std::shared_ptr<Type> fType; + const Type& fType; typedef IRNode INHERITED; }; diff --git a/src/sksl/ir/SkSLField.h b/src/sksl/ir/SkSLField.h index f2b68bc2bc..a01df2943d 100644 --- a/src/sksl/ir/SkSLField.h +++ b/src/sksl/ir/SkSLField.h @@ -21,16 +21,16 @@ namespace SkSL { * result of declaring anonymous interface blocks. */ struct Field : public Symbol { - Field(Position position, std::shared_ptr<Variable> owner, int fieldIndex) - : INHERITED(position, kField_Kind, owner->fType->fields()[fieldIndex].fName) + Field(Position position, const Variable& owner, int fieldIndex) + : INHERITED(position, kField_Kind, owner.fType.fields()[fieldIndex].fName) , fOwner(owner) , fFieldIndex(fieldIndex) {} virtual std::string description() const override { - return fOwner->description() + "." + fOwner->fType->fields()[fFieldIndex].fName; + return fOwner.description() + "." + fOwner.fType.fields()[fFieldIndex].fName; } - const std::shared_ptr<Variable> fOwner; + const Variable& fOwner; const int fFieldIndex; typedef Symbol INHERITED; diff --git a/src/sksl/ir/SkSLFieldAccess.h b/src/sksl/ir/SkSLFieldAccess.h index 053498e154..f09c3a3447 100644 --- a/src/sksl/ir/SkSLFieldAccess.h +++ b/src/sksl/ir/SkSLFieldAccess.h @@ -18,12 +18,12 @@ namespace SkSL { */ struct FieldAccess : public Expression { FieldAccess(std::unique_ptr<Expression> base, int fieldIndex) - : INHERITED(base->fPosition, kFieldAccess_Kind, base->fType->fields()[fieldIndex].fType) + : INHERITED(base->fPosition, kFieldAccess_Kind, base->fType.fields()[fieldIndex].fType) , fBase(std::move(base)) , fFieldIndex(fieldIndex) {} virtual std::string description() const override { - return fBase->description() + "." + fBase->fType->fields()[fFieldIndex].fName; + return fBase->description() + "." + fBase->fType.fields()[fFieldIndex].fName; } const std::unique_ptr<Expression> fBase; diff --git a/src/sksl/ir/SkSLFloatLiteral.h b/src/sksl/ir/SkSLFloatLiteral.h index deb5b27144..d9c8b6538a 100644 --- a/src/sksl/ir/SkSLFloatLiteral.h +++ b/src/sksl/ir/SkSLFloatLiteral.h @@ -8,6 +8,7 @@ #ifndef SKSL_FLOATLITERAL #define SKSL_FLOATLITERAL +#include "SkSLContext.h" #include "SkSLExpression.h" namespace SkSL { @@ -16,8 +17,8 @@ namespace SkSL { * A literal floating point number. */ struct FloatLiteral : public Expression { - FloatLiteral(Position position, double value) - : INHERITED(position, kFloatLiteral_Kind, kFloat_Type) + FloatLiteral(const Context& context, Position position, double value) + : INHERITED(position, kFloatLiteral_Kind, *context.fFloat_Type) , fValue(value) {} virtual std::string description() const override { diff --git a/src/sksl/ir/SkSLForStatement.h b/src/sksl/ir/SkSLForStatement.h index 70bb4014c8..642d15125e 100644 --- a/src/sksl/ir/SkSLForStatement.h +++ b/src/sksl/ir/SkSLForStatement.h @@ -10,6 +10,7 @@ #include "SkSLExpression.h" #include "SkSLStatement.h" +#include "SkSLSymbolTable.h" namespace SkSL { @@ -19,12 +20,13 @@ namespace SkSL { struct ForStatement : public Statement { ForStatement(Position position, std::unique_ptr<Statement> initializer, std::unique_ptr<Expression> test, std::unique_ptr<Expression> next, - std::unique_ptr<Statement> statement) + std::unique_ptr<Statement> statement, std::shared_ptr<SymbolTable> symbols) : INHERITED(position, kFor_Kind) , fInitializer(std::move(initializer)) , fTest(std::move(test)) , fNext(std::move(next)) - , fStatement(std::move(statement)) {} + , fStatement(std::move(statement)) + , fSymbols(symbols) {} std::string description() const override { std::string result = "for ("; @@ -47,6 +49,7 @@ struct ForStatement : public Statement { const std::unique_ptr<Expression> fTest; const std::unique_ptr<Expression> fNext; const std::unique_ptr<Statement> fStatement; + const std::shared_ptr<SymbolTable> fSymbols; typedef Statement INHERITED; }; diff --git a/src/sksl/ir/SkSLFunctionCall.h b/src/sksl/ir/SkSLFunctionCall.h index 78d2566227..85dba40f2a 100644 --- a/src/sksl/ir/SkSLFunctionCall.h +++ b/src/sksl/ir/SkSLFunctionCall.h @@ -17,14 +17,14 @@ namespace SkSL { * A function invocation. */ struct FunctionCall : public Expression { - FunctionCall(Position position, std::shared_ptr<FunctionDeclaration> function, + FunctionCall(Position position, const FunctionDeclaration& function, std::vector<std::unique_ptr<Expression>> arguments) - : INHERITED(position, kFunctionCall_Kind, function->fReturnType) + : INHERITED(position, kFunctionCall_Kind, function.fReturnType) , fFunction(std::move(function)) , fArguments(std::move(arguments)) {} std::string description() const override { - std::string result = fFunction->fName + "("; + std::string result = fFunction.fName + "("; std::string separator = ""; for (size_t i = 0; i < fArguments.size(); i++) { result += separator; @@ -35,7 +35,7 @@ struct FunctionCall : public Expression { return result; } - const std::shared_ptr<FunctionDeclaration> fFunction; + const FunctionDeclaration& fFunction; const std::vector<std::unique_ptr<Expression>> fArguments; typedef Expression INHERITED; diff --git a/src/sksl/ir/SkSLFunctionDeclaration.h b/src/sksl/ir/SkSLFunctionDeclaration.h index 32c23f545e..16a184a6d7 100644 --- a/src/sksl/ir/SkSLFunctionDeclaration.h +++ b/src/sksl/ir/SkSLFunctionDeclaration.h @@ -10,6 +10,7 @@ #include "SkSLModifiers.h" #include "SkSLSymbol.h" +#include "SkSLSymbolTable.h" #include "SkSLType.h" #include "SkSLVariable.h" @@ -20,15 +21,14 @@ namespace SkSL { */ struct FunctionDeclaration : public Symbol { FunctionDeclaration(Position position, std::string name, - std::vector<std::shared_ptr<Variable>> parameters, - std::shared_ptr<Type> returnType) + std::vector<const Variable*> parameters, const Type& returnType) : INHERITED(position, kFunctionDeclaration_Kind, std::move(name)) , fDefined(false) - , fParameters(parameters) + , fParameters(std::move(parameters)) , fReturnType(returnType) {} std::string description() const override { - std::string result = fReturnType->description() + " " + fName + "("; + std::string result = fReturnType.description() + " " + fName + "("; std::string separator = ""; for (auto p : fParameters) { result += separator; @@ -39,13 +39,24 @@ struct FunctionDeclaration : public Symbol { return result; } - bool matches(FunctionDeclaration& f) { - return fName == f.fName && fParameters == f.fParameters; + bool matches(const FunctionDeclaration& f) const { + if (fName != f.fName) { + return false; + } + if (fParameters.size() != f.fParameters.size()) { + return false; + } + for (size_t i = 0; i < fParameters.size(); i++) { + if (fParameters[i]->fType != f.fParameters[i]->fType) { + return false; + } + } + return true; } mutable bool fDefined; - const std::vector<std::shared_ptr<Variable>> fParameters; - const std::shared_ptr<Type> fReturnType; + const std::vector<const Variable*> fParameters; + const Type& fReturnType; typedef Symbol INHERITED; }; diff --git a/src/sksl/ir/SkSLFunctionDefinition.h b/src/sksl/ir/SkSLFunctionDefinition.h index fceb5474cb..ace27a3ed8 100644 --- a/src/sksl/ir/SkSLFunctionDefinition.h +++ b/src/sksl/ir/SkSLFunctionDefinition.h @@ -18,17 +18,17 @@ namespace SkSL { * A function definition (a declaration plus an associated block of code). */ struct FunctionDefinition : public ProgramElement { - FunctionDefinition(Position position, std::shared_ptr<FunctionDeclaration> declaration, + FunctionDefinition(Position position, const FunctionDeclaration& declaration, std::unique_ptr<Block> body) : INHERITED(position, kFunction_Kind) - , fDeclaration(std::move(declaration)) + , fDeclaration(declaration) , fBody(std::move(body)) {} std::string description() const override { - return fDeclaration->description() + " " + fBody->description(); + return fDeclaration.description() + " " + fBody->description(); } - const std::shared_ptr<FunctionDeclaration> fDeclaration; + const FunctionDeclaration& fDeclaration; const std::unique_ptr<Block> fBody; typedef ProgramElement INHERITED; diff --git a/src/sksl/ir/SkSLFunctionReference.h b/src/sksl/ir/SkSLFunctionReference.h index d5cc444000..5d97a5879f 100644 --- a/src/sksl/ir/SkSLFunctionReference.h +++ b/src/sksl/ir/SkSLFunctionReference.h @@ -8,6 +8,7 @@ #ifndef SKSL_FUNCTIONREFERENCE #define SKSL_FUNCTIONREFERENCE +#include "SkSLContext.h" #include "SkSLExpression.h" namespace SkSL { @@ -17,8 +18,9 @@ namespace SkSL { * always eventually replaced by FunctionCalls in valid programs. */ struct FunctionReference : public Expression { - FunctionReference(Position position, std::vector<std::shared_ptr<FunctionDeclaration>> function) - : INHERITED(position, kFunctionReference_Kind, kInvalid_Type) + FunctionReference(const Context& context, Position position, + std::vector<const FunctionDeclaration*> function) + : INHERITED(position, kFunctionReference_Kind, *context.fInvalid_Type) , fFunctions(function) {} virtual std::string description() const override { @@ -26,7 +28,7 @@ struct FunctionReference : public Expression { return "<function>"; } - const std::vector<std::shared_ptr<FunctionDeclaration>> fFunctions; + const std::vector<const FunctionDeclaration*> fFunctions; typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLIndexExpression.h b/src/sksl/ir/SkSLIndexExpression.h index 538c656153..f5b0d09c2c 100644 --- a/src/sksl/ir/SkSLIndexExpression.h +++ b/src/sksl/ir/SkSLIndexExpression.h @@ -16,21 +16,21 @@ namespace SkSL { /** * Given a type, returns the type that will result from extracting an array value from it. */ -static std::shared_ptr<Type> index_type(const Type& type) { +static const Type& index_type(const Context& context, const Type& type) { if (type.kind() == Type::kMatrix_Kind) { - if (type.componentType() == kFloat_Type) { + if (type.componentType() == *context.fFloat_Type) { switch (type.columns()) { - case 2: return kVec2_Type; - case 3: return kVec3_Type; - case 4: return kVec4_Type; + case 2: return *context.fVec2_Type; + case 3: return *context.fVec3_Type; + case 4: return *context.fVec4_Type; default: ASSERT(false); } } else { - ASSERT(type.componentType() == kDouble_Type); + ASSERT(type.componentType() == *context.fDouble_Type); switch (type.columns()) { - case 2: return kDVec2_Type; - case 3: return kDVec3_Type; - case 4: return kDVec4_Type; + case 2: return *context.fDVec2_Type; + case 3: return *context.fDVec3_Type; + case 4: return *context.fDVec4_Type; default: ASSERT(false); } } @@ -42,11 +42,12 @@ static std::shared_ptr<Type> index_type(const Type& type) { * An expression which extracts a value from an array or matrix, as in 'm[2]'. */ struct IndexExpression : public Expression { - IndexExpression(std::unique_ptr<Expression> base, std::unique_ptr<Expression> index) - : INHERITED(base->fPosition, kIndex_Kind, index_type(*base->fType)) + IndexExpression(const Context& context, std::unique_ptr<Expression> base, + std::unique_ptr<Expression> index) + : INHERITED(base->fPosition, kIndex_Kind, index_type(context, base->fType)) , fBase(std::move(base)) , fIndex(std::move(index)) { - ASSERT(fIndex->fType == kInt_Type); + ASSERT(fIndex->fType == *context.fInt_Type); } std::string description() const override { diff --git a/src/sksl/ir/SkSLIntLiteral.h b/src/sksl/ir/SkSLIntLiteral.h index 80b30d7c05..f2bf40b590 100644 --- a/src/sksl/ir/SkSLIntLiteral.h +++ b/src/sksl/ir/SkSLIntLiteral.h @@ -18,8 +18,8 @@ namespace SkSL { struct IntLiteral : public Expression { // FIXME: we will need to revisit this if/when we add full support for both signed and unsigned // 64-bit integers, but for right now an int64_t will hold every value we care about - IntLiteral(Position position, int64_t value) - : INHERITED(position, kIntLiteral_Kind, kInt_Type) + IntLiteral(const Context& context, Position position, int64_t value) + : INHERITED(position, kIntLiteral_Kind, *context.fInt_Type) , fValue(value) {} virtual std::string description() const override { diff --git a/src/sksl/ir/SkSLInterfaceBlock.h b/src/sksl/ir/SkSLInterfaceBlock.h index baedb5864c..f1121ed707 100644 --- a/src/sksl/ir/SkSLInterfaceBlock.h +++ b/src/sksl/ir/SkSLInterfaceBlock.h @@ -24,22 +24,24 @@ namespace SkSL { * At the IR level, this is represented by a single variable of struct type. */ struct InterfaceBlock : public ProgramElement { - InterfaceBlock(Position position, std::shared_ptr<Variable> var) + InterfaceBlock(Position position, const Variable& var, std::shared_ptr<SymbolTable> typeOwner) : INHERITED(position, kInterfaceBlock_Kind) - , fVariable(std::move(var)) { - ASSERT(fVariable->fType->kind() == Type::kStruct_Kind); + , fVariable(std::move(var)) + , fTypeOwner(typeOwner) { + ASSERT(fVariable.fType.kind() == Type::kStruct_Kind); } std::string description() const override { - std::string result = fVariable->fModifiers.description() + fVariable->fName + " {\n"; - for (size_t i = 0; i < fVariable->fType->fields().size(); i++) { - result += fVariable->fType->fields()[i].description() + "\n"; + std::string result = fVariable.fModifiers.description() + fVariable.fName + " {\n"; + for (size_t i = 0; i < fVariable.fType.fields().size(); i++) { + result += fVariable.fType.fields()[i].description() + "\n"; } result += "};"; return result; } - const std::shared_ptr<Variable> fVariable; + const Variable& fVariable; + const std::shared_ptr<SymbolTable> fTypeOwner; typedef ProgramElement INHERITED; }; diff --git a/src/sksl/ir/SkSLProgram.h b/src/sksl/ir/SkSLProgram.h index 5edcfded42..205db6e932 100644 --- a/src/sksl/ir/SkSLProgram.h +++ b/src/sksl/ir/SkSLProgram.h @@ -12,6 +12,7 @@ #include <memory> #include "SkSLProgramElement.h" +#include "SkSLSymbolTable.h" namespace SkSL { @@ -24,13 +25,16 @@ struct Program { kVertex_Kind }; - Program(Kind kind, std::vector<std::unique_ptr<ProgramElement>> elements) + Program(Kind kind, std::vector<std::unique_ptr<ProgramElement>> elements, + std::shared_ptr<SymbolTable> symbols) : fKind(kind) - , fElements(std::move(elements)) {} + , fElements(std::move(elements)) + , fSymbols(symbols) {} Kind fKind; std::vector<std::unique_ptr<ProgramElement>> fElements; + std::shared_ptr<SymbolTable> fSymbols; }; } // namespace diff --git a/src/sksl/ir/SkSLSwizzle.h b/src/sksl/ir/SkSLSwizzle.h index ce360d1847..0eb4a00dca 100644 --- a/src/sksl/ir/SkSLSwizzle.h +++ b/src/sksl/ir/SkSLSwizzle.h @@ -18,41 +18,40 @@ namespace SkSL { * instance, swizzling a vec3 with two components will result in a vec2. It is possible to swizzle * with more components than the source vector, as in 'vec2(1).xxxx'. */ -static std::shared_ptr<Type> get_type(Expression& value, - size_t count) { - std::shared_ptr<Type> base = value.fType->componentType(); +static const Type& get_type(const Context& context, Expression& value, size_t count) { + const Type& base = value.fType.componentType(); if (count == 1) { return base; } - if (base == kFloat_Type) { + if (base == *context.fFloat_Type) { switch (count) { - case 2: return kVec2_Type; - case 3: return kVec3_Type; - case 4: return kVec4_Type; + case 2: return *context.fVec2_Type; + case 3: return *context.fVec3_Type; + case 4: return *context.fVec4_Type; } - } else if (base == kDouble_Type) { + } else if (base == *context.fDouble_Type) { switch (count) { - case 2: return kDVec2_Type; - case 3: return kDVec3_Type; - case 4: return kDVec4_Type; + case 2: return *context.fDVec2_Type; + case 3: return *context.fDVec3_Type; + case 4: return *context.fDVec4_Type; } - } else if (base == kInt_Type) { + } else if (base == *context.fInt_Type) { switch (count) { - case 2: return kIVec2_Type; - case 3: return kIVec3_Type; - case 4: return kIVec4_Type; + case 2: return *context.fIVec2_Type; + case 3: return *context.fIVec3_Type; + case 4: return *context.fIVec4_Type; } - } else if (base == kUInt_Type) { + } else if (base == *context.fUInt_Type) { switch (count) { - case 2: return kUVec2_Type; - case 3: return kUVec3_Type; - case 4: return kUVec4_Type; + case 2: return *context.fUVec2_Type; + case 3: return *context.fUVec3_Type; + case 4: return *context.fUVec4_Type; } - } else if (base == kBool_Type) { + } else if (base == *context.fBool_Type) { switch (count) { - case 2: return kBVec2_Type; - case 3: return kBVec3_Type; - case 4: return kBVec4_Type; + case 2: return *context.fBVec2_Type; + case 3: return *context.fBVec3_Type; + case 4: return *context.fBVec4_Type; } } ABORT("cannot swizzle %s\n", value.description().c_str()); @@ -62,8 +61,8 @@ static std::shared_ptr<Type> get_type(Expression& value, * Represents a vector swizzle operation such as 'vec2(1, 2, 3).zyx'. */ struct Swizzle : public Expression { - Swizzle(std::unique_ptr<Expression> base, std::vector<int> components) - : INHERITED(base->fPosition, kSwizzle_Kind, get_type(*base, components.size())) + Swizzle(const Context& context, std::unique_ptr<Expression> base, std::vector<int> components) + : INHERITED(base->fPosition, kSwizzle_Kind, get_type(context, *base, components.size())) , fBase(std::move(base)) , fComponents(std::move(components)) { ASSERT(fComponents.size() >= 1 && fComponents.size() <= 4); diff --git a/src/sksl/ir/SkSLSymbolTable.cpp b/src/sksl/ir/SkSLSymbolTable.cpp index af83f7a456..80e22da009 100644 --- a/src/sksl/ir/SkSLSymbolTable.cpp +++ b/src/sksl/ir/SkSLSymbolTable.cpp @@ -5,23 +5,23 @@ * found in the LICENSE file. */ - #include "SkSLSymbolTable.h" +#include "SkSLSymbolTable.h" +#include "SkSLUnresolvedFunction.h" namespace SkSL { -std::vector<std::shared_ptr<FunctionDeclaration>> SymbolTable::GetFunctions( - const std::shared_ptr<Symbol>& s) { - switch (s->fKind) { +std::vector<const FunctionDeclaration*> SymbolTable::GetFunctions(const Symbol& s) { + switch (s.fKind) { case Symbol::kFunctionDeclaration_Kind: - return { std::static_pointer_cast<FunctionDeclaration>(s) }; + return { &((FunctionDeclaration&) s) }; case Symbol::kUnresolvedFunction_Kind: - return ((UnresolvedFunction&) *s).fFunctions; + return ((UnresolvedFunction&) s).fFunctions; default: return { }; } } -std::shared_ptr<Symbol> SymbolTable::operator[](const std::string& name) { +const Symbol* SymbolTable::operator[](const std::string& name) { const auto& entry = fSymbols.find(name); if (entry == fSymbols.end()) { if (fParent) { @@ -30,15 +30,15 @@ std::shared_ptr<Symbol> SymbolTable::operator[](const std::string& name) { return nullptr; } if (fParent) { - auto functions = GetFunctions(entry->second); + auto functions = GetFunctions(*entry->second); if (functions.size() > 0) { bool modified = false; - std::shared_ptr<Symbol> previous = (*fParent)[name]; + const Symbol* previous = (*fParent)[name]; if (previous) { - auto previousFunctions = GetFunctions(previous); - for (const std::shared_ptr<FunctionDeclaration>& prev : previousFunctions) { + auto previousFunctions = GetFunctions(*previous); + for (const FunctionDeclaration* prev : previousFunctions) { bool found = false; - for (const std::shared_ptr<FunctionDeclaration>& current : functions) { + for (const FunctionDeclaration* current : functions) { if (current->matches(*prev)) { found = true; break; @@ -51,7 +51,7 @@ std::shared_ptr<Symbol> SymbolTable::operator[](const std::string& name) { } if (modified) { ASSERT(functions.size() > 1); - return std::shared_ptr<Symbol>(new UnresolvedFunction(functions)); + return this->takeOwnership(new UnresolvedFunction(functions)); } } } @@ -59,27 +59,42 @@ std::shared_ptr<Symbol> SymbolTable::operator[](const std::string& name) { return entry->second; } -void SymbolTable::add(const std::string& name, std::shared_ptr<Symbol> symbol) { - const auto& existing = fSymbols.find(name); - if (existing == fSymbols.end()) { - fSymbols[name] = symbol; - } else if (symbol->fKind == Symbol::kFunctionDeclaration_Kind) { - const std::shared_ptr<Symbol>& oldSymbol = existing->second; - if (oldSymbol->fKind == Symbol::kFunctionDeclaration_Kind) { - std::vector<std::shared_ptr<FunctionDeclaration>> functions; - functions.push_back(std::static_pointer_cast<FunctionDeclaration>(oldSymbol)); - functions.push_back(std::static_pointer_cast<FunctionDeclaration>(symbol)); - fSymbols[name].reset(new UnresolvedFunction(std::move(functions))); - } else if (oldSymbol->fKind == Symbol::kUnresolvedFunction_Kind) { - std::vector<std::shared_ptr<FunctionDeclaration>> functions; - for (const auto& f : ((UnresolvedFunction&) *oldSymbol).fFunctions) { - functions.push_back(f); - } - functions.push_back(std::static_pointer_cast<FunctionDeclaration>(symbol)); - fSymbols[name].reset(new UnresolvedFunction(std::move(functions))); +Symbol* SymbolTable::takeOwnership(Symbol* s) { + fOwnedPointers.push_back(std::unique_ptr<Symbol>(s)); + return s; +} + +void SymbolTable::add(const std::string& name, std::unique_ptr<Symbol> symbol) { + this->addWithoutOwnership(name, symbol.get()); + fOwnedPointers.push_back(std::move(symbol)); +} + +void SymbolTable::addWithoutOwnership(const std::string& name, const Symbol* symbol) { + const auto& existing = fSymbols.find(name); + if (existing == fSymbols.end()) { + fSymbols[name] = symbol; + } else if (symbol->fKind == Symbol::kFunctionDeclaration_Kind) { + const Symbol* oldSymbol = existing->second; + if (oldSymbol->fKind == Symbol::kFunctionDeclaration_Kind) { + std::vector<const FunctionDeclaration*> functions; + functions.push_back((const FunctionDeclaration*) oldSymbol); + functions.push_back((const FunctionDeclaration*) symbol); + UnresolvedFunction* u = new UnresolvedFunction(std::move(functions)); + fSymbols[name] = u; + this->takeOwnership(u); + } else if (oldSymbol->fKind == Symbol::kUnresolvedFunction_Kind) { + std::vector<const FunctionDeclaration*> functions; + for (const auto* f : ((UnresolvedFunction&) *oldSymbol).fFunctions) { + functions.push_back(f); } - } else { - fErrorReporter.error(symbol->fPosition, "symbol '" + name + "' was already defined"); + functions.push_back((const FunctionDeclaration*) symbol); + UnresolvedFunction* u = new UnresolvedFunction(std::move(functions)); + fSymbols[name] = u; + this->takeOwnership(u); } + } else { + fErrorReporter.error(symbol->fPosition, "symbol '" + name + "' was already defined"); } +} + } // namespace diff --git a/src/sksl/ir/SkSLSymbolTable.h b/src/sksl/ir/SkSLSymbolTable.h index 151475d642..d732023ff0 100644 --- a/src/sksl/ir/SkSLSymbolTable.h +++ b/src/sksl/ir/SkSLSymbolTable.h @@ -10,12 +10,14 @@ #include <memory> #include <unordered_map> +#include <vector> #include "SkSLErrorReporter.h" #include "SkSLSymbol.h" -#include "SkSLUnresolvedFunction.h" namespace SkSL { +struct FunctionDeclaration; + /** * Maps identifiers to symbols. Functions, in particular, are mapped to either FunctionDeclaration * or UnresolvedFunction depending on whether they are overloaded or not. @@ -29,17 +31,22 @@ public: : fParent(parent) , fErrorReporter(errorReporter) {} - std::shared_ptr<Symbol> operator[](const std::string& name); + const Symbol* operator[](const std::string& name); + + void add(const std::string& name, std::unique_ptr<Symbol> symbol); - void add(const std::string& name, std::shared_ptr<Symbol> symbol); + void addWithoutOwnership(const std::string& name, const Symbol* symbol); + + Symbol* takeOwnership(Symbol* s); const std::shared_ptr<SymbolTable> fParent; private: - static std::vector<std::shared_ptr<FunctionDeclaration>> GetFunctions( - const std::shared_ptr<Symbol>& s); + static std::vector<const FunctionDeclaration*> GetFunctions(const Symbol& s); + + std::vector<std::unique_ptr<Symbol>> fOwnedPointers; - std::unordered_map<std::string, std::shared_ptr<Symbol>> fSymbols; + std::unordered_map<std::string, const Symbol*> fSymbols; ErrorReporter& fErrorReporter; }; diff --git a/src/sksl/ir/SkSLType.cpp b/src/sksl/ir/SkSLType.cpp index 27cbd39e44..d28c4f0666 100644 --- a/src/sksl/ir/SkSLType.cpp +++ b/src/sksl/ir/SkSLType.cpp @@ -6,29 +6,30 @@ */ #include "SkSLType.h" +#include "SkSLContext.h" namespace SkSL { -bool Type::determineCoercionCost(std::shared_ptr<Type> other, int* outCost) const { - if (this == other.get()) { +bool Type::determineCoercionCost(const Type& other, int* outCost) const { + if (*this == other) { *outCost = 0; return true; } - if (this->kind() == kVector_Kind && other->kind() == kVector_Kind) { - if (this->columns() == other->columns()) { - return this->componentType()->determineCoercionCost(other->componentType(), outCost); + if (this->kind() == kVector_Kind && other.kind() == kVector_Kind) { + if (this->columns() == other.columns()) { + return this->componentType().determineCoercionCost(other.componentType(), outCost); } return false; } if (this->kind() == kMatrix_Kind) { - if (this->columns() == other->columns() && - this->rows() == other->rows()) { - return this->componentType()->determineCoercionCost(other->componentType(), outCost); + if (this->columns() == other.columns() && + this->rows() == other.rows()) { + return this->componentType().determineCoercionCost(other.componentType(), outCost); } return false; } for (size_t i = 0; i < fCoercibleTypes.size(); i++) { - if (fCoercibleTypes[i] == other) { + if (*fCoercibleTypes[i] == other) { *outCost = (int) i + 1; return true; } @@ -36,93 +37,93 @@ bool Type::determineCoercionCost(std::shared_ptr<Type> other, int* outCost) cons return false; } -std::shared_ptr<Type> Type::toCompound(int columns, int rows) { +const Type& Type::toCompound(const Context& context, int columns, int rows) const { ASSERT(this->kind() == Type::kScalar_Kind); if (columns == 1 && rows == 1) { - return std::shared_ptr<Type>(this); + return *this; } - if (*this == *kFloat_Type) { + if (*this == *context.fFloat_Type) { switch (rows) { case 1: switch (columns) { - case 2: return kVec2_Type; - case 3: return kVec3_Type; - case 4: return kVec4_Type; + case 2: return *context.fVec2_Type; + case 3: return *context.fVec3_Type; + case 4: return *context.fVec4_Type; default: ABORT("unsupported vector column count (%d)", columns); } case 2: switch (columns) { - case 2: return kMat2x2_Type; - case 3: return kMat3x2_Type; - case 4: return kMat4x2_Type; + case 2: return *context.fMat2x2_Type; + case 3: return *context.fMat3x2_Type; + case 4: return *context.fMat4x2_Type; default: ABORT("unsupported matrix column count (%d)", columns); } case 3: switch (columns) { - case 2: return kMat2x3_Type; - case 3: return kMat3x3_Type; - case 4: return kMat4x3_Type; + case 2: return *context.fMat2x3_Type; + case 3: return *context.fMat3x3_Type; + case 4: return *context.fMat4x3_Type; default: ABORT("unsupported matrix column count (%d)", columns); } case 4: switch (columns) { - case 2: return kMat2x4_Type; - case 3: return kMat3x4_Type; - case 4: return kMat4x4_Type; + case 2: return *context.fMat2x4_Type; + case 3: return *context.fMat3x4_Type; + case 4: return *context.fMat4x4_Type; default: ABORT("unsupported matrix column count (%d)", columns); } default: ABORT("unsupported row count (%d)", rows); } - } else if (*this == *kDouble_Type) { + } else if (*this == *context.fDouble_Type) { switch (rows) { case 1: switch (columns) { - case 2: return kDVec2_Type; - case 3: return kDVec3_Type; - case 4: return kDVec4_Type; + case 2: return *context.fDVec2_Type; + case 3: return *context.fDVec3_Type; + case 4: return *context.fDVec4_Type; default: ABORT("unsupported vector column count (%d)", columns); } case 2: switch (columns) { - case 2: return kDMat2x2_Type; - case 3: return kDMat3x2_Type; - case 4: return kDMat4x2_Type; + case 2: return *context.fDMat2x2_Type; + case 3: return *context.fDMat3x2_Type; + case 4: return *context.fDMat4x2_Type; default: ABORT("unsupported matrix column count (%d)", columns); } case 3: switch (columns) { - case 2: return kDMat2x3_Type; - case 3: return kDMat3x3_Type; - case 4: return kDMat4x3_Type; + case 2: return *context.fDMat2x3_Type; + case 3: return *context.fDMat3x3_Type; + case 4: return *context.fDMat4x3_Type; default: ABORT("unsupported matrix column count (%d)", columns); } case 4: switch (columns) { - case 2: return kDMat2x4_Type; - case 3: return kDMat3x4_Type; - case 4: return kDMat4x4_Type; + case 2: return *context.fDMat2x4_Type; + case 3: return *context.fDMat3x4_Type; + case 4: return *context.fDMat4x4_Type; default: ABORT("unsupported matrix column count (%d)", columns); } default: ABORT("unsupported row count (%d)", rows); } - } else if (*this == *kInt_Type) { + } else if (*this == *context.fInt_Type) { switch (rows) { case 1: switch (columns) { - case 2: return kIVec2_Type; - case 3: return kIVec3_Type; - case 4: return kIVec4_Type; + case 2: return *context.fIVec2_Type; + case 3: return *context.fIVec3_Type; + case 4: return *context.fIVec4_Type; default: ABORT("unsupported vector column count (%d)", columns); } default: ABORT("unsupported row count (%d)", rows); } - } else if (*this == *kUInt_Type) { + } else if (*this == *context.fUInt_Type) { switch (rows) { case 1: switch (columns) { - case 2: return kUVec2_Type; - case 3: return kUVec3_Type; - case 4: return kUVec4_Type; + case 2: return *context.fUVec2_Type; + case 3: return *context.fUVec3_Type; + case 4: return *context.fUVec4_Type; default: ABORT("unsupported vector column count (%d)", columns); } default: ABORT("unsupported row count (%d)", rows); @@ -131,128 +132,4 @@ std::shared_ptr<Type> Type::toCompound(int columns, int rows) { ABORT("unsupported scalar_to_compound type %s", this->description().c_str()); } -const std::shared_ptr<Type> kVoid_Type(new Type("void")); - -const std::shared_ptr<Type> kDouble_Type(new Type("double", true)); -const std::shared_ptr<Type> kDVec2_Type(new Type("dvec2", kDouble_Type, 2)); -const std::shared_ptr<Type> kDVec3_Type(new Type("dvec3", kDouble_Type, 3)); -const std::shared_ptr<Type> kDVec4_Type(new Type("dvec4", kDouble_Type, 4)); - -const std::shared_ptr<Type> kFloat_Type(new Type("float", true, { kDouble_Type })); -const std::shared_ptr<Type> kVec2_Type(new Type("vec2", kFloat_Type, 2)); -const std::shared_ptr<Type> kVec3_Type(new Type("vec3", kFloat_Type, 3)); -const std::shared_ptr<Type> kVec4_Type(new Type("vec4", kFloat_Type, 4)); - -const std::shared_ptr<Type> kUInt_Type(new Type("uint", true, { kFloat_Type, kDouble_Type })); -const std::shared_ptr<Type> kUVec2_Type(new Type("uvec2", kUInt_Type, 2)); -const std::shared_ptr<Type> kUVec3_Type(new Type("uvec3", kUInt_Type, 3)); -const std::shared_ptr<Type> kUVec4_Type(new Type("uvec4", kUInt_Type, 4)); - -const std::shared_ptr<Type> kInt_Type(new Type("int", true, { kUInt_Type, kFloat_Type, - kDouble_Type })); -const std::shared_ptr<Type> kIVec2_Type(new Type("ivec2", kInt_Type, 2)); -const std::shared_ptr<Type> kIVec3_Type(new Type("ivec3", kInt_Type, 3)); -const std::shared_ptr<Type> kIVec4_Type(new Type("ivec4", kInt_Type, 4)); - -const std::shared_ptr<Type> kBool_Type(new Type("bool", false)); -const std::shared_ptr<Type> kBVec2_Type(new Type("bvec2", kBool_Type, 2)); -const std::shared_ptr<Type> kBVec3_Type(new Type("bvec3", kBool_Type, 3)); -const std::shared_ptr<Type> kBVec4_Type(new Type("bvec4", kBool_Type, 4)); - -const std::shared_ptr<Type> kMat2x2_Type(new Type("mat2", kFloat_Type, 2, 2)); -const std::shared_ptr<Type> kMat2x3_Type(new Type("mat2x3", kFloat_Type, 2, 3)); -const std::shared_ptr<Type> kMat2x4_Type(new Type("mat2x4", kFloat_Type, 2, 4)); -const std::shared_ptr<Type> kMat3x2_Type(new Type("mat3x2", kFloat_Type, 3, 2)); -const std::shared_ptr<Type> kMat3x3_Type(new Type("mat3", kFloat_Type, 3, 3)); -const std::shared_ptr<Type> kMat3x4_Type(new Type("mat3x4", kFloat_Type, 3, 4)); -const std::shared_ptr<Type> kMat4x2_Type(new Type("mat4x2", kFloat_Type, 4, 2)); -const std::shared_ptr<Type> kMat4x3_Type(new Type("mat4x3", kFloat_Type, 4, 3)); -const std::shared_ptr<Type> kMat4x4_Type(new Type("mat4", kFloat_Type, 4, 4)); - -const std::shared_ptr<Type> kDMat2x2_Type(new Type("dmat2", kFloat_Type, 2, 2)); -const std::shared_ptr<Type> kDMat2x3_Type(new Type("dmat2x3", kFloat_Type, 2, 3)); -const std::shared_ptr<Type> kDMat2x4_Type(new Type("dmat2x4", kFloat_Type, 2, 4)); -const std::shared_ptr<Type> kDMat3x2_Type(new Type("dmat3x2", kFloat_Type, 3, 2)); -const std::shared_ptr<Type> kDMat3x3_Type(new Type("dmat3", kFloat_Type, 3, 3)); -const std::shared_ptr<Type> kDMat3x4_Type(new Type("dmat3x4", kFloat_Type, 3, 4)); -const std::shared_ptr<Type> kDMat4x2_Type(new Type("dmat4x2", kFloat_Type, 4, 2)); -const std::shared_ptr<Type> kDMat4x3_Type(new Type("dmat4x3", kFloat_Type, 4, 3)); -const std::shared_ptr<Type> kDMat4x4_Type(new Type("dmat4", kFloat_Type, 4, 4)); - -const std::shared_ptr<Type> kSampler1D_Type(new Type("sampler1D", SpvDim1D, false, false, false, true)); -const std::shared_ptr<Type> kSampler2D_Type(new Type("sampler2D", SpvDim2D, false, false, false, true)); -const std::shared_ptr<Type> kSampler3D_Type(new Type("sampler3D", SpvDim3D, false, false, false, true)); -const std::shared_ptr<Type> kSamplerCube_Type(new Type("samplerCube")); -const std::shared_ptr<Type> kSampler2DRect_Type(new Type("sampler2DRect")); -const std::shared_ptr<Type> kSampler1DArray_Type(new Type("sampler1DArray")); -const std::shared_ptr<Type> kSampler2DArray_Type(new Type("sampler2DArray")); -const std::shared_ptr<Type> kSamplerCubeArray_Type(new Type("samplerCubeArray")); -const std::shared_ptr<Type> kSamplerBuffer_Type(new Type("samplerBuffer")); -const std::shared_ptr<Type> kSampler2DMS_Type(new Type("sampler2DMS")); -const std::shared_ptr<Type> kSampler2DMSArray_Type(new Type("sampler2DMSArray")); -const std::shared_ptr<Type> kSampler1DShadow_Type(new Type("sampler1DShadow")); -const std::shared_ptr<Type> kSampler2DShadow_Type(new Type("sampler2DShadow")); -const std::shared_ptr<Type> kSamplerCubeShadow_Type(new Type("samplerCubeShadow")); -const std::shared_ptr<Type> kSampler2DRectShadow_Type(new Type("sampler2DRectShadow")); -const std::shared_ptr<Type> kSampler1DArrayShadow_Type(new Type("sampler1DArrayShadow")); -const std::shared_ptr<Type> kSampler2DArrayShadow_Type(new Type("sampler2DArrayShadow")); -const std::shared_ptr<Type> kSamplerCubeArrayShadow_Type(new Type("samplerCubeArrayShadow")); - -static std::vector<std::shared_ptr<Type>> type(std::shared_ptr<Type> t) { - return { t, t, t, t }; -} - -// FIXME figure out what we're supposed to do with the gsampler et al. types -const std::shared_ptr<Type> kGSampler1D_Type(new Type("$gsampler1D", type(kSampler1D_Type))); -const std::shared_ptr<Type> kGSampler2D_Type(new Type("$gsampler2D", type(kSampler2D_Type))); -const std::shared_ptr<Type> kGSampler3D_Type(new Type("$gsampler3D", type(kSampler3D_Type))); -const std::shared_ptr<Type> kGSamplerCube_Type(new Type("$gsamplerCube", type(kSamplerCube_Type))); -const std::shared_ptr<Type> kGSampler2DRect_Type(new Type("$gsampler2DRect", - type(kSampler2DRect_Type))); -const std::shared_ptr<Type> kGSampler1DArray_Type(new Type("$gsampler1DArray", - type(kSampler1DArray_Type))); -const std::shared_ptr<Type> kGSampler2DArray_Type(new Type("$gsampler2DArray", - type(kSampler2DArray_Type))); -const std::shared_ptr<Type> kGSamplerCubeArray_Type(new Type("$gsamplerCubeArray", - type(kSamplerCubeArray_Type))); -const std::shared_ptr<Type> kGSamplerBuffer_Type(new Type("$gsamplerBuffer", - type(kSamplerBuffer_Type))); -const std::shared_ptr<Type> kGSampler2DMS_Type(new Type("$gsampler2DMS", - type(kSampler2DMS_Type))); -const std::shared_ptr<Type> kGSampler2DMSArray_Type(new Type("$gsampler2DMSArray", - type(kSampler2DMSArray_Type))); -const std::shared_ptr<Type> kGSampler2DArrayShadow_Type(new Type("$gsampler2DArrayShadow", - type(kSampler2DArrayShadow_Type))); -const std::shared_ptr<Type> kGSamplerCubeArrayShadow_Type(new Type("$gsamplerCubeArrayShadow", - type(kSamplerCubeArrayShadow_Type))); - -const std::shared_ptr<Type> kGenType_Type(new Type("$genType", { kFloat_Type, kVec2_Type, - kVec3_Type, kVec4_Type })); -const std::shared_ptr<Type> kGenDType_Type(new Type("$genDType", { kDouble_Type, kDVec2_Type, - kDVec3_Type, kDVec4_Type })); -const std::shared_ptr<Type> kGenIType_Type(new Type("$genIType", { kInt_Type, kIVec2_Type, - kIVec3_Type, kIVec4_Type })); -const std::shared_ptr<Type> kGenUType_Type(new Type("$genUType", { kUInt_Type, kUVec2_Type, - kUVec3_Type, kUVec4_Type })); -const std::shared_ptr<Type> kGenBType_Type(new Type("$genBType", { kBool_Type, kBVec2_Type, - kBVec3_Type, kBVec4_Type })); - -const std::shared_ptr<Type> kMat_Type(new Type("$mat")); - -const std::shared_ptr<Type> kVec_Type(new Type("$vec", { kVec2_Type, kVec2_Type, kVec3_Type, - kVec4_Type })); - -const std::shared_ptr<Type> kGVec_Type(new Type("$gvec")); -const std::shared_ptr<Type> kGVec2_Type(new Type("$gvec2")); -const std::shared_ptr<Type> kGVec3_Type(new Type("$gvec3")); -const std::shared_ptr<Type> kGVec4_Type(new Type("$gvec4", type(kVec4_Type))); -const std::shared_ptr<Type> kDVec_Type(new Type("$dvec")); -const std::shared_ptr<Type> kIVec_Type(new Type("$ivec")); -const std::shared_ptr<Type> kUVec_Type(new Type("$uvec")); - -const std::shared_ptr<Type> kBVec_Type(new Type("$bvec", { kBVec2_Type, kBVec2_Type, - kBVec3_Type, kBVec4_Type })); - -const std::shared_ptr<Type> kInvalid_Type(new Type("<INVALID>")); - } // namespace diff --git a/src/sksl/ir/SkSLType.h b/src/sksl/ir/SkSLType.h index e17bae68db..a929c8e4c7 100644 --- a/src/sksl/ir/SkSLType.h +++ b/src/sksl/ir/SkSLType.h @@ -18,24 +18,26 @@ namespace SkSL { +class Context; + /** * Represents a type, such as int or vec4. */ class Type : public Symbol { public: struct Field { - Field(Modifiers modifiers, std::string name, std::shared_ptr<Type> type) + Field(Modifiers modifiers, std::string name, const Type& type) : fModifiers(modifiers) , fName(std::move(name)) , fType(std::move(type)) {} - const std::string description() { - return fType->description() + " " + fName + ";"; + const std::string description() const { + return fType.description() + " " + fName + ";"; } const Modifiers fModifiers; const std::string fName; - const std::shared_ptr<Type> fType; + const Type& fType; }; enum Kind { @@ -56,7 +58,7 @@ public: , fTypeKind(kOther_Kind) {} // Create a generic type which maps to the listed types. - Type(std::string name, std::vector<std::shared_ptr<Type>> types) + Type(std::string name, std::vector<const Type*> types) : INHERITED(Position(), kType_Kind, std::move(name)) , fTypeKind(kGeneric_Kind) , fCoercibleTypes(std::move(types)) { @@ -78,7 +80,7 @@ public: , fRows(1) {} // Create a scalar type which can be coerced to the listed types. - Type(std::string name, bool isNumber, std::vector<std::shared_ptr<Type>> coercibleTypes) + Type(std::string name, bool isNumber, std::vector<const Type*> coercibleTypes) : INHERITED(Position(), kType_Kind, std::move(name)) , fTypeKind(kScalar_Kind) , fIsNumber(isNumber) @@ -87,23 +89,23 @@ public: , fRows(1) {} // Create a vector type. - Type(std::string name, std::shared_ptr<Type> componentType, int columns) + Type(std::string name, const Type& componentType, int columns) : Type(name, kVector_Kind, componentType, columns) {} // Create a vector or array type. - Type(std::string name, Kind kind, std::shared_ptr<Type> componentType, int columns) + Type(std::string name, Kind kind, const Type& componentType, int columns) : INHERITED(Position(), kType_Kind, std::move(name)) , fTypeKind(kind) - , fComponentType(std::move(componentType)) + , fComponentType(&componentType) , fColumns(columns) , fRows(1) , fDimensions(SpvDim1D) {} // Create a matrix type. - Type(std::string name, std::shared_ptr<Type> componentType, int columns, int rows) + Type(std::string name, const Type& componentType, int columns, int rows) : INHERITED(Position(), kType_Kind, std::move(name)) , fTypeKind(kMatrix_Kind) - , fComponentType(std::move(componentType)) + , fComponentType(&componentType) , fColumns(columns) , fRows(rows) , fDimensions(SpvDim1D) {} @@ -153,7 +155,7 @@ public: * Returns true if an instance of this type can be freely coerced (implicitly converted) to * another type. */ - bool canCoerceTo(std::shared_ptr<Type> other) const { + bool canCoerceTo(const Type& other) const { int cost; return determineCoercionCost(other, &cost); } @@ -164,15 +166,15 @@ public: * costs. Returns true if a conversion is possible, false otherwise. The value of the out * parameter is undefined if false is returned. */ - bool determineCoercionCost(std::shared_ptr<Type> other, int* outCost) const; + bool determineCoercionCost(const Type& other, int* outCost) const; /** * For matrices and vectors, returns the type of individual cells (e.g. mat2 has a component * type of kFloat_Type). For all other types, causes an assertion failure. */ - std::shared_ptr<Type> componentType() const { + const Type& componentType() const { ASSERT(fComponentType); - return fComponentType; + return *fComponentType; } /** @@ -195,7 +197,7 @@ public: return fRows; } - std::vector<Field> fields() const { + const std::vector<Field>& fields() const { ASSERT(fTypeKind == kStruct_Kind); return fFields; } @@ -204,7 +206,7 @@ public: * For generic types, returns the types that this generic type can substitute for. For other * types, returns a list of other types that this type can be coerced into. */ - std::vector<std::shared_ptr<Type>> coercibleTypes() const { + const std::vector<const Type*>& coercibleTypes() const { ASSERT(fCoercibleTypes.size() > 0); return fCoercibleTypes; } @@ -257,7 +259,7 @@ public: case kStruct_Kind: { size_t result = 16; for (size_t i = 0; i < fFields.size(); i++) { - size_t alignment = fFields[i].fType->alignment(); + size_t alignment = fFields[i].fType.alignment(); if (alignment > result) { result = alignment; } @@ -300,13 +302,13 @@ public: case kStruct_Kind: { size_t total = 0; for (size_t i = 0; i < fFields.size(); i++) { - size_t alignment = fFields[i].fType->alignment(); + size_t alignment = fFields[i].fType.alignment(); if (total % alignment != 0) { total += alignment - total % alignment; } ASSERT(false); ASSERT(total % alignment == 0); - total += fFields[i].fType->size(); + total += fFields[i].fType.size(); } return total; } @@ -319,15 +321,15 @@ public: * Returns the corresponding vector or matrix type with the specified number of columns and * rows. */ - std::shared_ptr<Type> toCompound(int columns, int rows); + const Type& toCompound(const Context& context, int columns, int rows) const; private: typedef Symbol INHERITED; const Kind fTypeKind; const bool fIsNumber = false; - const std::shared_ptr<Type> fComponentType = nullptr; - const std::vector<std::shared_ptr<Type>> fCoercibleTypes = { }; + const Type* fComponentType = nullptr; + const std::vector<const Type*> fCoercibleTypes = { }; const int fColumns = -1; const int fRows = -1; const std::vector<Field> fFields = { }; @@ -338,101 +340,6 @@ private: const bool fIsSampled = false; }; -extern const std::shared_ptr<Type> kVoid_Type; - -extern const std::shared_ptr<Type> kFloat_Type; -extern const std::shared_ptr<Type> kVec2_Type; -extern const std::shared_ptr<Type> kVec3_Type; -extern const std::shared_ptr<Type> kVec4_Type; -extern const std::shared_ptr<Type> kDouble_Type; -extern const std::shared_ptr<Type> kDVec2_Type; -extern const std::shared_ptr<Type> kDVec3_Type; -extern const std::shared_ptr<Type> kDVec4_Type; -extern const std::shared_ptr<Type> kInt_Type; -extern const std::shared_ptr<Type> kIVec2_Type; -extern const std::shared_ptr<Type> kIVec3_Type; -extern const std::shared_ptr<Type> kIVec4_Type; -extern const std::shared_ptr<Type> kUInt_Type; -extern const std::shared_ptr<Type> kUVec2_Type; -extern const std::shared_ptr<Type> kUVec3_Type; -extern const std::shared_ptr<Type> kUVec4_Type; -extern const std::shared_ptr<Type> kBool_Type; -extern const std::shared_ptr<Type> kBVec2_Type; -extern const std::shared_ptr<Type> kBVec3_Type; -extern const std::shared_ptr<Type> kBVec4_Type; - -extern const std::shared_ptr<Type> kMat2x2_Type; -extern const std::shared_ptr<Type> kMat2x3_Type; -extern const std::shared_ptr<Type> kMat2x4_Type; -extern const std::shared_ptr<Type> kMat3x2_Type; -extern const std::shared_ptr<Type> kMat3x3_Type; -extern const std::shared_ptr<Type> kMat3x4_Type; -extern const std::shared_ptr<Type> kMat4x2_Type; -extern const std::shared_ptr<Type> kMat4x3_Type; -extern const std::shared_ptr<Type> kMat4x4_Type; - -extern const std::shared_ptr<Type> kDMat2x2_Type; -extern const std::shared_ptr<Type> kDMat2x3_Type; -extern const std::shared_ptr<Type> kDMat2x4_Type; -extern const std::shared_ptr<Type> kDMat3x2_Type; -extern const std::shared_ptr<Type> kDMat3x3_Type; -extern const std::shared_ptr<Type> kDMat3x4_Type; -extern const std::shared_ptr<Type> kDMat4x2_Type; -extern const std::shared_ptr<Type> kDMat4x3_Type; -extern const std::shared_ptr<Type> kDMat4x4_Type; - -extern const std::shared_ptr<Type> kSampler1D_Type; -extern const std::shared_ptr<Type> kSampler2D_Type; -extern const std::shared_ptr<Type> kSampler3D_Type; -extern const std::shared_ptr<Type> kSamplerCube_Type; -extern const std::shared_ptr<Type> kSampler2DRect_Type; -extern const std::shared_ptr<Type> kSampler1DArray_Type; -extern const std::shared_ptr<Type> kSampler2DArray_Type; -extern const std::shared_ptr<Type> kSamplerCubeArray_Type; -extern const std::shared_ptr<Type> kSamplerBuffer_Type; -extern const std::shared_ptr<Type> kSampler2DMS_Type; -extern const std::shared_ptr<Type> kSampler2DMSArray_Type; - -extern const std::shared_ptr<Type> kGSampler1D_Type; -extern const std::shared_ptr<Type> kGSampler2D_Type; -extern const std::shared_ptr<Type> kGSampler3D_Type; -extern const std::shared_ptr<Type> kGSamplerCube_Type; -extern const std::shared_ptr<Type> kGSampler2DRect_Type; -extern const std::shared_ptr<Type> kGSampler1DArray_Type; -extern const std::shared_ptr<Type> kGSampler2DArray_Type; -extern const std::shared_ptr<Type> kGSamplerCubeArray_Type; -extern const std::shared_ptr<Type> kGSamplerBuffer_Type; -extern const std::shared_ptr<Type> kGSampler2DMS_Type; -extern const std::shared_ptr<Type> kGSampler2DMSArray_Type; - -extern const std::shared_ptr<Type> kSampler1DShadow_Type; -extern const std::shared_ptr<Type> kSampler2DShadow_Type; -extern const std::shared_ptr<Type> kSamplerCubeShadow_Type; -extern const std::shared_ptr<Type> kSampler2DRectShadow_Type; -extern const std::shared_ptr<Type> kSampler1DArrayShadow_Type; -extern const std::shared_ptr<Type> kSampler2DArrayShadow_Type; -extern const std::shared_ptr<Type> kSamplerCubeArrayShadow_Type; -extern const std::shared_ptr<Type> kGSampler2DArrayShadow_Type; -extern const std::shared_ptr<Type> kGSamplerCubeArrayShadow_Type; - -extern const std::shared_ptr<Type> kGenType_Type; -extern const std::shared_ptr<Type> kGenDType_Type; -extern const std::shared_ptr<Type> kGenIType_Type; -extern const std::shared_ptr<Type> kGenUType_Type; -extern const std::shared_ptr<Type> kGenBType_Type; -extern const std::shared_ptr<Type> kMat_Type; -extern const std::shared_ptr<Type> kVec_Type; -extern const std::shared_ptr<Type> kGVec_Type; -extern const std::shared_ptr<Type> kGVec2_Type; -extern const std::shared_ptr<Type> kGVec3_Type; -extern const std::shared_ptr<Type> kGVec4_Type; -extern const std::shared_ptr<Type> kDVec_Type; -extern const std::shared_ptr<Type> kIVec_Type; -extern const std::shared_ptr<Type> kUVec_Type; -extern const std::shared_ptr<Type> kBVec_Type; - -extern const std::shared_ptr<Type> kInvalid_Type; - } // namespace #endif diff --git a/src/sksl/ir/SkSLTypeReference.h b/src/sksl/ir/SkSLTypeReference.h index 5f4990f35d..76923aae4d 100644 --- a/src/sksl/ir/SkSLTypeReference.h +++ b/src/sksl/ir/SkSLTypeReference.h @@ -8,6 +8,7 @@ #ifndef SKSL_TYPEREFERENCE #define SKSL_TYPEREFERENCE +#include "SkSLContext.h" #include "SkSLExpression.h" namespace SkSL { @@ -17,16 +18,16 @@ namespace SkSL { * always eventually replaced by Constructors in valid programs. */ struct TypeReference : public Expression { - TypeReference(Position position, std::shared_ptr<Type> type) - : INHERITED(position, kTypeReference_Kind, kInvalid_Type) - , fValue(std::move(type)) {} + TypeReference(const Context& context, Position position, const Type& type) + : INHERITED(position, kTypeReference_Kind, *context.fInvalid_Type) + , fValue(type) {} std::string description() const override { ASSERT(false); return "<type>"; } - const std::shared_ptr<Type> fValue; + const Type& fValue; typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLUnresolvedFunction.h b/src/sksl/ir/SkSLUnresolvedFunction.h index a6cee0d072..3a368ad8d3 100644 --- a/src/sksl/ir/SkSLUnresolvedFunction.h +++ b/src/sksl/ir/SkSLUnresolvedFunction.h @@ -16,19 +16,21 @@ namespace SkSL { * A symbol representing multiple functions with the same name. */ struct UnresolvedFunction : public Symbol { - UnresolvedFunction(std::vector<std::shared_ptr<FunctionDeclaration>> funcs) + UnresolvedFunction(std::vector<const FunctionDeclaration*> funcs) : INHERITED(Position(), kUnresolvedFunction_Kind, funcs[0]->fName) , fFunctions(std::move(funcs)) { +#ifdef DEBUG for (auto func : funcs) { ASSERT(func->fName == fName); } +#endif } virtual std::string description() const override { return fName; } - const std::vector<std::shared_ptr<FunctionDeclaration>> fFunctions; + const std::vector<const FunctionDeclaration*> fFunctions; typedef Symbol INHERITED; }; diff --git a/src/sksl/ir/SkSLVarDeclaration.h b/src/sksl/ir/SkSLVarDeclaration.h index 400f430e4c..b234231b86 100644 --- a/src/sksl/ir/SkSLVarDeclaration.h +++ b/src/sksl/ir/SkSLVarDeclaration.h @@ -20,7 +20,7 @@ namespace SkSL { * names ['x', 'y', 'z'], sizes of [[], [], [4, 2]], and values of [null, 1, null]. */ struct VarDeclaration : public ProgramElement { - VarDeclaration(Position position, std::vector<std::shared_ptr<Variable>> vars, + VarDeclaration(Position position, std::vector<const Variable*> vars, std::vector<std::vector<std::unique_ptr<Expression>>> sizes, std::vector<std::unique_ptr<Expression>> values) : INHERITED(position, kVar_Kind) @@ -30,9 +30,9 @@ struct VarDeclaration : public ProgramElement { std::string description() const override { std::string result = fVars[0]->fModifiers.description(); - std::shared_ptr<Type> baseType = fVars[0]->fType; + const Type* baseType = &fVars[0]->fType; while (baseType->kind() == Type::kArray_Kind) { - baseType = baseType->componentType(); + baseType = &baseType->componentType(); } result += baseType->description(); std::string separator = " "; @@ -55,7 +55,7 @@ struct VarDeclaration : public ProgramElement { return result; } - const std::vector<std::shared_ptr<Variable>> fVars; + const std::vector<const Variable*> fVars; const std::vector<std::vector<std::unique_ptr<Expression>>> fSizes; const std::vector<std::unique_ptr<Expression>> fValues; diff --git a/src/sksl/ir/SkSLVariable.h b/src/sksl/ir/SkSLVariable.h index d4ea2c4a43..39af3093b6 100644 --- a/src/sksl/ir/SkSLVariable.h +++ b/src/sksl/ir/SkSLVariable.h @@ -27,7 +27,7 @@ struct Variable : public Symbol { kParameter_Storage }; - Variable(Position position, Modifiers modifiers, std::string name, std::shared_ptr<Type> type, + Variable(Position position, Modifiers modifiers, std::string name, const Type& type, Storage storage) : INHERITED(position, kVariable_Kind, std::move(name)) , fModifiers(modifiers) @@ -37,12 +37,11 @@ struct Variable : public Symbol { , fIsWrittenTo(false) {} virtual std::string description() const override { - return fModifiers.description() + fType->fName + " " + fName; + return fModifiers.description() + fType.fName + " " + fName; } const Modifiers fModifiers; - const std::string fValue; - const std::shared_ptr<Type> fType; + const Type& fType; const Storage fStorage; mutable bool fIsReadFrom; @@ -53,14 +52,4 @@ struct Variable : public Symbol { } // namespace SkSL -namespace std { - template <> - struct hash<SkSL::Variable> { - public : - size_t operator()(const SkSL::Variable &var) const{ - return hash<std::string>()(var.fName) ^ hash<std::string>()(var.fType->description()); - } - }; -} // namespace std - #endif diff --git a/src/sksl/ir/SkSLVariableReference.h b/src/sksl/ir/SkSLVariableReference.h index 8499511a1b..b443da1f22 100644 --- a/src/sksl/ir/SkSLVariableReference.h +++ b/src/sksl/ir/SkSLVariableReference.h @@ -20,15 +20,15 @@ namespace SkSL { * there is only one Variable 'x', but two VariableReferences to it. */ struct VariableReference : public Expression { - VariableReference(Position position, std::shared_ptr<Variable> variable) - : INHERITED(position, kVariableReference_Kind, variable->fType) - , fVariable(std::move(variable)) {} + VariableReference(Position position, const Variable& variable) + : INHERITED(position, kVariableReference_Kind, variable.fType) + , fVariable(variable) {} std::string description() const override { - return fVariable->fName; + return fVariable.fName; } - const std::shared_ptr<Variable> fVariable; + const Variable& fVariable; typedef Expression INHERITED; }; |