diff options
31 files changed, 424 insertions, 223 deletions
diff --git a/src/sksl/SkSLCFGGenerator.cpp b/src/sksl/SkSLCFGGenerator.cpp index 964a8dc84a..31bace9fb7 100644 --- a/src/sksl/SkSLCFGGenerator.cpp +++ b/src/sksl/SkSLCFGGenerator.cpp @@ -54,8 +54,8 @@ void CFG::dump() { printf("Block %d\n-------\nBefore: ", (int) i); const char* separator = ""; for (auto iter = fBlocks[i].fBefore.begin(); iter != fBlocks[i].fBefore.end(); iter++) { - printf("%s%s = %s", separator, iter->first->description().c_str(), - iter->second ? iter->second->description().c_str() : "<undefined>"); + printf("%s%s = %s", separator, iter->first->description().c_str(), + *iter->second ? (*iter->second)->description().c_str() : "<undefined>"); separator = ", "; } printf("\nEntrances: "); @@ -66,7 +66,10 @@ void CFG::dump() { } printf("\n"); for (size_t j = 0; j < fBlocks[i].fNodes.size(); j++) { - printf("Node %d: %s\n", (int) j, fBlocks[i].fNodes[j].fNode->description().c_str()); + BasicBlock::Node& n = fBlocks[i].fNodes[j]; + printf("Node %d: %s\n", (int) j, n.fKind == BasicBlock::Node::kExpression_Kind + ? (*n.fExpression)->description().c_str() + : n.fStatement->description().c_str()); } printf("Exits: "); separator = ""; @@ -78,96 +81,109 @@ void CFG::dump() { } } -void CFGGenerator::addExpression(CFG& cfg, const Expression* e) { - switch (e->fKind) { +void CFGGenerator::addExpression(CFG& cfg, std::unique_ptr<Expression>* e, bool constantPropagate) { + ASSERT(e); + switch ((*e)->fKind) { case Expression::kBinary_Kind: { - const BinaryExpression* b = (const BinaryExpression*) e; + BinaryExpression* b = (BinaryExpression*) e->get(); switch (b->fOperator) { case Token::LOGICALAND: // fall through case Token::LOGICALOR: { // this isn't as precise as it could be -- we don't bother to track that if we // early exit from a logical and/or, we know which branch of an 'if' we're going // to hit -- but it won't make much difference in practice. - this->addExpression(cfg, b->fLeft.get()); + this->addExpression(cfg, &b->fLeft, constantPropagate); BlockId start = cfg.fCurrent; cfg.newBlock(); - this->addExpression(cfg, b->fRight.get()); + this->addExpression(cfg, &b->fRight, constantPropagate); cfg.newBlock(); cfg.addExit(start, cfg.fCurrent); break; } case Token::EQ: { - this->addExpression(cfg, b->fRight.get()); - this->addLValue(cfg, b->fLeft.get()); - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ - BasicBlock::Node::kExpression_Kind, - b + this->addExpression(cfg, &b->fRight, constantPropagate); + this->addLValue(cfg, &b->fLeft); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ + BasicBlock::Node::kExpression_Kind, + constantPropagate, + e, + nullptr }); break; } default: - this->addExpression(cfg, b->fLeft.get()); - this->addExpression(cfg, b->fRight.get()); - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ - BasicBlock::Node::kExpression_Kind, - b + this->addExpression(cfg, &b->fLeft, !Token::IsAssignment(b->fOperator)); + this->addExpression(cfg, &b->fRight, constantPropagate); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ + BasicBlock::Node::kExpression_Kind, + constantPropagate, + e, + nullptr }); } break; } case Expression::kConstructor_Kind: { - const Constructor* c = (const Constructor*) e; - for (const auto& arg : c->fArguments) { - this->addExpression(cfg, arg.get()); + Constructor* c = (Constructor*) e->get(); + for (auto& arg : c->fArguments) { + this->addExpression(cfg, &arg, constantPropagate); } - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, c }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; } case Expression::kFunctionCall_Kind: { - const FunctionCall* c = (const FunctionCall*) e; - for (const auto& arg : c->fArguments) { - this->addExpression(cfg, arg.get()); + FunctionCall* c = (FunctionCall*) e->get(); + for (auto& arg : c->fArguments) { + this->addExpression(cfg, &arg, constantPropagate); } - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, c }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; } case Expression::kFieldAccess_Kind: - this->addExpression(cfg, ((const FieldAccess*) e)->fBase.get()); - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e }); + this->addExpression(cfg, &((FieldAccess*) e->get())->fBase, constantPropagate); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; case Expression::kIndex_Kind: - this->addExpression(cfg, ((const IndexExpression*) e)->fBase.get()); - this->addExpression(cfg, ((const IndexExpression*) e)->fIndex.get()); - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e }); + this->addExpression(cfg, &((IndexExpression*) e->get())->fBase, constantPropagate); + this->addExpression(cfg, &((IndexExpression*) e->get())->fIndex, constantPropagate); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; case Expression::kPrefix_Kind: - this->addExpression(cfg, ((const PrefixExpression*) e)->fOperand.get()); - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e }); + this->addExpression(cfg, &((PrefixExpression*) e->get())->fOperand, constantPropagate); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; case Expression::kPostfix_Kind: - this->addExpression(cfg, ((const PostfixExpression*) e)->fOperand.get()); - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e }); + this->addExpression(cfg, &((PostfixExpression*) e->get())->fOperand, constantPropagate); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; case Expression::kSwizzle_Kind: - this->addExpression(cfg, ((const Swizzle*) e)->fBase.get()); - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e }); + this->addExpression(cfg, &((Swizzle*) e->get())->fBase, constantPropagate); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; case Expression::kBoolLiteral_Kind: // fall through case Expression::kFloatLiteral_Kind: // fall through case Expression::kIntLiteral_Kind: // fall through case Expression::kVariableReference_Kind: - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; case Expression::kTernary_Kind: { - const TernaryExpression* t = (const TernaryExpression*) e; - this->addExpression(cfg, t->fTest.get()); + TernaryExpression* t = (TernaryExpression*) e->get(); + this->addExpression(cfg, &t->fTest, constantPropagate); BlockId start = cfg.fCurrent; cfg.newBlock(); - this->addExpression(cfg, t->fIfTrue.get()); + this->addExpression(cfg, &t->fIfTrue, constantPropagate); BlockId next = cfg.newBlock(); cfg.fCurrent = start; cfg.newBlock(); - this->addExpression(cfg, t->fIfFalse.get()); + this->addExpression(cfg, &t->fIfFalse, constantPropagate); cfg.addExit(cfg.fCurrent, next); cfg.fCurrent = next; break; @@ -181,17 +197,17 @@ void CFGGenerator::addExpression(CFG& cfg, const Expression* e) { } // adds expressions that are evaluated as part of resolving an lvalue -void CFGGenerator::addLValue(CFG& cfg, const Expression* e) { - switch (e->fKind) { +void CFGGenerator::addLValue(CFG& cfg, std::unique_ptr<Expression>* e) { + switch ((*e)->fKind) { case Expression::kFieldAccess_Kind: - this->addLValue(cfg, ((const FieldAccess*) e)->fBase.get()); + this->addLValue(cfg, &((FieldAccess&) **e).fBase); break; case Expression::kIndex_Kind: - this->addLValue(cfg, ((const IndexExpression*) e)->fBase.get()); - this->addExpression(cfg, ((const IndexExpression*) e)->fIndex.get()); + this->addLValue(cfg, &((IndexExpression&) **e).fBase); + this->addExpression(cfg, &((IndexExpression&) **e).fIndex, true); break; case Expression::kSwizzle_Kind: - this->addLValue(cfg, ((const Swizzle*) e)->fBase.get()); + this->addLValue(cfg, &((Swizzle&) **e).fBase); break; case Expression::kVariableReference_Kind: break; @@ -210,8 +226,8 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) { } break; case Statement::kIf_Kind: { - const IfStatement* ifs = (const IfStatement*) s; - this->addExpression(cfg, ifs->fTest.get()); + IfStatement* ifs = (IfStatement*) s; + this->addExpression(cfg, &ifs->fTest, true); BlockId start = cfg.fCurrent; cfg.newBlock(); this->addStatement(cfg, ifs->fIfTrue.get()); @@ -228,49 +244,54 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) { break; } case Statement::kExpression_Kind: { - this->addExpression(cfg, ((ExpressionStatement&) *s).fExpression.get()); + this->addExpression(cfg, &((ExpressionStatement&) *s).fExpression, true); break; } case Statement::kVarDeclarations_Kind: { - const VarDeclarationsStatement& decls = ((VarDeclarationsStatement&) *s); - for (const auto& vd : decls.fDeclaration->fVars) { + VarDeclarationsStatement& decls = ((VarDeclarationsStatement&) *s); + for (auto& vd : decls.fDeclaration->fVars) { if (vd.fValue) { - this->addExpression(cfg, vd.fValue.get()); + this->addExpression(cfg, &vd.fValue, true); } } - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, s }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, false, + nullptr, s }); break; } case Statement::kDiscard_Kind: - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, s }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, false, + nullptr, s }); cfg.fCurrent = cfg.newIsolatedBlock(); break; case Statement::kReturn_Kind: { - const ReturnStatement& r = ((ReturnStatement&) *s); + ReturnStatement& r = ((ReturnStatement&) *s); if (r.fExpression) { - this->addExpression(cfg, r.fExpression.get()); + this->addExpression(cfg, &r.fExpression, true); } - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, s }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, false, + nullptr, s }); cfg.fCurrent = cfg.newIsolatedBlock(); break; } case Statement::kBreak_Kind: - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, s }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, false, + nullptr, s }); cfg.addExit(cfg.fCurrent, fLoopExits.top()); cfg.fCurrent = cfg.newIsolatedBlock(); break; case Statement::kContinue_Kind: - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, s }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, false, + nullptr, s }); cfg.addExit(cfg.fCurrent, fLoopContinues.top()); cfg.fCurrent = cfg.newIsolatedBlock(); break; case Statement::kWhile_Kind: { - const WhileStatement* w = (const WhileStatement*) s; + WhileStatement* w = (WhileStatement*) s; BlockId loopStart = cfg.newBlock(); fLoopContinues.push(loopStart); BlockId loopExit = cfg.newIsolatedBlock(); fLoopExits.push(loopExit); - this->addExpression(cfg, w->fTest.get()); + this->addExpression(cfg, &w->fTest, true); BlockId test = cfg.fCurrent; cfg.addExit(test, loopExit); cfg.newBlock(); @@ -282,13 +303,13 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) { break; } case Statement::kDo_Kind: { - const DoStatement* d = (const DoStatement*) s; + DoStatement* d = (DoStatement*) s; BlockId loopStart = cfg.newBlock(); fLoopContinues.push(loopStart); BlockId loopExit = cfg.newIsolatedBlock(); fLoopExits.push(loopExit); this->addStatement(cfg, d->fStatement.get()); - this->addExpression(cfg, d->fTest.get()); + this->addExpression(cfg, &d->fTest, true); cfg.addExit(cfg.fCurrent, loopExit); cfg.addExit(cfg.fCurrent, loopStart); fLoopContinues.pop(); @@ -297,7 +318,7 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) { break; } case Statement::kFor_Kind: { - const ForStatement* f = (const ForStatement*) s; + ForStatement* f = (ForStatement*) s; if (f->fInitializer) { this->addStatement(cfg, f->fInitializer.get()); } @@ -307,7 +328,7 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) { BlockId loopExit = cfg.newIsolatedBlock(); fLoopExits.push(loopExit); if (f->fTest) { - this->addExpression(cfg, f->fTest.get()); + this->addExpression(cfg, &f->fTest, true); BlockId test = cfg.fCurrent; cfg.addExit(test, loopExit); } @@ -316,9 +337,9 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) { cfg.addExit(cfg.fCurrent, next); cfg.fCurrent = next; if (f->fNext) { - this->addExpression(cfg, f->fNext.get()); + this->addExpression(cfg, &f->fNext, true); } - cfg.addExit(next, loopStart); + cfg.addExit(cfg.fCurrent, loopStart); fLoopContinues.pop(); fLoopExits.pop(); cfg.fCurrent = loopExit; diff --git a/src/sksl/SkSLCFGGenerator.h b/src/sksl/SkSLCFGGenerator.h index c37850112c..337fdfac35 100644 --- a/src/sksl/SkSLCFGGenerator.h +++ b/src/sksl/SkSLCFGGenerator.h @@ -27,14 +27,23 @@ struct BasicBlock { }; Kind fKind; - const IRNode* fNode; + // if false, this node should not be subject to constant propagation. This happens with + // compound assignment (i.e. x *= 2), in which the value x is used as an rvalue for + // multiplication by 2 and then as an lvalue for assignment purposes. Since there is only + // one "x" node, replacing it with a constant would break the assignment and we suppress + // it. Down the road, we should handle this more elegantly by substituting a regular + // assignment if the target is constant (i.e. x = 1; x *= 2; should become x = 1; x = 1 * 2; + // and then collapse down to a simple x = 2;). + bool fConstantPropagation; + std::unique_ptr<Expression>* fExpression; + const Statement* fStatement; }; - + std::vector<Node> fNodes; std::set<BlockId> fEntrances; std::set<BlockId> fExits; // variable definitions upon entering this basic block (null expression = undefined) - std::unordered_map<const Variable*, const Expression*> fBefore; + DefinitionMap fBefore; }; struct CFG { @@ -77,9 +86,9 @@ public: private: void addStatement(CFG& cfg, const Statement* s); - void addExpression(CFG& cfg, const Expression* e); + void addExpression(CFG& cfg, std::unique_ptr<Expression>* e, bool constantPropagate); - void addLValue(CFG& cfg, const Expression* e); + void addLValue(CFG& cfg, std::unique_ptr<Expression>* e); std::stack<BlockId> fLoopContinues; std::stack<BlockId> fLoopExits; diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp index 9faf836156..743745ad14 100644 --- a/src/sksl/SkSLCompiler.cpp +++ b/src/sksl/SkSLCompiler.cpp @@ -156,8 +156,8 @@ Compiler::~Compiler() { } // add the definition created by assigning to the lvalue to the definition set -void Compiler::addDefinition(const Expression* lvalue, const Expression* expr, - std::unordered_map<const Variable*, const Expression*>* definitions) { +void Compiler::addDefinition(const Expression* lvalue, std::unique_ptr<Expression>* expr, + DefinitionMap* definitions) { switch (lvalue->fKind) { case Expression::kVariableReference_Kind: { const Variable& var = ((VariableReference*) lvalue)->fVariable; @@ -174,19 +174,19 @@ void Compiler::addDefinition(const Expression* lvalue, const Expression* expr, // but since we pass foo as a whole it is flagged as an error) unless we perform a much // more complicated whole-program analysis. This is probably good enough. this->addDefinition(((Swizzle*) lvalue)->fBase.get(), - fContext.fDefined_Expression.get(), + (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, definitions); break; case Expression::kIndex_Kind: // see comments in Swizzle this->addDefinition(((IndexExpression*) lvalue)->fBase.get(), - fContext.fDefined_Expression.get(), + (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, definitions); break; case Expression::kFieldAccess_Kind: // see comments in Swizzle this->addDefinition(((FieldAccess*) lvalue)->fBase.get(), - fContext.fDefined_Expression.get(), + (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, definitions); break; default: @@ -197,25 +197,58 @@ void Compiler::addDefinition(const Expression* lvalue, const Expression* expr, // add local variables defined by this node to the set void Compiler::addDefinitions(const BasicBlock::Node& node, - std::unordered_map<const Variable*, const Expression*>* definitions) { + DefinitionMap* definitions) { switch (node.fKind) { case BasicBlock::Node::kExpression_Kind: { - const Expression* expr = (Expression*) node.fNode; - if (expr->fKind == Expression::kBinary_Kind) { - const BinaryExpression* b = (BinaryExpression*) expr; - if (b->fOperator == Token::EQ) { - this->addDefinition(b->fLeft.get(), b->fRight.get(), definitions); + ASSERT(node.fExpression); + const Expression* expr = (Expression*) node.fExpression->get(); + switch (expr->fKind) { + case Expression::kBinary_Kind: { + BinaryExpression* b = (BinaryExpression*) expr; + if (b->fOperator == Token::EQ) { + this->addDefinition(b->fLeft.get(), &b->fRight, definitions); + } else if (Token::IsAssignment(b->fOperator)) { + this->addDefinition( + b->fLeft.get(), + (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, + definitions); + + } + break; + } + case Expression::kPrefix_Kind: { + const PrefixExpression* p = (PrefixExpression*) expr; + if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) { + this->addDefinition( + p->fOperand.get(), + (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, + definitions); + } + break; } + case Expression::kPostfix_Kind: { + const PostfixExpression* p = (PostfixExpression*) expr; + if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) { + this->addDefinition( + p->fOperand.get(), + (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, + definitions); + + } + break; + } + default: + break; } break; } case BasicBlock::Node::kStatement_Kind: { - const Statement* stmt = (Statement*) node.fNode; + const Statement* stmt = (Statement*) node.fStatement; if (stmt->fKind == Statement::kVarDeclarations_Kind) { - const VarDeclarationsStatement* vd = (VarDeclarationsStatement*) stmt; - for (const VarDeclaration& decl : vd->fDeclaration->fVars) { + VarDeclarationsStatement* vd = (VarDeclarationsStatement*) stmt; + for (VarDeclaration& decl : vd->fDeclaration->fVars) { if (decl.fValue) { - (*definitions)[decl.fVar] = decl.fValue.get(); + (*definitions)[decl.fVar] = &decl.fValue; } } } @@ -228,7 +261,7 @@ void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) { BasicBlock& block = cfg->fBlocks[blockId]; // compute definitions after this block - std::unordered_map<const Variable*, const Expression*> after = block.fBefore; + DefinitionMap after = block.fBefore; for (const BasicBlock::Node& n : block.fNodes) { this->addDefinitions(n, &after); } @@ -237,19 +270,20 @@ void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) { for (BlockId exitId : block.fExits) { BasicBlock& exit = cfg->fBlocks[exitId]; for (const auto& pair : after) { - const Expression* e1 = pair.second; - if (exit.fBefore.find(pair.first) == exit.fBefore.end()) { + std::unique_ptr<Expression>* e1 = pair.second; + auto found = exit.fBefore.find(pair.first); + if (found == exit.fBefore.end()) { + // exit has no definition for it, just copy it + workList->insert(exitId); exit.fBefore[pair.first] = e1; } else { - const Expression* e2 = exit.fBefore[pair.first]; + // exit has a (possibly different) value already defined + std::unique_ptr<Expression>* e2 = exit.fBefore[pair.first]; if (e1 != e2) { // definition has changed, merge and add exit block to worklist workList->insert(exitId); - if (!e1 || !e2) { - exit.fBefore[pair.first] = nullptr; - } else { - exit.fBefore[pair.first] = fContext.fDefined_Expression.get(); - } + exit.fBefore[pair.first] = + (std::unique_ptr<Expression>*) &fContext.fDefined_Expression; } } } @@ -258,12 +292,13 @@ void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) { // returns a map which maps all local variables in the function to null, indicating that their value // is initially unknown -static std::unordered_map<const Variable*, const Expression*> compute_start_state(const CFG& cfg) { - std::unordered_map<const Variable*, const Expression*> result; +static DefinitionMap compute_start_state(const CFG& cfg) { + DefinitionMap result; for (const auto& block : cfg.fBlocks) { for (const auto& node : block.fNodes) { if (node.fKind == BasicBlock::Node::kStatement_Kind) { - const Statement* s = (Statement*) node.fNode; + ASSERT(node.fStatement); + const Statement* s = node.fStatement; if (s->fKind == Statement::kVarDeclarations_Kind) { const VarDeclarationsStatement* vd = (const VarDeclarationsStatement*) s; for (const VarDeclaration& decl : vd->fDeclaration->fVars) { @@ -295,19 +330,37 @@ void Compiler::scanCFG(const FunctionDefinition& f) { for (size_t i = 0; i < cfg.fBlocks.size(); i++) { if (i != cfg.fStart && !cfg.fBlocks[i].fEntrances.size() && cfg.fBlocks[i].fNodes.size()) { - this->error(cfg.fBlocks[i].fNodes[0].fNode->fPosition, SkString("unreachable")); + Position p; + switch (cfg.fBlocks[i].fNodes[0].fKind) { + case BasicBlock::Node::kStatement_Kind: + p = cfg.fBlocks[i].fNodes[0].fStatement->fPosition; + break; + case BasicBlock::Node::kExpression_Kind: + p = (*cfg.fBlocks[i].fNodes[0].fExpression)->fPosition; + break; + } + this->error(p, SkString("unreachable")); } } if (fErrorCount) { return; } - // check for undefined variables - for (const BasicBlock& b : cfg.fBlocks) { - std::unordered_map<const Variable*, const Expression*> definitions = b.fBefore; - for (const BasicBlock::Node& n : b.fNodes) { + // check for undefined variables, perform constant propagation + for (BasicBlock& b : cfg.fBlocks) { + DefinitionMap definitions = b.fBefore; + for (BasicBlock::Node& n : b.fNodes) { if (n.fKind == BasicBlock::Node::kExpression_Kind) { - const Expression* expr = (const Expression*) n.fNode; + ASSERT(n.fExpression); + Expression* expr = n.fExpression->get(); + if (n.fConstantPropagation) { + std::unique_ptr<Expression> optimized = expr->constantPropagate(*fIRGenerator, + definitions); + if (optimized) { + n.fExpression->reset(optimized.release()); + expr = n.fExpression->get(); + } + } if (expr->fKind == Expression::kVariableReference_Kind) { const Variable& var = ((VariableReference*) expr)->fVariable; if (var.fStorage == Variable::kLocal_Storage && diff --git a/src/sksl/SkSLCompiler.h b/src/sksl/SkSLCompiler.h index 0f893f7e64..fdca12d2cf 100644 --- a/src/sksl/SkSLCompiler.h +++ b/src/sksl/SkSLCompiler.h @@ -60,11 +60,10 @@ public: } private: - void addDefinition(const Expression* lvalue, const Expression* expr, - std::unordered_map<const Variable*, const Expression*>* definitions); + void addDefinition(const Expression* lvalue, std::unique_ptr<Expression>* expr, + DefinitionMap* definitions); - void addDefinitions(const BasicBlock::Node& node, - std::unordered_map<const Variable*, const Expression*>* definitions); + void addDefinitions(const BasicBlock::Node& node, DefinitionMap* definitions); void scanCFG(CFG* cfg, BlockId block, std::set<BlockId>* workList); diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp index 9f06c97d11..55d9d2c8d6 100644 --- a/src/sksl/SkSLIRGenerator.cpp +++ b/src/sksl/SkSLIRGenerator.cpp @@ -551,11 +551,11 @@ std::unique_ptr<InterfaceBlock> IRGenerator::convertInterfaceBlock(const ASTInte } } Type* type = new Type(intf.fPosition, intf.fInterfaceName, fields); - fSymbolTable->takeOwnership(type); + old->takeOwnership(type); SkString name = intf.fValueName.size() > 0 ? intf.fValueName : intf.fInterfaceName; Variable* var = new Variable(intf.fPosition, intf.fModifiers, name, *type, Variable::kGlobal_Storage); - fSymbolTable->takeOwnership(var); + old->takeOwnership(var); if (intf.fValueName.size()) { old->addWithoutOwnership(intf.fValueName, var); } else { @@ -624,19 +624,22 @@ std::unique_ptr<Expression> IRGenerator::convertIdentifier(const ASTIdentifier& f->fFunctions)); } case Symbol::kVariable_Kind: { - const Variable* var = (const Variable*) result; - this->markReadFrom(*var); + Variable* var = (Variable*) result; if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN && fSettings->fFlipY && (!fSettings->fCaps || !fSettings->fCaps->fragCoordConventionsExtensionString())) { fInputs.fRTHeight = true; } - return std::unique_ptr<VariableReference>(new VariableReference(identifier.fPosition, - *var)); + // default to kRead_RefKind; this will be corrected later if the variable is written to + return std::unique_ptr<VariableReference>(new VariableReference( + identifier.fPosition, + *var, + VariableReference::kRead_RefKind)); } case Symbol::kField_Kind: { const Field* field = (const Field*) result; - VariableReference* base = new VariableReference(identifier.fPosition, field->fOwner); + VariableReference* base = new VariableReference(identifier.fPosition, field->fOwner, + VariableReference::kRead_RefKind); return std::unique_ptr<Expression>(new FieldAccess( std::unique_ptr<Expression>(base), field->fFieldIndex, @@ -690,28 +693,6 @@ static bool is_matrix_multiply(const Type& left, const Type& right) { return left.kind() == Type::kVector_Kind && right.kind() == Type::kMatrix_Kind; } -static bool is_assignment(Token::Kind op) { - switch (op) { - case Token::EQ: // fall through - case Token::PLUSEQ: // fall through - case Token::MINUSEQ: // fall through - case Token::STAREQ: // fall through - case Token::SLASHEQ: // fall through - case Token::PERCENTEQ: // fall through - case Token::SHLEQ: // fall through - case Token::SHREQ: // fall through - case Token::BITWISEOREQ: // fall through - case Token::BITWISEXOREQ: // fall through - case Token::BITWISEANDEQ: // fall through - case Token::LOGICALOREQ: // fall through - case Token::LOGICALXOREQ: // fall through - case Token::LOGICALANDEQ: - return true; - default: - return false; - } -} - /** * Determines the operand and result types of a binary expression. Returns true if the expression is * legal, false otherwise. If false, the values of the out parameters are undefined. @@ -842,14 +823,9 @@ static bool determine_binary_type(const Context& context, return false; } -/** - * If both operands are compile-time constants and can be folded, returns an expression representing - * the folded value. Otherwise, returns null. Note that unlike most other functions here, null does - * not represent a compilation error. - */ std::unique_ptr<Expression> IRGenerator::constantFold(const Expression& left, Token::Kind op, - const Expression& right) { + const Expression& right) const { // Note that we expressly do not worry about precision and overflow here -- we use the maximum // precision to calculate the results and hope the result makes sense. The plan is to move the // Skia caps into SkSL, so we have access to all of them including the precisions of the various @@ -943,15 +919,16 @@ std::unique_ptr<Expression> IRGenerator::convertBinaryExpression( const Type* rightType; const Type* resultType; if (!determine_binary_type(fContext, expression.fOperator, left->fType, right->fType, &leftType, - &rightType, &resultType, !is_assignment(expression.fOperator))) { + &rightType, &resultType, + !Token::IsAssignment(expression.fOperator))) { fErrors.error(expression.fPosition, "type mismatch: '" + Token::OperatorName(expression.fOperator) + "' cannot operate on '" + left->fType.fName + "', '" + right->fType.fName + "'"); return nullptr; } - if (is_assignment(expression.fOperator)) { - this->markWrittenTo(*left); + if (Token::IsAssignment(expression.fOperator)) { + this->markWrittenTo(*left, expression.fOperator != Token::EQ); } left = this->coerce(std::move(left), *leftType); right = this->coerce(std::move(right), *rightType); @@ -1051,7 +1028,7 @@ std::unique_ptr<Expression> IRGenerator::call(Position position, return nullptr; } if (arguments[i] && (function.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag)) { - this->markWrittenTo(*arguments[i]); + this->markWrittenTo(*arguments[i], true); } } return std::unique_ptr<FunctionCall>(new FunctionCall(position, *returnType, function, @@ -1261,7 +1238,7 @@ std::unique_ptr<Expression> IRGenerator::convertPrefixExpression( "' cannot operate on '" + base->fType.description() + "'"); return nullptr; } - this->markWrittenTo(*base); + this->markWrittenTo(*base, true); break; case Token::MINUSMINUS: if (!base->fType.isNumber()) { @@ -1270,7 +1247,7 @@ std::unique_ptr<Expression> IRGenerator::convertPrefixExpression( "' cannot operate on '" + base->fType.description() + "'"); return nullptr; } - this->markWrittenTo(*base); + this->markWrittenTo(*base, true); break; case Token::LOGICALNOT: if (base->fType != *fContext.fBool_Type) { @@ -1464,7 +1441,7 @@ std::unique_ptr<Expression> IRGenerator::convertSuffixExpression( "'++' cannot operate on '" + base->fType.description() + "'"); return nullptr; } - this->markWrittenTo(*base); + this->markWrittenTo(*base, true); return std::unique_ptr<Expression>(new PostfixExpression(std::move(base), Token::PLUSPLUS)); case ASTSuffix::kPostDecrement_Kind: @@ -1473,7 +1450,7 @@ std::unique_ptr<Expression> IRGenerator::convertSuffixExpression( "'--' cannot operate on '" + base->fType.description() + "'"); return nullptr; } - this->markWrittenTo(*base); + this->markWrittenTo(*base, true); return std::unique_ptr<Expression>(new PostfixExpression(std::move(base), Token::MINUSMINUS)); default: @@ -1496,10 +1473,6 @@ void IRGenerator::checkValid(const Expression& expr) { } } -void IRGenerator::markReadFrom(const Variable& var) { - var.fIsReadFrom = true; -} - static bool has_duplicates(const Swizzle& swizzle) { int bits = 0; for (int idx : swizzle.fComponents) { @@ -1513,7 +1486,7 @@ static bool has_duplicates(const Swizzle& swizzle) { return false; } -void IRGenerator::markWrittenTo(const Expression& expr) { +void IRGenerator::markWrittenTo(const Expression& expr, bool readWrite) { switch (expr.fKind) { case Expression::kVariableReference_Kind: { const Variable& var = ((VariableReference&) expr).fVariable; @@ -1521,21 +1494,22 @@ void IRGenerator::markWrittenTo(const Expression& expr) { fErrors.error(expr.fPosition, "cannot modify immutable variable '" + var.fName + "'"); } - var.fIsWrittenTo = true; + ((VariableReference&) expr).setRefKind(readWrite ? VariableReference::kReadWrite_RefKind + : VariableReference::kWrite_RefKind); break; } case Expression::kFieldAccess_Kind: - this->markWrittenTo(*((FieldAccess&) expr).fBase); + this->markWrittenTo(*((FieldAccess&) expr).fBase, readWrite); break; case Expression::kSwizzle_Kind: if (has_duplicates((Swizzle&) expr)) { fErrors.error(expr.fPosition, "cannot write to the same swizzle field more than once"); } - this->markWrittenTo(*((Swizzle&) expr).fBase); + this->markWrittenTo(*((Swizzle&) expr).fBase, readWrite); break; case Expression::kIndex_Kind: - this->markWrittenTo(*((IndexExpression&) expr).fBase); + this->markWrittenTo(*((IndexExpression&) expr).fBase, readWrite); break; default: fErrors.error(expr.fPosition, "cannot assign to '" + expr.description() + "'"); diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h index 13b20fbbcc..2ffcb0df26 100644 --- a/src/sksl/SkSLIRGenerator.h +++ b/src/sksl/SkSLIRGenerator.h @@ -88,7 +88,16 @@ public: std::unique_ptr<ModifiersDeclaration> convertModifiersDeclaration( const ASTModifiersDeclaration& m); + /** + * If both operands are compile-time constants and can be folded, returns an expression + * representing the folded value. Otherwise, returns null. Note that unlike most other functions + * here, null does not represent a compilation error. + */ + std::unique_ptr<Expression> constantFold(const Expression& left, + Token::Kind op, + const Expression& right) const; Program::Inputs fInputs; + const Context& fContext; private: /** @@ -124,11 +133,6 @@ private: std::unique_ptr<Statement> convertDiscard(const ASTDiscardStatement& d); std::unique_ptr<Statement> convertDo(const ASTDoStatement& d); std::unique_ptr<Expression> convertBinaryExpression(const ASTBinaryExpression& expression); - // Returns null if it cannot fold the expression. Note that unlike most other functions here, a - // null return does not represent a compilation error. - std::unique_ptr<Expression> constantFold(const Expression& left, - Token::Kind op, - const Expression& right); std::unique_ptr<Extension> convertExtension(const ASTExtension& e); std::unique_ptr<Statement> convertExpressionStatement(const ASTExpressionStatement& s); std::unique_ptr<Statement> convertFor(const ASTForStatement& f); @@ -151,10 +155,8 @@ private: std::unique_ptr<Statement> convertWhile(const ASTWhileStatement& w); void checkValid(const Expression& expr); - void markReadFrom(const Variable& var); - void markWrittenTo(const Expression& expr); + void markWrittenTo(const Expression& expr, bool readWrite); - const Context& fContext; const FunctionDeclaration* fCurrentFunction; const Program::Settings* fSettings; std::unordered_map<SkString, CapValue> fCapsMap; diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp index 8afd13688c..d43e4c4035 100644 --- a/src/sksl/SkSLSPIRVCodeGenerator.cpp +++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp @@ -2540,7 +2540,7 @@ void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclaratio kind != Program::kFragment_Kind) { continue; } - if (!var->fIsReadFrom && !var->fIsWrittenTo && + if (!var->fReadCount && !var->fWriteCount && !(var->fModifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag | Modifiers::kUniform_Flag))) { diff --git a/src/sksl/SkSLToken.h b/src/sksl/SkSLToken.h index 5c8c2bd215..197781f2a0 100644 --- a/src/sksl/SkSLToken.h +++ b/src/sksl/SkSLToken.h @@ -160,6 +160,28 @@ struct Token { , fKind(kind) , fText(std::move(text)) {} + static bool IsAssignment(Token::Kind op) { + switch (op) { + case Token::EQ: // fall through + case Token::PLUSEQ: // fall through + case Token::MINUSEQ: // fall through + case Token::STAREQ: // fall through + case Token::SLASHEQ: // fall through + case Token::PERCENTEQ: // fall through + case Token::SHLEQ: // fall through + case Token::SHREQ: // fall through + case Token::BITWISEOREQ: // fall through + case Token::BITWISEXOREQ: // fall through + case Token::BITWISEANDEQ: // fall through + case Token::LOGICALOREQ: // fall through + case Token::LOGICALXOREQ: // fall through + case Token::LOGICALANDEQ: + return true; + default: + return false; + } + } + Position fPosition; Kind fKind; // will be the empty string unless the token has variable text content (identifiers, numeric diff --git a/src/sksl/ir/SkSLBinaryExpression.h b/src/sksl/ir/SkSLBinaryExpression.h index 132513e7f7..de85e4812b 100644 --- a/src/sksl/ir/SkSLBinaryExpression.h +++ b/src/sksl/ir/SkSLBinaryExpression.h @@ -4,17 +4,19 @@ * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ - + #ifndef SKSL_BINARYEXPRESSION #define SKSL_BINARYEXPRESSION #include "SkSLExpression.h" +#include "SkSLExpression.h" +#include "../SkSLIRGenerator.h" #include "../SkSLToken.h" namespace SkSL { /** - * A binary operation. + * A binary operation. */ struct BinaryExpression : public Expression { BinaryExpression(Position position, std::unique_ptr<Expression> left, Token::Kind op, @@ -24,14 +26,22 @@ struct BinaryExpression : public Expression { , fOperator(op) , fRight(std::move(right)) {} + virtual std::unique_ptr<Expression> constantPropagate( + const IRGenerator& irGenerator, + const DefinitionMap& definitions) override { + return irGenerator.constantFold(*fLeft, + fOperator, + *fRight); + } + virtual SkString description() const override { return "(" + fLeft->description() + " " + Token::OperatorName(fOperator) + " " + fRight->description() + ")"; } - const std::unique_ptr<Expression> fLeft; + std::unique_ptr<Expression> fLeft; const Token::Kind fOperator; - const std::unique_ptr<Expression> fRight; + std::unique_ptr<Expression> fRight; typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLBlock.h b/src/sksl/ir/SkSLBlock.h index f975d160a0..17970fd561 100644 --- a/src/sksl/ir/SkSLBlock.h +++ b/src/sksl/ir/SkSLBlock.h @@ -20,8 +20,8 @@ struct Block : public Statement { Block(Position position, std::vector<std::unique_ptr<Statement>> statements, const std::shared_ptr<SymbolTable> symbols) : INHERITED(position, kBlock_Kind) - , fStatements(std::move(statements)) - , fSymbols(std::move(symbols)) {} + , fSymbols(std::move(symbols)) + , fStatements(std::move(statements)) {} SkString description() const override { SkString result("{"); @@ -33,8 +33,10 @@ struct Block : public Statement { return result; } - const std::vector<std::unique_ptr<Statement>> fStatements; + // it's important to keep fStatements defined after (and thus destroyed before) fSymbols, + // because destroying statements can modify reference counts in symbols const std::shared_ptr<SymbolTable> fSymbols; + const std::vector<std::unique_ptr<Statement>> fStatements; typedef Statement INHERITED; }; diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h index 63c692b88e..691bea123a 100644 --- a/src/sksl/ir/SkSLConstructor.h +++ b/src/sksl/ir/SkSLConstructor.h @@ -9,6 +9,9 @@ #define SKSL_CONSTRUCTOR #include "SkSLExpression.h" +#include "SkSLFloatLiteral.h" +#include "SkSLIntLiteral.h" +#include "SkSLIRGenerator.h" namespace SkSL { @@ -21,6 +24,20 @@ struct Constructor : public Expression { : INHERITED(position, kConstructor_Kind, type) , fArguments(std::move(arguments)) {} + virtual std::unique_ptr<Expression> constantPropagate( + const IRGenerator& irGenerator, + const DefinitionMap& definitions) override { + if (fArguments.size() == 1 && fArguments[0]->fKind == Expression::kIntLiteral_Kind && + // promote float(1) to 1.0 + fType == *irGenerator.fContext.fFloat_Type) { + int64_t intValue = ((IntLiteral&) *fArguments[0]).fValue; + return std::unique_ptr<Expression>(new FloatLiteral(irGenerator.fContext, + fPosition, + intValue)); + } + return nullptr; + } + SkString description() const override { SkString result = fType.description() + "("; SkString separator; @@ -42,7 +59,7 @@ struct Constructor : public Expression { return true; } - const std::vector<std::unique_ptr<Expression>> fArguments; + std::vector<std::unique_ptr<Expression>> fArguments; typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLDoStatement.h b/src/sksl/ir/SkSLDoStatement.h index 78c0a1b768..e26d3dc974 100644 --- a/src/sksl/ir/SkSLDoStatement.h +++ b/src/sksl/ir/SkSLDoStatement.h @@ -28,7 +28,7 @@ struct DoStatement : public Statement { } const std::unique_ptr<Statement> fStatement; - const std::unique_ptr<Expression> fTest; + std::unique_ptr<Expression> fTest; typedef Statement INHERITED; }; diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h index b4ed37c09a..f87d810fc0 100644 --- a/src/sksl/ir/SkSLExpression.h +++ b/src/sksl/ir/SkSLExpression.h @@ -4,17 +4,24 @@ * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ - + #ifndef SKSL_EXPRESSION #define SKSL_EXPRESSION -#include "SkSLIRNode.h" #include "SkSLType.h" +#include "SkSLVariable.h" + +#include <unordered_map> namespace SkSL { +struct Expression; +class IRGenerator; + +typedef std::unordered_map<const Variable*, std::unique_ptr<Expression>*> DefinitionMap; + /** - * Abstract supertype of all expressions. + * Abstract supertype of all expressions. */ struct Expression : public IRNode { enum Kind { @@ -45,6 +52,18 @@ struct Expression : public IRNode { return false; } + /** + * Given a map of known constant variable values, substitute them in for references to those + * variables occurring in this expression and its subexpressions. Similar simplifications, such + * as folding a constant binary expression down to a single value, may also be performed. + * Returns a new expression which replaces this expression, or null if no replacements were + * made. If a new expression is returned, this expression is no longer valid. + */ + virtual std::unique_ptr<Expression> constantPropagate(const IRGenerator& irGenerator, + const DefinitionMap& definitions) { + return nullptr; + } + const Kind fKind; const Type& fType; diff --git a/src/sksl/ir/SkSLExpressionStatement.h b/src/sksl/ir/SkSLExpressionStatement.h index 677c647587..088b1c9ad1 100644 --- a/src/sksl/ir/SkSLExpressionStatement.h +++ b/src/sksl/ir/SkSLExpressionStatement.h @@ -25,7 +25,7 @@ struct ExpressionStatement : public Statement { return fExpression->description() + ";"; } - const std::unique_ptr<Expression> fExpression; + std::unique_ptr<Expression> fExpression; typedef Statement INHERITED; }; diff --git a/src/sksl/ir/SkSLFieldAccess.h b/src/sksl/ir/SkSLFieldAccess.h index fb727e017e..de26a3f626 100644 --- a/src/sksl/ir/SkSLFieldAccess.h +++ b/src/sksl/ir/SkSLFieldAccess.h @@ -35,7 +35,7 @@ struct FieldAccess : public Expression { return fBase->description() + "." + fBase->fType.fields()[fFieldIndex].fName; } - const std::unique_ptr<Expression> fBase; + std::unique_ptr<Expression> fBase; const int fFieldIndex; const OwnerKind fOwnerKind; diff --git a/src/sksl/ir/SkSLForStatement.h b/src/sksl/ir/SkSLForStatement.h index ff03d0d7f9..6f03e2bb36 100644 --- a/src/sksl/ir/SkSLForStatement.h +++ b/src/sksl/ir/SkSLForStatement.h @@ -46,8 +46,8 @@ struct ForStatement : public Statement { } const std::unique_ptr<Statement> fInitializer; - const std::unique_ptr<Expression> fTest; - const std::unique_ptr<Expression> fNext; + std::unique_ptr<Expression> fTest; + std::unique_ptr<Expression> fNext; const std::unique_ptr<Statement> fStatement; const std::shared_ptr<SymbolTable> fSymbols; diff --git a/src/sksl/ir/SkSLFunctionCall.h b/src/sksl/ir/SkSLFunctionCall.h index 971af366b9..1838076796 100644 --- a/src/sksl/ir/SkSLFunctionCall.h +++ b/src/sksl/ir/SkSLFunctionCall.h @@ -36,7 +36,7 @@ struct FunctionCall : public Expression { } const FunctionDeclaration& fFunction; - const std::vector<std::unique_ptr<Expression>> fArguments; + std::vector<std::unique_ptr<Expression>> fArguments; typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLIfStatement.h b/src/sksl/ir/SkSLIfStatement.h index f8beded9e8..8667e932ec 100644 --- a/src/sksl/ir/SkSLIfStatement.h +++ b/src/sksl/ir/SkSLIfStatement.h @@ -32,7 +32,7 @@ struct IfStatement : public Statement { return result; } - const std::unique_ptr<Expression> fTest; + std::unique_ptr<Expression> fTest; const std::unique_ptr<Statement> fIfTrue; const std::unique_ptr<Statement> fIfFalse; diff --git a/src/sksl/ir/SkSLIndexExpression.h b/src/sksl/ir/SkSLIndexExpression.h index 079dde5e53..d255c7daf6 100644 --- a/src/sksl/ir/SkSLIndexExpression.h +++ b/src/sksl/ir/SkSLIndexExpression.h @@ -55,8 +55,8 @@ struct IndexExpression : public Expression { return fBase->description() + "[" + fIndex->description() + "]"; } - const std::unique_ptr<Expression> fBase; - const std::unique_ptr<Expression> fIndex; + std::unique_ptr<Expression> fBase; + std::unique_ptr<Expression> fIndex; typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLPostfixExpression.h b/src/sksl/ir/SkSLPostfixExpression.h index 01671b5b88..6c9fafe5a0 100644 --- a/src/sksl/ir/SkSLPostfixExpression.h +++ b/src/sksl/ir/SkSLPostfixExpression.h @@ -25,7 +25,7 @@ struct PostfixExpression : public Expression { return fOperand->description() + Token::OperatorName(fOperator); } - const std::unique_ptr<Expression> fOperand; + std::unique_ptr<Expression> fOperand; const Token::Kind fOperator; typedef Expression INHERITED; diff --git a/src/sksl/ir/SkSLPrefixExpression.h b/src/sksl/ir/SkSLPrefixExpression.h index 790c5ab47a..b7db99a0a4 100644 --- a/src/sksl/ir/SkSLPrefixExpression.h +++ b/src/sksl/ir/SkSLPrefixExpression.h @@ -25,7 +25,7 @@ struct PrefixExpression : public Expression { return Token::OperatorName(fOperator) + fOperand->description(); } - const std::unique_ptr<Expression> fOperand; + std::unique_ptr<Expression> fOperand; const Token::Kind fOperator; typedef Expression INHERITED; diff --git a/src/sksl/ir/SkSLProgram.h b/src/sksl/ir/SkSLProgram.h index ac49d6dcc7..6a73be6983 100644 --- a/src/sksl/ir/SkSLProgram.h +++ b/src/sksl/ir/SkSLProgram.h @@ -59,8 +59,8 @@ struct Program { , fSettings(settings) , fDefaultPrecision(defaultPrecision) , fContext(context) - , fElements(std::move(elements)) , fSymbols(symbols) + , fElements(std::move(elements)) , fInputs(inputs) {} Kind fKind; @@ -68,8 +68,10 @@ struct Program { // FIXME handle different types; currently it assumes this is for floats Modifiers::Flag fDefaultPrecision; Context* fContext; - std::vector<std::unique_ptr<ProgramElement>> fElements; + // it's important to keep fElements defined after (and thus destroyed before) fSymbols, + // because destroying elements can modify reference counts in symbols std::shared_ptr<SymbolTable> fSymbols; + std::vector<std::unique_ptr<ProgramElement>> fElements; Inputs fInputs; }; diff --git a/src/sksl/ir/SkSLReturnStatement.h b/src/sksl/ir/SkSLReturnStatement.h index c83b45066e..dc5ec9aa9c 100644 --- a/src/sksl/ir/SkSLReturnStatement.h +++ b/src/sksl/ir/SkSLReturnStatement.h @@ -32,7 +32,7 @@ struct ReturnStatement : public Statement { } } - const std::unique_ptr<Expression> fExpression; + std::unique_ptr<Expression> fExpression; typedef Statement INHERITED; }; diff --git a/src/sksl/ir/SkSLSwizzle.h b/src/sksl/ir/SkSLSwizzle.h index c9397aec7f..8ad9001ada 100644 --- a/src/sksl/ir/SkSLSwizzle.h +++ b/src/sksl/ir/SkSLSwizzle.h @@ -76,7 +76,7 @@ struct Swizzle : public Expression { return result; } - const std::unique_ptr<Expression> fBase; + std::unique_ptr<Expression> fBase; const std::vector<int> fComponents; typedef Expression INHERITED; diff --git a/src/sksl/ir/SkSLTernaryExpression.h b/src/sksl/ir/SkSLTernaryExpression.h index 4a352536e3..02750049d4 100644 --- a/src/sksl/ir/SkSLTernaryExpression.h +++ b/src/sksl/ir/SkSLTernaryExpression.h @@ -31,9 +31,9 @@ struct TernaryExpression : public Expression { fIfFalse->description() + ")"; } - const std::unique_ptr<Expression> fTest; - const std::unique_ptr<Expression> fIfTrue; - const std::unique_ptr<Expression> fIfFalse; + std::unique_ptr<Expression> fTest; + std::unique_ptr<Expression> fIfTrue; + std::unique_ptr<Expression> fIfFalse; typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLVarDeclarations.h b/src/sksl/ir/SkSLVarDeclarations.h index 295c0b6997..490259a081 100644 --- a/src/sksl/ir/SkSLVarDeclarations.h +++ b/src/sksl/ir/SkSLVarDeclarations.h @@ -72,7 +72,7 @@ struct VarDeclarations : public ProgramElement { } const Type& fBaseType; - const std::vector<VarDeclaration> fVars; + std::vector<VarDeclaration> fVars; typedef ProgramElement INHERITED; }; diff --git a/src/sksl/ir/SkSLVarDeclarationsStatement.h b/src/sksl/ir/SkSLVarDeclarationsStatement.h index 7a29656593..66b570f853 100644 --- a/src/sksl/ir/SkSLVarDeclarationsStatement.h +++ b/src/sksl/ir/SkSLVarDeclarationsStatement.h @@ -18,14 +18,14 @@ namespace SkSL { */ struct VarDeclarationsStatement : public Statement { VarDeclarationsStatement(std::unique_ptr<VarDeclarations> decl) - : INHERITED(decl->fPosition, kVarDeclarations_Kind) + : INHERITED(decl->fPosition, kVarDeclarations_Kind) , fDeclaration(std::move(decl)) {} SkString description() const override { return fDeclaration->description(); } - const std::shared_ptr<VarDeclarations> fDeclaration; + std::shared_ptr<VarDeclarations> fDeclaration; typedef Statement INHERITED; }; diff --git a/src/sksl/ir/SkSLVariable.h b/src/sksl/ir/SkSLVariable.h index 39b8482a7b..2c3391dfa2 100644 --- a/src/sksl/ir/SkSLVariable.h +++ b/src/sksl/ir/SkSLVariable.h @@ -33,8 +33,8 @@ struct Variable : public Symbol { , fModifiers(modifiers) , fType(type) , fStorage(storage) - , fIsReadFrom(false) - , fIsWrittenTo(false) {} + , fReadCount(0) + , fWriteCount(0) {} virtual SkString description() const override { return fModifiers.description() + fType.fName + " " + fName; @@ -44,8 +44,12 @@ struct Variable : public Symbol { const Type& fType; const Storage fStorage; - mutable bool fIsReadFrom; - mutable bool fIsWrittenTo; + // Tracks how many sites read from the variable. If this is zero for a non-out variable (or + // becomes zero during optimization), the variable is dead and may be eliminated. + mutable int fReadCount; + // Tracks how many sites write to the variable. If this is zero, the variable is dead and may be + // eliminated. + mutable int fWriteCount; typedef Symbol INHERITED; }; diff --git a/src/sksl/ir/SkSLVariableReference.h b/src/sksl/ir/SkSLVariableReference.h index c6a2ea0511..fecb04e2e5 100644 --- a/src/sksl/ir/SkSLVariableReference.h +++ b/src/sksl/ir/SkSLVariableReference.h @@ -20,16 +20,83 @@ namespace SkSL { * there is only one Variable 'x', but two VariableReferences to it. */ struct VariableReference : public Expression { - VariableReference(Position position, const Variable& variable) + enum RefKind { + kRead_RefKind, + kWrite_RefKind, + kReadWrite_RefKind + }; + + VariableReference(Position position, const Variable& variable, RefKind refKind = kRead_RefKind) : INHERITED(position, kVariableReference_Kind, variable.fType) - , fVariable(variable) {} + , fVariable(variable) + , fRefKind(refKind) { + if (refKind != kRead_RefKind) { + fVariable.fWriteCount++; + } + if (refKind != kWrite_RefKind) { + fVariable.fReadCount++; + } + } + + virtual ~VariableReference() override { + if (fRefKind != kWrite_RefKind) { + fVariable.fReadCount--; + } + } + + RefKind refKind() { + return fRefKind; + } + + void setRefKind(RefKind refKind) { + if (fRefKind != kRead_RefKind) { + fVariable.fWriteCount--; + } + if (fRefKind != kWrite_RefKind) { + fVariable.fReadCount--; + } + if (refKind != kRead_RefKind) { + fVariable.fWriteCount++; + } + if (refKind != kWrite_RefKind) { + fVariable.fReadCount++; + } + fRefKind = refKind; + } SkString description() const override { return fVariable.fName; } + virtual std::unique_ptr<Expression> constantPropagate( + const IRGenerator& irGenerator, + const DefinitionMap& definitions) override { + auto exprIter = definitions.find(&fVariable); + if (exprIter != definitions.end() && exprIter->second) { + const Expression* expr = exprIter->second->get(); + switch (expr->fKind) { + case Expression::kIntLiteral_Kind: + return std::unique_ptr<Expression>(new IntLiteral( + irGenerator.fContext, + Position(), + ((IntLiteral*) expr)->fValue)); + case Expression::kFloatLiteral_Kind: + return std::unique_ptr<Expression>(new FloatLiteral( + irGenerator.fContext, + Position(), + ((FloatLiteral*) expr)->fValue)); + default: + break; + } + } + return nullptr; + } + const Variable& fVariable; +private: + RefKind fRefKind; + typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLWhileStatement.h b/src/sksl/ir/SkSLWhileStatement.h index 7c6a2907c4..a741a0441d 100644 --- a/src/sksl/ir/SkSLWhileStatement.h +++ b/src/sksl/ir/SkSLWhileStatement.h @@ -27,7 +27,7 @@ struct WhileStatement : public Statement { return "while (" + fTest->description() + ") " + fStatement->description(); } - const std::unique_ptr<Expression> fTest; + std::unique_ptr<Expression> fTest; const std::unique_ptr<Statement> fStatement; typedef Statement INHERITED; diff --git a/tests/SkSLGLSLTest.cpp b/tests/SkSLGLSLTest.cpp index 12ac4d1101..1501dc5677 100644 --- a/tests/SkSLGLSLTest.cpp +++ b/tests/SkSLGLSLTest.cpp @@ -61,7 +61,7 @@ DEF_TEST(SkSLControl, r) { "while (i < 10) sk_FragColor *= 0.5;" "do { sk_FragColor += 0.01; } while (sk_FragColor.x < 0.75);" "for (int i = 0; i < 10; i++) {" - "if (i % 0 == 1) break; else continue;" + "if (i % 2 == 1) break; else continue;" "}" "return;" "}", @@ -75,12 +75,12 @@ DEF_TEST(SkSLControl, r) { " discard;\n" " }\n" " int i = 0;\n" - " while (i < 10) sk_FragColor *= 0.5;\n" + " while (true) sk_FragColor *= 0.5;\n" " do {\n" " sk_FragColor += 0.01;\n" " } while (sk_FragColor.x < 0.75);\n" " for (int i = 0;i < 10; i++) {\n" - " if (i % 0 == 1) break; else continue;\n" + " if (i % 2 == 1) break; else continue;\n" " }\n" " return;\n" "}\n"); @@ -106,8 +106,8 @@ DEF_TEST(SkSLFunctions, r) { "}\n" "void main() {\n" " float x = 10.0;\n" - " bar(x);\n" - " sk_FragColor = vec4(x);\n" + " bar(10.0);\n" + " sk_FragColor = vec4(10.0);\n" "}\n"); } @@ -116,7 +116,7 @@ DEF_TEST(SkSLOperators, r) { "void main() {" "float x = 1, y = 2;" "int z = 3;" - "x = x + y * z * x * (y - z);" + "x = x - x + y * z * x * (y - z);" "y = x / y / z;" "z = (z / 2 % 3 << 4) >> 2 << 1;" "bool b = (x > 4) == x < 2 || 2 >= sqrt(2) && y <= z;" @@ -139,10 +139,10 @@ DEF_TEST(SkSLOperators, r) { "void main() {\n" " float x = 1.0, y = 2.0;\n" " int z = 3;\n" - " x = x + ((y * float(z)) * x) * (y - float(z));\n" - " y = (x / y) / float(z);\n" - " z = (((z / 2) % 3 << 4) >> 2) << 1;\n" - " bool b = x > 4.0 == x < 2.0 || 2.0 >= sqrt(2.0) && y <= float(z);\n" + " x = -6.0;\n" + " y = -1.0;\n" + " z = 8;\n" + " bool b = false == true || 2.0 >= sqrt(2.0) && true;\n" " x += 12.0;\n" " x -= 12.0;\n" " x *= (y /= float(z = 10));\n" @@ -287,7 +287,7 @@ DEF_TEST(SkSLMinAbs, r) { "out vec4 sk_FragColor;\n" "void main() {\n" " float x = -5.0;\n" - " x = min(abs(x), 6.0);\n" + " x = min(abs(-5.0), 6.0);\n" "}\n"); test(r, @@ -302,7 +302,7 @@ DEF_TEST(SkSLMinAbs, r) { " float minAbsHackVar0;\n" " float minAbsHackVar1;\n" " float x = -5.0;\n" - " x = ((minAbsHackVar0 = abs(x)) < (minAbsHackVar1 = 6.0) ? minAbsHackVar0 : " + " x = ((minAbsHackVar0 = abs(-5.0)) < (minAbsHackVar1 = 6.0) ? minAbsHackVar0 : " "minAbsHackVar1);\n" "}\n"); } |