aboutsummaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
authorGravatar ethannicholas <ethannicholas@google.com>2016-07-25 10:08:54 -0700
committerGravatar Commit bot <commit-bot@chromium.org>2016-07-25 10:08:54 -0700
commitd598f7981f34811e6f2a949207dc13638852f3f7 (patch)
tree83dd4cf4983f90125651a0ab380f4f71cb3e27f2 /src
parentd9ddad2952cdfd0809249abbd94a285abdb6b1d0 (diff)
SkSL performance improvements (plus a couple of minor warning fixes)
Diffstat (limited to 'src')
-rw-r--r--src/gpu/vk/GrVkPipelineStateBuilder.cpp2
-rw-r--r--src/sksl/SkSLCompiler.cpp19
-rw-r--r--src/sksl/SkSLCompiler.h2
-rw-r--r--src/sksl/SkSLContext.h227
-rw-r--r--src/sksl/SkSLIRGenerator.cpp536
-rw-r--r--src/sksl/SkSLIRGenerator.h19
-rw-r--r--src/sksl/SkSLParser.cpp16
-rw-r--r--src/sksl/SkSLSPIRVCodeGenerator.cpp434
-rw-r--r--src/sksl/SkSLSPIRVCodeGenerator.h23
-rw-r--r--src/sksl/ir/SkSLBinaryExpression.h2
-rw-r--r--src/sksl/ir/SkSLBlock.h8
-rw-r--r--src/sksl/ir/SkSLBoolLiteral.h5
-rw-r--r--src/sksl/ir/SkSLConstructor.h6
-rw-r--r--src/sksl/ir/SkSLExpression.h4
-rw-r--r--src/sksl/ir/SkSLField.h8
-rw-r--r--src/sksl/ir/SkSLFieldAccess.h4
-rw-r--r--src/sksl/ir/SkSLFloatLiteral.h5
-rw-r--r--src/sksl/ir/SkSLForStatement.h7
-rw-r--r--src/sksl/ir/SkSLFunctionCall.h8
-rw-r--r--src/sksl/ir/SkSLFunctionDeclaration.h27
-rw-r--r--src/sksl/ir/SkSLFunctionDefinition.h8
-rw-r--r--src/sksl/ir/SkSLFunctionReference.h8
-rw-r--r--src/sksl/ir/SkSLIndexExpression.h25
-rw-r--r--src/sksl/ir/SkSLIntLiteral.h4
-rw-r--r--src/sksl/ir/SkSLInterfaceBlock.h16
-rw-r--r--src/sksl/ir/SkSLProgram.h8
-rw-r--r--src/sksl/ir/SkSLSwizzle.h49
-rw-r--r--src/sksl/ir/SkSLSymbolTable.cpp81
-rw-r--r--src/sksl/ir/SkSLSymbolTable.h19
-rw-r--r--src/sksl/ir/SkSLType.cpp215
-rw-r--r--src/sksl/ir/SkSLType.h143
-rw-r--r--src/sksl/ir/SkSLTypeReference.h9
-rw-r--r--src/sksl/ir/SkSLUnresolvedFunction.h6
-rw-r--r--src/sksl/ir/SkSLVarDeclaration.h8
-rw-r--r--src/sksl/ir/SkSLVariable.h17
-rw-r--r--src/sksl/ir/SkSLVariableReference.h10
36 files changed, 1032 insertions, 956 deletions
diff --git a/src/gpu/vk/GrVkPipelineStateBuilder.cpp b/src/gpu/vk/GrVkPipelineStateBuilder.cpp
index 323ea66946..d9d1b6cfb8 100644
--- a/src/gpu/vk/GrVkPipelineStateBuilder.cpp
+++ b/src/gpu/vk/GrVkPipelineStateBuilder.cpp
@@ -93,8 +93,6 @@ shaderc_shader_kind vk_shader_stage_to_shaderc_kind(VkShaderStageFlagBits stage)
}
#endif
-#include <fstream>
-#include <sstream>
bool GrVkPipelineStateBuilder::CreateVkShaderModule(const GrVkGpu* gpu,
VkShaderStageFlagBits stage,
const GrGLSLShaderBuilder& builder,
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index 2b4adc1026..0d65b107ec 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -41,9 +41,10 @@ Compiler::Compiler()
: fErrorCount(0) {
auto types = std::shared_ptr<SymbolTable>(new SymbolTable(*this));
auto symbols = std::shared_ptr<SymbolTable>(new SymbolTable(types, *this));
- fIRGenerator = new IRGenerator(symbols, *this);
+ fIRGenerator = new IRGenerator(&fContext, symbols, *this);
fTypes = types;
- #define ADD_TYPE(t) types->add(k ## t ## _Type->fName, k ## t ## _Type)
+ #define ADD_TYPE(t) types->addWithoutOwnership(fContext.f ## t ## _Type->fName, \
+ fContext.f ## t ## _Type.get())
ADD_TYPE(Void);
ADD_TYPE(Float);
ADD_TYPE(Vec2);
@@ -185,19 +186,21 @@ std::unique_ptr<Program> Compiler::convertProgram(Program::Kind kind, std::strin
fErrorText = "";
fErrorCount = 0;
fIRGenerator->pushSymbolTable();
- std::vector<std::unique_ptr<ProgramElement>> result;
+ std::vector<std::unique_ptr<ProgramElement>> elements;
switch (kind) {
case Program::kVertex_Kind:
- this->internalConvertProgram(SKSL_VERT_INCLUDE, &result);
+ this->internalConvertProgram(SKSL_VERT_INCLUDE, &elements);
break;
case Program::kFragment_Kind:
- this->internalConvertProgram(SKSL_FRAG_INCLUDE, &result);
+ this->internalConvertProgram(SKSL_FRAG_INCLUDE, &elements);
break;
}
- this->internalConvertProgram(text, &result);
+ this->internalConvertProgram(text, &elements);
+ auto result = std::unique_ptr<Program>(new Program(kind, std::move(elements),
+ fIRGenerator->fSymbolTable));;
fIRGenerator->popSymbolTable();
this->writeErrorCount();
- return std::unique_ptr<Program>(new Program(kind, std::move(result)));;
+ return result;
}
void Compiler::error(Position position, std::string msg) {
@@ -224,7 +227,7 @@ void Compiler::writeErrorCount() {
bool Compiler::toSPIRV(Program::Kind kind, std::string text, std::ostream& out) {
auto program = this->convertProgram(kind, text);
if (fErrorCount == 0) {
- SkSL::SPIRVCodeGenerator cg;
+ SkSL::SPIRVCodeGenerator cg(&fContext);
cg.generateCode(*program.get(), out);
ASSERT(!out.rdstate());
}
diff --git a/src/sksl/SkSLCompiler.h b/src/sksl/SkSLCompiler.h
index 2209427eef..e63d5f4ed8 100644
--- a/src/sksl/SkSLCompiler.h
+++ b/src/sksl/SkSLCompiler.h
@@ -11,6 +11,7 @@
#include <vector>
#include "ir/SkSLProgram.h"
#include "ir/SkSLSymbolTable.h"
+#include "SkSLContext.h"
#include "SkSLErrorReporter.h"
namespace SkSL {
@@ -50,6 +51,7 @@ private:
IRGenerator* fIRGenerator;
std::string fSkiaVertText; // FIXME store parsed version instead
+ Context fContext;
int fErrorCount;
std::string fErrorText;
};
diff --git a/src/sksl/SkSLContext.h b/src/sksl/SkSLContext.h
new file mode 100644
index 0000000000..1f124d05eb
--- /dev/null
+++ b/src/sksl/SkSLContext.h
@@ -0,0 +1,227 @@
+/*
+ * Copyright 2016 Google Inc.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_CONTEXT
+#define SKSL_CONTEXT
+
+#include "ir/SkSLType.h"
+
+namespace SkSL {
+
+/**
+ * Contains compiler-wide objects, which currently means the core types.
+ */
+class Context {
+public:
+ Context()
+ : fVoid_Type(new Type("void"))
+ , fDouble_Type(new Type("double", true))
+ , fDVec2_Type(new Type("dvec2", *fDouble_Type, 2))
+ , fDVec3_Type(new Type("dvec3", *fDouble_Type, 3))
+ , fDVec4_Type(new Type("dvec4", *fDouble_Type, 4))
+ , fFloat_Type(new Type("float", true, { fDouble_Type.get() }))
+ , fVec2_Type(new Type("vec2", *fFloat_Type, 2))
+ , fVec3_Type(new Type("vec3", *fFloat_Type, 3))
+ , fVec4_Type(new Type("vec4", *fFloat_Type, 4))
+ , fUInt_Type(new Type("uint", true, { fFloat_Type.get(), fDouble_Type.get() }))
+ , fUVec2_Type(new Type("uvec2", *fUInt_Type, 2))
+ , fUVec3_Type(new Type("uvec3", *fUInt_Type, 3))
+ , fUVec4_Type(new Type("uvec4", *fUInt_Type, 4))
+ , fInt_Type(new Type("int", true, { fUInt_Type.get(), fFloat_Type.get(), fDouble_Type.get() }))
+ , fIVec2_Type(new Type("ivec2", *fInt_Type, 2))
+ , fIVec3_Type(new Type("ivec3", *fInt_Type, 3))
+ , fIVec4_Type(new Type("ivec4", *fInt_Type, 4))
+ , fBool_Type(new Type("bool", false))
+ , fBVec2_Type(new Type("bvec2", *fBool_Type, 2))
+ , fBVec3_Type(new Type("bvec3", *fBool_Type, 3))
+ , fBVec4_Type(new Type("bvec4", *fBool_Type, 4))
+ , fMat2x2_Type(new Type("mat2", *fFloat_Type, 2, 2))
+ , fMat2x3_Type(new Type("mat2x3", *fFloat_Type, 2, 3))
+ , fMat2x4_Type(new Type("mat2x4", *fFloat_Type, 2, 4))
+ , fMat3x2_Type(new Type("mat3x2", *fFloat_Type, 3, 2))
+ , fMat3x3_Type(new Type("mat3", *fFloat_Type, 3, 3))
+ , fMat3x4_Type(new Type("mat3x4", *fFloat_Type, 3, 4))
+ , fMat4x2_Type(new Type("mat4x2", *fFloat_Type, 4, 2))
+ , fMat4x3_Type(new Type("mat4x3", *fFloat_Type, 4, 3))
+ , fMat4x4_Type(new Type("mat4", *fFloat_Type, 4, 4))
+ , fDMat2x2_Type(new Type("dmat2", *fFloat_Type, 2, 2))
+ , fDMat2x3_Type(new Type("dmat2x3", *fFloat_Type, 2, 3))
+ , fDMat2x4_Type(new Type("dmat2x4", *fFloat_Type, 2, 4))
+ , fDMat3x2_Type(new Type("dmat3x2", *fFloat_Type, 3, 2))
+ , fDMat3x3_Type(new Type("dmat3", *fFloat_Type, 3, 3))
+ , fDMat3x4_Type(new Type("dmat3x4", *fFloat_Type, 3, 4))
+ , fDMat4x2_Type(new Type("dmat4x2", *fFloat_Type, 4, 2))
+ , fDMat4x3_Type(new Type("dmat4x3", *fFloat_Type, 4, 3))
+ , fDMat4x4_Type(new Type("dmat4", *fFloat_Type, 4, 4))
+ , fSampler1D_Type(new Type("sampler1D", SpvDim1D, false, false, false, true))
+ , fSampler2D_Type(new Type("sampler2D", SpvDim2D, false, false, false, true))
+ , fSampler3D_Type(new Type("sampler3D", SpvDim3D, false, false, false, true))
+ , fSamplerCube_Type(new Type("samplerCube"))
+ , fSampler2DRect_Type(new Type("sampler2DRect"))
+ , fSampler1DArray_Type(new Type("sampler1DArray"))
+ , fSampler2DArray_Type(new Type("sampler2DArray"))
+ , fSamplerCubeArray_Type(new Type("samplerCubeArray"))
+ , fSamplerBuffer_Type(new Type("samplerBuffer"))
+ , fSampler2DMS_Type(new Type("sampler2DMS"))
+ , fSampler2DMSArray_Type(new Type("sampler2DMSArray"))
+ , fSampler1DShadow_Type(new Type("sampler1DShadow"))
+ , fSampler2DShadow_Type(new Type("sampler2DShadow"))
+ , fSamplerCubeShadow_Type(new Type("samplerCubeShadow"))
+ , fSampler2DRectShadow_Type(new Type("sampler2DRectShadow"))
+ , fSampler1DArrayShadow_Type(new Type("sampler1DArrayShadow"))
+ , fSampler2DArrayShadow_Type(new Type("sampler2DArrayShadow"))
+ , fSamplerCubeArrayShadow_Type(new Type("samplerCubeArrayShadow"))
+ // FIXME figure out what we're supposed to do with the gsampler et al. types)
+ , fGSampler1D_Type(new Type("$gsampler1D", static_type(*fSampler1D_Type)))
+ , fGSampler2D_Type(new Type("$gsampler2D", static_type(*fSampler2D_Type)))
+ , fGSampler3D_Type(new Type("$gsampler3D", static_type(*fSampler3D_Type)))
+ , fGSamplerCube_Type(new Type("$gsamplerCube", static_type(*fSamplerCube_Type)))
+ , fGSampler2DRect_Type(new Type("$gsampler2DRect", static_type(*fSampler2DRect_Type)))
+ , fGSampler1DArray_Type(new Type("$gsampler1DArray", static_type(*fSampler1DArray_Type)))
+ , fGSampler2DArray_Type(new Type("$gsampler2DArray", static_type(*fSampler2DArray_Type)))
+ , fGSamplerCubeArray_Type(new Type("$gsamplerCubeArray", static_type(*fSamplerCubeArray_Type)))
+ , fGSamplerBuffer_Type(new Type("$gsamplerBuffer", static_type(*fSamplerBuffer_Type)))
+ , fGSampler2DMS_Type(new Type("$gsampler2DMS", static_type(*fSampler2DMS_Type)))
+ , fGSampler2DMSArray_Type(new Type("$gsampler2DMSArray", static_type(*fSampler2DMSArray_Type)))
+ , fGSampler2DArrayShadow_Type(new Type("$gsampler2DArrayShadow",
+ static_type(*fSampler2DArrayShadow_Type)))
+ , fGSamplerCubeArrayShadow_Type(new Type("$gsamplerCubeArrayShadow",
+ static_type(*fSamplerCubeArrayShadow_Type)))
+ , fGenType_Type(new Type("$genType", { fFloat_Type.get(), fVec2_Type.get(), fVec3_Type.get(),
+ fVec4_Type.get() }))
+ , fGenDType_Type(new Type("$genDType", { fDouble_Type.get(), fDVec2_Type.get(),
+ fDVec3_Type.get(), fDVec4_Type.get() }))
+ , fGenIType_Type(new Type("$genIType", { fInt_Type.get(), fIVec2_Type.get(), fIVec3_Type.get(),
+ fIVec4_Type.get() }))
+ , fGenUType_Type(new Type("$genUType", { fUInt_Type.get(), fUVec2_Type.get(), fUVec3_Type.get(),
+ fUVec4_Type.get() }))
+ , fGenBType_Type(new Type("$genBType", { fBool_Type.get(), fBVec2_Type.get(), fBVec3_Type.get(),
+ fBVec4_Type.get() }))
+ , fMat_Type(new Type("$mat"))
+ , fVec_Type(new Type("$vec", { fVec2_Type.get(), fVec2_Type.get(), fVec3_Type.get(),
+ fVec4_Type.get() }))
+ , fGVec_Type(new Type("$gvec"))
+ , fGVec2_Type(new Type("$gvec2"))
+ , fGVec3_Type(new Type("$gvec3"))
+ , fGVec4_Type(new Type("$gvec4", static_type(*fVec4_Type)))
+ , fDVec_Type(new Type("$dvec"))
+ , fIVec_Type(new Type("$ivec"))
+ , fUVec_Type(new Type("$uvec"))
+ , fBVec_Type(new Type("$bvec", { fBVec2_Type.get(), fBVec2_Type.get(), fBVec3_Type.get(),
+ fBVec4_Type.get() }))
+ , fInvalid_Type(new Type("<INVALID>")) {}
+
+ static std::vector<const Type*> static_type(const Type& t) {
+ return { &t, &t, &t, &t };
+ }
+
+ const std::unique_ptr<Type> fVoid_Type;
+
+ const std::unique_ptr<Type> fDouble_Type;
+ const std::unique_ptr<Type> fDVec2_Type;
+ const std::unique_ptr<Type> fDVec3_Type;
+ const std::unique_ptr<Type> fDVec4_Type;
+
+ const std::unique_ptr<Type> fFloat_Type;
+ const std::unique_ptr<Type> fVec2_Type;
+ const std::unique_ptr<Type> fVec3_Type;
+ const std::unique_ptr<Type> fVec4_Type;
+
+ const std::unique_ptr<Type> fUInt_Type;
+ const std::unique_ptr<Type> fUVec2_Type;
+ const std::unique_ptr<Type> fUVec3_Type;
+ const std::unique_ptr<Type> fUVec4_Type;
+
+ const std::unique_ptr<Type> fInt_Type;
+ const std::unique_ptr<Type> fIVec2_Type;
+ const std::unique_ptr<Type> fIVec3_Type;
+ const std::unique_ptr<Type> fIVec4_Type;
+
+ const std::unique_ptr<Type> fBool_Type;
+ const std::unique_ptr<Type> fBVec2_Type;
+ const std::unique_ptr<Type> fBVec3_Type;
+ const std::unique_ptr<Type> fBVec4_Type;
+
+ const std::unique_ptr<Type> fMat2x2_Type;
+ const std::unique_ptr<Type> fMat2x3_Type;
+ const std::unique_ptr<Type> fMat2x4_Type;
+ const std::unique_ptr<Type> fMat3x2_Type;
+ const std::unique_ptr<Type> fMat3x3_Type;
+ const std::unique_ptr<Type> fMat3x4_Type;
+ const std::unique_ptr<Type> fMat4x2_Type;
+ const std::unique_ptr<Type> fMat4x3_Type;
+ const std::unique_ptr<Type> fMat4x4_Type;
+
+ const std::unique_ptr<Type> fDMat2x2_Type;
+ const std::unique_ptr<Type> fDMat2x3_Type;
+ const std::unique_ptr<Type> fDMat2x4_Type;
+ const std::unique_ptr<Type> fDMat3x2_Type;
+ const std::unique_ptr<Type> fDMat3x3_Type;
+ const std::unique_ptr<Type> fDMat3x4_Type;
+ const std::unique_ptr<Type> fDMat4x2_Type;
+ const std::unique_ptr<Type> fDMat4x3_Type;
+ const std::unique_ptr<Type> fDMat4x4_Type;
+
+ const std::unique_ptr<Type> fSampler1D_Type;
+ const std::unique_ptr<Type> fSampler2D_Type;
+ const std::unique_ptr<Type> fSampler3D_Type;
+ const std::unique_ptr<Type> fSamplerCube_Type;
+ const std::unique_ptr<Type> fSampler2DRect_Type;
+ const std::unique_ptr<Type> fSampler1DArray_Type;
+ const std::unique_ptr<Type> fSampler2DArray_Type;
+ const std::unique_ptr<Type> fSamplerCubeArray_Type;
+ const std::unique_ptr<Type> fSamplerBuffer_Type;
+ const std::unique_ptr<Type> fSampler2DMS_Type;
+ const std::unique_ptr<Type> fSampler2DMSArray_Type;
+ const std::unique_ptr<Type> fSampler1DShadow_Type;
+ const std::unique_ptr<Type> fSampler2DShadow_Type;
+ const std::unique_ptr<Type> fSamplerCubeShadow_Type;
+ const std::unique_ptr<Type> fSampler2DRectShadow_Type;
+ const std::unique_ptr<Type> fSampler1DArrayShadow_Type;
+ const std::unique_ptr<Type> fSampler2DArrayShadow_Type;
+ const std::unique_ptr<Type> fSamplerCubeArrayShadow_Type;
+
+ const std::unique_ptr<Type> fGSampler1D_Type;
+ const std::unique_ptr<Type> fGSampler2D_Type;
+ const std::unique_ptr<Type> fGSampler3D_Type;
+ const std::unique_ptr<Type> fGSamplerCube_Type;
+ const std::unique_ptr<Type> fGSampler2DRect_Type;
+ const std::unique_ptr<Type> fGSampler1DArray_Type;
+ const std::unique_ptr<Type> fGSampler2DArray_Type;
+ const std::unique_ptr<Type> fGSamplerCubeArray_Type;
+ const std::unique_ptr<Type> fGSamplerBuffer_Type;
+ const std::unique_ptr<Type> fGSampler2DMS_Type;
+ const std::unique_ptr<Type> fGSampler2DMSArray_Type;
+ const std::unique_ptr<Type> fGSampler2DArrayShadow_Type;
+ const std::unique_ptr<Type> fGSamplerCubeArrayShadow_Type;
+
+ const std::unique_ptr<Type> fGenType_Type;
+ const std::unique_ptr<Type> fGenDType_Type;
+ const std::unique_ptr<Type> fGenIType_Type;
+ const std::unique_ptr<Type> fGenUType_Type;
+ const std::unique_ptr<Type> fGenBType_Type;
+
+ const std::unique_ptr<Type> fMat_Type;
+
+ const std::unique_ptr<Type> fVec_Type;
+
+ const std::unique_ptr<Type> fGVec_Type;
+ const std::unique_ptr<Type> fGVec2_Type;
+ const std::unique_ptr<Type> fGVec3_Type;
+ const std::unique_ptr<Type> fGVec4_Type;
+ const std::unique_ptr<Type> fDVec_Type;
+ const std::unique_ptr<Type> fIVec_Type;
+ const std::unique_ptr<Type> fUVec_Type;
+
+ const std::unique_ptr<Type> fBVec_Type;
+
+ const std::unique_ptr<Type> fInvalid_Type;
+};
+
+} // namespace
+
+#endif
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index 2cc7eacb4d..f250c4bb0c 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -66,11 +66,12 @@ public:
std::shared_ptr<SymbolTable> fPrevious;
};
-IRGenerator::IRGenerator(std::shared_ptr<SymbolTable> symbolTable,
+IRGenerator::IRGenerator(const Context* context, std::shared_ptr<SymbolTable> symbolTable,
ErrorReporter& errorReporter)
-: fSymbolTable(std::move(symbolTable))
-, fErrors(errorReporter) {
-}
+: fContext(*context)
+, fCurrentFunction(nullptr)
+, fSymbolTable(std::move(symbolTable))
+, fErrors(errorReporter) {}
void IRGenerator::pushSymbolTable() {
fSymbolTable.reset(new SymbolTable(std::move(fSymbolTable), fErrors));
@@ -123,7 +124,7 @@ std::unique_ptr<Block> IRGenerator::convertBlock(const ASTBlock& block) {
}
statements.push_back(std::move(statement));
}
- return std::unique_ptr<Block>(new Block(block.fPosition, std::move(statements)));
+ return std::unique_ptr<Block>(new Block(block.fPosition, std::move(statements), fSymbolTable));
}
std::unique_ptr<Statement> IRGenerator::convertVarDeclarationStatement(
@@ -141,22 +142,22 @@ Modifiers IRGenerator::convertModifiers(const ASTModifiers& modifiers) {
std::unique_ptr<VarDeclaration> IRGenerator::convertVarDeclaration(const ASTVarDeclaration& decl,
Variable::Storage storage) {
- std::vector<std::shared_ptr<Variable>> variables;
+ std::vector<const Variable*> variables;
std::vector<std::vector<std::unique_ptr<Expression>>> sizes;
std::vector<std::unique_ptr<Expression>> values;
- std::shared_ptr<Type> baseType = this->convertType(*decl.fType);
+ const Type* baseType = this->convertType(*decl.fType);
if (!baseType) {
return nullptr;
}
for (size_t i = 0; i < decl.fNames.size(); i++) {
Modifiers modifiers = this->convertModifiers(decl.fModifiers);
- std::shared_ptr<Type> type = baseType;
+ const Type* type = baseType;
ASSERT(type->kind() != Type::kArray_Kind);
std::vector<std::unique_ptr<Expression>> currentVarSizes;
for (size_t j = 0; j < decl.fSizes[i].size(); j++) {
if (decl.fSizes[i][j]) {
ASTExpression& rawSize = *decl.fSizes[i][j];
- auto size = this->coerce(this->convertExpression(rawSize), kInt_Type);
+ auto size = this->coerce(this->convertExpression(rawSize), *fContext.fInt_Type);
if (!size) {
return nullptr;
}
@@ -172,27 +173,28 @@ std::unique_ptr<VarDeclaration> IRGenerator::convertVarDeclaration(const ASTVarD
count = -1;
name += "[]";
}
- type = std::shared_ptr<Type>(new Type(name, Type::kArray_Kind, type, (int) count));
+ type = new Type(name, Type::kArray_Kind, *type, (int) count);
+ fSymbolTable->takeOwnership((Type*) type);
currentVarSizes.push_back(std::move(size));
} else {
- type = std::shared_ptr<Type>(new Type(type->fName + "[]", Type::kArray_Kind, type,
- -1));
+ type = new Type(type->fName + "[]", Type::kArray_Kind, *type, -1);
+ fSymbolTable->takeOwnership((Type*) type);
currentVarSizes.push_back(nullptr);
}
}
sizes.push_back(std::move(currentVarSizes));
- auto var = std::make_shared<Variable>(decl.fPosition, modifiers, decl.fNames[i], type,
- storage);
- variables.push_back(var);
+ auto var = std::unique_ptr<Variable>(new Variable(decl.fPosition, modifiers, decl.fNames[i],
+ *type, storage));
std::unique_ptr<Expression> value;
if (decl.fValues[i]) {
value = this->convertExpression(*decl.fValues[i]);
if (!value) {
return nullptr;
}
- value = this->coerce(std::move(value), type);
+ value = this->coerce(std::move(value), *type);
}
- fSymbolTable->add(var->fName, var);
+ variables.push_back(var.get());
+ fSymbolTable->add(decl.fNames[i], std::move(var));
values.push_back(std::move(value));
}
return std::unique_ptr<VarDeclaration>(new VarDeclaration(decl.fPosition, std::move(variables),
@@ -200,7 +202,8 @@ std::unique_ptr<VarDeclaration> IRGenerator::convertVarDeclaration(const ASTVarD
}
std::unique_ptr<Statement> IRGenerator::convertIf(const ASTIfStatement& s) {
- std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*s.fTest), kBool_Type);
+ std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*s.fTest),
+ *fContext.fBool_Type);
if (!test) {
return nullptr;
}
@@ -225,7 +228,8 @@ std::unique_ptr<Statement> IRGenerator::convertFor(const ASTForStatement& f) {
if (!initializer) {
return nullptr;
}
- std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*f.fTest), kBool_Type);
+ std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*f.fTest),
+ *fContext.fBool_Type);
if (!test) {
return nullptr;
}
@@ -240,11 +244,12 @@ std::unique_ptr<Statement> IRGenerator::convertFor(const ASTForStatement& f) {
}
return std::unique_ptr<Statement>(new ForStatement(f.fPosition, std::move(initializer),
std::move(test), std::move(next),
- std::move(statement)));
+ std::move(statement), fSymbolTable));
}
std::unique_ptr<Statement> IRGenerator::convertWhile(const ASTWhileStatement& w) {
- std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*w.fTest), kBool_Type);
+ std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*w.fTest),
+ *fContext.fBool_Type);
if (!test) {
return nullptr;
}
@@ -257,7 +262,8 @@ std::unique_ptr<Statement> IRGenerator::convertWhile(const ASTWhileStatement& w)
}
std::unique_ptr<Statement> IRGenerator::convertDo(const ASTDoStatement& d) {
- std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*d.fTest), kBool_Type);
+ std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*d.fTest),
+ *fContext.fBool_Type);
if (!test) {
return nullptr;
}
@@ -286,7 +292,7 @@ std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTReturnStatement&
if (!result) {
return nullptr;
}
- if (fCurrentFunction->fReturnType == kVoid_Type) {
+ if (fCurrentFunction->fReturnType == *fContext.fVoid_Type) {
fErrors.error(result->fPosition, "may not return a value from a void function");
} else {
result = this->coerce(std::move(result), fCurrentFunction->fReturnType);
@@ -296,9 +302,9 @@ std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTReturnStatement&
}
return std::unique_ptr<Statement>(new ReturnStatement(std::move(result)));
} else {
- if (fCurrentFunction->fReturnType != kVoid_Type) {
+ if (fCurrentFunction->fReturnType != *fContext.fVoid_Type) {
fErrors.error(r.fPosition, "expected function to return '" +
- fCurrentFunction->fReturnType->description() + "'");
+ fCurrentFunction->fReturnType.description() + "'");
}
return std::unique_ptr<Statement>(new ReturnStatement(r.fPosition));
}
@@ -316,80 +322,74 @@ std::unique_ptr<Statement> IRGenerator::convertDiscard(const ASTDiscardStatement
return std::unique_ptr<Statement>(new DiscardStatement(d.fPosition));
}
-static std::shared_ptr<Type> expand_generics(std::shared_ptr<Type> type, int i) {
- if (type->kind() == Type::kGeneric_Kind) {
- return type->coercibleTypes()[i];
+static const Type& expand_generics(const Type& type, int i) {
+ if (type.kind() == Type::kGeneric_Kind) {
+ return *type.coercibleTypes()[i];
}
return type;
}
-static void expand_generics(FunctionDeclaration& decl,
- SymbolTable& symbolTable) {
+static void expand_generics(const FunctionDeclaration& decl,
+ std::shared_ptr<SymbolTable> symbolTable) {
for (int i = 0; i < 4; i++) {
- std::shared_ptr<Type> returnType = expand_generics(decl.fReturnType, i);
- std::vector<std::shared_ptr<Variable>> arguments;
+ const Type& returnType = expand_generics(decl.fReturnType, i);
+ std::vector<const Variable*> parameters;
for (const auto& p : decl.fParameters) {
- arguments.push_back(std::shared_ptr<Variable>(new Variable(
- p->fPosition,
- Modifiers(p->fModifiers),
- p->fName,
- expand_generics(p->fType, i),
- Variable::kParameter_Storage)));
+ Variable* var = new Variable(p->fPosition, Modifiers(p->fModifiers), p->fName,
+ expand_generics(p->fType, i),
+ Variable::kParameter_Storage);
+ symbolTable->takeOwnership(var);
+ parameters.push_back(var);
}
- std::shared_ptr<FunctionDeclaration> expanded(new FunctionDeclaration(
- decl.fPosition,
- decl.fName,
- std::move(arguments),
- std::move(returnType)));
- symbolTable.add(expanded->fName, expanded);
+ symbolTable->add(decl.fName, std::unique_ptr<FunctionDeclaration>(new FunctionDeclaration(
+ decl.fPosition,
+ decl.fName,
+ std::move(parameters),
+ std::move(returnType))));
}
}
std::unique_ptr<FunctionDefinition> IRGenerator::convertFunction(const ASTFunction& f) {
- std::shared_ptr<SymbolTable> old = fSymbolTable;
- AutoSymbolTable table(this);
bool isGeneric;
- std::shared_ptr<Type> returnType = this->convertType(*f.fReturnType);
+ const Type* returnType = this->convertType(*f.fReturnType);
if (!returnType) {
return nullptr;
}
isGeneric = returnType->kind() == Type::kGeneric_Kind;
- std::vector<std::shared_ptr<Variable>> parameters;
+ std::vector<const Variable*> parameters;
for (const auto& param : f.fParameters) {
- std::shared_ptr<Type> type = this->convertType(*param->fType);
+ const Type* type = this->convertType(*param->fType);
if (!type) {
return nullptr;
}
for (int j = (int) param->fSizes.size() - 1; j >= 0; j--) {
int size = param->fSizes[j];
std::string name = type->name() + "[" + to_string(size) + "]";
- type = std::shared_ptr<Type>(new Type(std::move(name), Type::kArray_Kind,
- std::move(type), size));
+ Type* newType = new Type(std::move(name), Type::kArray_Kind, *type, size);
+ fSymbolTable->takeOwnership(newType);
+ type = newType;
}
std::string name = param->fName;
Modifiers modifiers = this->convertModifiers(param->fModifiers);
Position pos = param->fPosition;
- std::shared_ptr<Variable> var = std::shared_ptr<Variable>(new Variable(
- pos,
- modifiers,
- std::move(name),
- type,
- Variable::kParameter_Storage));
- parameters.push_back(std::move(var));
+ Variable* var = new Variable(pos, modifiers, std::move(name), *type,
+ Variable::kParameter_Storage);
+ fSymbolTable->takeOwnership(var);
+ parameters.push_back(var);
isGeneric |= type->kind() == Type::kGeneric_Kind;
}
// find existing declaration
- std::shared_ptr<FunctionDeclaration> decl;
- auto entry = (*old)[f.fName];
+ const FunctionDeclaration* decl = nullptr;
+ auto entry = (*fSymbolTable)[f.fName];
if (entry) {
- std::vector<std::shared_ptr<FunctionDeclaration>> functions;
+ std::vector<const FunctionDeclaration*> functions;
switch (entry->fKind) {
case Symbol::kUnresolvedFunction_Kind:
- functions = std::static_pointer_cast<UnresolvedFunction>(entry)->fFunctions;
+ functions = ((UnresolvedFunction*) entry)->fFunctions;
break;
case Symbol::kFunctionDeclaration_Kind:
- functions.push_back(std::static_pointer_cast<FunctionDeclaration>(entry));
+ functions.push_back((FunctionDeclaration*) entry);
break;
default:
fErrors.error(f.fPosition, "symbol '" + f.fName + "' was already defined");
@@ -406,11 +406,8 @@ std::unique_ptr<FunctionDefinition> IRGenerator::convertFunction(const ASTFuncti
}
}
if (match) {
- if (returnType != other->fReturnType) {
- FunctionDeclaration newDecl = FunctionDeclaration(f.fPosition,
- f.fName,
- parameters,
- returnType);
+ if (*returnType != other->fReturnType) {
+ FunctionDeclaration newDecl(f.fPosition, f.fName, parameters, *returnType);
fErrors.error(f.fPosition, "functions '" + newDecl.description() +
"' and '" + other->description() +
"' differ only in return type");
@@ -424,7 +421,6 @@ std::unique_ptr<FunctionDefinition> IRGenerator::convertFunction(const ASTFuncti
"declaration and definition");
return nullptr;
}
- fSymbolTable->add(parameters[i]->fName, decl->fParameters[i]);
}
if (other->fDefined) {
fErrors.error(f.fPosition, "duplicate definition of " +
@@ -437,28 +433,36 @@ std::unique_ptr<FunctionDefinition> IRGenerator::convertFunction(const ASTFuncti
}
if (!decl) {
// couldn't find an existing declaration
- decl.reset(new FunctionDeclaration(f.fPosition, f.fName, parameters, returnType));
- for (auto var : parameters) {
- fSymbolTable->add(var->fName, var);
+ if (isGeneric) {
+ ASSERT(!f.fBody);
+ expand_generics(FunctionDeclaration(f.fPosition, f.fName, parameters, *returnType),
+ fSymbolTable);
+ } else {
+ auto newDecl = std::unique_ptr<FunctionDeclaration>(new FunctionDeclaration(
+ f.fPosition,
+ f.fName,
+ parameters,
+ *returnType));
+ decl = newDecl.get();
+ fSymbolTable->add(decl->fName, std::move(newDecl));
}
}
- if (isGeneric) {
- ASSERT(!f.fBody);
- expand_generics(*decl, *old);
- } else {
- old->add(decl->fName, decl);
- if (f.fBody) {
- ASSERT(!fCurrentFunction);
- fCurrentFunction = decl;
- decl->fDefined = true;
- std::unique_ptr<Block> body = this->convertBlock(*f.fBody);
- fCurrentFunction = nullptr;
- if (!body) {
- return nullptr;
- }
- return std::unique_ptr<FunctionDefinition>(new FunctionDefinition(f.fPosition, decl,
- std::move(body)));
+ if (f.fBody) {
+ ASSERT(!fCurrentFunction);
+ fCurrentFunction = decl;
+ decl->fDefined = true;
+ std::shared_ptr<SymbolTable> old = fSymbolTable;
+ AutoSymbolTable table(this);
+ for (size_t i = 0; i < parameters.size(); i++) {
+ fSymbolTable->addWithoutOwnership(parameters[i]->fName, decl->fParameters[i]);
+ }
+ std::unique_ptr<Block> body = this->convertBlock(*f.fBody);
+ fCurrentFunction = nullptr;
+ if (!body) {
+ return nullptr;
}
+ return std::unique_ptr<FunctionDefinition>(new FunctionDefinition(f.fPosition, *decl,
+ std::move(body)));
}
return nullptr;
}
@@ -488,28 +492,26 @@ std::unique_ptr<InterfaceBlock> IRGenerator::convertInterfaceBlock(const ASTInte
}
}
}
- std::shared_ptr<Type> type = std::shared_ptr<Type>(new Type(intf.fInterfaceName, fields));
+ Type* type = new Type(intf.fInterfaceName, fields);
+ fSymbolTable->takeOwnership(type);
std::string name = intf.fValueName.length() > 0 ? intf.fValueName : intf.fInterfaceName;
- std::shared_ptr<Variable> var = std::shared_ptr<Variable>(new Variable(intf.fPosition, mods,
- name, type,
- Variable::kGlobal_Storage));
+ Variable* var = new Variable(intf.fPosition, mods, name, *type, Variable::kGlobal_Storage);
+ fSymbolTable->takeOwnership(var);
if (intf.fValueName.length()) {
- old->add(intf.fValueName, var);
-
+ old->addWithoutOwnership(intf.fValueName, var);
} else {
for (size_t i = 0; i < fields.size(); i++) {
- std::shared_ptr<Field> field = std::shared_ptr<Field>(new Field(intf.fPosition, var,
- (int) i));
- old->add(fields[i].fName, field);
+ old->add(fields[i].fName, std::unique_ptr<Field>(new Field(intf.fPosition, *var,
+ (int) i)));
}
}
- return std::unique_ptr<InterfaceBlock>(new InterfaceBlock(intf.fPosition, var));
+ return std::unique_ptr<InterfaceBlock>(new InterfaceBlock(intf.fPosition, *var, fSymbolTable));
}
-std::shared_ptr<Type> IRGenerator::convertType(const ASTType& type) {
- std::shared_ptr<Symbol> result = (*fSymbolTable)[type.fName];
+const Type* IRGenerator::convertType(const ASTType& type) {
+ const Symbol* result = (*fSymbolTable)[type.fName];
if (result && result->fKind == Symbol::kType_Kind) {
- return std::static_pointer_cast<Type>(result);
+ return (const Type*) result;
}
fErrors.error(type.fPosition, "unknown type '" + type.fName + "'");
return nullptr;
@@ -520,13 +522,13 @@ std::unique_ptr<Expression> IRGenerator::convertExpression(const ASTExpression&
case ASTExpression::kIdentifier_Kind:
return this->convertIdentifier((ASTIdentifier&) expr);
case ASTExpression::kBool_Kind:
- return std::unique_ptr<Expression>(new BoolLiteral(expr.fPosition,
+ return std::unique_ptr<Expression>(new BoolLiteral(fContext, expr.fPosition,
((ASTBoolLiteral&) expr).fValue));
case ASTExpression::kInt_Kind:
- return std::unique_ptr<Expression>(new IntLiteral(expr.fPosition,
+ return std::unique_ptr<Expression>(new IntLiteral(fContext, expr.fPosition,
((ASTIntLiteral&) expr).fValue));
case ASTExpression::kFloat_Kind:
- return std::unique_ptr<Expression>(new FloatLiteral(expr.fPosition,
+ return std::unique_ptr<Expression>(new FloatLiteral(fContext, expr.fPosition,
((ASTFloatLiteral&) expr).fValue));
case ASTExpression::kBinary_Kind:
return this->convertBinaryExpression((ASTBinaryExpression&) expr);
@@ -542,40 +544,42 @@ std::unique_ptr<Expression> IRGenerator::convertExpression(const ASTExpression&
}
std::unique_ptr<Expression> IRGenerator::convertIdentifier(const ASTIdentifier& identifier) {
- std::shared_ptr<Symbol> result = (*fSymbolTable)[identifier.fText];
+ const Symbol* result = (*fSymbolTable)[identifier.fText];
if (!result) {
fErrors.error(identifier.fPosition, "unknown identifier '" + identifier.fText + "'");
return nullptr;
}
switch (result->fKind) {
case Symbol::kFunctionDeclaration_Kind: {
- std::vector<std::shared_ptr<FunctionDeclaration>> f = {
- std::static_pointer_cast<FunctionDeclaration>(result)
+ std::vector<const FunctionDeclaration*> f = {
+ (const FunctionDeclaration*) result
};
- return std::unique_ptr<FunctionReference>(new FunctionReference(identifier.fPosition,
- std::move(f)));
+ return std::unique_ptr<FunctionReference>(new FunctionReference(fContext,
+ identifier.fPosition,
+ f));
}
case Symbol::kUnresolvedFunction_Kind: {
- auto f = std::static_pointer_cast<UnresolvedFunction>(result);
- return std::unique_ptr<FunctionReference>(new FunctionReference(identifier.fPosition,
+ const UnresolvedFunction* f = (const UnresolvedFunction*) result;
+ return std::unique_ptr<FunctionReference>(new FunctionReference(fContext,
+ identifier.fPosition,
f->fFunctions));
}
case Symbol::kVariable_Kind: {
- std::shared_ptr<Variable> var = std::static_pointer_cast<Variable>(result);
- this->markReadFrom(var);
+ const Variable* var = (const Variable*) result;
+ this->markReadFrom(*var);
return std::unique_ptr<VariableReference>(new VariableReference(identifier.fPosition,
- std::move(var)));
+ *var));
}
case Symbol::kField_Kind: {
- std::shared_ptr<Field> field = std::static_pointer_cast<Field>(result);
+ const Field* field = (const Field*) result;
VariableReference* base = new VariableReference(identifier.fPosition, field->fOwner);
return std::unique_ptr<Expression>(new FieldAccess(std::unique_ptr<Expression>(base),
field->fFieldIndex));
}
case Symbol::kType_Kind: {
- auto t = std::static_pointer_cast<Type>(result);
- return std::unique_ptr<TypeReference>(new TypeReference(identifier.fPosition,
- std::move(t)));
+ const Type* t = (const Type*) result;
+ return std::unique_ptr<TypeReference>(new TypeReference(fContext, identifier.fPosition,
+ *t));
}
default:
ABORT("unsupported symbol type %d\n", result->fKind);
@@ -584,43 +588,45 @@ std::unique_ptr<Expression> IRGenerator::convertIdentifier(const ASTIdentifier&
}
std::unique_ptr<Expression> IRGenerator::coerce(std::unique_ptr<Expression> expr,
- std::shared_ptr<Type> type) {
+ const Type& type) {
if (!expr) {
return nullptr;
}
- if (*expr->fType == *type) {
+ if (expr->fType == type) {
return expr;
}
this->checkValid(*expr);
- if (*expr->fType == *kInvalid_Type) {
+ if (expr->fType == *fContext.fInvalid_Type) {
return nullptr;
}
- if (!expr->fType->canCoerceTo(type)) {
- fErrors.error(expr->fPosition, "expected '" + type->description() + "', but found '" +
- expr->fType->description() + "'");
+ if (!expr->fType.canCoerceTo(type)) {
+ fErrors.error(expr->fPosition, "expected '" + type.description() + "', but found '" +
+ expr->fType.description() + "'");
return nullptr;
}
- if (type->kind() == Type::kScalar_Kind) {
+ if (type.kind() == Type::kScalar_Kind) {
std::vector<std::unique_ptr<Expression>> args;
args.push_back(std::move(expr));
- ASTIdentifier id(Position(), type->description());
+ ASTIdentifier id(Position(), type.description());
std::unique_ptr<Expression> ctor = this->convertIdentifier(id);
ASSERT(ctor);
return this->call(Position(), std::move(ctor), std::move(args));
}
- ABORT("cannot coerce %s to %s", expr->fType->description().c_str(),
- type->description().c_str());
+ ABORT("cannot coerce %s to %s", expr->fType.description().c_str(),
+ type.description().c_str());
}
/**
* Determines the operand and result types of a binary expression. Returns true if the expression is
* legal, false otherwise. If false, the values of the out parameters are undefined.
*/
-static bool determine_binary_type(Token::Kind op, std::shared_ptr<Type> left,
- std::shared_ptr<Type> right,
- std::shared_ptr<Type>* outLeftType,
- std::shared_ptr<Type>* outRightType,
- std::shared_ptr<Type>* outResultType,
+static bool determine_binary_type(const Context& context,
+ Token::Kind op,
+ const Type& left,
+ const Type& right,
+ const Type** outLeftType,
+ const Type** outRightType,
+ const Type** outResultType,
bool tryFlipped) {
bool isLogical;
switch (op) {
@@ -638,24 +644,25 @@ static bool determine_binary_type(Token::Kind op, std::shared_ptr<Type> left,
case Token::LOGICALOREQ: // fall through
case Token::LOGICALANDEQ: // fall through
case Token::LOGICALXOREQ:
- *outLeftType = kBool_Type;
- *outRightType = kBool_Type;
- *outResultType = kBool_Type;
- return left->canCoerceTo(kBool_Type) && right->canCoerceTo(kBool_Type);
+ *outLeftType = context.fBool_Type.get();
+ *outRightType = context.fBool_Type.get();
+ *outResultType = context.fBool_Type.get();
+ return left.canCoerceTo(*context.fBool_Type) &&
+ right.canCoerceTo(*context.fBool_Type);
case Token::STAR: // fall through
case Token::STAREQ:
// FIXME need to handle non-square matrices
- if (left->kind() == Type::kMatrix_Kind && right->kind() == Type::kVector_Kind) {
- *outLeftType = left;
- *outRightType = right;
- *outResultType = right;
- return left->rows() == right->columns();
+ if (left.kind() == Type::kMatrix_Kind && right.kind() == Type::kVector_Kind) {
+ *outLeftType = &left;
+ *outRightType = &right;
+ *outResultType = &right;
+ return left.rows() == right.columns();
}
- if (left->kind() == Type::kVector_Kind && right->kind() == Type::kMatrix_Kind) {
- *outLeftType = left;
- *outRightType = right;
- *outResultType = left;
- return left->columns() == right->columns();
+ if (left.kind() == Type::kVector_Kind && right.kind() == Type::kMatrix_Kind) {
+ *outLeftType = &left;
+ *outRightType = &right;
+ *outResultType = &left;
+ return left.columns() == right.columns();
}
// fall through
default:
@@ -664,41 +671,42 @@ static bool determine_binary_type(Token::Kind op, std::shared_ptr<Type> left,
// FIXME: need to disallow illegal operations like vec3 > vec3. Also do not currently have
// full support for numbers other than float.
if (left == right) {
- *outLeftType = left;
- *outRightType = left;
+ *outLeftType = &left;
+ *outRightType = &left;
if (isLogical) {
- *outResultType = kBool_Type;
+ *outResultType = context.fBool_Type.get();
} else {
- *outResultType = left;
+ *outResultType = &left;
}
return true;
}
// FIXME: incorrect for shift operations
- if (left->canCoerceTo(right)) {
- *outLeftType = right;
- *outRightType = right;
+ if (left.canCoerceTo(right)) {
+ *outLeftType = &right;
+ *outRightType = &right;
if (isLogical) {
- *outResultType = kBool_Type;
+ *outResultType = context.fBool_Type.get();
} else {
- *outResultType = right;
+ *outResultType = &right;
}
return true;
}
- if ((left->kind() == Type::kVector_Kind || left->kind() == Type::kMatrix_Kind) &&
- (right->kind() == Type::kScalar_Kind)) {
- if (determine_binary_type(op, left->componentType(), right, outLeftType, outRightType,
- outResultType, false)) {
- *outLeftType = (*outLeftType)->toCompound(left->columns(), left->rows());
+ if ((left.kind() == Type::kVector_Kind || left.kind() == Type::kMatrix_Kind) &&
+ (right.kind() == Type::kScalar_Kind)) {
+ if (determine_binary_type(context, op, left.componentType(), right, outLeftType,
+ outRightType, outResultType, false)) {
+ *outLeftType = &(*outLeftType)->toCompound(context, left.columns(), left.rows());
if (!isLogical) {
- *outResultType = (*outResultType)->toCompound(left->columns(), left->rows());
+ *outResultType = &(*outResultType)->toCompound(context, left.columns(),
+ left.rows());
}
return true;
}
return false;
}
if (tryFlipped) {
- return determine_binary_type(op, right, left, outRightType, outLeftType, outResultType,
- false);
+ return determine_binary_type(context, op, right, left, outRightType, outLeftType,
+ outResultType, false);
}
return false;
}
@@ -713,15 +721,15 @@ std::unique_ptr<Expression> IRGenerator::convertBinaryExpression(
if (!right) {
return nullptr;
}
- std::shared_ptr<Type> leftType;
- std::shared_ptr<Type> rightType;
- std::shared_ptr<Type> resultType;
- if (!determine_binary_type(expression.fOperator, left->fType, right->fType, &leftType,
+ const Type* leftType;
+ const Type* rightType;
+ const Type* resultType;
+ if (!determine_binary_type(fContext, expression.fOperator, left->fType, right->fType, &leftType,
&rightType, &resultType, true)) {
fErrors.error(expression.fPosition, "type mismatch: '" +
Token::OperatorName(expression.fOperator) +
- "' cannot operate on '" + left->fType->fName +
- "', '" + right->fType->fName + "'");
+ "' cannot operate on '" + left->fType.fName +
+ "', '" + right->fType.fName + "'");
return nullptr;
}
switch (expression.fOperator) {
@@ -744,17 +752,18 @@ std::unique_ptr<Expression> IRGenerator::convertBinaryExpression(
break;
}
return std::unique_ptr<Expression>(new BinaryExpression(expression.fPosition,
- this->coerce(std::move(left), leftType),
+ this->coerce(std::move(left),
+ *leftType),
expression.fOperator,
this->coerce(std::move(right),
- rightType),
- resultType));
+ *rightType),
+ *resultType));
}
std::unique_ptr<Expression> IRGenerator::convertTernaryExpression(
const ASTTernaryExpression& expression) {
std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*expression.fTest),
- kBool_Type);
+ *fContext.fBool_Type);
if (!test) {
return nullptr;
}
@@ -766,34 +775,33 @@ std::unique_ptr<Expression> IRGenerator::convertTernaryExpression(
if (!ifFalse) {
return nullptr;
}
- std::shared_ptr<Type> trueType;
- std::shared_ptr<Type> falseType;
- std::shared_ptr<Type> resultType;
- if (!determine_binary_type(Token::EQEQ, ifTrue->fType, ifFalse->fType, &trueType,
+ const Type* trueType;
+ const Type* falseType;
+ const Type* resultType;
+ if (!determine_binary_type(fContext, Token::EQEQ, ifTrue->fType, ifFalse->fType, &trueType,
&falseType, &resultType, true)) {
fErrors.error(expression.fPosition, "ternary operator result mismatch: '" +
- ifTrue->fType->fName + "', '" +
- ifFalse->fType->fName + "'");
+ ifTrue->fType.fName + "', '" +
+ ifFalse->fType.fName + "'");
return nullptr;
}
ASSERT(trueType == falseType);
- ifTrue = this->coerce(std::move(ifTrue), trueType);
- ifFalse = this->coerce(std::move(ifFalse), falseType);
+ ifTrue = this->coerce(std::move(ifTrue), *trueType);
+ ifFalse = this->coerce(std::move(ifFalse), *falseType);
return std::unique_ptr<Expression>(new TernaryExpression(expression.fPosition,
std::move(test),
std::move(ifTrue),
std::move(ifFalse)));
}
-std::unique_ptr<Expression> IRGenerator::call(
- Position position,
- std::shared_ptr<FunctionDeclaration> function,
- std::vector<std::unique_ptr<Expression>> arguments) {
- if (function->fParameters.size() != arguments.size()) {
- std::string msg = "call to '" + function->fName + "' expected " +
- to_string(function->fParameters.size()) +
+std::unique_ptr<Expression> IRGenerator::call(Position position,
+ const FunctionDeclaration& function,
+ std::vector<std::unique_ptr<Expression>> arguments) {
+ if (function.fParameters.size() != arguments.size()) {
+ std::string msg = "call to '" + function.fName + "' expected " +
+ to_string(function.fParameters.size()) +
" argument";
- if (function->fParameters.size() != 1) {
+ if (function.fParameters.size() != 1) {
msg += "s";
}
msg += ", but found " + to_string(arguments.size());
@@ -801,12 +809,12 @@ std::unique_ptr<Expression> IRGenerator::call(
return nullptr;
}
for (size_t i = 0; i < arguments.size(); i++) {
- arguments[i] = this->coerce(std::move(arguments[i]), function->fParameters[i]->fType);
- if (arguments[i] && (function->fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag)) {
+ arguments[i] = this->coerce(std::move(arguments[i]), function.fParameters[i]->fType);
+ if (arguments[i] && (function.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag)) {
this->markWrittenTo(*arguments[i]);
}
}
- return std::unique_ptr<FunctionCall>(new FunctionCall(position, std::move(function),
+ return std::unique_ptr<FunctionCall>(new FunctionCall(position, function,
std::move(arguments)));
}
@@ -815,16 +823,16 @@ std::unique_ptr<Expression> IRGenerator::call(
* if the cost could be computed, false if the call is not valid. Cost has no particular meaning
* other than "lower costs are preferred".
*/
-bool IRGenerator::determineCallCost(std::shared_ptr<FunctionDeclaration> function,
+bool IRGenerator::determineCallCost(const FunctionDeclaration& function,
const std::vector<std::unique_ptr<Expression>>& arguments,
int* outCost) {
- if (function->fParameters.size() != arguments.size()) {
+ if (function.fParameters.size() != arguments.size()) {
return false;
}
int total = 0;
for (size_t i = 0; i < arguments.size(); i++) {
int cost;
- if (arguments[i]->fType->determineCoercionCost(function->fParameters[i]->fType, &cost)) {
+ if (arguments[i]->fType.determineCoercionCost(function.fParameters[i]->fType, &cost)) {
total += cost;
} else {
return false;
@@ -848,97 +856,97 @@ std::unique_ptr<Expression> IRGenerator::call(Position position,
}
FunctionReference* ref = (FunctionReference*) functionValue.get();
int bestCost = INT_MAX;
- std::shared_ptr<FunctionDeclaration> best;
+ const FunctionDeclaration* best = nullptr;
if (ref->fFunctions.size() > 1) {
for (const auto& f : ref->fFunctions) {
int cost;
- if (this->determineCallCost(f, arguments, &cost) && cost < bestCost) {
+ if (this->determineCallCost(*f, arguments, &cost) && cost < bestCost) {
bestCost = cost;
best = f;
}
}
if (best) {
- return this->call(position, std::move(best), std::move(arguments));
+ return this->call(position, *best, std::move(arguments));
}
std::string msg = "no match for " + ref->fFunctions[0]->fName + "(";
std::string separator = "";
for (size_t i = 0; i < arguments.size(); i++) {
msg += separator;
separator = ", ";
- msg += arguments[i]->fType->description();
+ msg += arguments[i]->fType.description();
}
msg += ")";
fErrors.error(position, msg);
return nullptr;
}
- return this->call(position, ref->fFunctions[0], std::move(arguments));
+ return this->call(position, *ref->fFunctions[0], std::move(arguments));
}
std::unique_ptr<Expression> IRGenerator::convertConstructor(
Position position,
- std::shared_ptr<Type> type,
+ const Type& type,
std::vector<std::unique_ptr<Expression>> args) {
// FIXME: add support for structs and arrays
- Type::Kind kind = type->kind();
- if (!type->isNumber() && kind != Type::kVector_Kind && kind != Type::kMatrix_Kind) {
- fErrors.error(position, "cannot construct '" + type->description() + "'");
+ Type::Kind kind = type.kind();
+ if (!type.isNumber() && kind != Type::kVector_Kind && kind != Type::kMatrix_Kind) {
+ fErrors.error(position, "cannot construct '" + type.description() + "'");
return nullptr;
}
- if (type == kFloat_Type && args.size() == 1 &&
+ if (type == *fContext.fFloat_Type && args.size() == 1 &&
args[0]->fKind == Expression::kIntLiteral_Kind) {
int64_t value = ((IntLiteral&) *args[0]).fValue;
- return std::unique_ptr<Expression>(new FloatLiteral(position, (double) value));
+ return std::unique_ptr<Expression>(new FloatLiteral(fContext, position, (double) value));
}
if (args.size() == 1 && args[0]->fType == type) {
// argument is already the right type, just return it
return std::move(args[0]);
}
- if (type->isNumber()) {
+ if (type.isNumber()) {
if (args.size() != 1) {
- fErrors.error(position, "invalid arguments to '" + type->description() +
+ fErrors.error(position, "invalid arguments to '" + type.description() +
"' constructor, (expected exactly 1 argument, but found " +
to_string(args.size()) + ")");
}
- if (args[0]->fType == kBool_Type) {
- std::unique_ptr<IntLiteral> zero(new IntLiteral(position, 0));
- std::unique_ptr<IntLiteral> one(new IntLiteral(position, 1));
+ if (args[0]->fType == *fContext.fBool_Type) {
+ std::unique_ptr<IntLiteral> zero(new IntLiteral(fContext, position, 0));
+ std::unique_ptr<IntLiteral> one(new IntLiteral(fContext, position, 1));
return std::unique_ptr<Expression>(
new TernaryExpression(position, std::move(args[0]),
this->coerce(std::move(one), type),
this->coerce(std::move(zero),
type)));
- } else if (!args[0]->fType->isNumber()) {
- fErrors.error(position, "invalid argument to '" + type->description() +
+ } else if (!args[0]->fType.isNumber()) {
+ fErrors.error(position, "invalid argument to '" + type.description() +
"' constructor (expected a number or bool, but found '" +
- args[0]->fType->description() + "')");
+ args[0]->fType.description() + "')");
}
} else {
ASSERT(kind == Type::kVector_Kind || kind == Type::kMatrix_Kind);
int actual = 0;
for (size_t i = 0; i < args.size(); i++) {
- if (args[i]->fType->kind() == Type::kVector_Kind ||
- args[i]->fType->kind() == Type::kMatrix_Kind) {
- int columns = args[i]->fType->columns();
- int rows = args[i]->fType->rows();
+ if (args[i]->fType.kind() == Type::kVector_Kind ||
+ args[i]->fType.kind() == Type::kMatrix_Kind) {
+ int columns = args[i]->fType.columns();
+ int rows = args[i]->fType.rows();
args[i] = this->coerce(std::move(args[i]),
- type->componentType()->toCompound(columns, rows));
- actual += args[i]->fType->rows() * args[i]->fType->columns();
- } else if (args[i]->fType->kind() == Type::kScalar_Kind) {
+ type.componentType().toCompound(fContext, columns, rows));
+ actual += args[i]->fType.rows() * args[i]->fType.columns();
+ } else if (args[i]->fType.kind() == Type::kScalar_Kind) {
actual += 1;
- if (type->kind() != Type::kScalar_Kind) {
- args[i] = this->coerce(std::move(args[i]), type->componentType());
+ if (type.kind() != Type::kScalar_Kind) {
+ args[i] = this->coerce(std::move(args[i]), type.componentType());
}
} else {
- fErrors.error(position, "'" + args[i]->fType->description() + "' is not a valid "
- "parameter to '" + type->description() + "' constructor");
+ fErrors.error(position, "'" + args[i]->fType.description() + "' is not a valid "
+ "parameter to '" + type.description() + "' constructor");
return nullptr;
}
}
- int min = type->rows() * type->columns();
- int max = type->columns() > 1 ? INT_MAX : min;
+ int min = type.rows() * type.columns();
+ int max = type.columns() > 1 ? INT_MAX : min;
if ((actual < min || actual > max) &&
!((kind == Type::kVector_Kind || kind == Type::kMatrix_Kind) && (actual == 1))) {
- fErrors.error(position, "invalid arguments to '" + type->description() +
+ fErrors.error(position, "invalid arguments to '" + type.description() +
"' constructor (expected " + to_string(min) + " scalar" +
(min == 1 ? "" : "s") + ", but found " + to_string(actual) +
")");
@@ -956,50 +964,51 @@ std::unique_ptr<Expression> IRGenerator::convertPrefixExpression(
}
switch (expression.fOperator) {
case Token::PLUS:
- if (!base->fType->isNumber() && base->fType->kind() != Type::kVector_Kind) {
+ if (!base->fType.isNumber() && base->fType.kind() != Type::kVector_Kind) {
fErrors.error(expression.fPosition,
- "'+' cannot operate on '" + base->fType->description() + "'");
+ "'+' cannot operate on '" + base->fType.description() + "'");
return nullptr;
}
return base;
case Token::MINUS:
- if (!base->fType->isNumber() && base->fType->kind() != Type::kVector_Kind) {
+ if (!base->fType.isNumber() && base->fType.kind() != Type::kVector_Kind) {
fErrors.error(expression.fPosition,
- "'-' cannot operate on '" + base->fType->description() + "'");
+ "'-' cannot operate on '" + base->fType.description() + "'");
return nullptr;
}
if (base->fKind == Expression::kIntLiteral_Kind) {
- return std::unique_ptr<Expression>(new IntLiteral(base->fPosition,
+ return std::unique_ptr<Expression>(new IntLiteral(fContext, base->fPosition,
-((IntLiteral&) *base).fValue));
}
if (base->fKind == Expression::kFloatLiteral_Kind) {
double value = -((FloatLiteral&) *base).fValue;
- return std::unique_ptr<Expression>(new FloatLiteral(base->fPosition, value));
+ return std::unique_ptr<Expression>(new FloatLiteral(fContext, base->fPosition,
+ value));
}
return std::unique_ptr<Expression>(new PrefixExpression(Token::MINUS, std::move(base)));
case Token::PLUSPLUS:
- if (!base->fType->isNumber()) {
+ if (!base->fType.isNumber()) {
fErrors.error(expression.fPosition,
"'" + Token::OperatorName(expression.fOperator) +
- "' cannot operate on '" + base->fType->description() + "'");
+ "' cannot operate on '" + base->fType.description() + "'");
return nullptr;
}
this->markWrittenTo(*base);
break;
case Token::MINUSMINUS:
- if (!base->fType->isNumber()) {
+ if (!base->fType.isNumber()) {
fErrors.error(expression.fPosition,
"'" + Token::OperatorName(expression.fOperator) +
- "' cannot operate on '" + base->fType->description() + "'");
+ "' cannot operate on '" + base->fType.description() + "'");
return nullptr;
}
this->markWrittenTo(*base);
break;
case Token::NOT:
- if (base->fType != kBool_Type) {
+ if (base->fType != *fContext.fBool_Type) {
fErrors.error(expression.fPosition,
"'" + Token::OperatorName(expression.fOperator) +
- "' cannot operate on '" + base->fType->description() + "'");
+ "' cannot operate on '" + base->fType.description() + "'");
return nullptr;
}
break;
@@ -1012,8 +1021,8 @@ std::unique_ptr<Expression> IRGenerator::convertPrefixExpression(
std::unique_ptr<Expression> IRGenerator::convertIndex(std::unique_ptr<Expression> base,
const ASTExpression& index) {
- if (base->fType->kind() != Type::kArray_Kind && base->fType->kind() != Type::kMatrix_Kind) {
- fErrors.error(base->fPosition, "expected array, but found '" + base->fType->description() +
+ if (base->fType.kind() != Type::kArray_Kind && base->fType.kind() != Type::kMatrix_Kind) {
+ fErrors.error(base->fPosition, "expected array, but found '" + base->fType.description() +
"'");
return nullptr;
}
@@ -1021,30 +1030,31 @@ std::unique_ptr<Expression> IRGenerator::convertIndex(std::unique_ptr<Expression
if (!converted) {
return nullptr;
}
- converted = this->coerce(std::move(converted), kInt_Type);
+ converted = this->coerce(std::move(converted), *fContext.fInt_Type);
if (!converted) {
return nullptr;
}
- return std::unique_ptr<Expression>(new IndexExpression(std::move(base), std::move(converted)));
+ return std::unique_ptr<Expression>(new IndexExpression(fContext, std::move(base),
+ std::move(converted)));
}
std::unique_ptr<Expression> IRGenerator::convertField(std::unique_ptr<Expression> base,
const std::string& field) {
- auto fields = base->fType->fields();
+ auto fields = base->fType.fields();
for (size_t i = 0; i < fields.size(); i++) {
if (fields[i].fName == field) {
return std::unique_ptr<Expression>(new FieldAccess(std::move(base), (int) i));
}
}
- fErrors.error(base->fPosition, "type '" + base->fType->description() + "' does not have a "
+ fErrors.error(base->fPosition, "type '" + base->fType.description() + "' does not have a "
"field named '" + field + "");
return nullptr;
}
std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expression> base,
const std::string& fields) {
- if (base->fType->kind() != Type::kVector_Kind) {
- fErrors.error(base->fPosition, "cannot swizzle type '" + base->fType->description() + "'");
+ if (base->fType.kind() != Type::kVector_Kind) {
+ fErrors.error(base->fPosition, "cannot swizzle type '" + base->fType.description() + "'");
return nullptr;
}
std::vector<int> swizzleComponents;
@@ -1058,7 +1068,7 @@ std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expressi
case 'y': // fall through
case 'g': // fall through
case 't':
- if (base->fType->columns() >= 2) {
+ if (base->fType.columns() >= 2) {
swizzleComponents.push_back(1);
break;
}
@@ -1066,7 +1076,7 @@ std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expressi
case 'z': // fall through
case 'b': // fall through
case 'p':
- if (base->fType->columns() >= 3) {
+ if (base->fType.columns() >= 3) {
swizzleComponents.push_back(2);
break;
}
@@ -1074,7 +1084,7 @@ std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expressi
case 'w': // fall through
case 'a': // fall through
case 'q':
- if (base->fType->columns() >= 4) {
+ if (base->fType.columns() >= 4) {
swizzleComponents.push_back(3);
break;
}
@@ -1090,7 +1100,7 @@ std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expressi
fErrors.error(base->fPosition, "too many components in swizzle mask '" + fields + "'");
return nullptr;
}
- return std::unique_ptr<Expression>(new Swizzle(std::move(base), swizzleComponents));
+ return std::unique_ptr<Expression>(new Swizzle(fContext, std::move(base), swizzleComponents));
}
std::unique_ptr<Expression> IRGenerator::convertSuffixExpression(
@@ -1117,7 +1127,7 @@ std::unique_ptr<Expression> IRGenerator::convertSuffixExpression(
return this->call(expression.fPosition, std::move(base), std::move(arguments));
}
case ASTSuffix::kField_Kind: {
- switch (base->fType->kind()) {
+ switch (base->fType.kind()) {
case Type::kVector_Kind:
return this->convertSwizzle(std::move(base),
((ASTFieldSuffix&) *expression.fSuffix).fField);
@@ -1126,23 +1136,23 @@ std::unique_ptr<Expression> IRGenerator::convertSuffixExpression(
((ASTFieldSuffix&) *expression.fSuffix).fField);
default:
fErrors.error(base->fPosition, "cannot swizzle value of type '" +
- base->fType->description() + "'");
+ base->fType.description() + "'");
return nullptr;
}
}
case ASTSuffix::kPostIncrement_Kind:
- if (!base->fType->isNumber()) {
+ if (!base->fType.isNumber()) {
fErrors.error(expression.fPosition,
- "'++' cannot operate on '" + base->fType->description() + "'");
+ "'++' cannot operate on '" + base->fType.description() + "'");
return nullptr;
}
this->markWrittenTo(*base);
return std::unique_ptr<Expression>(new PostfixExpression(std::move(base),
Token::PLUSPLUS));
case ASTSuffix::kPostDecrement_Kind:
- if (!base->fType->isNumber()) {
+ if (!base->fType.isNumber()) {
fErrors.error(expression.fPosition,
- "'--' cannot operate on '" + base->fType->description() + "'");
+ "'--' cannot operate on '" + base->fType.description() + "'");
return nullptr;
}
this->markWrittenTo(*base);
@@ -1162,13 +1172,13 @@ void IRGenerator::checkValid(const Expression& expr) {
fErrors.error(expr.fPosition, "expected '(' to begin constructor invocation");
break;
default:
- ASSERT(expr.fType != kInvalid_Type);
+ ASSERT(expr.fType != *fContext.fInvalid_Type);
break;
}
}
-void IRGenerator::markReadFrom(std::shared_ptr<Variable> var) {
- var->fIsReadFrom = true;
+void IRGenerator::markReadFrom(const Variable& var) {
+ var.fIsReadFrom = true;
}
static bool has_duplicates(const Swizzle& swizzle) {
@@ -1187,7 +1197,7 @@ static bool has_duplicates(const Swizzle& swizzle) {
void IRGenerator::markWrittenTo(const Expression& expr) {
switch (expr.fKind) {
case Expression::kVariableReference_Kind: {
- const Variable& var = *((VariableReference&) expr).fVariable;
+ const Variable& var = ((VariableReference&) expr).fVariable;
if (var.fModifiers.fFlags & (Modifiers::kConst_Flag | Modifiers::kUniform_Flag)) {
fErrors.error(expr.fPosition,
"cannot modify immutable variable '" + var.fName + "'");
diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h
index d23e5a1bdb..2384b2dabc 100644
--- a/src/sksl/SkSLIRGenerator.h
+++ b/src/sksl/SkSLIRGenerator.h
@@ -53,7 +53,8 @@ namespace SkSL {
*/
class IRGenerator {
public:
- IRGenerator(std::shared_ptr<SymbolTable> root, ErrorReporter& errorReporter);
+ IRGenerator(const Context* context, std::shared_ptr<SymbolTable> root,
+ ErrorReporter& errorReporter);
std::unique_ptr<VarDeclaration> convertVarDeclaration(const ASTVarDeclaration& decl,
Variable::Storage storage);
@@ -65,21 +66,20 @@ private:
void pushSymbolTable();
void popSymbolTable();
- std::shared_ptr<Type> convertType(const ASTType& type);
+ const Type* convertType(const ASTType& type);
std::unique_ptr<Expression> call(Position position,
- std::shared_ptr<FunctionDeclaration> function,
+ const FunctionDeclaration& function,
std::vector<std::unique_ptr<Expression>> arguments);
- bool determineCallCost(std::shared_ptr<FunctionDeclaration> function,
+ bool determineCallCost(const FunctionDeclaration& function,
const std::vector<std::unique_ptr<Expression>>& arguments,
int* outCost);
std::unique_ptr<Expression> call(Position position, std::unique_ptr<Expression> function,
std::vector<std::unique_ptr<Expression>> arguments);
- std::unique_ptr<Expression> coerce(std::unique_ptr<Expression> expr,
- std::shared_ptr<Type> type);
+ std::unique_ptr<Expression> coerce(std::unique_ptr<Expression> expr, const Type& type);
std::unique_ptr<Block> convertBlock(const ASTBlock& block);
std::unique_ptr<Statement> convertBreak(const ASTBreakStatement& b);
std::unique_ptr<Expression> convertConstructor(Position position,
- std::shared_ptr<Type> type,
+ const Type& type,
std::vector<std::unique_ptr<Expression>> params);
std::unique_ptr<Statement> convertContinue(const ASTContinueStatement& c);
std::unique_ptr<Statement> convertDiscard(const ASTDiscardStatement& d);
@@ -106,10 +106,11 @@ private:
std::unique_ptr<Statement> convertWhile(const ASTWhileStatement& w);
void checkValid(const Expression& expr);
- void markReadFrom(std::shared_ptr<Variable> var);
+ void markReadFrom(const Variable& var);
void markWrittenTo(const Expression& expr);
- std::shared_ptr<FunctionDeclaration> fCurrentFunction;
+ const Context& fContext;
+ const FunctionDeclaration* fCurrentFunction;
std::shared_ptr<SymbolTable> fSymbolTable;
ErrorReporter& fErrors;
diff --git a/src/sksl/SkSLParser.cpp b/src/sksl/SkSLParser.cpp
index fa302af0d3..edff0c67d1 100644
--- a/src/sksl/SkSLParser.cpp
+++ b/src/sksl/SkSLParser.cpp
@@ -52,6 +52,7 @@
#include "ast/SkSLASTVarDeclarationStatement.h"
#include "ast/SkSLASTWhileStatement.h"
#include "ir/SkSLSymbolTable.h"
+#include "ir/SkSLType.h"
namespace SkSL {
@@ -290,17 +291,17 @@ std::unique_ptr<ASTType> Parser::structDeclaration() {
return nullptr;
}
for (size_t i = 0; i < decl->fNames.size(); i++) {
- auto type = std::static_pointer_cast<Type>(fTypes[decl->fType->fName]);
+ auto type = (const Type*) fTypes[decl->fType->fName];
for (int j = (int) decl->fSizes[i].size() - 1; j >= 0; j--) {
- if (decl->fSizes[i][j]->fKind == ASTExpression::kInt_Kind) {
+ if (decl->fSizes[i][j]->fKind != ASTExpression::kInt_Kind) {
this->error(decl->fPosition, "array size in struct field must be a constant");
}
uint64_t columns = ((ASTIntLiteral&) *decl->fSizes[i][j]).fValue;
std::string name = type->name() + "[" + to_string(columns) + "]";
- type = std::shared_ptr<Type>(new Type(name, Type::kArray_Kind, std::move(type),
- (int) columns));
+ type = new Type(name, Type::kArray_Kind, *type, (int) columns);
+ fTypes.takeOwnership((Type*) type);
}
- fields.push_back(Type::Field(decl->fModifiers, decl->fNames[i], std::move(type)));
+ fields.push_back(Type::Field(decl->fModifiers, decl->fNames[i], *type));
if (decl->fValues[i]) {
this->error(decl->fPosition, "initializers are not permitted on struct fields");
}
@@ -309,9 +310,8 @@ std::unique_ptr<ASTType> Parser::structDeclaration() {
if (!this->expect(Token::RBRACE, "'}'")) {
return nullptr;
}
- std::shared_ptr<Type> type(new Type(name.fText, fields));
- fTypes.add(type->fName, type);
- return std::unique_ptr<ASTType>(new ASTType(name.fPosition, type->fName,
+ fTypes.add(name.fText, std::unique_ptr<Type>(new Type(name.fText, fields)));
+ return std::unique_ptr<ASTType>(new ASTType(name.fPosition, name.fText,
ASTType::kStruct_Kind));
}
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index 0a2dab3adf..2771e0291b 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -141,36 +141,36 @@ void SPIRVCodeGenerator::writeWord(int32_t word, std::ostream& out) {
#endif
}
-static bool is_float(const Type& type) {
+static bool is_float(const Context& context, const Type& type) {
if (type.kind() == Type::kVector_Kind) {
- return is_float(*type.componentType());
+ return is_float(context, type.componentType());
}
- return type == *kFloat_Type || type == *kDouble_Type;
+ return type == *context.fFloat_Type || type == *context.fDouble_Type;
}
-static bool is_signed(const Type& type) {
+static bool is_signed(const Context& context, const Type& type) {
if (type.kind() == Type::kVector_Kind) {
- return is_signed(*type.componentType());
+ return is_signed(context, type.componentType());
}
- return type == *kInt_Type;
+ return type == *context.fInt_Type;
}
-static bool is_unsigned(const Type& type) {
+static bool is_unsigned(const Context& context, const Type& type) {
if (type.kind() == Type::kVector_Kind) {
- return is_unsigned(*type.componentType());
+ return is_unsigned(context, type.componentType());
}
- return type == *kUInt_Type;
+ return type == *context.fUInt_Type;
}
-static bool is_bool(const Type& type) {
+static bool is_bool(const Context& context, const Type& type) {
if (type.kind() == Type::kVector_Kind) {
- return is_bool(*type.componentType());
+ return is_bool(context, type.componentType());
}
- return type == *kBool_Type;
+ return type == *context.fBool_Type;
}
-static bool is_out(std::shared_ptr<Variable> var) {
- return (var->fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
+static bool is_out(const Variable& var) {
+ return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
}
#if SPIRV_DEBUG
@@ -973,7 +973,7 @@ void SPIRVCodeGenerator::writeStruct(const Type& type, SpvId resultId) {
// in the middle of writing the struct instruction
std::vector<SpvId> types;
for (const auto& f : type.fields()) {
- types.push_back(this->getType(*f.fType));
+ types.push_back(this->getType(f.fType));
}
this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
this->writeWord(resultId, fConstantBuffer);
@@ -982,8 +982,8 @@ void SPIRVCodeGenerator::writeStruct(const Type& type, SpvId resultId) {
}
size_t offset = 0;
for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
- size_t size = type.fields()[i].fType->size();
- size_t alignment = type.fields()[i].fType->alignment();
+ size_t size = type.fields()[i].fType.size();
+ size_t alignment = type.fields()[i].fType.alignment();
size_t mod = offset % alignment;
if (mod != 0) {
offset += alignment - mod;
@@ -995,14 +995,14 @@ void SPIRVCodeGenerator::writeStruct(const Type& type, SpvId resultId) {
this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
(SpvId) offset, fDecorationBuffer);
}
- if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) {
+ if (type.fields()[i].fType.kind() == Type::kMatrix_Kind) {
this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
fDecorationBuffer);
this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
- (SpvId) type.fields()[i].fType->stride(), fDecorationBuffer);
+ (SpvId) type.fields()[i].fType.stride(), fDecorationBuffer);
}
offset += size;
- Type::Kind kind = type.fields()[i].fType->kind();
+ Type::Kind kind = type.fields()[i].fType.kind();
if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
offset += alignment - offset % alignment;
}
@@ -1016,15 +1016,15 @@ SpvId SPIRVCodeGenerator::getType(const Type& type) {
SpvId result = this->nextId();
switch (type.kind()) {
case Type::kScalar_Kind:
- if (type == *kBool_Type) {
+ if (type == *fContext.fBool_Type) {
this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
- } else if (type == *kInt_Type) {
+ } else if (type == *fContext.fInt_Type) {
this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
- } else if (type == *kUInt_Type) {
+ } else if (type == *fContext.fUInt_Type) {
this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
- } else if (type == *kFloat_Type) {
+ } else if (type == *fContext.fFloat_Type) {
this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
- } else if (type == *kDouble_Type) {
+ } else if (type == *fContext.fDouble_Type) {
this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
} else {
ASSERT(false);
@@ -1032,11 +1032,12 @@ SpvId SPIRVCodeGenerator::getType(const Type& type) {
break;
case Type::kVector_Kind:
this->writeInstruction(SpvOpTypeVector, result,
- this->getType(*type.componentType()),
+ this->getType(type.componentType()),
type.columns(), fConstantBuffer);
break;
case Type::kMatrix_Kind:
- this->writeInstruction(SpvOpTypeMatrix, result, this->getType(*index_type(type)),
+ this->writeInstruction(SpvOpTypeMatrix, result,
+ this->getType(index_type(fContext, type)),
type.columns(), fConstantBuffer);
break;
case Type::kStruct_Kind:
@@ -1044,22 +1045,22 @@ SpvId SPIRVCodeGenerator::getType(const Type& type) {
break;
case Type::kArray_Kind: {
if (type.columns() > 0) {
- IntLiteral count(Position(), type.columns());
+ IntLiteral count(fContext, Position(), type.columns());
this->writeInstruction(SpvOpTypeArray, result,
- this->getType(*type.componentType()),
+ this->getType(type.componentType()),
this->writeIntLiteral(count), fConstantBuffer);
this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
(int32_t) type.stride(), fDecorationBuffer);
} else {
ABORT("runtime-sized arrays are not yet supported");
this->writeInstruction(SpvOpTypeRuntimeArray, result,
- this->getType(*type.componentType()), fConstantBuffer);
+ this->getType(type.componentType()), fConstantBuffer);
}
break;
}
case Type::kSampler_Kind: {
SpvId image = this->nextId();
- this->writeInstruction(SpvOpTypeImage, image, this->getType(*kFloat_Type),
+ this->writeInstruction(SpvOpTypeImage, image, this->getType(*fContext.fFloat_Type),
type.dimensions(), type.isDepth(), type.isArrayed(),
type.isMultisampled(), type.isSampled(),
SpvImageFormatUnknown, fConstantBuffer);
@@ -1067,7 +1068,7 @@ SpvId SPIRVCodeGenerator::getType(const Type& type) {
break;
}
default:
- if (type == *kVoid_Type) {
+ if (type == *fContext.fVoid_Type) {
this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
} else {
ABORT("invalid type: %s", type.description().c_str());
@@ -1079,22 +1080,22 @@ SpvId SPIRVCodeGenerator::getType(const Type& type) {
return entry->second;
}
-SpvId SPIRVCodeGenerator::getFunctionType(std::shared_ptr<FunctionDeclaration> function) {
- std::string key = function->fReturnType->description() + "(";
+SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
+ std::string key = function.fReturnType.description() + "(";
std::string separator = "";
- for (size_t i = 0; i < function->fParameters.size(); i++) {
+ for (size_t i = 0; i < function.fParameters.size(); i++) {
key += separator;
separator = ", ";
- key += function->fParameters[i]->fType->description();
+ key += function.fParameters[i]->fType.description();
}
key += ")";
auto entry = fTypeMap.find(key);
if (entry == fTypeMap.end()) {
SpvId result = this->nextId();
- int32_t length = 3 + (int32_t) function->fParameters.size();
- SpvId returnType = this->getType(*function->fReturnType);
+ int32_t length = 3 + (int32_t) function.fParameters.size();
+ SpvId returnType = this->getType(function.fReturnType);
std::vector<SpvId> parameterTypes;
- for (size_t i = 0; i < function->fParameters.size(); i++) {
+ for (size_t i = 0; i < function.fParameters.size(); i++) {
// glslang seems to treat all function arguments as pointers whether they need to be or
// not. I was initially puzzled by this until I ran bizarre failures with certain
// patterns of function calls and control constructs, as exemplified by this minimal
@@ -1118,10 +1119,10 @@ SpvId SPIRVCodeGenerator::getFunctionType(std::shared_ptr<FunctionDeclaration> f
// as glslang does, fixes it. It's entirely possible I simply missed whichever part of
// the spec makes this make sense.
// if (is_out(function->fParameters[i])) {
- parameterTypes.push_back(this->getPointerType(function->fParameters[i]->fType,
+ parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
SpvStorageClassFunction));
// } else {
-// parameterTypes.push_back(this->getType(*function->fParameters[i]->fType));
+// parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
// }
}
this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
@@ -1136,14 +1137,14 @@ SpvId SPIRVCodeGenerator::getFunctionType(std::shared_ptr<FunctionDeclaration> f
return entry->second;
}
-SpvId SPIRVCodeGenerator::getPointerType(std::shared_ptr<Type> type,
+SpvId SPIRVCodeGenerator::getPointerType(const Type& type,
SpvStorageClass_ storageClass) {
- std::string key = type->description() + "*" + to_string(storageClass);
+ std::string key = type.description() + "*" + to_string(storageClass);
auto entry = fTypeMap.find(key);
if (entry == fTypeMap.end()) {
SpvId result = this->nextId();
this->writeInstruction(SpvOpTypePointer, result, storageClass,
- this->getType(*type), fConstantBuffer);
+ this->getType(type), fConstantBuffer);
fTypeMap[key] = result;
return result;
}
@@ -1185,21 +1186,21 @@ SpvId SPIRVCodeGenerator::writeExpression(Expression& expr, std::ostream& out) {
}
SpvId SPIRVCodeGenerator::writeIntrinsicCall(FunctionCall& c, std::ostream& out) {
- auto intrinsic = fIntrinsicMap.find(c.fFunction->fName);
+ auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
ASSERT(intrinsic != fIntrinsicMap.end());
- std::shared_ptr<Type> type = c.fArguments[0]->fType;
+ const Type& type = c.fArguments[0]->fType;
int32_t intrinsicId;
- if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(*type)) {
+ if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
intrinsicId = std::get<1>(intrinsic->second);
- } else if (is_signed(*type)) {
+ } else if (is_signed(fContext, type)) {
intrinsicId = std::get<2>(intrinsic->second);
- } else if (is_unsigned(*type)) {
+ } else if (is_unsigned(fContext, type)) {
intrinsicId = std::get<3>(intrinsic->second);
- } else if (is_bool(*type)) {
+ } else if (is_bool(fContext, type)) {
intrinsicId = std::get<4>(intrinsic->second);
} else {
ABORT("invalid call %s, cannot operate on '%s'", c.description().c_str(),
- type->description().c_str());
+ type.description().c_str());
}
switch (std::get<0>(intrinsic->second)) {
case kGLSL_STD_450_IntrinsicKind: {
@@ -1209,7 +1210,7 @@ SpvId SPIRVCodeGenerator::writeIntrinsicCall(FunctionCall& c, std::ostream& out)
arguments.push_back(this->writeExpression(*c.fArguments[i], out));
}
this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
- this->writeWord(this->getType(*c.fType), out);
+ this->writeWord(this->getType(c.fType), out);
this->writeWord(result, out);
this->writeWord(fGLSLExtendedInstructions, out);
this->writeWord(intrinsicId, out);
@@ -1225,7 +1226,7 @@ SpvId SPIRVCodeGenerator::writeIntrinsicCall(FunctionCall& c, std::ostream& out)
arguments.push_back(this->writeExpression(*c.fArguments[i], out));
}
this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
- this->writeWord(this->getType(*c.fType), out);
+ this->writeWord(this->getType(c.fType), out);
this->writeWord(result, out);
for (SpvId id : arguments) {
this->writeWord(id, out);
@@ -1249,7 +1250,7 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsi
arguments.push_back(this->writeExpression(*c.fArguments[i], out));
}
this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
- this->writeWord(this->getType(*c.fType), out);
+ this->writeWord(this->getType(c.fType), out);
this->writeWord(result, out);
this->writeWord(fGLSLExtendedInstructions, out);
this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
@@ -1259,7 +1260,7 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsi
return result;
}
case kTexture_SpecialIntrinsic: {
- SpvId type = this->getType(*c.fType);
+ SpvId type = this->getType(c.fType);
SpvId sampler = this->writeExpression(*c.fArguments[0], out);
SpvId uv = this->writeExpression(*c.fArguments[1], out);
if (c.fArguments.size() == 3) {
@@ -1274,7 +1275,7 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsi
break;
}
case kTextureProj_SpecialIntrinsic: {
- SpvId type = this->getType(*c.fType);
+ SpvId type = this->getType(c.fType);
SpvId sampler = this->writeExpression(*c.fArguments[0], out);
SpvId uv = this->writeExpression(*c.fArguments[1], out);
if (c.fArguments.size() == 3) {
@@ -1293,7 +1294,7 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsi
SpvId img = this->writeExpression(*c.fArguments[0], out);
SpvId coords = this->writeExpression(*c.fArguments[1], out);
this->writeInstruction(SpvOpImageSampleImplicitLod,
- this->getType(*c.fType),
+ this->getType(c.fType),
result,
img,
coords,
@@ -1305,7 +1306,7 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsi
}
SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out) {
- const auto& entry = fFunctionMap.find(c.fFunction);
+ const auto& entry = fFunctionMap.find(&c.fFunction);
if (entry == fFunctionMap.end()) {
return this->writeIntrinsicCall(c, out);
}
@@ -1318,7 +1319,7 @@ SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out)
SpvId tmpVar;
// if we need a temporary var to store this argument, this is the value to store in the var
SpvId tmpValueId;
- if (is_out(c.fFunction->fParameters[i])) {
+ if (is_out(*c.fFunction.fParameters[i])) {
std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
SpvId ptr = lv->getPointer();
if (ptr) {
@@ -1330,7 +1331,7 @@ SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out)
// update the lvalue.
tmpValueId = lv->load(out);
tmpVar = this->nextId();
- lvalues.push_back(std::make_tuple(tmpVar, this->getType(*c.fArguments[i]->fType),
+ lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType),
std::move(lv)));
}
} else {
@@ -1343,13 +1344,13 @@ SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out)
SpvStorageClassFunction),
tmpVar,
SpvStorageClassFunction,
- out);
+ fVariableBuffer);
this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
arguments.push_back(tmpVar);
}
SpvId result = this->nextId();
this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
- this->writeWord(this->getType(*c.fType), out);
+ this->writeWord(this->getType(c.fType), out);
this->writeWord(result, out);
this->writeWord(entry->second, out);
for (SpvId id : arguments) {
@@ -1366,19 +1367,19 @@ SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out)
}
SpvId SPIRVCodeGenerator::writeConstantVector(Constructor& c) {
- ASSERT(c.fType->kind() == Type::kVector_Kind && c.isConstant());
+ ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
SpvId result = this->nextId();
std::vector<SpvId> arguments;
for (size_t i = 0; i < c.fArguments.size(); i++) {
arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
}
- SpvId type = this->getType(*c.fType);
+ SpvId type = this->getType(c.fType);
if (c.fArguments.size() == 1) {
// with a single argument, a vector will have all of its entries equal to the argument
- this->writeOpCode(SpvOpConstantComposite, 3 + c.fType->columns(), fConstantBuffer);
+ this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
this->writeWord(type, fConstantBuffer);
this->writeWord(result, fConstantBuffer);
- for (int i = 0; i < c.fType->columns(); i++) {
+ for (int i = 0; i < c.fType.columns(); i++) {
this->writeWord(arguments[0], fConstantBuffer);
}
} else {
@@ -1394,43 +1395,43 @@ SpvId SPIRVCodeGenerator::writeConstantVector(Constructor& c) {
}
SpvId SPIRVCodeGenerator::writeFloatConstructor(Constructor& c, std::ostream& out) {
- ASSERT(c.fType == kFloat_Type);
+ ASSERT(c.fType == *fContext.fFloat_Type);
ASSERT(c.fArguments.size() == 1);
- ASSERT(c.fArguments[0]->fType->isNumber());
+ ASSERT(c.fArguments[0]->fType.isNumber());
SpvId result = this->nextId();
SpvId parameter = this->writeExpression(*c.fArguments[0], out);
- if (c.fArguments[0]->fType == kInt_Type) {
- this->writeInstruction(SpvOpConvertSToF, this->getType(*c.fType), result, parameter,
+ if (c.fArguments[0]->fType == *fContext.fInt_Type) {
+ this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter,
out);
- } else if (c.fArguments[0]->fType == kUInt_Type) {
- this->writeInstruction(SpvOpConvertUToF, this->getType(*c.fType), result, parameter,
+ } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
+ this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter,
out);
- } else if (c.fArguments[0]->fType == kFloat_Type) {
+ } else if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
return parameter;
}
return result;
}
SpvId SPIRVCodeGenerator::writeIntConstructor(Constructor& c, std::ostream& out) {
- ASSERT(c.fType == kInt_Type);
+ ASSERT(c.fType == *fContext.fInt_Type);
ASSERT(c.fArguments.size() == 1);
- ASSERT(c.fArguments[0]->fType->isNumber());
+ ASSERT(c.fArguments[0]->fType.isNumber());
SpvId result = this->nextId();
SpvId parameter = this->writeExpression(*c.fArguments[0], out);
- if (c.fArguments[0]->fType == kFloat_Type) {
- this->writeInstruction(SpvOpConvertFToS, this->getType(*c.fType), result, parameter,
+ if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
+ this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter,
out);
- } else if (c.fArguments[0]->fType == kUInt_Type) {
- this->writeInstruction(SpvOpSatConvertUToS, this->getType(*c.fType), result, parameter,
+ } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
+ this->writeInstruction(SpvOpSatConvertUToS, this->getType(c.fType), result, parameter,
out);
- } else if (c.fArguments[0]->fType == kInt_Type) {
+ } else if (c.fArguments[0]->fType == *fContext.fInt_Type) {
return parameter;
}
return result;
}
SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& out) {
- ASSERT(c.fType->kind() == Type::kMatrix_Kind);
+ ASSERT(c.fType.kind() == Type::kMatrix_Kind);
// go ahead and write the arguments so we don't try to write new instructions in the middle of
// an instruction
std::vector<SpvId> arguments;
@@ -1438,30 +1439,31 @@ SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& o
arguments.push_back(this->writeExpression(*c.fArguments[i], out));
}
SpvId result = this->nextId();
- int rows = c.fType->rows();
- int columns = c.fType->columns();
+ int rows = c.fType.rows();
+ int columns = c.fType.columns();
// FIXME this won't work to create a matrix from another matrix
if (arguments.size() == 1) {
// with a single argument, a matrix will have all of its diagonal entries equal to the
// argument and its other values equal to zero
// FIXME this won't work for int matrices
- FloatLiteral zero(Position(), 0);
+ FloatLiteral zero(fContext, Position(), 0);
SpvId zeroId = this->writeFloatLiteral(zero);
std::vector<SpvId> columnIds;
for (int column = 0; column < columns; column++) {
- this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType->rows(),
+ this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.rows(),
out);
- this->writeWord(this->getType(*c.fType->componentType()->toCompound(rows, 1)), out);
+ this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows, 1)),
+ out);
SpvId columnId = this->nextId();
this->writeWord(columnId, out);
columnIds.push_back(columnId);
- for (int row = 0; row < c.fType->columns(); row++) {
+ for (int row = 0; row < c.fType.columns(); row++) {
this->writeWord(row == column ? arguments[0] : zeroId, out);
}
}
this->writeOpCode(SpvOpCompositeConstruct, 3 + columns,
out);
- this->writeWord(this->getType(*c.fType), out);
+ this->writeWord(this->getType(c.fType), out);
this->writeWord(result, out);
for (SpvId id : columnIds) {
this->writeWord(id, out);
@@ -1470,15 +1472,16 @@ SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& o
std::vector<SpvId> columnIds;
int currentCount = 0;
for (size_t i = 0; i < arguments.size(); i++) {
- if (c.fArguments[i]->fType->kind() == Type::kVector_Kind) {
+ if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
ASSERT(currentCount == 0);
columnIds.push_back(arguments[i]);
currentCount = 0;
} else {
- ASSERT(c.fArguments[i]->fType->kind() == Type::kScalar_Kind);
+ ASSERT(c.fArguments[i]->fType.kind() == Type::kScalar_Kind);
if (currentCount == 0) {
- this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType->rows(), out);
- this->writeWord(this->getType(*c.fType->componentType()->toCompound(rows, 1)),
+ this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.rows(), out);
+ this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows,
+ 1)),
out);
SpvId id = this->nextId();
this->writeWord(id, out);
@@ -1490,7 +1493,7 @@ SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& o
}
ASSERT(columnIds.size() == (size_t) columns);
this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
- this->writeWord(this->getType(*c.fType), out);
+ this->writeWord(this->getType(c.fType), out);
this->writeWord(result, out);
for (SpvId id : columnIds) {
this->writeWord(id, out);
@@ -1500,7 +1503,7 @@ SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& o
}
SpvId SPIRVCodeGenerator::writeVectorConstructor(Constructor& c, std::ostream& out) {
- ASSERT(c.fType->kind() == Type::kVector_Kind);
+ ASSERT(c.fType.kind() == Type::kVector_Kind);
if (c.isConstant()) {
return this->writeConstantVector(c);
}
@@ -1511,16 +1514,16 @@ SpvId SPIRVCodeGenerator::writeVectorConstructor(Constructor& c, std::ostream& o
arguments.push_back(this->writeExpression(*c.fArguments[i], out));
}
SpvId result = this->nextId();
- if (arguments.size() == 1 && c.fArguments[0]->fType->kind() == Type::kScalar_Kind) {
- this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType->columns(), out);
- this->writeWord(this->getType(*c.fType), out);
+ if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
+ this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
+ this->writeWord(this->getType(c.fType), out);
this->writeWord(result, out);
- for (int i = 0; i < c.fType->columns(); i++) {
+ for (int i = 0; i < c.fType.columns(); i++) {
this->writeWord(arguments[0], out);
}
} else {
this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
- this->writeWord(this->getType(*c.fType), out);
+ this->writeWord(this->getType(c.fType), out);
this->writeWord(result, out);
for (SpvId id : arguments) {
this->writeWord(id, out);
@@ -1530,12 +1533,12 @@ SpvId SPIRVCodeGenerator::writeVectorConstructor(Constructor& c, std::ostream& o
}
SpvId SPIRVCodeGenerator::writeConstructor(Constructor& c, std::ostream& out) {
- if (c.fType == kFloat_Type) {
+ if (c.fType == *fContext.fFloat_Type) {
return this->writeFloatConstructor(c, out);
- } else if (c.fType == kInt_Type) {
+ } else if (c.fType == *fContext.fInt_Type) {
return this->writeIntConstructor(c, out);
}
- switch (c.fType->kind()) {
+ switch (c.fType.kind()) {
case Type::kVector_Kind:
return this->writeVectorConstructor(c, out);
case Type::kMatrix_Kind:
@@ -1560,7 +1563,7 @@ SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
SpvStorageClass_ get_storage_class(Expression& expr) {
switch (expr.fKind) {
case Expression::kVariableReference_Kind:
- return get_storage_class(((VariableReference&) expr).fVariable->fModifiers);
+ return get_storage_class(((VariableReference&) expr).fVariable.fModifiers);
case Expression::kFieldAccess_Kind:
return get_storage_class(*((FieldAccess&) expr).fBase);
case Expression::kIndex_Kind:
@@ -1582,7 +1585,7 @@ std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(Expression& expr, std::ost
case Expression::kFieldAccess_Kind: {
FieldAccess& fieldExpr = (FieldAccess&) expr;
chain = this->getAccessChain(*fieldExpr.fBase, out);
- IntLiteral index(Position(), fieldExpr.fFieldIndex);
+ IntLiteral index(fContext, Position(), fieldExpr.fFieldIndex);
chain.push_back(this->writeIntLiteral(index));
break;
}
@@ -1698,13 +1701,13 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(Expres
std::ostream& out) {
switch (expr.fKind) {
case Expression::kVariableReference_Kind: {
- std::shared_ptr<Variable> var = ((VariableReference&) expr).fVariable;
- auto entry = fVariableMap.find(var);
+ const Variable& var = ((VariableReference&) expr).fVariable;
+ auto entry = fVariableMap.find(&var);
ASSERT(entry != fVariableMap.end());
return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
*this,
entry->second,
- this->getType(*expr.fType)));
+ this->getType(expr.fType)));
}
case Expression::kIndex_Kind: // fall through
case Expression::kFieldAccess_Kind: {
@@ -1719,7 +1722,7 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(Expres
return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
*this,
member,
- this->getType(*expr.fType)));
+ this->getType(expr.fType)));
}
case Expression::kSwizzle_Kind: {
@@ -1728,7 +1731,7 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(Expres
SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
ASSERT(base);
if (count == 1) {
- IntLiteral index(Position(), swizzle.fComponents[0]);
+ IntLiteral index(fContext, Position(), swizzle.fComponents[0]);
SpvId member = this->nextId();
this->writeInstruction(SpvOpAccessChain,
this->getPointerType(swizzle.fType,
@@ -1740,14 +1743,14 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(Expres
return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
*this,
member,
- this->getType(*expr.fType)));
+ this->getType(expr.fType)));
} else {
return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
*this,
base,
swizzle.fComponents,
- *swizzle.fBase->fType,
- *expr.fType));
+ swizzle.fBase->fType,
+ expr.fType));
}
}
@@ -1758,21 +1761,22 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(Expres
// caught by IRGenerator
SpvId result = this->nextId();
SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
- this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction, out);
+ this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
+ fVariableBuffer);
this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
*this,
result,
- this->getType(*expr.fType)));
+ this->getType(expr.fType)));
}
}
SpvId SPIRVCodeGenerator::writeVariableReference(VariableReference& ref, std::ostream& out) {
- auto entry = fVariableMap.find(ref.fVariable);
+ auto entry = fVariableMap.find(&ref.fVariable);
ASSERT(entry != fVariableMap.end());
SpvId var = entry->second;
SpvId result = this->nextId();
- this->writeInstruction(SpvOpLoad, this->getType(*ref.fVariable->fType), result, var, out);
+ this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
return result;
}
@@ -1789,11 +1793,11 @@ SpvId SPIRVCodeGenerator::writeSwizzle(Swizzle& swizzle, std::ostream& out) {
SpvId result = this->nextId();
size_t count = swizzle.fComponents.size();
if (count == 1) {
- this->writeInstruction(SpvOpCompositeExtract, this->getType(*swizzle.fType), result, base,
+ this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
swizzle.fComponents[0], out);
} else {
this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
- this->writeWord(this->getType(*swizzle.fType), out);
+ this->writeWord(this->getType(swizzle.fType), out);
this->writeWord(result, out);
this->writeWord(base, out);
this->writeWord(base, out);
@@ -1809,13 +1813,13 @@ SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
SpvOp_ ifUInt, SpvOp_ ifBool, std::ostream& out) {
SpvId result = this->nextId();
- if (is_float(operandType)) {
+ if (is_float(fContext, operandType)) {
this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
- } else if (is_signed(operandType)) {
+ } else if (is_signed(fContext, operandType)) {
this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
- } else if (is_unsigned(operandType)) {
+ } else if (is_unsigned(fContext, operandType)) {
this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
- } else if (operandType == *kBool_Type) {
+ } else if (operandType == *fContext.fBool_Type) {
this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
} else {
ABORT("invalid operandType: %s", operandType.description().c_str());
@@ -1862,7 +1866,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea
}
// "normal" operators
- const Type& resultType = *b.fType;
+ const Type& resultType = b.fType;
std::unique_ptr<LValue> lvalue;
SpvId lhs;
if (is_assignment(b.fOperator)) {
@@ -1878,23 +1882,23 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea
// IR allows mismatched types in expressions (e.g. vec2 * float), but they need special handling
// in SPIR-V
if (b.fLeft->fType != b.fRight->fType) {
- if (b.fLeft->fType->kind() == Type::kVector_Kind &&
- b.fRight->fType->isNumber()) {
+ if (b.fLeft->fType.kind() == Type::kVector_Kind &&
+ b.fRight->fType.isNumber()) {
// promote number to vector
SpvId vec = this->nextId();
- this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType->columns(), out);
+ this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
this->writeWord(this->getType(resultType), out);
this->writeWord(vec, out);
for (int i = 0; i < resultType.columns(); i++) {
this->writeWord(rhs, out);
}
rhs = vec;
- operandType = b.fRight->fType.get();
- } else if (b.fRight->fType->kind() == Type::kVector_Kind &&
- b.fLeft->fType->isNumber()) {
+ operandType = &b.fRight->fType;
+ } else if (b.fRight->fType.kind() == Type::kVector_Kind &&
+ b.fLeft->fType.isNumber()) {
// promote number to vector
SpvId vec = this->nextId();
- this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType->columns(), out);
+ this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
this->writeWord(this->getType(resultType), out);
this->writeWord(vec, out);
for (int i = 0; i < resultType.columns(); i++) {
@@ -1902,33 +1906,33 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea
}
lhs = vec;
ASSERT(!lvalue);
- operandType = b.fLeft->fType.get();
- } else if (b.fLeft->fType->kind() == Type::kMatrix_Kind) {
+ operandType = &b.fLeft->fType;
+ } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
SpvOp_ op;
- if (b.fRight->fType->kind() == Type::kMatrix_Kind) {
+ if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
op = SpvOpMatrixTimesMatrix;
- } else if (b.fRight->fType->kind() == Type::kVector_Kind) {
+ } else if (b.fRight->fType.kind() == Type::kVector_Kind) {
op = SpvOpMatrixTimesVector;
} else {
- ASSERT(b.fRight->fType->kind() == Type::kScalar_Kind);
+ ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
op = SpvOpMatrixTimesScalar;
}
SpvId result = this->nextId();
- this->writeInstruction(op, this->getType(*b.fType), result, lhs, rhs, out);
+ this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
if (b.fOperator == Token::STAREQ) {
lvalue->store(result, out);
} else {
ASSERT(b.fOperator == Token::STAR);
}
return result;
- } else if (b.fRight->fType->kind() == Type::kMatrix_Kind) {
+ } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
SpvId result = this->nextId();
- if (b.fLeft->fType->kind() == Type::kVector_Kind) {
- this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(*b.fType), result,
+ if (b.fLeft->fType.kind() == Type::kVector_Kind) {
+ this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
lhs, rhs, out);
} else {
- ASSERT(b.fLeft->fType->kind() == Type::kScalar_Kind);
- this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(*b.fType), result, rhs,
+ ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
+ this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
lhs, out);
}
if (b.fOperator == Token::STAREQ) {
@@ -1941,35 +1945,35 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea
ABORT("unsupported binary expression: %s", b.description().c_str());
}
} else {
- operandType = b.fLeft->fType.get();
- ASSERT(*operandType == *b.fRight->fType);
+ operandType = &b.fLeft->fType;
+ ASSERT(*operandType == b.fRight->fType);
}
switch (b.fOperator) {
case Token::EQEQ:
- ASSERT(resultType == *kBool_Type);
+ ASSERT(resultType == *fContext.fBool_Type);
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdEqual,
SpvOpIEqual, SpvOpIEqual, SpvOpLogicalEqual, out);
case Token::NEQ:
- ASSERT(resultType == *kBool_Type);
+ ASSERT(resultType == *fContext.fBool_Type);
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdNotEqual,
SpvOpINotEqual, SpvOpINotEqual, SpvOpLogicalNotEqual,
out);
case Token::GT:
- ASSERT(resultType == *kBool_Type);
+ ASSERT(resultType == *fContext.fBool_Type);
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
SpvOpUGreaterThan, SpvOpUndef, out);
case Token::LT:
- ASSERT(resultType == *kBool_Type);
+ ASSERT(resultType == *fContext.fBool_Type);
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
case Token::GTEQ:
- ASSERT(resultType == *kBool_Type);
+ ASSERT(resultType == *fContext.fBool_Type);
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
SpvOpUGreaterThanEqual, SpvOpUndef, out);
case Token::LTEQ:
- ASSERT(resultType == *kBool_Type);
+ ASSERT(resultType == *fContext.fBool_Type);
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
SpvOpULessThanEqual, SpvOpUndef, out);
@@ -1980,8 +1984,8 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
SpvOpISub, SpvOpISub, SpvOpUndef, out);
case Token::STAR:
- if (b.fLeft->fType->kind() == Type::kMatrix_Kind &&
- b.fRight->fType->kind() == Type::kMatrix_Kind) {
+ if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
+ b.fRight->fType.kind() == Type::kMatrix_Kind) {
// matrix multiply
SpvId result = this->nextId();
this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
@@ -2008,8 +2012,8 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea
return result;
}
case Token::STAREQ: {
- if (b.fLeft->fType->kind() == Type::kMatrix_Kind &&
- b.fRight->fType->kind() == Type::kMatrix_Kind) {
+ if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
+ b.fRight->fType.kind() == Type::kMatrix_Kind) {
// matrix multiply
SpvId result = this->nextId();
this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
@@ -2039,7 +2043,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea
SpvId SPIRVCodeGenerator::writeLogicalAnd(BinaryExpression& a, std::ostream& out) {
ASSERT(a.fOperator == Token::LOGICALAND);
- BoolLiteral falseLiteral(Position(), false);
+ BoolLiteral falseLiteral(fContext, Position(), false);
SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
SpvId lhs = this->writeExpression(*a.fLeft, out);
SpvId rhsLabel = this->nextId();
@@ -2053,14 +2057,14 @@ SpvId SPIRVCodeGenerator::writeLogicalAnd(BinaryExpression& a, std::ostream& out
this->writeInstruction(SpvOpBranch, end, out);
this->writeLabel(end, out);
SpvId result = this->nextId();
- this->writeInstruction(SpvOpPhi, this->getType(*kBool_Type), result, falseConstant, lhsBlock,
- rhs, rhsBlock, out);
+ this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant,
+ lhsBlock, rhs, rhsBlock, out);
return result;
}
SpvId SPIRVCodeGenerator::writeLogicalOr(BinaryExpression& o, std::ostream& out) {
ASSERT(o.fOperator == Token::LOGICALOR);
- BoolLiteral trueLiteral(Position(), true);
+ BoolLiteral trueLiteral(fContext, Position(), true);
SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
SpvId lhs = this->writeExpression(*o.fLeft, out);
SpvId rhsLabel = this->nextId();
@@ -2074,8 +2078,8 @@ SpvId SPIRVCodeGenerator::writeLogicalOr(BinaryExpression& o, std::ostream& out)
this->writeInstruction(SpvOpBranch, end, out);
this->writeLabel(end, out);
SpvId result = this->nextId();
- this->writeInstruction(SpvOpPhi, this->getType(*kBool_Type), result, trueConstant, lhsBlock,
- rhs, rhsBlock, out);
+ this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant,
+ lhsBlock, rhs, rhsBlock, out);
return result;
}
@@ -2086,7 +2090,7 @@ SpvId SPIRVCodeGenerator::writeTernaryExpression(TernaryExpression& t, std::ostr
SpvId result = this->nextId();
SpvId trueId = this->writeExpression(*t.fIfTrue, out);
SpvId falseId = this->writeExpression(*t.fIfFalse, out);
- this->writeInstruction(SpvOpSelect, this->getType(*t.fType), result, test, trueId, falseId,
+ this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId,
out);
return result;
}
@@ -2094,7 +2098,7 @@ SpvId SPIRVCodeGenerator::writeTernaryExpression(TernaryExpression& t, std::ostr
// Adreno. Switched to storing the result in a temp variable as glslang does.
SpvId var = this->nextId();
this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
- var, SpvStorageClassFunction, out);
+ var, SpvStorageClassFunction, fVariableBuffer);
SpvId trueLabel = this->nextId();
SpvId falseLabel = this->nextId();
SpvId end = this->nextId();
@@ -2108,18 +2112,16 @@ SpvId SPIRVCodeGenerator::writeTernaryExpression(TernaryExpression& t, std::ostr
this->writeInstruction(SpvOpBranch, end, out);
this->writeLabel(end, out);
SpvId result = this->nextId();
- this->writeInstruction(SpvOpLoad, this->getType(*t.fType), result, var, out);
+ this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
return result;
}
-Expression* literal_1(const Type& type) {
- static IntLiteral int1(Position(), 1);
- static FloatLiteral float1(Position(), 1.0);
- if (type == *kInt_Type) {
- return &int1;
+std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
+ if (type == *context.fInt_Type) {
+ return std::unique_ptr<Expression>(new IntLiteral(context, Position(), 1));
}
- else if (type == *kFloat_Type) {
- return &float1;
+ else if (type == *context.fFloat_Type) {
+ return std::unique_ptr<Expression>(new FloatLiteral(context, Position(), 1.0));
} else {
ABORT("math is unsupported on type '%s'")
}
@@ -2128,11 +2130,11 @@ Expression* literal_1(const Type& type) {
SpvId SPIRVCodeGenerator::writePrefixExpression(PrefixExpression& p, std::ostream& out) {
if (p.fOperator == Token::MINUS) {
SpvId result = this->nextId();
- SpvId typeId = this->getType(*p.fType);
+ SpvId typeId = this->getType(p.fType);
SpvId expr = this->writeExpression(*p.fOperand, out);
- if (is_float(*p.fType)) {
+ if (is_float(fContext, p.fType)) {
this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
- } else if (is_signed(*p.fType)) {
+ } else if (is_signed(fContext, p.fType)) {
this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
} else {
ABORT("unsupported prefix expression %s", p.description().c_str());
@@ -2144,8 +2146,8 @@ SpvId SPIRVCodeGenerator::writePrefixExpression(PrefixExpression& p, std::ostrea
return this->writeExpression(*p.fOperand, out);
case Token::PLUSPLUS: {
std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
- SpvId one = this->writeExpression(*literal_1(*p.fType), out);
- SpvId result = this->writeBinaryOperation(*p.fType, *p.fType, lv->load(out), one,
+ SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
+ SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
out);
lv->store(result, out);
@@ -2153,17 +2155,17 @@ SpvId SPIRVCodeGenerator::writePrefixExpression(PrefixExpression& p, std::ostrea
}
case Token::MINUSMINUS: {
std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
- SpvId one = this->writeExpression(*literal_1(*p.fType), out);
- SpvId result = this->writeBinaryOperation(*p.fType, *p.fType, lv->load(out), one,
+ SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
+ SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
out);
lv->store(result, out);
return result;
}
case Token::NOT: {
- ASSERT(p.fOperand->fType == kBool_Type);
+ ASSERT(p.fOperand->fType == *fContext.fBool_Type);
SpvId result = this->nextId();
- this->writeInstruction(SpvOpLogicalNot, this->getType(*p.fOperand->fType), result,
+ this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
this->writeExpression(*p.fOperand, out), out);
return result;
}
@@ -2175,16 +2177,16 @@ SpvId SPIRVCodeGenerator::writePrefixExpression(PrefixExpression& p, std::ostrea
SpvId SPIRVCodeGenerator::writePostfixExpression(PostfixExpression& p, std::ostream& out) {
std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
SpvId result = lv->load(out);
- SpvId one = this->writeExpression(*literal_1(*p.fType), out);
+ SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
switch (p.fOperator) {
case Token::PLUSPLUS: {
- SpvId temp = this->writeBinaryOperation(*p.fType, *p.fType, result, one, SpvOpFAdd,
+ SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd,
SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
lv->store(temp, out);
return result;
}
case Token::MINUSMINUS: {
- SpvId temp = this->writeBinaryOperation(*p.fType, *p.fType, result, one, SpvOpFSub,
+ SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub,
SpvOpISub, SpvOpISub, SpvOpUndef, out);
lv->store(temp, out);
return result;
@@ -2198,14 +2200,14 @@ SpvId SPIRVCodeGenerator::writeBoolLiteral(BoolLiteral& b) {
if (b.fValue) {
if (fBoolTrue == 0) {
fBoolTrue = this->nextId();
- this->writeInstruction(SpvOpConstantTrue, this->getType(*b.fType), fBoolTrue,
+ this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue,
fConstantBuffer);
}
return fBoolTrue;
} else {
if (fBoolFalse == 0) {
fBoolFalse = this->nextId();
- this->writeInstruction(SpvOpConstantFalse, this->getType(*b.fType), fBoolFalse,
+ this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse,
fConstantBuffer);
}
return fBoolFalse;
@@ -2213,22 +2215,22 @@ SpvId SPIRVCodeGenerator::writeBoolLiteral(BoolLiteral& b) {
}
SpvId SPIRVCodeGenerator::writeIntLiteral(IntLiteral& i) {
- if (i.fType == kInt_Type) {
+ if (i.fType == *fContext.fInt_Type) {
auto entry = fIntConstants.find(i.fValue);
if (entry == fIntConstants.end()) {
SpvId result = this->nextId();
- this->writeInstruction(SpvOpConstant, this->getType(*i.fType), result, (SpvId) i.fValue,
+ this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
fConstantBuffer);
fIntConstants[i.fValue] = result;
return result;
}
return entry->second;
} else {
- ASSERT(i.fType == kUInt_Type);
+ ASSERT(i.fType == *fContext.fUInt_Type);
auto entry = fUIntConstants.find(i.fValue);
if (entry == fUIntConstants.end()) {
SpvId result = this->nextId();
- this->writeInstruction(SpvOpConstant, this->getType(*i.fType), result, (SpvId) i.fValue,
+ this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
fConstantBuffer);
fUIntConstants[i.fValue] = result;
return result;
@@ -2238,7 +2240,7 @@ SpvId SPIRVCodeGenerator::writeIntLiteral(IntLiteral& i) {
}
SpvId SPIRVCodeGenerator::writeFloatLiteral(FloatLiteral& f) {
- if (f.fType == kFloat_Type) {
+ if (f.fType == *fContext.fFloat_Type) {
float value = (float) f.fValue;
auto entry = fFloatConstants.find(value);
if (entry == fFloatConstants.end()) {
@@ -2246,21 +2248,21 @@ SpvId SPIRVCodeGenerator::writeFloatLiteral(FloatLiteral& f) {
uint32_t bits;
ASSERT(sizeof(bits) == sizeof(value));
memcpy(&bits, &value, sizeof(bits));
- this->writeInstruction(SpvOpConstant, this->getType(*f.fType), result, bits,
+ this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits,
fConstantBuffer);
fFloatConstants[value] = result;
return result;
}
return entry->second;
} else {
- ASSERT(f.fType == kDouble_Type);
+ ASSERT(f.fType == *fContext.fDouble_Type);
auto entry = fDoubleConstants.find(f.fValue);
if (entry == fDoubleConstants.end()) {
SpvId result = this->nextId();
uint64_t bits;
ASSERT(sizeof(bits) == sizeof(f.fValue));
memcpy(&bits, &f.fValue, sizeof(bits));
- this->writeInstruction(SpvOpConstant, this->getType(*f.fType), result,
+ this->writeInstruction(SpvOpConstant, this->getType(f.fType), result,
bits & 0xffffffff, bits >> 32, fConstantBuffer);
fDoubleConstants[f.fValue] = result;
return result;
@@ -2269,26 +2271,25 @@ SpvId SPIRVCodeGenerator::writeFloatLiteral(FloatLiteral& f) {
}
}
-SpvId SPIRVCodeGenerator::writeFunctionStart(std::shared_ptr<FunctionDeclaration> f,
- std::ostream& out) {
- SpvId result = fFunctionMap[f];
- this->writeInstruction(SpvOpFunction, this->getType(*f->fReturnType), result,
+SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, std::ostream& out) {
+ SpvId result = fFunctionMap[&f];
+ this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result,
SpvFunctionControlMaskNone, this->getFunctionType(f), out);
- this->writeInstruction(SpvOpName, result, f->fName.c_str(), fNameBuffer);
- for (size_t i = 0; i < f->fParameters.size(); i++) {
+ this->writeInstruction(SpvOpName, result, f.fName.c_str(), fNameBuffer);
+ for (size_t i = 0; i < f.fParameters.size(); i++) {
SpvId id = this->nextId();
- fVariableMap[f->fParameters[i]] = id;
+ fVariableMap[f.fParameters[i]] = id;
SpvId type;
- type = this->getPointerType(f->fParameters[i]->fType, SpvStorageClassFunction);
+ type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
this->writeInstruction(SpvOpFunctionParameter, type, id, out);
}
return result;
}
-SpvId SPIRVCodeGenerator::writeFunction(FunctionDefinition& f, std::ostream& out) {
+SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, std::ostream& out) {
SpvId result = this->writeFunctionStart(f.fDeclaration, out);
this->writeLabel(this->nextId(), out);
- if (f.fDeclaration->fName == "main") {
+ if (f.fDeclaration.fName == "main") {
out << fGlobalInitializersBuffer.str();
}
std::stringstream bodyBuffer;
@@ -2350,21 +2351,26 @@ void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int mem
}
SpvId SPIRVCodeGenerator::writeInterfaceBlock(InterfaceBlock& intf) {
- SpvId type = this->getType(*intf.fVariable->fType);
+ SpvId type = this->getType(intf.fVariable.fType);
SpvId result = this->nextId();
this->writeInstruction(SpvOpDecorate, type, SpvDecorationBlock, fDecorationBuffer);
- SpvStorageClass_ storageClass = get_storage_class(intf.fVariable->fModifiers);
+ SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
SpvId ptrType = this->nextId();
this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, type, fConstantBuffer);
this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
- this->writeLayout(intf.fVariable->fModifiers.fLayout, result);
- fVariableMap[intf.fVariable] = result;
+ this->writeLayout(intf.fVariable.fModifiers.fLayout, result);
+ fVariableMap[&intf.fVariable] = result;
return result;
}
void SPIRVCodeGenerator::writeGlobalVars(VarDeclaration& decl, std::ostream& out) {
for (size_t i = 0; i < decl.fVars.size(); i++) {
- if (!decl.fVars[i]->fIsReadFrom && !decl.fVars[i]->fIsWrittenTo) {
+ if (!decl.fVars[i]->fIsReadFrom && !decl.fVars[i]->fIsWrittenTo &&
+ !(decl.fVars[i]->fModifiers.fFlags & (Modifiers::kIn_Flag |
+ Modifiers::kOut_Flag |
+ Modifiers::kUniform_Flag))) {
+ // variable is dead and not an input / output var (the Vulkan debug layers complain if
+ // we elide an interface var, even if it's dead)
continue;
}
SpvStorageClass_ storageClass;
@@ -2373,7 +2379,7 @@ void SPIRVCodeGenerator::writeGlobalVars(VarDeclaration& decl, std::ostream& out
} else if (decl.fVars[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
storageClass = SpvStorageClassOutput;
} else if (decl.fVars[i]->fModifiers.fFlags & Modifiers::kUniform_Flag) {
- if (decl.fVars[i]->fType->kind() == Type::kSampler_Kind) {
+ if (decl.fVars[i]->fType.kind() == Type::kSampler_Kind) {
storageClass = SpvStorageClassUniformConstant;
} else {
storageClass = SpvStorageClassUniform;
@@ -2386,11 +2392,11 @@ void SPIRVCodeGenerator::writeGlobalVars(VarDeclaration& decl, std::ostream& out
SpvId type = this->getPointerType(decl.fVars[i]->fType, storageClass);
this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
this->writeInstruction(SpvOpName, id, decl.fVars[i]->fName.c_str(), fNameBuffer);
- if (decl.fVars[i]->fType->kind() == Type::kMatrix_Kind) {
+ if (decl.fVars[i]->fType.kind() == Type::kMatrix_Kind) {
this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationColMajor,
fDecorationBuffer);
this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationMatrixStride,
- (SpvId) decl.fVars[i]->fType->stride(), fDecorationBuffer);
+ (SpvId) decl.fVars[i]->fType.stride(), fDecorationBuffer);
}
if (decl.fValues[i]) {
ASSERT(!fCurrentBlock);
@@ -2538,15 +2544,15 @@ void SPIRVCodeGenerator::writeInstructions(Program& program, std::ostream& out)
for (size_t i = 0; i < program.fElements.size(); i++) {
if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i];
- fFunctionMap[f.fDeclaration] = this->nextId();
+ fFunctionMap[&f.fDeclaration] = this->nextId();
}
}
for (size_t i = 0; i < program.fElements.size(); i++) {
if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) {
InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i];
SpvId id = this->writeInterfaceBlock(intf);
- if ((intf.fVariable->fModifiers.fFlags & Modifiers::kIn_Flag) ||
- (intf.fVariable->fModifiers.fFlags & Modifiers::kOut_Flag)) {
+ if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
+ (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
interfaceVars.push_back(id);
}
}
@@ -2561,7 +2567,7 @@ void SPIRVCodeGenerator::writeInstructions(Program& program, std::ostream& out)
this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body);
}
}
- std::shared_ptr<FunctionDeclaration> main = nullptr;
+ const FunctionDeclaration* main = nullptr;
for (auto entry : fFunctionMap) {
if (entry.first->fName == "main") {
main = entry.first;
@@ -2569,7 +2575,7 @@ void SPIRVCodeGenerator::writeInstructions(Program& program, std::ostream& out)
}
ASSERT(main);
for (auto entry : fVariableMap) {
- std::shared_ptr<Variable> var = entry.first;
+ const Variable* var = entry.first;
if (var->fStorage == Variable::kGlobal_Storage &&
((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
(var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.h b/src/sksl/SkSLSPIRVCodeGenerator.h
index 885c6b8b70..a20ad9f40b 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.h
+++ b/src/sksl/SkSLSPIRVCodeGenerator.h
@@ -61,8 +61,9 @@ public:
virtual void store(SpvId value, std::ostream& out) = 0;
};
- SPIRVCodeGenerator()
- : fCapabilities(1 << SpvCapabilityShader)
+ SPIRVCodeGenerator(const Context* context)
+ : fContext(*context)
+ , fCapabilities(1 << SpvCapabilityShader)
, fIdCount(1)
, fBoolTrue(0)
, fBoolFalse(0)
@@ -92,9 +93,9 @@ private:
SpvId getType(const Type& type);
- SpvId getFunctionType(std::shared_ptr<FunctionDeclaration> function);
+ SpvId getFunctionType(const FunctionDeclaration& function);
- SpvId getPointerType(std::shared_ptr<Type> type, SpvStorageClass_ storageClass);
+ SpvId getPointerType(const Type& type, SpvStorageClass_ storageClass);
std::vector<SpvId> getAccessChain(Expression& expr, std::ostream& out);
@@ -108,11 +109,11 @@ private:
SpvId writeInterfaceBlock(InterfaceBlock& intf);
- SpvId writeFunctionStart(std::shared_ptr<FunctionDeclaration> f, std::ostream& out);
+ SpvId writeFunctionStart(const FunctionDeclaration& f, std::ostream& out);
- SpvId writeFunctionDeclaration(std::shared_ptr<FunctionDeclaration> f, std::ostream& out);
+ SpvId writeFunctionDeclaration(const FunctionDeclaration& f, std::ostream& out);
- SpvId writeFunction(FunctionDefinition& f, std::ostream& out);
+ SpvId writeFunction(const FunctionDefinition& f, std::ostream& out);
void writeGlobalVars(VarDeclaration& v, std::ostream& out);
@@ -227,14 +228,16 @@ private:
int32_t word5, int32_t word6, int32_t word7, int32_t word8,
std::ostream& out);
+ const Context& fContext;
+
uint64_t fCapabilities;
SpvId fIdCount;
SpvId fGLSLExtendedInstructions;
typedef std::tuple<IntrinsicKind, int32_t, int32_t, int32_t, int32_t> Intrinsic;
std::unordered_map<std::string, Intrinsic> fIntrinsicMap;
- std::unordered_map<std::shared_ptr<FunctionDeclaration>, SpvId> fFunctionMap;
- std::unordered_map<std::shared_ptr<Variable>, SpvId> fVariableMap;
- std::unordered_map<std::shared_ptr<Variable>, int32_t> fInterfaceBlockMap;
+ std::unordered_map<const FunctionDeclaration*, SpvId> fFunctionMap;
+ std::unordered_map<const Variable*, SpvId> fVariableMap;
+ std::unordered_map<const Variable*, int32_t> fInterfaceBlockMap;
std::unordered_map<std::string, SpvId> fTypeMap;
std::stringstream fCapabilitiesBuffer;
std::stringstream fGlobalInitializersBuffer;
diff --git a/src/sksl/ir/SkSLBinaryExpression.h b/src/sksl/ir/SkSLBinaryExpression.h
index bd89d6c602..9ecdbc717c 100644
--- a/src/sksl/ir/SkSLBinaryExpression.h
+++ b/src/sksl/ir/SkSLBinaryExpression.h
@@ -18,7 +18,7 @@ namespace SkSL {
*/
struct BinaryExpression : public Expression {
BinaryExpression(Position position, std::unique_ptr<Expression> left, Token::Kind op,
- std::unique_ptr<Expression> right, std::shared_ptr<Type> type)
+ std::unique_ptr<Expression> right, const Type& type)
: INHERITED(position, kBinary_Kind, type)
, fLeft(std::move(left))
, fOperator(op)
diff --git a/src/sksl/ir/SkSLBlock.h b/src/sksl/ir/SkSLBlock.h
index 56ed77a0ba..a53d13d169 100644
--- a/src/sksl/ir/SkSLBlock.h
+++ b/src/sksl/ir/SkSLBlock.h
@@ -9,6 +9,7 @@
#define SKSL_BLOCK
#include "SkSLStatement.h"
+#include "SkSLSymbolTable.h"
namespace SkSL {
@@ -16,9 +17,11 @@ namespace SkSL {
* A block of multiple statements functioning as a single statement.
*/
struct Block : public Statement {
- Block(Position position, std::vector<std::unique_ptr<Statement>> statements)
+ Block(Position position, std::vector<std::unique_ptr<Statement>> statements,
+ const std::shared_ptr<SymbolTable> symbols)
: INHERITED(position, kBlock_Kind)
- , fStatements(std::move(statements)) {}
+ , fStatements(std::move(statements))
+ , fSymbols(std::move(symbols)) {}
std::string description() const override {
std::string result = "{";
@@ -31,6 +34,7 @@ struct Block : public Statement {
}
const std::vector<std::unique_ptr<Statement>> fStatements;
+ const std::shared_ptr<SymbolTable> fSymbols;
typedef Statement INHERITED;
};
diff --git a/src/sksl/ir/SkSLBoolLiteral.h b/src/sksl/ir/SkSLBoolLiteral.h
index 3c40e59514..8f55a69311 100644
--- a/src/sksl/ir/SkSLBoolLiteral.h
+++ b/src/sksl/ir/SkSLBoolLiteral.h
@@ -8,6 +8,7 @@
#ifndef SKSL_BOOLLITERAL
#define SKSL_BOOLLITERAL
+#include "SkSLContext.h"
#include "SkSLExpression.h"
namespace SkSL {
@@ -16,8 +17,8 @@ namespace SkSL {
* Represents 'true' or 'false'.
*/
struct BoolLiteral : public Expression {
- BoolLiteral(Position position, bool value)
- : INHERITED(position, kBoolLiteral_Kind, kBool_Type)
+ BoolLiteral(const Context& context, Position position, bool value)
+ : INHERITED(position, kBoolLiteral_Kind, *context.fBool_Type)
, fValue(value) {}
std::string description() const override {
diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h
index c58da7e5b8..0501b651ea 100644
--- a/src/sksl/ir/SkSLConstructor.h
+++ b/src/sksl/ir/SkSLConstructor.h
@@ -16,13 +16,13 @@ namespace SkSL {
* Represents the construction of a compound type, such as "vec2(x, y)".
*/
struct Constructor : public Expression {
- Constructor(Position position, std::shared_ptr<Type> type,
+ Constructor(Position position, const Type& type,
std::vector<std::unique_ptr<Expression>> arguments)
- : INHERITED(position, kConstructor_Kind, std::move(type))
+ : INHERITED(position, kConstructor_Kind, type)
, fArguments(std::move(arguments)) {}
std::string description() const override {
- std::string result = fType->description() + "(";
+ std::string result = fType.description() + "(";
std::string separator = "";
for (size_t i = 0; i < fArguments.size(); i++) {
result += separator;
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index 1e42c7a475..92cb37de77 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -35,7 +35,7 @@ struct Expression : public IRNode {
kTypeReference_Kind,
};
- Expression(Position position, Kind kind, std::shared_ptr<Type> type)
+ Expression(Position position, Kind kind, const Type& type)
: INHERITED(position)
, fKind(kind)
, fType(std::move(type)) {}
@@ -45,7 +45,7 @@ struct Expression : public IRNode {
}
const Kind fKind;
- const std::shared_ptr<Type> fType;
+ const Type& fType;
typedef IRNode INHERITED;
};
diff --git a/src/sksl/ir/SkSLField.h b/src/sksl/ir/SkSLField.h
index f2b68bc2bc..a01df2943d 100644
--- a/src/sksl/ir/SkSLField.h
+++ b/src/sksl/ir/SkSLField.h
@@ -21,16 +21,16 @@ namespace SkSL {
* result of declaring anonymous interface blocks.
*/
struct Field : public Symbol {
- Field(Position position, std::shared_ptr<Variable> owner, int fieldIndex)
- : INHERITED(position, kField_Kind, owner->fType->fields()[fieldIndex].fName)
+ Field(Position position, const Variable& owner, int fieldIndex)
+ : INHERITED(position, kField_Kind, owner.fType.fields()[fieldIndex].fName)
, fOwner(owner)
, fFieldIndex(fieldIndex) {}
virtual std::string description() const override {
- return fOwner->description() + "." + fOwner->fType->fields()[fFieldIndex].fName;
+ return fOwner.description() + "." + fOwner.fType.fields()[fFieldIndex].fName;
}
- const std::shared_ptr<Variable> fOwner;
+ const Variable& fOwner;
const int fFieldIndex;
typedef Symbol INHERITED;
diff --git a/src/sksl/ir/SkSLFieldAccess.h b/src/sksl/ir/SkSLFieldAccess.h
index 053498e154..f09c3a3447 100644
--- a/src/sksl/ir/SkSLFieldAccess.h
+++ b/src/sksl/ir/SkSLFieldAccess.h
@@ -18,12 +18,12 @@ namespace SkSL {
*/
struct FieldAccess : public Expression {
FieldAccess(std::unique_ptr<Expression> base, int fieldIndex)
- : INHERITED(base->fPosition, kFieldAccess_Kind, base->fType->fields()[fieldIndex].fType)
+ : INHERITED(base->fPosition, kFieldAccess_Kind, base->fType.fields()[fieldIndex].fType)
, fBase(std::move(base))
, fFieldIndex(fieldIndex) {}
virtual std::string description() const override {
- return fBase->description() + "." + fBase->fType->fields()[fFieldIndex].fName;
+ return fBase->description() + "." + fBase->fType.fields()[fFieldIndex].fName;
}
const std::unique_ptr<Expression> fBase;
diff --git a/src/sksl/ir/SkSLFloatLiteral.h b/src/sksl/ir/SkSLFloatLiteral.h
index deb5b27144..d9c8b6538a 100644
--- a/src/sksl/ir/SkSLFloatLiteral.h
+++ b/src/sksl/ir/SkSLFloatLiteral.h
@@ -8,6 +8,7 @@
#ifndef SKSL_FLOATLITERAL
#define SKSL_FLOATLITERAL
+#include "SkSLContext.h"
#include "SkSLExpression.h"
namespace SkSL {
@@ -16,8 +17,8 @@ namespace SkSL {
* A literal floating point number.
*/
struct FloatLiteral : public Expression {
- FloatLiteral(Position position, double value)
- : INHERITED(position, kFloatLiteral_Kind, kFloat_Type)
+ FloatLiteral(const Context& context, Position position, double value)
+ : INHERITED(position, kFloatLiteral_Kind, *context.fFloat_Type)
, fValue(value) {}
virtual std::string description() const override {
diff --git a/src/sksl/ir/SkSLForStatement.h b/src/sksl/ir/SkSLForStatement.h
index 70bb4014c8..642d15125e 100644
--- a/src/sksl/ir/SkSLForStatement.h
+++ b/src/sksl/ir/SkSLForStatement.h
@@ -10,6 +10,7 @@
#include "SkSLExpression.h"
#include "SkSLStatement.h"
+#include "SkSLSymbolTable.h"
namespace SkSL {
@@ -19,12 +20,13 @@ namespace SkSL {
struct ForStatement : public Statement {
ForStatement(Position position, std::unique_ptr<Statement> initializer,
std::unique_ptr<Expression> test, std::unique_ptr<Expression> next,
- std::unique_ptr<Statement> statement)
+ std::unique_ptr<Statement> statement, std::shared_ptr<SymbolTable> symbols)
: INHERITED(position, kFor_Kind)
, fInitializer(std::move(initializer))
, fTest(std::move(test))
, fNext(std::move(next))
- , fStatement(std::move(statement)) {}
+ , fStatement(std::move(statement))
+ , fSymbols(symbols) {}
std::string description() const override {
std::string result = "for (";
@@ -47,6 +49,7 @@ struct ForStatement : public Statement {
const std::unique_ptr<Expression> fTest;
const std::unique_ptr<Expression> fNext;
const std::unique_ptr<Statement> fStatement;
+ const std::shared_ptr<SymbolTable> fSymbols;
typedef Statement INHERITED;
};
diff --git a/src/sksl/ir/SkSLFunctionCall.h b/src/sksl/ir/SkSLFunctionCall.h
index 78d2566227..85dba40f2a 100644
--- a/src/sksl/ir/SkSLFunctionCall.h
+++ b/src/sksl/ir/SkSLFunctionCall.h
@@ -17,14 +17,14 @@ namespace SkSL {
* A function invocation.
*/
struct FunctionCall : public Expression {
- FunctionCall(Position position, std::shared_ptr<FunctionDeclaration> function,
+ FunctionCall(Position position, const FunctionDeclaration& function,
std::vector<std::unique_ptr<Expression>> arguments)
- : INHERITED(position, kFunctionCall_Kind, function->fReturnType)
+ : INHERITED(position, kFunctionCall_Kind, function.fReturnType)
, fFunction(std::move(function))
, fArguments(std::move(arguments)) {}
std::string description() const override {
- std::string result = fFunction->fName + "(";
+ std::string result = fFunction.fName + "(";
std::string separator = "";
for (size_t i = 0; i < fArguments.size(); i++) {
result += separator;
@@ -35,7 +35,7 @@ struct FunctionCall : public Expression {
return result;
}
- const std::shared_ptr<FunctionDeclaration> fFunction;
+ const FunctionDeclaration& fFunction;
const std::vector<std::unique_ptr<Expression>> fArguments;
typedef Expression INHERITED;
diff --git a/src/sksl/ir/SkSLFunctionDeclaration.h b/src/sksl/ir/SkSLFunctionDeclaration.h
index 32c23f545e..16a184a6d7 100644
--- a/src/sksl/ir/SkSLFunctionDeclaration.h
+++ b/src/sksl/ir/SkSLFunctionDeclaration.h
@@ -10,6 +10,7 @@
#include "SkSLModifiers.h"
#include "SkSLSymbol.h"
+#include "SkSLSymbolTable.h"
#include "SkSLType.h"
#include "SkSLVariable.h"
@@ -20,15 +21,14 @@ namespace SkSL {
*/
struct FunctionDeclaration : public Symbol {
FunctionDeclaration(Position position, std::string name,
- std::vector<std::shared_ptr<Variable>> parameters,
- std::shared_ptr<Type> returnType)
+ std::vector<const Variable*> parameters, const Type& returnType)
: INHERITED(position, kFunctionDeclaration_Kind, std::move(name))
, fDefined(false)
- , fParameters(parameters)
+ , fParameters(std::move(parameters))
, fReturnType(returnType) {}
std::string description() const override {
- std::string result = fReturnType->description() + " " + fName + "(";
+ std::string result = fReturnType.description() + " " + fName + "(";
std::string separator = "";
for (auto p : fParameters) {
result += separator;
@@ -39,13 +39,24 @@ struct FunctionDeclaration : public Symbol {
return result;
}
- bool matches(FunctionDeclaration& f) {
- return fName == f.fName && fParameters == f.fParameters;
+ bool matches(const FunctionDeclaration& f) const {
+ if (fName != f.fName) {
+ return false;
+ }
+ if (fParameters.size() != f.fParameters.size()) {
+ return false;
+ }
+ for (size_t i = 0; i < fParameters.size(); i++) {
+ if (fParameters[i]->fType != f.fParameters[i]->fType) {
+ return false;
+ }
+ }
+ return true;
}
mutable bool fDefined;
- const std::vector<std::shared_ptr<Variable>> fParameters;
- const std::shared_ptr<Type> fReturnType;
+ const std::vector<const Variable*> fParameters;
+ const Type& fReturnType;
typedef Symbol INHERITED;
};
diff --git a/src/sksl/ir/SkSLFunctionDefinition.h b/src/sksl/ir/SkSLFunctionDefinition.h
index fceb5474cb..ace27a3ed8 100644
--- a/src/sksl/ir/SkSLFunctionDefinition.h
+++ b/src/sksl/ir/SkSLFunctionDefinition.h
@@ -18,17 +18,17 @@ namespace SkSL {
* A function definition (a declaration plus an associated block of code).
*/
struct FunctionDefinition : public ProgramElement {
- FunctionDefinition(Position position, std::shared_ptr<FunctionDeclaration> declaration,
+ FunctionDefinition(Position position, const FunctionDeclaration& declaration,
std::unique_ptr<Block> body)
: INHERITED(position, kFunction_Kind)
- , fDeclaration(std::move(declaration))
+ , fDeclaration(declaration)
, fBody(std::move(body)) {}
std::string description() const override {
- return fDeclaration->description() + " " + fBody->description();
+ return fDeclaration.description() + " " + fBody->description();
}
- const std::shared_ptr<FunctionDeclaration> fDeclaration;
+ const FunctionDeclaration& fDeclaration;
const std::unique_ptr<Block> fBody;
typedef ProgramElement INHERITED;
diff --git a/src/sksl/ir/SkSLFunctionReference.h b/src/sksl/ir/SkSLFunctionReference.h
index d5cc444000..5d97a5879f 100644
--- a/src/sksl/ir/SkSLFunctionReference.h
+++ b/src/sksl/ir/SkSLFunctionReference.h
@@ -8,6 +8,7 @@
#ifndef SKSL_FUNCTIONREFERENCE
#define SKSL_FUNCTIONREFERENCE
+#include "SkSLContext.h"
#include "SkSLExpression.h"
namespace SkSL {
@@ -17,8 +18,9 @@ namespace SkSL {
* always eventually replaced by FunctionCalls in valid programs.
*/
struct FunctionReference : public Expression {
- FunctionReference(Position position, std::vector<std::shared_ptr<FunctionDeclaration>> function)
- : INHERITED(position, kFunctionReference_Kind, kInvalid_Type)
+ FunctionReference(const Context& context, Position position,
+ std::vector<const FunctionDeclaration*> function)
+ : INHERITED(position, kFunctionReference_Kind, *context.fInvalid_Type)
, fFunctions(function) {}
virtual std::string description() const override {
@@ -26,7 +28,7 @@ struct FunctionReference : public Expression {
return "<function>";
}
- const std::vector<std::shared_ptr<FunctionDeclaration>> fFunctions;
+ const std::vector<const FunctionDeclaration*> fFunctions;
typedef Expression INHERITED;
};
diff --git a/src/sksl/ir/SkSLIndexExpression.h b/src/sksl/ir/SkSLIndexExpression.h
index 538c656153..f5b0d09c2c 100644
--- a/src/sksl/ir/SkSLIndexExpression.h
+++ b/src/sksl/ir/SkSLIndexExpression.h
@@ -16,21 +16,21 @@ namespace SkSL {
/**
* Given a type, returns the type that will result from extracting an array value from it.
*/
-static std::shared_ptr<Type> index_type(const Type& type) {
+static const Type& index_type(const Context& context, const Type& type) {
if (type.kind() == Type::kMatrix_Kind) {
- if (type.componentType() == kFloat_Type) {
+ if (type.componentType() == *context.fFloat_Type) {
switch (type.columns()) {
- case 2: return kVec2_Type;
- case 3: return kVec3_Type;
- case 4: return kVec4_Type;
+ case 2: return *context.fVec2_Type;
+ case 3: return *context.fVec3_Type;
+ case 4: return *context.fVec4_Type;
default: ASSERT(false);
}
} else {
- ASSERT(type.componentType() == kDouble_Type);
+ ASSERT(type.componentType() == *context.fDouble_Type);
switch (type.columns()) {
- case 2: return kDVec2_Type;
- case 3: return kDVec3_Type;
- case 4: return kDVec4_Type;
+ case 2: return *context.fDVec2_Type;
+ case 3: return *context.fDVec3_Type;
+ case 4: return *context.fDVec4_Type;
default: ASSERT(false);
}
}
@@ -42,11 +42,12 @@ static std::shared_ptr<Type> index_type(const Type& type) {
* An expression which extracts a value from an array or matrix, as in 'm[2]'.
*/
struct IndexExpression : public Expression {
- IndexExpression(std::unique_ptr<Expression> base, std::unique_ptr<Expression> index)
- : INHERITED(base->fPosition, kIndex_Kind, index_type(*base->fType))
+ IndexExpression(const Context& context, std::unique_ptr<Expression> base,
+ std::unique_ptr<Expression> index)
+ : INHERITED(base->fPosition, kIndex_Kind, index_type(context, base->fType))
, fBase(std::move(base))
, fIndex(std::move(index)) {
- ASSERT(fIndex->fType == kInt_Type);
+ ASSERT(fIndex->fType == *context.fInt_Type);
}
std::string description() const override {
diff --git a/src/sksl/ir/SkSLIntLiteral.h b/src/sksl/ir/SkSLIntLiteral.h
index 80b30d7c05..f2bf40b590 100644
--- a/src/sksl/ir/SkSLIntLiteral.h
+++ b/src/sksl/ir/SkSLIntLiteral.h
@@ -18,8 +18,8 @@ namespace SkSL {
struct IntLiteral : public Expression {
// FIXME: we will need to revisit this if/when we add full support for both signed and unsigned
// 64-bit integers, but for right now an int64_t will hold every value we care about
- IntLiteral(Position position, int64_t value)
- : INHERITED(position, kIntLiteral_Kind, kInt_Type)
+ IntLiteral(const Context& context, Position position, int64_t value)
+ : INHERITED(position, kIntLiteral_Kind, *context.fInt_Type)
, fValue(value) {}
virtual std::string description() const override {
diff --git a/src/sksl/ir/SkSLInterfaceBlock.h b/src/sksl/ir/SkSLInterfaceBlock.h
index baedb5864c..f1121ed707 100644
--- a/src/sksl/ir/SkSLInterfaceBlock.h
+++ b/src/sksl/ir/SkSLInterfaceBlock.h
@@ -24,22 +24,24 @@ namespace SkSL {
* At the IR level, this is represented by a single variable of struct type.
*/
struct InterfaceBlock : public ProgramElement {
- InterfaceBlock(Position position, std::shared_ptr<Variable> var)
+ InterfaceBlock(Position position, const Variable& var, std::shared_ptr<SymbolTable> typeOwner)
: INHERITED(position, kInterfaceBlock_Kind)
- , fVariable(std::move(var)) {
- ASSERT(fVariable->fType->kind() == Type::kStruct_Kind);
+ , fVariable(std::move(var))
+ , fTypeOwner(typeOwner) {
+ ASSERT(fVariable.fType.kind() == Type::kStruct_Kind);
}
std::string description() const override {
- std::string result = fVariable->fModifiers.description() + fVariable->fName + " {\n";
- for (size_t i = 0; i < fVariable->fType->fields().size(); i++) {
- result += fVariable->fType->fields()[i].description() + "\n";
+ std::string result = fVariable.fModifiers.description() + fVariable.fName + " {\n";
+ for (size_t i = 0; i < fVariable.fType.fields().size(); i++) {
+ result += fVariable.fType.fields()[i].description() + "\n";
}
result += "};";
return result;
}
- const std::shared_ptr<Variable> fVariable;
+ const Variable& fVariable;
+ const std::shared_ptr<SymbolTable> fTypeOwner;
typedef ProgramElement INHERITED;
};
diff --git a/src/sksl/ir/SkSLProgram.h b/src/sksl/ir/SkSLProgram.h
index 5edcfded42..205db6e932 100644
--- a/src/sksl/ir/SkSLProgram.h
+++ b/src/sksl/ir/SkSLProgram.h
@@ -12,6 +12,7 @@
#include <memory>
#include "SkSLProgramElement.h"
+#include "SkSLSymbolTable.h"
namespace SkSL {
@@ -24,13 +25,16 @@ struct Program {
kVertex_Kind
};
- Program(Kind kind, std::vector<std::unique_ptr<ProgramElement>> elements)
+ Program(Kind kind, std::vector<std::unique_ptr<ProgramElement>> elements,
+ std::shared_ptr<SymbolTable> symbols)
: fKind(kind)
- , fElements(std::move(elements)) {}
+ , fElements(std::move(elements))
+ , fSymbols(symbols) {}
Kind fKind;
std::vector<std::unique_ptr<ProgramElement>> fElements;
+ std::shared_ptr<SymbolTable> fSymbols;
};
} // namespace
diff --git a/src/sksl/ir/SkSLSwizzle.h b/src/sksl/ir/SkSLSwizzle.h
index ce360d1847..0eb4a00dca 100644
--- a/src/sksl/ir/SkSLSwizzle.h
+++ b/src/sksl/ir/SkSLSwizzle.h
@@ -18,41 +18,40 @@ namespace SkSL {
* instance, swizzling a vec3 with two components will result in a vec2. It is possible to swizzle
* with more components than the source vector, as in 'vec2(1).xxxx'.
*/
-static std::shared_ptr<Type> get_type(Expression& value,
- size_t count) {
- std::shared_ptr<Type> base = value.fType->componentType();
+static const Type& get_type(const Context& context, Expression& value, size_t count) {
+ const Type& base = value.fType.componentType();
if (count == 1) {
return base;
}
- if (base == kFloat_Type) {
+ if (base == *context.fFloat_Type) {
switch (count) {
- case 2: return kVec2_Type;
- case 3: return kVec3_Type;
- case 4: return kVec4_Type;
+ case 2: return *context.fVec2_Type;
+ case 3: return *context.fVec3_Type;
+ case 4: return *context.fVec4_Type;
}
- } else if (base == kDouble_Type) {
+ } else if (base == *context.fDouble_Type) {
switch (count) {
- case 2: return kDVec2_Type;
- case 3: return kDVec3_Type;
- case 4: return kDVec4_Type;
+ case 2: return *context.fDVec2_Type;
+ case 3: return *context.fDVec3_Type;
+ case 4: return *context.fDVec4_Type;
}
- } else if (base == kInt_Type) {
+ } else if (base == *context.fInt_Type) {
switch (count) {
- case 2: return kIVec2_Type;
- case 3: return kIVec3_Type;
- case 4: return kIVec4_Type;
+ case 2: return *context.fIVec2_Type;
+ case 3: return *context.fIVec3_Type;
+ case 4: return *context.fIVec4_Type;
}
- } else if (base == kUInt_Type) {
+ } else if (base == *context.fUInt_Type) {
switch (count) {
- case 2: return kUVec2_Type;
- case 3: return kUVec3_Type;
- case 4: return kUVec4_Type;
+ case 2: return *context.fUVec2_Type;
+ case 3: return *context.fUVec3_Type;
+ case 4: return *context.fUVec4_Type;
}
- } else if (base == kBool_Type) {
+ } else if (base == *context.fBool_Type) {
switch (count) {
- case 2: return kBVec2_Type;
- case 3: return kBVec3_Type;
- case 4: return kBVec4_Type;
+ case 2: return *context.fBVec2_Type;
+ case 3: return *context.fBVec3_Type;
+ case 4: return *context.fBVec4_Type;
}
}
ABORT("cannot swizzle %s\n", value.description().c_str());
@@ -62,8 +61,8 @@ static std::shared_ptr<Type> get_type(Expression& value,
* Represents a vector swizzle operation such as 'vec2(1, 2, 3).zyx'.
*/
struct Swizzle : public Expression {
- Swizzle(std::unique_ptr<Expression> base, std::vector<int> components)
- : INHERITED(base->fPosition, kSwizzle_Kind, get_type(*base, components.size()))
+ Swizzle(const Context& context, std::unique_ptr<Expression> base, std::vector<int> components)
+ : INHERITED(base->fPosition, kSwizzle_Kind, get_type(context, *base, components.size()))
, fBase(std::move(base))
, fComponents(std::move(components)) {
ASSERT(fComponents.size() >= 1 && fComponents.size() <= 4);
diff --git a/src/sksl/ir/SkSLSymbolTable.cpp b/src/sksl/ir/SkSLSymbolTable.cpp
index af83f7a456..80e22da009 100644
--- a/src/sksl/ir/SkSLSymbolTable.cpp
+++ b/src/sksl/ir/SkSLSymbolTable.cpp
@@ -5,23 +5,23 @@
* found in the LICENSE file.
*/
- #include "SkSLSymbolTable.h"
+#include "SkSLSymbolTable.h"
+#include "SkSLUnresolvedFunction.h"
namespace SkSL {
-std::vector<std::shared_ptr<FunctionDeclaration>> SymbolTable::GetFunctions(
- const std::shared_ptr<Symbol>& s) {
- switch (s->fKind) {
+std::vector<const FunctionDeclaration*> SymbolTable::GetFunctions(const Symbol& s) {
+ switch (s.fKind) {
case Symbol::kFunctionDeclaration_Kind:
- return { std::static_pointer_cast<FunctionDeclaration>(s) };
+ return { &((FunctionDeclaration&) s) };
case Symbol::kUnresolvedFunction_Kind:
- return ((UnresolvedFunction&) *s).fFunctions;
+ return ((UnresolvedFunction&) s).fFunctions;
default:
return { };
}
}
-std::shared_ptr<Symbol> SymbolTable::operator[](const std::string& name) {
+const Symbol* SymbolTable::operator[](const std::string& name) {
const auto& entry = fSymbols.find(name);
if (entry == fSymbols.end()) {
if (fParent) {
@@ -30,15 +30,15 @@ std::shared_ptr<Symbol> SymbolTable::operator[](const std::string& name) {
return nullptr;
}
if (fParent) {
- auto functions = GetFunctions(entry->second);
+ auto functions = GetFunctions(*entry->second);
if (functions.size() > 0) {
bool modified = false;
- std::shared_ptr<Symbol> previous = (*fParent)[name];
+ const Symbol* previous = (*fParent)[name];
if (previous) {
- auto previousFunctions = GetFunctions(previous);
- for (const std::shared_ptr<FunctionDeclaration>& prev : previousFunctions) {
+ auto previousFunctions = GetFunctions(*previous);
+ for (const FunctionDeclaration* prev : previousFunctions) {
bool found = false;
- for (const std::shared_ptr<FunctionDeclaration>& current : functions) {
+ for (const FunctionDeclaration* current : functions) {
if (current->matches(*prev)) {
found = true;
break;
@@ -51,7 +51,7 @@ std::shared_ptr<Symbol> SymbolTable::operator[](const std::string& name) {
}
if (modified) {
ASSERT(functions.size() > 1);
- return std::shared_ptr<Symbol>(new UnresolvedFunction(functions));
+ return this->takeOwnership(new UnresolvedFunction(functions));
}
}
}
@@ -59,27 +59,42 @@ std::shared_ptr<Symbol> SymbolTable::operator[](const std::string& name) {
return entry->second;
}
-void SymbolTable::add(const std::string& name, std::shared_ptr<Symbol> symbol) {
- const auto& existing = fSymbols.find(name);
- if (existing == fSymbols.end()) {
- fSymbols[name] = symbol;
- } else if (symbol->fKind == Symbol::kFunctionDeclaration_Kind) {
- const std::shared_ptr<Symbol>& oldSymbol = existing->second;
- if (oldSymbol->fKind == Symbol::kFunctionDeclaration_Kind) {
- std::vector<std::shared_ptr<FunctionDeclaration>> functions;
- functions.push_back(std::static_pointer_cast<FunctionDeclaration>(oldSymbol));
- functions.push_back(std::static_pointer_cast<FunctionDeclaration>(symbol));
- fSymbols[name].reset(new UnresolvedFunction(std::move(functions)));
- } else if (oldSymbol->fKind == Symbol::kUnresolvedFunction_Kind) {
- std::vector<std::shared_ptr<FunctionDeclaration>> functions;
- for (const auto& f : ((UnresolvedFunction&) *oldSymbol).fFunctions) {
- functions.push_back(f);
- }
- functions.push_back(std::static_pointer_cast<FunctionDeclaration>(symbol));
- fSymbols[name].reset(new UnresolvedFunction(std::move(functions)));
+Symbol* SymbolTable::takeOwnership(Symbol* s) {
+ fOwnedPointers.push_back(std::unique_ptr<Symbol>(s));
+ return s;
+}
+
+void SymbolTable::add(const std::string& name, std::unique_ptr<Symbol> symbol) {
+ this->addWithoutOwnership(name, symbol.get());
+ fOwnedPointers.push_back(std::move(symbol));
+}
+
+void SymbolTable::addWithoutOwnership(const std::string& name, const Symbol* symbol) {
+ const auto& existing = fSymbols.find(name);
+ if (existing == fSymbols.end()) {
+ fSymbols[name] = symbol;
+ } else if (symbol->fKind == Symbol::kFunctionDeclaration_Kind) {
+ const Symbol* oldSymbol = existing->second;
+ if (oldSymbol->fKind == Symbol::kFunctionDeclaration_Kind) {
+ std::vector<const FunctionDeclaration*> functions;
+ functions.push_back((const FunctionDeclaration*) oldSymbol);
+ functions.push_back((const FunctionDeclaration*) symbol);
+ UnresolvedFunction* u = new UnresolvedFunction(std::move(functions));
+ fSymbols[name] = u;
+ this->takeOwnership(u);
+ } else if (oldSymbol->fKind == Symbol::kUnresolvedFunction_Kind) {
+ std::vector<const FunctionDeclaration*> functions;
+ for (const auto* f : ((UnresolvedFunction&) *oldSymbol).fFunctions) {
+ functions.push_back(f);
}
- } else {
- fErrorReporter.error(symbol->fPosition, "symbol '" + name + "' was already defined");
+ functions.push_back((const FunctionDeclaration*) symbol);
+ UnresolvedFunction* u = new UnresolvedFunction(std::move(functions));
+ fSymbols[name] = u;
+ this->takeOwnership(u);
}
+ } else {
+ fErrorReporter.error(symbol->fPosition, "symbol '" + name + "' was already defined");
}
+}
+
} // namespace
diff --git a/src/sksl/ir/SkSLSymbolTable.h b/src/sksl/ir/SkSLSymbolTable.h
index 151475d642..d732023ff0 100644
--- a/src/sksl/ir/SkSLSymbolTable.h
+++ b/src/sksl/ir/SkSLSymbolTable.h
@@ -10,12 +10,14 @@
#include <memory>
#include <unordered_map>
+#include <vector>
#include "SkSLErrorReporter.h"
#include "SkSLSymbol.h"
-#include "SkSLUnresolvedFunction.h"
namespace SkSL {
+struct FunctionDeclaration;
+
/**
* Maps identifiers to symbols. Functions, in particular, are mapped to either FunctionDeclaration
* or UnresolvedFunction depending on whether they are overloaded or not.
@@ -29,17 +31,22 @@ public:
: fParent(parent)
, fErrorReporter(errorReporter) {}
- std::shared_ptr<Symbol> operator[](const std::string& name);
+ const Symbol* operator[](const std::string& name);
+
+ void add(const std::string& name, std::unique_ptr<Symbol> symbol);
- void add(const std::string& name, std::shared_ptr<Symbol> symbol);
+ void addWithoutOwnership(const std::string& name, const Symbol* symbol);
+
+ Symbol* takeOwnership(Symbol* s);
const std::shared_ptr<SymbolTable> fParent;
private:
- static std::vector<std::shared_ptr<FunctionDeclaration>> GetFunctions(
- const std::shared_ptr<Symbol>& s);
+ static std::vector<const FunctionDeclaration*> GetFunctions(const Symbol& s);
+
+ std::vector<std::unique_ptr<Symbol>> fOwnedPointers;
- std::unordered_map<std::string, std::shared_ptr<Symbol>> fSymbols;
+ std::unordered_map<std::string, const Symbol*> fSymbols;
ErrorReporter& fErrorReporter;
};
diff --git a/src/sksl/ir/SkSLType.cpp b/src/sksl/ir/SkSLType.cpp
index 27cbd39e44..d28c4f0666 100644
--- a/src/sksl/ir/SkSLType.cpp
+++ b/src/sksl/ir/SkSLType.cpp
@@ -6,29 +6,30 @@
*/
#include "SkSLType.h"
+#include "SkSLContext.h"
namespace SkSL {
-bool Type::determineCoercionCost(std::shared_ptr<Type> other, int* outCost) const {
- if (this == other.get()) {
+bool Type::determineCoercionCost(const Type& other, int* outCost) const {
+ if (*this == other) {
*outCost = 0;
return true;
}
- if (this->kind() == kVector_Kind && other->kind() == kVector_Kind) {
- if (this->columns() == other->columns()) {
- return this->componentType()->determineCoercionCost(other->componentType(), outCost);
+ if (this->kind() == kVector_Kind && other.kind() == kVector_Kind) {
+ if (this->columns() == other.columns()) {
+ return this->componentType().determineCoercionCost(other.componentType(), outCost);
}
return false;
}
if (this->kind() == kMatrix_Kind) {
- if (this->columns() == other->columns() &&
- this->rows() == other->rows()) {
- return this->componentType()->determineCoercionCost(other->componentType(), outCost);
+ if (this->columns() == other.columns() &&
+ this->rows() == other.rows()) {
+ return this->componentType().determineCoercionCost(other.componentType(), outCost);
}
return false;
}
for (size_t i = 0; i < fCoercibleTypes.size(); i++) {
- if (fCoercibleTypes[i] == other) {
+ if (*fCoercibleTypes[i] == other) {
*outCost = (int) i + 1;
return true;
}
@@ -36,93 +37,93 @@ bool Type::determineCoercionCost(std::shared_ptr<Type> other, int* outCost) cons
return false;
}
-std::shared_ptr<Type> Type::toCompound(int columns, int rows) {
+const Type& Type::toCompound(const Context& context, int columns, int rows) const {
ASSERT(this->kind() == Type::kScalar_Kind);
if (columns == 1 && rows == 1) {
- return std::shared_ptr<Type>(this);
+ return *this;
}
- if (*this == *kFloat_Type) {
+ if (*this == *context.fFloat_Type) {
switch (rows) {
case 1:
switch (columns) {
- case 2: return kVec2_Type;
- case 3: return kVec3_Type;
- case 4: return kVec4_Type;
+ case 2: return *context.fVec2_Type;
+ case 3: return *context.fVec3_Type;
+ case 4: return *context.fVec4_Type;
default: ABORT("unsupported vector column count (%d)", columns);
}
case 2:
switch (columns) {
- case 2: return kMat2x2_Type;
- case 3: return kMat3x2_Type;
- case 4: return kMat4x2_Type;
+ case 2: return *context.fMat2x2_Type;
+ case 3: return *context.fMat3x2_Type;
+ case 4: return *context.fMat4x2_Type;
default: ABORT("unsupported matrix column count (%d)", columns);
}
case 3:
switch (columns) {
- case 2: return kMat2x3_Type;
- case 3: return kMat3x3_Type;
- case 4: return kMat4x3_Type;
+ case 2: return *context.fMat2x3_Type;
+ case 3: return *context.fMat3x3_Type;
+ case 4: return *context.fMat4x3_Type;
default: ABORT("unsupported matrix column count (%d)", columns);
}
case 4:
switch (columns) {
- case 2: return kMat2x4_Type;
- case 3: return kMat3x4_Type;
- case 4: return kMat4x4_Type;
+ case 2: return *context.fMat2x4_Type;
+ case 3: return *context.fMat3x4_Type;
+ case 4: return *context.fMat4x4_Type;
default: ABORT("unsupported matrix column count (%d)", columns);
}
default: ABORT("unsupported row count (%d)", rows);
}
- } else if (*this == *kDouble_Type) {
+ } else if (*this == *context.fDouble_Type) {
switch (rows) {
case 1:
switch (columns) {
- case 2: return kDVec2_Type;
- case 3: return kDVec3_Type;
- case 4: return kDVec4_Type;
+ case 2: return *context.fDVec2_Type;
+ case 3: return *context.fDVec3_Type;
+ case 4: return *context.fDVec4_Type;
default: ABORT("unsupported vector column count (%d)", columns);
}
case 2:
switch (columns) {
- case 2: return kDMat2x2_Type;
- case 3: return kDMat3x2_Type;
- case 4: return kDMat4x2_Type;
+ case 2: return *context.fDMat2x2_Type;
+ case 3: return *context.fDMat3x2_Type;
+ case 4: return *context.fDMat4x2_Type;
default: ABORT("unsupported matrix column count (%d)", columns);
}
case 3:
switch (columns) {
- case 2: return kDMat2x3_Type;
- case 3: return kDMat3x3_Type;
- case 4: return kDMat4x3_Type;
+ case 2: return *context.fDMat2x3_Type;
+ case 3: return *context.fDMat3x3_Type;
+ case 4: return *context.fDMat4x3_Type;
default: ABORT("unsupported matrix column count (%d)", columns);
}
case 4:
switch (columns) {
- case 2: return kDMat2x4_Type;
- case 3: return kDMat3x4_Type;
- case 4: return kDMat4x4_Type;
+ case 2: return *context.fDMat2x4_Type;
+ case 3: return *context.fDMat3x4_Type;
+ case 4: return *context.fDMat4x4_Type;
default: ABORT("unsupported matrix column count (%d)", columns);
}
default: ABORT("unsupported row count (%d)", rows);
}
- } else if (*this == *kInt_Type) {
+ } else if (*this == *context.fInt_Type) {
switch (rows) {
case 1:
switch (columns) {
- case 2: return kIVec2_Type;
- case 3: return kIVec3_Type;
- case 4: return kIVec4_Type;
+ case 2: return *context.fIVec2_Type;
+ case 3: return *context.fIVec3_Type;
+ case 4: return *context.fIVec4_Type;
default: ABORT("unsupported vector column count (%d)", columns);
}
default: ABORT("unsupported row count (%d)", rows);
}
- } else if (*this == *kUInt_Type) {
+ } else if (*this == *context.fUInt_Type) {
switch (rows) {
case 1:
switch (columns) {
- case 2: return kUVec2_Type;
- case 3: return kUVec3_Type;
- case 4: return kUVec4_Type;
+ case 2: return *context.fUVec2_Type;
+ case 3: return *context.fUVec3_Type;
+ case 4: return *context.fUVec4_Type;
default: ABORT("unsupported vector column count (%d)", columns);
}
default: ABORT("unsupported row count (%d)", rows);
@@ -131,128 +132,4 @@ std::shared_ptr<Type> Type::toCompound(int columns, int rows) {
ABORT("unsupported scalar_to_compound type %s", this->description().c_str());
}
-const std::shared_ptr<Type> kVoid_Type(new Type("void"));
-
-const std::shared_ptr<Type> kDouble_Type(new Type("double", true));
-const std::shared_ptr<Type> kDVec2_Type(new Type("dvec2", kDouble_Type, 2));
-const std::shared_ptr<Type> kDVec3_Type(new Type("dvec3", kDouble_Type, 3));
-const std::shared_ptr<Type> kDVec4_Type(new Type("dvec4", kDouble_Type, 4));
-
-const std::shared_ptr<Type> kFloat_Type(new Type("float", true, { kDouble_Type }));
-const std::shared_ptr<Type> kVec2_Type(new Type("vec2", kFloat_Type, 2));
-const std::shared_ptr<Type> kVec3_Type(new Type("vec3", kFloat_Type, 3));
-const std::shared_ptr<Type> kVec4_Type(new Type("vec4", kFloat_Type, 4));
-
-const std::shared_ptr<Type> kUInt_Type(new Type("uint", true, { kFloat_Type, kDouble_Type }));
-const std::shared_ptr<Type> kUVec2_Type(new Type("uvec2", kUInt_Type, 2));
-const std::shared_ptr<Type> kUVec3_Type(new Type("uvec3", kUInt_Type, 3));
-const std::shared_ptr<Type> kUVec4_Type(new Type("uvec4", kUInt_Type, 4));
-
-const std::shared_ptr<Type> kInt_Type(new Type("int", true, { kUInt_Type, kFloat_Type,
- kDouble_Type }));
-const std::shared_ptr<Type> kIVec2_Type(new Type("ivec2", kInt_Type, 2));
-const std::shared_ptr<Type> kIVec3_Type(new Type("ivec3", kInt_Type, 3));
-const std::shared_ptr<Type> kIVec4_Type(new Type("ivec4", kInt_Type, 4));
-
-const std::shared_ptr<Type> kBool_Type(new Type("bool", false));
-const std::shared_ptr<Type> kBVec2_Type(new Type("bvec2", kBool_Type, 2));
-const std::shared_ptr<Type> kBVec3_Type(new Type("bvec3", kBool_Type, 3));
-const std::shared_ptr<Type> kBVec4_Type(new Type("bvec4", kBool_Type, 4));
-
-const std::shared_ptr<Type> kMat2x2_Type(new Type("mat2", kFloat_Type, 2, 2));
-const std::shared_ptr<Type> kMat2x3_Type(new Type("mat2x3", kFloat_Type, 2, 3));
-const std::shared_ptr<Type> kMat2x4_Type(new Type("mat2x4", kFloat_Type, 2, 4));
-const std::shared_ptr<Type> kMat3x2_Type(new Type("mat3x2", kFloat_Type, 3, 2));
-const std::shared_ptr<Type> kMat3x3_Type(new Type("mat3", kFloat_Type, 3, 3));
-const std::shared_ptr<Type> kMat3x4_Type(new Type("mat3x4", kFloat_Type, 3, 4));
-const std::shared_ptr<Type> kMat4x2_Type(new Type("mat4x2", kFloat_Type, 4, 2));
-const std::shared_ptr<Type> kMat4x3_Type(new Type("mat4x3", kFloat_Type, 4, 3));
-const std::shared_ptr<Type> kMat4x4_Type(new Type("mat4", kFloat_Type, 4, 4));
-
-const std::shared_ptr<Type> kDMat2x2_Type(new Type("dmat2", kFloat_Type, 2, 2));
-const std::shared_ptr<Type> kDMat2x3_Type(new Type("dmat2x3", kFloat_Type, 2, 3));
-const std::shared_ptr<Type> kDMat2x4_Type(new Type("dmat2x4", kFloat_Type, 2, 4));
-const std::shared_ptr<Type> kDMat3x2_Type(new Type("dmat3x2", kFloat_Type, 3, 2));
-const std::shared_ptr<Type> kDMat3x3_Type(new Type("dmat3", kFloat_Type, 3, 3));
-const std::shared_ptr<Type> kDMat3x4_Type(new Type("dmat3x4", kFloat_Type, 3, 4));
-const std::shared_ptr<Type> kDMat4x2_Type(new Type("dmat4x2", kFloat_Type, 4, 2));
-const std::shared_ptr<Type> kDMat4x3_Type(new Type("dmat4x3", kFloat_Type, 4, 3));
-const std::shared_ptr<Type> kDMat4x4_Type(new Type("dmat4", kFloat_Type, 4, 4));
-
-const std::shared_ptr<Type> kSampler1D_Type(new Type("sampler1D", SpvDim1D, false, false, false, true));
-const std::shared_ptr<Type> kSampler2D_Type(new Type("sampler2D", SpvDim2D, false, false, false, true));
-const std::shared_ptr<Type> kSampler3D_Type(new Type("sampler3D", SpvDim3D, false, false, false, true));
-const std::shared_ptr<Type> kSamplerCube_Type(new Type("samplerCube"));
-const std::shared_ptr<Type> kSampler2DRect_Type(new Type("sampler2DRect"));
-const std::shared_ptr<Type> kSampler1DArray_Type(new Type("sampler1DArray"));
-const std::shared_ptr<Type> kSampler2DArray_Type(new Type("sampler2DArray"));
-const std::shared_ptr<Type> kSamplerCubeArray_Type(new Type("samplerCubeArray"));
-const std::shared_ptr<Type> kSamplerBuffer_Type(new Type("samplerBuffer"));
-const std::shared_ptr<Type> kSampler2DMS_Type(new Type("sampler2DMS"));
-const std::shared_ptr<Type> kSampler2DMSArray_Type(new Type("sampler2DMSArray"));
-const std::shared_ptr<Type> kSampler1DShadow_Type(new Type("sampler1DShadow"));
-const std::shared_ptr<Type> kSampler2DShadow_Type(new Type("sampler2DShadow"));
-const std::shared_ptr<Type> kSamplerCubeShadow_Type(new Type("samplerCubeShadow"));
-const std::shared_ptr<Type> kSampler2DRectShadow_Type(new Type("sampler2DRectShadow"));
-const std::shared_ptr<Type> kSampler1DArrayShadow_Type(new Type("sampler1DArrayShadow"));
-const std::shared_ptr<Type> kSampler2DArrayShadow_Type(new Type("sampler2DArrayShadow"));
-const std::shared_ptr<Type> kSamplerCubeArrayShadow_Type(new Type("samplerCubeArrayShadow"));
-
-static std::vector<std::shared_ptr<Type>> type(std::shared_ptr<Type> t) {
- return { t, t, t, t };
-}
-
-// FIXME figure out what we're supposed to do with the gsampler et al. types
-const std::shared_ptr<Type> kGSampler1D_Type(new Type("$gsampler1D", type(kSampler1D_Type)));
-const std::shared_ptr<Type> kGSampler2D_Type(new Type("$gsampler2D", type(kSampler2D_Type)));
-const std::shared_ptr<Type> kGSampler3D_Type(new Type("$gsampler3D", type(kSampler3D_Type)));
-const std::shared_ptr<Type> kGSamplerCube_Type(new Type("$gsamplerCube", type(kSamplerCube_Type)));
-const std::shared_ptr<Type> kGSampler2DRect_Type(new Type("$gsampler2DRect",
- type(kSampler2DRect_Type)));
-const std::shared_ptr<Type> kGSampler1DArray_Type(new Type("$gsampler1DArray",
- type(kSampler1DArray_Type)));
-const std::shared_ptr<Type> kGSampler2DArray_Type(new Type("$gsampler2DArray",
- type(kSampler2DArray_Type)));
-const std::shared_ptr<Type> kGSamplerCubeArray_Type(new Type("$gsamplerCubeArray",
- type(kSamplerCubeArray_Type)));
-const std::shared_ptr<Type> kGSamplerBuffer_Type(new Type("$gsamplerBuffer",
- type(kSamplerBuffer_Type)));
-const std::shared_ptr<Type> kGSampler2DMS_Type(new Type("$gsampler2DMS",
- type(kSampler2DMS_Type)));
-const std::shared_ptr<Type> kGSampler2DMSArray_Type(new Type("$gsampler2DMSArray",
- type(kSampler2DMSArray_Type)));
-const std::shared_ptr<Type> kGSampler2DArrayShadow_Type(new Type("$gsampler2DArrayShadow",
- type(kSampler2DArrayShadow_Type)));
-const std::shared_ptr<Type> kGSamplerCubeArrayShadow_Type(new Type("$gsamplerCubeArrayShadow",
- type(kSamplerCubeArrayShadow_Type)));
-
-const std::shared_ptr<Type> kGenType_Type(new Type("$genType", { kFloat_Type, kVec2_Type,
- kVec3_Type, kVec4_Type }));
-const std::shared_ptr<Type> kGenDType_Type(new Type("$genDType", { kDouble_Type, kDVec2_Type,
- kDVec3_Type, kDVec4_Type }));
-const std::shared_ptr<Type> kGenIType_Type(new Type("$genIType", { kInt_Type, kIVec2_Type,
- kIVec3_Type, kIVec4_Type }));
-const std::shared_ptr<Type> kGenUType_Type(new Type("$genUType", { kUInt_Type, kUVec2_Type,
- kUVec3_Type, kUVec4_Type }));
-const std::shared_ptr<Type> kGenBType_Type(new Type("$genBType", { kBool_Type, kBVec2_Type,
- kBVec3_Type, kBVec4_Type }));
-
-const std::shared_ptr<Type> kMat_Type(new Type("$mat"));
-
-const std::shared_ptr<Type> kVec_Type(new Type("$vec", { kVec2_Type, kVec2_Type, kVec3_Type,
- kVec4_Type }));
-
-const std::shared_ptr<Type> kGVec_Type(new Type("$gvec"));
-const std::shared_ptr<Type> kGVec2_Type(new Type("$gvec2"));
-const std::shared_ptr<Type> kGVec3_Type(new Type("$gvec3"));
-const std::shared_ptr<Type> kGVec4_Type(new Type("$gvec4", type(kVec4_Type)));
-const std::shared_ptr<Type> kDVec_Type(new Type("$dvec"));
-const std::shared_ptr<Type> kIVec_Type(new Type("$ivec"));
-const std::shared_ptr<Type> kUVec_Type(new Type("$uvec"));
-
-const std::shared_ptr<Type> kBVec_Type(new Type("$bvec", { kBVec2_Type, kBVec2_Type,
- kBVec3_Type, kBVec4_Type }));
-
-const std::shared_ptr<Type> kInvalid_Type(new Type("<INVALID>"));
-
} // namespace
diff --git a/src/sksl/ir/SkSLType.h b/src/sksl/ir/SkSLType.h
index e17bae68db..a929c8e4c7 100644
--- a/src/sksl/ir/SkSLType.h
+++ b/src/sksl/ir/SkSLType.h
@@ -18,24 +18,26 @@
namespace SkSL {
+class Context;
+
/**
* Represents a type, such as int or vec4.
*/
class Type : public Symbol {
public:
struct Field {
- Field(Modifiers modifiers, std::string name, std::shared_ptr<Type> type)
+ Field(Modifiers modifiers, std::string name, const Type& type)
: fModifiers(modifiers)
, fName(std::move(name))
, fType(std::move(type)) {}
- const std::string description() {
- return fType->description() + " " + fName + ";";
+ const std::string description() const {
+ return fType.description() + " " + fName + ";";
}
const Modifiers fModifiers;
const std::string fName;
- const std::shared_ptr<Type> fType;
+ const Type& fType;
};
enum Kind {
@@ -56,7 +58,7 @@ public:
, fTypeKind(kOther_Kind) {}
// Create a generic type which maps to the listed types.
- Type(std::string name, std::vector<std::shared_ptr<Type>> types)
+ Type(std::string name, std::vector<const Type*> types)
: INHERITED(Position(), kType_Kind, std::move(name))
, fTypeKind(kGeneric_Kind)
, fCoercibleTypes(std::move(types)) {
@@ -78,7 +80,7 @@ public:
, fRows(1) {}
// Create a scalar type which can be coerced to the listed types.
- Type(std::string name, bool isNumber, std::vector<std::shared_ptr<Type>> coercibleTypes)
+ Type(std::string name, bool isNumber, std::vector<const Type*> coercibleTypes)
: INHERITED(Position(), kType_Kind, std::move(name))
, fTypeKind(kScalar_Kind)
, fIsNumber(isNumber)
@@ -87,23 +89,23 @@ public:
, fRows(1) {}
// Create a vector type.
- Type(std::string name, std::shared_ptr<Type> componentType, int columns)
+ Type(std::string name, const Type& componentType, int columns)
: Type(name, kVector_Kind, componentType, columns) {}
// Create a vector or array type.
- Type(std::string name, Kind kind, std::shared_ptr<Type> componentType, int columns)
+ Type(std::string name, Kind kind, const Type& componentType, int columns)
: INHERITED(Position(), kType_Kind, std::move(name))
, fTypeKind(kind)
- , fComponentType(std::move(componentType))
+ , fComponentType(&componentType)
, fColumns(columns)
, fRows(1)
, fDimensions(SpvDim1D) {}
// Create a matrix type.
- Type(std::string name, std::shared_ptr<Type> componentType, int columns, int rows)
+ Type(std::string name, const Type& componentType, int columns, int rows)
: INHERITED(Position(), kType_Kind, std::move(name))
, fTypeKind(kMatrix_Kind)
- , fComponentType(std::move(componentType))
+ , fComponentType(&componentType)
, fColumns(columns)
, fRows(rows)
, fDimensions(SpvDim1D) {}
@@ -153,7 +155,7 @@ public:
* Returns true if an instance of this type can be freely coerced (implicitly converted) to
* another type.
*/
- bool canCoerceTo(std::shared_ptr<Type> other) const {
+ bool canCoerceTo(const Type& other) const {
int cost;
return determineCoercionCost(other, &cost);
}
@@ -164,15 +166,15 @@ public:
* costs. Returns true if a conversion is possible, false otherwise. The value of the out
* parameter is undefined if false is returned.
*/
- bool determineCoercionCost(std::shared_ptr<Type> other, int* outCost) const;
+ bool determineCoercionCost(const Type& other, int* outCost) const;
/**
* For matrices and vectors, returns the type of individual cells (e.g. mat2 has a component
* type of kFloat_Type). For all other types, causes an assertion failure.
*/
- std::shared_ptr<Type> componentType() const {
+ const Type& componentType() const {
ASSERT(fComponentType);
- return fComponentType;
+ return *fComponentType;
}
/**
@@ -195,7 +197,7 @@ public:
return fRows;
}
- std::vector<Field> fields() const {
+ const std::vector<Field>& fields() const {
ASSERT(fTypeKind == kStruct_Kind);
return fFields;
}
@@ -204,7 +206,7 @@ public:
* For generic types, returns the types that this generic type can substitute for. For other
* types, returns a list of other types that this type can be coerced into.
*/
- std::vector<std::shared_ptr<Type>> coercibleTypes() const {
+ const std::vector<const Type*>& coercibleTypes() const {
ASSERT(fCoercibleTypes.size() > 0);
return fCoercibleTypes;
}
@@ -257,7 +259,7 @@ public:
case kStruct_Kind: {
size_t result = 16;
for (size_t i = 0; i < fFields.size(); i++) {
- size_t alignment = fFields[i].fType->alignment();
+ size_t alignment = fFields[i].fType.alignment();
if (alignment > result) {
result = alignment;
}
@@ -300,13 +302,13 @@ public:
case kStruct_Kind: {
size_t total = 0;
for (size_t i = 0; i < fFields.size(); i++) {
- size_t alignment = fFields[i].fType->alignment();
+ size_t alignment = fFields[i].fType.alignment();
if (total % alignment != 0) {
total += alignment - total % alignment;
}
ASSERT(false);
ASSERT(total % alignment == 0);
- total += fFields[i].fType->size();
+ total += fFields[i].fType.size();
}
return total;
}
@@ -319,15 +321,15 @@ public:
* Returns the corresponding vector or matrix type with the specified number of columns and
* rows.
*/
- std::shared_ptr<Type> toCompound(int columns, int rows);
+ const Type& toCompound(const Context& context, int columns, int rows) const;
private:
typedef Symbol INHERITED;
const Kind fTypeKind;
const bool fIsNumber = false;
- const std::shared_ptr<Type> fComponentType = nullptr;
- const std::vector<std::shared_ptr<Type>> fCoercibleTypes = { };
+ const Type* fComponentType = nullptr;
+ const std::vector<const Type*> fCoercibleTypes = { };
const int fColumns = -1;
const int fRows = -1;
const std::vector<Field> fFields = { };
@@ -338,101 +340,6 @@ private:
const bool fIsSampled = false;
};
-extern const std::shared_ptr<Type> kVoid_Type;
-
-extern const std::shared_ptr<Type> kFloat_Type;
-extern const std::shared_ptr<Type> kVec2_Type;
-extern const std::shared_ptr<Type> kVec3_Type;
-extern const std::shared_ptr<Type> kVec4_Type;
-extern const std::shared_ptr<Type> kDouble_Type;
-extern const std::shared_ptr<Type> kDVec2_Type;
-extern const std::shared_ptr<Type> kDVec3_Type;
-extern const std::shared_ptr<Type> kDVec4_Type;
-extern const std::shared_ptr<Type> kInt_Type;
-extern const std::shared_ptr<Type> kIVec2_Type;
-extern const std::shared_ptr<Type> kIVec3_Type;
-extern const std::shared_ptr<Type> kIVec4_Type;
-extern const std::shared_ptr<Type> kUInt_Type;
-extern const std::shared_ptr<Type> kUVec2_Type;
-extern const std::shared_ptr<Type> kUVec3_Type;
-extern const std::shared_ptr<Type> kUVec4_Type;
-extern const std::shared_ptr<Type> kBool_Type;
-extern const std::shared_ptr<Type> kBVec2_Type;
-extern const std::shared_ptr<Type> kBVec3_Type;
-extern const std::shared_ptr<Type> kBVec4_Type;
-
-extern const std::shared_ptr<Type> kMat2x2_Type;
-extern const std::shared_ptr<Type> kMat2x3_Type;
-extern const std::shared_ptr<Type> kMat2x4_Type;
-extern const std::shared_ptr<Type> kMat3x2_Type;
-extern const std::shared_ptr<Type> kMat3x3_Type;
-extern const std::shared_ptr<Type> kMat3x4_Type;
-extern const std::shared_ptr<Type> kMat4x2_Type;
-extern const std::shared_ptr<Type> kMat4x3_Type;
-extern const std::shared_ptr<Type> kMat4x4_Type;
-
-extern const std::shared_ptr<Type> kDMat2x2_Type;
-extern const std::shared_ptr<Type> kDMat2x3_Type;
-extern const std::shared_ptr<Type> kDMat2x4_Type;
-extern const std::shared_ptr<Type> kDMat3x2_Type;
-extern const std::shared_ptr<Type> kDMat3x3_Type;
-extern const std::shared_ptr<Type> kDMat3x4_Type;
-extern const std::shared_ptr<Type> kDMat4x2_Type;
-extern const std::shared_ptr<Type> kDMat4x3_Type;
-extern const std::shared_ptr<Type> kDMat4x4_Type;
-
-extern const std::shared_ptr<Type> kSampler1D_Type;
-extern const std::shared_ptr<Type> kSampler2D_Type;
-extern const std::shared_ptr<Type> kSampler3D_Type;
-extern const std::shared_ptr<Type> kSamplerCube_Type;
-extern const std::shared_ptr<Type> kSampler2DRect_Type;
-extern const std::shared_ptr<Type> kSampler1DArray_Type;
-extern const std::shared_ptr<Type> kSampler2DArray_Type;
-extern const std::shared_ptr<Type> kSamplerCubeArray_Type;
-extern const std::shared_ptr<Type> kSamplerBuffer_Type;
-extern const std::shared_ptr<Type> kSampler2DMS_Type;
-extern const std::shared_ptr<Type> kSampler2DMSArray_Type;
-
-extern const std::shared_ptr<Type> kGSampler1D_Type;
-extern const std::shared_ptr<Type> kGSampler2D_Type;
-extern const std::shared_ptr<Type> kGSampler3D_Type;
-extern const std::shared_ptr<Type> kGSamplerCube_Type;
-extern const std::shared_ptr<Type> kGSampler2DRect_Type;
-extern const std::shared_ptr<Type> kGSampler1DArray_Type;
-extern const std::shared_ptr<Type> kGSampler2DArray_Type;
-extern const std::shared_ptr<Type> kGSamplerCubeArray_Type;
-extern const std::shared_ptr<Type> kGSamplerBuffer_Type;
-extern const std::shared_ptr<Type> kGSampler2DMS_Type;
-extern const std::shared_ptr<Type> kGSampler2DMSArray_Type;
-
-extern const std::shared_ptr<Type> kSampler1DShadow_Type;
-extern const std::shared_ptr<Type> kSampler2DShadow_Type;
-extern const std::shared_ptr<Type> kSamplerCubeShadow_Type;
-extern const std::shared_ptr<Type> kSampler2DRectShadow_Type;
-extern const std::shared_ptr<Type> kSampler1DArrayShadow_Type;
-extern const std::shared_ptr<Type> kSampler2DArrayShadow_Type;
-extern const std::shared_ptr<Type> kSamplerCubeArrayShadow_Type;
-extern const std::shared_ptr<Type> kGSampler2DArrayShadow_Type;
-extern const std::shared_ptr<Type> kGSamplerCubeArrayShadow_Type;
-
-extern const std::shared_ptr<Type> kGenType_Type;
-extern const std::shared_ptr<Type> kGenDType_Type;
-extern const std::shared_ptr<Type> kGenIType_Type;
-extern const std::shared_ptr<Type> kGenUType_Type;
-extern const std::shared_ptr<Type> kGenBType_Type;
-extern const std::shared_ptr<Type> kMat_Type;
-extern const std::shared_ptr<Type> kVec_Type;
-extern const std::shared_ptr<Type> kGVec_Type;
-extern const std::shared_ptr<Type> kGVec2_Type;
-extern const std::shared_ptr<Type> kGVec3_Type;
-extern const std::shared_ptr<Type> kGVec4_Type;
-extern const std::shared_ptr<Type> kDVec_Type;
-extern const std::shared_ptr<Type> kIVec_Type;
-extern const std::shared_ptr<Type> kUVec_Type;
-extern const std::shared_ptr<Type> kBVec_Type;
-
-extern const std::shared_ptr<Type> kInvalid_Type;
-
} // namespace
#endif
diff --git a/src/sksl/ir/SkSLTypeReference.h b/src/sksl/ir/SkSLTypeReference.h
index 5f4990f35d..76923aae4d 100644
--- a/src/sksl/ir/SkSLTypeReference.h
+++ b/src/sksl/ir/SkSLTypeReference.h
@@ -8,6 +8,7 @@
#ifndef SKSL_TYPEREFERENCE
#define SKSL_TYPEREFERENCE
+#include "SkSLContext.h"
#include "SkSLExpression.h"
namespace SkSL {
@@ -17,16 +18,16 @@ namespace SkSL {
* always eventually replaced by Constructors in valid programs.
*/
struct TypeReference : public Expression {
- TypeReference(Position position, std::shared_ptr<Type> type)
- : INHERITED(position, kTypeReference_Kind, kInvalid_Type)
- , fValue(std::move(type)) {}
+ TypeReference(const Context& context, Position position, const Type& type)
+ : INHERITED(position, kTypeReference_Kind, *context.fInvalid_Type)
+ , fValue(type) {}
std::string description() const override {
ASSERT(false);
return "<type>";
}
- const std::shared_ptr<Type> fValue;
+ const Type& fValue;
typedef Expression INHERITED;
};
diff --git a/src/sksl/ir/SkSLUnresolvedFunction.h b/src/sksl/ir/SkSLUnresolvedFunction.h
index a6cee0d072..3a368ad8d3 100644
--- a/src/sksl/ir/SkSLUnresolvedFunction.h
+++ b/src/sksl/ir/SkSLUnresolvedFunction.h
@@ -16,19 +16,21 @@ namespace SkSL {
* A symbol representing multiple functions with the same name.
*/
struct UnresolvedFunction : public Symbol {
- UnresolvedFunction(std::vector<std::shared_ptr<FunctionDeclaration>> funcs)
+ UnresolvedFunction(std::vector<const FunctionDeclaration*> funcs)
: INHERITED(Position(), kUnresolvedFunction_Kind, funcs[0]->fName)
, fFunctions(std::move(funcs)) {
+#ifdef DEBUG
for (auto func : funcs) {
ASSERT(func->fName == fName);
}
+#endif
}
virtual std::string description() const override {
return fName;
}
- const std::vector<std::shared_ptr<FunctionDeclaration>> fFunctions;
+ const std::vector<const FunctionDeclaration*> fFunctions;
typedef Symbol INHERITED;
};
diff --git a/src/sksl/ir/SkSLVarDeclaration.h b/src/sksl/ir/SkSLVarDeclaration.h
index 400f430e4c..b234231b86 100644
--- a/src/sksl/ir/SkSLVarDeclaration.h
+++ b/src/sksl/ir/SkSLVarDeclaration.h
@@ -20,7 +20,7 @@ namespace SkSL {
* names ['x', 'y', 'z'], sizes of [[], [], [4, 2]], and values of [null, 1, null].
*/
struct VarDeclaration : public ProgramElement {
- VarDeclaration(Position position, std::vector<std::shared_ptr<Variable>> vars,
+ VarDeclaration(Position position, std::vector<const Variable*> vars,
std::vector<std::vector<std::unique_ptr<Expression>>> sizes,
std::vector<std::unique_ptr<Expression>> values)
: INHERITED(position, kVar_Kind)
@@ -30,9 +30,9 @@ struct VarDeclaration : public ProgramElement {
std::string description() const override {
std::string result = fVars[0]->fModifiers.description();
- std::shared_ptr<Type> baseType = fVars[0]->fType;
+ const Type* baseType = &fVars[0]->fType;
while (baseType->kind() == Type::kArray_Kind) {
- baseType = baseType->componentType();
+ baseType = &baseType->componentType();
}
result += baseType->description();
std::string separator = " ";
@@ -55,7 +55,7 @@ struct VarDeclaration : public ProgramElement {
return result;
}
- const std::vector<std::shared_ptr<Variable>> fVars;
+ const std::vector<const Variable*> fVars;
const std::vector<std::vector<std::unique_ptr<Expression>>> fSizes;
const std::vector<std::unique_ptr<Expression>> fValues;
diff --git a/src/sksl/ir/SkSLVariable.h b/src/sksl/ir/SkSLVariable.h
index d4ea2c4a43..39af3093b6 100644
--- a/src/sksl/ir/SkSLVariable.h
+++ b/src/sksl/ir/SkSLVariable.h
@@ -27,7 +27,7 @@ struct Variable : public Symbol {
kParameter_Storage
};
- Variable(Position position, Modifiers modifiers, std::string name, std::shared_ptr<Type> type,
+ Variable(Position position, Modifiers modifiers, std::string name, const Type& type,
Storage storage)
: INHERITED(position, kVariable_Kind, std::move(name))
, fModifiers(modifiers)
@@ -37,12 +37,11 @@ struct Variable : public Symbol {
, fIsWrittenTo(false) {}
virtual std::string description() const override {
- return fModifiers.description() + fType->fName + " " + fName;
+ return fModifiers.description() + fType.fName + " " + fName;
}
const Modifiers fModifiers;
- const std::string fValue;
- const std::shared_ptr<Type> fType;
+ const Type& fType;
const Storage fStorage;
mutable bool fIsReadFrom;
@@ -53,14 +52,4 @@ struct Variable : public Symbol {
} // namespace SkSL
-namespace std {
- template <>
- struct hash<SkSL::Variable> {
- public :
- size_t operator()(const SkSL::Variable &var) const{
- return hash<std::string>()(var.fName) ^ hash<std::string>()(var.fType->description());
- }
- };
-} // namespace std
-
#endif
diff --git a/src/sksl/ir/SkSLVariableReference.h b/src/sksl/ir/SkSLVariableReference.h
index 8499511a1b..b443da1f22 100644
--- a/src/sksl/ir/SkSLVariableReference.h
+++ b/src/sksl/ir/SkSLVariableReference.h
@@ -20,15 +20,15 @@ namespace SkSL {
* there is only one Variable 'x', but two VariableReferences to it.
*/
struct VariableReference : public Expression {
- VariableReference(Position position, std::shared_ptr<Variable> variable)
- : INHERITED(position, kVariableReference_Kind, variable->fType)
- , fVariable(std::move(variable)) {}
+ VariableReference(Position position, const Variable& variable)
+ : INHERITED(position, kVariableReference_Kind, variable.fType)
+ , fVariable(variable) {}
std::string description() const override {
- return fVariable->fName;
+ return fVariable.fName;
}
- const std::shared_ptr<Variable> fVariable;
+ const Variable& fVariable;
typedef Expression INHERITED;
};