diff options
Diffstat (limited to 'src')
34 files changed, 869 insertions, 895 deletions
diff --git a/src/gpu/vk/GrVkPipelineStateBuilder.cpp b/src/gpu/vk/GrVkPipelineStateBuilder.cpp index d9d1b6cfb8..323ea66946 100644 --- a/src/gpu/vk/GrVkPipelineStateBuilder.cpp +++ b/src/gpu/vk/GrVkPipelineStateBuilder.cpp @@ -93,6 +93,8 @@ shaderc_shader_kind vk_shader_stage_to_shaderc_kind(VkShaderStageFlagBits stage) } #endif +#include <fstream> +#include <sstream> bool GrVkPipelineStateBuilder::CreateVkShaderModule(const GrVkGpu* gpu, VkShaderStageFlagBits stage, const GrGLSLShaderBuilder& builder, diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp index ff125f8fdc..2b4adc1026 100644 --- a/src/sksl/SkSLCompiler.cpp +++ b/src/sksl/SkSLCompiler.cpp @@ -43,7 +43,7 @@ Compiler::Compiler() auto symbols = std::shared_ptr<SymbolTable>(new SymbolTable(types, *this)); fIRGenerator = new IRGenerator(symbols, *this); fTypes = types; - #define ADD_TYPE(t) types->addWithoutOwnership(k ## t ## _Type->fName, k ## t ## _Type) + #define ADD_TYPE(t) types->add(k ## t ## _Type->fName, k ## t ## _Type) ADD_TYPE(Void); ADD_TYPE(Float); ADD_TYPE(Vec2); @@ -185,21 +185,19 @@ std::unique_ptr<Program> Compiler::convertProgram(Program::Kind kind, std::strin fErrorText = ""; fErrorCount = 0; fIRGenerator->pushSymbolTable(); - std::vector<std::unique_ptr<ProgramElement>> elements; + std::vector<std::unique_ptr<ProgramElement>> result; switch (kind) { case Program::kVertex_Kind: - this->internalConvertProgram(SKSL_VERT_INCLUDE, &elements); + this->internalConvertProgram(SKSL_VERT_INCLUDE, &result); break; case Program::kFragment_Kind: - this->internalConvertProgram(SKSL_FRAG_INCLUDE, &elements); + this->internalConvertProgram(SKSL_FRAG_INCLUDE, &result); break; } - this->internalConvertProgram(text, &elements); - auto result = std::unique_ptr<Program>(new Program(kind, std::move(elements), - fIRGenerator->fSymbolTable));; + this->internalConvertProgram(text, &result); fIRGenerator->popSymbolTable(); this->writeErrorCount(); - return result; + return std::unique_ptr<Program>(new Program(kind, std::move(result)));; } void Compiler::error(Position position, std::string msg) { diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp index 6efaad0d58..2cc7eacb4d 100644 --- a/src/sksl/SkSLIRGenerator.cpp +++ b/src/sksl/SkSLIRGenerator.cpp @@ -68,9 +68,9 @@ public: IRGenerator::IRGenerator(std::shared_ptr<SymbolTable> symbolTable, ErrorReporter& errorReporter) -: fCurrentFunction(nullptr) -, fSymbolTable(std::move(symbolTable)) -, fErrors(errorReporter) {} +: fSymbolTable(std::move(symbolTable)) +, fErrors(errorReporter) { +} void IRGenerator::pushSymbolTable() { fSymbolTable.reset(new SymbolTable(std::move(fSymbolTable), fErrors)); @@ -123,7 +123,7 @@ std::unique_ptr<Block> IRGenerator::convertBlock(const ASTBlock& block) { } statements.push_back(std::move(statement)); } - return std::unique_ptr<Block>(new Block(block.fPosition, std::move(statements), fSymbolTable)); + return std::unique_ptr<Block>(new Block(block.fPosition, std::move(statements))); } std::unique_ptr<Statement> IRGenerator::convertVarDeclarationStatement( @@ -141,22 +141,22 @@ Modifiers IRGenerator::convertModifiers(const ASTModifiers& modifiers) { std::unique_ptr<VarDeclaration> IRGenerator::convertVarDeclaration(const ASTVarDeclaration& decl, Variable::Storage storage) { - std::vector<const Variable*> variables; + std::vector<std::shared_ptr<Variable>> variables; std::vector<std::vector<std::unique_ptr<Expression>>> sizes; std::vector<std::unique_ptr<Expression>> values; - const Type* baseType = this->convertType(*decl.fType); + std::shared_ptr<Type> baseType = this->convertType(*decl.fType); if (!baseType) { return nullptr; } for (size_t i = 0; i < decl.fNames.size(); i++) { Modifiers modifiers = this->convertModifiers(decl.fModifiers); - const Type* type = baseType; + std::shared_ptr<Type> type = baseType; ASSERT(type->kind() != Type::kArray_Kind); std::vector<std::unique_ptr<Expression>> currentVarSizes; for (size_t j = 0; j < decl.fSizes[i].size(); j++) { if (decl.fSizes[i][j]) { ASTExpression& rawSize = *decl.fSizes[i][j]; - auto size = this->coerce(this->convertExpression(rawSize), *kInt_Type); + auto size = this->coerce(this->convertExpression(rawSize), kInt_Type); if (!size) { return nullptr; } @@ -172,28 +172,27 @@ std::unique_ptr<VarDeclaration> IRGenerator::convertVarDeclaration(const ASTVarD count = -1; name += "[]"; } - type = new Type(name, Type::kArray_Kind, *type, (int) count); - fSymbolTable->takeOwnership((Type*) type); + type = std::shared_ptr<Type>(new Type(name, Type::kArray_Kind, type, (int) count)); currentVarSizes.push_back(std::move(size)); } else { - type = new Type(type->fName + "[]", Type::kArray_Kind, *type, -1); - fSymbolTable->takeOwnership((Type*) type); + type = std::shared_ptr<Type>(new Type(type->fName + "[]", Type::kArray_Kind, type, + -1)); currentVarSizes.push_back(nullptr); } } sizes.push_back(std::move(currentVarSizes)); - auto var = std::unique_ptr<Variable>(new Variable(decl.fPosition, modifiers, decl.fNames[i], - *type, storage)); + auto var = std::make_shared<Variable>(decl.fPosition, modifiers, decl.fNames[i], type, + storage); + variables.push_back(var); std::unique_ptr<Expression> value; if (decl.fValues[i]) { value = this->convertExpression(*decl.fValues[i]); if (!value) { return nullptr; } - value = this->coerce(std::move(value), *type); + value = this->coerce(std::move(value), type); } - variables.push_back(var.get()); - fSymbolTable->add(decl.fNames[i], std::move(var)); + fSymbolTable->add(var->fName, var); values.push_back(std::move(value)); } return std::unique_ptr<VarDeclaration>(new VarDeclaration(decl.fPosition, std::move(variables), @@ -201,7 +200,7 @@ std::unique_ptr<VarDeclaration> IRGenerator::convertVarDeclaration(const ASTVarD } std::unique_ptr<Statement> IRGenerator::convertIf(const ASTIfStatement& s) { - std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*s.fTest), *kBool_Type); + std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*s.fTest), kBool_Type); if (!test) { return nullptr; } @@ -226,7 +225,7 @@ std::unique_ptr<Statement> IRGenerator::convertFor(const ASTForStatement& f) { if (!initializer) { return nullptr; } - std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*f.fTest), *kBool_Type); + std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*f.fTest), kBool_Type); if (!test) { return nullptr; } @@ -241,11 +240,11 @@ std::unique_ptr<Statement> IRGenerator::convertFor(const ASTForStatement& f) { } return std::unique_ptr<Statement>(new ForStatement(f.fPosition, std::move(initializer), std::move(test), std::move(next), - std::move(statement), fSymbolTable)); + std::move(statement))); } std::unique_ptr<Statement> IRGenerator::convertWhile(const ASTWhileStatement& w) { - std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*w.fTest), *kBool_Type); + std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*w.fTest), kBool_Type); if (!test) { return nullptr; } @@ -258,7 +257,7 @@ std::unique_ptr<Statement> IRGenerator::convertWhile(const ASTWhileStatement& w) } std::unique_ptr<Statement> IRGenerator::convertDo(const ASTDoStatement& d) { - std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*d.fTest), *kBool_Type); + std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*d.fTest), kBool_Type); if (!test) { return nullptr; } @@ -287,7 +286,7 @@ std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTReturnStatement& if (!result) { return nullptr; } - if (fCurrentFunction->fReturnType == *kVoid_Type) { + if (fCurrentFunction->fReturnType == kVoid_Type) { fErrors.error(result->fPosition, "may not return a value from a void function"); } else { result = this->coerce(std::move(result), fCurrentFunction->fReturnType); @@ -297,9 +296,9 @@ std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTReturnStatement& } return std::unique_ptr<Statement>(new ReturnStatement(std::move(result))); } else { - if (fCurrentFunction->fReturnType != *kVoid_Type) { + if (fCurrentFunction->fReturnType != kVoid_Type) { fErrors.error(r.fPosition, "expected function to return '" + - fCurrentFunction->fReturnType.description() + "'"); + fCurrentFunction->fReturnType->description() + "'"); } return std::unique_ptr<Statement>(new ReturnStatement(r.fPosition)); } @@ -317,74 +316,80 @@ std::unique_ptr<Statement> IRGenerator::convertDiscard(const ASTDiscardStatement return std::unique_ptr<Statement>(new DiscardStatement(d.fPosition)); } -static const Type& expand_generics(const Type& type, int i) { - if (type.kind() == Type::kGeneric_Kind) { - return *type.coercibleTypes()[i]; +static std::shared_ptr<Type> expand_generics(std::shared_ptr<Type> type, int i) { + if (type->kind() == Type::kGeneric_Kind) { + return type->coercibleTypes()[i]; } return type; } -static void expand_generics(const FunctionDeclaration& decl, - std::shared_ptr<SymbolTable> symbolTable) { +static void expand_generics(FunctionDeclaration& decl, + SymbolTable& symbolTable) { for (int i = 0; i < 4; i++) { - const Type& returnType = expand_generics(decl.fReturnType, i); - std::vector<const Variable*> parameters; + std::shared_ptr<Type> returnType = expand_generics(decl.fReturnType, i); + std::vector<std::shared_ptr<Variable>> arguments; for (const auto& p : decl.fParameters) { - Variable* var = new Variable(p->fPosition, Modifiers(p->fModifiers), p->fName, - expand_generics(p->fType, i), - Variable::kParameter_Storage); - symbolTable->takeOwnership(var); - parameters.push_back(var); + arguments.push_back(std::shared_ptr<Variable>(new Variable( + p->fPosition, + Modifiers(p->fModifiers), + p->fName, + expand_generics(p->fType, i), + Variable::kParameter_Storage))); } - symbolTable->add(decl.fName, std::unique_ptr<FunctionDeclaration>(new FunctionDeclaration( - decl.fPosition, - decl.fName, - std::move(parameters), - std::move(returnType)))); + std::shared_ptr<FunctionDeclaration> expanded(new FunctionDeclaration( + decl.fPosition, + decl.fName, + std::move(arguments), + std::move(returnType))); + symbolTable.add(expanded->fName, expanded); } } std::unique_ptr<FunctionDefinition> IRGenerator::convertFunction(const ASTFunction& f) { + std::shared_ptr<SymbolTable> old = fSymbolTable; + AutoSymbolTable table(this); bool isGeneric; - const Type* returnType = this->convertType(*f.fReturnType); + std::shared_ptr<Type> returnType = this->convertType(*f.fReturnType); if (!returnType) { return nullptr; } isGeneric = returnType->kind() == Type::kGeneric_Kind; - std::vector<const Variable*> parameters; + std::vector<std::shared_ptr<Variable>> parameters; for (const auto& param : f.fParameters) { - const Type* type = this->convertType(*param->fType); + std::shared_ptr<Type> type = this->convertType(*param->fType); if (!type) { return nullptr; } for (int j = (int) param->fSizes.size() - 1; j >= 0; j--) { int size = param->fSizes[j]; std::string name = type->name() + "[" + to_string(size) + "]"; - Type* newType = new Type(std::move(name), Type::kArray_Kind, *type, size); - fSymbolTable->takeOwnership(newType); - type = newType; + type = std::shared_ptr<Type>(new Type(std::move(name), Type::kArray_Kind, + std::move(type), size)); } std::string name = param->fName; Modifiers modifiers = this->convertModifiers(param->fModifiers); Position pos = param->fPosition; - Variable* var = new Variable(pos, modifiers, std::move(name), *type, - Variable::kParameter_Storage); - fSymbolTable->takeOwnership(var); - parameters.push_back(var); + std::shared_ptr<Variable> var = std::shared_ptr<Variable>(new Variable( + pos, + modifiers, + std::move(name), + type, + Variable::kParameter_Storage)); + parameters.push_back(std::move(var)); isGeneric |= type->kind() == Type::kGeneric_Kind; } // find existing declaration - const FunctionDeclaration* decl = nullptr; - auto entry = (*fSymbolTable)[f.fName]; + std::shared_ptr<FunctionDeclaration> decl; + auto entry = (*old)[f.fName]; if (entry) { - std::vector<const FunctionDeclaration*> functions; + std::vector<std::shared_ptr<FunctionDeclaration>> functions; switch (entry->fKind) { case Symbol::kUnresolvedFunction_Kind: - functions = ((UnresolvedFunction*) entry)->fFunctions; + functions = std::static_pointer_cast<UnresolvedFunction>(entry)->fFunctions; break; case Symbol::kFunctionDeclaration_Kind: - functions.push_back((FunctionDeclaration*) entry); + functions.push_back(std::static_pointer_cast<FunctionDeclaration>(entry)); break; default: fErrors.error(f.fPosition, "symbol '" + f.fName + "' was already defined"); @@ -401,8 +406,11 @@ std::unique_ptr<FunctionDefinition> IRGenerator::convertFunction(const ASTFuncti } } if (match) { - if (*returnType != other->fReturnType) { - FunctionDeclaration newDecl(f.fPosition, f.fName, parameters, *returnType); + if (returnType != other->fReturnType) { + FunctionDeclaration newDecl = FunctionDeclaration(f.fPosition, + f.fName, + parameters, + returnType); fErrors.error(f.fPosition, "functions '" + newDecl.description() + "' and '" + other->description() + "' differ only in return type"); @@ -416,6 +424,7 @@ std::unique_ptr<FunctionDefinition> IRGenerator::convertFunction(const ASTFuncti "declaration and definition"); return nullptr; } + fSymbolTable->add(parameters[i]->fName, decl->fParameters[i]); } if (other->fDefined) { fErrors.error(f.fPosition, "duplicate definition of " + @@ -428,36 +437,28 @@ std::unique_ptr<FunctionDefinition> IRGenerator::convertFunction(const ASTFuncti } if (!decl) { // couldn't find an existing declaration - if (isGeneric) { - ASSERT(!f.fBody); - expand_generics(FunctionDeclaration(f.fPosition, f.fName, parameters, *returnType), - fSymbolTable); - } else { - auto newDecl = std::unique_ptr<FunctionDeclaration>(new FunctionDeclaration( - f.fPosition, - f.fName, - parameters, - *returnType)); - decl = newDecl.get(); - fSymbolTable->add(decl->fName, std::move(newDecl)); + decl.reset(new FunctionDeclaration(f.fPosition, f.fName, parameters, returnType)); + for (auto var : parameters) { + fSymbolTable->add(var->fName, var); } } - if (f.fBody) { - ASSERT(!fCurrentFunction); - fCurrentFunction = decl; - decl->fDefined = true; - std::shared_ptr<SymbolTable> old = fSymbolTable; - AutoSymbolTable table(this); - for (size_t i = 0; i < parameters.size(); i++) { - fSymbolTable->addWithoutOwnership(parameters[i]->fName, decl->fParameters[i]); - } - std::unique_ptr<Block> body = this->convertBlock(*f.fBody); - fCurrentFunction = nullptr; - if (!body) { - return nullptr; + if (isGeneric) { + ASSERT(!f.fBody); + expand_generics(*decl, *old); + } else { + old->add(decl->fName, decl); + if (f.fBody) { + ASSERT(!fCurrentFunction); + fCurrentFunction = decl; + decl->fDefined = true; + std::unique_ptr<Block> body = this->convertBlock(*f.fBody); + fCurrentFunction = nullptr; + if (!body) { + return nullptr; + } + return std::unique_ptr<FunctionDefinition>(new FunctionDefinition(f.fPosition, decl, + std::move(body))); } - return std::unique_ptr<FunctionDefinition>(new FunctionDefinition(f.fPosition, *decl, - std::move(body))); } return nullptr; } @@ -487,26 +488,28 @@ std::unique_ptr<InterfaceBlock> IRGenerator::convertInterfaceBlock(const ASTInte } } } - Type* type = new Type(intf.fInterfaceName, fields); - fSymbolTable->takeOwnership(type); + std::shared_ptr<Type> type = std::shared_ptr<Type>(new Type(intf.fInterfaceName, fields)); std::string name = intf.fValueName.length() > 0 ? intf.fValueName : intf.fInterfaceName; - Variable* var = new Variable(intf.fPosition, mods, name, *type, Variable::kGlobal_Storage); - fSymbolTable->takeOwnership(var); + std::shared_ptr<Variable> var = std::shared_ptr<Variable>(new Variable(intf.fPosition, mods, + name, type, + Variable::kGlobal_Storage)); if (intf.fValueName.length()) { - old->addWithoutOwnership(intf.fValueName, var); + old->add(intf.fValueName, var); + } else { for (size_t i = 0; i < fields.size(); i++) { - old->add(fields[i].fName, std::unique_ptr<Field>(new Field(intf.fPosition, *var, - (int) i))); + std::shared_ptr<Field> field = std::shared_ptr<Field>(new Field(intf.fPosition, var, + (int) i)); + old->add(fields[i].fName, field); } } - return std::unique_ptr<InterfaceBlock>(new InterfaceBlock(intf.fPosition, *var, fSymbolTable)); + return std::unique_ptr<InterfaceBlock>(new InterfaceBlock(intf.fPosition, var)); } -const Type* IRGenerator::convertType(const ASTType& type) { - const Symbol* result = (*fSymbolTable)[type.fName]; +std::shared_ptr<Type> IRGenerator::convertType(const ASTType& type) { + std::shared_ptr<Symbol> result = (*fSymbolTable)[type.fName]; if (result && result->fKind == Symbol::kType_Kind) { - return (const Type*) result; + return std::static_pointer_cast<Type>(result); } fErrors.error(type.fPosition, "unknown type '" + type.fName + "'"); return nullptr; @@ -539,40 +542,40 @@ std::unique_ptr<Expression> IRGenerator::convertExpression(const ASTExpression& } std::unique_ptr<Expression> IRGenerator::convertIdentifier(const ASTIdentifier& identifier) { - const Symbol* result = (*fSymbolTable)[identifier.fText]; + std::shared_ptr<Symbol> result = (*fSymbolTable)[identifier.fText]; if (!result) { fErrors.error(identifier.fPosition, "unknown identifier '" + identifier.fText + "'"); return nullptr; } switch (result->fKind) { case Symbol::kFunctionDeclaration_Kind: { - std::vector<const FunctionDeclaration*> f = { - (const FunctionDeclaration*) result + std::vector<std::shared_ptr<FunctionDeclaration>> f = { + std::static_pointer_cast<FunctionDeclaration>(result) }; return std::unique_ptr<FunctionReference>(new FunctionReference(identifier.fPosition, - f)); + std::move(f))); } case Symbol::kUnresolvedFunction_Kind: { - const UnresolvedFunction* f = (const UnresolvedFunction*) result; + auto f = std::static_pointer_cast<UnresolvedFunction>(result); return std::unique_ptr<FunctionReference>(new FunctionReference(identifier.fPosition, f->fFunctions)); } case Symbol::kVariable_Kind: { - const Variable* var = (const Variable*) result; - this->markReadFrom(*var); + std::shared_ptr<Variable> var = std::static_pointer_cast<Variable>(result); + this->markReadFrom(var); return std::unique_ptr<VariableReference>(new VariableReference(identifier.fPosition, - *var)); + std::move(var))); } case Symbol::kField_Kind: { - const Field* field = (const Field*) result; + std::shared_ptr<Field> field = std::static_pointer_cast<Field>(result); VariableReference* base = new VariableReference(identifier.fPosition, field->fOwner); return std::unique_ptr<Expression>(new FieldAccess(std::unique_ptr<Expression>(base), field->fFieldIndex)); } case Symbol::kType_Kind: { - const Type* t = (const Type*) result; + auto t = std::static_pointer_cast<Type>(result); return std::unique_ptr<TypeReference>(new TypeReference(identifier.fPosition, - *t)); + std::move(t))); } default: ABORT("unsupported symbol type %d\n", result->fKind); @@ -581,42 +584,43 @@ std::unique_ptr<Expression> IRGenerator::convertIdentifier(const ASTIdentifier& } std::unique_ptr<Expression> IRGenerator::coerce(std::unique_ptr<Expression> expr, - const Type& type) { + std::shared_ptr<Type> type) { if (!expr) { return nullptr; } - if (expr->fType == type) { + if (*expr->fType == *type) { return expr; } this->checkValid(*expr); - if (expr->fType == *kInvalid_Type) { + if (*expr->fType == *kInvalid_Type) { return nullptr; } - if (!expr->fType.canCoerceTo(type)) { - fErrors.error(expr->fPosition, "expected '" + type.description() + "', but found '" + - expr->fType.description() + "'"); + if (!expr->fType->canCoerceTo(type)) { + fErrors.error(expr->fPosition, "expected '" + type->description() + "', but found '" + + expr->fType->description() + "'"); return nullptr; } - if (type.kind() == Type::kScalar_Kind) { + if (type->kind() == Type::kScalar_Kind) { std::vector<std::unique_ptr<Expression>> args; args.push_back(std::move(expr)); - ASTIdentifier id(Position(), type.description()); + ASTIdentifier id(Position(), type->description()); std::unique_ptr<Expression> ctor = this->convertIdentifier(id); ASSERT(ctor); return this->call(Position(), std::move(ctor), std::move(args)); } - ABORT("cannot coerce %s to %s", expr->fType.description().c_str(), - type.description().c_str()); + ABORT("cannot coerce %s to %s", expr->fType->description().c_str(), + type->description().c_str()); } /** * Determines the operand and result types of a binary expression. Returns true if the expression is * legal, false otherwise. If false, the values of the out parameters are undefined. */ -static bool determine_binary_type(Token::Kind op, const Type& left, const Type& right, - const Type** outLeftType, - const Type** outRightType, - const Type** outResultType, +static bool determine_binary_type(Token::Kind op, std::shared_ptr<Type> left, + std::shared_ptr<Type> right, + std::shared_ptr<Type>* outLeftType, + std::shared_ptr<Type>* outRightType, + std::shared_ptr<Type>* outResultType, bool tryFlipped) { bool isLogical; switch (op) { @@ -637,21 +641,21 @@ static bool determine_binary_type(Token::Kind op, const Type& left, const Type& *outLeftType = kBool_Type; *outRightType = kBool_Type; *outResultType = kBool_Type; - return left.canCoerceTo(*kBool_Type) && right.canCoerceTo(*kBool_Type); + return left->canCoerceTo(kBool_Type) && right->canCoerceTo(kBool_Type); case Token::STAR: // fall through case Token::STAREQ: // FIXME need to handle non-square matrices - if (left.kind() == Type::kMatrix_Kind && right.kind() == Type::kVector_Kind) { - *outLeftType = &left; - *outRightType = &right; - *outResultType = &right; - return left.rows() == right.columns(); + if (left->kind() == Type::kMatrix_Kind && right->kind() == Type::kVector_Kind) { + *outLeftType = left; + *outRightType = right; + *outResultType = right; + return left->rows() == right->columns(); } - if (left.kind() == Type::kVector_Kind && right.kind() == Type::kMatrix_Kind) { - *outLeftType = &left; - *outRightType = &right; - *outResultType = &left; - return left.columns() == right.columns(); + if (left->kind() == Type::kVector_Kind && right->kind() == Type::kMatrix_Kind) { + *outLeftType = left; + *outRightType = right; + *outResultType = left; + return left->columns() == right->columns(); } // fall through default: @@ -660,33 +664,33 @@ static bool determine_binary_type(Token::Kind op, const Type& left, const Type& // FIXME: need to disallow illegal operations like vec3 > vec3. Also do not currently have // full support for numbers other than float. if (left == right) { - *outLeftType = &left; - *outRightType = &left; + *outLeftType = left; + *outRightType = left; if (isLogical) { *outResultType = kBool_Type; } else { - *outResultType = &left; + *outResultType = left; } return true; } // FIXME: incorrect for shift operations - if (left.canCoerceTo(right)) { - *outLeftType = &right; - *outRightType = &right; + if (left->canCoerceTo(right)) { + *outLeftType = right; + *outRightType = right; if (isLogical) { *outResultType = kBool_Type; } else { - *outResultType = &right; + *outResultType = right; } return true; } - if ((left.kind() == Type::kVector_Kind || left.kind() == Type::kMatrix_Kind) && - (right.kind() == Type::kScalar_Kind)) { - if (determine_binary_type(op, left.componentType(), right, outLeftType, outRightType, + if ((left->kind() == Type::kVector_Kind || left->kind() == Type::kMatrix_Kind) && + (right->kind() == Type::kScalar_Kind)) { + if (determine_binary_type(op, left->componentType(), right, outLeftType, outRightType, outResultType, false)) { - *outLeftType = &(*outLeftType)->toCompound(left.columns(), left.rows()); + *outLeftType = (*outLeftType)->toCompound(left->columns(), left->rows()); if (!isLogical) { - *outResultType = &(*outResultType)->toCompound(left.columns(), left.rows()); + *outResultType = (*outResultType)->toCompound(left->columns(), left->rows()); } return true; } @@ -709,15 +713,15 @@ std::unique_ptr<Expression> IRGenerator::convertBinaryExpression( if (!right) { return nullptr; } - const Type* leftType; - const Type* rightType; - const Type* resultType; + std::shared_ptr<Type> leftType; + std::shared_ptr<Type> rightType; + std::shared_ptr<Type> resultType; if (!determine_binary_type(expression.fOperator, left->fType, right->fType, &leftType, &rightType, &resultType, true)) { fErrors.error(expression.fPosition, "type mismatch: '" + Token::OperatorName(expression.fOperator) + - "' cannot operate on '" + left->fType.fName + - "', '" + right->fType.fName + "'"); + "' cannot operate on '" + left->fType->fName + + "', '" + right->fType->fName + "'"); return nullptr; } switch (expression.fOperator) { @@ -740,18 +744,17 @@ std::unique_ptr<Expression> IRGenerator::convertBinaryExpression( break; } return std::unique_ptr<Expression>(new BinaryExpression(expression.fPosition, - this->coerce(std::move(left), - *leftType), + this->coerce(std::move(left), leftType), expression.fOperator, this->coerce(std::move(right), - *rightType), - *resultType)); + rightType), + resultType)); } std::unique_ptr<Expression> IRGenerator::convertTernaryExpression( const ASTTernaryExpression& expression) { std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*expression.fTest), - *kBool_Type); + kBool_Type); if (!test) { return nullptr; } @@ -763,33 +766,34 @@ std::unique_ptr<Expression> IRGenerator::convertTernaryExpression( if (!ifFalse) { return nullptr; } - const Type* trueType; - const Type* falseType; - const Type* resultType; + std::shared_ptr<Type> trueType; + std::shared_ptr<Type> falseType; + std::shared_ptr<Type> resultType; if (!determine_binary_type(Token::EQEQ, ifTrue->fType, ifFalse->fType, &trueType, &falseType, &resultType, true)) { fErrors.error(expression.fPosition, "ternary operator result mismatch: '" + - ifTrue->fType.fName + "', '" + - ifFalse->fType.fName + "'"); + ifTrue->fType->fName + "', '" + + ifFalse->fType->fName + "'"); return nullptr; } ASSERT(trueType == falseType); - ifTrue = this->coerce(std::move(ifTrue), *trueType); - ifFalse = this->coerce(std::move(ifFalse), *falseType); + ifTrue = this->coerce(std::move(ifTrue), trueType); + ifFalse = this->coerce(std::move(ifFalse), falseType); return std::unique_ptr<Expression>(new TernaryExpression(expression.fPosition, std::move(test), std::move(ifTrue), std::move(ifFalse))); } -std::unique_ptr<Expression> IRGenerator::call(Position position, - const FunctionDeclaration& function, - std::vector<std::unique_ptr<Expression>> arguments) { - if (function.fParameters.size() != arguments.size()) { - std::string msg = "call to '" + function.fName + "' expected " + - to_string(function.fParameters.size()) + +std::unique_ptr<Expression> IRGenerator::call( + Position position, + std::shared_ptr<FunctionDeclaration> function, + std::vector<std::unique_ptr<Expression>> arguments) { + if (function->fParameters.size() != arguments.size()) { + std::string msg = "call to '" + function->fName + "' expected " + + to_string(function->fParameters.size()) + " argument"; - if (function.fParameters.size() != 1) { + if (function->fParameters.size() != 1) { msg += "s"; } msg += ", but found " + to_string(arguments.size()); @@ -797,12 +801,12 @@ std::unique_ptr<Expression> IRGenerator::call(Position position, return nullptr; } for (size_t i = 0; i < arguments.size(); i++) { - arguments[i] = this->coerce(std::move(arguments[i]), function.fParameters[i]->fType); - if (arguments[i] && (function.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag)) { + arguments[i] = this->coerce(std::move(arguments[i]), function->fParameters[i]->fType); + if (arguments[i] && (function->fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag)) { this->markWrittenTo(*arguments[i]); } } - return std::unique_ptr<FunctionCall>(new FunctionCall(position, function, + return std::unique_ptr<FunctionCall>(new FunctionCall(position, std::move(function), std::move(arguments))); } @@ -811,16 +815,16 @@ std::unique_ptr<Expression> IRGenerator::call(Position position, * if the cost could be computed, false if the call is not valid. Cost has no particular meaning * other than "lower costs are preferred". */ -bool IRGenerator::determineCallCost(const FunctionDeclaration& function, +bool IRGenerator::determineCallCost(std::shared_ptr<FunctionDeclaration> function, const std::vector<std::unique_ptr<Expression>>& arguments, int* outCost) { - if (function.fParameters.size() != arguments.size()) { + if (function->fParameters.size() != arguments.size()) { return false; } int total = 0; for (size_t i = 0; i < arguments.size(); i++) { int cost; - if (arguments[i]->fType.determineCoercionCost(function.fParameters[i]->fType, &cost)) { + if (arguments[i]->fType->determineCoercionCost(function->fParameters[i]->fType, &cost)) { total += cost; } else { return false; @@ -844,43 +848,43 @@ std::unique_ptr<Expression> IRGenerator::call(Position position, } FunctionReference* ref = (FunctionReference*) functionValue.get(); int bestCost = INT_MAX; - const FunctionDeclaration* best = nullptr; + std::shared_ptr<FunctionDeclaration> best; if (ref->fFunctions.size() > 1) { for (const auto& f : ref->fFunctions) { int cost; - if (this->determineCallCost(*f, arguments, &cost) && cost < bestCost) { + if (this->determineCallCost(f, arguments, &cost) && cost < bestCost) { bestCost = cost; best = f; } } if (best) { - return this->call(position, *best, std::move(arguments)); + return this->call(position, std::move(best), std::move(arguments)); } std::string msg = "no match for " + ref->fFunctions[0]->fName + "("; std::string separator = ""; for (size_t i = 0; i < arguments.size(); i++) { msg += separator; separator = ", "; - msg += arguments[i]->fType.description(); + msg += arguments[i]->fType->description(); } msg += ")"; fErrors.error(position, msg); return nullptr; } - return this->call(position, *ref->fFunctions[0], std::move(arguments)); + return this->call(position, ref->fFunctions[0], std::move(arguments)); } std::unique_ptr<Expression> IRGenerator::convertConstructor( Position position, - const Type& type, + std::shared_ptr<Type> type, std::vector<std::unique_ptr<Expression>> args) { // FIXME: add support for structs and arrays - Type::Kind kind = type.kind(); - if (!type.isNumber() && kind != Type::kVector_Kind && kind != Type::kMatrix_Kind) { - fErrors.error(position, "cannot construct '" + type.description() + "'"); + Type::Kind kind = type->kind(); + if (!type->isNumber() && kind != Type::kVector_Kind && kind != Type::kMatrix_Kind) { + fErrors.error(position, "cannot construct '" + type->description() + "'"); return nullptr; } - if (type == *kFloat_Type && args.size() == 1 && + if (type == kFloat_Type && args.size() == 1 && args[0]->fKind == Expression::kIntLiteral_Kind) { int64_t value = ((IntLiteral&) *args[0]).fValue; return std::unique_ptr<Expression>(new FloatLiteral(position, (double) value)); @@ -889,13 +893,13 @@ std::unique_ptr<Expression> IRGenerator::convertConstructor( // argument is already the right type, just return it return std::move(args[0]); } - if (type.isNumber()) { + if (type->isNumber()) { if (args.size() != 1) { - fErrors.error(position, "invalid arguments to '" + type.description() + + fErrors.error(position, "invalid arguments to '" + type->description() + "' constructor, (expected exactly 1 argument, but found " + to_string(args.size()) + ")"); } - if (args[0]->fType == *kBool_Type) { + if (args[0]->fType == kBool_Type) { std::unique_ptr<IntLiteral> zero(new IntLiteral(position, 0)); std::unique_ptr<IntLiteral> one(new IntLiteral(position, 1)); return std::unique_ptr<Expression>( @@ -903,38 +907,38 @@ std::unique_ptr<Expression> IRGenerator::convertConstructor( this->coerce(std::move(one), type), this->coerce(std::move(zero), type))); - } else if (!args[0]->fType.isNumber()) { - fErrors.error(position, "invalid argument to '" + type.description() + + } else if (!args[0]->fType->isNumber()) { + fErrors.error(position, "invalid argument to '" + type->description() + "' constructor (expected a number or bool, but found '" + - args[0]->fType.description() + "')"); + args[0]->fType->description() + "')"); } } else { ASSERT(kind == Type::kVector_Kind || kind == Type::kMatrix_Kind); int actual = 0; for (size_t i = 0; i < args.size(); i++) { - if (args[i]->fType.kind() == Type::kVector_Kind || - args[i]->fType.kind() == Type::kMatrix_Kind) { - int columns = args[i]->fType.columns(); - int rows = args[i]->fType.rows(); + if (args[i]->fType->kind() == Type::kVector_Kind || + args[i]->fType->kind() == Type::kMatrix_Kind) { + int columns = args[i]->fType->columns(); + int rows = args[i]->fType->rows(); args[i] = this->coerce(std::move(args[i]), - type.componentType().toCompound(columns, rows)); - actual += args[i]->fType.rows() * args[i]->fType.columns(); - } else if (args[i]->fType.kind() == Type::kScalar_Kind) { + type->componentType()->toCompound(columns, rows)); + actual += args[i]->fType->rows() * args[i]->fType->columns(); + } else if (args[i]->fType->kind() == Type::kScalar_Kind) { actual += 1; - if (type.kind() != Type::kScalar_Kind) { - args[i] = this->coerce(std::move(args[i]), type.componentType()); + if (type->kind() != Type::kScalar_Kind) { + args[i] = this->coerce(std::move(args[i]), type->componentType()); } } else { - fErrors.error(position, "'" + args[i]->fType.description() + "' is not a valid " - "parameter to '" + type.description() + "' constructor"); + fErrors.error(position, "'" + args[i]->fType->description() + "' is not a valid " + "parameter to '" + type->description() + "' constructor"); return nullptr; } } - int min = type.rows() * type.columns(); - int max = type.columns() > 1 ? INT_MAX : min; + int min = type->rows() * type->columns(); + int max = type->columns() > 1 ? INT_MAX : min; if ((actual < min || actual > max) && !((kind == Type::kVector_Kind || kind == Type::kMatrix_Kind) && (actual == 1))) { - fErrors.error(position, "invalid arguments to '" + type.description() + + fErrors.error(position, "invalid arguments to '" + type->description() + "' constructor (expected " + to_string(min) + " scalar" + (min == 1 ? "" : "s") + ", but found " + to_string(actual) + ")"); @@ -952,16 +956,16 @@ std::unique_ptr<Expression> IRGenerator::convertPrefixExpression( } switch (expression.fOperator) { case Token::PLUS: - if (!base->fType.isNumber() && base->fType.kind() != Type::kVector_Kind) { + if (!base->fType->isNumber() && base->fType->kind() != Type::kVector_Kind) { fErrors.error(expression.fPosition, - "'+' cannot operate on '" + base->fType.description() + "'"); + "'+' cannot operate on '" + base->fType->description() + "'"); return nullptr; } return base; case Token::MINUS: - if (!base->fType.isNumber() && base->fType.kind() != Type::kVector_Kind) { + if (!base->fType->isNumber() && base->fType->kind() != Type::kVector_Kind) { fErrors.error(expression.fPosition, - "'-' cannot operate on '" + base->fType.description() + "'"); + "'-' cannot operate on '" + base->fType->description() + "'"); return nullptr; } if (base->fKind == Expression::kIntLiteral_Kind) { @@ -974,28 +978,28 @@ std::unique_ptr<Expression> IRGenerator::convertPrefixExpression( } return std::unique_ptr<Expression>(new PrefixExpression(Token::MINUS, std::move(base))); case Token::PLUSPLUS: - if (!base->fType.isNumber()) { + if (!base->fType->isNumber()) { fErrors.error(expression.fPosition, "'" + Token::OperatorName(expression.fOperator) + - "' cannot operate on '" + base->fType.description() + "'"); + "' cannot operate on '" + base->fType->description() + "'"); return nullptr; } this->markWrittenTo(*base); break; case Token::MINUSMINUS: - if (!base->fType.isNumber()) { + if (!base->fType->isNumber()) { fErrors.error(expression.fPosition, "'" + Token::OperatorName(expression.fOperator) + - "' cannot operate on '" + base->fType.description() + "'"); + "' cannot operate on '" + base->fType->description() + "'"); return nullptr; } this->markWrittenTo(*base); break; case Token::NOT: - if (base->fType != *kBool_Type) { + if (base->fType != kBool_Type) { fErrors.error(expression.fPosition, "'" + Token::OperatorName(expression.fOperator) + - "' cannot operate on '" + base->fType.description() + "'"); + "' cannot operate on '" + base->fType->description() + "'"); return nullptr; } break; @@ -1008,8 +1012,8 @@ std::unique_ptr<Expression> IRGenerator::convertPrefixExpression( std::unique_ptr<Expression> IRGenerator::convertIndex(std::unique_ptr<Expression> base, const ASTExpression& index) { - if (base->fType.kind() != Type::kArray_Kind && base->fType.kind() != Type::kMatrix_Kind) { - fErrors.error(base->fPosition, "expected array, but found '" + base->fType.description() + + if (base->fType->kind() != Type::kArray_Kind && base->fType->kind() != Type::kMatrix_Kind) { + fErrors.error(base->fPosition, "expected array, but found '" + base->fType->description() + "'"); return nullptr; } @@ -1017,7 +1021,7 @@ std::unique_ptr<Expression> IRGenerator::convertIndex(std::unique_ptr<Expression if (!converted) { return nullptr; } - converted = this->coerce(std::move(converted), *kInt_Type); + converted = this->coerce(std::move(converted), kInt_Type); if (!converted) { return nullptr; } @@ -1026,21 +1030,21 @@ std::unique_ptr<Expression> IRGenerator::convertIndex(std::unique_ptr<Expression std::unique_ptr<Expression> IRGenerator::convertField(std::unique_ptr<Expression> base, const std::string& field) { - auto fields = base->fType.fields(); + auto fields = base->fType->fields(); for (size_t i = 0; i < fields.size(); i++) { if (fields[i].fName == field) { return std::unique_ptr<Expression>(new FieldAccess(std::move(base), (int) i)); } } - fErrors.error(base->fPosition, "type '" + base->fType.description() + "' does not have a " + fErrors.error(base->fPosition, "type '" + base->fType->description() + "' does not have a " "field named '" + field + ""); return nullptr; } std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expression> base, const std::string& fields) { - if (base->fType.kind() != Type::kVector_Kind) { - fErrors.error(base->fPosition, "cannot swizzle type '" + base->fType.description() + "'"); + if (base->fType->kind() != Type::kVector_Kind) { + fErrors.error(base->fPosition, "cannot swizzle type '" + base->fType->description() + "'"); return nullptr; } std::vector<int> swizzleComponents; @@ -1054,7 +1058,7 @@ std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expressi case 'y': // fall through case 'g': // fall through case 't': - if (base->fType.columns() >= 2) { + if (base->fType->columns() >= 2) { swizzleComponents.push_back(1); break; } @@ -1062,7 +1066,7 @@ std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expressi case 'z': // fall through case 'b': // fall through case 'p': - if (base->fType.columns() >= 3) { + if (base->fType->columns() >= 3) { swizzleComponents.push_back(2); break; } @@ -1070,7 +1074,7 @@ std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expressi case 'w': // fall through case 'a': // fall through case 'q': - if (base->fType.columns() >= 4) { + if (base->fType->columns() >= 4) { swizzleComponents.push_back(3); break; } @@ -1113,7 +1117,7 @@ std::unique_ptr<Expression> IRGenerator::convertSuffixExpression( return this->call(expression.fPosition, std::move(base), std::move(arguments)); } case ASTSuffix::kField_Kind: { - switch (base->fType.kind()) { + switch (base->fType->kind()) { case Type::kVector_Kind: return this->convertSwizzle(std::move(base), ((ASTFieldSuffix&) *expression.fSuffix).fField); @@ -1122,23 +1126,23 @@ std::unique_ptr<Expression> IRGenerator::convertSuffixExpression( ((ASTFieldSuffix&) *expression.fSuffix).fField); default: fErrors.error(base->fPosition, "cannot swizzle value of type '" + - base->fType.description() + "'"); + base->fType->description() + "'"); return nullptr; } } case ASTSuffix::kPostIncrement_Kind: - if (!base->fType.isNumber()) { + if (!base->fType->isNumber()) { fErrors.error(expression.fPosition, - "'++' cannot operate on '" + base->fType.description() + "'"); + "'++' cannot operate on '" + base->fType->description() + "'"); return nullptr; } this->markWrittenTo(*base); return std::unique_ptr<Expression>(new PostfixExpression(std::move(base), Token::PLUSPLUS)); case ASTSuffix::kPostDecrement_Kind: - if (!base->fType.isNumber()) { + if (!base->fType->isNumber()) { fErrors.error(expression.fPosition, - "'--' cannot operate on '" + base->fType.description() + "'"); + "'--' cannot operate on '" + base->fType->description() + "'"); return nullptr; } this->markWrittenTo(*base); @@ -1158,13 +1162,13 @@ void IRGenerator::checkValid(const Expression& expr) { fErrors.error(expr.fPosition, "expected '(' to begin constructor invocation"); break; default: - ASSERT(expr.fType != *kInvalid_Type); + ASSERT(expr.fType != kInvalid_Type); break; } } -void IRGenerator::markReadFrom(const Variable& var) { - var.fIsReadFrom = true; +void IRGenerator::markReadFrom(std::shared_ptr<Variable> var) { + var->fIsReadFrom = true; } static bool has_duplicates(const Swizzle& swizzle) { @@ -1183,7 +1187,7 @@ static bool has_duplicates(const Swizzle& swizzle) { void IRGenerator::markWrittenTo(const Expression& expr) { switch (expr.fKind) { case Expression::kVariableReference_Kind: { - const Variable& var = ((VariableReference&) expr).fVariable; + const Variable& var = *((VariableReference&) expr).fVariable; if (var.fModifiers.fFlags & (Modifiers::kConst_Flag | Modifiers::kUniform_Flag)) { fErrors.error(expr.fPosition, "cannot modify immutable variable '" + var.fName + "'"); diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h index 7cf38f3fe5..d23e5a1bdb 100644 --- a/src/sksl/SkSLIRGenerator.h +++ b/src/sksl/SkSLIRGenerator.h @@ -65,20 +65,21 @@ private: void pushSymbolTable(); void popSymbolTable(); - const Type* convertType(const ASTType& type); + std::shared_ptr<Type> convertType(const ASTType& type); std::unique_ptr<Expression> call(Position position, - const FunctionDeclaration& function, + std::shared_ptr<FunctionDeclaration> function, std::vector<std::unique_ptr<Expression>> arguments); - bool determineCallCost(const FunctionDeclaration& function, + bool determineCallCost(std::shared_ptr<FunctionDeclaration> function, const std::vector<std::unique_ptr<Expression>>& arguments, int* outCost); std::unique_ptr<Expression> call(Position position, std::unique_ptr<Expression> function, std::vector<std::unique_ptr<Expression>> arguments); - std::unique_ptr<Expression> coerce(std::unique_ptr<Expression> expr, const Type& type); + std::unique_ptr<Expression> coerce(std::unique_ptr<Expression> expr, + std::shared_ptr<Type> type); std::unique_ptr<Block> convertBlock(const ASTBlock& block); std::unique_ptr<Statement> convertBreak(const ASTBreakStatement& b); std::unique_ptr<Expression> convertConstructor(Position position, - const Type& type, + std::shared_ptr<Type> type, std::vector<std::unique_ptr<Expression>> params); std::unique_ptr<Statement> convertContinue(const ASTContinueStatement& c); std::unique_ptr<Statement> convertDiscard(const ASTDiscardStatement& d); @@ -105,10 +106,10 @@ private: std::unique_ptr<Statement> convertWhile(const ASTWhileStatement& w); void checkValid(const Expression& expr); - void markReadFrom(const Variable& var); + void markReadFrom(std::shared_ptr<Variable> var); void markWrittenTo(const Expression& expr); - const FunctionDeclaration* fCurrentFunction; + std::shared_ptr<FunctionDeclaration> fCurrentFunction; std::shared_ptr<SymbolTable> fSymbolTable; ErrorReporter& fErrors; diff --git a/src/sksl/SkSLParser.cpp b/src/sksl/SkSLParser.cpp index edff0c67d1..fa302af0d3 100644 --- a/src/sksl/SkSLParser.cpp +++ b/src/sksl/SkSLParser.cpp @@ -52,7 +52,6 @@ #include "ast/SkSLASTVarDeclarationStatement.h" #include "ast/SkSLASTWhileStatement.h" #include "ir/SkSLSymbolTable.h" -#include "ir/SkSLType.h" namespace SkSL { @@ -291,17 +290,17 @@ std::unique_ptr<ASTType> Parser::structDeclaration() { return nullptr; } for (size_t i = 0; i < decl->fNames.size(); i++) { - auto type = (const Type*) fTypes[decl->fType->fName]; + auto type = std::static_pointer_cast<Type>(fTypes[decl->fType->fName]); for (int j = (int) decl->fSizes[i].size() - 1; j >= 0; j--) { - if (decl->fSizes[i][j]->fKind != ASTExpression::kInt_Kind) { + if (decl->fSizes[i][j]->fKind == ASTExpression::kInt_Kind) { this->error(decl->fPosition, "array size in struct field must be a constant"); } uint64_t columns = ((ASTIntLiteral&) *decl->fSizes[i][j]).fValue; std::string name = type->name() + "[" + to_string(columns) + "]"; - type = new Type(name, Type::kArray_Kind, *type, (int) columns); - fTypes.takeOwnership((Type*) type); + type = std::shared_ptr<Type>(new Type(name, Type::kArray_Kind, std::move(type), + (int) columns)); } - fields.push_back(Type::Field(decl->fModifiers, decl->fNames[i], *type)); + fields.push_back(Type::Field(decl->fModifiers, decl->fNames[i], std::move(type))); if (decl->fValues[i]) { this->error(decl->fPosition, "initializers are not permitted on struct fields"); } @@ -310,8 +309,9 @@ std::unique_ptr<ASTType> Parser::structDeclaration() { if (!this->expect(Token::RBRACE, "'}'")) { return nullptr; } - fTypes.add(name.fText, std::unique_ptr<Type>(new Type(name.fText, fields))); - return std::unique_ptr<ASTType>(new ASTType(name.fPosition, name.fText, + std::shared_ptr<Type> type(new Type(name.fText, fields)); + fTypes.add(type->fName, type); + return std::unique_ptr<ASTType>(new ASTType(name.fPosition, type->fName, ASTType::kStruct_Kind)); } diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp index 3823bc7961..0a2dab3adf 100644 --- a/src/sksl/SkSLSPIRVCodeGenerator.cpp +++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp @@ -143,34 +143,34 @@ void SPIRVCodeGenerator::writeWord(int32_t word, std::ostream& out) { static bool is_float(const Type& type) { if (type.kind() == Type::kVector_Kind) { - return is_float(type.componentType()); + return is_float(*type.componentType()); } return type == *kFloat_Type || type == *kDouble_Type; } static bool is_signed(const Type& type) { if (type.kind() == Type::kVector_Kind) { - return is_signed(type.componentType()); + return is_signed(*type.componentType()); } return type == *kInt_Type; } static bool is_unsigned(const Type& type) { if (type.kind() == Type::kVector_Kind) { - return is_unsigned(type.componentType()); + return is_unsigned(*type.componentType()); } return type == *kUInt_Type; } static bool is_bool(const Type& type) { if (type.kind() == Type::kVector_Kind) { - return is_bool(type.componentType()); + return is_bool(*type.componentType()); } return type == *kBool_Type; } -static bool is_out(const Variable& var) { - return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0; +static bool is_out(std::shared_ptr<Variable> var) { + return (var->fModifiers.fFlags & Modifiers::kOut_Flag) != 0; } #if SPIRV_DEBUG @@ -973,7 +973,7 @@ void SPIRVCodeGenerator::writeStruct(const Type& type, SpvId resultId) { // in the middle of writing the struct instruction std::vector<SpvId> types; for (const auto& f : type.fields()) { - types.push_back(this->getType(f.fType)); + types.push_back(this->getType(*f.fType)); } this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer); this->writeWord(resultId, fConstantBuffer); @@ -982,8 +982,8 @@ void SPIRVCodeGenerator::writeStruct(const Type& type, SpvId resultId) { } size_t offset = 0; for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) { - size_t size = type.fields()[i].fType.size(); - size_t alignment = type.fields()[i].fType.alignment(); + size_t size = type.fields()[i].fType->size(); + size_t alignment = type.fields()[i].fType->alignment(); size_t mod = offset % alignment; if (mod != 0) { offset += alignment - mod; @@ -995,14 +995,14 @@ void SPIRVCodeGenerator::writeStruct(const Type& type, SpvId resultId) { this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset, (SpvId) offset, fDecorationBuffer); } - if (type.fields()[i].fType.kind() == Type::kMatrix_Kind) { + if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) { this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor, fDecorationBuffer); this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride, - (SpvId) type.fields()[i].fType.stride(), fDecorationBuffer); + (SpvId) type.fields()[i].fType->stride(), fDecorationBuffer); } offset += size; - Type::Kind kind = type.fields()[i].fType.kind(); + Type::Kind kind = type.fields()[i].fType->kind(); if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) { offset += alignment - offset % alignment; } @@ -1032,11 +1032,11 @@ SpvId SPIRVCodeGenerator::getType(const Type& type) { break; case Type::kVector_Kind: this->writeInstruction(SpvOpTypeVector, result, - this->getType(type.componentType()), + this->getType(*type.componentType()), type.columns(), fConstantBuffer); break; case Type::kMatrix_Kind: - this->writeInstruction(SpvOpTypeMatrix, result, this->getType(index_type(type)), + this->writeInstruction(SpvOpTypeMatrix, result, this->getType(*index_type(type)), type.columns(), fConstantBuffer); break; case Type::kStruct_Kind: @@ -1046,14 +1046,14 @@ SpvId SPIRVCodeGenerator::getType(const Type& type) { if (type.columns() > 0) { IntLiteral count(Position(), type.columns()); this->writeInstruction(SpvOpTypeArray, result, - this->getType(type.componentType()), + this->getType(*type.componentType()), this->writeIntLiteral(count), fConstantBuffer); this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride, (int32_t) type.stride(), fDecorationBuffer); } else { ABORT("runtime-sized arrays are not yet supported"); this->writeInstruction(SpvOpTypeRuntimeArray, result, - this->getType(type.componentType()), fConstantBuffer); + this->getType(*type.componentType()), fConstantBuffer); } break; } @@ -1079,22 +1079,22 @@ SpvId SPIRVCodeGenerator::getType(const Type& type) { return entry->second; } -SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) { - std::string key = function.fReturnType.description() + "("; +SpvId SPIRVCodeGenerator::getFunctionType(std::shared_ptr<FunctionDeclaration> function) { + std::string key = function->fReturnType->description() + "("; std::string separator = ""; - for (size_t i = 0; i < function.fParameters.size(); i++) { + for (size_t i = 0; i < function->fParameters.size(); i++) { key += separator; separator = ", "; - key += function.fParameters[i]->fType.description(); + key += function->fParameters[i]->fType->description(); } key += ")"; auto entry = fTypeMap.find(key); if (entry == fTypeMap.end()) { SpvId result = this->nextId(); - int32_t length = 3 + (int32_t) function.fParameters.size(); - SpvId returnType = this->getType(function.fReturnType); + int32_t length = 3 + (int32_t) function->fParameters.size(); + SpvId returnType = this->getType(*function->fReturnType); std::vector<SpvId> parameterTypes; - for (size_t i = 0; i < function.fParameters.size(); i++) { + for (size_t i = 0; i < function->fParameters.size(); i++) { // glslang seems to treat all function arguments as pointers whether they need to be or // not. I was initially puzzled by this until I ran bizarre failures with certain // patterns of function calls and control constructs, as exemplified by this minimal @@ -1118,10 +1118,10 @@ SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) { // as glslang does, fixes it. It's entirely possible I simply missed whichever part of // the spec makes this make sense. // if (is_out(function->fParameters[i])) { - parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType, + parameterTypes.push_back(this->getPointerType(function->fParameters[i]->fType, SpvStorageClassFunction)); // } else { -// parameterTypes.push_back(this->getType(function.fParameters[i]->fType)); +// parameterTypes.push_back(this->getType(*function->fParameters[i]->fType)); // } } this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer); @@ -1136,14 +1136,14 @@ SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) { return entry->second; } -SpvId SPIRVCodeGenerator::getPointerType(const Type& type, +SpvId SPIRVCodeGenerator::getPointerType(std::shared_ptr<Type> type, SpvStorageClass_ storageClass) { - std::string key = type.description() + "*" + to_string(storageClass); + std::string key = type->description() + "*" + to_string(storageClass); auto entry = fTypeMap.find(key); if (entry == fTypeMap.end()) { SpvId result = this->nextId(); this->writeInstruction(SpvOpTypePointer, result, storageClass, - this->getType(type), fConstantBuffer); + this->getType(*type), fConstantBuffer); fTypeMap[key] = result; return result; } @@ -1185,21 +1185,21 @@ SpvId SPIRVCodeGenerator::writeExpression(Expression& expr, std::ostream& out) { } SpvId SPIRVCodeGenerator::writeIntrinsicCall(FunctionCall& c, std::ostream& out) { - auto intrinsic = fIntrinsicMap.find(c.fFunction.fName); + auto intrinsic = fIntrinsicMap.find(c.fFunction->fName); ASSERT(intrinsic != fIntrinsicMap.end()); - const Type& type = c.fArguments[0]->fType; + std::shared_ptr<Type> type = c.fArguments[0]->fType; int32_t intrinsicId; - if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(type)) { + if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(*type)) { intrinsicId = std::get<1>(intrinsic->second); - } else if (is_signed(type)) { + } else if (is_signed(*type)) { intrinsicId = std::get<2>(intrinsic->second); - } else if (is_unsigned(type)) { + } else if (is_unsigned(*type)) { intrinsicId = std::get<3>(intrinsic->second); - } else if (is_bool(type)) { + } else if (is_bool(*type)) { intrinsicId = std::get<4>(intrinsic->second); } else { ABORT("invalid call %s, cannot operate on '%s'", c.description().c_str(), - type.description().c_str()); + type->description().c_str()); } switch (std::get<0>(intrinsic->second)) { case kGLSL_STD_450_IntrinsicKind: { @@ -1209,7 +1209,7 @@ SpvId SPIRVCodeGenerator::writeIntrinsicCall(FunctionCall& c, std::ostream& out) arguments.push_back(this->writeExpression(*c.fArguments[i], out)); } this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out); - this->writeWord(this->getType(c.fType), out); + this->writeWord(this->getType(*c.fType), out); this->writeWord(result, out); this->writeWord(fGLSLExtendedInstructions, out); this->writeWord(intrinsicId, out); @@ -1225,7 +1225,7 @@ SpvId SPIRVCodeGenerator::writeIntrinsicCall(FunctionCall& c, std::ostream& out) arguments.push_back(this->writeExpression(*c.fArguments[i], out)); } this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out); - this->writeWord(this->getType(c.fType), out); + this->writeWord(this->getType(*c.fType), out); this->writeWord(result, out); for (SpvId id : arguments) { this->writeWord(id, out); @@ -1249,7 +1249,7 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsi arguments.push_back(this->writeExpression(*c.fArguments[i], out)); } this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out); - this->writeWord(this->getType(c.fType), out); + this->writeWord(this->getType(*c.fType), out); this->writeWord(result, out); this->writeWord(fGLSLExtendedInstructions, out); this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out); @@ -1259,7 +1259,7 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsi return result; } case kTexture_SpecialIntrinsic: { - SpvId type = this->getType(c.fType); + SpvId type = this->getType(*c.fType); SpvId sampler = this->writeExpression(*c.fArguments[0], out); SpvId uv = this->writeExpression(*c.fArguments[1], out); if (c.fArguments.size() == 3) { @@ -1274,7 +1274,7 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsi break; } case kTextureProj_SpecialIntrinsic: { - SpvId type = this->getType(c.fType); + SpvId type = this->getType(*c.fType); SpvId sampler = this->writeExpression(*c.fArguments[0], out); SpvId uv = this->writeExpression(*c.fArguments[1], out); if (c.fArguments.size() == 3) { @@ -1293,7 +1293,7 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsi SpvId img = this->writeExpression(*c.fArguments[0], out); SpvId coords = this->writeExpression(*c.fArguments[1], out); this->writeInstruction(SpvOpImageSampleImplicitLod, - this->getType(c.fType), + this->getType(*c.fType), result, img, coords, @@ -1305,7 +1305,7 @@ SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(FunctionCall& c, SpecialIntrinsi } SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out) { - const auto& entry = fFunctionMap.find(&c.fFunction); + const auto& entry = fFunctionMap.find(c.fFunction); if (entry == fFunctionMap.end()) { return this->writeIntrinsicCall(c, out); } @@ -1318,7 +1318,7 @@ SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out) SpvId tmpVar; // if we need a temporary var to store this argument, this is the value to store in the var SpvId tmpValueId; - if (is_out(*c.fFunction.fParameters[i])) { + if (is_out(c.fFunction->fParameters[i])) { std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out); SpvId ptr = lv->getPointer(); if (ptr) { @@ -1330,7 +1330,7 @@ SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out) // update the lvalue. tmpValueId = lv->load(out); tmpVar = this->nextId(); - lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType), + lvalues.push_back(std::make_tuple(tmpVar, this->getType(*c.fArguments[i]->fType), std::move(lv))); } } else { @@ -1343,13 +1343,13 @@ SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out) SpvStorageClassFunction), tmpVar, SpvStorageClassFunction, - fVariableBuffer); + out); this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out); arguments.push_back(tmpVar); } SpvId result = this->nextId(); this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out); - this->writeWord(this->getType(c.fType), out); + this->writeWord(this->getType(*c.fType), out); this->writeWord(result, out); this->writeWord(entry->second, out); for (SpvId id : arguments) { @@ -1366,19 +1366,19 @@ SpvId SPIRVCodeGenerator::writeFunctionCall(FunctionCall& c, std::ostream& out) } SpvId SPIRVCodeGenerator::writeConstantVector(Constructor& c) { - ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant()); + ASSERT(c.fType->kind() == Type::kVector_Kind && c.isConstant()); SpvId result = this->nextId(); std::vector<SpvId> arguments; for (size_t i = 0; i < c.fArguments.size(); i++) { arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer)); } - SpvId type = this->getType(c.fType); + SpvId type = this->getType(*c.fType); if (c.fArguments.size() == 1) { // with a single argument, a vector will have all of its entries equal to the argument - this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer); + this->writeOpCode(SpvOpConstantComposite, 3 + c.fType->columns(), fConstantBuffer); this->writeWord(type, fConstantBuffer); this->writeWord(result, fConstantBuffer); - for (int i = 0; i < c.fType.columns(); i++) { + for (int i = 0; i < c.fType->columns(); i++) { this->writeWord(arguments[0], fConstantBuffer); } } else { @@ -1394,43 +1394,43 @@ SpvId SPIRVCodeGenerator::writeConstantVector(Constructor& c) { } SpvId SPIRVCodeGenerator::writeFloatConstructor(Constructor& c, std::ostream& out) { - ASSERT(c.fType == *kFloat_Type); + ASSERT(c.fType == kFloat_Type); ASSERT(c.fArguments.size() == 1); - ASSERT(c.fArguments[0]->fType.isNumber()); + ASSERT(c.fArguments[0]->fType->isNumber()); SpvId result = this->nextId(); SpvId parameter = this->writeExpression(*c.fArguments[0], out); - if (c.fArguments[0]->fType == *kInt_Type) { - this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter, + if (c.fArguments[0]->fType == kInt_Type) { + this->writeInstruction(SpvOpConvertSToF, this->getType(*c.fType), result, parameter, out); - } else if (c.fArguments[0]->fType == *kUInt_Type) { - this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter, + } else if (c.fArguments[0]->fType == kUInt_Type) { + this->writeInstruction(SpvOpConvertUToF, this->getType(*c.fType), result, parameter, out); - } else if (c.fArguments[0]->fType == *kFloat_Type) { + } else if (c.fArguments[0]->fType == kFloat_Type) { return parameter; } return result; } SpvId SPIRVCodeGenerator::writeIntConstructor(Constructor& c, std::ostream& out) { - ASSERT(c.fType == *kInt_Type); + ASSERT(c.fType == kInt_Type); ASSERT(c.fArguments.size() == 1); - ASSERT(c.fArguments[0]->fType.isNumber()); + ASSERT(c.fArguments[0]->fType->isNumber()); SpvId result = this->nextId(); SpvId parameter = this->writeExpression(*c.fArguments[0], out); - if (c.fArguments[0]->fType == *kFloat_Type) { - this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter, + if (c.fArguments[0]->fType == kFloat_Type) { + this->writeInstruction(SpvOpConvertFToS, this->getType(*c.fType), result, parameter, out); - } else if (c.fArguments[0]->fType == *kUInt_Type) { - this->writeInstruction(SpvOpSatConvertUToS, this->getType(c.fType), result, parameter, + } else if (c.fArguments[0]->fType == kUInt_Type) { + this->writeInstruction(SpvOpSatConvertUToS, this->getType(*c.fType), result, parameter, out); - } else if (c.fArguments[0]->fType == *kInt_Type) { + } else if (c.fArguments[0]->fType == kInt_Type) { return parameter; } return result; } SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& out) { - ASSERT(c.fType.kind() == Type::kMatrix_Kind); + ASSERT(c.fType->kind() == Type::kMatrix_Kind); // go ahead and write the arguments so we don't try to write new instructions in the middle of // an instruction std::vector<SpvId> arguments; @@ -1438,8 +1438,8 @@ SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& o arguments.push_back(this->writeExpression(*c.fArguments[i], out)); } SpvId result = this->nextId(); - int rows = c.fType.rows(); - int columns = c.fType.columns(); + int rows = c.fType->rows(); + int columns = c.fType->columns(); // FIXME this won't work to create a matrix from another matrix if (arguments.size() == 1) { // with a single argument, a matrix will have all of its diagonal entries equal to the @@ -1449,19 +1449,19 @@ SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& o SpvId zeroId = this->writeFloatLiteral(zero); std::vector<SpvId> columnIds; for (int column = 0; column < columns; column++) { - this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.rows(), + this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType->rows(), out); - this->writeWord(this->getType(c.fType.componentType().toCompound(rows, 1)), out); + this->writeWord(this->getType(*c.fType->componentType()->toCompound(rows, 1)), out); SpvId columnId = this->nextId(); this->writeWord(columnId, out); columnIds.push_back(columnId); - for (int row = 0; row < c.fType.columns(); row++) { + for (int row = 0; row < c.fType->columns(); row++) { this->writeWord(row == column ? arguments[0] : zeroId, out); } } this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out); - this->writeWord(this->getType(c.fType), out); + this->writeWord(this->getType(*c.fType), out); this->writeWord(result, out); for (SpvId id : columnIds) { this->writeWord(id, out); @@ -1470,15 +1470,15 @@ SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& o std::vector<SpvId> columnIds; int currentCount = 0; for (size_t i = 0; i < arguments.size(); i++) { - if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) { + if (c.fArguments[i]->fType->kind() == Type::kVector_Kind) { ASSERT(currentCount == 0); columnIds.push_back(arguments[i]); currentCount = 0; } else { - ASSERT(c.fArguments[i]->fType.kind() == Type::kScalar_Kind); + ASSERT(c.fArguments[i]->fType->kind() == Type::kScalar_Kind); if (currentCount == 0) { - this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.rows(), out); - this->writeWord(this->getType(c.fType.componentType().toCompound(rows, 1)), + this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType->rows(), out); + this->writeWord(this->getType(*c.fType->componentType()->toCompound(rows, 1)), out); SpvId id = this->nextId(); this->writeWord(id, out); @@ -1490,7 +1490,7 @@ SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& o } ASSERT(columnIds.size() == (size_t) columns); this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out); - this->writeWord(this->getType(c.fType), out); + this->writeWord(this->getType(*c.fType), out); this->writeWord(result, out); for (SpvId id : columnIds) { this->writeWord(id, out); @@ -1500,7 +1500,7 @@ SpvId SPIRVCodeGenerator::writeMatrixConstructor(Constructor& c, std::ostream& o } SpvId SPIRVCodeGenerator::writeVectorConstructor(Constructor& c, std::ostream& out) { - ASSERT(c.fType.kind() == Type::kVector_Kind); + ASSERT(c.fType->kind() == Type::kVector_Kind); if (c.isConstant()) { return this->writeConstantVector(c); } @@ -1511,16 +1511,16 @@ SpvId SPIRVCodeGenerator::writeVectorConstructor(Constructor& c, std::ostream& o arguments.push_back(this->writeExpression(*c.fArguments[i], out)); } SpvId result = this->nextId(); - if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) { - this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out); - this->writeWord(this->getType(c.fType), out); + if (arguments.size() == 1 && c.fArguments[0]->fType->kind() == Type::kScalar_Kind) { + this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType->columns(), out); + this->writeWord(this->getType(*c.fType), out); this->writeWord(result, out); - for (int i = 0; i < c.fType.columns(); i++) { + for (int i = 0; i < c.fType->columns(); i++) { this->writeWord(arguments[0], out); } } else { this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out); - this->writeWord(this->getType(c.fType), out); + this->writeWord(this->getType(*c.fType), out); this->writeWord(result, out); for (SpvId id : arguments) { this->writeWord(id, out); @@ -1530,12 +1530,12 @@ SpvId SPIRVCodeGenerator::writeVectorConstructor(Constructor& c, std::ostream& o } SpvId SPIRVCodeGenerator::writeConstructor(Constructor& c, std::ostream& out) { - if (c.fType == *kFloat_Type) { + if (c.fType == kFloat_Type) { return this->writeFloatConstructor(c, out); - } else if (c.fType == *kInt_Type) { + } else if (c.fType == kInt_Type) { return this->writeIntConstructor(c, out); } - switch (c.fType.kind()) { + switch (c.fType->kind()) { case Type::kVector_Kind: return this->writeVectorConstructor(c, out); case Type::kMatrix_Kind: @@ -1560,7 +1560,7 @@ SpvStorageClass_ get_storage_class(const Modifiers& modifiers) { SpvStorageClass_ get_storage_class(Expression& expr) { switch (expr.fKind) { case Expression::kVariableReference_Kind: - return get_storage_class(((VariableReference&) expr).fVariable.fModifiers); + return get_storage_class(((VariableReference&) expr).fVariable->fModifiers); case Expression::kFieldAccess_Kind: return get_storage_class(*((FieldAccess&) expr).fBase); case Expression::kIndex_Kind: @@ -1698,13 +1698,13 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(Expres std::ostream& out) { switch (expr.fKind) { case Expression::kVariableReference_Kind: { - const Variable& var = ((VariableReference&) expr).fVariable; - auto entry = fVariableMap.find(&var); + std::shared_ptr<Variable> var = ((VariableReference&) expr).fVariable; + auto entry = fVariableMap.find(var); ASSERT(entry != fVariableMap.end()); return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( *this, entry->second, - this->getType(expr.fType))); + this->getType(*expr.fType))); } case Expression::kIndex_Kind: // fall through case Expression::kFieldAccess_Kind: { @@ -1719,7 +1719,7 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(Expres return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( *this, member, - this->getType(expr.fType))); + this->getType(*expr.fType))); } case Expression::kSwizzle_Kind: { @@ -1740,14 +1740,14 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(Expres return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( *this, member, - this->getType(expr.fType))); + this->getType(*expr.fType))); } else { return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue( *this, base, swizzle.fComponents, - swizzle.fBase->fType, - expr.fType)); + *swizzle.fBase->fType, + *expr.fType)); } } @@ -1758,22 +1758,21 @@ std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(Expres // caught by IRGenerator SpvId result = this->nextId(); SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction); - this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction, - fVariableBuffer); + this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction, out); this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out); return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( *this, result, - this->getType(expr.fType))); + this->getType(*expr.fType))); } } SpvId SPIRVCodeGenerator::writeVariableReference(VariableReference& ref, std::ostream& out) { - auto entry = fVariableMap.find(&ref.fVariable); + auto entry = fVariableMap.find(ref.fVariable); ASSERT(entry != fVariableMap.end()); SpvId var = entry->second; SpvId result = this->nextId(); - this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out); + this->writeInstruction(SpvOpLoad, this->getType(*ref.fVariable->fType), result, var, out); return result; } @@ -1790,11 +1789,11 @@ SpvId SPIRVCodeGenerator::writeSwizzle(Swizzle& swizzle, std::ostream& out) { SpvId result = this->nextId(); size_t count = swizzle.fComponents.size(); if (count == 1) { - this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base, + this->writeInstruction(SpvOpCompositeExtract, this->getType(*swizzle.fType), result, base, swizzle.fComponents[0], out); } else { this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out); - this->writeWord(this->getType(swizzle.fType), out); + this->writeWord(this->getType(*swizzle.fType), out); this->writeWord(result, out); this->writeWord(base, out); this->writeWord(base, out); @@ -1863,7 +1862,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea } // "normal" operators - const Type& resultType = b.fType; + const Type& resultType = *b.fType; std::unique_ptr<LValue> lvalue; SpvId lhs; if (is_assignment(b.fOperator)) { @@ -1879,23 +1878,23 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea // IR allows mismatched types in expressions (e.g. vec2 * float), but they need special handling // in SPIR-V if (b.fLeft->fType != b.fRight->fType) { - if (b.fLeft->fType.kind() == Type::kVector_Kind && - b.fRight->fType.isNumber()) { + if (b.fLeft->fType->kind() == Type::kVector_Kind && + b.fRight->fType->isNumber()) { // promote number to vector SpvId vec = this->nextId(); - this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out); + this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType->columns(), out); this->writeWord(this->getType(resultType), out); this->writeWord(vec, out); for (int i = 0; i < resultType.columns(); i++) { this->writeWord(rhs, out); } rhs = vec; - operandType = &b.fRight->fType; - } else if (b.fRight->fType.kind() == Type::kVector_Kind && - b.fLeft->fType.isNumber()) { + operandType = b.fRight->fType.get(); + } else if (b.fRight->fType->kind() == Type::kVector_Kind && + b.fLeft->fType->isNumber()) { // promote number to vector SpvId vec = this->nextId(); - this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out); + this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType->columns(), out); this->writeWord(this->getType(resultType), out); this->writeWord(vec, out); for (int i = 0; i < resultType.columns(); i++) { @@ -1903,33 +1902,33 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea } lhs = vec; ASSERT(!lvalue); - operandType = &b.fLeft->fType; - } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) { + operandType = b.fLeft->fType.get(); + } else if (b.fLeft->fType->kind() == Type::kMatrix_Kind) { SpvOp_ op; - if (b.fRight->fType.kind() == Type::kMatrix_Kind) { + if (b.fRight->fType->kind() == Type::kMatrix_Kind) { op = SpvOpMatrixTimesMatrix; - } else if (b.fRight->fType.kind() == Type::kVector_Kind) { + } else if (b.fRight->fType->kind() == Type::kVector_Kind) { op = SpvOpMatrixTimesVector; } else { - ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind); + ASSERT(b.fRight->fType->kind() == Type::kScalar_Kind); op = SpvOpMatrixTimesScalar; } SpvId result = this->nextId(); - this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out); + this->writeInstruction(op, this->getType(*b.fType), result, lhs, rhs, out); if (b.fOperator == Token::STAREQ) { lvalue->store(result, out); } else { ASSERT(b.fOperator == Token::STAR); } return result; - } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) { + } else if (b.fRight->fType->kind() == Type::kMatrix_Kind) { SpvId result = this->nextId(); - if (b.fLeft->fType.kind() == Type::kVector_Kind) { - this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result, + if (b.fLeft->fType->kind() == Type::kVector_Kind) { + this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(*b.fType), result, lhs, rhs, out); } else { - ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind); - this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs, + ASSERT(b.fLeft->fType->kind() == Type::kScalar_Kind); + this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(*b.fType), result, rhs, lhs, out); } if (b.fOperator == Token::STAREQ) { @@ -1942,8 +1941,8 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea ABORT("unsupported binary expression: %s", b.description().c_str()); } } else { - operandType = &b.fLeft->fType; - ASSERT(*operandType == b.fRight->fType); + operandType = b.fLeft->fType.get(); + ASSERT(*operandType == *b.fRight->fType); } switch (b.fOperator) { case Token::EQEQ: @@ -1981,8 +1980,8 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef, out); case Token::STAR: - if (b.fLeft->fType.kind() == Type::kMatrix_Kind && - b.fRight->fType.kind() == Type::kMatrix_Kind) { + if (b.fLeft->fType->kind() == Type::kMatrix_Kind && + b.fRight->fType->kind() == Type::kMatrix_Kind) { // matrix multiply SpvId result = this->nextId(); this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result, @@ -2009,8 +2008,8 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(BinaryExpression& b, std::ostrea return result; } case Token::STAREQ: { - if (b.fLeft->fType.kind() == Type::kMatrix_Kind && - b.fRight->fType.kind() == Type::kMatrix_Kind) { + if (b.fLeft->fType->kind() == Type::kMatrix_Kind && + b.fRight->fType->kind() == Type::kMatrix_Kind) { // matrix multiply SpvId result = this->nextId(); this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result, @@ -2087,7 +2086,7 @@ SpvId SPIRVCodeGenerator::writeTernaryExpression(TernaryExpression& t, std::ostr SpvId result = this->nextId(); SpvId trueId = this->writeExpression(*t.fIfTrue, out); SpvId falseId = this->writeExpression(*t.fIfFalse, out); - this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId, + this->writeInstruction(SpvOpSelect, this->getType(*t.fType), result, test, trueId, falseId, out); return result; } @@ -2095,7 +2094,7 @@ SpvId SPIRVCodeGenerator::writeTernaryExpression(TernaryExpression& t, std::ostr // Adreno. Switched to storing the result in a temp variable as glslang does. SpvId var = this->nextId(); this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction), - var, SpvStorageClassFunction, fVariableBuffer); + var, SpvStorageClassFunction, out); SpvId trueLabel = this->nextId(); SpvId falseLabel = this->nextId(); SpvId end = this->nextId(); @@ -2109,7 +2108,7 @@ SpvId SPIRVCodeGenerator::writeTernaryExpression(TernaryExpression& t, std::ostr this->writeInstruction(SpvOpBranch, end, out); this->writeLabel(end, out); SpvId result = this->nextId(); - this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out); + this->writeInstruction(SpvOpLoad, this->getType(*t.fType), result, var, out); return result; } @@ -2129,11 +2128,11 @@ Expression* literal_1(const Type& type) { SpvId SPIRVCodeGenerator::writePrefixExpression(PrefixExpression& p, std::ostream& out) { if (p.fOperator == Token::MINUS) { SpvId result = this->nextId(); - SpvId typeId = this->getType(p.fType); + SpvId typeId = this->getType(*p.fType); SpvId expr = this->writeExpression(*p.fOperand, out); - if (is_float(p.fType)) { + if (is_float(*p.fType)) { this->writeInstruction(SpvOpFNegate, typeId, result, expr, out); - } else if (is_signed(p.fType)) { + } else if (is_signed(*p.fType)) { this->writeInstruction(SpvOpSNegate, typeId, result, expr, out); } else { ABORT("unsupported prefix expression %s", p.description().c_str()); @@ -2145,8 +2144,8 @@ SpvId SPIRVCodeGenerator::writePrefixExpression(PrefixExpression& p, std::ostrea return this->writeExpression(*p.fOperand, out); case Token::PLUSPLUS: { std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out); - SpvId one = this->writeExpression(*literal_1(p.fType), out); - SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one, + SpvId one = this->writeExpression(*literal_1(*p.fType), out); + SpvId result = this->writeBinaryOperation(*p.fType, *p.fType, lv->load(out), one, SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); lv->store(result, out); @@ -2154,17 +2153,17 @@ SpvId SPIRVCodeGenerator::writePrefixExpression(PrefixExpression& p, std::ostrea } case Token::MINUSMINUS: { std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out); - SpvId one = this->writeExpression(*literal_1(p.fType), out); - SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one, + SpvId one = this->writeExpression(*literal_1(*p.fType), out); + SpvId result = this->writeBinaryOperation(*p.fType, *p.fType, lv->load(out), one, SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef, out); lv->store(result, out); return result; } case Token::NOT: { - ASSERT(p.fOperand->fType == *kBool_Type); + ASSERT(p.fOperand->fType == kBool_Type); SpvId result = this->nextId(); - this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result, + this->writeInstruction(SpvOpLogicalNot, this->getType(*p.fOperand->fType), result, this->writeExpression(*p.fOperand, out), out); return result; } @@ -2176,16 +2175,16 @@ SpvId SPIRVCodeGenerator::writePrefixExpression(PrefixExpression& p, std::ostrea SpvId SPIRVCodeGenerator::writePostfixExpression(PostfixExpression& p, std::ostream& out) { std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out); SpvId result = lv->load(out); - SpvId one = this->writeExpression(*literal_1(p.fType), out); + SpvId one = this->writeExpression(*literal_1(*p.fType), out); switch (p.fOperator) { case Token::PLUSPLUS: { - SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd, + SpvId temp = this->writeBinaryOperation(*p.fType, *p.fType, result, one, SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); lv->store(temp, out); return result; } case Token::MINUSMINUS: { - SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub, + SpvId temp = this->writeBinaryOperation(*p.fType, *p.fType, result, one, SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef, out); lv->store(temp, out); return result; @@ -2199,14 +2198,14 @@ SpvId SPIRVCodeGenerator::writeBoolLiteral(BoolLiteral& b) { if (b.fValue) { if (fBoolTrue == 0) { fBoolTrue = this->nextId(); - this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue, + this->writeInstruction(SpvOpConstantTrue, this->getType(*b.fType), fBoolTrue, fConstantBuffer); } return fBoolTrue; } else { if (fBoolFalse == 0) { fBoolFalse = this->nextId(); - this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse, + this->writeInstruction(SpvOpConstantFalse, this->getType(*b.fType), fBoolFalse, fConstantBuffer); } return fBoolFalse; @@ -2214,22 +2213,22 @@ SpvId SPIRVCodeGenerator::writeBoolLiteral(BoolLiteral& b) { } SpvId SPIRVCodeGenerator::writeIntLiteral(IntLiteral& i) { - if (i.fType == *kInt_Type) { + if (i.fType == kInt_Type) { auto entry = fIntConstants.find(i.fValue); if (entry == fIntConstants.end()) { SpvId result = this->nextId(); - this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue, + this->writeInstruction(SpvOpConstant, this->getType(*i.fType), result, (SpvId) i.fValue, fConstantBuffer); fIntConstants[i.fValue] = result; return result; } return entry->second; } else { - ASSERT(i.fType == *kUInt_Type); + ASSERT(i.fType == kUInt_Type); auto entry = fUIntConstants.find(i.fValue); if (entry == fUIntConstants.end()) { SpvId result = this->nextId(); - this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue, + this->writeInstruction(SpvOpConstant, this->getType(*i.fType), result, (SpvId) i.fValue, fConstantBuffer); fUIntConstants[i.fValue] = result; return result; @@ -2239,7 +2238,7 @@ SpvId SPIRVCodeGenerator::writeIntLiteral(IntLiteral& i) { } SpvId SPIRVCodeGenerator::writeFloatLiteral(FloatLiteral& f) { - if (f.fType == *kFloat_Type) { + if (f.fType == kFloat_Type) { float value = (float) f.fValue; auto entry = fFloatConstants.find(value); if (entry == fFloatConstants.end()) { @@ -2247,21 +2246,21 @@ SpvId SPIRVCodeGenerator::writeFloatLiteral(FloatLiteral& f) { uint32_t bits; ASSERT(sizeof(bits) == sizeof(value)); memcpy(&bits, &value, sizeof(bits)); - this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits, + this->writeInstruction(SpvOpConstant, this->getType(*f.fType), result, bits, fConstantBuffer); fFloatConstants[value] = result; return result; } return entry->second; } else { - ASSERT(f.fType == *kDouble_Type); + ASSERT(f.fType == kDouble_Type); auto entry = fDoubleConstants.find(f.fValue); if (entry == fDoubleConstants.end()) { SpvId result = this->nextId(); uint64_t bits; ASSERT(sizeof(bits) == sizeof(f.fValue)); memcpy(&bits, &f.fValue, sizeof(bits)); - this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, + this->writeInstruction(SpvOpConstant, this->getType(*f.fType), result, bits & 0xffffffff, bits >> 32, fConstantBuffer); fDoubleConstants[f.fValue] = result; return result; @@ -2270,25 +2269,26 @@ SpvId SPIRVCodeGenerator::writeFloatLiteral(FloatLiteral& f) { } } -SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, std::ostream& out) { - SpvId result = fFunctionMap[&f]; - this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result, +SpvId SPIRVCodeGenerator::writeFunctionStart(std::shared_ptr<FunctionDeclaration> f, + std::ostream& out) { + SpvId result = fFunctionMap[f]; + this->writeInstruction(SpvOpFunction, this->getType(*f->fReturnType), result, SpvFunctionControlMaskNone, this->getFunctionType(f), out); - this->writeInstruction(SpvOpName, result, f.fName.c_str(), fNameBuffer); - for (size_t i = 0; i < f.fParameters.size(); i++) { + this->writeInstruction(SpvOpName, result, f->fName.c_str(), fNameBuffer); + for (size_t i = 0; i < f->fParameters.size(); i++) { SpvId id = this->nextId(); - fVariableMap[f.fParameters[i]] = id; + fVariableMap[f->fParameters[i]] = id; SpvId type; - type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction); + type = this->getPointerType(f->fParameters[i]->fType, SpvStorageClassFunction); this->writeInstruction(SpvOpFunctionParameter, type, id, out); } return result; } -SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, std::ostream& out) { +SpvId SPIRVCodeGenerator::writeFunction(FunctionDefinition& f, std::ostream& out) { SpvId result = this->writeFunctionStart(f.fDeclaration, out); this->writeLabel(this->nextId(), out); - if (f.fDeclaration.fName == "main") { + if (f.fDeclaration->fName == "main") { out << fGlobalInitializersBuffer.str(); } std::stringstream bodyBuffer; @@ -2350,26 +2350,21 @@ void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int mem } SpvId SPIRVCodeGenerator::writeInterfaceBlock(InterfaceBlock& intf) { - SpvId type = this->getType(intf.fVariable.fType); + SpvId type = this->getType(*intf.fVariable->fType); SpvId result = this->nextId(); this->writeInstruction(SpvOpDecorate, type, SpvDecorationBlock, fDecorationBuffer); - SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers); + SpvStorageClass_ storageClass = get_storage_class(intf.fVariable->fModifiers); SpvId ptrType = this->nextId(); this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, type, fConstantBuffer); this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer); - this->writeLayout(intf.fVariable.fModifiers.fLayout, result); - fVariableMap[&intf.fVariable] = result; + this->writeLayout(intf.fVariable->fModifiers.fLayout, result); + fVariableMap[intf.fVariable] = result; return result; } void SPIRVCodeGenerator::writeGlobalVars(VarDeclaration& decl, std::ostream& out) { for (size_t i = 0; i < decl.fVars.size(); i++) { - if (!decl.fVars[i]->fIsReadFrom && !decl.fVars[i]->fIsWrittenTo && - !(decl.fVars[i]->fModifiers.fFlags & (Modifiers::kIn_Flag | - Modifiers::kOut_Flag | - Modifiers::kUniform_Flag))) { - // variable is dead and not an input / output var (the Vulkan debug layers complain if - // we elide an interface var, even if it's dead) + if (!decl.fVars[i]->fIsReadFrom && !decl.fVars[i]->fIsWrittenTo) { continue; } SpvStorageClass_ storageClass; @@ -2378,7 +2373,7 @@ void SPIRVCodeGenerator::writeGlobalVars(VarDeclaration& decl, std::ostream& out } else if (decl.fVars[i]->fModifiers.fFlags & Modifiers::kOut_Flag) { storageClass = SpvStorageClassOutput; } else if (decl.fVars[i]->fModifiers.fFlags & Modifiers::kUniform_Flag) { - if (decl.fVars[i]->fType.kind() == Type::kSampler_Kind) { + if (decl.fVars[i]->fType->kind() == Type::kSampler_Kind) { storageClass = SpvStorageClassUniformConstant; } else { storageClass = SpvStorageClassUniform; @@ -2391,11 +2386,11 @@ void SPIRVCodeGenerator::writeGlobalVars(VarDeclaration& decl, std::ostream& out SpvId type = this->getPointerType(decl.fVars[i]->fType, storageClass); this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer); this->writeInstruction(SpvOpName, id, decl.fVars[i]->fName.c_str(), fNameBuffer); - if (decl.fVars[i]->fType.kind() == Type::kMatrix_Kind) { + if (decl.fVars[i]->fType->kind() == Type::kMatrix_Kind) { this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationColMajor, fDecorationBuffer); this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationMatrixStride, - (SpvId) decl.fVars[i]->fType.stride(), fDecorationBuffer); + (SpvId) decl.fVars[i]->fType->stride(), fDecorationBuffer); } if (decl.fValues[i]) { ASSERT(!fCurrentBlock); @@ -2543,15 +2538,15 @@ void SPIRVCodeGenerator::writeInstructions(Program& program, std::ostream& out) for (size_t i = 0; i < program.fElements.size(); i++) { if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) { FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i]; - fFunctionMap[&f.fDeclaration] = this->nextId(); + fFunctionMap[f.fDeclaration] = this->nextId(); } } for (size_t i = 0; i < program.fElements.size(); i++) { if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) { InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i]; SpvId id = this->writeInterfaceBlock(intf); - if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) || - (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) { + if ((intf.fVariable->fModifiers.fFlags & Modifiers::kIn_Flag) || + (intf.fVariable->fModifiers.fFlags & Modifiers::kOut_Flag)) { interfaceVars.push_back(id); } } @@ -2566,7 +2561,7 @@ void SPIRVCodeGenerator::writeInstructions(Program& program, std::ostream& out) this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body); } } - const FunctionDeclaration* main = nullptr; + std::shared_ptr<FunctionDeclaration> main = nullptr; for (auto entry : fFunctionMap) { if (entry.first->fName == "main") { main = entry.first; @@ -2574,7 +2569,7 @@ void SPIRVCodeGenerator::writeInstructions(Program& program, std::ostream& out) } ASSERT(main); for (auto entry : fVariableMap) { - const Variable* var = entry.first; + std::shared_ptr<Variable> var = entry.first; if (var->fStorage == Variable::kGlobal_Storage && ((var->fModifiers.fFlags & Modifiers::kIn_Flag) || (var->fModifiers.fFlags & Modifiers::kOut_Flag))) { diff --git a/src/sksl/SkSLSPIRVCodeGenerator.h b/src/sksl/SkSLSPIRVCodeGenerator.h index dd301778a5..885c6b8b70 100644 --- a/src/sksl/SkSLSPIRVCodeGenerator.h +++ b/src/sksl/SkSLSPIRVCodeGenerator.h @@ -92,9 +92,9 @@ private: SpvId getType(const Type& type); - SpvId getFunctionType(const FunctionDeclaration& function); + SpvId getFunctionType(std::shared_ptr<FunctionDeclaration> function); - SpvId getPointerType(const Type& type, SpvStorageClass_ storageClass); + SpvId getPointerType(std::shared_ptr<Type> type, SpvStorageClass_ storageClass); std::vector<SpvId> getAccessChain(Expression& expr, std::ostream& out); @@ -108,11 +108,11 @@ private: SpvId writeInterfaceBlock(InterfaceBlock& intf); - SpvId writeFunctionStart(const FunctionDeclaration& f, std::ostream& out); + SpvId writeFunctionStart(std::shared_ptr<FunctionDeclaration> f, std::ostream& out); - SpvId writeFunctionDeclaration(const FunctionDeclaration& f, std::ostream& out); + SpvId writeFunctionDeclaration(std::shared_ptr<FunctionDeclaration> f, std::ostream& out); - SpvId writeFunction(const FunctionDefinition& f, std::ostream& out); + SpvId writeFunction(FunctionDefinition& f, std::ostream& out); void writeGlobalVars(VarDeclaration& v, std::ostream& out); @@ -232,9 +232,9 @@ private: SpvId fGLSLExtendedInstructions; typedef std::tuple<IntrinsicKind, int32_t, int32_t, int32_t, int32_t> Intrinsic; std::unordered_map<std::string, Intrinsic> fIntrinsicMap; - std::unordered_map<const FunctionDeclaration*, SpvId> fFunctionMap; - std::unordered_map<const Variable*, SpvId> fVariableMap; - std::unordered_map<const Variable*, int32_t> fInterfaceBlockMap; + std::unordered_map<std::shared_ptr<FunctionDeclaration>, SpvId> fFunctionMap; + std::unordered_map<std::shared_ptr<Variable>, SpvId> fVariableMap; + std::unordered_map<std::shared_ptr<Variable>, int32_t> fInterfaceBlockMap; std::unordered_map<std::string, SpvId> fTypeMap; std::stringstream fCapabilitiesBuffer; std::stringstream fGlobalInitializersBuffer; diff --git a/src/sksl/ir/SkSLBinaryExpression.h b/src/sksl/ir/SkSLBinaryExpression.h index 9ecdbc717c..bd89d6c602 100644 --- a/src/sksl/ir/SkSLBinaryExpression.h +++ b/src/sksl/ir/SkSLBinaryExpression.h @@ -18,7 +18,7 @@ namespace SkSL { */ struct BinaryExpression : public Expression { BinaryExpression(Position position, std::unique_ptr<Expression> left, Token::Kind op, - std::unique_ptr<Expression> right, const Type& type) + std::unique_ptr<Expression> right, std::shared_ptr<Type> type) : INHERITED(position, kBinary_Kind, type) , fLeft(std::move(left)) , fOperator(op) diff --git a/src/sksl/ir/SkSLBlock.h b/src/sksl/ir/SkSLBlock.h index a53d13d169..56ed77a0ba 100644 --- a/src/sksl/ir/SkSLBlock.h +++ b/src/sksl/ir/SkSLBlock.h @@ -9,7 +9,6 @@ #define SKSL_BLOCK #include "SkSLStatement.h" -#include "SkSLSymbolTable.h" namespace SkSL { @@ -17,11 +16,9 @@ namespace SkSL { * A block of multiple statements functioning as a single statement. */ struct Block : public Statement { - Block(Position position, std::vector<std::unique_ptr<Statement>> statements, - const std::shared_ptr<SymbolTable> symbols) + Block(Position position, std::vector<std::unique_ptr<Statement>> statements) : INHERITED(position, kBlock_Kind) - , fStatements(std::move(statements)) - , fSymbols(std::move(symbols)) {} + , fStatements(std::move(statements)) {} std::string description() const override { std::string result = "{"; @@ -34,7 +31,6 @@ struct Block : public Statement { } const std::vector<std::unique_ptr<Statement>> fStatements; - const std::shared_ptr<SymbolTable> fSymbols; typedef Statement INHERITED; }; diff --git a/src/sksl/ir/SkSLBoolLiteral.h b/src/sksl/ir/SkSLBoolLiteral.h index 31cb817548..3c40e59514 100644 --- a/src/sksl/ir/SkSLBoolLiteral.h +++ b/src/sksl/ir/SkSLBoolLiteral.h @@ -17,7 +17,7 @@ namespace SkSL { */ struct BoolLiteral : public Expression { BoolLiteral(Position position, bool value) - : INHERITED(position, kBoolLiteral_Kind, *kBool_Type) + : INHERITED(position, kBoolLiteral_Kind, kBool_Type) , fValue(value) {} std::string description() const override { diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h index 0501b651ea..c58da7e5b8 100644 --- a/src/sksl/ir/SkSLConstructor.h +++ b/src/sksl/ir/SkSLConstructor.h @@ -16,13 +16,13 @@ namespace SkSL { * Represents the construction of a compound type, such as "vec2(x, y)". */ struct Constructor : public Expression { - Constructor(Position position, const Type& type, + Constructor(Position position, std::shared_ptr<Type> type, std::vector<std::unique_ptr<Expression>> arguments) - : INHERITED(position, kConstructor_Kind, type) + : INHERITED(position, kConstructor_Kind, std::move(type)) , fArguments(std::move(arguments)) {} std::string description() const override { - std::string result = fType.description() + "("; + std::string result = fType->description() + "("; std::string separator = ""; for (size_t i = 0; i < fArguments.size(); i++) { result += separator; diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h index 92cb37de77..1e42c7a475 100644 --- a/src/sksl/ir/SkSLExpression.h +++ b/src/sksl/ir/SkSLExpression.h @@ -35,7 +35,7 @@ struct Expression : public IRNode { kTypeReference_Kind, }; - Expression(Position position, Kind kind, const Type& type) + Expression(Position position, Kind kind, std::shared_ptr<Type> type) : INHERITED(position) , fKind(kind) , fType(std::move(type)) {} @@ -45,7 +45,7 @@ struct Expression : public IRNode { } const Kind fKind; - const Type& fType; + const std::shared_ptr<Type> fType; typedef IRNode INHERITED; }; diff --git a/src/sksl/ir/SkSLField.h b/src/sksl/ir/SkSLField.h index a01df2943d..f2b68bc2bc 100644 --- a/src/sksl/ir/SkSLField.h +++ b/src/sksl/ir/SkSLField.h @@ -21,16 +21,16 @@ namespace SkSL { * result of declaring anonymous interface blocks. */ struct Field : public Symbol { - Field(Position position, const Variable& owner, int fieldIndex) - : INHERITED(position, kField_Kind, owner.fType.fields()[fieldIndex].fName) + Field(Position position, std::shared_ptr<Variable> owner, int fieldIndex) + : INHERITED(position, kField_Kind, owner->fType->fields()[fieldIndex].fName) , fOwner(owner) , fFieldIndex(fieldIndex) {} virtual std::string description() const override { - return fOwner.description() + "." + fOwner.fType.fields()[fFieldIndex].fName; + return fOwner->description() + "." + fOwner->fType->fields()[fFieldIndex].fName; } - const Variable& fOwner; + const std::shared_ptr<Variable> fOwner; const int fFieldIndex; typedef Symbol INHERITED; diff --git a/src/sksl/ir/SkSLFieldAccess.h b/src/sksl/ir/SkSLFieldAccess.h index f09c3a3447..053498e154 100644 --- a/src/sksl/ir/SkSLFieldAccess.h +++ b/src/sksl/ir/SkSLFieldAccess.h @@ -18,12 +18,12 @@ namespace SkSL { */ struct FieldAccess : public Expression { FieldAccess(std::unique_ptr<Expression> base, int fieldIndex) - : INHERITED(base->fPosition, kFieldAccess_Kind, base->fType.fields()[fieldIndex].fType) + : INHERITED(base->fPosition, kFieldAccess_Kind, base->fType->fields()[fieldIndex].fType) , fBase(std::move(base)) , fFieldIndex(fieldIndex) {} virtual std::string description() const override { - return fBase->description() + "." + fBase->fType.fields()[fFieldIndex].fName; + return fBase->description() + "." + fBase->fType->fields()[fFieldIndex].fName; } const std::unique_ptr<Expression> fBase; diff --git a/src/sksl/ir/SkSLFloatLiteral.h b/src/sksl/ir/SkSLFloatLiteral.h index 8f5ec43299..deb5b27144 100644 --- a/src/sksl/ir/SkSLFloatLiteral.h +++ b/src/sksl/ir/SkSLFloatLiteral.h @@ -17,7 +17,7 @@ namespace SkSL { */ struct FloatLiteral : public Expression { FloatLiteral(Position position, double value) - : INHERITED(position, kFloatLiteral_Kind, *kFloat_Type) + : INHERITED(position, kFloatLiteral_Kind, kFloat_Type) , fValue(value) {} virtual std::string description() const override { diff --git a/src/sksl/ir/SkSLForStatement.h b/src/sksl/ir/SkSLForStatement.h index 642d15125e..70bb4014c8 100644 --- a/src/sksl/ir/SkSLForStatement.h +++ b/src/sksl/ir/SkSLForStatement.h @@ -10,7 +10,6 @@ #include "SkSLExpression.h" #include "SkSLStatement.h" -#include "SkSLSymbolTable.h" namespace SkSL { @@ -20,13 +19,12 @@ namespace SkSL { struct ForStatement : public Statement { ForStatement(Position position, std::unique_ptr<Statement> initializer, std::unique_ptr<Expression> test, std::unique_ptr<Expression> next, - std::unique_ptr<Statement> statement, std::shared_ptr<SymbolTable> symbols) + std::unique_ptr<Statement> statement) : INHERITED(position, kFor_Kind) , fInitializer(std::move(initializer)) , fTest(std::move(test)) , fNext(std::move(next)) - , fStatement(std::move(statement)) - , fSymbols(symbols) {} + , fStatement(std::move(statement)) {} std::string description() const override { std::string result = "for ("; @@ -49,7 +47,6 @@ struct ForStatement : public Statement { const std::unique_ptr<Expression> fTest; const std::unique_ptr<Expression> fNext; const std::unique_ptr<Statement> fStatement; - const std::shared_ptr<SymbolTable> fSymbols; typedef Statement INHERITED; }; diff --git a/src/sksl/ir/SkSLFunctionCall.h b/src/sksl/ir/SkSLFunctionCall.h index 85dba40f2a..78d2566227 100644 --- a/src/sksl/ir/SkSLFunctionCall.h +++ b/src/sksl/ir/SkSLFunctionCall.h @@ -17,14 +17,14 @@ namespace SkSL { * A function invocation. */ struct FunctionCall : public Expression { - FunctionCall(Position position, const FunctionDeclaration& function, + FunctionCall(Position position, std::shared_ptr<FunctionDeclaration> function, std::vector<std::unique_ptr<Expression>> arguments) - : INHERITED(position, kFunctionCall_Kind, function.fReturnType) + : INHERITED(position, kFunctionCall_Kind, function->fReturnType) , fFunction(std::move(function)) , fArguments(std::move(arguments)) {} std::string description() const override { - std::string result = fFunction.fName + "("; + std::string result = fFunction->fName + "("; std::string separator = ""; for (size_t i = 0; i < fArguments.size(); i++) { result += separator; @@ -35,7 +35,7 @@ struct FunctionCall : public Expression { return result; } - const FunctionDeclaration& fFunction; + const std::shared_ptr<FunctionDeclaration> fFunction; const std::vector<std::unique_ptr<Expression>> fArguments; typedef Expression INHERITED; diff --git a/src/sksl/ir/SkSLFunctionDeclaration.h b/src/sksl/ir/SkSLFunctionDeclaration.h index 16a184a6d7..32c23f545e 100644 --- a/src/sksl/ir/SkSLFunctionDeclaration.h +++ b/src/sksl/ir/SkSLFunctionDeclaration.h @@ -10,7 +10,6 @@ #include "SkSLModifiers.h" #include "SkSLSymbol.h" -#include "SkSLSymbolTable.h" #include "SkSLType.h" #include "SkSLVariable.h" @@ -21,14 +20,15 @@ namespace SkSL { */ struct FunctionDeclaration : public Symbol { FunctionDeclaration(Position position, std::string name, - std::vector<const Variable*> parameters, const Type& returnType) + std::vector<std::shared_ptr<Variable>> parameters, + std::shared_ptr<Type> returnType) : INHERITED(position, kFunctionDeclaration_Kind, std::move(name)) , fDefined(false) - , fParameters(std::move(parameters)) + , fParameters(parameters) , fReturnType(returnType) {} std::string description() const override { - std::string result = fReturnType.description() + " " + fName + "("; + std::string result = fReturnType->description() + " " + fName + "("; std::string separator = ""; for (auto p : fParameters) { result += separator; @@ -39,24 +39,13 @@ struct FunctionDeclaration : public Symbol { return result; } - bool matches(const FunctionDeclaration& f) const { - if (fName != f.fName) { - return false; - } - if (fParameters.size() != f.fParameters.size()) { - return false; - } - for (size_t i = 0; i < fParameters.size(); i++) { - if (fParameters[i]->fType != f.fParameters[i]->fType) { - return false; - } - } - return true; + bool matches(FunctionDeclaration& f) { + return fName == f.fName && fParameters == f.fParameters; } mutable bool fDefined; - const std::vector<const Variable*> fParameters; - const Type& fReturnType; + const std::vector<std::shared_ptr<Variable>> fParameters; + const std::shared_ptr<Type> fReturnType; typedef Symbol INHERITED; }; diff --git a/src/sksl/ir/SkSLFunctionDefinition.h b/src/sksl/ir/SkSLFunctionDefinition.h index ace27a3ed8..fceb5474cb 100644 --- a/src/sksl/ir/SkSLFunctionDefinition.h +++ b/src/sksl/ir/SkSLFunctionDefinition.h @@ -18,17 +18,17 @@ namespace SkSL { * A function definition (a declaration plus an associated block of code). */ struct FunctionDefinition : public ProgramElement { - FunctionDefinition(Position position, const FunctionDeclaration& declaration, + FunctionDefinition(Position position, std::shared_ptr<FunctionDeclaration> declaration, std::unique_ptr<Block> body) : INHERITED(position, kFunction_Kind) - , fDeclaration(declaration) + , fDeclaration(std::move(declaration)) , fBody(std::move(body)) {} std::string description() const override { - return fDeclaration.description() + " " + fBody->description(); + return fDeclaration->description() + " " + fBody->description(); } - const FunctionDeclaration& fDeclaration; + const std::shared_ptr<FunctionDeclaration> fDeclaration; const std::unique_ptr<Block> fBody; typedef ProgramElement INHERITED; diff --git a/src/sksl/ir/SkSLFunctionReference.h b/src/sksl/ir/SkSLFunctionReference.h index 8afcbb1e32..d5cc444000 100644 --- a/src/sksl/ir/SkSLFunctionReference.h +++ b/src/sksl/ir/SkSLFunctionReference.h @@ -17,8 +17,8 @@ namespace SkSL { * always eventually replaced by FunctionCalls in valid programs. */ struct FunctionReference : public Expression { - FunctionReference(Position position, std::vector<const FunctionDeclaration*> function) - : INHERITED(position, kFunctionReference_Kind, *kInvalid_Type) + FunctionReference(Position position, std::vector<std::shared_ptr<FunctionDeclaration>> function) + : INHERITED(position, kFunctionReference_Kind, kInvalid_Type) , fFunctions(function) {} virtual std::string description() const override { @@ -26,7 +26,7 @@ struct FunctionReference : public Expression { return "<function>"; } - const std::vector<const FunctionDeclaration*> fFunctions; + const std::vector<std::shared_ptr<FunctionDeclaration>> fFunctions; typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLIndexExpression.h b/src/sksl/ir/SkSLIndexExpression.h index 78727aa5c5..538c656153 100644 --- a/src/sksl/ir/SkSLIndexExpression.h +++ b/src/sksl/ir/SkSLIndexExpression.h @@ -16,21 +16,21 @@ namespace SkSL { /** * Given a type, returns the type that will result from extracting an array value from it. */ -static const Type& index_type(const Type& type) { +static std::shared_ptr<Type> index_type(const Type& type) { if (type.kind() == Type::kMatrix_Kind) { - if (type.componentType() == *kFloat_Type) { + if (type.componentType() == kFloat_Type) { switch (type.columns()) { - case 2: return *kVec2_Type; - case 3: return *kVec3_Type; - case 4: return *kVec4_Type; + case 2: return kVec2_Type; + case 3: return kVec3_Type; + case 4: return kVec4_Type; default: ASSERT(false); } } else { - ASSERT(type.componentType() == *kDouble_Type); + ASSERT(type.componentType() == kDouble_Type); switch (type.columns()) { - case 2: return *kDVec2_Type; - case 3: return *kDVec3_Type; - case 4: return *kDVec4_Type; + case 2: return kDVec2_Type; + case 3: return kDVec3_Type; + case 4: return kDVec4_Type; default: ASSERT(false); } } @@ -43,10 +43,10 @@ static const Type& index_type(const Type& type) { */ struct IndexExpression : public Expression { IndexExpression(std::unique_ptr<Expression> base, std::unique_ptr<Expression> index) - : INHERITED(base->fPosition, kIndex_Kind, index_type(base->fType)) + : INHERITED(base->fPosition, kIndex_Kind, index_type(*base->fType)) , fBase(std::move(base)) , fIndex(std::move(index)) { - ASSERT(fIndex->fType == *kInt_Type); + ASSERT(fIndex->fType == kInt_Type); } std::string description() const override { diff --git a/src/sksl/ir/SkSLIntLiteral.h b/src/sksl/ir/SkSLIntLiteral.h index 63b0069784..80b30d7c05 100644 --- a/src/sksl/ir/SkSLIntLiteral.h +++ b/src/sksl/ir/SkSLIntLiteral.h @@ -19,7 +19,7 @@ struct IntLiteral : public Expression { // FIXME: we will need to revisit this if/when we add full support for both signed and unsigned // 64-bit integers, but for right now an int64_t will hold every value we care about IntLiteral(Position position, int64_t value) - : INHERITED(position, kIntLiteral_Kind, *kInt_Type) + : INHERITED(position, kIntLiteral_Kind, kInt_Type) , fValue(value) {} virtual std::string description() const override { diff --git a/src/sksl/ir/SkSLInterfaceBlock.h b/src/sksl/ir/SkSLInterfaceBlock.h index f1121ed707..baedb5864c 100644 --- a/src/sksl/ir/SkSLInterfaceBlock.h +++ b/src/sksl/ir/SkSLInterfaceBlock.h @@ -24,24 +24,22 @@ namespace SkSL { * At the IR level, this is represented by a single variable of struct type. */ struct InterfaceBlock : public ProgramElement { - InterfaceBlock(Position position, const Variable& var, std::shared_ptr<SymbolTable> typeOwner) + InterfaceBlock(Position position, std::shared_ptr<Variable> var) : INHERITED(position, kInterfaceBlock_Kind) - , fVariable(std::move(var)) - , fTypeOwner(typeOwner) { - ASSERT(fVariable.fType.kind() == Type::kStruct_Kind); + , fVariable(std::move(var)) { + ASSERT(fVariable->fType->kind() == Type::kStruct_Kind); } std::string description() const override { - std::string result = fVariable.fModifiers.description() + fVariable.fName + " {\n"; - for (size_t i = 0; i < fVariable.fType.fields().size(); i++) { - result += fVariable.fType.fields()[i].description() + "\n"; + std::string result = fVariable->fModifiers.description() + fVariable->fName + " {\n"; + for (size_t i = 0; i < fVariable->fType->fields().size(); i++) { + result += fVariable->fType->fields()[i].description() + "\n"; } result += "};"; return result; } - const Variable& fVariable; - const std::shared_ptr<SymbolTable> fTypeOwner; + const std::shared_ptr<Variable> fVariable; typedef ProgramElement INHERITED; }; diff --git a/src/sksl/ir/SkSLProgram.h b/src/sksl/ir/SkSLProgram.h index 205db6e932..5edcfded42 100644 --- a/src/sksl/ir/SkSLProgram.h +++ b/src/sksl/ir/SkSLProgram.h @@ -12,7 +12,6 @@ #include <memory> #include "SkSLProgramElement.h" -#include "SkSLSymbolTable.h" namespace SkSL { @@ -25,16 +24,13 @@ struct Program { kVertex_Kind }; - Program(Kind kind, std::vector<std::unique_ptr<ProgramElement>> elements, - std::shared_ptr<SymbolTable> symbols) + Program(Kind kind, std::vector<std::unique_ptr<ProgramElement>> elements) : fKind(kind) - , fElements(std::move(elements)) - , fSymbols(symbols) {} + , fElements(std::move(elements)) {} Kind fKind; std::vector<std::unique_ptr<ProgramElement>> fElements; - std::shared_ptr<SymbolTable> fSymbols; }; } // namespace diff --git a/src/sksl/ir/SkSLSwizzle.h b/src/sksl/ir/SkSLSwizzle.h index fad71b8114..ce360d1847 100644 --- a/src/sksl/ir/SkSLSwizzle.h +++ b/src/sksl/ir/SkSLSwizzle.h @@ -18,40 +18,41 @@ namespace SkSL { * instance, swizzling a vec3 with two components will result in a vec2. It is possible to swizzle * with more components than the source vector, as in 'vec2(1).xxxx'. */ -static const Type& get_type(Expression& value, size_t count) { - const Type& base = value.fType.componentType(); +static std::shared_ptr<Type> get_type(Expression& value, + size_t count) { + std::shared_ptr<Type> base = value.fType->componentType(); if (count == 1) { return base; } - if (base == *kFloat_Type) { + if (base == kFloat_Type) { switch (count) { - case 2: return *kVec2_Type; - case 3: return *kVec3_Type; - case 4: return *kVec4_Type; + case 2: return kVec2_Type; + case 3: return kVec3_Type; + case 4: return kVec4_Type; } - } else if (base == *kDouble_Type) { + } else if (base == kDouble_Type) { switch (count) { - case 2: return *kDVec2_Type; - case 3: return *kDVec3_Type; - case 4: return *kDVec4_Type; + case 2: return kDVec2_Type; + case 3: return kDVec3_Type; + case 4: return kDVec4_Type; } - } else if (base == *kInt_Type) { + } else if (base == kInt_Type) { switch (count) { - case 2: return *kIVec2_Type; - case 3: return *kIVec3_Type; - case 4: return *kIVec4_Type; + case 2: return kIVec2_Type; + case 3: return kIVec3_Type; + case 4: return kIVec4_Type; } - } else if (base == *kUInt_Type) { + } else if (base == kUInt_Type) { switch (count) { - case 2: return *kUVec2_Type; - case 3: return *kUVec3_Type; - case 4: return *kUVec4_Type; + case 2: return kUVec2_Type; + case 3: return kUVec3_Type; + case 4: return kUVec4_Type; } - } else if (base == *kBool_Type) { + } else if (base == kBool_Type) { switch (count) { - case 2: return *kBVec2_Type; - case 3: return *kBVec3_Type; - case 4: return *kBVec4_Type; + case 2: return kBVec2_Type; + case 3: return kBVec3_Type; + case 4: return kBVec4_Type; } } ABORT("cannot swizzle %s\n", value.description().c_str()); diff --git a/src/sksl/ir/SkSLSymbolTable.cpp b/src/sksl/ir/SkSLSymbolTable.cpp index 80e22da009..af83f7a456 100644 --- a/src/sksl/ir/SkSLSymbolTable.cpp +++ b/src/sksl/ir/SkSLSymbolTable.cpp @@ -5,23 +5,23 @@ * found in the LICENSE file. */ -#include "SkSLSymbolTable.h" -#include "SkSLUnresolvedFunction.h" + #include "SkSLSymbolTable.h" namespace SkSL { -std::vector<const FunctionDeclaration*> SymbolTable::GetFunctions(const Symbol& s) { - switch (s.fKind) { +std::vector<std::shared_ptr<FunctionDeclaration>> SymbolTable::GetFunctions( + const std::shared_ptr<Symbol>& s) { + switch (s->fKind) { case Symbol::kFunctionDeclaration_Kind: - return { &((FunctionDeclaration&) s) }; + return { std::static_pointer_cast<FunctionDeclaration>(s) }; case Symbol::kUnresolvedFunction_Kind: - return ((UnresolvedFunction&) s).fFunctions; + return ((UnresolvedFunction&) *s).fFunctions; default: return { }; } } -const Symbol* SymbolTable::operator[](const std::string& name) { +std::shared_ptr<Symbol> SymbolTable::operator[](const std::string& name) { const auto& entry = fSymbols.find(name); if (entry == fSymbols.end()) { if (fParent) { @@ -30,15 +30,15 @@ const Symbol* SymbolTable::operator[](const std::string& name) { return nullptr; } if (fParent) { - auto functions = GetFunctions(*entry->second); + auto functions = GetFunctions(entry->second); if (functions.size() > 0) { bool modified = false; - const Symbol* previous = (*fParent)[name]; + std::shared_ptr<Symbol> previous = (*fParent)[name]; if (previous) { - auto previousFunctions = GetFunctions(*previous); - for (const FunctionDeclaration* prev : previousFunctions) { + auto previousFunctions = GetFunctions(previous); + for (const std::shared_ptr<FunctionDeclaration>& prev : previousFunctions) { bool found = false; - for (const FunctionDeclaration* current : functions) { + for (const std::shared_ptr<FunctionDeclaration>& current : functions) { if (current->matches(*prev)) { found = true; break; @@ -51,7 +51,7 @@ const Symbol* SymbolTable::operator[](const std::string& name) { } if (modified) { ASSERT(functions.size() > 1); - return this->takeOwnership(new UnresolvedFunction(functions)); + return std::shared_ptr<Symbol>(new UnresolvedFunction(functions)); } } } @@ -59,42 +59,27 @@ const Symbol* SymbolTable::operator[](const std::string& name) { return entry->second; } -Symbol* SymbolTable::takeOwnership(Symbol* s) { - fOwnedPointers.push_back(std::unique_ptr<Symbol>(s)); - return s; -} - -void SymbolTable::add(const std::string& name, std::unique_ptr<Symbol> symbol) { - this->addWithoutOwnership(name, symbol.get()); - fOwnedPointers.push_back(std::move(symbol)); -} - -void SymbolTable::addWithoutOwnership(const std::string& name, const Symbol* symbol) { - const auto& existing = fSymbols.find(name); - if (existing == fSymbols.end()) { - fSymbols[name] = symbol; - } else if (symbol->fKind == Symbol::kFunctionDeclaration_Kind) { - const Symbol* oldSymbol = existing->second; - if (oldSymbol->fKind == Symbol::kFunctionDeclaration_Kind) { - std::vector<const FunctionDeclaration*> functions; - functions.push_back((const FunctionDeclaration*) oldSymbol); - functions.push_back((const FunctionDeclaration*) symbol); - UnresolvedFunction* u = new UnresolvedFunction(std::move(functions)); - fSymbols[name] = u; - this->takeOwnership(u); - } else if (oldSymbol->fKind == Symbol::kUnresolvedFunction_Kind) { - std::vector<const FunctionDeclaration*> functions; - for (const auto* f : ((UnresolvedFunction&) *oldSymbol).fFunctions) { - functions.push_back(f); +void SymbolTable::add(const std::string& name, std::shared_ptr<Symbol> symbol) { + const auto& existing = fSymbols.find(name); + if (existing == fSymbols.end()) { + fSymbols[name] = symbol; + } else if (symbol->fKind == Symbol::kFunctionDeclaration_Kind) { + const std::shared_ptr<Symbol>& oldSymbol = existing->second; + if (oldSymbol->fKind == Symbol::kFunctionDeclaration_Kind) { + std::vector<std::shared_ptr<FunctionDeclaration>> functions; + functions.push_back(std::static_pointer_cast<FunctionDeclaration>(oldSymbol)); + functions.push_back(std::static_pointer_cast<FunctionDeclaration>(symbol)); + fSymbols[name].reset(new UnresolvedFunction(std::move(functions))); + } else if (oldSymbol->fKind == Symbol::kUnresolvedFunction_Kind) { + std::vector<std::shared_ptr<FunctionDeclaration>> functions; + for (const auto& f : ((UnresolvedFunction&) *oldSymbol).fFunctions) { + functions.push_back(f); + } + functions.push_back(std::static_pointer_cast<FunctionDeclaration>(symbol)); + fSymbols[name].reset(new UnresolvedFunction(std::move(functions))); } - functions.push_back((const FunctionDeclaration*) symbol); - UnresolvedFunction* u = new UnresolvedFunction(std::move(functions)); - fSymbols[name] = u; - this->takeOwnership(u); + } else { + fErrorReporter.error(symbol->fPosition, "symbol '" + name + "' was already defined"); } - } else { - fErrorReporter.error(symbol->fPosition, "symbol '" + name + "' was already defined"); } -} - } // namespace diff --git a/src/sksl/ir/SkSLSymbolTable.h b/src/sksl/ir/SkSLSymbolTable.h index d732023ff0..151475d642 100644 --- a/src/sksl/ir/SkSLSymbolTable.h +++ b/src/sksl/ir/SkSLSymbolTable.h @@ -10,14 +10,12 @@ #include <memory> #include <unordered_map> -#include <vector> #include "SkSLErrorReporter.h" #include "SkSLSymbol.h" +#include "SkSLUnresolvedFunction.h" namespace SkSL { -struct FunctionDeclaration; - /** * Maps identifiers to symbols. Functions, in particular, are mapped to either FunctionDeclaration * or UnresolvedFunction depending on whether they are overloaded or not. @@ -31,22 +29,17 @@ public: : fParent(parent) , fErrorReporter(errorReporter) {} - const Symbol* operator[](const std::string& name); - - void add(const std::string& name, std::unique_ptr<Symbol> symbol); + std::shared_ptr<Symbol> operator[](const std::string& name); - void addWithoutOwnership(const std::string& name, const Symbol* symbol); - - Symbol* takeOwnership(Symbol* s); + void add(const std::string& name, std::shared_ptr<Symbol> symbol); const std::shared_ptr<SymbolTable> fParent; private: - static std::vector<const FunctionDeclaration*> GetFunctions(const Symbol& s); - - std::vector<std::unique_ptr<Symbol>> fOwnedPointers; + static std::vector<std::shared_ptr<FunctionDeclaration>> GetFunctions( + const std::shared_ptr<Symbol>& s); - std::unordered_map<std::string, const Symbol*> fSymbols; + std::unordered_map<std::string, std::shared_ptr<Symbol>> fSymbols; ErrorReporter& fErrorReporter; }; diff --git a/src/sksl/ir/SkSLType.cpp b/src/sksl/ir/SkSLType.cpp index 671f40b79d..27cbd39e44 100644 --- a/src/sksl/ir/SkSLType.cpp +++ b/src/sksl/ir/SkSLType.cpp @@ -9,26 +9,26 @@ namespace SkSL { -bool Type::determineCoercionCost(const Type& other, int* outCost) const { - if (*this == other) { +bool Type::determineCoercionCost(std::shared_ptr<Type> other, int* outCost) const { + if (this == other.get()) { *outCost = 0; return true; } - if (this->kind() == kVector_Kind && other.kind() == kVector_Kind) { - if (this->columns() == other.columns()) { - return this->componentType().determineCoercionCost(other.componentType(), outCost); + if (this->kind() == kVector_Kind && other->kind() == kVector_Kind) { + if (this->columns() == other->columns()) { + return this->componentType()->determineCoercionCost(other->componentType(), outCost); } return false; } if (this->kind() == kMatrix_Kind) { - if (this->columns() == other.columns() && - this->rows() == other.rows()) { - return this->componentType().determineCoercionCost(other.componentType(), outCost); + if (this->columns() == other->columns() && + this->rows() == other->rows()) { + return this->componentType()->determineCoercionCost(other->componentType(), outCost); } return false; } for (size_t i = 0; i < fCoercibleTypes.size(); i++) { - if (*fCoercibleTypes[i] == other) { + if (fCoercibleTypes[i] == other) { *outCost = (int) i + 1; return true; } @@ -36,39 +36,39 @@ bool Type::determineCoercionCost(const Type& other, int* outCost) const { return false; } -const Type& Type::toCompound(int columns, int rows) const { +std::shared_ptr<Type> Type::toCompound(int columns, int rows) { ASSERT(this->kind() == Type::kScalar_Kind); if (columns == 1 && rows == 1) { - return *this; + return std::shared_ptr<Type>(this); } if (*this == *kFloat_Type) { switch (rows) { case 1: switch (columns) { - case 2: return *kVec2_Type; - case 3: return *kVec3_Type; - case 4: return *kVec4_Type; + case 2: return kVec2_Type; + case 3: return kVec3_Type; + case 4: return kVec4_Type; default: ABORT("unsupported vector column count (%d)", columns); } case 2: switch (columns) { - case 2: return *kMat2x2_Type; - case 3: return *kMat3x2_Type; - case 4: return *kMat4x2_Type; + case 2: return kMat2x2_Type; + case 3: return kMat3x2_Type; + case 4: return kMat4x2_Type; default: ABORT("unsupported matrix column count (%d)", columns); } case 3: switch (columns) { - case 2: return *kMat2x3_Type; - case 3: return *kMat3x3_Type; - case 4: return *kMat4x3_Type; + case 2: return kMat2x3_Type; + case 3: return kMat3x3_Type; + case 4: return kMat4x3_Type; default: ABORT("unsupported matrix column count (%d)", columns); } case 4: switch (columns) { - case 2: return *kMat2x4_Type; - case 3: return *kMat3x4_Type; - case 4: return *kMat4x4_Type; + case 2: return kMat2x4_Type; + case 3: return kMat3x4_Type; + case 4: return kMat4x4_Type; default: ABORT("unsupported matrix column count (%d)", columns); } default: ABORT("unsupported row count (%d)", rows); @@ -77,30 +77,30 @@ const Type& Type::toCompound(int columns, int rows) const { switch (rows) { case 1: switch (columns) { - case 2: return *kDVec2_Type; - case 3: return *kDVec3_Type; - case 4: return *kDVec4_Type; + case 2: return kDVec2_Type; + case 3: return kDVec3_Type; + case 4: return kDVec4_Type; default: ABORT("unsupported vector column count (%d)", columns); } case 2: switch (columns) { - case 2: return *kDMat2x2_Type; - case 3: return *kDMat3x2_Type; - case 4: return *kDMat4x2_Type; + case 2: return kDMat2x2_Type; + case 3: return kDMat3x2_Type; + case 4: return kDMat4x2_Type; default: ABORT("unsupported matrix column count (%d)", columns); } case 3: switch (columns) { - case 2: return *kDMat2x3_Type; - case 3: return *kDMat3x3_Type; - case 4: return *kDMat4x3_Type; + case 2: return kDMat2x3_Type; + case 3: return kDMat3x3_Type; + case 4: return kDMat4x3_Type; default: ABORT("unsupported matrix column count (%d)", columns); } case 4: switch (columns) { - case 2: return *kDMat2x4_Type; - case 3: return *kDMat3x4_Type; - case 4: return *kDMat4x4_Type; + case 2: return kDMat2x4_Type; + case 3: return kDMat3x4_Type; + case 4: return kDMat4x4_Type; default: ABORT("unsupported matrix column count (%d)", columns); } default: ABORT("unsupported row count (%d)", rows); @@ -109,9 +109,9 @@ const Type& Type::toCompound(int columns, int rows) const { switch (rows) { case 1: switch (columns) { - case 2: return *kIVec2_Type; - case 3: return *kIVec3_Type; - case 4: return *kIVec4_Type; + case 2: return kIVec2_Type; + case 3: return kIVec3_Type; + case 4: return kIVec4_Type; default: ABORT("unsupported vector column count (%d)", columns); } default: ABORT("unsupported row count (%d)", rows); @@ -120,9 +120,9 @@ const Type& Type::toCompound(int columns, int rows) const { switch (rows) { case 1: switch (columns) { - case 2: return *kUVec2_Type; - case 3: return *kUVec3_Type; - case 4: return *kUVec4_Type; + case 2: return kUVec2_Type; + case 3: return kUVec3_Type; + case 4: return kUVec4_Type; default: ABORT("unsupported vector column count (%d)", columns); } default: ABORT("unsupported row count (%d)", rows); @@ -131,118 +131,128 @@ const Type& Type::toCompound(int columns, int rows) const { ABORT("unsupported scalar_to_compound type %s", this->description().c_str()); } -const Type* kVoid_Type = new Type("void"); - -const Type* kDouble_Type = new Type("double", true); -const Type* kDVec2_Type = new Type("dvec2", *kDouble_Type, 2); -const Type* kDVec3_Type = new Type("dvec3", *kDouble_Type, 3); -const Type* kDVec4_Type = new Type("dvec4", *kDouble_Type, 4); - -const Type* kFloat_Type = new Type("float", true, { kDouble_Type }); -const Type* kVec2_Type = new Type("vec2", *kFloat_Type, 2); -const Type* kVec3_Type = new Type("vec3", *kFloat_Type, 3); -const Type* kVec4_Type = new Type("vec4", *kFloat_Type, 4); - -const Type* kUInt_Type = new Type("uint", true, { kFloat_Type, kDouble_Type }); -const Type* kUVec2_Type = new Type("uvec2", *kUInt_Type, 2); -const Type* kUVec3_Type = new Type("uvec3", *kUInt_Type, 3); -const Type* kUVec4_Type = new Type("uvec4", *kUInt_Type, 4); - -const Type* kInt_Type = new Type("int", true, { kUInt_Type, kFloat_Type, kDouble_Type }); -const Type* kIVec2_Type = new Type("ivec2", *kInt_Type, 2); -const Type* kIVec3_Type = new Type("ivec3", *kInt_Type, 3); -const Type* kIVec4_Type = new Type("ivec4", *kInt_Type, 4); - -const Type* kBool_Type = new Type("bool", false); -const Type* kBVec2_Type = new Type("bvec2", *kBool_Type, 2); -const Type* kBVec3_Type = new Type("bvec3", *kBool_Type, 3); -const Type* kBVec4_Type = new Type("bvec4", *kBool_Type, 4); - -const Type* kMat2x2_Type = new Type("mat2", *kFloat_Type, 2, 2); -const Type* kMat2x3_Type = new Type("mat2x3", *kFloat_Type, 2, 3); -const Type* kMat2x4_Type = new Type("mat2x4", *kFloat_Type, 2, 4); -const Type* kMat3x2_Type = new Type("mat3x2", *kFloat_Type, 3, 2); -const Type* kMat3x3_Type = new Type("mat3", *kFloat_Type, 3, 3); -const Type* kMat3x4_Type = new Type("mat3x4", *kFloat_Type, 3, 4); -const Type* kMat4x2_Type = new Type("mat4x2", *kFloat_Type, 4, 2); -const Type* kMat4x3_Type = new Type("mat4x3", *kFloat_Type, 4, 3); -const Type* kMat4x4_Type = new Type("mat4", *kFloat_Type, 4, 4); - -const Type* kDMat2x2_Type = new Type("dmat2", *kFloat_Type, 2, 2); -const Type* kDMat2x3_Type = new Type("dmat2x3", *kFloat_Type, 2, 3); -const Type* kDMat2x4_Type = new Type("dmat2x4", *kFloat_Type, 2, 4); -const Type* kDMat3x2_Type = new Type("dmat3x2", *kFloat_Type, 3, 2); -const Type* kDMat3x3_Type = new Type("dmat3", *kFloat_Type, 3, 3); -const Type* kDMat3x4_Type = new Type("dmat3x4", *kFloat_Type, 3, 4); -const Type* kDMat4x2_Type = new Type("dmat4x2", *kFloat_Type, 4, 2); -const Type* kDMat4x3_Type = new Type("dmat4x3", *kFloat_Type, 4, 3); -const Type* kDMat4x4_Type = new Type("dmat4", *kFloat_Type, 4, 4); - -const Type* kSampler1D_Type = new Type("sampler1D", SpvDim1D, false, false, false, true); -const Type* kSampler2D_Type = new Type("sampler2D", SpvDim2D, false, false, false, true); -const Type* kSampler3D_Type = new Type("sampler3D", SpvDim3D, false, false, false, true); -const Type* kSamplerCube_Type = new Type("samplerCube"); -const Type* kSampler2DRect_Type = new Type("sampler2DRect"); -const Type* kSampler1DArray_Type = new Type("sampler1DArray"); -const Type* kSampler2DArray_Type = new Type("sampler2DArray"); -const Type* kSamplerCubeArray_Type = new Type("samplerCubeArray"); -const Type* kSamplerBuffer_Type = new Type("samplerBuffer"); -const Type* kSampler2DMS_Type = new Type("sampler2DMS"); -const Type* kSampler2DMSArray_Type = new Type("sampler2DMSArray"); -const Type* kSampler1DShadow_Type = new Type("sampler1DShadow"); -const Type* kSampler2DShadow_Type = new Type("sampler2DShadow"); -const Type* kSamplerCubeShadow_Type = new Type("samplerCubeShadow"); -const Type* kSampler2DRectShadow_Type = new Type("sampler2DRectShadow"); -const Type* kSampler1DArrayShadow_Type = new Type("sampler1DArrayShadow"); -const Type* kSampler2DArrayShadow_Type = new Type("sampler2DArrayShadow"); -const Type* kSamplerCubeArrayShadow_Type = new Type("samplerCubeArrayShadow"); - -static std::vector<const Type*> type(const Type* t) { +const std::shared_ptr<Type> kVoid_Type(new Type("void")); + +const std::shared_ptr<Type> kDouble_Type(new Type("double", true)); +const std::shared_ptr<Type> kDVec2_Type(new Type("dvec2", kDouble_Type, 2)); +const std::shared_ptr<Type> kDVec3_Type(new Type("dvec3", kDouble_Type, 3)); +const std::shared_ptr<Type> kDVec4_Type(new Type("dvec4", kDouble_Type, 4)); + +const std::shared_ptr<Type> kFloat_Type(new Type("float", true, { kDouble_Type })); +const std::shared_ptr<Type> kVec2_Type(new Type("vec2", kFloat_Type, 2)); +const std::shared_ptr<Type> kVec3_Type(new Type("vec3", kFloat_Type, 3)); +const std::shared_ptr<Type> kVec4_Type(new Type("vec4", kFloat_Type, 4)); + +const std::shared_ptr<Type> kUInt_Type(new Type("uint", true, { kFloat_Type, kDouble_Type })); +const std::shared_ptr<Type> kUVec2_Type(new Type("uvec2", kUInt_Type, 2)); +const std::shared_ptr<Type> kUVec3_Type(new Type("uvec3", kUInt_Type, 3)); +const std::shared_ptr<Type> kUVec4_Type(new Type("uvec4", kUInt_Type, 4)); + +const std::shared_ptr<Type> kInt_Type(new Type("int", true, { kUInt_Type, kFloat_Type, + kDouble_Type })); +const std::shared_ptr<Type> kIVec2_Type(new Type("ivec2", kInt_Type, 2)); +const std::shared_ptr<Type> kIVec3_Type(new Type("ivec3", kInt_Type, 3)); +const std::shared_ptr<Type> kIVec4_Type(new Type("ivec4", kInt_Type, 4)); + +const std::shared_ptr<Type> kBool_Type(new Type("bool", false)); +const std::shared_ptr<Type> kBVec2_Type(new Type("bvec2", kBool_Type, 2)); +const std::shared_ptr<Type> kBVec3_Type(new Type("bvec3", kBool_Type, 3)); +const std::shared_ptr<Type> kBVec4_Type(new Type("bvec4", kBool_Type, 4)); + +const std::shared_ptr<Type> kMat2x2_Type(new Type("mat2", kFloat_Type, 2, 2)); +const std::shared_ptr<Type> kMat2x3_Type(new Type("mat2x3", kFloat_Type, 2, 3)); +const std::shared_ptr<Type> kMat2x4_Type(new Type("mat2x4", kFloat_Type, 2, 4)); +const std::shared_ptr<Type> kMat3x2_Type(new Type("mat3x2", kFloat_Type, 3, 2)); +const std::shared_ptr<Type> kMat3x3_Type(new Type("mat3", kFloat_Type, 3, 3)); +const std::shared_ptr<Type> kMat3x4_Type(new Type("mat3x4", kFloat_Type, 3, 4)); +const std::shared_ptr<Type> kMat4x2_Type(new Type("mat4x2", kFloat_Type, 4, 2)); +const std::shared_ptr<Type> kMat4x3_Type(new Type("mat4x3", kFloat_Type, 4, 3)); +const std::shared_ptr<Type> kMat4x4_Type(new Type("mat4", kFloat_Type, 4, 4)); + +const std::shared_ptr<Type> kDMat2x2_Type(new Type("dmat2", kFloat_Type, 2, 2)); +const std::shared_ptr<Type> kDMat2x3_Type(new Type("dmat2x3", kFloat_Type, 2, 3)); +const std::shared_ptr<Type> kDMat2x4_Type(new Type("dmat2x4", kFloat_Type, 2, 4)); +const std::shared_ptr<Type> kDMat3x2_Type(new Type("dmat3x2", kFloat_Type, 3, 2)); +const std::shared_ptr<Type> kDMat3x3_Type(new Type("dmat3", kFloat_Type, 3, 3)); +const std::shared_ptr<Type> kDMat3x4_Type(new Type("dmat3x4", kFloat_Type, 3, 4)); +const std::shared_ptr<Type> kDMat4x2_Type(new Type("dmat4x2", kFloat_Type, 4, 2)); +const std::shared_ptr<Type> kDMat4x3_Type(new Type("dmat4x3", kFloat_Type, 4, 3)); +const std::shared_ptr<Type> kDMat4x4_Type(new Type("dmat4", kFloat_Type, 4, 4)); + +const std::shared_ptr<Type> kSampler1D_Type(new Type("sampler1D", SpvDim1D, false, false, false, true)); +const std::shared_ptr<Type> kSampler2D_Type(new Type("sampler2D", SpvDim2D, false, false, false, true)); +const std::shared_ptr<Type> kSampler3D_Type(new Type("sampler3D", SpvDim3D, false, false, false, true)); +const std::shared_ptr<Type> kSamplerCube_Type(new Type("samplerCube")); +const std::shared_ptr<Type> kSampler2DRect_Type(new Type("sampler2DRect")); +const std::shared_ptr<Type> kSampler1DArray_Type(new Type("sampler1DArray")); +const std::shared_ptr<Type> kSampler2DArray_Type(new Type("sampler2DArray")); +const std::shared_ptr<Type> kSamplerCubeArray_Type(new Type("samplerCubeArray")); +const std::shared_ptr<Type> kSamplerBuffer_Type(new Type("samplerBuffer")); +const std::shared_ptr<Type> kSampler2DMS_Type(new Type("sampler2DMS")); +const std::shared_ptr<Type> kSampler2DMSArray_Type(new Type("sampler2DMSArray")); +const std::shared_ptr<Type> kSampler1DShadow_Type(new Type("sampler1DShadow")); +const std::shared_ptr<Type> kSampler2DShadow_Type(new Type("sampler2DShadow")); +const std::shared_ptr<Type> kSamplerCubeShadow_Type(new Type("samplerCubeShadow")); +const std::shared_ptr<Type> kSampler2DRectShadow_Type(new Type("sampler2DRectShadow")); +const std::shared_ptr<Type> kSampler1DArrayShadow_Type(new Type("sampler1DArrayShadow")); +const std::shared_ptr<Type> kSampler2DArrayShadow_Type(new Type("sampler2DArrayShadow")); +const std::shared_ptr<Type> kSamplerCubeArrayShadow_Type(new Type("samplerCubeArrayShadow")); + +static std::vector<std::shared_ptr<Type>> type(std::shared_ptr<Type> t) { return { t, t, t, t }; } // FIXME figure out what we're supposed to do with the gsampler et al. types -const Type* kGSampler1D_Type = new Type("$gsampler1D", type(kSampler1D_Type)); -const Type* kGSampler2D_Type = new Type("$gsampler2D", type(kSampler2D_Type)); -const Type* kGSampler3D_Type = new Type("$gsampler3D", type(kSampler3D_Type)); -const Type* kGSamplerCube_Type = new Type("$gsamplerCube", type(kSamplerCube_Type)); -const Type* kGSampler2DRect_Type = new Type("$gsampler2DRect", type(kSampler2DRect_Type)); -const Type* kGSampler1DArray_Type = new Type("$gsampler1DArray", type(kSampler1DArray_Type)); -const Type* kGSampler2DArray_Type = new Type("$gsampler2DArray", type(kSampler2DArray_Type)); -const Type* kGSamplerCubeArray_Type = new Type("$gsamplerCubeArray", type(kSamplerCubeArray_Type)); -const Type* kGSamplerBuffer_Type = new Type("$gsamplerBuffer", type(kSamplerBuffer_Type)); -const Type* kGSampler2DMS_Type = new Type("$gsampler2DMS", type(kSampler2DMS_Type)); -const Type* kGSampler2DMSArray_Type = new Type("$gsampler2DMSArray", type(kSampler2DMSArray_Type)); -const Type* kGSampler2DArrayShadow_Type = new Type("$gsampler2DArrayShadow", - type(kSampler2DArrayShadow_Type)); -const Type* kGSamplerCubeArrayShadow_Type = new Type("$gsamplerCubeArrayShadow", - type(kSamplerCubeArrayShadow_Type)); - -const Type* kGenType_Type = new Type("$genType", { kFloat_Type, kVec2_Type, kVec3_Type, - kVec4_Type }); -const Type* kGenDType_Type = new Type("$genDType", { kDouble_Type, kDVec2_Type, kDVec3_Type, - kDVec4_Type }); -const Type* kGenIType_Type = new Type("$genIType", { kInt_Type, kIVec2_Type, kIVec3_Type, - kIVec4_Type }); -const Type* kGenUType_Type = new Type("$genUType", { kUInt_Type, kUVec2_Type, kUVec3_Type, - kUVec4_Type }); -const Type* kGenBType_Type = new Type("$genBType", { kBool_Type, kBVec2_Type, kBVec3_Type, - kBVec4_Type }); - -const Type* kMat_Type = new Type("$mat"); - -const Type* kVec_Type = new Type("$vec", { kVec2_Type, kVec2_Type, kVec3_Type, kVec4_Type }); - -const Type* kGVec_Type = new Type("$gvec"); -const Type* kGVec2_Type = new Type("$gvec2"); -const Type* kGVec3_Type = new Type("$gvec3"); -const Type* kGVec4_Type = new Type("$gvec4", type(kVec4_Type)); -const Type* kDVec_Type = new Type("$dvec"); -const Type* kIVec_Type = new Type("$ivec"); -const Type* kUVec_Type = new Type("$uvec"); - -const Type* kBVec_Type = new Type("$bvec", { kBVec2_Type, kBVec2_Type, kBVec3_Type, kBVec4_Type }); - -const Type* kInvalid_Type = new Type("<INVALID>"); +const std::shared_ptr<Type> kGSampler1D_Type(new Type("$gsampler1D", type(kSampler1D_Type))); +const std::shared_ptr<Type> kGSampler2D_Type(new Type("$gsampler2D", type(kSampler2D_Type))); +const std::shared_ptr<Type> kGSampler3D_Type(new Type("$gsampler3D", type(kSampler3D_Type))); +const std::shared_ptr<Type> kGSamplerCube_Type(new Type("$gsamplerCube", type(kSamplerCube_Type))); +const std::shared_ptr<Type> kGSampler2DRect_Type(new Type("$gsampler2DRect", + type(kSampler2DRect_Type))); +const std::shared_ptr<Type> kGSampler1DArray_Type(new Type("$gsampler1DArray", + type(kSampler1DArray_Type))); +const std::shared_ptr<Type> kGSampler2DArray_Type(new Type("$gsampler2DArray", + type(kSampler2DArray_Type))); +const std::shared_ptr<Type> kGSamplerCubeArray_Type(new Type("$gsamplerCubeArray", + type(kSamplerCubeArray_Type))); +const std::shared_ptr<Type> kGSamplerBuffer_Type(new Type("$gsamplerBuffer", + type(kSamplerBuffer_Type))); +const std::shared_ptr<Type> kGSampler2DMS_Type(new Type("$gsampler2DMS", + type(kSampler2DMS_Type))); +const std::shared_ptr<Type> kGSampler2DMSArray_Type(new Type("$gsampler2DMSArray", + type(kSampler2DMSArray_Type))); +const std::shared_ptr<Type> kGSampler2DArrayShadow_Type(new Type("$gsampler2DArrayShadow", + type(kSampler2DArrayShadow_Type))); +const std::shared_ptr<Type> kGSamplerCubeArrayShadow_Type(new Type("$gsamplerCubeArrayShadow", + type(kSamplerCubeArrayShadow_Type))); + +const std::shared_ptr<Type> kGenType_Type(new Type("$genType", { kFloat_Type, kVec2_Type, + kVec3_Type, kVec4_Type })); +const std::shared_ptr<Type> kGenDType_Type(new Type("$genDType", { kDouble_Type, kDVec2_Type, + kDVec3_Type, kDVec4_Type })); +const std::shared_ptr<Type> kGenIType_Type(new Type("$genIType", { kInt_Type, kIVec2_Type, + kIVec3_Type, kIVec4_Type })); +const std::shared_ptr<Type> kGenUType_Type(new Type("$genUType", { kUInt_Type, kUVec2_Type, + kUVec3_Type, kUVec4_Type })); +const std::shared_ptr<Type> kGenBType_Type(new Type("$genBType", { kBool_Type, kBVec2_Type, + kBVec3_Type, kBVec4_Type })); + +const std::shared_ptr<Type> kMat_Type(new Type("$mat")); + +const std::shared_ptr<Type> kVec_Type(new Type("$vec", { kVec2_Type, kVec2_Type, kVec3_Type, + kVec4_Type })); + +const std::shared_ptr<Type> kGVec_Type(new Type("$gvec")); +const std::shared_ptr<Type> kGVec2_Type(new Type("$gvec2")); +const std::shared_ptr<Type> kGVec3_Type(new Type("$gvec3")); +const std::shared_ptr<Type> kGVec4_Type(new Type("$gvec4", type(kVec4_Type))); +const std::shared_ptr<Type> kDVec_Type(new Type("$dvec")); +const std::shared_ptr<Type> kIVec_Type(new Type("$ivec")); +const std::shared_ptr<Type> kUVec_Type(new Type("$uvec")); + +const std::shared_ptr<Type> kBVec_Type(new Type("$bvec", { kBVec2_Type, kBVec2_Type, + kBVec3_Type, kBVec4_Type })); + +const std::shared_ptr<Type> kInvalid_Type(new Type("<INVALID>")); } // namespace diff --git a/src/sksl/ir/SkSLType.h b/src/sksl/ir/SkSLType.h index 8a73b139ba..e17bae68db 100644 --- a/src/sksl/ir/SkSLType.h +++ b/src/sksl/ir/SkSLType.h @@ -24,18 +24,18 @@ namespace SkSL { class Type : public Symbol { public: struct Field { - Field(Modifiers modifiers, std::string name, const Type& type) + Field(Modifiers modifiers, std::string name, std::shared_ptr<Type> type) : fModifiers(modifiers) , fName(std::move(name)) , fType(std::move(type)) {} - const std::string description() const { - return fType.description() + " " + fName + ";"; + const std::string description() { + return fType->description() + " " + fName + ";"; } const Modifiers fModifiers; const std::string fName; - const Type& fType; + const std::shared_ptr<Type> fType; }; enum Kind { @@ -56,7 +56,7 @@ public: , fTypeKind(kOther_Kind) {} // Create a generic type which maps to the listed types. - Type(std::string name, std::vector<const Type*> types) + Type(std::string name, std::vector<std::shared_ptr<Type>> types) : INHERITED(Position(), kType_Kind, std::move(name)) , fTypeKind(kGeneric_Kind) , fCoercibleTypes(std::move(types)) { @@ -78,7 +78,7 @@ public: , fRows(1) {} // Create a scalar type which can be coerced to the listed types. - Type(std::string name, bool isNumber, std::vector<const Type*> coercibleTypes) + Type(std::string name, bool isNumber, std::vector<std::shared_ptr<Type>> coercibleTypes) : INHERITED(Position(), kType_Kind, std::move(name)) , fTypeKind(kScalar_Kind) , fIsNumber(isNumber) @@ -87,23 +87,23 @@ public: , fRows(1) {} // Create a vector type. - Type(std::string name, const Type& componentType, int columns) + Type(std::string name, std::shared_ptr<Type> componentType, int columns) : Type(name, kVector_Kind, componentType, columns) {} // Create a vector or array type. - Type(std::string name, Kind kind, const Type& componentType, int columns) + Type(std::string name, Kind kind, std::shared_ptr<Type> componentType, int columns) : INHERITED(Position(), kType_Kind, std::move(name)) , fTypeKind(kind) - , fComponentType(&componentType) + , fComponentType(std::move(componentType)) , fColumns(columns) , fRows(1) , fDimensions(SpvDim1D) {} // Create a matrix type. - Type(std::string name, const Type& componentType, int columns, int rows) + Type(std::string name, std::shared_ptr<Type> componentType, int columns, int rows) : INHERITED(Position(), kType_Kind, std::move(name)) , fTypeKind(kMatrix_Kind) - , fComponentType(&componentType) + , fComponentType(std::move(componentType)) , fColumns(columns) , fRows(rows) , fDimensions(SpvDim1D) {} @@ -153,7 +153,7 @@ public: * Returns true if an instance of this type can be freely coerced (implicitly converted) to * another type. */ - bool canCoerceTo(const Type& other) const { + bool canCoerceTo(std::shared_ptr<Type> other) const { int cost; return determineCoercionCost(other, &cost); } @@ -164,15 +164,15 @@ public: * costs. Returns true if a conversion is possible, false otherwise. The value of the out * parameter is undefined if false is returned. */ - bool determineCoercionCost(const Type& other, int* outCost) const; + bool determineCoercionCost(std::shared_ptr<Type> other, int* outCost) const; /** * For matrices and vectors, returns the type of individual cells (e.g. mat2 has a component * type of kFloat_Type). For all other types, causes an assertion failure. */ - const Type& componentType() const { + std::shared_ptr<Type> componentType() const { ASSERT(fComponentType); - return *fComponentType; + return fComponentType; } /** @@ -195,7 +195,7 @@ public: return fRows; } - const std::vector<Field>& fields() const { + std::vector<Field> fields() const { ASSERT(fTypeKind == kStruct_Kind); return fFields; } @@ -204,7 +204,7 @@ public: * For generic types, returns the types that this generic type can substitute for. For other * types, returns a list of other types that this type can be coerced into. */ - const std::vector<const Type*>& coercibleTypes() const { + std::vector<std::shared_ptr<Type>> coercibleTypes() const { ASSERT(fCoercibleTypes.size() > 0); return fCoercibleTypes; } @@ -257,7 +257,7 @@ public: case kStruct_Kind: { size_t result = 16; for (size_t i = 0; i < fFields.size(); i++) { - size_t alignment = fFields[i].fType.alignment(); + size_t alignment = fFields[i].fType->alignment(); if (alignment > result) { result = alignment; } @@ -300,13 +300,13 @@ public: case kStruct_Kind: { size_t total = 0; for (size_t i = 0; i < fFields.size(); i++) { - size_t alignment = fFields[i].fType.alignment(); + size_t alignment = fFields[i].fType->alignment(); if (total % alignment != 0) { total += alignment - total % alignment; } ASSERT(false); ASSERT(total % alignment == 0); - total += fFields[i].fType.size(); + total += fFields[i].fType->size(); } return total; } @@ -319,15 +319,15 @@ public: * Returns the corresponding vector or matrix type with the specified number of columns and * rows. */ - const Type& toCompound(int columns, int rows) const; + std::shared_ptr<Type> toCompound(int columns, int rows); private: typedef Symbol INHERITED; const Kind fTypeKind; const bool fIsNumber = false; - const Type* fComponentType = nullptr; - const std::vector<const Type*> fCoercibleTypes = { }; + const std::shared_ptr<Type> fComponentType = nullptr; + const std::vector<std::shared_ptr<Type>> fCoercibleTypes = { }; const int fColumns = -1; const int fRows = -1; const std::vector<Field> fFields = { }; @@ -338,100 +338,100 @@ private: const bool fIsSampled = false; }; -extern const Type* kVoid_Type; - -extern const Type* kFloat_Type; -extern const Type* kVec2_Type; -extern const Type* kVec3_Type; -extern const Type* kVec4_Type; -extern const Type* kDouble_Type; -extern const Type* kDVec2_Type; -extern const Type* kDVec3_Type; -extern const Type* kDVec4_Type; -extern const Type* kInt_Type; -extern const Type* kIVec2_Type; -extern const Type* kIVec3_Type; -extern const Type* kIVec4_Type; -extern const Type* kUInt_Type; -extern const Type* kUVec2_Type; -extern const Type* kUVec3_Type; -extern const Type* kUVec4_Type; -extern const Type* kBool_Type; -extern const Type* kBVec2_Type; -extern const Type* kBVec3_Type; -extern const Type* kBVec4_Type; - -extern const Type* kMat2x2_Type; -extern const Type* kMat2x3_Type; -extern const Type* kMat2x4_Type; -extern const Type* kMat3x2_Type; -extern const Type* kMat3x3_Type; -extern const Type* kMat3x4_Type; -extern const Type* kMat4x2_Type; -extern const Type* kMat4x3_Type; -extern const Type* kMat4x4_Type; - -extern const Type* kDMat2x2_Type; -extern const Type* kDMat2x3_Type; -extern const Type* kDMat2x4_Type; -extern const Type* kDMat3x2_Type; -extern const Type* kDMat3x3_Type; -extern const Type* kDMat3x4_Type; -extern const Type* kDMat4x2_Type; -extern const Type* kDMat4x3_Type; -extern const Type* kDMat4x4_Type; - -extern const Type* kSampler1D_Type; -extern const Type* kSampler2D_Type; -extern const Type* kSampler3D_Type; -extern const Type* kSamplerCube_Type; -extern const Type* kSampler2DRect_Type; -extern const Type* kSampler1DArray_Type; -extern const Type* kSampler2DArray_Type; -extern const Type* kSamplerCubeArray_Type; -extern const Type* kSamplerBuffer_Type; -extern const Type* kSampler2DMS_Type; -extern const Type* kSampler2DMSArray_Type; - -extern const Type* kGSampler1D_Type; -extern const Type* kGSampler2D_Type; -extern const Type* kGSampler3D_Type; -extern const Type* kGSamplerCube_Type; -extern const Type* kGSampler2DRect_Type; -extern const Type* kGSampler1DArray_Type; -extern const Type* kGSampler2DArray_Type; -extern const Type* kGSamplerCubeArray_Type; -extern const Type* kGSamplerBuffer_Type; -extern const Type* kGSampler2DMS_Type; -extern const Type* kGSampler2DMSArray_Type; - -extern const Type* kSampler1DShadow_Type; -extern const Type* kSampler2DShadow_Type; -extern const Type* kSamplerCubeShadow_Type; -extern const Type* kSampler2DRectShadow_Type; -extern const Type* kSampler1DArrayShadow_Type; -extern const Type* kSampler2DArrayShadow_Type; -extern const Type* kSamplerCubeArrayShadow_Type; -extern const Type* kGSampler2DArrayShadow_Type; -extern const Type* kGSamplerCubeArrayShadow_Type; - -extern const Type* kGenType_Type; -extern const Type* kGenDType_Type; -extern const Type* kGenIType_Type; -extern const Type* kGenUType_Type; -extern const Type* kGenBType_Type; -extern const Type* kMat_Type; -extern const Type* kVec_Type; -extern const Type* kGVec_Type; -extern const Type* kGVec2_Type; -extern const Type* kGVec3_Type; -extern const Type* kGVec4_Type; -extern const Type* kDVec_Type; -extern const Type* kIVec_Type; -extern const Type* kUVec_Type; -extern const Type* kBVec_Type; - -extern const Type* kInvalid_Type; +extern const std::shared_ptr<Type> kVoid_Type; + +extern const std::shared_ptr<Type> kFloat_Type; +extern const std::shared_ptr<Type> kVec2_Type; +extern const std::shared_ptr<Type> kVec3_Type; +extern const std::shared_ptr<Type> kVec4_Type; +extern const std::shared_ptr<Type> kDouble_Type; +extern const std::shared_ptr<Type> kDVec2_Type; +extern const std::shared_ptr<Type> kDVec3_Type; +extern const std::shared_ptr<Type> kDVec4_Type; +extern const std::shared_ptr<Type> kInt_Type; +extern const std::shared_ptr<Type> kIVec2_Type; +extern const std::shared_ptr<Type> kIVec3_Type; +extern const std::shared_ptr<Type> kIVec4_Type; +extern const std::shared_ptr<Type> kUInt_Type; +extern const std::shared_ptr<Type> kUVec2_Type; +extern const std::shared_ptr<Type> kUVec3_Type; +extern const std::shared_ptr<Type> kUVec4_Type; +extern const std::shared_ptr<Type> kBool_Type; +extern const std::shared_ptr<Type> kBVec2_Type; +extern const std::shared_ptr<Type> kBVec3_Type; +extern const std::shared_ptr<Type> kBVec4_Type; + +extern const std::shared_ptr<Type> kMat2x2_Type; +extern const std::shared_ptr<Type> kMat2x3_Type; +extern const std::shared_ptr<Type> kMat2x4_Type; +extern const std::shared_ptr<Type> kMat3x2_Type; +extern const std::shared_ptr<Type> kMat3x3_Type; +extern const std::shared_ptr<Type> kMat3x4_Type; +extern const std::shared_ptr<Type> kMat4x2_Type; +extern const std::shared_ptr<Type> kMat4x3_Type; +extern const std::shared_ptr<Type> kMat4x4_Type; + +extern const std::shared_ptr<Type> kDMat2x2_Type; +extern const std::shared_ptr<Type> kDMat2x3_Type; +extern const std::shared_ptr<Type> kDMat2x4_Type; +extern const std::shared_ptr<Type> kDMat3x2_Type; +extern const std::shared_ptr<Type> kDMat3x3_Type; +extern const std::shared_ptr<Type> kDMat3x4_Type; +extern const std::shared_ptr<Type> kDMat4x2_Type; +extern const std::shared_ptr<Type> kDMat4x3_Type; +extern const std::shared_ptr<Type> kDMat4x4_Type; + +extern const std::shared_ptr<Type> kSampler1D_Type; +extern const std::shared_ptr<Type> kSampler2D_Type; +extern const std::shared_ptr<Type> kSampler3D_Type; +extern const std::shared_ptr<Type> kSamplerCube_Type; +extern const std::shared_ptr<Type> kSampler2DRect_Type; +extern const std::shared_ptr<Type> kSampler1DArray_Type; +extern const std::shared_ptr<Type> kSampler2DArray_Type; +extern const std::shared_ptr<Type> kSamplerCubeArray_Type; +extern const std::shared_ptr<Type> kSamplerBuffer_Type; +extern const std::shared_ptr<Type> kSampler2DMS_Type; +extern const std::shared_ptr<Type> kSampler2DMSArray_Type; + +extern const std::shared_ptr<Type> kGSampler1D_Type; +extern const std::shared_ptr<Type> kGSampler2D_Type; +extern const std::shared_ptr<Type> kGSampler3D_Type; +extern const std::shared_ptr<Type> kGSamplerCube_Type; +extern const std::shared_ptr<Type> kGSampler2DRect_Type; +extern const std::shared_ptr<Type> kGSampler1DArray_Type; +extern const std::shared_ptr<Type> kGSampler2DArray_Type; +extern const std::shared_ptr<Type> kGSamplerCubeArray_Type; +extern const std::shared_ptr<Type> kGSamplerBuffer_Type; +extern const std::shared_ptr<Type> kGSampler2DMS_Type; +extern const std::shared_ptr<Type> kGSampler2DMSArray_Type; + +extern const std::shared_ptr<Type> kSampler1DShadow_Type; +extern const std::shared_ptr<Type> kSampler2DShadow_Type; +extern const std::shared_ptr<Type> kSamplerCubeShadow_Type; +extern const std::shared_ptr<Type> kSampler2DRectShadow_Type; +extern const std::shared_ptr<Type> kSampler1DArrayShadow_Type; +extern const std::shared_ptr<Type> kSampler2DArrayShadow_Type; +extern const std::shared_ptr<Type> kSamplerCubeArrayShadow_Type; +extern const std::shared_ptr<Type> kGSampler2DArrayShadow_Type; +extern const std::shared_ptr<Type> kGSamplerCubeArrayShadow_Type; + +extern const std::shared_ptr<Type> kGenType_Type; +extern const std::shared_ptr<Type> kGenDType_Type; +extern const std::shared_ptr<Type> kGenIType_Type; +extern const std::shared_ptr<Type> kGenUType_Type; +extern const std::shared_ptr<Type> kGenBType_Type; +extern const std::shared_ptr<Type> kMat_Type; +extern const std::shared_ptr<Type> kVec_Type; +extern const std::shared_ptr<Type> kGVec_Type; +extern const std::shared_ptr<Type> kGVec2_Type; +extern const std::shared_ptr<Type> kGVec3_Type; +extern const std::shared_ptr<Type> kGVec4_Type; +extern const std::shared_ptr<Type> kDVec_Type; +extern const std::shared_ptr<Type> kIVec_Type; +extern const std::shared_ptr<Type> kUVec_Type; +extern const std::shared_ptr<Type> kBVec_Type; + +extern const std::shared_ptr<Type> kInvalid_Type; } // namespace diff --git a/src/sksl/ir/SkSLTypeReference.h b/src/sksl/ir/SkSLTypeReference.h index 8f2ab4fe0e..5f4990f35d 100644 --- a/src/sksl/ir/SkSLTypeReference.h +++ b/src/sksl/ir/SkSLTypeReference.h @@ -17,16 +17,16 @@ namespace SkSL { * always eventually replaced by Constructors in valid programs. */ struct TypeReference : public Expression { - TypeReference(Position position, const Type& type) - : INHERITED(position, kTypeReference_Kind, *kInvalid_Type) - , fValue(type) {} + TypeReference(Position position, std::shared_ptr<Type> type) + : INHERITED(position, kTypeReference_Kind, kInvalid_Type) + , fValue(std::move(type)) {} std::string description() const override { ASSERT(false); return "<type>"; } - const Type& fValue; + const std::shared_ptr<Type> fValue; typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLUnresolvedFunction.h b/src/sksl/ir/SkSLUnresolvedFunction.h index 3a368ad8d3..a6cee0d072 100644 --- a/src/sksl/ir/SkSLUnresolvedFunction.h +++ b/src/sksl/ir/SkSLUnresolvedFunction.h @@ -16,21 +16,19 @@ namespace SkSL { * A symbol representing multiple functions with the same name. */ struct UnresolvedFunction : public Symbol { - UnresolvedFunction(std::vector<const FunctionDeclaration*> funcs) + UnresolvedFunction(std::vector<std::shared_ptr<FunctionDeclaration>> funcs) : INHERITED(Position(), kUnresolvedFunction_Kind, funcs[0]->fName) , fFunctions(std::move(funcs)) { -#ifdef DEBUG for (auto func : funcs) { ASSERT(func->fName == fName); } -#endif } virtual std::string description() const override { return fName; } - const std::vector<const FunctionDeclaration*> fFunctions; + const std::vector<std::shared_ptr<FunctionDeclaration>> fFunctions; typedef Symbol INHERITED; }; diff --git a/src/sksl/ir/SkSLVarDeclaration.h b/src/sksl/ir/SkSLVarDeclaration.h index b234231b86..400f430e4c 100644 --- a/src/sksl/ir/SkSLVarDeclaration.h +++ b/src/sksl/ir/SkSLVarDeclaration.h @@ -20,7 +20,7 @@ namespace SkSL { * names ['x', 'y', 'z'], sizes of [[], [], [4, 2]], and values of [null, 1, null]. */ struct VarDeclaration : public ProgramElement { - VarDeclaration(Position position, std::vector<const Variable*> vars, + VarDeclaration(Position position, std::vector<std::shared_ptr<Variable>> vars, std::vector<std::vector<std::unique_ptr<Expression>>> sizes, std::vector<std::unique_ptr<Expression>> values) : INHERITED(position, kVar_Kind) @@ -30,9 +30,9 @@ struct VarDeclaration : public ProgramElement { std::string description() const override { std::string result = fVars[0]->fModifiers.description(); - const Type* baseType = &fVars[0]->fType; + std::shared_ptr<Type> baseType = fVars[0]->fType; while (baseType->kind() == Type::kArray_Kind) { - baseType = &baseType->componentType(); + baseType = baseType->componentType(); } result += baseType->description(); std::string separator = " "; @@ -55,7 +55,7 @@ struct VarDeclaration : public ProgramElement { return result; } - const std::vector<const Variable*> fVars; + const std::vector<std::shared_ptr<Variable>> fVars; const std::vector<std::vector<std::unique_ptr<Expression>>> fSizes; const std::vector<std::unique_ptr<Expression>> fValues; diff --git a/src/sksl/ir/SkSLVariable.h b/src/sksl/ir/SkSLVariable.h index 39af3093b6..d4ea2c4a43 100644 --- a/src/sksl/ir/SkSLVariable.h +++ b/src/sksl/ir/SkSLVariable.h @@ -27,7 +27,7 @@ struct Variable : public Symbol { kParameter_Storage }; - Variable(Position position, Modifiers modifiers, std::string name, const Type& type, + Variable(Position position, Modifiers modifiers, std::string name, std::shared_ptr<Type> type, Storage storage) : INHERITED(position, kVariable_Kind, std::move(name)) , fModifiers(modifiers) @@ -37,11 +37,12 @@ struct Variable : public Symbol { , fIsWrittenTo(false) {} virtual std::string description() const override { - return fModifiers.description() + fType.fName + " " + fName; + return fModifiers.description() + fType->fName + " " + fName; } const Modifiers fModifiers; - const Type& fType; + const std::string fValue; + const std::shared_ptr<Type> fType; const Storage fStorage; mutable bool fIsReadFrom; @@ -52,4 +53,14 @@ struct Variable : public Symbol { } // namespace SkSL +namespace std { + template <> + struct hash<SkSL::Variable> { + public : + size_t operator()(const SkSL::Variable &var) const{ + return hash<std::string>()(var.fName) ^ hash<std::string>()(var.fType->description()); + } + }; +} // namespace std + #endif diff --git a/src/sksl/ir/SkSLVariableReference.h b/src/sksl/ir/SkSLVariableReference.h index b443da1f22..8499511a1b 100644 --- a/src/sksl/ir/SkSLVariableReference.h +++ b/src/sksl/ir/SkSLVariableReference.h @@ -20,15 +20,15 @@ namespace SkSL { * there is only one Variable 'x', but two VariableReferences to it. */ struct VariableReference : public Expression { - VariableReference(Position position, const Variable& variable) - : INHERITED(position, kVariableReference_Kind, variable.fType) - , fVariable(variable) {} + VariableReference(Position position, std::shared_ptr<Variable> variable) + : INHERITED(position, kVariableReference_Kind, variable->fType) + , fVariable(std::move(variable)) {} std::string description() const override { - return fVariable.fName; + return fVariable->fName; } - const Variable& fVariable; + const std::shared_ptr<Variable> fVariable; typedef Expression INHERITED; }; |